• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 
18 #include <vector>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/tuple_util.h"
31 #include "tensorflow/compiler/xla/service/while_util.h"
32 #include "tensorflow/compiler/xla/shape_tree.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 namespace xla {
38 
39 namespace {
40 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
WidenComputation(HloComputation * narrow_comp,const Shape & wide_shape)41 StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
42                                            const Shape& wide_shape) {
43   TF_RET_CHECK(wide_shape.IsTuple());
44   const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
45   if (Shape::Equal()(wide_shape, narrow_shape)) {
46     // No need to widen the computation.
47     return narrow_comp;
48   }
49   HloComputation* wide_comp = [&]() {
50     HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
51     builder.AddInstruction(
52         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
53     return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
54   }();
55 
56   HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
57   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
58       wide_parameter, narrow_shape.tuple_shapes_size());
59   HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
60       HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
61                                  {truncated_parameter}, narrow_comp));
62   wide_comp->set_root_instruction(call_narrow_comp,
63                                   /*accept_different_shape=*/true);
64   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
65   return wide_comp;
66 }
67 }  // namespace
68 
69 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
70  public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)71   explicit DynamicDimensionInferenceVisitor(
72       const DynamicParameterBinding& param_bindings,
73       DynamicDimensionInference* parent,
74       DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)
75       : param_bindings_(param_bindings),
76         parent_(parent),
77         custom_call_handler_(std::move(custom_call_handler)) {}
78 
79   Status DefaultAction(HloInstruction* hlo) override;
80 
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler=nullptr)81   static Status Run(HloComputation* computation,
82                     const DynamicParameterBinding& param_bindings,
83                     DynamicDimensionInference* parent,
84                     DynamicDimensionInference::CustomCallInferenceHandler
85                         custom_call_handler = nullptr) {
86     DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
87                                              std::move(custom_call_handler));
88     return computation->Accept(&visitor);
89   }
90 
91   Status HandleParameter(HloInstruction* hlo) override;
92 
93   Status HandleReduce(HloInstruction* hlo) override;
94 
95   Status HandleDot(HloInstruction* hlo) override;
96 
97   Status HandleTuple(HloInstruction* hlo) override;
98 
99   Status HandleTranspose(HloInstruction* hlo) override;
100 
101   Status HandleDynamicReshape(HloInstruction* hlo) override;
102 
103   Status HandleReshape(HloInstruction* hlo) override;
104 
105   Status HandleSort(HloInstruction* hlo) override;
106 
107   Status HandlePad(HloInstruction* hlo) override;
108 
109   Status HandleCustomCall(HloInstruction* hlo) override;
110 
111   Status HandleBroadcast(HloInstruction* hlo) override;
112 
113   Status HandleGetDimensionSize(HloInstruction* hlo) override;
114 
115   Status HandleSetDimensionSize(HloInstruction* hlo) override;
116 
117   Status HandleSelect(HloInstruction* hlo) override;
118 
119   Status HandleConvolution(HloInstruction* hlo) override;
120 
121   Status HandleConcatenate(HloInstruction* hlo) override;
122 
123   Status HandleReduceWindow(HloInstruction* hlo) override;
124 
125   Status HandleReverse(HloInstruction* hlo) override;
126 
127   Status HandleSelectAndScatter(HloInstruction* hlo) override;
128 
129   Status HandleGetTupleElement(HloInstruction* hlo) override;
130 
131   Status HandleElementwiseUnary(HloInstruction* hlo) override;
132 
133   Status HandleElementwiseBinary(HloInstruction* hlo) override;
134 
135   Status HandleClamp(HloInstruction* hlo) override;
136 
137   Status HandleConditional(HloInstruction* hlo) override;
138 
139   Status HandleWhile(HloInstruction* hlo) override;
140 
141   Status HandleSlice(HloInstruction* hlo) override;
142 
143   Status HandleDynamicSlice(HloInstruction* hlo) override;
144 
145   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
146 
147   Status HandleGather(HloInstruction* hlo) override;
148 
149   Status HandleScatter(HloInstruction* hlo) override;
150 
151   Status HandleDomain(HloInstruction* hlo) override;
152 
153  private:
154   using OperandDynamicDimensionFn = std::function<Status(
155       HloInstruction* operand, ShapeIndex index, int64 dimension,
156       int64 operand_index, HloInstruction* dynamic_size)>;
157 
158   using DynamicDimensionFn = std::function<Status(
159       ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>;
160 
161   Status HandleDynamicConvolutionForward(HloInstruction* hlo,
162                                          int64 operand_index, int64 dimension,
163                                          HloInstruction* dynamic_size);
164 
165   Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
166                                             int64 operand_index,
167                                             int64 dimension);
168 
169   Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
170                                            int64 operand_index,
171                                            int64 dimension);
172 
173   Status HandleDynamicWindowSamePadding(HloInstruction* hlo,
174                                         HloInstruction* dynamic_size,
175                                         int64 operand_index, int64 dimension);
176 
177   Status ForEachOperandDynamicDimension(HloInstruction* inst,
178                                         const OperandDynamicDimensionFn&);
179   Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
180                                           int64 operand_index,
181                                           const OperandDynamicDimensionFn&);
182   Status ForEachDynamicDimension(HloInstruction* inst,
183                                  const DynamicDimensionFn& fn);
184 
185   // Pass through a dynamic dimension from the input to the output with the
186   // same value and index in the shape. This is a helper function to handle
187   // trivial instructions like elementwise operations.
188   Status PassThroughDynamicDimension(HloInstruction*);
189 
190   // The dynamic parameter bindings of this computation.
191   const DynamicParameterBinding& param_bindings_;
192 
193   // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
194   DynamicDimensionInference* parent_;
195 
196   // A handler for custom calls.
197   DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
198 };
199 
DefaultAction(HloInstruction * hlo)200 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
201   return ForEachOperandDynamicDimension(
202       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
203                int64 operand_index, HloInstruction* dynamic_size) {
204         return UnimplementedStrCat(
205             "Asked to propagate a dynamic dimension from hlo ", operand->name(),
206             "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
207             ", which is not implemented.");
208       });
209 }
210 
HandleGetTupleElement(HloInstruction * hlo)211 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
212     HloInstruction* hlo) {
213   return ForEachOperandDynamicDimension(
214       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
215                int64 operand_index, HloInstruction* dynamic_size) {
216         if (hlo->tuple_index() == index[0]) {
217           ShapeIndex new_index =
218               ShapeIndexView(index).ConsumeFront().ToShapeIndex();
219           parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
220         }
221         return Status::OK();
222       });
223 }
224 
HandleTuple(HloInstruction * hlo)225 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
226   return ForEachOperandDynamicDimension(
227       hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
228                int64 operand_index, HloInstruction* dynamic_size) {
229         index.push_front(operand_index);
230         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
231         return Status::OK();
232       });
233 }
234 
HandleBroadcast(HloInstruction * hlo)235 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
236   return ForEachOperandDynamicDimension(
237       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
238                int64 operand_index, HloInstruction* dynamic_size) {
239         int64 broadcast_dim = hlo->dimensions(dimension);
240         parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
241         return Status::OK();
242       });
243 }
244 
HandleCustomCall(HloInstruction * hlo)245 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
246   if (hlo->custom_call_target() == "PadToStatic") {
247     for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
248       if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
249         HloInstruction* dynamic_size =
250             hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
251                 ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
252         // PadToStatic converts a dynamic dimension to static dimension. It then
253         // returns the padded data output and the dynamic sizes of input
254         // dimensions.
255         ShapeIndex data_output = {0};
256         parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
257       }
258     }
259     return Status::OK();
260   }
261   if (custom_call_handler_) {
262     return custom_call_handler_(hlo, parent_);
263   }
264 
265   if (hlo->custom_call_target() == "DynamicConvolutionForward") {
266     // If input feature is dynamic and kernel feature is static, we can infer
267     // that input feature is also static.
268     // E.g.,:
269     // lhs = [B, X, Y, ?]
270     // rhs = [X, Y, I, O]
271     // dim_labels = b01f_01io
272     // We can infer that the dynamic dimension in rhs is static I.
273     const ConvolutionDimensionNumbers& dnums =
274         hlo->convolution_dimension_numbers();
275     HloInstruction* input_feature = parent_->GetDynamicSize(
276         hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
277     HloInstruction* kernel_feature = parent_->GetDynamicSize(
278         hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
279 
280     if (input_feature != nullptr && kernel_feature == nullptr) {
281       if (hlo->mutable_operand(0)->shape().dimensions(
282               dnums.input_feature_dimension()) ==
283           hlo->mutable_operand(1)->shape().dimensions(
284               dnums.kernel_input_feature_dimension()))
285         parent_->SetDynamicSize(hlo->mutable_operand(0), {},
286                                 dnums.input_feature_dimension(), nullptr);
287     }
288   }
289   return ForEachOperandDynamicDimension(
290       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
291                int64 operand_index, HloInstruction* dynamic_size) {
292         // Resize custom call should propagate dynamic batch (0) and channel (3)
293         // dimensions.
294         if (hlo->custom_call_target() == "SliceToDynamic" ||
295             hlo->custom_call_target() == "Sharding" ||
296             (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
297              (dimension == 0 || dimension == 3))) {
298           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
299           return Status::OK();
300         }
301         if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
302           if (hlo->operand_count() > 2) {
303             return Unimplemented(
304                 "DynamicReduceWindowSamePadding doesn't support variadic "
305                 "reduce window %s",
306                 hlo->ToString());
307           }
308           return HandleDynamicWindowSamePadding(hlo, dynamic_size,
309                                                 operand_index, dimension);
310         }
311 
312         if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
313           if (operand_index == 1) {
314             // Operand 0 (input) determines dynamic output size. We ignore the
315             // dynamic size in the operand 1 (output gradient).
316             return Status::OK();
317           }
318           parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
319           return Status::OK();
320         }
321 
322         if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
323           return HandleDynamicConvolutionInputGrad(hlo, operand_index,
324                                                    dimension);
325         }
326 
327         if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
328           return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
329                                                     dimension);
330         }
331 
332         if (hlo->custom_call_target() == "DynamicConvolutionForward") {
333           return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
334                                                  dynamic_size);
335         }
336         return Unimplemented(
337             "CustomCall \"%s\" is not supported to have a dynamic dimension",
338             hlo->custom_call_target());
339       });
340 }
341 
HandleSort(HloInstruction * hlo)342 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
343   return ForEachOperandDynamicDimension(
344       hlo,
345       [&](HloInstruction* operand, ShapeIndex index, int64 dynamic_dimension,
346           int64 operand_index, HloInstruction* dynamic_size) {
347         HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
348         if (sort->values_count() == 0) {
349           parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
350         } else {
351           parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
352                                   dynamic_size);
353         }
354 
355         return Status::OK();
356       });
357 }
358 
HandlePad(HloInstruction * hlo)359 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
360   return ForEachOperandDynamicDimension(
361       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
362                int64 operand_index, HloInstruction* dynamic_size) {
363         if (operand_index != 0) {
364           return Unimplemented(
365               "Dynamic dimension on padding value is not supported");
366         }
367         const PaddingConfig_PaddingConfigDimension& padding_config =
368             hlo->padding_config().dimensions(dimension);
369 
370         HloInstruction* dynamic_size_adjusted = dynamic_size;
371         if (padding_config.interior_padding() != 0) {
372           // Adjust for interior padding :
373           // Size' = max((Size - 1), 0) * interior_padding + Size
374           HloInstruction* one = hlo->parent()->AddInstruction(
375               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
376           HloInstruction* zero = hlo->parent()->AddInstruction(
377               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
378           HloInstruction* interior_padding = hlo->parent()->AddInstruction(
379               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
380                   padding_config.interior_padding())));
381           dynamic_size_adjusted =
382               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
383                   dynamic_size_adjusted->shape(), HloOpcode::kSubtract,
384                   dynamic_size_adjusted, one));
385           dynamic_size_adjusted =
386               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
387                   dynamic_size_adjusted->shape(), HloOpcode::kMaximum,
388                   dynamic_size_adjusted, zero));
389           dynamic_size_adjusted =
390               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
391                   dynamic_size_adjusted->shape(), HloOpcode::kMultiply,
392                   dynamic_size_adjusted, interior_padding));
393           dynamic_size_adjusted =
394               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
395                   dynamic_size_adjusted->shape(), HloOpcode::kAdd,
396                   dynamic_size_adjusted, dynamic_size));
397         }
398         HloInstruction* adjustment = hlo->parent()->AddInstruction(
399             HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
400                 padding_config.edge_padding_low() +
401                 padding_config.edge_padding_high())));
402         dynamic_size_adjusted =
403             hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
404                 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
405                 dynamic_size_adjusted, adjustment));
406         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
407         return Status::OK();
408       });
409 }
410 
HandleReduce(HloInstruction * hlo)411 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
412   return ForEachOperandDynamicDimension(
413       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
414                int64 operand_index, HloInstruction* dynamic_size) {
415         HloInstruction* reduce = hlo;
416         int64 operand_count = reduce->operand_count();
417         bool is_variadic_reduce = operand_count > 2;
418         CHECK_EQ(operand_count % 2, 0);
419         if (operand_index >= operand_count / 2) {
420           // Init values doesn't have dynamic size.
421           return Status::OK();
422         }
423         if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
424           // Dimension is to be reduced, stop tracing.
425           return Status::OK();
426         }
427 
428         // Find out the new dynamic dimension after reduce.
429         int64 dimensions_not_reduced_count = 0;
430         for (int i = 0; i < operand->shape().rank(); ++i) {
431           if (dimension == i) {
432             ShapeIndex result_index = {};
433 
434             if (is_variadic_reduce) {
435               // The dimensions of all data operands of a variadic reduce have
436               // to be the same.  This means that if one operand of variadic
437               // reduce has a dynamic dimension, we set all outputs to use the
438               // same dynamic size in corresponding dimensions.
439               for (int64 i = 0; i < operand_count / 2; ++i) {
440                 parent_->SetDynamicSize(
441                     reduce, {i}, dimensions_not_reduced_count, dynamic_size);
442               }
443             } else {
444               parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
445                                       dynamic_size);
446             }
447 
448             return Status::OK();
449           }
450           if (absl::c_count(reduce->dimensions(), i) == 0) {
451             dimensions_not_reduced_count++;
452           }
453         }
454 
455         return Status::OK();
456       });
457 }
458 
HandleDot(HloInstruction * hlo)459 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
460   return ForEachOperandDynamicDimension(
461       hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index,
462                int64 operand_dimension, int64 operand_index,
463                HloInstruction* dynamic_size) {
464         // There are three types of dimensions in a dot:
465         // A. batch dims
466         // B. contracting dims
467         // C. non-batch non-contracting dims.
468         // The output dimensions of a dot has three parts with the following
469         // order:
470         // [(type A), (lhs type C), (rhs type C)]
471         //
472         // Note that both lhs and rhs have the same dimension sizes for batch,
473         // but the dimension index could be different.
474         //
475         // Given one dynamic input dimension, either lhs or rhs, we use a
476         // mapping to find the corresponding output dimension.
477         HloInstruction* dot = hlo;
478         const DotDimensionNumbers& dimension_numbers =
479             dot->dot_dimension_numbers();
480         // A map from the operand dimensions to result dimension.
481         absl::flat_hash_map<int64, int64> result_dim_mapping;
482         int64 current_result_dims = 0;
483 
484         bool lhs = operand_index == 0;
485 
486         // The first loop keep tracks of batch dimension. RHS and LHS could have
487         // different batch dimension numbers.
488         if (lhs) {
489           for (int64 i : dimension_numbers.lhs_batch_dimensions()) {
490             result_dim_mapping[i] = current_result_dims++;
491           }
492         } else {
493           for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
494             result_dim_mapping[i] = current_result_dims++;
495           }
496         }
497 
498         // Handle dimensions in the lhs.
499         for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
500           // Look for non-contracting and non-batching dimension.
501           if (absl::c_linear_search(
502                   dimension_numbers.lhs_contracting_dimensions(), i)) {
503             continue;
504           }
505           if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
506                                     i)) {
507             continue;
508           }
509           if (lhs) {
510             result_dim_mapping[i] = current_result_dims;
511           }
512           current_result_dims++;
513         }
514 
515         // Handle dimensions in the rhs.
516         for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
517           // Look for non-contracting and non-batching dimension.
518           if (absl::c_linear_search(
519                   dimension_numbers.rhs_contracting_dimensions(), i)) {
520             continue;
521           }
522           if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
523                                     i)) {
524             continue;
525           }
526           if (!lhs) {
527             result_dim_mapping[i] = current_result_dims;
528           }
529           current_result_dims++;
530         }
531 
532         // Check if the operand dim is in the result shape. If so, add another
533         // work item to trace that dimension.
534         auto iter = result_dim_mapping.find(operand_dimension);
535         if (iter != result_dim_mapping.end()) {
536           parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
537         }
538 
539         return Status::OK();
540       });
541 }
542 
HandleTranspose(HloInstruction * hlo)543 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
544   return ForEachOperandDynamicDimension(
545       hlo,
546       [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
547           int64 operand_index, HloInstruction* dynamic_size) -> Status {
548         int64 permuted_dim = -1;
549         for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
550           if (hlo->dimensions()[i] == dimension) {
551             TF_RET_CHECK(permuted_dim == -1);
552             permuted_dim = i;
553           }
554         }
555         parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
556         return Status::OK();
557       });
558 }
559 
HandleConvolution(HloInstruction * hlo)560 Status DynamicDimensionInferenceVisitor::HandleConvolution(
561     HloInstruction* hlo) {
562   return ForEachOperandDynamicDimension(
563       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
564                int64 operand_index, HloInstruction* dynamic_size) {
565         HloInstruction* conv = hlo;
566         const ConvolutionDimensionNumbers& dimension_numbers =
567             conv->convolution_dimension_numbers();
568         if (operand_index == 0) {
569           if (dimension == dimension_numbers.input_batch_dimension()) {
570             parent_->SetDynamicSize(conv, {},
571                                     dimension_numbers.output_batch_dimension(),
572                                     dynamic_size);
573             return Status::OK();
574           }
575 
576           if (dimension == dimension_numbers.input_feature_dimension()) {
577             return Status::OK();
578           }
579         } else {
580           if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
581             return Status::OK();
582           }
583         }
584 
585         return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
586                              conv->ToString());
587       });
588 }
589 
HandleConcatenate(HloInstruction * hlo)590 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
591     HloInstruction* hlo) {
592   // First handle concatenate dimensions. We do this by iterating through all
593   // operands while tracking both dynamic and static dimensions.
594 
595   // static_size is used to keep track of the concated size of static
596   // dimensions.
597   int64 static_size = 0;
598   std::vector<HloInstruction*> dynamic_concat_dims;
599   for (int64 i = 0; i < hlo->operand_count(); ++i) {
600     HloInstruction* dynamic_size = parent_->GetDynamicSize(
601         hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
602     if (dynamic_size == nullptr) {
603       // This is a static dimension.
604       static_size +=
605           hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
606     } else {
607       dynamic_concat_dims.push_back(dynamic_size);
608     }
609   }
610   // If concat dimension is dynamic, calculate its size by summing up static
611   // dims and dynamic dims together.
612   if (!dynamic_concat_dims.empty()) {
613     HloInstruction* dim_size_total =
614         hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
615             LiteralUtil::CreateR0<int32>(static_size)));
616     for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
617       dim_size_total = hlo->parent()->AddInstruction(
618           HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
619                                        dim_size_total, dynamic_dim));
620     }
621     parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
622                             dim_size_total);
623   }
624 
625   // Simply pass through non-concat dynamic dimensions.
626   return ForEachOperandDynamicDimension(
627       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
628                int64 operand_index, HloInstruction* dynamic_size) {
629         int64 concatenate_dimension = hlo->concatenate_dimension();
630         if (concatenate_dimension == dimension) {
631           return Status::OK();
632         }
633         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
634         return Status::OK();
635       });
636 }
637 
HandleGetDimensionSize(HloInstruction *)638 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
639     HloInstruction*) {
640   // Dynamic dimension doesn't propagate through GetDimensionSize:
641   //
642   //   Input: F32[x, y, z]
643   //     |
644   //   GetDimensionSize(1): S32[]
645   //
646   // The returned value is a scalar, which doesn't have any dynamic dimension in
647   // the shape (although the value contains the real size of the dynamic
648   // dimension of the input).
649   return Status::OK();
650 }
651 
HandleSetDimensionSize(HloInstruction * hlo)652 Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
653     HloInstruction* hlo) {
654   bool dimension_is_static = false;
655   const HloInstruction* size = hlo->operand(1);
656   if (size->opcode() == HloOpcode::kConstant) {
657     // Check if we are setting a dimension size to its static size. If so,
658     // removes the dynamic dimension.
659     //
660     // size = s32[] constant(5)
661     // s32[2, 5] = set-dimension-size(s32[2,<=5]{1,0} %param, s32[] %size),
662     //                                                        dimensions={1}
663     // The result shape has no dynamic dimension.
664     TF_RET_CHECK(size->shape().rank() == 0);
665     if (size->literal().Get<int32>({}) ==
666         hlo->shape().dimensions(hlo->dimension())) {
667       dimension_is_static = true;
668     }
669   }
670 
671   if (!dimension_is_static) {
672     // Propagate dynamic dimension indicated by this set dimension size
673     // instruction.
674     parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
675   }
676 
677   // Also Propagate dynamic dimension already set by operands.
678   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
679       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
680                int64 operand_index, HloInstruction* dynamic_size) {
681         if (dimension != hlo->dimension()) {
682           parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
683         }
684         return Status::OK();
685       }));
686 
687   return Status::OK();
688 }
689 
HandleDynamicConvolutionForward(HloInstruction * hlo,int64 operand_index,int64 dimension,HloInstruction * dynamic_size)690 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
691     HloInstruction* hlo, int64 operand_index, int64 dimension,
692     HloInstruction* dynamic_size) {
693   TF_RET_CHECK(operand_index == 0);
694   const ConvolutionDimensionNumbers& dimension_numbers =
695       hlo->convolution_dimension_numbers();
696 
697   if (dimension == dimension_numbers.input_batch_dimension()) {
698     // Batch dimension is propagated without any changes.
699     parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
700                             dynamic_size);
701     return Status::OK();
702   }
703 
704   for (int64 spatial_dim_index = 0;
705        spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
706        ++spatial_dim_index) {
707     int64 input_spatial_dim =
708         dimension_numbers.input_spatial_dimensions(spatial_dim_index);
709     int64 output_spatial_dim =
710         dimension_numbers.output_spatial_dimensions(spatial_dim_index);
711     if (dimension == input_spatial_dim) {
712       // This is a dynamic spatial dimension. Calculate the output size.
713       WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
714       DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
715           dynamic_size, window_dim.size(), window_dim.window_dilation(),
716           window_dim.stride(), hlo->padding_type());
717       TF_RET_CHECK(window_dim.base_dilation() == 1);
718       parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
719                               dynamic_window_dims.output_size);
720       return Status::OK();
721     }
722   }
723   // Input Feature dim disappears after convolution.
724   return Status::OK();
725 }
726 
HandleDynamicWindowSamePadding(HloInstruction * hlo,HloInstruction * dynamic_size,int64 operand_index,int64 dimension)727 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
728     HloInstruction* hlo, HloInstruction* dynamic_size, int64 operand_index,
729     int64 dimension) {
730   const Window& window = hlo->window();
731   const WindowDimension& window_dim = window.dimensions(dimension);
732   if (!window_util::IsTrivialWindowDimension(window_dim)) {
733     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
734         dynamic_size, window_dim.size(), window_dim.window_dilation(),
735         window_dim.stride(), PaddingType::PADDING_SAME);
736     parent_->SetDynamicSize(hlo, {}, dimension,
737                             dynamic_window_dims.output_size);
738     return Status::OK();
739   }
740 
741   parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
742 
743   return Status::OK();
744 }
745 
HandleDynamicConvolutionInputGrad(HloInstruction * hlo,int64 operand_index,int64 dimension)746 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
747     HloInstruction* hlo, int64 operand_index, int64 dimension) {
748   // The output size of convolution input grad is corresponding input size.
749   HloInstruction* input_sizes = hlo->mutable_operand(0);
750   HloComputation* comp = hlo->parent();
751   TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
752   TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
753   TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
754                hlo->shape().dimensions_size())
755       << hlo->ToString();
756   // Slice to get corresponding input size.
757   HloInstruction* slice = comp->AddInstruction(
758       HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
759                                   {dimension}, {dimension + 1}, {1}));
760   HloInstruction* reshape = comp->AddInstruction(
761       HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
762   parent_->SetDynamicSize(hlo, {}, dimension, reshape);
763   return Status::OK();
764 }
765 
HandleDynamicConvolutionKernelGrad(HloInstruction * hlo,int64 operand_index,int64 dimension)766 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
767     HloInstruction* hlo, int64 operand_index, int64 dimension) {
768   // Dynamic convolution kernel grad produces static shape outputs.
769   return Status::OK();
770 }
771 
PassThroughDynamicDimension(HloInstruction * hlo)772 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
773     HloInstruction* hlo) {
774   return ForEachOperandDynamicDimension(
775       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
776                int64 operand_index, HloInstruction* dynamic_size) {
777         parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
778         return Status::OK();
779       });
780 }
781 
HandleDomain(HloInstruction * hlo)782 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
783   return PassThroughDynamicDimension(hlo);
784 }
785 
HandleElementwiseUnary(HloInstruction * hlo)786 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
787     HloInstruction* hlo) {
788   return PassThroughDynamicDimension(hlo);
789 }
790 
HandleSelect(HloInstruction * hlo)791 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
792   return PassThroughDynamicDimension(hlo);
793 }
794 
HandleElementwiseBinary(HloInstruction * hlo)795 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
796     HloInstruction* hlo) {
797   return PassThroughDynamicDimension(hlo);
798 }
799 
HandleClamp(HloInstruction * hlo)800 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
801   return PassThroughDynamicDimension(hlo);
802 }
803 
HandleDynamicReshape(HloInstruction * hlo)804 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
805     HloInstruction* hlo) {
806   HloDynamicReshapeInstruction* dynamic_reshape =
807       Cast<HloDynamicReshapeInstruction>(hlo);
808   for (int64 i = 0; i < hlo->shape().rank(); ++i) {
809     if (hlo->shape().is_dynamic_dimension(i)) {
810       parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
811     }
812   }
813   return Status::OK();
814 }
815 
HandleReshape(HloInstruction * hlo)816 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
817   return ForEachOperandDynamicDimension(
818       hlo,
819       [&](HloInstruction* operand, ShapeIndex index,
820           int64 input_dynamic_dimension, int64 operand_index,
821           HloInstruction* operand_dynamic_size) -> Status {
822         HloInstruction* reshape = hlo;
823         if (reshape->shape().rank() == 0) {
824           VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
825                      "undefined behavior when input size is 0. The offending "
826                      "instruction is: "
827                   << reshape->ToString();
828           return Status::OK();
829         }
830         auto common_factors = CommonFactors(operand->shape().dimensions(),
831                                             reshape->shape().dimensions());
832         int64 input_dim_start = -1;
833         int64 input_dim_end = -1;
834         int64 output_dim_start = -1;
835         int64 output_dim_end = -1;
836         // Find common_factors that the input belongs to.
837         for (int64 i = 0; i < common_factors.size() - 1; ++i) {
838           auto start = common_factors[i];
839           auto end = common_factors[i + 1];
840           if (input_dynamic_dimension >= start.first &&
841               input_dynamic_dimension < end.first) {
842             // Found the common_factor group that the input_dim belongs to.
843             input_dim_start = start.first;
844             input_dim_end = end.first;
845             output_dim_start = start.second;
846             output_dim_end = end.second;
847           }
848         }
849 
850         VLOG(2) << "Input dim start: " << input_dim_start
851                 << " Input dim end: " << input_dim_end
852                 << " output dim start: " << output_dim_start
853                 << " output dim end: " << output_dim_end;
854 
855         if ((input_dim_end - input_dim_start) > 1 &&
856             (output_dim_end - output_dim_start) > 1) {
857           // We don't support the case when a dynamic dimension is both combined
858           // with and splitted into other dimensions:
859           //
860           //  [x, yz]
861           //     | Reshape
862           //  [xy, z]
863           //
864           // TODO(yunxing): This can be supported by canonicalizing
865           // the offending reshape into two reshapes:
866           //
867           //  [x,yz]
868           //     | Reshape
869           //  [x, y, z]
870           //     | Reshape
871           //  [xy, z]
872           //
873           return Unimplemented(
874               "Dynamic input dimension to reshape that is both splitted and "
875               "combined is not supported %s",
876               hlo->ToString());
877         }
878 
879         for (auto common_factor : common_factors) {
880           // Expand common factor to include degenerated output dimensions.
881           if (common_factor.first == input_dim_start) {
882             output_dim_start = std::min(output_dim_start, common_factor.second);
883           }
884           if (common_factor.first == input_dim_end) {
885             output_dim_end = std::max(output_dim_end, common_factor.second);
886           }
887         }
888 
889         int64 output_dynamic_dimension = -1;
890 
891         if (operand->shape().dimensions(input_dynamic_dimension) == 1) {
892           // If dynamic dimension is 1, it can only be most-major or
893           // most-minor.
894           if (input_dynamic_dimension == 0) {
895             output_dynamic_dimension = 0;
896           }
897           if (input_dynamic_dimension == operand->shape().rank() - 1) {
898             output_dynamic_dimension = reshape->shape().rank() - 1;
899           }
900 
901           if (output_dynamic_dimension == -1) {
902             return Unimplemented(
903                 "Dynamic degenerated dimension that's not most-minor nor "
904                 "most-major is not supported %s",
905                 reshape->ToString());
906           }
907         }
908 
909         if (output_dynamic_dimension == -1 &&
910             output_dim_end - output_dim_start == 1) {
911           // Only one possible output dimension.
912           output_dynamic_dimension = output_dim_start;
913         }
914 
915         if (output_dynamic_dimension == -1 &&
916             output_dim_end - output_dim_start > 1) {
917           // One input dimension is splitted into multiple output dimensions.
918           // Output dimension is decomposed from input most major dimension.
919           // In this case, we don't know which one is dynamic, e.g., when we
920           // have:
921           //
922           //           [<=a/c, c, b]
923           //              | Reshape
924           //           [<=a, b] // a is dynamic, has to be multiple of c.
925           //             |  Reshape
926           // [1, 1, ... , a/c, c, b]
927           //
928           // Any dimension from the first '1' to 'a/c' can be dynamic.
929           //
930           // We use the following logics to disambiguate:
931           // 1. If the user sets "inferred_dimension", then use that as
932           // dynamic dimension.
933           // 2. If the one dimension in the reshape is dynamic, use that as
934           // dynamic dimension.
935           // E.g.:
936           //     [<=4]
937           //      |
938           //   reshape
939           //      |
940           //   [1, <=2, 2]
941           // We use second dim as dynamic dimension.
942           //
943           // 3. If all logics above cannot disambiguate, e.g.,:
944           //
945           //     [<=1]
946           //      |
947           //   reshape
948           //      |
949           //   [1, 1, 1]
950           //
951           //   We bail out and return an error.
952           // TODO(yunxing): Further simplify this, remove 1. and fully rely
953           // on 2.
954           output_dynamic_dimension = reshape->inferred_dimension();
955           if (output_dynamic_dimension == -1) {
956             // Try find dynamic dimension from the result shape.
957             for (int64 i = output_dim_start; i < output_dim_end; ++i) {
958               if (reshape->shape().is_dynamic_dimension(i)) {
959                 output_dynamic_dimension = i;
960               }
961             }
962           }
963 
964           if (output_dynamic_dimension == -1) {
965             std::vector<int64> output_non_degenerated;
966             for (int64 i = output_dim_start; i < output_dim_end; ++i) {
967               if (reshape->shape().dimensions(i) != 1) {
968                 output_non_degenerated.push_back(i);
969               }
970             }
971             if (output_non_degenerated.size() == 1) {
972               output_dynamic_dimension = output_non_degenerated[0];
973             }
974           }
975 
976           if (output_dynamic_dimension == -1) {
977             return InvalidArgument(
978                 "Reshape's input dynamic dimension is decomposed into "
979                 "multiple output dynamic dimensions, but the constraint is "
980                 "ambiguous and XLA can't infer the output dimension %s. ",
981                 hlo->ToString());
982           }
983         }
984 
985         CHECK_NE(output_dynamic_dimension, -1);
986         const int64 input_dim_size =
987             operand->shape().dimensions(input_dynamic_dimension);
988         const int64 output_dim_size =
989             reshape->shape().dimensions(output_dynamic_dimension);
990         VLOG(2) << "input_dim_size: " << input_dim_size
991                 << " output_dim_size: " << output_dim_size;
992 
993         if (input_dim_size == output_dim_size) {
994           // Simply forward dynamic dimension.
995           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
996                                   operand_dynamic_size);
997         }
998 
999         if (input_dim_size > output_dim_size) {
1000           TF_RET_CHECK(input_dim_size % output_dim_size == 0)
1001               << reshape->ToString();
1002           const int64 divisor = input_dim_size / output_dim_size;
1003           HloInstruction* divisor_hlo =
1004               hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1005                   LiteralUtil::CreateR0<int32>(divisor)));
1006 
1007           HloInstruction* new_dynamic_size =
1008               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1009                   operand_dynamic_size->shape(), HloOpcode::kDivide,
1010                   operand_dynamic_size, divisor_hlo));
1011 
1012           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1013                                   new_dynamic_size);
1014         }
1015 
1016         if (input_dim_size < output_dim_size) {
1017           // Input dimension is combined with other input dimensions.
1018           //
1019           // Adjust the output size by the ratio of dynamic_input_dim /
1020           // static_input_dim.
1021           //
1022           // For example if we have  [<=3, 3] -> [9], if the dynamic size is 2,
1023           // the new output dynamic isze is 9 / 3 * 2 = 6.
1024           //
1025           // If it turns out the second dimension is also dynamic:
1026           // [<=3, <=3] -> [9], and the dynamic size is also 2, the new output
1027           // dynamic size is 6 / 3 * 2 = 4.
1028           //
1029           //
1030           HloInstruction* output_dynamic_size =
1031               parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
1032           if (output_dynamic_size == nullptr) {
1033             output_dynamic_size =
1034                 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1035                     LiteralUtil::CreateR0<int32>(output_dim_size)));
1036           }
1037           HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(
1038               HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1039                   operand->shape().dimensions(input_dynamic_dimension))));
1040 
1041           HloInstruction* new_dynamic_size =
1042               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1043                   output_dynamic_size->shape(), HloOpcode::kDivide,
1044                   output_dynamic_size, divisor_hlo));
1045 
1046           new_dynamic_size =
1047               hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1048                   output_dynamic_size->shape(), HloOpcode::kMultiply,
1049                   new_dynamic_size, operand_dynamic_size));
1050           parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1051                                   new_dynamic_size);
1052         }
1053 
1054         return Status::OK();
1055       });
1056 }
1057 
HandleReduceWindow(HloInstruction * hlo)1058 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
1059     HloInstruction* hlo) {
1060   return ForEachOperandDynamicDimension(
1061       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
1062                int64 operand_index, HloInstruction* dynamic_size) {
1063         HloInstruction* reduce_window = hlo;
1064         const WindowDimension& window_dim =
1065             reduce_window->window().dimensions(dimension);
1066 
1067         if (!window_util::IsTrivialWindowDimension(window_dim)) {
1068           DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1069               dynamic_size, window_dim.size(), window_dim.window_dilation(),
1070               window_dim.stride(), PaddingType::PADDING_VALID);
1071           parent_->SetDynamicSize(hlo, {}, dimension,
1072                                   dynamic_window_dims.output_size);
1073           return Status::OK();
1074         }
1075 
1076         parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size);
1077 
1078         return Status::OK();
1079       });
1080 }
1081 
HandleSelectAndScatter(HloInstruction * hlo)1082 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
1083     HloInstruction* hlo) {
1084   return ForEachOperandDynamicDimension(
1085       hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
1086                int64 operand_index, HloInstruction* dynamic_size) {
1087         if (operand_index == 1) {
1088           // Operand 0 (input) determines dynamic output size. We ignore the
1089           // dynamic size in the operand 1 (output gradient).
1090           return Status::OK();
1091         }
1092         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1093 
1094         return Status::OK();
1095       });
1096 }
1097 
HandleSlice(HloInstruction * hlo)1098 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
1099   return ForEachOperandDynamicDimension(
1100       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension,
1101                int64 /*operand_index*/, HloInstruction* dynamic_size) {
1102         if (hlo->slice_starts(dimension) != 0 ||
1103             hlo->slice_strides(dimension) != 1 ||
1104             hlo->slice_limits(dimension) !=
1105                 operand->shape().dimensions(dimension)) {
1106           // Slicing a partial element out eliminates the dynamic dimension.
1107           return Status::OK();
1108         }
1109 
1110         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1111 
1112         return Status::OK();
1113       });
1114 }
1115 
HandleDynamicSlice(HloInstruction * hlo)1116 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
1117     HloInstruction* hlo) {
1118   return ForEachOperandDynamicDimension(
1119       hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64 dimension,
1120                int64 /*operand_index*/, HloInstruction* dynamic_size) {
1121         if (hlo->shape().dimensions(dimension) !=
1122             hlo->operand(0)->shape().dimensions(dimension)) {
1123           // Slicing a single element out kills the dynamic dimension.
1124           if (hlo->shape().dimensions(dimension) == 1) {
1125             return Status::OK();
1126           }
1127           return Unimplemented(
1128               "Dynamic dimension propagation on DynamicSlice where a partial "
1129               "dimension is selected %s",
1130               hlo->ToString());
1131         }
1132 
1133         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1134 
1135         return Status::OK();
1136       });
1137 }
1138 
HandleDynamicUpdateSlice(HloInstruction * hlo)1139 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
1140     HloInstruction* hlo) {
1141   return ForEachOperandDynamicDimension(
1142       hlo,
1143       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
1144           int64 operand_index, HloInstruction* dynamic_size) {
1145         if (hlo->shape().dimensions(dimension) !=
1146             hlo->operand(0)->shape().dimensions(dimension)) {
1147           return Unimplemented(
1148               "Dynamic dimension propagation on DynamicUpdateSlice where a "
1149               "partial dimension is selected %s",
1150               hlo->ToString());
1151         }
1152 
1153         if (operand_index == 1 &&
1154             hlo->operand(1)->shape().dimensions(dimension) <
1155                 hlo->operand(0)->shape().dimensions(dimension)) {
1156           // DUS(input=[A], update=[<=B])
1157           //
1158           // If update dim is smaller than input dim (B < A) , then we are doing
1159           // a partial update, no need to set the output dynamic dimension.
1160           //
1161           // The dynamic shape in `update` doesn't change output dynamic shape.
1162           return Status::OK();
1163         }
1164 
1165         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1166 
1167         return Status::OK();
1168       });
1169 }
1170 
HandleReverse(HloInstruction * hlo)1171 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
1172   return ForEachOperandDynamicDimension(
1173       hlo,
1174       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
1175           int64 /*operand_index*/, HloInstruction* dynamic_size) {
1176         if (absl::c_linear_search(hlo->dimensions(), dimension)) {
1177           return Unimplemented(
1178               "Dynamic dimension propagation on reversed dimension is not "
1179               "supported %s",
1180               hlo->ToString());
1181         }
1182         parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1183 
1184         return Status::OK();
1185       });
1186 }
1187 
HandleGather(HloInstruction * hlo)1188 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
1189   return ForEachOperandDynamicDimension(
1190       hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
1191                int64 input_dynamic_dimension, int64 operand_index,
1192                HloInstruction* dynamic_size) {
1193         const GatherDimensionNumbers& gather_dims =
1194             hlo->gather_dimension_numbers();
1195         if (operand_index != 1) {
1196           if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
1197             // Gathering a size 1 dimension out of a dynamic dimension removes
1198             // the dynamicity.
1199             return Status::OK();
1200           }
1201           if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
1202               operand->shape().dimensions(input_dynamic_dimension)) {
1203             // Gathering a full-sized dimension out of a dynamic dimension
1204             // propagates the dynamicity to output.
1205             int64 output_dimension = input_dynamic_dimension;
1206             for (int64 collapsed_dim : gather_dims.collapsed_slice_dims()) {
1207               if (collapsed_dim < input_dynamic_dimension) {
1208                 // This output dimension is collapsed.
1209                 output_dimension--;
1210               }
1211             }
1212             parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
1213             return Status::OK();
1214           }
1215           return Unimplemented(
1216               "Detects a dynamic dimension on the data input of gather, which "
1217               "is not supported: %s, %lld",
1218               hlo->ToString(), input_dynamic_dimension);
1219         }
1220         // A mapping from output to input batch dim number. -1 means not a batch
1221         // dimension.
1222         int64 indices_rank = hlo->operand(1)->shape().rank();
1223         int64 output_rank = hlo->shape().rank();
1224 
1225         // indices_dim is an iterator over indices dimensions.
1226         int64 indices_dim = 0;
1227         // Find the corresponding batch dimension in the output.
1228         for (int64 output_dim = 0; output_dim < output_rank; ++output_dim) {
1229           if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) {
1230             // Skips index vector dimension.
1231             if (indices_dim == gather_dims.index_vector_dim()) {
1232               indices_dim++;
1233             }
1234             if (indices_dim++ == input_dynamic_dimension) {
1235               parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
1236               return Status::OK();
1237             }
1238           }
1239         }
1240         CHECK(indices_dim == indices_rank);
1241 
1242         return Unimplemented(
1243             "Detects a non-batch dynamic dimension of gather, "
1244             "which is not supported: %s",
1245             hlo->ToString());
1246       });
1247 }
1248 
HandleConditional(HloInstruction * hlo)1249 Status DynamicDimensionInferenceVisitor::HandleConditional(
1250     HloInstruction* hlo) {
1251   // Conditionals are handled by producing additional inputs and outputs of
1252   // the conditional instruction.
1253   std::vector<HloComputation*> new_branch_computations;
1254   std::vector<HloInstruction*> new_operands;
1255   // If the output of the conditional contains dynamic dimension. We send
1256   // dynamic dimension size out by adding additional root element. A mapping
1257   // from the root instruction's dynamic dimension index (represented by a shape
1258   // index as output index and a int64 dimension number) to output index
1259   // (represented by an int64) is tracked for the conditional intsruction (all
1260   // branches should have the same mapping).
1261   ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1262       hlo->shape());
1263 
1264   bool need_rewrite = false;
1265 
1266   for (int64 branch_index = 0; branch_index < hlo->branch_count();
1267        ++branch_index) {
1268     std::vector<HloInstruction*> operands_to_add;
1269 
1270     absl::flat_hash_map<HloInstruction*, int64>
1271         dynamic_size_to_operand_id_index_map;
1272     // Only look at branch_index + 1, the correct operand index for a
1273     // given branch.
1274     const int64 operand_index = branch_index + 1;
1275 
1276     int64 operand_count =
1277         hlo->operand(operand_index)->shape().tuple_shapes_size();
1278     // Prepare to pass dynamic dimension into the new computation and add
1279     // dynamic dimension sizes as parameters to the new tuple.
1280     TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1281         hlo, operand_index,
1282         [&](HloInstruction*, ShapeIndex, int64, int64,
1283             HloInstruction* dynamic_size) -> Status {
1284           TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
1285               << "Only tuple typed inputs can have dynamic dimension. Please "
1286                  "file a bug against XLA team.";
1287           const HloInstruction* tuple_operand = hlo->operand(operand_index);
1288           for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
1289             // If the dynamic size is already an operand to the computation,
1290             // skip adding it to the computation input again.
1291             if (dynamic_size == tuple_operand->operand(i)) {
1292               dynamic_size_to_operand_id_index_map[dynamic_size] = i;
1293               return Status::OK();
1294             }
1295           }
1296           auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
1297           if (iter == dynamic_size_to_operand_id_index_map.end()) {
1298             operands_to_add.push_back(dynamic_size);
1299             dynamic_size_to_operand_id_index_map[dynamic_size] =
1300                 operand_count++;
1301           }
1302           return Status::OK();
1303         }));
1304 
1305     HloInstruction* original_input = hlo->mutable_operand(operand_index);
1306     HloComputation* branch_computation = hlo->branch_computation(branch_index);
1307 
1308     HloComputation* new_computation = branch_computation;
1309     HloInstruction* new_operand = hlo->mutable_operand(operand_index);
1310     if (!operands_to_add.empty()) {
1311       TF_RET_CHECK(original_input->shape().IsTuple());
1312       need_rewrite = true;
1313       new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
1314       TF_ASSIGN_OR_RETURN(
1315           new_computation,
1316           WidenComputation(branch_computation, new_operand->shape()));
1317     }
1318     // Set the dynamic dimensions for the newly created branch computation's
1319     // parameters so that the hlos inside the computation can see dynamic
1320     // dimensions.
1321     DynamicParameterBinding dynamic_parameter_binding;
1322     TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1323         hlo, operand_index,
1324         [&](HloInstruction*, ShapeIndex index, int64 dimension,
1325             int64 operand_index, HloInstruction* dynamic_size) {
1326           DynamicParameterBinding::DynamicParameter dynamic_parameter{
1327               0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
1328           DynamicParameterBinding::DynamicDimension dynamic_dimension{
1329               0, {index}, dimension};
1330           TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
1331                                                             dynamic_dimension));
1332 
1333           return Status::OK();
1334         }));
1335     VLOG(2) << "dynamic_parameter_binding for conditional branch"
1336             << dynamic_parameter_binding;
1337     TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1338         new_computation, dynamic_parameter_binding, parent_));
1339     std::vector<HloInstruction*> hlos_to_add_in_root;
1340     int64 original_tuple_count = hlo->shape().tuple_shapes_size();
1341     // There may be some dynamic dimensions coming out of the computation, wire
1342     // that into the root instruction as additional tuple elements.
1343     TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1344         new_computation->root_instruction(),
1345         [&](ShapeIndex index, int64 dim,
1346             HloInstruction* dynamic_size) -> Status {
1347           TF_RET_CHECK(hlo->shape().IsTuple())
1348               << "Only tuple typed conditionals can have dynamic dimension. "
1349                  "Please file a bug against XLA team.";
1350           dynamic_output_mapping.mutable_element(index)->emplace(
1351               dim, original_tuple_count++);
1352           hlos_to_add_in_root.push_back(dynamic_size);
1353           return Status::OK();
1354         }));
1355 
1356     VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
1357     if (!hlos_to_add_in_root.empty()) {
1358       need_rewrite = true;
1359       HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
1360           new_computation->root_instruction(), hlos_to_add_in_root);
1361       new_computation->set_root_instruction(new_branch_root,
1362                                             /*accept_different_shape=*/true);
1363     }
1364 
1365     new_branch_computations.push_back(new_computation);
1366     new_operands.push_back(new_operand);
1367   }
1368   if (!need_rewrite) {
1369     return Status::OK();
1370   }
1371   // Create a new conditional with the new operations and computations.
1372   HloInstruction* new_conditional =
1373       hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
1374           new_branch_computations[0]->root_instruction()->shape(),
1375           hlo->mutable_operand(0), new_branch_computations, new_operands));
1376 
1377   HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
1378       new_conditional, hlo->shape().tuple_shapes_size());
1379   // Now set the dynamic dimensions of the newly created conditional.
1380   dynamic_output_mapping.ForEachElement(
1381       [&](const ShapeIndex& index,
1382           const absl::flat_hash_map<int64, int64>& dim_to_output) {
1383         for (auto iter : dim_to_output) {
1384           int64 dim = iter.first;
1385           int64 output_index = iter.second;
1386           HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
1387               HloInstruction::CreateGetTupleElement(
1388                   ShapeUtil::MakeScalarShape(S32), new_conditional,
1389                   output_index));
1390           parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
1391           parent_->SetDynamicSize(new_conditional_extracted, index, dim,
1392                                   dynamic_size);
1393         }
1394       });
1395 
1396   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
1397   // Remove the original instruction even if has side-effects.
1398   TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
1399   SetVisited(*new_conditional);
1400   SetVisited(*new_conditional_extracted);
1401   return Status::OK();
1402 }
1403 
HandleScatter(HloInstruction * hlo)1404 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
1405   return ForEachOperandDynamicDimension(
1406       hlo,
1407       [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
1408           int64 operand_index, HloInstruction* operand_dynamic_size) {
1409         if (operand_index == 0) {
1410           parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
1411           return Status::OK();
1412         }
1413 
1414         const ScatterDimensionNumbers& scatter_dims =
1415             hlo->scatter_dimension_numbers();
1416         if (operand_index == 2 &&
1417             absl::c_linear_search(scatter_dims.update_window_dims(),
1418                                   dimension)) {
1419           return Unimplemented(
1420               "Dynamic dimension of update window dims is not supported: %s",
1421               hlo->ToString());
1422         }
1423         // The dynamic dimension is collapsed and won't show up in the output.
1424         // Do nothing here.
1425         return Status::OK();
1426       });
1427 }
1428 
HandleWhile(HloInstruction * hlo)1429 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
1430   // If the output of the kWhile contains dynamic dimension, we send
1431   // dynamic dimension size into the while body by adding additional root/body
1432   // element. A mapping from the root instruction's dynamic dimension index
1433   // (represented by a shape index as output index and an int64 dimension
1434   // number) to output index (represented by an int64) is tracked for the
1435   // conditional instruction.
1436   ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1437       hlo->shape());
1438   std::vector<HloInstruction*> operands_to_add;
1439   const int64 original_tuple_count = hlo->shape().tuple_shapes_size();
1440   int64 operand_count = original_tuple_count;
1441   TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1442       hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
1443                HloInstruction* dynamic_size) {
1444         operands_to_add.push_back(dynamic_size);
1445         dynamic_output_mapping.mutable_element(index)->emplace(dim,
1446                                                                operand_count++);
1447         return Status::OK();
1448       }));
1449 
1450   DynamicParameterBinding binding_for_while;
1451   if (!operands_to_add.empty()) {
1452     // Only replace the while loop if there are new parameters to add.
1453     HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
1454     TF_ASSIGN_OR_RETURN(
1455         WhileUtil::MakeInstructionsLiveInResult result,
1456         WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
1457     // WhileUtil creates a new while hlo and tuple. Update the dynamic size
1458     // mapping for the newly created tuple.
1459     HloInstruction* new_tuple_operand =
1460         result.new_while_instr->mutable_operand(0);
1461     parent_->CopyMapping(/*from=*/old_tuple_operand,
1462                          /*to=*/new_tuple_operand);
1463     hlo = result.new_while_instr;
1464     // We have replaced the while loop, now set the dynamic dimensions for the
1465     // newly created while loop so that the hlos that consumes the while loop
1466     // can see the dynamic dimensions. Also sets the dynamic parameter binding
1467     // for running inference in the while loop.
1468     TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1469         hlo,
1470         [&](HloInstruction*, ShapeIndex index, int64 dimension,
1471             int64 operand_index, HloInstruction* dynamic_size) -> Status {
1472           TF_RET_CHECK(!operands_to_add.empty());
1473           const int64 output_dynamic_size_index =
1474               dynamic_output_mapping.element(index).at(dimension);
1475           DynamicParameterBinding::DynamicParameter dynamic_parameter{
1476               operand_index, {output_dynamic_size_index}};
1477           DynamicParameterBinding::DynamicDimension dynamic_dimension{
1478               operand_index, index, dimension};
1479           TF_RETURN_IF_ERROR(
1480               binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
1481           // This is the updated output dynamic size coming out of hlo while
1482           // loop.
1483           HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
1484               HloInstruction::CreateGetTupleElement(
1485                   ShapeUtil::MakeScalarShape(S32), hlo,
1486                   output_dynamic_size_index));
1487           parent_->SetDynamicSize(result.replacement_instr, index, dimension,
1488                                   output_dynamic_size);
1489           return Status::OK();
1490         }));
1491     // Set the replacement instruction as visited to avoid visiting it again.
1492     SetVisited(*result.replacement_instr);
1493   }
1494 
1495   // Run inference in while body and condition.
1496   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1497       hlo->while_body(), binding_for_while, parent_));
1498   TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1499       hlo->while_condition(), binding_for_while, parent_));
1500 
1501   if (operands_to_add.empty()) {
1502     // No dynamic dimension in the inputs and outputs.
1503     return Status::OK();
1504   }
1505 
1506   // The dynamic dimension size could have been changed in the loop body (e.g, A
1507   // loop that inserts items in a stack, the stack size increases with each
1508   // iteration). Rewrite the dynamic dimension size at the root.
1509   HloInstruction* body_root = hlo->while_body()->root_instruction();
1510   std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
1511                                                  nullptr);
1512 
1513   // Original non-dynamic-dim operands of root are pass-through.
1514   for (int64 i = 0; i < original_tuple_count; ++i) {
1515     new_root_operands[i] =
1516         hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
1517             body_root->shape().tuple_shapes(i), body_root, i));
1518   }
1519   // Add dynamic dimension size as new parameters.
1520   TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1521       hlo->while_body()->root_instruction(),
1522       [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size) -> Status {
1523         const int64 output_index =
1524             dynamic_output_mapping.element(index).at(dim);
1525         new_root_operands[output_index] = dynamic_size;
1526         return Status::OK();
1527       }));
1528   for (auto operand : new_root_operands) {
1529     TF_RET_CHECK(operand != nullptr);
1530   }
1531   HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
1532       HloInstruction::CreateTuple(new_root_operands));
1533   hlo->while_body()->set_root_instruction(new_body_root);
1534   return Status::OK();
1535 }
1536 
HandleParameter(HloInstruction * hlo)1537 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
1538   return param_bindings_.ForEachBinding(
1539       [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1540           const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
1541         if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
1542           return Status::OK();
1543         }
1544         HloComputation* computation = hlo->parent();
1545         HloInstruction* target_parameter =
1546             computation->parameter_instruction(dynamic_dimension.parameter_num);
1547 
1548         HloInstruction* dynamic_size =
1549             computation->parameter_instruction(dynamic_parameter.parameter_num);
1550         for (int64 i : dynamic_parameter.parameter_index) {
1551           dynamic_size =
1552               computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1553                   ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
1554                   dynamic_size, i));
1555         }
1556 
1557         parent_->SetDynamicSize(target_parameter,
1558                                 dynamic_dimension.parameter_index,
1559                                 dynamic_dimension.dimension, dynamic_size);
1560         return Status::OK();
1561       });
1562 }
1563 
ForEachDynamicDimension(HloInstruction * inst,const DynamicDimensionFn & fn)1564 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
1565     HloInstruction* inst, const DynamicDimensionFn& fn) {
1566   auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
1567   if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1568     for (auto& dynamic_dimension : iter->second) {
1569       HloInstruction* dynamic_size = parent_->GetDynamicSize(
1570           dynamic_dimension.inst, dynamic_dimension.index,
1571           dynamic_dimension.dim);
1572       TF_RETURN_IF_ERROR(
1573           fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
1574     }
1575   }
1576   return Status::OK();
1577 }
1578 
ForEachDynamicDimensionInOperand(HloInstruction * inst,int64 operand_index,const OperandDynamicDimensionFn & fn)1579 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
1580     HloInstruction* inst, int64 operand_index,
1581     const OperandDynamicDimensionFn& fn) {
1582   auto iter =
1583       parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
1584   if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1585     for (auto& dynamic_dimension : iter->second) {
1586       HloInstruction* dynamic_size = parent_->GetDynamicSize(
1587           dynamic_dimension.inst, dynamic_dimension.index,
1588           dynamic_dimension.dim);
1589       TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
1590                             dynamic_dimension.dim, operand_index,
1591                             dynamic_size));
1592     }
1593   }
1594   return Status::OK();
1595 }
1596 
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)1597 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
1598     HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
1599   for (int64 operand_index = 0; operand_index < inst->operand_count();
1600        ++operand_index) {
1601     TF_RETURN_IF_ERROR(
1602         ForEachDynamicDimensionInOperand(inst, operand_index, fn));
1603   }
1604   return Status::OK();
1605 }
1606 
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim,HloInstruction * size)1607 void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
1608                                                const ShapeIndex& index,
1609                                                int64 dim,
1610                                                HloInstruction* size) {
1611   VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
1612           << index.ToString() << "@" << dim << " to " << size->ToShortString();
1613   Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
1614   CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
1615   CHECK(dim < subshape.rank() && dim >= 0)
1616       << "Asked to set invalid dynamic dimension. Shape: "
1617       << subshape.ToString() << ", Dimension: " << dim;
1618   DynamicDimension dynamic_dimension{inst, index, dim};
1619   // Updating a dynamic dimension twice overwrites the previous one.
1620   dynamic_mapping_[dynamic_dimension] = size;
1621   auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
1622   iter.first->second.emplace(dynamic_dimension);
1623 }
1624 
CopyMapping(HloInstruction * from,HloInstruction * to)1625 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
1626                                             HloInstruction* to) {
1627   auto iter = per_hlo_dynamic_dimensions_.find(from);
1628   if (iter != per_hlo_dynamic_dimensions_.end()) {
1629     for (auto& dynamic_dimension : iter->second) {
1630       HloInstruction* dynamic_size =
1631           GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
1632                          dynamic_dimension.dim);
1633       SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
1634                      dynamic_size);
1635     }
1636   }
1637 }
1638 
1639 /* static */
Run(HloModule * module,CustomCallInferenceHandler custom_call_handler)1640 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
1641     HloModule* module, CustomCallInferenceHandler custom_call_handler) {
1642   VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
1643   DynamicDimensionInference inference(module, std::move(custom_call_handler));
1644   TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
1645   return inference;
1646 }
1647 
ToString() const1648 string DynamicDimensionInference::ToString() const {
1649   std::vector<string> pieces;
1650   pieces.push_back("DynamicDimensionInference: ");
1651   for (const auto& mapping : dynamic_mapping_) {
1652     const DynamicDimension& dynamic_dimension = mapping.first;
1653     pieces.push_back(absl::StrFormat(
1654         " -- instruction %s at %s has dim %lld as dynamic"
1655         " dimension, which is represented by instruction %s",
1656         dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
1657         dynamic_dimension.dim, mapping.second->ToString()));
1658   }
1659   return absl::StrJoin(pieces, "\n");
1660 }
1661 
DynamicDimensionInference(HloModule * module,CustomCallInferenceHandler custom_call_handler)1662 DynamicDimensionInference::DynamicDimensionInference(
1663     HloModule* module, CustomCallInferenceHandler custom_call_handler)
1664     : module_(module), custom_call_handler_(std::move(custom_call_handler)) {}
1665 
AnalyzeDynamicDimensions()1666 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
1667   return DynamicDimensionInferenceVisitor::Run(
1668       module_->entry_computation(), module_->dynamic_parameter_binding(), this,
1669       custom_call_handler_);
1670 }
1671 
ReplaceAllDynamicDimensionUsesWith(HloInstruction * replace,HloInstruction * with)1672 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
1673     HloInstruction* replace, HloInstruction* with) {
1674   CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
1675                                       ShapeUtil::MakeScalarShape(S32)));
1676   CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
1677                                       ShapeUtil::MakeScalarShape(S32)));
1678   for (auto& kv : dynamic_mapping_) {
1679     if (kv.second == replace) {
1680       kv.second = with;
1681     }
1682   }
1683 }
1684 
ForwardDynamicSize(HloInstruction * inst,HloInstruction * new_inst,const ShapeIndex & index)1685 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
1686                                                      HloInstruction* new_inst,
1687                                                      const ShapeIndex& index) {
1688   CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
1689 
1690   for (int64 dim = 0; dim < inst->shape().rank(); ++dim) {
1691     DynamicDimension dynamic_dimension_new{new_inst, index, dim};
1692     DynamicDimension dynamic_dimension{inst, index, dim};
1693     auto iter = dynamic_mapping_.find(dynamic_dimension);
1694     if (iter != dynamic_mapping_.end()) {
1695       dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
1696       auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
1697       iter.first->second.emplace(dynamic_dimension_new);
1698     }
1699   }
1700 
1701   return Status::OK();
1702 }
1703 
HasDynamicDimension(HloInstruction * inst) const1704 bool DynamicDimensionInference::HasDynamicDimension(
1705     HloInstruction* inst) const {
1706   bool has_dynamic_dim = false;
1707   ShapeUtil::ForEachSubshape(
1708       inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1709         if (subshape.IsTuple()) {
1710           return;
1711         }
1712         for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
1713           HloInstruction* operand_dynamic_size = GetDynamicSize(inst, index, i);
1714           if (operand_dynamic_size != nullptr) {
1715             has_dynamic_dim = true;
1716           }
1717         }
1718       });
1719   return has_dynamic_dim;
1720 }
1721 
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim) const1722 HloInstruction* DynamicDimensionInference::GetDynamicSize(
1723     HloInstruction* inst, const ShapeIndex& index, int64 dim) const {
1724   auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
1725   if (iter != dynamic_mapping_.end()) {
1726     return iter->second;
1727   }
1728   return nullptr;
1729 }
1730 
GetDynamicSizes(HloInstruction * inst,const ShapeIndex & index) const1731 std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
1732     HloInstruction* inst, const ShapeIndex& index) const {
1733   CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
1734   const int64 rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
1735   std::vector<HloInstruction*> result(rank, nullptr);
1736   for (int64 i = 0; i < rank; ++i) {
1737     result[i] = GetDynamicSize(inst, {}, i);
1738   }
1739   return result;
1740 }
1741 
1742 }  // namespace xla
1743