1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/tuple_util.h"
31 #include "tensorflow/compiler/xla/service/while_util.h"
32 #include "tensorflow/compiler/xla/shape_tree.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 namespace xla {
38
39 namespace {
40 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
WidenComputation(HloComputation * narrow_comp,const Shape & wide_shape)41 StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
42 const Shape& wide_shape) {
43 TF_RET_CHECK(wide_shape.IsTuple());
44 const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
45 if (Shape::Equal()(wide_shape, narrow_shape)) {
46 // No need to widen the computation.
47 return narrow_comp;
48 }
49 HloComputation* wide_comp = [&]() {
50 HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
51 builder.AddInstruction(
52 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
53 return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
54 }();
55
56 HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
57 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
58 wide_parameter, narrow_shape.tuple_shapes_size());
59 HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
60 HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
61 {truncated_parameter}, narrow_comp));
62 wide_comp->set_root_instruction(call_narrow_comp,
63 /*accept_different_shape=*/true);
64 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
65 return wide_comp;
66 }
67 } // namespace
68
69 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
70 public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)71 explicit DynamicDimensionInferenceVisitor(
72 const DynamicParameterBinding& param_bindings,
73 DynamicDimensionInference* parent,
74 DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)
75 : param_bindings_(param_bindings),
76 parent_(parent),
77 custom_call_handler_(std::move(custom_call_handler)) {}
78
79 Status DefaultAction(HloInstruction* hlo) override;
80
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler=nullptr)81 static Status Run(HloComputation* computation,
82 const DynamicParameterBinding& param_bindings,
83 DynamicDimensionInference* parent,
84 DynamicDimensionInference::CustomCallInferenceHandler
85 custom_call_handler = nullptr) {
86 DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
87 std::move(custom_call_handler));
88 return computation->Accept(&visitor);
89 }
90
91 Status HandleParameter(HloInstruction* hlo) override;
92
93 Status HandleReduce(HloInstruction* hlo) override;
94
95 Status HandleDot(HloInstruction* hlo) override;
96
97 Status HandleTuple(HloInstruction* hlo) override;
98
99 Status HandleTranspose(HloInstruction* hlo) override;
100
101 Status HandleDynamicReshape(HloInstruction* hlo) override;
102
103 Status HandleReshape(HloInstruction* hlo) override;
104
105 Status HandleSort(HloInstruction* hlo) override;
106
107 Status HandlePad(HloInstruction* hlo) override;
108
109 Status HandleCustomCall(HloInstruction* hlo) override;
110
111 Status HandleBroadcast(HloInstruction* hlo) override;
112
113 Status HandleGetDimensionSize(HloInstruction* hlo) override;
114
115 Status HandleSetDimensionSize(HloInstruction* hlo) override;
116
117 Status HandleSelect(HloInstruction* hlo) override;
118
119 Status HandleConvolution(HloInstruction* hlo) override;
120
121 Status HandleConcatenate(HloInstruction* hlo) override;
122
123 Status HandleReduceWindow(HloInstruction* hlo) override;
124
125 Status HandleReverse(HloInstruction* hlo) override;
126
127 Status HandleSelectAndScatter(HloInstruction* hlo) override;
128
129 Status HandleGetTupleElement(HloInstruction* hlo) override;
130
131 Status HandleElementwiseUnary(HloInstruction* hlo) override;
132
133 Status HandleElementwiseBinary(HloInstruction* hlo) override;
134
135 Status HandleClamp(HloInstruction* hlo) override;
136
137 Status HandleConditional(HloInstruction* hlo) override;
138
139 Status HandleWhile(HloInstruction* hlo) override;
140
141 Status HandleSlice(HloInstruction* hlo) override;
142
143 Status HandleDynamicSlice(HloInstruction* hlo) override;
144
145 Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
146
147 Status HandleGather(HloInstruction* hlo) override;
148
149 Status HandleScatter(HloInstruction* hlo) override;
150
151 Status HandleDomain(HloInstruction* hlo) override;
152
153 private:
154 using OperandDynamicDimensionFn = std::function<Status(
155 HloInstruction* operand, ShapeIndex index, int64_t dimension,
156 int64_t operand_index, HloInstruction* dynamic_size)>;
157
158 using DynamicDimensionFn = std::function<Status(
159 ShapeIndex index, int64_t dimension, HloInstruction* dynamic_size)>;
160
161 Status HandleDynamicConvolutionForward(HloInstruction* hlo,
162 int64_t operand_index,
163 int64_t dimension,
164 HloInstruction* dynamic_size);
165
166 Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
167 int64_t operand_index,
168 int64_t dimension);
169
170 Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
171 int64_t operand_index,
172 int64_t dimension);
173
174 Status HandleDynamicWindowSamePadding(HloInstruction* hlo,
175 HloInstruction* dynamic_size,
176 int64_t operand_index,
177 int64_t dimension);
178
179 Status ForEachOperandDynamicDimension(HloInstruction* inst,
180 const OperandDynamicDimensionFn&);
181 Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
182 int64_t operand_index,
183 const OperandDynamicDimensionFn&);
184 Status ForEachDynamicDimension(HloInstruction* inst,
185 const DynamicDimensionFn& fn);
186
187 // Pass through a dynamic dimension from the input to the output with the
188 // same value and index in the shape. This is a helper function to handle
189 // trivial instructions like elementwise operations.
190 Status PassThroughDynamicDimension(HloInstruction*);
191
192 // The dynamic parameter bindings of this computation.
193 const DynamicParameterBinding& param_bindings_;
194
195 // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
196 DynamicDimensionInference* parent_;
197
198 // A handler for custom calls.
199 DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
200 };
201
DefaultAction(HloInstruction * hlo)202 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
203 return ForEachOperandDynamicDimension(
204 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
205 int64_t operand_index, HloInstruction* dynamic_size) {
206 return UnimplementedStrCat(
207 "Asked to propagate a dynamic dimension from hlo ", operand->name(),
208 "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
209 ", which is not implemented.");
210 });
211 }
212
HandleGetTupleElement(HloInstruction * hlo)213 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
214 HloInstruction* hlo) {
215 return ForEachOperandDynamicDimension(
216 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
217 int64_t operand_index, HloInstruction* dynamic_size) {
218 if (hlo->tuple_index() == index[0]) {
219 ShapeIndex new_index =
220 ShapeIndexView(index).ConsumeFront().ToShapeIndex();
221 parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
222 }
223 return Status::OK();
224 });
225 }
226
HandleTuple(HloInstruction * hlo)227 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
228 return ForEachOperandDynamicDimension(
229 hlo, [&](HloInstruction*, ShapeIndex index, int64_t dimension,
230 int64_t operand_index, HloInstruction* dynamic_size) {
231 index.push_front(operand_index);
232 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
233 return Status::OK();
234 });
235 }
236
HandleBroadcast(HloInstruction * hlo)237 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
238 return ForEachOperandDynamicDimension(
239 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
240 int64_t operand_index, HloInstruction* dynamic_size) {
241 int64_t broadcast_dim = hlo->dimensions(dimension);
242 parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
243 return Status::OK();
244 });
245 }
246
HandleCustomCall(HloInstruction * hlo)247 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
248 if (hlo->custom_call_target() == "PadToStatic") {
249 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
250 if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
251 HloInstruction* dynamic_size =
252 hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
253 ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
254 // PadToStatic converts a dynamic dimension to static dimension. It then
255 // returns the padded data output and the dynamic sizes of input
256 // dimensions.
257 ShapeIndex data_output = {0};
258 parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
259 }
260 }
261 return Status::OK();
262 }
263 if (custom_call_handler_) {
264 return custom_call_handler_(hlo, parent_);
265 }
266
267 if (hlo->custom_call_target() == "DynamicConvolutionForward") {
268 // If input feature is dynamic and kernel feature is static, we can infer
269 // that input feature is also static.
270 // E.g.,:
271 // lhs = [B, X, Y, ?]
272 // rhs = [X, Y, I, O]
273 // dim_labels = b01f_01io
274 // We can infer that the dynamic dimension in rhs is static I.
275 const ConvolutionDimensionNumbers& dnums =
276 hlo->convolution_dimension_numbers();
277 HloInstruction* input_feature = parent_->GetDynamicSize(
278 hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
279 HloInstruction* kernel_feature = parent_->GetDynamicSize(
280 hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
281
282 if (input_feature != nullptr && kernel_feature == nullptr) {
283 if (hlo->mutable_operand(0)->shape().dimensions(
284 dnums.input_feature_dimension()) ==
285 hlo->mutable_operand(1)->shape().dimensions(
286 dnums.kernel_input_feature_dimension()))
287 parent_->SetDynamicSize(hlo->mutable_operand(0), {},
288 dnums.input_feature_dimension(), nullptr);
289 }
290 }
291 return ForEachOperandDynamicDimension(
292 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
293 int64_t operand_index, HloInstruction* dynamic_size) {
294 // Resize custom call should propagate dynamic batch (0) and channel (3)
295 // dimensions.
296 if (hlo->custom_call_target() == "SliceToDynamic" ||
297 hlo->custom_call_target() == "Sharding" ||
298 (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
299 (dimension == 0 || dimension == 3))) {
300 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
301 return Status::OK();
302 }
303 if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
304 if (hlo->operand_count() > 2) {
305 return Unimplemented(
306 "DynamicReduceWindowSamePadding doesn't support variadic "
307 "reduce window %s",
308 hlo->ToString());
309 }
310 return HandleDynamicWindowSamePadding(hlo, dynamic_size,
311 operand_index, dimension);
312 }
313
314 if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
315 if (operand_index == 1) {
316 // Operand 0 (input) determines dynamic output size. We ignore the
317 // dynamic size in the operand 1 (output gradient).
318 return Status::OK();
319 }
320 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
321 return Status::OK();
322 }
323
324 if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
325 return HandleDynamicConvolutionInputGrad(hlo, operand_index,
326 dimension);
327 }
328
329 if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
330 return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
331 dimension);
332 }
333
334 if (hlo->custom_call_target() == "DynamicConvolutionForward") {
335 return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
336 dynamic_size);
337 }
338 return Unimplemented(
339 "CustomCall \"%s\" is not supported to have a dynamic dimension",
340 hlo->custom_call_target());
341 });
342 }
343
HandleSort(HloInstruction * hlo)344 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
345 return ForEachOperandDynamicDimension(
346 hlo,
347 [&](HloInstruction* operand, ShapeIndex index, int64_t dynamic_dimension,
348 int64_t operand_index, HloInstruction* dynamic_size) {
349 HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
350 if (sort->values_count() == 0) {
351 parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
352 } else {
353 parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
354 dynamic_size);
355 }
356
357 return Status::OK();
358 });
359 }
360
HandlePad(HloInstruction * hlo)361 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
362 return ForEachOperandDynamicDimension(
363 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
364 int64_t operand_index, HloInstruction* dynamic_size) {
365 if (operand_index != 0) {
366 return Unimplemented(
367 "Dynamic dimension on padding value is not supported");
368 }
369 const PaddingConfig_PaddingConfigDimension& padding_config =
370 hlo->padding_config().dimensions(dimension);
371
372 HloInstruction* dynamic_size_adjusted = dynamic_size;
373 if (padding_config.interior_padding() != 0) {
374 // Adjust for interior padding :
375 // Size' = max((Size - 1), 0) * interior_padding + Size
376 HloInstruction* one = hlo->parent()->AddInstruction(
377 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
378 HloInstruction* zero = hlo->parent()->AddInstruction(
379 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
380 HloInstruction* interior_padding = hlo->parent()->AddInstruction(
381 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
382 padding_config.interior_padding())));
383 dynamic_size_adjusted =
384 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
385 dynamic_size_adjusted->shape(), HloOpcode::kSubtract,
386 dynamic_size_adjusted, one));
387 dynamic_size_adjusted =
388 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
389 dynamic_size_adjusted->shape(), HloOpcode::kMaximum,
390 dynamic_size_adjusted, zero));
391 dynamic_size_adjusted =
392 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
393 dynamic_size_adjusted->shape(), HloOpcode::kMultiply,
394 dynamic_size_adjusted, interior_padding));
395 dynamic_size_adjusted =
396 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
397 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
398 dynamic_size_adjusted, dynamic_size));
399 }
400 HloInstruction* adjustment = hlo->parent()->AddInstruction(
401 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
402 padding_config.edge_padding_low() +
403 padding_config.edge_padding_high())));
404 dynamic_size_adjusted =
405 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
406 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
407 dynamic_size_adjusted, adjustment));
408 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
409 return Status::OK();
410 });
411 }
412
HandleReduce(HloInstruction * hlo)413 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
414 return ForEachOperandDynamicDimension(
415 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
416 int64_t operand_index, HloInstruction* dynamic_size) {
417 auto* reduce = Cast<HloReduceInstruction>(hlo);
418 int64_t operand_count = reduce->operand_count();
419 CHECK_EQ(operand_count % 2, 0);
420 if (operand_index >= reduce->input_count()) {
421 // Init values doesn't have dynamic size.
422 return Status::OK();
423 }
424 if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
425 // Dimension is to be reduced, stop tracing.
426 return Status::OK();
427 }
428
429 // Find out the new dynamic dimension after reduce.
430 int64_t dimensions_not_reduced_count = 0;
431 for (int i = 0; i < operand->shape().rank(); ++i) {
432 if (dimension == i) {
433 // The dimensions of all data operands of a variadic reduce have
434 // to be the same. This means that if one operand of variadic
435 // reduce has a dynamic dimension, we set all outputs to use the
436 // same dynamic size in corresponding dimensions.
437 ShapeUtil::ForEachSubshape(
438 reduce->shape(),
439 [&](const Shape& subshape, ShapeIndex reduce_result_index) {
440 if (!ShapeUtil::IsLeafIndex(reduce->shape(),
441 reduce_result_index)) {
442 return;
443 }
444 parent_->SetDynamicSize(reduce, reduce_result_index,
445 dimensions_not_reduced_count,
446 dynamic_size);
447 });
448
449 return Status::OK();
450 }
451 if (absl::c_count(reduce->dimensions(), i) == 0) {
452 dimensions_not_reduced_count++;
453 }
454 }
455
456 return Status::OK();
457 });
458 }
459
HandleDot(HloInstruction * hlo)460 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
461 return ForEachOperandDynamicDimension(
462 hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index,
463 int64_t operand_dimension, int64_t operand_index,
464 HloInstruction* dynamic_size) {
465 // There are three types of dimensions in a dot:
466 // A. batch dims
467 // B. contracting dims
468 // C. non-batch non-contracting dims.
469 // The output dimensions of a dot has three parts with the following
470 // order:
471 // [(type A), (lhs type C), (rhs type C)]
472 //
473 // Note that both lhs and rhs have the same dimension sizes for batch,
474 // but the dimension index could be different.
475 //
476 // Given one dynamic input dimension, either lhs or rhs, we use a
477 // mapping to find the corresponding output dimension.
478 HloInstruction* dot = hlo;
479 const DotDimensionNumbers& dimension_numbers =
480 dot->dot_dimension_numbers();
481 // A map from the operand dimensions to result dimension.
482 absl::flat_hash_map<int64, int64> result_dim_mapping;
483 int64_t current_result_dims = 0;
484
485 bool lhs = operand_index == 0;
486
487 // The first loop keep tracks of batch dimension. RHS and LHS could have
488 // different batch dimension numbers.
489 if (lhs) {
490 for (int64_t i : dimension_numbers.lhs_batch_dimensions()) {
491 result_dim_mapping[i] = current_result_dims++;
492 }
493 } else {
494 for (int64_t i : dimension_numbers.rhs_batch_dimensions()) {
495 result_dim_mapping[i] = current_result_dims++;
496 }
497 }
498
499 // Handle dimensions in the lhs.
500 for (int64_t i = 0; i < dot->operand(0)->shape().rank(); i++) {
501 // Look for non-contracting and non-batching dimension.
502 if (absl::c_linear_search(
503 dimension_numbers.lhs_contracting_dimensions(), i)) {
504 continue;
505 }
506 if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
507 i)) {
508 continue;
509 }
510 if (lhs) {
511 result_dim_mapping[i] = current_result_dims;
512 }
513 current_result_dims++;
514 }
515
516 // Handle dimensions in the rhs.
517 for (int64_t i = 0; i < dot->operand(1)->shape().rank(); i++) {
518 // Look for non-contracting and non-batching dimension.
519 if (absl::c_linear_search(
520 dimension_numbers.rhs_contracting_dimensions(), i)) {
521 continue;
522 }
523 if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
524 i)) {
525 continue;
526 }
527 if (!lhs) {
528 result_dim_mapping[i] = current_result_dims;
529 }
530 current_result_dims++;
531 }
532
533 // Check if the operand dim is in the result shape. If so, add another
534 // work item to trace that dimension.
535 auto iter = result_dim_mapping.find(operand_dimension);
536 if (iter != result_dim_mapping.end()) {
537 parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
538 }
539
540 return Status::OK();
541 });
542 }
543
HandleTranspose(HloInstruction * hlo)544 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
545 return ForEachOperandDynamicDimension(
546 hlo,
547 [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
548 int64_t operand_index, HloInstruction* dynamic_size) -> Status {
549 int64_t permuted_dim = -1;
550 for (int64_t i = 0; i < hlo->dimensions().size(); ++i) {
551 if (hlo->dimensions()[i] == dimension) {
552 TF_RET_CHECK(permuted_dim == -1);
553 permuted_dim = i;
554 }
555 }
556 parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
557 return Status::OK();
558 });
559 }
560
HandleConvolution(HloInstruction * hlo)561 Status DynamicDimensionInferenceVisitor::HandleConvolution(
562 HloInstruction* hlo) {
563 return ForEachOperandDynamicDimension(
564 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
565 int64_t operand_index, HloInstruction* dynamic_size) {
566 HloInstruction* conv = hlo;
567 const ConvolutionDimensionNumbers& dimension_numbers =
568 conv->convolution_dimension_numbers();
569 if (operand_index == 0) {
570 if (dimension == dimension_numbers.input_batch_dimension()) {
571 parent_->SetDynamicSize(conv, {},
572 dimension_numbers.output_batch_dimension(),
573 dynamic_size);
574 return Status::OK();
575 }
576
577 if (dimension == dimension_numbers.input_feature_dimension()) {
578 return Status::OK();
579 }
580 } else {
581 if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
582 return Status::OK();
583 }
584 }
585
586 return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
587 conv->ToString());
588 });
589 }
590
HandleConcatenate(HloInstruction * hlo)591 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
592 HloInstruction* hlo) {
593 // First handle concatenate dimensions. We do this by iterating through all
594 // operands while tracking both dynamic and static dimensions.
595
596 // static_size is used to keep track of the concated size of static
597 // dimensions.
598 int64_t static_size = 0;
599 std::vector<HloInstruction*> dynamic_concat_dims;
600 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
601 HloInstruction* dynamic_size = parent_->GetDynamicSize(
602 hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
603 if (dynamic_size == nullptr) {
604 // This is a static dimension.
605 static_size +=
606 hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
607 } else {
608 dynamic_concat_dims.push_back(dynamic_size);
609 }
610 }
611 // If concat dimension is dynamic, calculate its size by summing up static
612 // dims and dynamic dims together.
613 if (!dynamic_concat_dims.empty()) {
614 HloInstruction* dim_size_total =
615 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
616 LiteralUtil::CreateR0<int32>(static_size)));
617 for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
618 dim_size_total = hlo->parent()->AddInstruction(
619 HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
620 dim_size_total, dynamic_dim));
621 }
622 parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
623 dim_size_total);
624 }
625
626 // Simply pass through non-concat dynamic dimensions.
627 return ForEachOperandDynamicDimension(
628 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
629 int64_t operand_index, HloInstruction* dynamic_size) {
630 int64_t concatenate_dimension = hlo->concatenate_dimension();
631 if (concatenate_dimension == dimension) {
632 return Status::OK();
633 }
634 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
635 return Status::OK();
636 });
637 }
638
HandleGetDimensionSize(HloInstruction * gds)639 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
640 HloInstruction* gds) {
641 // Dynamic dimension doesn't propagate through GetDimensionSize:
642 //
643 // Input: F32[x, y, z]
644 // |
645 // GetDimensionSize(1): S32[]
646 //
647 // The returned value is a scalar, which doesn't have any dynamic dimension in
648 // the shape (although the value contains the real size of the dynamic
649 // dimension of the input).
650 int64_t dim = gds->dimension();
651 HloInstruction* operand = gds->mutable_operand(0);
652 HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
653 HloComputation* computation = gds->parent();
654 if (dynamic_size != nullptr) {
655 TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
656 // The dependency between an instruction and its dynamic dimensions is not
657 // modeled in the IR. As instr is being replaced by dynamic_size, also tell
658 // dynamic dimension inference that the instruction is being replaced.
659 parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
660 } else {
661 TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
662 int32_t size = gds->operand(0)->shape().dimensions(dim);
663 HloInstruction* new_instr = computation->AddInstruction(
664 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
665 TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
666 parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
667 }
668 return Status::OK();
669 }
670
HandleSetDimensionSize(HloInstruction * hlo)671 Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
672 HloInstruction* hlo) {
673 bool dimension_is_static = false;
674 const HloInstruction* size = hlo->operand(1);
675 if (size->opcode() == HloOpcode::kConstant) {
676 // Check if we are setting a dimension size to its static size. If so,
677 // removes the dynamic dimension.
678 //
679 // size = s32[] constant(5)
680 // s32[2, 5] = set-dimension-size(s32[2,<=5]{1,0} %param, s32[] %size),
681 // dimensions={1}
682 // The result shape has no dynamic dimension.
683 TF_RET_CHECK(size->shape().rank() == 0);
684 if (size->literal().Get<int32>({}) ==
685 hlo->shape().dimensions(hlo->dimension())) {
686 dimension_is_static = true;
687 }
688 }
689
690 if (!dimension_is_static) {
691 // Propagate dynamic dimension indicated by this set dimension size
692 // instruction.
693 parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
694 }
695
696 // Also Propagate dynamic dimension already set by operands.
697 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
698 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
699 int64_t operand_index, HloInstruction* dynamic_size) {
700 if (dimension != hlo->dimension()) {
701 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
702 }
703 return Status::OK();
704 }));
705
706 return Status::OK();
707 }
708
HandleDynamicConvolutionForward(HloInstruction * hlo,int64_t operand_index,int64_t dimension,HloInstruction * dynamic_size)709 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
710 HloInstruction* hlo, int64_t operand_index, int64_t dimension,
711 HloInstruction* dynamic_size) {
712 TF_RET_CHECK(operand_index == 0);
713 const ConvolutionDimensionNumbers& dimension_numbers =
714 hlo->convolution_dimension_numbers();
715
716 if (dimension == dimension_numbers.input_batch_dimension()) {
717 // Batch dimension is propagated without any changes.
718 parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
719 dynamic_size);
720 return Status::OK();
721 }
722
723 for (int64_t spatial_dim_index = 0;
724 spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
725 ++spatial_dim_index) {
726 int64_t input_spatial_dim =
727 dimension_numbers.input_spatial_dimensions(spatial_dim_index);
728 int64_t output_spatial_dim =
729 dimension_numbers.output_spatial_dimensions(spatial_dim_index);
730 if (dimension == input_spatial_dim) {
731 // This is a dynamic spatial dimension. Calculate the output size.
732 WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
733 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
734 dynamic_size, window_dim.size(), window_dim.window_dilation(),
735 window_dim.stride(), hlo->padding_type());
736 TF_RET_CHECK(window_dim.base_dilation() == 1);
737 parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
738 dynamic_window_dims.output_size);
739 return Status::OK();
740 }
741 }
742 // Input Feature dim disappears after convolution.
743 return Status::OK();
744 }
745
HandleDynamicWindowSamePadding(HloInstruction * hlo,HloInstruction * dynamic_size,int64_t operand_index,int64_t dimension)746 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
747 HloInstruction* hlo, HloInstruction* dynamic_size, int64_t operand_index,
748 int64_t dimension) {
749 const Window& window = hlo->window();
750 const WindowDimension& window_dim = window.dimensions(dimension);
751 if (!window_util::IsTrivialWindowDimension(window_dim)) {
752 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
753 dynamic_size, window_dim.size(), window_dim.window_dilation(),
754 window_dim.stride(), PaddingType::PADDING_SAME);
755 parent_->SetDynamicSize(hlo, {}, dimension,
756 dynamic_window_dims.output_size);
757 return Status::OK();
758 }
759
760 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
761
762 return Status::OK();
763 }
764
HandleDynamicConvolutionInputGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)765 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
766 HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
767 // The output size of convolution input grad is corresponding input size.
768 HloInstruction* input_sizes = hlo->mutable_operand(0);
769 HloComputation* comp = hlo->parent();
770 TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
771 TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
772 TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
773 hlo->shape().dimensions_size())
774 << hlo->ToString();
775 // Slice to get corresponding input size.
776 HloInstruction* slice = comp->AddInstruction(
777 HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
778 {dimension}, {dimension + 1}, {1}));
779 HloInstruction* reshape = comp->AddInstruction(
780 HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
781 parent_->SetDynamicSize(hlo, {}, dimension, reshape);
782 return Status::OK();
783 }
784
HandleDynamicConvolutionKernelGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)785 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
786 HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
787 // Dynamic convolution kernel grad produces static shape outputs.
788 return Status::OK();
789 }
790
PassThroughDynamicDimension(HloInstruction * hlo)791 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
792 HloInstruction* hlo) {
793 return ForEachOperandDynamicDimension(
794 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
795 int64_t operand_index, HloInstruction* dynamic_size) {
796 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
797 return Status::OK();
798 });
799 }
800
HandleDomain(HloInstruction * hlo)801 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
802 return PassThroughDynamicDimension(hlo);
803 }
804
HandleElementwiseUnary(HloInstruction * hlo)805 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
806 HloInstruction* hlo) {
807 return PassThroughDynamicDimension(hlo);
808 }
809
HandleSelect(HloInstruction * hlo)810 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
811 return PassThroughDynamicDimension(hlo);
812 }
813
HandleElementwiseBinary(HloInstruction * hlo)814 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
815 HloInstruction* hlo) {
816 HloComputation* comp = hlo->parent();
817 return ForEachOperandDynamicDimension(
818 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
819 int64_t operand_index, HloInstruction* dynamic_size) {
820 HloInstruction* existing_size =
821 parent_->GetDynamicSize(hlo, index, dimension);
822 if (existing_size == nullptr || existing_size == dynamic_size) {
823 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
824 } else {
825 HloInstruction* max =
826 comp->AddInstruction(HloInstruction::CreateBinary(
827 ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
828 dynamic_size, existing_size));
829 parent_->SetDynamicSize(hlo, index, dimension, max);
830 }
831 return Status::OK();
832 });
833 }
834
HandleClamp(HloInstruction * hlo)835 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
836 return PassThroughDynamicDimension(hlo);
837 }
838
HandleDynamicReshape(HloInstruction * hlo)839 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
840 HloInstruction* hlo) {
841 HloDynamicReshapeInstruction* dynamic_reshape =
842 Cast<HloDynamicReshapeInstruction>(hlo);
843 for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
844 if (hlo->shape().is_dynamic_dimension(i)) {
845 parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
846 }
847 }
848 return Status::OK();
849 }
850
HandleReshape(HloInstruction * hlo)851 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
852 return ForEachOperandDynamicDimension(
853 hlo,
854 [&](HloInstruction* operand, ShapeIndex index,
855 int64_t input_dynamic_dimension, int64_t operand_index,
856 HloInstruction* operand_dynamic_size) -> Status {
857 HloInstruction* reshape = hlo;
858 if (reshape->shape().rank() == 0) {
859 VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
860 "undefined behavior when input size is 0. The offending "
861 "instruction is: "
862 << reshape->ToString();
863 return Status::OK();
864 }
865 auto common_factors = CommonFactors(operand->shape().dimensions(),
866 reshape->shape().dimensions());
867 int64_t input_dim_start = -1;
868 int64_t input_dim_end = -1;
869 int64_t output_dim_start = -1;
870 int64_t output_dim_end = -1;
871 // Find common_factors that the input belongs to.
872 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
873 auto start = common_factors[i];
874 auto end = common_factors[i + 1];
875 if (input_dynamic_dimension >= start.first &&
876 input_dynamic_dimension < end.first) {
877 // Found the common_factor group that the input_dim belongs to.
878 input_dim_start = start.first;
879 input_dim_end = end.first;
880 output_dim_start = start.second;
881 output_dim_end = end.second;
882 }
883 }
884
885 VLOG(2) << "Input dim start: " << input_dim_start
886 << " Input dim end: " << input_dim_end
887 << " output dim start: " << output_dim_start
888 << " output dim end: " << output_dim_end;
889
890 if ((input_dim_end - input_dim_start) > 1 &&
891 (output_dim_end - output_dim_start) > 1) {
892 // We don't support the case when a dynamic dimension is both combined
893 // with and splitted into other dimensions:
894 //
895 // [x, yz]
896 // | Reshape
897 // [xy, z]
898 //
899 // TODO(yunxing): This can be supported by canonicalizing
900 // the offending reshape into two reshapes:
901 //
902 // [x,yz]
903 // | Reshape
904 // [x, y, z]
905 // | Reshape
906 // [xy, z]
907 //
908 return Unimplemented(
909 "Dynamic input dimension to reshape that is both splitted and "
910 "combined is not supported %s",
911 hlo->ToString());
912 }
913
914 for (auto common_factor : common_factors) {
915 // Expand common factor to include degenerated output dimensions.
916 if (common_factor.first == input_dim_start) {
917 output_dim_start = std::min(output_dim_start, common_factor.second);
918 }
919 if (common_factor.first == input_dim_end) {
920 output_dim_end = std::max(output_dim_end, common_factor.second);
921 }
922 }
923
924 int64_t output_dynamic_dimension = -1;
925
926 if (operand->shape().dimensions(input_dynamic_dimension) == 1) {
927 // If dynamic dimension is 1, it can only be most-major or
928 // most-minor.
929 if (input_dynamic_dimension == 0) {
930 output_dynamic_dimension = 0;
931 }
932 if (input_dynamic_dimension == operand->shape().rank() - 1) {
933 output_dynamic_dimension = reshape->shape().rank() - 1;
934 }
935
936 if (output_dynamic_dimension == -1) {
937 return Unimplemented(
938 "Dynamic degenerated dimension that's not most-minor nor "
939 "most-major is not supported %s",
940 reshape->ToString());
941 }
942 }
943
944 if (output_dynamic_dimension == -1 &&
945 output_dim_end - output_dim_start == 1) {
946 // Only one possible output dimension.
947 output_dynamic_dimension = output_dim_start;
948 }
949
950 if (output_dynamic_dimension == -1 &&
951 output_dim_end - output_dim_start > 1) {
952 // One input dimension is splitted into multiple output dimensions.
953 // Output dimension is decomposed from input most major dimension.
954 // In this case, we don't know which one is dynamic, e.g., when we
955 // have:
956 //
957 // [<=a/c, c, b]
958 // | Reshape
959 // [<=a, b] // a is dynamic, has to be multiple of c.
960 // | Reshape
961 // [1, 1, ... , a/c, c, b]
962 //
963 // Any dimension from the first '1' to 'a/c' can be dynamic.
964 //
965 // We use the following logics to disambiguate:
966 // 1. If the user sets "inferred_dimension", then use that as
967 // dynamic dimension.
968 // 2. If the one dimension in the reshape is dynamic, use that as
969 // dynamic dimension.
970 // E.g.:
971 // [<=4]
972 // |
973 // reshape
974 // |
975 // [1, <=2, 2]
976 // We use second dim as dynamic dimension.
977 //
978 // 3. If all logics above cannot disambiguate, e.g.,:
979 //
980 // [<=1]
981 // |
982 // reshape
983 // |
984 // [1, 1, 1]
985 //
986 // We bail out and return an error.
987 // TODO(yunxing): Further simplify this, remove 1. and fully rely
988 // on 2.
989 output_dynamic_dimension = reshape->inferred_dimension();
990 if (output_dynamic_dimension == -1) {
991 // Try find dynamic dimension from the result shape.
992 for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
993 if (reshape->shape().is_dynamic_dimension(i)) {
994 output_dynamic_dimension = i;
995 }
996 }
997 }
998
999 if (output_dynamic_dimension == -1) {
1000 std::vector<int64> output_non_degenerated;
1001 for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
1002 if (reshape->shape().dimensions(i) != 1) {
1003 output_non_degenerated.push_back(i);
1004 }
1005 }
1006 if (output_non_degenerated.size() == 1) {
1007 output_dynamic_dimension = output_non_degenerated[0];
1008 }
1009 }
1010
1011 if (output_dynamic_dimension == -1) {
1012 return InvalidArgument(
1013 "Reshape's input dynamic dimension is decomposed into "
1014 "multiple output dynamic dimensions, but the constraint is "
1015 "ambiguous and XLA can't infer the output dimension %s. ",
1016 hlo->ToString());
1017 }
1018 }
1019
1020 CHECK_NE(output_dynamic_dimension, -1);
1021 const int64_t input_dim_size =
1022 operand->shape().dimensions(input_dynamic_dimension);
1023 const int64_t output_dim_size =
1024 reshape->shape().dimensions(output_dynamic_dimension);
1025 VLOG(2) << "input_dim_size: " << input_dim_size
1026 << " output_dim_size: " << output_dim_size;
1027
1028 if (input_dim_size == output_dim_size) {
1029 // Simply forward dynamic dimension.
1030 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1031 operand_dynamic_size);
1032 }
1033
1034 if (input_dim_size > output_dim_size) {
1035 TF_RET_CHECK(input_dim_size % output_dim_size == 0)
1036 << reshape->ToString();
1037 const int64_t divisor = input_dim_size / output_dim_size;
1038 HloInstruction* divisor_hlo =
1039 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1040 LiteralUtil::CreateR0<int32>(divisor)));
1041
1042 HloInstruction* new_dynamic_size =
1043 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1044 operand_dynamic_size->shape(), HloOpcode::kDivide,
1045 operand_dynamic_size, divisor_hlo));
1046
1047 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1048 new_dynamic_size);
1049 }
1050
1051 if (input_dim_size < output_dim_size) {
1052 // Input dimension is combined with other input dimensions.
1053 //
1054 // Adjust the output size by the ratio of dynamic_input_dim /
1055 // static_input_dim.
1056 //
1057 // For example if we have [<=3, 3] -> [9], if the dynamic size is 2,
1058 // the new output dynamic isze is 9 / 3 * 2 = 6.
1059 //
1060 // If it turns out the second dimension is also dynamic:
1061 // [<=3, <=3] -> [9], and the dynamic size is also 2, the new output
1062 // dynamic size is 6 / 3 * 2 = 4.
1063 //
1064 //
1065 HloInstruction* output_dynamic_size =
1066 parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
1067 if (output_dynamic_size == nullptr) {
1068 output_dynamic_size =
1069 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1070 LiteralUtil::CreateR0<int32>(output_dim_size)));
1071 }
1072 HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(
1073 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1074 operand->shape().dimensions(input_dynamic_dimension))));
1075
1076 HloInstruction* new_dynamic_size =
1077 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1078 output_dynamic_size->shape(), HloOpcode::kDivide,
1079 output_dynamic_size, divisor_hlo));
1080
1081 new_dynamic_size =
1082 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1083 output_dynamic_size->shape(), HloOpcode::kMultiply,
1084 new_dynamic_size, operand_dynamic_size));
1085 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1086 new_dynamic_size);
1087 }
1088
1089 return Status::OK();
1090 });
1091 }
1092
HandleReduceWindow(HloInstruction * hlo)1093 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
1094 HloInstruction* hlo) {
1095 return ForEachOperandDynamicDimension(
1096 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1097 int64_t operand_index, HloInstruction* dynamic_size) {
1098 auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
1099 const WindowDimension& window_dim =
1100 reduce_window->window().dimensions(dimension);
1101
1102 if (operand_index >= reduce_window->input_count()) {
1103 // Init values doesn't have dynamic size.
1104 return Status::OK();
1105 }
1106
1107 if (!window_util::IsTrivialWindowDimension(window_dim)) {
1108 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1109 dynamic_size, window_dim.size(), window_dim.window_dilation(),
1110 window_dim.stride(), PaddingType::PADDING_VALID);
1111 dynamic_size = dynamic_window_dims.output_size;
1112 }
1113
1114 // The dimensions of all data operands of a variadic reduce window have
1115 // to be the same. This means that if one operand of variadic
1116 // reduce has a dynamic dimension, we set all outputs to use the
1117 // same dynamic size in corresponding dimensions.
1118 ShapeUtil::ForEachSubshape(
1119 reduce_window->shape(),
1120 [&](const Shape& subshape, ShapeIndex reduce_window_result_index) {
1121 if (!ShapeUtil::IsLeafIndex(reduce_window->shape(),
1122 reduce_window_result_index)) {
1123 return;
1124 }
1125 parent_->SetDynamicSize(reduce_window, reduce_window_result_index,
1126 dimension, dynamic_size);
1127 });
1128
1129 return Status::OK();
1130 });
1131 }
1132
HandleSelectAndScatter(HloInstruction * hlo)1133 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
1134 HloInstruction* hlo) {
1135 return ForEachOperandDynamicDimension(
1136 hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1137 int64_t operand_index, HloInstruction* dynamic_size) {
1138 if (operand_index == 1) {
1139 // Operand 0 (input) determines dynamic output size. We ignore the
1140 // dynamic size in the operand 1 (output gradient).
1141 return Status::OK();
1142 }
1143 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1144
1145 return Status::OK();
1146 });
1147 }
1148
HandleSlice(HloInstruction * hlo)1149 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
1150 return ForEachOperandDynamicDimension(
1151 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t dimension,
1152 int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1153 if (hlo->slice_starts(dimension) != 0 ||
1154 hlo->slice_strides(dimension) != 1 ||
1155 hlo->slice_limits(dimension) !=
1156 operand->shape().dimensions(dimension)) {
1157 // Slicing a partial element out eliminates the dynamic dimension.
1158 return Status::OK();
1159 }
1160
1161 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1162
1163 return Status::OK();
1164 });
1165 }
1166
HandleDynamicSlice(HloInstruction * hlo)1167 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
1168 HloInstruction* hlo) {
1169 return ForEachOperandDynamicDimension(
1170 hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64_t dimension,
1171 int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1172 if (hlo->shape().dimensions(dimension) !=
1173 hlo->operand(0)->shape().dimensions(dimension)) {
1174 // Slicing a single element out kills the dynamic dimension.
1175 if (hlo->shape().dimensions(dimension) == 1) {
1176 return Status::OK();
1177 }
1178 return Unimplemented(
1179 "Dynamic dimension propagation on DynamicSlice where a partial "
1180 "dimension is selected %s",
1181 hlo->ToString());
1182 }
1183
1184 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1185
1186 return Status::OK();
1187 });
1188 }
1189
HandleDynamicUpdateSlice(HloInstruction * hlo)1190 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
1191 HloInstruction* hlo) {
1192 return ForEachOperandDynamicDimension(
1193 hlo,
1194 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1195 int64_t operand_index, HloInstruction* dynamic_size) {
1196 if (hlo->shape().dimensions(dimension) !=
1197 hlo->operand(0)->shape().dimensions(dimension)) {
1198 return Unimplemented(
1199 "Dynamic dimension propagation on DynamicUpdateSlice where a "
1200 "partial dimension is selected %s",
1201 hlo->ToString());
1202 }
1203
1204 if (operand_index == 1 &&
1205 hlo->operand(1)->shape().dimensions(dimension) <
1206 hlo->operand(0)->shape().dimensions(dimension)) {
1207 // DUS(input=[A], update=[<=B])
1208 //
1209 // If update dim is smaller than input dim (B < A) , then we are doing
1210 // a partial update, no need to set the output dynamic dimension.
1211 //
1212 // The dynamic shape in `update` doesn't change output dynamic shape.
1213 return Status::OK();
1214 }
1215
1216 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1217
1218 return Status::OK();
1219 });
1220 }
1221
HandleReverse(HloInstruction * hlo)1222 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
1223 return PassThroughDynamicDimension(hlo);
1224 }
1225
HandleGather(HloInstruction * hlo)1226 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
1227 return ForEachOperandDynamicDimension(
1228 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
1229 int64_t input_dynamic_dimension, int64_t operand_index,
1230 HloInstruction* dynamic_size) {
1231 const GatherDimensionNumbers& gather_dims =
1232 hlo->gather_dimension_numbers();
1233 if (operand_index != 1) {
1234 if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
1235 // Gathering a size 1 dimension out of a dynamic dimension removes
1236 // the dynamicity.
1237 return Status::OK();
1238 }
1239 if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
1240 operand->shape().dimensions(input_dynamic_dimension)) {
1241 // Gathering a full-sized dimension out of a dynamic dimension
1242 // propagates the dynamicity to output.
1243 int64_t output_dimension = input_dynamic_dimension;
1244 for (int64_t collapsed_dim : gather_dims.collapsed_slice_dims()) {
1245 if (collapsed_dim < input_dynamic_dimension) {
1246 // This output dimension is collapsed.
1247 output_dimension--;
1248 }
1249 }
1250 parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
1251 return Status::OK();
1252 }
1253 return Unimplemented(
1254 "Detects a dynamic dimension on the data input of gather, which "
1255 "is not supported: %s, %lld",
1256 hlo->ToString(), input_dynamic_dimension);
1257 }
1258 // A mapping from output to input batch dim number. -1 means not a batch
1259 // dimension.
1260 int64_t indices_rank = hlo->operand(1)->shape().rank();
1261 int64_t output_rank = hlo->shape().rank();
1262
1263 // indices_dim is an iterator over indices dimensions.
1264 int64_t indices_dim = 0;
1265 // Find the corresponding batch dimension in the output.
1266 for (int64_t output_dim = 0; output_dim < output_rank; ++output_dim) {
1267 if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) {
1268 // Skips index vector dimension.
1269 if (indices_dim == gather_dims.index_vector_dim()) {
1270 indices_dim++;
1271 }
1272 if (indices_dim++ == input_dynamic_dimension) {
1273 parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
1274 return Status::OK();
1275 }
1276 }
1277 }
1278 CHECK(indices_dim == indices_rank);
1279
1280 return Unimplemented(
1281 "Detects a non-batch dynamic dimension of gather, "
1282 "which is not supported: %s",
1283 hlo->ToString());
1284 });
1285 }
1286
HandleConditional(HloInstruction * hlo)1287 Status DynamicDimensionInferenceVisitor::HandleConditional(
1288 HloInstruction* hlo) {
1289 // Conditionals are handled by producing additional inputs and outputs of
1290 // the conditional instruction.
1291 std::vector<HloComputation*> new_branch_computations;
1292 std::vector<HloInstruction*> new_operands;
1293 // If the output of the conditional contains dynamic dimension. We send
1294 // dynamic dimension size out by adding additional root element. A mapping
1295 // from the root instruction's dynamic dimension index (represented by a shape
1296 // index as output index and a int64 dimension number) to output index
1297 // (represented by an int64) is tracked for the conditional intsruction (all
1298 // branches should have the same mapping).
1299 ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1300 hlo->shape());
1301
1302 bool need_rewrite = false;
1303 for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1304 ++branch_index) {
1305 std::vector<HloInstruction*> operands_to_add;
1306
1307 absl::flat_hash_map<HloInstruction*, int64>
1308 dynamic_size_to_operand_id_index_map;
1309 // Only look at branch_index + 1, the correct operand index for a
1310 // given branch.
1311 const int64_t operand_index = branch_index + 1;
1312
1313 int64_t operand_count =
1314 hlo->operand(operand_index)->shape().tuple_shapes_size();
1315 // Prepare to pass dynamic dimension into the new computation and add
1316 // dynamic dimension sizes as parameters to the new tuple.
1317 TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1318 hlo, operand_index,
1319 [&](HloInstruction*, ShapeIndex, int64_t, int64_t,
1320 HloInstruction* dynamic_size) -> Status {
1321 TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
1322 << "Only tuple typed inputs can have dynamic dimension. Please "
1323 "file a bug against XLA team.";
1324 const HloInstruction* tuple_operand = hlo->operand(operand_index);
1325 for (int64_t i = 0; i < tuple_operand->operand_count(); ++i) {
1326 // If the dynamic size is already an operand to the computation,
1327 // skip adding it to the computation input again.
1328 if (dynamic_size == tuple_operand->operand(i)) {
1329 dynamic_size_to_operand_id_index_map[dynamic_size] = i;
1330 return Status::OK();
1331 }
1332 }
1333 auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
1334 if (iter == dynamic_size_to_operand_id_index_map.end()) {
1335 operands_to_add.push_back(dynamic_size);
1336 dynamic_size_to_operand_id_index_map[dynamic_size] =
1337 operand_count++;
1338 }
1339 return Status::OK();
1340 }));
1341
1342 HloInstruction* original_input = hlo->mutable_operand(operand_index);
1343 HloComputation* branch_computation = hlo->branch_computation(branch_index);
1344
1345 HloComputation* new_computation = branch_computation;
1346 HloInstruction* new_operand = hlo->mutable_operand(operand_index);
1347 if (!operands_to_add.empty()) {
1348 TF_RET_CHECK(original_input->shape().IsTuple());
1349 need_rewrite = true;
1350 new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
1351 TF_ASSIGN_OR_RETURN(
1352 new_computation,
1353 WidenComputation(branch_computation, new_operand->shape()));
1354 }
1355 // Set the dynamic dimensions for the newly created branch computation's
1356 // parameters so that the hlos inside the computation can see dynamic
1357 // dimensions.
1358 DynamicParameterBinding dynamic_parameter_binding;
1359 TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1360 hlo, operand_index,
1361 [&](HloInstruction*, ShapeIndex index, int64_t dimension,
1362 int64_t operand_index, HloInstruction* dynamic_size) {
1363 DynamicParameterBinding::DynamicParameter dynamic_parameter{
1364 0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
1365 DynamicParameterBinding::DynamicDimension dynamic_dimension{
1366 0, {index}, dimension};
1367 TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
1368 dynamic_dimension));
1369
1370 return Status::OK();
1371 }));
1372 VLOG(2) << "dynamic_parameter_binding for conditional branch"
1373 << dynamic_parameter_binding;
1374 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1375 new_computation, dynamic_parameter_binding, parent_));
1376
1377 new_branch_computations.push_back(new_computation);
1378 new_operands.push_back(new_operand);
1379 }
1380 int64_t tuple_count = hlo->shape().tuple_shapes_size();
1381 // The dynamism of the output of branches can be different.
1382 // E.g.,
1383 // true_branch (s32[<=4])
1384 // false_branch (s32[4])
1385 //
1386 // The following loop populates dynamic_output_mapping and account for
1387 // dynamism across all branches.
1388 ShapeUtil::ForEachSubshape(
1389 hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1390 if (!subshape.IsArray()) {
1391 return;
1392 }
1393 for (int64 i = 0; i < subshape.rank(); ++i) {
1394 for (int64 j = 0; j < new_branch_computations.size(); ++j) {
1395 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1396 new_branch_computations[j]->root_instruction(), index, i);
1397 if (dynamic_size) {
1398 if (dynamic_output_mapping.element(index).contains(i)) {
1399 continue;
1400 }
1401 dynamic_output_mapping.mutable_element(index)->emplace(
1402 i, tuple_count++);
1403 }
1404 }
1405 }
1406 });
1407 for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1408 ++branch_index) {
1409 std::vector<HloInstruction*> hlos_to_add_in_root;
1410 // There may be some dynamic dimensions coming out of the computation, wire
1411 // that into the root instruction as additional tuple elements.
1412 ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape,
1413 const ShapeIndex& index) {
1414 if (!subshape.IsArray()) {
1415 return;
1416 }
1417 for (int64 i = 0; i < subshape.rank(); ++i) {
1418 if (dynamic_output_mapping.element(index).contains(i)) {
1419 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1420 new_branch_computations[branch_index]->root_instruction(), index,
1421 i);
1422 if (dynamic_size) {
1423 hlos_to_add_in_root.push_back(dynamic_size);
1424 } else {
1425 HloInstruction* constant_size =
1426 new_branch_computations[branch_index]->AddInstruction(
1427 HloInstruction::CreateConstant(
1428 LiteralUtil::CreateR0<int32>(subshape.dimensions(i))));
1429 hlos_to_add_in_root.push_back(constant_size);
1430 }
1431 }
1432 }
1433 });
1434
1435 VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
1436 if (!hlos_to_add_in_root.empty()) {
1437 need_rewrite = true;
1438 HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
1439 new_branch_computations[branch_index]->root_instruction(),
1440 hlos_to_add_in_root);
1441 new_branch_computations[branch_index]->set_root_instruction(
1442 new_branch_root,
1443 /*accept_different_shape=*/true);
1444 }
1445 }
1446
1447 if (!need_rewrite) {
1448 return Status::OK();
1449 }
1450 // Create a new conditional with the new operations and computations.
1451 HloInstruction* new_conditional =
1452 hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
1453 new_branch_computations[0]->root_instruction()->shape(),
1454 hlo->mutable_operand(0), new_branch_computations, new_operands));
1455
1456 HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
1457 new_conditional, hlo->shape().tuple_shapes_size());
1458 // Now set the dynamic dimensions of the newly created conditional.
1459 dynamic_output_mapping.ForEachElement(
1460 [&](const ShapeIndex& index,
1461 const absl::flat_hash_map<int64, int64>& dim_to_output) {
1462 for (auto iter : dim_to_output) {
1463 int64_t dim = iter.first;
1464 int64_t output_index = iter.second;
1465 HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
1466 HloInstruction::CreateGetTupleElement(
1467 ShapeUtil::MakeScalarShape(S32), new_conditional,
1468 output_index));
1469 parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
1470 parent_->SetDynamicSize(new_conditional_extracted, index, dim,
1471 dynamic_size);
1472 }
1473 });
1474
1475 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
1476 // Remove the original instruction even if has side-effects.
1477 TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
1478 SetVisited(*new_conditional);
1479 SetVisited(*new_conditional_extracted);
1480 return Status::OK();
1481 }
1482
HandleScatter(HloInstruction * hlo)1483 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
1484 return ForEachOperandDynamicDimension(
1485 hlo,
1486 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1487 int64_t operand_index, HloInstruction* operand_dynamic_size) {
1488 if (operand_index == 0) {
1489 parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
1490 return Status::OK();
1491 }
1492
1493 const ScatterDimensionNumbers& scatter_dims =
1494 hlo->scatter_dimension_numbers();
1495 if (operand_index == 2 &&
1496 absl::c_linear_search(scatter_dims.update_window_dims(),
1497 dimension)) {
1498 // Dynamic update window dimension is only allowed if it is exactly
1499 // the same as the corresponding operand dimension.
1500 std::vector<int64> update_window_dims_in_operand;
1501 for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
1502 if (absl::c_linear_search(scatter_dims.inserted_window_dims(), i)) {
1503 continue;
1504 }
1505 update_window_dims_in_operand.push_back(i);
1506 }
1507
1508 for (int64_t i = 0; i < scatter_dims.update_window_dims_size(); ++i) {
1509 if (scatter_dims.update_window_dims(i) == dimension) {
1510 const Shape& operand_shape = hlo->operand(0)->shape();
1511 const Shape& update_shape = hlo->operand(2)->shape();
1512 int64_t dim_in_operand = update_window_dims_in_operand[i];
1513 if (operand_shape.dimensions(dim_in_operand) !=
1514 update_shape.dimensions(dimension) ||
1515 !operand_shape.is_dynamic_dimension(dim_in_operand)) {
1516 return Unimplemented(
1517 "Dynamic dimension of update window dims that are not the "
1518 "same as corresponding operand dim is not supported: "
1519 "%s",
1520 hlo->ToString());
1521 }
1522 HloInstruction* base_dynamic_size = parent_->GetDynamicSize(
1523 hlo->mutable_operand(0), {}, dim_in_operand);
1524 if (base_dynamic_size != operand_dynamic_size) {
1525 return Unimplemented(
1526 "Dynamic dimension size of update window dims that are not "
1527 "the same as corresponding operand dim is not supported: "
1528 "%s.\n Dynamic dim size of base: %s, dynamic dim size of "
1529 "update: %s",
1530 hlo->ToString(), base_dynamic_size->ToString(),
1531 operand_dynamic_size->ToString());
1532 }
1533 }
1534 }
1535 }
1536 // The dynamic dimension is collapsed and won't show up in the output.
1537 // Do nothing here.
1538 return Status::OK();
1539 });
1540 }
1541
HandleWhile(HloInstruction * hlo)1542 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
1543 // If the output of the kWhile contains dynamic dimension, we send
1544 // dynamic dimension size into the while body by adding additional root/body
1545 // element. A mapping from the root instruction's dynamic dimension index
1546 // (represented by a shape index as output index and an int64 dimension
1547 // number) to output index (represented by an int64) is tracked for the
1548 // conditional instruction.
1549 ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1550 hlo->shape());
1551 std::vector<HloInstruction*> operands_to_add;
1552 const int64_t original_tuple_count = hlo->shape().tuple_shapes_size();
1553 int64_t operand_count = original_tuple_count;
1554 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1555 hlo, [&](HloInstruction*, ShapeIndex index, int64_t dim, int64_t,
1556 HloInstruction* dynamic_size) {
1557 operands_to_add.push_back(dynamic_size);
1558 dynamic_output_mapping.mutable_element(index)->emplace(dim,
1559 operand_count++);
1560 return Status::OK();
1561 }));
1562
1563 DynamicParameterBinding binding_for_while;
1564 if (!operands_to_add.empty()) {
1565 // Only replace the while loop if there are new parameters to add.
1566 HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
1567 TF_ASSIGN_OR_RETURN(
1568 WhileUtil::MakeInstructionsLiveInResult result,
1569 WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
1570 // WhileUtil creates a new while hlo and tuple. Update the dynamic size
1571 // mapping for the newly created tuple.
1572 HloInstruction* new_tuple_operand =
1573 result.new_while_instr->mutable_operand(0);
1574 parent_->CopyMapping(/*from=*/old_tuple_operand,
1575 /*to=*/new_tuple_operand);
1576 hlo = result.new_while_instr;
1577 // We have replaced the while loop, now set the dynamic dimensions for the
1578 // newly created while loop so that the hlos that consumes the while loop
1579 // can see the dynamic dimensions. Also sets the dynamic parameter binding
1580 // for running inference in the while loop.
1581 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1582 hlo,
1583 [&](HloInstruction*, ShapeIndex index, int64_t dimension,
1584 int64_t operand_index, HloInstruction* dynamic_size) -> Status {
1585 TF_RET_CHECK(!operands_to_add.empty());
1586 const int64_t output_dynamic_size_index =
1587 dynamic_output_mapping.element(index).at(dimension);
1588 DynamicParameterBinding::DynamicParameter dynamic_parameter{
1589 operand_index, {output_dynamic_size_index}};
1590 DynamicParameterBinding::DynamicDimension dynamic_dimension{
1591 operand_index, index, dimension};
1592 TF_RETURN_IF_ERROR(
1593 binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
1594 // This is the updated output dynamic size coming out of hlo while
1595 // loop.
1596 HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
1597 HloInstruction::CreateGetTupleElement(
1598 ShapeUtil::MakeScalarShape(S32), hlo,
1599 output_dynamic_size_index));
1600 parent_->SetDynamicSize(result.replacement_instr, index, dimension,
1601 output_dynamic_size);
1602 return Status::OK();
1603 }));
1604 // Set the replacement instruction as visited to avoid visiting it again.
1605 SetVisited(*result.replacement_instr);
1606 }
1607
1608 // Run inference in while body and condition.
1609 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1610 hlo->while_body(), binding_for_while, parent_));
1611 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1612 hlo->while_condition(), binding_for_while, parent_));
1613
1614 if (operands_to_add.empty()) {
1615 // No dynamic dimension in the inputs and outputs.
1616 return Status::OK();
1617 }
1618
1619 // The dynamic dimension size could have been changed in the loop body (e.g, A
1620 // loop that inserts items in a stack, the stack size increases with each
1621 // iteration). Rewrite the dynamic dimension size at the root.
1622 HloInstruction* body_root = hlo->while_body()->root_instruction();
1623 std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
1624 nullptr);
1625
1626 // Original non-dynamic-dim operands of root are pass-through.
1627 for (int64_t i = 0; i < original_tuple_count; ++i) {
1628 new_root_operands[i] =
1629 hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
1630 body_root->shape().tuple_shapes(i), body_root, i));
1631 }
1632 // Add dynamic dimension size as new parameters.
1633 TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1634 hlo->while_body()->root_instruction(),
1635 [&](ShapeIndex index, int64_t dim,
1636 HloInstruction* dynamic_size) -> Status {
1637 const int64_t output_index =
1638 dynamic_output_mapping.element(index).at(dim);
1639 new_root_operands[output_index] = dynamic_size;
1640 return Status::OK();
1641 }));
1642 for (auto operand : new_root_operands) {
1643 TF_RET_CHECK(operand != nullptr);
1644 }
1645 HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
1646 HloInstruction::CreateTuple(new_root_operands));
1647 hlo->while_body()->set_root_instruction(new_body_root);
1648 return Status::OK();
1649 }
1650
HandleParameter(HloInstruction * hlo)1651 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
1652 return param_bindings_.ForEachBinding(
1653 [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1654 const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
1655 if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
1656 return Status::OK();
1657 }
1658 HloComputation* computation = hlo->parent();
1659 HloInstruction* target_parameter =
1660 computation->parameter_instruction(dynamic_dimension.parameter_num);
1661
1662 HloInstruction* dynamic_size =
1663 computation->parameter_instruction(dynamic_parameter.parameter_num);
1664 for (int64_t i : dynamic_parameter.parameter_index) {
1665 dynamic_size =
1666 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1667 ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
1668 dynamic_size, i));
1669 }
1670
1671 parent_->SetDynamicSize(target_parameter,
1672 dynamic_dimension.parameter_index,
1673 dynamic_dimension.dimension, dynamic_size);
1674 return Status::OK();
1675 });
1676 }
1677
ForEachDynamicDimension(HloInstruction * inst,const DynamicDimensionFn & fn)1678 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
1679 HloInstruction* inst, const DynamicDimensionFn& fn) {
1680 auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
1681 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1682 for (auto& dynamic_dimension : iter->second) {
1683 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1684 dynamic_dimension.inst, dynamic_dimension.index,
1685 dynamic_dimension.dim);
1686 TF_RETURN_IF_ERROR(
1687 fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
1688 }
1689 }
1690 return Status::OK();
1691 }
1692
ForEachDynamicDimensionInOperand(HloInstruction * inst,int64_t operand_index,const OperandDynamicDimensionFn & fn)1693 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
1694 HloInstruction* inst, int64_t operand_index,
1695 const OperandDynamicDimensionFn& fn) {
1696 auto iter =
1697 parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
1698 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1699 for (auto& dynamic_dimension : iter->second) {
1700 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1701 dynamic_dimension.inst, dynamic_dimension.index,
1702 dynamic_dimension.dim);
1703 TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
1704 dynamic_dimension.dim, operand_index,
1705 dynamic_size));
1706 }
1707 }
1708 return Status::OK();
1709 }
1710
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)1711 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
1712 HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
1713 for (int64_t operand_index = 0; operand_index < inst->operand_count();
1714 ++operand_index) {
1715 TF_RETURN_IF_ERROR(
1716 ForEachDynamicDimensionInOperand(inst, operand_index, fn));
1717 }
1718 return Status::OK();
1719 }
1720
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim,HloInstruction * size)1721 void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
1722 const ShapeIndex& index,
1723 int64_t dim,
1724 HloInstruction* size) {
1725 VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
1726 << index.ToString() << "@" << dim << " to " << size->ToShortString();
1727 Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
1728 CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
1729 CHECK(dim < subshape.rank() && dim >= 0)
1730 << "Asked to set invalid dynamic dimension. Shape: "
1731 << subshape.ToString() << ", Dimension: " << dim;
1732 DynamicDimension dynamic_dimension{inst, index, dim};
1733 // Updating a dynamic dimension twice overwrites the previous one.
1734 dynamic_mapping_[dynamic_dimension] = size;
1735 auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
1736 iter.first->second.emplace(dynamic_dimension);
1737 }
1738
CopyMapping(HloInstruction * from,HloInstruction * to)1739 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
1740 HloInstruction* to) {
1741 auto iter = per_hlo_dynamic_dimensions_.find(from);
1742 if (iter != per_hlo_dynamic_dimensions_.end()) {
1743 for (auto& dynamic_dimension : iter->second) {
1744 HloInstruction* dynamic_size =
1745 GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
1746 dynamic_dimension.dim);
1747 SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
1748 dynamic_size);
1749 }
1750 }
1751 }
1752
1753 /* static */
Run(HloModule * module,CustomCallInferenceHandler custom_call_handler)1754 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
1755 HloModule* module, CustomCallInferenceHandler custom_call_handler) {
1756 VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
1757 DynamicDimensionInference inference(module, std::move(custom_call_handler));
1758 TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
1759 return inference;
1760 }
1761
ToString() const1762 string DynamicDimensionInference::ToString() const {
1763 std::vector<string> pieces;
1764 pieces.push_back("DynamicDimensionInference: ");
1765 for (const auto& mapping : dynamic_mapping_) {
1766 const DynamicDimension& dynamic_dimension = mapping.first;
1767 pieces.push_back(absl::StrFormat(
1768 " -- instruction %s at %s has dim %lld as dynamic"
1769 " dimension, which is represented by instruction %s",
1770 dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
1771 dynamic_dimension.dim, mapping.second->ToString()));
1772 }
1773 return absl::StrJoin(pieces, "\n");
1774 }
1775
DynamicDimensionInference(HloModule * module,CustomCallInferenceHandler custom_call_handler)1776 DynamicDimensionInference::DynamicDimensionInference(
1777 HloModule* module, CustomCallInferenceHandler custom_call_handler)
1778 : module_(module), custom_call_handler_(std::move(custom_call_handler)) {}
1779
AnalyzeDynamicDimensions()1780 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
1781 return DynamicDimensionInferenceVisitor::Run(
1782 module_->entry_computation(), module_->dynamic_parameter_binding(), this,
1783 custom_call_handler_);
1784 }
1785
ReplaceAllDynamicDimensionUsesWith(HloInstruction * replace,HloInstruction * with)1786 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
1787 HloInstruction* replace, HloInstruction* with) {
1788 CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
1789 ShapeUtil::MakeScalarShape(S32)));
1790 CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
1791 ShapeUtil::MakeScalarShape(S32)));
1792 for (auto& kv : dynamic_mapping_) {
1793 if (kv.second == replace) {
1794 kv.second = with;
1795 }
1796 }
1797 }
1798
ForwardDynamicSize(HloInstruction * inst,HloInstruction * new_inst,const ShapeIndex & index)1799 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
1800 HloInstruction* new_inst,
1801 const ShapeIndex& index) {
1802 CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
1803
1804 for (int64_t dim = 0; dim < inst->shape().rank(); ++dim) {
1805 DynamicDimension dynamic_dimension_new{new_inst, index, dim};
1806 DynamicDimension dynamic_dimension{inst, index, dim};
1807 auto iter = dynamic_mapping_.find(dynamic_dimension);
1808 if (iter != dynamic_mapping_.end()) {
1809 dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
1810 auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
1811 iter.first->second.emplace(dynamic_dimension_new);
1812 }
1813 }
1814
1815 return Status::OK();
1816 }
1817
HasDynamicDimension(HloInstruction * inst,ShapeIndexView index) const1818 bool DynamicDimensionInference::HasDynamicDimension(
1819 HloInstruction* inst, ShapeIndexView index) const {
1820 bool has_dynamic_dim = false;
1821 ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape,
1822 const ShapeIndex& subindex) {
1823 if (subshape.IsTuple()) {
1824 return;
1825 }
1826 if (!ShapeIndexView(subindex).StartsWith(index)) {
1827 return;
1828 }
1829 for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
1830 HloInstruction* operand_dynamic_size = GetDynamicSize(inst, subindex, i);
1831 if (operand_dynamic_size != nullptr) {
1832 has_dynamic_dim = true;
1833 }
1834 }
1835 });
1836 return has_dynamic_dim;
1837 }
1838
Update(HloInstruction * inst)1839 Status DynamicDimensionInference::Update(HloInstruction* inst) {
1840 DynamicParameterBinding parameter_binding;
1841 DynamicDimensionInferenceVisitor visitor(parameter_binding, this,
1842 custom_call_handler_);
1843 return inst->Visit(&visitor);
1844 }
1845
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim) const1846 HloInstruction* DynamicDimensionInference::GetDynamicSize(
1847 HloInstruction* inst, const ShapeIndex& index, int64_t dim) const {
1848 auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
1849 if (iter != dynamic_mapping_.end()) {
1850 return iter->second;
1851 }
1852 return nullptr;
1853 }
1854
GetDynamicSizes(HloInstruction * inst,const ShapeIndex & index) const1855 std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
1856 HloInstruction* inst, const ShapeIndex& index) const {
1857 CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
1858 const int64_t rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
1859 std::vector<HloInstruction*> result(rank, nullptr);
1860 for (int64_t i = 0; i < rank; ++i) {
1861 result[i] = GetDynamicSize(inst, {}, i);
1862 }
1863 return result;
1864 }
1865
1866 } // namespace xla
1867