• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/tuple_util.h"
31 #include "tensorflow/compiler/xla/service/while_util.h"
32 #include "tensorflow/compiler/xla/shape_tree.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 namespace xla {
38 
39 namespace {
40 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
WidenComputation(HloComputation * narrow_comp,const Shape & wide_shape)41 StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
42                                            const Shape& wide_shape) {
43   TF_RET_CHECK(wide_shape.IsTuple());
44   const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
45   if (Shape::Equal()(wide_shape, narrow_shape)) {
46     // No need to widen the computation.
47     return narrow_comp;
48   }
49   HloComputation* wide_comp = [&]() {
50     HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
51     builder.AddInstruction(
52         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
53     return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
54   }();
55 
56   HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
57   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
58       wide_parameter, narrow_shape.tuple_shapes_size());
59   HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
60       HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
61                                  {truncated_parameter}, narrow_comp));
62   wide_comp->set_root_instruction(call_narrow_comp,
63                                   /*accept_different_shape=*/true);
64   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
65   return wide_comp;
66 }
67 }  // namespace
68 
69 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
70  public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)71   explicit DynamicDimensionInferenceVisitor(
72       const DynamicParameterBinding& param_bindings,
73       DynamicDimensionInference* parent,
74       DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)
75       : param_bindings_(param_bindings),
76         parent_(parent),
77         custom_call_handler_(std::move(custom_call_handler)) {}
78 
79   Status DefaultAction(HloInstruction* hlo) override;
80 
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler=nullptr)81   static Status Run(HloComputation* computation,
82                     const DynamicParameterBinding& param_bindings,
83                     DynamicDimensionInference* parent,
84                     DynamicDimensionInference::CustomCallInferenceHandler
85                         custom_call_handler = nullptr) {
86     DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
87                                              std::move(custom_call_handler));
88     return computation->Accept(&visitor);
89   }
90 
91   Status HandleParameter(HloInstruction* hlo) override;
92 
93   Status HandleReduce(HloInstruction* hlo) override;
94 
95   Status HandleDot(HloInstruction* hlo) override;
96 
97   Status HandleTuple(HloInstruction* hlo) override;
98 
99   Status HandleTranspose(HloInstruction* hlo) override;
100 
101   Status HandleDynamicReshape(HloInstruction* hlo) override;
102 
103   Status HandleReshape(HloInstruction* hlo) override;
104 
105   Status HandleSort(HloInstruction* hlo) override;
106 
107   Status HandlePad(HloInstruction* hlo) override;
108 
109   Status HandleCustomCall(HloInstruction* hlo) override;
110 
111   Status HandleBroadcast(HloInstruction* hlo) override;
112 
113   Status HandleGetDimensionSize(HloInstruction* hlo) override;
114 
115   Status HandleSetDimensionSize(HloInstruction* hlo) override;
116 
117   Status HandleSelect(HloInstruction* hlo) override;
118 
119   Status HandleConvolution(HloInstruction* hlo) override;
120 
121   Status HandleConcatenate(HloInstruction* hlo) override;
122 
123   Status HandleReduceWindow(HloInstruction* hlo) override;
124 
125   Status HandleReverse(HloInstruction* hlo) override;
126 
127   Status HandleSelectAndScatter(HloInstruction* hlo) override;
128 
129   Status HandleGetTupleElement(HloInstruction* hlo) override;
130 
131   Status HandleElementwiseUnary(HloInstruction* hlo) override;
132 
133   Status HandleElementwiseBinary(HloInstruction* hlo) override;
134 
135   Status HandleClamp(HloInstruction* hlo) override;
136 
137   Status HandleConditional(HloInstruction* hlo) override;
138 
139   Status HandleWhile(HloInstruction* hlo) override;
140 
141   Status HandleSlice(HloInstruction* hlo) override;
142 
143   Status HandleDynamicSlice(HloInstruction* hlo) override;
144 
145   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
146 
147   Status HandleGather(HloInstruction* hlo) override;
148 
149   Status HandleScatter(HloInstruction* hlo) override;
150 
151   Status HandleDomain(HloInstruction* hlo) override;
152 
153  private:
154   using OperandDynamicDimensionFn = std::function<Status(
155       HloInstruction* operand, ShapeIndex index, int64_t dimension,
156       int64_t operand_index, HloInstruction* dynamic_size)>;
157 
158   using DynamicDimensionFn = std::function<Status(
159       ShapeIndex index, int64_t dimension, HloInstruction* dynamic_size)>;
160 
161   Status HandleDynamicConvolutionForward(HloInstruction* hlo,
162                                          int64_t operand_index,
163                                          int64_t dimension,
164                                          HloInstruction* dynamic_size);
165 
166   Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
167                                             int64_t operand_index,
168                                             int64_t dimension);
169 
170   Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
171                                            int64_t operand_index,
172                                            int64_t dimension);
173 
174   Status HandleDynamicWindowSamePadding(HloInstruction* hlo,
175                                         HloInstruction* dynamic_size,
176                                         int64_t operand_index,
177                                         int64_t dimension);
178 
179   Status ForEachOperandDynamicDimension(HloInstruction* inst,
180                                         const OperandDynamicDimensionFn&);
181   Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
182                                           int64_t operand_index,
183                                           const OperandDynamicDimensionFn&);
184   Status ForEachDynamicDimension(HloInstruction* inst,
185                                  const DynamicDimensionFn& fn);
186 
187   // Pass through a dynamic dimension from the input to the output with the
188   // same value and index in the shape. This is a helper function to handle
189   // trivial instructions like elementwise operations.
190   Status PassThroughDynamicDimension(HloInstruction*);
191 
192   // The dynamic parameter bindings of this computation.
193   const DynamicParameterBinding& param_bindings_;
194 
195   // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
196   DynamicDimensionInference* parent_;
197 
198   // A handler for custom calls.
199   DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
200 };
201 
DefaultAction(HloInstruction * hlo)202 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
203   return ForEachOperandDynamicDimension(
204       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
205                int64_t operand_index, HloInstruction* dynamic_size) {
206         return UnimplementedStrCat(
207             "Asked to propagate a dynamic dimension from hlo ", operand->name(),
208             "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
209             ", which is not implemented.");
210       });
211 }
212 
HandleGetTupleElement(HloInstruction * hlo)213 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
214     HloInstruction* hlo) {
215   return ForEachOperandDynamicDimension(
216       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
217                int64_t operand_index, HloInstruction* dynamic_size) {
218         if (hlo->tuple_index() == index[0]) {
219           ShapeIndex new_index =
220               ShapeIndexView(index).ConsumeFront().ToShapeIndex();
221           parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
222         }
223         return Status::OK();
224       });
225 }
226 
HandleTuple(HloInstruction * hlo)227 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
228   return ForEachOperandDynamicDimension(
229       hlo, [&](HloInstruction*, ShapeIndex index, int64_t dimension,
230                int64_t operand_index, HloInstruction* dynamic_size) {
231         index.push_front(operand_index);
232         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
233         return Status::OK();
234       });
235 }
236 
HandleBroadcast(HloInstruction * hlo)237 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
238   return ForEachOperandDynamicDimension(
239       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
240                int64_t operand_index, HloInstruction* dynamic_size) {
241         int64_t broadcast_dim = hlo->dimensions(dimension);
242         parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
243         return Status::OK();
244       });
245 }
246 
HandleCustomCall(HloInstruction * hlo)247 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
248   if (hlo->custom_call_target() == "PadToStatic") {
249     for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
250       if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
251         HloInstruction* dynamic_size =
252             hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
253                 ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
254         // PadToStatic converts a dynamic dimension to static dimension. It then
255         // returns the padded data output and the dynamic sizes of input
256         // dimensions.
257         ShapeIndex data_output = {0};
258         parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
259       }
260     }
261     return Status::OK();
262   }
263   if (custom_call_handler_) {
264     return custom_call_handler_(hlo, parent_);
265   }
266 
267   if (hlo->custom_call_target() == "DynamicConvolutionForward") {
268     // If input feature is dynamic and kernel feature is static, we can infer
269     // that input feature is also static.
270     // E.g.,:
271     // lhs = [B, X, Y, ?]
272     // rhs = [X, Y, I, O]
273     // dim_labels = b01f_01io
274     // We can infer that the dynamic dimension in rhs is static I.
275     const ConvolutionDimensionNumbers& dnums =
276         hlo->convolution_dimension_numbers();
277     HloInstruction* input_feature = parent_->GetDynamicSize(
278         hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
279     HloInstruction* kernel_feature = parent_->GetDynamicSize(
280         hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
281 
282     if (input_feature != nullptr && kernel_feature == nullptr) {
283       if (hlo->mutable_operand(0)->shape().dimensions(
284               dnums.input_feature_dimension()) ==
285           hlo->mutable_operand(1)->shape().dimensions(
286               dnums.kernel_input_feature_dimension()))
287         parent_->SetDynamicSize(hlo->mutable_operand(0), {},
288                                 dnums.input_feature_dimension(), nullptr);
289     }
290   }
291   return ForEachOperandDynamicDimension(
292       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
293                int64_t operand_index, HloInstruction* dynamic_size) {
294         // Resize custom call should propagate dynamic batch (0) and channel (3)
295         // dimensions.
296         if (hlo->custom_call_target() == "SliceToDynamic" ||
297             hlo->custom_call_target() == "Sharding" ||
298             (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
299              (dimension == 0 || dimension == 3))) {
300           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
301           return Status::OK();
302         }
303         if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
304           if (hlo->operand_count() > 2) {
305             return Unimplemented(
306                 "DynamicReduceWindowSamePadding doesn't support variadic "
307                 "reduce window %s",
308                 hlo->ToString());
309           }
310           return HandleDynamicWindowSamePadding(hlo, dynamic_size,
311                                                 operand_index, dimension);
312         }
313 
314         if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
315           if (operand_index == 1) {
316             // Operand 0 (input) determines dynamic output size. We ignore the
317             // dynamic size in the operand 1 (output gradient).
318             return Status::OK();
319           }
320           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
321           return Status::OK();
322         }
323 
324         if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
325           return HandleDynamicConvolutionInputGrad(hlo, operand_index,
326                                                    dimension);
327         }
328 
329         if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
330           return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
331                                                     dimension);
332         }
333 
334         if (hlo->custom_call_target() == "DynamicConvolutionForward") {
335           return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
336                                                  dynamic_size);
337         }
338         return Unimplemented(
339             "CustomCall \"%s\" is not supported to have a dynamic dimension",
340             hlo->custom_call_target());
341       });
342 }
343 
HandleSort(HloInstruction * hlo)344 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
345   return ForEachOperandDynamicDimension(
346       hlo,
347       [&](HloInstruction* operand, ShapeIndex index, int64_t dynamic_dimension,
348           int64_t operand_index, HloInstruction* dynamic_size) {
349         HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
350         if (sort->values_count() == 0) {
351           parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
352         } else {
353           parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
354                                   dynamic_size);
355         }
356 
357         return Status::OK();
358       });
359 }
360 
HandlePad(HloInstruction * hlo)361 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
362   return ForEachOperandDynamicDimension(
363       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
364                int64_t operand_index, HloInstruction* dynamic_size) {
365         if (operand_index != 0) {
366           return Unimplemented(
367               "Dynamic dimension on padding value is not supported");
368         }
369         const PaddingConfig_PaddingConfigDimension& padding_config =
370             hlo->padding_config().dimensions(dimension);
371 
372         HloInstruction* dynamic_size_adjusted = dynamic_size;
373         if (padding_config.interior_padding() != 0) {
374           // Adjust for interior padding :
375           // Size' = max((Size - 1), 0) * interior_padding + Size
376           HloInstruction* one = hlo->parent()->AddInstruction(
377               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
378           HloInstruction* zero = hlo->parent()->AddInstruction(
379               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
380           HloInstruction* interior_padding = hlo->parent()->AddInstruction(
381               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
382                   padding_config.interior_padding())));
383           dynamic_size_adjusted =
384               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
385                   dynamic_size_adjusted->shape(), HloOpcode::kSubtract,
386                   dynamic_size_adjusted, one));
387           dynamic_size_adjusted =
388               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
389                   dynamic_size_adjusted->shape(), HloOpcode::kMaximum,
390                   dynamic_size_adjusted, zero));
391           dynamic_size_adjusted =
392               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
393                   dynamic_size_adjusted->shape(), HloOpcode::kMultiply,
394                   dynamic_size_adjusted, interior_padding));
395           dynamic_size_adjusted =
396               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
397                   dynamic_size_adjusted->shape(), HloOpcode::kAdd,
398                   dynamic_size_adjusted, dynamic_size));
399         }
400         HloInstruction* adjustment = hlo->parent()->AddInstruction(
401             HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
402                 padding_config.edge_padding_low() +
403                 padding_config.edge_padding_high())));
404         dynamic_size_adjusted =
405             hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
406                 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
407                 dynamic_size_adjusted, adjustment));
408         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
409         return Status::OK();
410       });
411 }
412 
HandleReduce(HloInstruction * hlo)413 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
414   return ForEachOperandDynamicDimension(
415       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
416                int64_t operand_index, HloInstruction* dynamic_size) {
417         auto* reduce = Cast<HloReduceInstruction>(hlo);
418         int64_t operand_count = reduce->operand_count();
419         CHECK_EQ(operand_count % 2, 0);
420         if (operand_index >= reduce->input_count()) {
421           // Init values doesn't have dynamic size.
422           return Status::OK();
423         }
424         if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
425           // Dimension is to be reduced, stop tracing.
426           return Status::OK();
427         }
428 
429         // Find out the new dynamic dimension after reduce.
430         int64_t dimensions_not_reduced_count = 0;
431         for (int i = 0; i < operand->shape().rank(); ++i) {
432           if (dimension == i) {
433             // The dimensions of all data operands of a variadic reduce have
434             // to be the same.  This means that if one operand of variadic
435             // reduce has a dynamic dimension, we set all outputs to use the
436             // same dynamic size in corresponding dimensions.
437             ShapeUtil::ForEachSubshape(
438                 reduce->shape(),
439                 [&](const Shape& subshape, ShapeIndex reduce_result_index) {
440                   if (!ShapeUtil::IsLeafIndex(reduce->shape(),
441                                               reduce_result_index)) {
442                     return;
443                   }
444                   parent_->SetDynamicSize(reduce, reduce_result_index,
445                                           dimensions_not_reduced_count,
446                                           dynamic_size);
447                 });
448 
449             return Status::OK();
450           }
451           if (absl::c_count(reduce->dimensions(), i) == 0) {
452             dimensions_not_reduced_count++;
453           }
454         }
455 
456         return Status::OK();
457       });
458 }
459 
HandleDot(HloInstruction * hlo)460 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
461   return ForEachOperandDynamicDimension(
462       hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index,
463                int64_t operand_dimension, int64_t operand_index,
464                HloInstruction* dynamic_size) {
465         // There are three types of dimensions in a dot:
466         // A. batch dims
467         // B. contracting dims
468         // C. non-batch non-contracting dims.
469         // The output dimensions of a dot has three parts with the following
470         // order:
471         // [(type A), (lhs type C), (rhs type C)]
472         //
473         // Note that both lhs and rhs have the same dimension sizes for batch,
474         // but the dimension index could be different.
475         //
476         // Given one dynamic input dimension, either lhs or rhs, we use a
477         // mapping to find the corresponding output dimension.
478         HloInstruction* dot = hlo;
479         const DotDimensionNumbers& dimension_numbers =
480             dot->dot_dimension_numbers();
481         // A map from the operand dimensions to result dimension.
482         absl::flat_hash_map<int64, int64> result_dim_mapping;
483         int64_t current_result_dims = 0;
484 
485         bool lhs = operand_index == 0;
486 
487         // The first loop keep tracks of batch dimension. RHS and LHS could have
488         // different batch dimension numbers.
489         if (lhs) {
490           for (int64_t i : dimension_numbers.lhs_batch_dimensions()) {
491             result_dim_mapping[i] = current_result_dims++;
492           }
493         } else {
494           for (int64_t i : dimension_numbers.rhs_batch_dimensions()) {
495             result_dim_mapping[i] = current_result_dims++;
496           }
497         }
498 
499         // Handle dimensions in the lhs.
500         for (int64_t i = 0; i < dot->operand(0)->shape().rank(); i++) {
501           // Look for non-contracting and non-batching dimension.
502           if (absl::c_linear_search(
503                   dimension_numbers.lhs_contracting_dimensions(), i)) {
504             continue;
505           }
506           if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
507                                     i)) {
508             continue;
509           }
510           if (lhs) {
511             result_dim_mapping[i] = current_result_dims;
512           }
513           current_result_dims++;
514         }
515 
516         // Handle dimensions in the rhs.
517         for (int64_t i = 0; i < dot->operand(1)->shape().rank(); i++) {
518           // Look for non-contracting and non-batching dimension.
519           if (absl::c_linear_search(
520                   dimension_numbers.rhs_contracting_dimensions(), i)) {
521             continue;
522           }
523           if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
524                                     i)) {
525             continue;
526           }
527           if (!lhs) {
528             result_dim_mapping[i] = current_result_dims;
529           }
530           current_result_dims++;
531         }
532 
533         // Check if the operand dim is in the result shape. If so, add another
534         // work item to trace that dimension.
535         auto iter = result_dim_mapping.find(operand_dimension);
536         if (iter != result_dim_mapping.end()) {
537           parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
538         }
539 
540         return Status::OK();
541       });
542 }
543 
HandleTranspose(HloInstruction * hlo)544 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
545   return ForEachOperandDynamicDimension(
546       hlo,
547       [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
548           int64_t operand_index, HloInstruction* dynamic_size) -> Status {
549         int64_t permuted_dim = -1;
550         for (int64_t i = 0; i < hlo->dimensions().size(); ++i) {
551           if (hlo->dimensions()[i] == dimension) {
552             TF_RET_CHECK(permuted_dim == -1);
553             permuted_dim = i;
554           }
555         }
556         parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
557         return Status::OK();
558       });
559 }
560 
HandleConvolution(HloInstruction * hlo)561 Status DynamicDimensionInferenceVisitor::HandleConvolution(
562     HloInstruction* hlo) {
563   return ForEachOperandDynamicDimension(
564       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
565                int64_t operand_index, HloInstruction* dynamic_size) {
566         HloInstruction* conv = hlo;
567         const ConvolutionDimensionNumbers& dimension_numbers =
568             conv->convolution_dimension_numbers();
569         if (operand_index == 0) {
570           if (dimension == dimension_numbers.input_batch_dimension()) {
571             parent_->SetDynamicSize(conv, {},
572                                     dimension_numbers.output_batch_dimension(),
573                                     dynamic_size);
574             return Status::OK();
575           }
576 
577           if (dimension == dimension_numbers.input_feature_dimension()) {
578             return Status::OK();
579           }
580         } else {
581           if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
582             return Status::OK();
583           }
584         }
585 
586         return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
587                              conv->ToString());
588       });
589 }
590 
HandleConcatenate(HloInstruction * hlo)591 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
592     HloInstruction* hlo) {
593   // First handle concatenate dimensions. We do this by iterating through all
594   // operands while tracking both dynamic and static dimensions.
595 
596   // static_size is used to keep track of the concated size of static
597   // dimensions.
598   int64_t static_size = 0;
599   std::vector<HloInstruction*> dynamic_concat_dims;
600   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
601     HloInstruction* dynamic_size = parent_->GetDynamicSize(
602         hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
603     if (dynamic_size == nullptr) {
604       // This is a static dimension.
605       static_size +=
606           hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
607     } else {
608       dynamic_concat_dims.push_back(dynamic_size);
609     }
610   }
611   // If concat dimension is dynamic, calculate its size by summing up static
612   // dims and dynamic dims together.
613   if (!dynamic_concat_dims.empty()) {
614     HloInstruction* dim_size_total =
615         hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
616             LiteralUtil::CreateR0<int32>(static_size)));
617     for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
618       dim_size_total = hlo->parent()->AddInstruction(
619           HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
620                                        dim_size_total, dynamic_dim));
621     }
622     parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
623                             dim_size_total);
624   }
625 
626   // Simply pass through non-concat dynamic dimensions.
627   return ForEachOperandDynamicDimension(
628       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
629                int64_t operand_index, HloInstruction* dynamic_size) {
630         int64_t concatenate_dimension = hlo->concatenate_dimension();
631         if (concatenate_dimension == dimension) {
632           return Status::OK();
633         }
634         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
635         return Status::OK();
636       });
637 }
638 
HandleGetDimensionSize(HloInstruction * gds)639 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
640     HloInstruction* gds) {
641   // Dynamic dimension doesn't propagate through GetDimensionSize:
642   //
643   //   Input: F32[x, y, z]
644   //     |
645   //   GetDimensionSize(1): S32[]
646   //
647   // The returned value is a scalar, which doesn't have any dynamic dimension in
648   // the shape (although the value contains the real size of the dynamic
649   // dimension of the input).
650   int64_t dim = gds->dimension();
651   HloInstruction* operand = gds->mutable_operand(0);
652   HloInstruction* dynamic_size = parent_->GetDynamicSize(operand, {}, dim);
653   HloComputation* computation = gds->parent();
654   if (dynamic_size != nullptr) {
655     TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(dynamic_size));
656     // The dependency between an instruction and its dynamic dimensions is not
657     // modeled in the IR. As instr is being replaced by dynamic_size, also tell
658     // dynamic dimension inference that the instruction is being replaced.
659     parent_->ReplaceAllDynamicDimensionUsesWith(gds, dynamic_size);
660   } else {
661     TF_RET_CHECK(dim < gds->operand(0)->shape().rank());
662     int32_t size = gds->operand(0)->shape().dimensions(dim);
663     HloInstruction* new_instr = computation->AddInstruction(
664         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
665     TF_RETURN_IF_ERROR(gds->ReplaceAllUsesWith(new_instr));
666     parent_->ReplaceAllDynamicDimensionUsesWith(gds, new_instr);
667   }
668   return Status::OK();
669 }
670 
HandleSetDimensionSize(HloInstruction * hlo)671 Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
672     HloInstruction* hlo) {
673   bool dimension_is_static = false;
674   const HloInstruction* size = hlo->operand(1);
675   if (size->opcode() == HloOpcode::kConstant) {
676     // Check if we are setting a dimension size to its static size. If so,
677     // removes the dynamic dimension.
678     //
679     // size = s32[] constant(5)
680     // s32[2, 5] = set-dimension-size(s32[2,<=5]{1,0} %param, s32[] %size),
681     //                                                        dimensions={1}
682     // The result shape has no dynamic dimension.
683     TF_RET_CHECK(size->shape().rank() == 0);
684     if (size->literal().Get<int32>({}) ==
685         hlo->shape().dimensions(hlo->dimension())) {
686       dimension_is_static = true;
687     }
688   }
689 
690   if (!dimension_is_static) {
691     // Propagate dynamic dimension indicated by this set dimension size
692     // instruction.
693     parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
694   }
695 
696   // Also Propagate dynamic dimension already set by operands.
697   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
698       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
699                int64_t operand_index, HloInstruction* dynamic_size) {
700         if (dimension != hlo->dimension()) {
701           parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
702         }
703         return Status::OK();
704       }));
705 
706   return Status::OK();
707 }
708 
HandleDynamicConvolutionForward(HloInstruction * hlo,int64_t operand_index,int64_t dimension,HloInstruction * dynamic_size)709 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
710     HloInstruction* hlo, int64_t operand_index, int64_t dimension,
711     HloInstruction* dynamic_size) {
712   TF_RET_CHECK(operand_index == 0);
713   const ConvolutionDimensionNumbers& dimension_numbers =
714       hlo->convolution_dimension_numbers();
715 
716   if (dimension == dimension_numbers.input_batch_dimension()) {
717     // Batch dimension is propagated without any changes.
718     parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
719                             dynamic_size);
720     return Status::OK();
721   }
722 
723   for (int64_t spatial_dim_index = 0;
724        spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
725        ++spatial_dim_index) {
726     int64_t input_spatial_dim =
727         dimension_numbers.input_spatial_dimensions(spatial_dim_index);
728     int64_t output_spatial_dim =
729         dimension_numbers.output_spatial_dimensions(spatial_dim_index);
730     if (dimension == input_spatial_dim) {
731       // This is a dynamic spatial dimension. Calculate the output size.
732       WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
733       DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
734           dynamic_size, window_dim.size(), window_dim.window_dilation(),
735           window_dim.stride(), hlo->padding_type());
736       TF_RET_CHECK(window_dim.base_dilation() == 1);
737       parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
738                               dynamic_window_dims.output_size);
739       return Status::OK();
740     }
741   }
742   // Input Feature dim disappears after convolution.
743   return Status::OK();
744 }
745 
HandleDynamicWindowSamePadding(HloInstruction * hlo,HloInstruction * dynamic_size,int64_t operand_index,int64_t dimension)746 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
747     HloInstruction* hlo, HloInstruction* dynamic_size, int64_t operand_index,
748     int64_t dimension) {
749   const Window& window = hlo->window();
750   const WindowDimension& window_dim = window.dimensions(dimension);
751   if (!window_util::IsTrivialWindowDimension(window_dim)) {
752     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
753         dynamic_size, window_dim.size(), window_dim.window_dilation(),
754         window_dim.stride(), PaddingType::PADDING_SAME);
755     parent_->SetDynamicSize(hlo, {}, dimension,
756                             dynamic_window_dims.output_size);
757     return Status::OK();
758   }
759 
760   parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
761 
762   return Status::OK();
763 }
764 
HandleDynamicConvolutionInputGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)765 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
766     HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
767   // The output size of convolution input grad is corresponding input size.
768   HloInstruction* input_sizes = hlo->mutable_operand(0);
769   HloComputation* comp = hlo->parent();
770   TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
771   TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
772   TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
773                hlo->shape().dimensions_size())
774       << hlo->ToString();
775   // Slice to get corresponding input size.
776   HloInstruction* slice = comp->AddInstruction(
777       HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
778                                   {dimension}, {dimension + 1}, {1}));
779   HloInstruction* reshape = comp->AddInstruction(
780       HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
781   parent_->SetDynamicSize(hlo, {}, dimension, reshape);
782   return Status::OK();
783 }
784 
HandleDynamicConvolutionKernelGrad(HloInstruction * hlo,int64_t operand_index,int64_t dimension)785 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
786     HloInstruction* hlo, int64_t operand_index, int64_t dimension) {
787   // Dynamic convolution kernel grad produces static shape outputs.
788   return Status::OK();
789 }
790 
PassThroughDynamicDimension(HloInstruction * hlo)791 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
792     HloInstruction* hlo) {
793   return ForEachOperandDynamicDimension(
794       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
795                int64_t operand_index, HloInstruction* dynamic_size) {
796         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
797         return Status::OK();
798       });
799 }
800 
HandleDomain(HloInstruction * hlo)801 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
802   return PassThroughDynamicDimension(hlo);
803 }
804 
HandleElementwiseUnary(HloInstruction * hlo)805 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
806     HloInstruction* hlo) {
807   return PassThroughDynamicDimension(hlo);
808 }
809 
HandleSelect(HloInstruction * hlo)810 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
811   return PassThroughDynamicDimension(hlo);
812 }
813 
HandleElementwiseBinary(HloInstruction * hlo)814 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
815     HloInstruction* hlo) {
816   HloComputation* comp = hlo->parent();
817   return ForEachOperandDynamicDimension(
818       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
819                int64_t operand_index, HloInstruction* dynamic_size) {
820         HloInstruction* existing_size =
821             parent_->GetDynamicSize(hlo, index, dimension);
822         if (existing_size == nullptr || existing_size == dynamic_size) {
823           parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
824         } else {
825           HloInstruction* max =
826               comp->AddInstruction(HloInstruction::CreateBinary(
827                   ShapeUtil::MakeScalarShape(S32), HloOpcode::kMaximum,
828                   dynamic_size, existing_size));
829           parent_->SetDynamicSize(hlo, index, dimension, max);
830         }
831         return Status::OK();
832       });
833 }
834 
HandleClamp(HloInstruction * hlo)835 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
836   return PassThroughDynamicDimension(hlo);
837 }
838 
HandleDynamicReshape(HloInstruction * hlo)839 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
840     HloInstruction* hlo) {
841   HloDynamicReshapeInstruction* dynamic_reshape =
842       Cast<HloDynamicReshapeInstruction>(hlo);
843   for (int64_t i = 0; i < hlo->shape().rank(); ++i) {
844     if (hlo->shape().is_dynamic_dimension(i)) {
845       parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
846     }
847   }
848   return Status::OK();
849 }
850 
HandleReshape(HloInstruction * hlo)851 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
852   return ForEachOperandDynamicDimension(
853       hlo,
854       [&](HloInstruction* operand, ShapeIndex index,
855           int64_t input_dynamic_dimension, int64_t operand_index,
856           HloInstruction* operand_dynamic_size) -> Status {
857         HloInstruction* reshape = hlo;
858         if (reshape->shape().rank() == 0) {
859           VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
860                      "undefined behavior when input size is 0. The offending "
861                      "instruction is: "
862                   << reshape->ToString();
863           return Status::OK();
864         }
865         auto common_factors = CommonFactors(operand->shape().dimensions(),
866                                             reshape->shape().dimensions());
867         int64_t input_dim_start = -1;
868         int64_t input_dim_end = -1;
869         int64_t output_dim_start = -1;
870         int64_t output_dim_end = -1;
871         // Find common_factors that the input belongs to.
872         for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
873           auto start = common_factors[i];
874           auto end = common_factors[i + 1];
875           if (input_dynamic_dimension >= start.first &&
876               input_dynamic_dimension < end.first) {
877             // Found the common_factor group that the input_dim belongs to.
878             input_dim_start = start.first;
879             input_dim_end = end.first;
880             output_dim_start = start.second;
881             output_dim_end = end.second;
882           }
883         }
884 
885         VLOG(2) << "Input dim start: " << input_dim_start
886                 << " Input dim end: " << input_dim_end
887                 << " output dim start: " << output_dim_start
888                 << " output dim end: " << output_dim_end;
889 
890         if ((input_dim_end - input_dim_start) > 1 &&
891             (output_dim_end - output_dim_start) > 1) {
892           // We don't support the case when a dynamic dimension is both combined
893           // with and splitted into other dimensions:
894           //
895           //  [x, yz]
896           //     | Reshape
897           //  [xy, z]
898           //
899           // TODO(yunxing): This can be supported by canonicalizing
900           // the offending reshape into two reshapes:
901           //
902           //  [x,yz]
903           //     | Reshape
904           //  [x, y, z]
905           //     | Reshape
906           //  [xy, z]
907           //
908           return Unimplemented(
909               "Dynamic input dimension to reshape that is both splitted and "
910               "combined is not supported %s",
911               hlo->ToString());
912         }
913 
914         for (auto common_factor : common_factors) {
915           // Expand common factor to include degenerated output dimensions.
916           if (common_factor.first == input_dim_start) {
917             output_dim_start = std::min(output_dim_start, common_factor.second);
918           }
919           if (common_factor.first == input_dim_end) {
920             output_dim_end = std::max(output_dim_end, common_factor.second);
921           }
922         }
923 
924         int64_t output_dynamic_dimension = -1;
925 
926         if (operand->shape().dimensions(input_dynamic_dimension) == 1) {
927           // If dynamic dimension is 1, it can only be most-major or
928           // most-minor.
929           if (input_dynamic_dimension == 0) {
930             output_dynamic_dimension = 0;
931           }
932           if (input_dynamic_dimension == operand->shape().rank() - 1) {
933             output_dynamic_dimension = reshape->shape().rank() - 1;
934           }
935 
936           if (output_dynamic_dimension == -1) {
937             return Unimplemented(
938                 "Dynamic degenerated dimension that's not most-minor nor "
939                 "most-major is not supported %s",
940                 reshape->ToString());
941           }
942         }
943 
944         if (output_dynamic_dimension == -1 &&
945             output_dim_end - output_dim_start == 1) {
946           // Only one possible output dimension.
947           output_dynamic_dimension = output_dim_start;
948         }
949 
950         if (output_dynamic_dimension == -1 &&
951             output_dim_end - output_dim_start > 1) {
952           // One input dimension is splitted into multiple output dimensions.
953           // Output dimension is decomposed from input most major dimension.
954           // In this case, we don't know which one is dynamic, e.g., when we
955           // have:
956           //
957           //           [<=a/c, c, b]
958           //              | Reshape
959           //           [<=a, b] // a is dynamic, has to be multiple of c.
960           //             |  Reshape
961           // [1, 1, ... , a/c, c, b]
962           //
963           // Any dimension from the first '1' to 'a/c' can be dynamic.
964           //
965           // We use the following logics to disambiguate:
966           // 1. If the user sets "inferred_dimension", then use that as
967           // dynamic dimension.
968           // 2. If the one dimension in the reshape is dynamic, use that as
969           // dynamic dimension.
970           // E.g.:
971           //     [<=4]
972           //      |
973           //   reshape
974           //      |
975           //   [1, <=2, 2]
976           // We use second dim as dynamic dimension.
977           //
978           // 3. If all logics above cannot disambiguate, e.g.,:
979           //
980           //     [<=1]
981           //      |
982           //   reshape
983           //      |
984           //   [1, 1, 1]
985           //
986           //   We bail out and return an error.
987           // TODO(yunxing): Further simplify this, remove 1. and fully rely
988           // on 2.
989           output_dynamic_dimension = reshape->inferred_dimension();
990           if (output_dynamic_dimension == -1) {
991             // Try find dynamic dimension from the result shape.
992             for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
993               if (reshape->shape().is_dynamic_dimension(i)) {
994                 output_dynamic_dimension = i;
995               }
996             }
997           }
998 
999           if (output_dynamic_dimension == -1) {
1000             std::vector<int64> output_non_degenerated;
1001             for (int64_t i = output_dim_start; i < output_dim_end; ++i) {
1002               if (reshape->shape().dimensions(i) != 1) {
1003                 output_non_degenerated.push_back(i);
1004               }
1005             }
1006             if (output_non_degenerated.size() == 1) {
1007               output_dynamic_dimension = output_non_degenerated[0];
1008             }
1009           }
1010 
1011           if (output_dynamic_dimension == -1) {
1012             return InvalidArgument(
1013                 "Reshape's input dynamic dimension is decomposed into "
1014                 "multiple output dynamic dimensions, but the constraint is "
1015                 "ambiguous and XLA can't infer the output dimension %s. ",
1016                 hlo->ToString());
1017           }
1018         }
1019 
1020         CHECK_NE(output_dynamic_dimension, -1);
1021         const int64_t input_dim_size =
1022             operand->shape().dimensions(input_dynamic_dimension);
1023         const int64_t output_dim_size =
1024             reshape->shape().dimensions(output_dynamic_dimension);
1025         VLOG(2) << "input_dim_size: " << input_dim_size
1026                 << " output_dim_size: " << output_dim_size;
1027 
1028         if (input_dim_size == output_dim_size) {
1029           // Simply forward dynamic dimension.
1030           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1031                                   operand_dynamic_size);
1032         }
1033 
1034         if (input_dim_size > output_dim_size) {
1035           TF_RET_CHECK(input_dim_size % output_dim_size == 0)
1036               << reshape->ToString();
1037           const int64_t divisor = input_dim_size / output_dim_size;
1038           HloInstruction* divisor_hlo =
1039               hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1040                   LiteralUtil::CreateR0<int32>(divisor)));
1041 
1042           HloInstruction* new_dynamic_size =
1043               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1044                   operand_dynamic_size->shape(), HloOpcode::kDivide,
1045                   operand_dynamic_size, divisor_hlo));
1046 
1047           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1048                                   new_dynamic_size);
1049         }
1050 
1051         if (input_dim_size < output_dim_size) {
1052           // Input dimension is combined with other input dimensions.
1053           //
1054           // Adjust the output size by the ratio of dynamic_input_dim /
1055           // static_input_dim.
1056           //
1057           // For example if we have  [<=3, 3] -> [9], if the dynamic size is 2,
1058           // the new output dynamic isze is 9 / 3 * 2 = 6.
1059           //
1060           // If it turns out the second dimension is also dynamic:
1061           // [<=3, <=3] -> [9], and the dynamic size is also 2, the new output
1062           // dynamic size is 6 / 3 * 2 = 4.
1063           //
1064           //
1065           HloInstruction* output_dynamic_size =
1066               parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
1067           if (output_dynamic_size == nullptr) {
1068             output_dynamic_size =
1069                 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1070                     LiteralUtil::CreateR0<int32>(output_dim_size)));
1071           }
1072           HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(
1073               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1074                   operand->shape().dimensions(input_dynamic_dimension))));
1075 
1076           HloInstruction* new_dynamic_size =
1077               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1078                   output_dynamic_size->shape(), HloOpcode::kDivide,
1079                   output_dynamic_size, divisor_hlo));
1080 
1081           new_dynamic_size =
1082               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1083                   output_dynamic_size->shape(), HloOpcode::kMultiply,
1084                   new_dynamic_size, operand_dynamic_size));
1085           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1086                                   new_dynamic_size);
1087         }
1088 
1089         return Status::OK();
1090       });
1091 }
1092 
HandleReduceWindow(HloInstruction * hlo)1093 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
1094     HloInstruction* hlo) {
1095   return ForEachOperandDynamicDimension(
1096       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1097                int64_t operand_index, HloInstruction* dynamic_size) {
1098         auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
1099         const WindowDimension& window_dim =
1100             reduce_window->window().dimensions(dimension);
1101 
1102         if (operand_index >= reduce_window->input_count()) {
1103           // Init values doesn't have dynamic size.
1104           return Status::OK();
1105         }
1106 
1107         if (!window_util::IsTrivialWindowDimension(window_dim)) {
1108           DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1109               dynamic_size, window_dim.size(), window_dim.window_dilation(),
1110               window_dim.stride(), PaddingType::PADDING_VALID);
1111           dynamic_size = dynamic_window_dims.output_size;
1112         }
1113 
1114         // The dimensions of all data operands of a variadic reduce window have
1115         // to be the same.  This means that if one operand of variadic
1116         // reduce has a dynamic dimension, we set all outputs to use the
1117         // same dynamic size in corresponding dimensions.
1118         ShapeUtil::ForEachSubshape(
1119             reduce_window->shape(),
1120             [&](const Shape& subshape, ShapeIndex reduce_window_result_index) {
1121               if (!ShapeUtil::IsLeafIndex(reduce_window->shape(),
1122                                           reduce_window_result_index)) {
1123                 return;
1124               }
1125               parent_->SetDynamicSize(reduce_window, reduce_window_result_index,
1126                                       dimension, dynamic_size);
1127             });
1128 
1129         return Status::OK();
1130       });
1131 }
1132 
HandleSelectAndScatter(HloInstruction * hlo)1133 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
1134     HloInstruction* hlo) {
1135   return ForEachOperandDynamicDimension(
1136       hlo, [&](HloInstruction* operand, ShapeIndex index, int64_t dimension,
1137                int64_t operand_index, HloInstruction* dynamic_size) {
1138         if (operand_index == 1) {
1139           // Operand 0 (input) determines dynamic output size. We ignore the
1140           // dynamic size in the operand 1 (output gradient).
1141           return Status::OK();
1142         }
1143         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1144 
1145         return Status::OK();
1146       });
1147 }
1148 
HandleSlice(HloInstruction * hlo)1149 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
1150   return ForEachOperandDynamicDimension(
1151       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64_t dimension,
1152                int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1153         if (hlo->slice_starts(dimension) != 0 ||
1154             hlo->slice_strides(dimension) != 1 ||
1155             hlo->slice_limits(dimension) !=
1156                 operand->shape().dimensions(dimension)) {
1157           // Slicing a partial element out eliminates the dynamic dimension.
1158           return Status::OK();
1159         }
1160 
1161         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1162 
1163         return Status::OK();
1164       });
1165 }
1166 
HandleDynamicSlice(HloInstruction * hlo)1167 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
1168     HloInstruction* hlo) {
1169   return ForEachOperandDynamicDimension(
1170       hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64_t dimension,
1171                int64_t /*operand_index*/, HloInstruction* dynamic_size) {
1172         if (hlo->shape().dimensions(dimension) !=
1173             hlo->operand(0)->shape().dimensions(dimension)) {
1174           // Slicing a single element out kills the dynamic dimension.
1175           if (hlo->shape().dimensions(dimension) == 1) {
1176             return Status::OK();
1177           }
1178           return Unimplemented(
1179               "Dynamic dimension propagation on DynamicSlice where a partial "
1180               "dimension is selected %s",
1181               hlo->ToString());
1182         }
1183 
1184         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1185 
1186         return Status::OK();
1187       });
1188 }
1189 
HandleDynamicUpdateSlice(HloInstruction * hlo)1190 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
1191     HloInstruction* hlo) {
1192   return ForEachOperandDynamicDimension(
1193       hlo,
1194       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1195           int64_t operand_index, HloInstruction* dynamic_size) {
1196         if (hlo->shape().dimensions(dimension) !=
1197             hlo->operand(0)->shape().dimensions(dimension)) {
1198           return Unimplemented(
1199               "Dynamic dimension propagation on DynamicUpdateSlice where a "
1200               "partial dimension is selected %s",
1201               hlo->ToString());
1202         }
1203 
1204         if (operand_index == 1 &&
1205             hlo->operand(1)->shape().dimensions(dimension) <
1206                 hlo->operand(0)->shape().dimensions(dimension)) {
1207           // DUS(input=[A], update=[<=B])
1208           //
1209           // If update dim is smaller than input dim (B < A) , then we are doing
1210           // a partial update, no need to set the output dynamic dimension.
1211           //
1212           // The dynamic shape in `update` doesn't change output dynamic shape.
1213           return Status::OK();
1214         }
1215 
1216         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1217 
1218         return Status::OK();
1219       });
1220 }
1221 
HandleReverse(HloInstruction * hlo)1222 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
1223   return PassThroughDynamicDimension(hlo);
1224 }
1225 
HandleGather(HloInstruction * hlo)1226 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
1227   return ForEachOperandDynamicDimension(
1228       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
1229                int64_t input_dynamic_dimension, int64_t operand_index,
1230                HloInstruction* dynamic_size) {
1231         const GatherDimensionNumbers& gather_dims =
1232             hlo->gather_dimension_numbers();
1233         if (operand_index != 1) {
1234           if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
1235             // Gathering a size 1 dimension out of a dynamic dimension removes
1236             // the dynamicity.
1237             return Status::OK();
1238           }
1239           if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
1240               operand->shape().dimensions(input_dynamic_dimension)) {
1241             // Gathering a full-sized dimension out of a dynamic dimension
1242             // propagates the dynamicity to output.
1243             int64_t output_dimension = input_dynamic_dimension;
1244             for (int64_t collapsed_dim : gather_dims.collapsed_slice_dims()) {
1245               if (collapsed_dim < input_dynamic_dimension) {
1246                 // This output dimension is collapsed.
1247                 output_dimension--;
1248               }
1249             }
1250             parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
1251             return Status::OK();
1252           }
1253           return Unimplemented(
1254               "Detects a dynamic dimension on the data input of gather, which "
1255               "is not supported: %s, %lld",
1256               hlo->ToString(), input_dynamic_dimension);
1257         }
1258         // A mapping from output to input batch dim number. -1 means not a batch
1259         // dimension.
1260         int64_t indices_rank = hlo->operand(1)->shape().rank();
1261         int64_t output_rank = hlo->shape().rank();
1262 
1263         // indices_dim is an iterator over indices dimensions.
1264         int64_t indices_dim = 0;
1265         // Find the corresponding batch dimension in the output.
1266         for (int64_t output_dim = 0; output_dim < output_rank; ++output_dim) {
1267           if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) {
1268             // Skips index vector dimension.
1269             if (indices_dim == gather_dims.index_vector_dim()) {
1270               indices_dim++;
1271             }
1272             if (indices_dim++ == input_dynamic_dimension) {
1273               parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
1274               return Status::OK();
1275             }
1276           }
1277         }
1278         CHECK(indices_dim == indices_rank);
1279 
1280         return Unimplemented(
1281             "Detects a non-batch dynamic dimension of gather, "
1282             "which is not supported: %s",
1283             hlo->ToString());
1284       });
1285 }
1286 
HandleConditional(HloInstruction * hlo)1287 Status DynamicDimensionInferenceVisitor::HandleConditional(
1288     HloInstruction* hlo) {
1289   // Conditionals are handled by producing additional inputs and outputs of
1290   // the conditional instruction.
1291   std::vector<HloComputation*> new_branch_computations;
1292   std::vector<HloInstruction*> new_operands;
1293   // If the output of the conditional contains dynamic dimension. We send
1294   // dynamic dimension size out by adding additional root element. A mapping
1295   // from the root instruction's dynamic dimension index (represented by a shape
1296   // index as output index and a int64 dimension number) to output index
1297   // (represented by an int64) is tracked for the conditional intsruction (all
1298   // branches should have the same mapping).
1299   ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1300       hlo->shape());
1301 
1302   bool need_rewrite = false;
1303   for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1304        ++branch_index) {
1305     std::vector<HloInstruction*> operands_to_add;
1306 
1307     absl::flat_hash_map<HloInstruction*, int64>
1308         dynamic_size_to_operand_id_index_map;
1309     // Only look at branch_index + 1, the correct operand index for a
1310     // given branch.
1311     const int64_t operand_index = branch_index + 1;
1312 
1313     int64_t operand_count =
1314         hlo->operand(operand_index)->shape().tuple_shapes_size();
1315     // Prepare to pass dynamic dimension into the new computation and add
1316     // dynamic dimension sizes as parameters to the new tuple.
1317     TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1318         hlo, operand_index,
1319         [&](HloInstruction*, ShapeIndex, int64_t, int64_t,
1320             HloInstruction* dynamic_size) -> Status {
1321           TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
1322               << "Only tuple typed inputs can have dynamic dimension. Please "
1323                  "file a bug against XLA team.";
1324           const HloInstruction* tuple_operand = hlo->operand(operand_index);
1325           for (int64_t i = 0; i < tuple_operand->operand_count(); ++i) {
1326             // If the dynamic size is already an operand to the computation,
1327             // skip adding it to the computation input again.
1328             if (dynamic_size == tuple_operand->operand(i)) {
1329               dynamic_size_to_operand_id_index_map[dynamic_size] = i;
1330               return Status::OK();
1331             }
1332           }
1333           auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
1334           if (iter == dynamic_size_to_operand_id_index_map.end()) {
1335             operands_to_add.push_back(dynamic_size);
1336             dynamic_size_to_operand_id_index_map[dynamic_size] =
1337                 operand_count++;
1338           }
1339           return Status::OK();
1340         }));
1341 
1342     HloInstruction* original_input = hlo->mutable_operand(operand_index);
1343     HloComputation* branch_computation = hlo->branch_computation(branch_index);
1344 
1345     HloComputation* new_computation = branch_computation;
1346     HloInstruction* new_operand = hlo->mutable_operand(operand_index);
1347     if (!operands_to_add.empty()) {
1348       TF_RET_CHECK(original_input->shape().IsTuple());
1349       need_rewrite = true;
1350       new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
1351       TF_ASSIGN_OR_RETURN(
1352           new_computation,
1353           WidenComputation(branch_computation, new_operand->shape()));
1354     }
1355     // Set the dynamic dimensions for the newly created branch computation's
1356     // parameters so that the hlos inside the computation can see dynamic
1357     // dimensions.
1358     DynamicParameterBinding dynamic_parameter_binding;
1359     TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1360         hlo, operand_index,
1361         [&](HloInstruction*, ShapeIndex index, int64_t dimension,
1362             int64_t operand_index, HloInstruction* dynamic_size) {
1363           DynamicParameterBinding::DynamicParameter dynamic_parameter{
1364               0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
1365           DynamicParameterBinding::DynamicDimension dynamic_dimension{
1366               0, {index}, dimension};
1367           TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
1368                                                             dynamic_dimension));
1369 
1370           return Status::OK();
1371         }));
1372     VLOG(2) << "dynamic_parameter_binding for conditional branch"
1373             << dynamic_parameter_binding;
1374     TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1375         new_computation, dynamic_parameter_binding, parent_));
1376 
1377     new_branch_computations.push_back(new_computation);
1378     new_operands.push_back(new_operand);
1379   }
1380   int64_t tuple_count = hlo->shape().tuple_shapes_size();
1381   // The dynamism of the output of branches can be different.
1382   // E.g.,
1383   //   true_branch  (s32[<=4])
1384   //   false_branch (s32[4])
1385   //
1386   // The following loop populates dynamic_output_mapping and account for
1387   // dynamism across all branches.
1388   ShapeUtil::ForEachSubshape(
1389       hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1390         if (!subshape.IsArray()) {
1391           return;
1392         }
1393         for (int64 i = 0; i < subshape.rank(); ++i) {
1394           for (int64 j = 0; j < new_branch_computations.size(); ++j) {
1395             HloInstruction* dynamic_size = parent_->GetDynamicSize(
1396                 new_branch_computations[j]->root_instruction(), index, i);
1397             if (dynamic_size) {
1398               if (dynamic_output_mapping.element(index).contains(i)) {
1399                 continue;
1400               }
1401               dynamic_output_mapping.mutable_element(index)->emplace(
1402                   i, tuple_count++);
1403             }
1404           }
1405         }
1406       });
1407   for (int64_t branch_index = 0; branch_index < hlo->branch_count();
1408        ++branch_index) {
1409     std::vector<HloInstruction*> hlos_to_add_in_root;
1410     // There may be some dynamic dimensions coming out of the computation, wire
1411     // that into the root instruction as additional tuple elements.
1412     ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape,
1413                                                  const ShapeIndex& index) {
1414       if (!subshape.IsArray()) {
1415         return;
1416       }
1417       for (int64 i = 0; i < subshape.rank(); ++i) {
1418         if (dynamic_output_mapping.element(index).contains(i)) {
1419           HloInstruction* dynamic_size = parent_->GetDynamicSize(
1420               new_branch_computations[branch_index]->root_instruction(), index,
1421               i);
1422           if (dynamic_size) {
1423             hlos_to_add_in_root.push_back(dynamic_size);
1424           } else {
1425             HloInstruction* constant_size =
1426                 new_branch_computations[branch_index]->AddInstruction(
1427                     HloInstruction::CreateConstant(
1428                         LiteralUtil::CreateR0<int32>(subshape.dimensions(i))));
1429             hlos_to_add_in_root.push_back(constant_size);
1430           }
1431         }
1432       }
1433     });
1434 
1435     VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
1436     if (!hlos_to_add_in_root.empty()) {
1437       need_rewrite = true;
1438       HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
1439           new_branch_computations[branch_index]->root_instruction(),
1440           hlos_to_add_in_root);
1441       new_branch_computations[branch_index]->set_root_instruction(
1442           new_branch_root,
1443           /*accept_different_shape=*/true);
1444     }
1445   }
1446 
1447   if (!need_rewrite) {
1448     return Status::OK();
1449   }
1450   // Create a new conditional with the new operations and computations.
1451   HloInstruction* new_conditional =
1452       hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
1453           new_branch_computations[0]->root_instruction()->shape(),
1454           hlo->mutable_operand(0), new_branch_computations, new_operands));
1455 
1456   HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
1457       new_conditional, hlo->shape().tuple_shapes_size());
1458   // Now set the dynamic dimensions of the newly created conditional.
1459   dynamic_output_mapping.ForEachElement(
1460       [&](const ShapeIndex& index,
1461           const absl::flat_hash_map<int64, int64>& dim_to_output) {
1462         for (auto iter : dim_to_output) {
1463           int64_t dim = iter.first;
1464           int64_t output_index = iter.second;
1465           HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
1466               HloInstruction::CreateGetTupleElement(
1467                   ShapeUtil::MakeScalarShape(S32), new_conditional,
1468                   output_index));
1469           parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
1470           parent_->SetDynamicSize(new_conditional_extracted, index, dim,
1471                                   dynamic_size);
1472         }
1473       });
1474 
1475   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
1476   // Remove the original instruction even if has side-effects.
1477   TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
1478   SetVisited(*new_conditional);
1479   SetVisited(*new_conditional_extracted);
1480   return Status::OK();
1481 }
1482 
HandleScatter(HloInstruction * hlo)1483 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
1484   return ForEachOperandDynamicDimension(
1485       hlo,
1486       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64_t dimension,
1487           int64_t operand_index, HloInstruction* operand_dynamic_size) {
1488         if (operand_index == 0) {
1489           parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
1490           return Status::OK();
1491         }
1492 
1493         const ScatterDimensionNumbers& scatter_dims =
1494             hlo->scatter_dimension_numbers();
1495         if (operand_index == 2 &&
1496             absl::c_linear_search(scatter_dims.update_window_dims(),
1497                                   dimension)) {
1498           // Dynamic update window dimension is only allowed if it is exactly
1499           // the same as the corresponding operand dimension.
1500           std::vector<int64> update_window_dims_in_operand;
1501           for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
1502             if (absl::c_linear_search(scatter_dims.inserted_window_dims(), i)) {
1503               continue;
1504             }
1505             update_window_dims_in_operand.push_back(i);
1506           }
1507 
1508           for (int64_t i = 0; i < scatter_dims.update_window_dims_size(); ++i) {
1509             if (scatter_dims.update_window_dims(i) == dimension) {
1510               const Shape& operand_shape = hlo->operand(0)->shape();
1511               const Shape& update_shape = hlo->operand(2)->shape();
1512               int64_t dim_in_operand = update_window_dims_in_operand[i];
1513               if (operand_shape.dimensions(dim_in_operand) !=
1514                       update_shape.dimensions(dimension) ||
1515                   !operand_shape.is_dynamic_dimension(dim_in_operand)) {
1516                 return Unimplemented(
1517                     "Dynamic dimension of update window dims that are not the "
1518                     "same as corresponding operand dim is not supported: "
1519                     "%s",
1520                     hlo->ToString());
1521               }
1522               HloInstruction* base_dynamic_size = parent_->GetDynamicSize(
1523                   hlo->mutable_operand(0), {}, dim_in_operand);
1524               if (base_dynamic_size != operand_dynamic_size) {
1525                 return Unimplemented(
1526                     "Dynamic dimension size of update window dims that are not "
1527                     "the same as corresponding operand dim is not supported: "
1528                     "%s.\n Dynamic dim size of base: %s, dynamic dim size of "
1529                     "update: %s",
1530                     hlo->ToString(), base_dynamic_size->ToString(),
1531                     operand_dynamic_size->ToString());
1532               }
1533             }
1534           }
1535         }
1536         // The dynamic dimension is collapsed and won't show up in the output.
1537         // Do nothing here.
1538         return Status::OK();
1539       });
1540 }
1541 
HandleWhile(HloInstruction * hlo)1542 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
1543   // If the output of the kWhile contains dynamic dimension, we send
1544   // dynamic dimension size into the while body by adding additional root/body
1545   // element. A mapping from the root instruction's dynamic dimension index
1546   // (represented by a shape index as output index and an int64 dimension
1547   // number) to output index (represented by an int64) is tracked for the
1548   // conditional instruction.
1549   ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1550       hlo->shape());
1551   std::vector<HloInstruction*> operands_to_add;
1552   const int64_t original_tuple_count = hlo->shape().tuple_shapes_size();
1553   int64_t operand_count = original_tuple_count;
1554   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1555       hlo, [&](HloInstruction*, ShapeIndex index, int64_t dim, int64_t,
1556                HloInstruction* dynamic_size) {
1557         operands_to_add.push_back(dynamic_size);
1558         dynamic_output_mapping.mutable_element(index)->emplace(dim,
1559                                                                operand_count++);
1560         return Status::OK();
1561       }));
1562 
1563   DynamicParameterBinding binding_for_while;
1564   if (!operands_to_add.empty()) {
1565     // Only replace the while loop if there are new parameters to add.
1566     HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
1567     TF_ASSIGN_OR_RETURN(
1568         WhileUtil::MakeInstructionsLiveInResult result,
1569         WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
1570     // WhileUtil creates a new while hlo and tuple. Update the dynamic size
1571     // mapping for the newly created tuple.
1572     HloInstruction* new_tuple_operand =
1573         result.new_while_instr->mutable_operand(0);
1574     parent_->CopyMapping(/*from=*/old_tuple_operand,
1575                          /*to=*/new_tuple_operand);
1576     hlo = result.new_while_instr;
1577     // We have replaced the while loop, now set the dynamic dimensions for the
1578     // newly created while loop so that the hlos that consumes the while loop
1579     // can see the dynamic dimensions. Also sets the dynamic parameter binding
1580     // for running inference in the while loop.
1581     TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1582         hlo,
1583         [&](HloInstruction*, ShapeIndex index, int64_t dimension,
1584             int64_t operand_index, HloInstruction* dynamic_size) -> Status {
1585           TF_RET_CHECK(!operands_to_add.empty());
1586           const int64_t output_dynamic_size_index =
1587               dynamic_output_mapping.element(index).at(dimension);
1588           DynamicParameterBinding::DynamicParameter dynamic_parameter{
1589               operand_index, {output_dynamic_size_index}};
1590           DynamicParameterBinding::DynamicDimension dynamic_dimension{
1591               operand_index, index, dimension};
1592           TF_RETURN_IF_ERROR(
1593               binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
1594           // This is the updated output dynamic size coming out of hlo while
1595           // loop.
1596           HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
1597               HloInstruction::CreateGetTupleElement(
1598                   ShapeUtil::MakeScalarShape(S32), hlo,
1599                   output_dynamic_size_index));
1600           parent_->SetDynamicSize(result.replacement_instr, index, dimension,
1601                                   output_dynamic_size);
1602           return Status::OK();
1603         }));
1604     // Set the replacement instruction as visited to avoid visiting it again.
1605     SetVisited(*result.replacement_instr);
1606   }
1607 
1608   // Run inference in while body and condition.
1609   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1610       hlo->while_body(), binding_for_while, parent_));
1611   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1612       hlo->while_condition(), binding_for_while, parent_));
1613 
1614   if (operands_to_add.empty()) {
1615     // No dynamic dimension in the inputs and outputs.
1616     return Status::OK();
1617   }
1618 
1619   // The dynamic dimension size could have been changed in the loop body (e.g, A
1620   // loop that inserts items in a stack, the stack size increases with each
1621   // iteration). Rewrite the dynamic dimension size at the root.
1622   HloInstruction* body_root = hlo->while_body()->root_instruction();
1623   std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
1624                                                  nullptr);
1625 
1626   // Original non-dynamic-dim operands of root are pass-through.
1627   for (int64_t i = 0; i < original_tuple_count; ++i) {
1628     new_root_operands[i] =
1629         hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
1630             body_root->shape().tuple_shapes(i), body_root, i));
1631   }
1632   // Add dynamic dimension size as new parameters.
1633   TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1634       hlo->while_body()->root_instruction(),
1635       [&](ShapeIndex index, int64_t dim,
1636           HloInstruction* dynamic_size) -> Status {
1637         const int64_t output_index =
1638             dynamic_output_mapping.element(index).at(dim);
1639         new_root_operands[output_index] = dynamic_size;
1640         return Status::OK();
1641       }));
1642   for (auto operand : new_root_operands) {
1643     TF_RET_CHECK(operand != nullptr);
1644   }
1645   HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
1646       HloInstruction::CreateTuple(new_root_operands));
1647   hlo->while_body()->set_root_instruction(new_body_root);
1648   return Status::OK();
1649 }
1650 
HandleParameter(HloInstruction * hlo)1651 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
1652   return param_bindings_.ForEachBinding(
1653       [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1654           const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
1655         if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
1656           return Status::OK();
1657         }
1658         HloComputation* computation = hlo->parent();
1659         HloInstruction* target_parameter =
1660             computation->parameter_instruction(dynamic_dimension.parameter_num);
1661 
1662         HloInstruction* dynamic_size =
1663             computation->parameter_instruction(dynamic_parameter.parameter_num);
1664         for (int64_t i : dynamic_parameter.parameter_index) {
1665           dynamic_size =
1666               computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1667                   ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
1668                   dynamic_size, i));
1669         }
1670 
1671         parent_->SetDynamicSize(target_parameter,
1672                                 dynamic_dimension.parameter_index,
1673                                 dynamic_dimension.dimension, dynamic_size);
1674         return Status::OK();
1675       });
1676 }
1677 
ForEachDynamicDimension(HloInstruction * inst,const DynamicDimensionFn & fn)1678 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
1679     HloInstruction* inst, const DynamicDimensionFn& fn) {
1680   auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
1681   if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1682     for (auto& dynamic_dimension : iter->second) {
1683       HloInstruction* dynamic_size = parent_->GetDynamicSize(
1684           dynamic_dimension.inst, dynamic_dimension.index,
1685           dynamic_dimension.dim);
1686       TF_RETURN_IF_ERROR(
1687           fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
1688     }
1689   }
1690   return Status::OK();
1691 }
1692 
ForEachDynamicDimensionInOperand(HloInstruction * inst,int64_t operand_index,const OperandDynamicDimensionFn & fn)1693 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
1694     HloInstruction* inst, int64_t operand_index,
1695     const OperandDynamicDimensionFn& fn) {
1696   auto iter =
1697       parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
1698   if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1699     for (auto& dynamic_dimension : iter->second) {
1700       HloInstruction* dynamic_size = parent_->GetDynamicSize(
1701           dynamic_dimension.inst, dynamic_dimension.index,
1702           dynamic_dimension.dim);
1703       TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
1704                             dynamic_dimension.dim, operand_index,
1705                             dynamic_size));
1706     }
1707   }
1708   return Status::OK();
1709 }
1710 
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)1711 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
1712     HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
1713   for (int64_t operand_index = 0; operand_index < inst->operand_count();
1714        ++operand_index) {
1715     TF_RETURN_IF_ERROR(
1716         ForEachDynamicDimensionInOperand(inst, operand_index, fn));
1717   }
1718   return Status::OK();
1719 }
1720 
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim,HloInstruction * size)1721 void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
1722                                                const ShapeIndex& index,
1723                                                int64_t dim,
1724                                                HloInstruction* size) {
1725   VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
1726           << index.ToString() << "@" << dim << " to " << size->ToShortString();
1727   Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
1728   CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
1729   CHECK(dim < subshape.rank() && dim >= 0)
1730       << "Asked to set invalid dynamic dimension. Shape: "
1731       << subshape.ToString() << ", Dimension: " << dim;
1732   DynamicDimension dynamic_dimension{inst, index, dim};
1733   // Updating a dynamic dimension twice overwrites the previous one.
1734   dynamic_mapping_[dynamic_dimension] = size;
1735   auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
1736   iter.first->second.emplace(dynamic_dimension);
1737 }
1738 
CopyMapping(HloInstruction * from,HloInstruction * to)1739 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
1740                                             HloInstruction* to) {
1741   auto iter = per_hlo_dynamic_dimensions_.find(from);
1742   if (iter != per_hlo_dynamic_dimensions_.end()) {
1743     for (auto& dynamic_dimension : iter->second) {
1744       HloInstruction* dynamic_size =
1745           GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
1746                          dynamic_dimension.dim);
1747       SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
1748                      dynamic_size);
1749     }
1750   }
1751 }
1752 
1753 /* static */
Run(HloModule * module,CustomCallInferenceHandler custom_call_handler)1754 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
1755     HloModule* module, CustomCallInferenceHandler custom_call_handler) {
1756   VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
1757   DynamicDimensionInference inference(module, std::move(custom_call_handler));
1758   TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
1759   return inference;
1760 }
1761 
ToString() const1762 string DynamicDimensionInference::ToString() const {
1763   std::vector<string> pieces;
1764   pieces.push_back("DynamicDimensionInference: ");
1765   for (const auto& mapping : dynamic_mapping_) {
1766     const DynamicDimension& dynamic_dimension = mapping.first;
1767     pieces.push_back(absl::StrFormat(
1768         " -- instruction %s at %s has dim %lld as dynamic"
1769         " dimension, which is represented by instruction %s",
1770         dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
1771         dynamic_dimension.dim, mapping.second->ToString()));
1772   }
1773   return absl::StrJoin(pieces, "\n");
1774 }
1775 
DynamicDimensionInference(HloModule * module,CustomCallInferenceHandler custom_call_handler)1776 DynamicDimensionInference::DynamicDimensionInference(
1777     HloModule* module, CustomCallInferenceHandler custom_call_handler)
1778     : module_(module), custom_call_handler_(std::move(custom_call_handler)) {}
1779 
AnalyzeDynamicDimensions()1780 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
1781   return DynamicDimensionInferenceVisitor::Run(
1782       module_->entry_computation(), module_->dynamic_parameter_binding(), this,
1783       custom_call_handler_);
1784 }
1785 
ReplaceAllDynamicDimensionUsesWith(HloInstruction * replace,HloInstruction * with)1786 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
1787     HloInstruction* replace, HloInstruction* with) {
1788   CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
1789                                       ShapeUtil::MakeScalarShape(S32)));
1790   CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
1791                                       ShapeUtil::MakeScalarShape(S32)));
1792   for (auto& kv : dynamic_mapping_) {
1793     if (kv.second == replace) {
1794       kv.second = with;
1795     }
1796   }
1797 }
1798 
ForwardDynamicSize(HloInstruction * inst,HloInstruction * new_inst,const ShapeIndex & index)1799 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
1800                                                      HloInstruction* new_inst,
1801                                                      const ShapeIndex& index) {
1802   CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
1803 
1804   for (int64_t dim = 0; dim < inst->shape().rank(); ++dim) {
1805     DynamicDimension dynamic_dimension_new{new_inst, index, dim};
1806     DynamicDimension dynamic_dimension{inst, index, dim};
1807     auto iter = dynamic_mapping_.find(dynamic_dimension);
1808     if (iter != dynamic_mapping_.end()) {
1809       dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
1810       auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
1811       iter.first->second.emplace(dynamic_dimension_new);
1812     }
1813   }
1814 
1815   return Status::OK();
1816 }
1817 
HasDynamicDimension(HloInstruction * inst,ShapeIndexView index) const1818 bool DynamicDimensionInference::HasDynamicDimension(
1819     HloInstruction* inst, ShapeIndexView index) const {
1820   bool has_dynamic_dim = false;
1821   ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape,
1822                                                 const ShapeIndex& subindex) {
1823     if (subshape.IsTuple()) {
1824       return;
1825     }
1826     if (!ShapeIndexView(subindex).StartsWith(index)) {
1827       return;
1828     }
1829     for (int64_t i = 0; i < subshape.dimensions_size(); ++i) {
1830       HloInstruction* operand_dynamic_size = GetDynamicSize(inst, subindex, i);
1831       if (operand_dynamic_size != nullptr) {
1832         has_dynamic_dim = true;
1833       }
1834     }
1835   });
1836   return has_dynamic_dim;
1837 }
1838 
Update(HloInstruction * inst)1839 Status DynamicDimensionInference::Update(HloInstruction* inst) {
1840   DynamicParameterBinding parameter_binding;
1841   DynamicDimensionInferenceVisitor visitor(parameter_binding, this,
1842                                            custom_call_handler_);
1843   return inst->Visit(&visitor);
1844 }
1845 
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64_t dim) const1846 HloInstruction* DynamicDimensionInference::GetDynamicSize(
1847     HloInstruction* inst, const ShapeIndex& index, int64_t dim) const {
1848   auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
1849   if (iter != dynamic_mapping_.end()) {
1850     return iter->second;
1851   }
1852   return nullptr;
1853 }
1854 
GetDynamicSizes(HloInstruction * inst,const ShapeIndex & index) const1855 std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
1856     HloInstruction* inst, const ShapeIndex& index) const {
1857   CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
1858   const int64_t rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
1859   std::vector<HloInstruction*> result(rank, nullptr);
1860   for (int64_t i = 0; i < rank; ++i) {
1861     result[i] = GetDynamicSize(inst, {}, i);
1862   }
1863   return result;
1864 }
1865 
1866 }  // namespace xla
1867