1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17
18 #include <vector>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "tensorflow/compiler/xla/literal_util.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/tuple_util.h"
31 #include "tensorflow/compiler/xla/service/while_util.h"
32 #include "tensorflow/compiler/xla/shape_tree.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/window_util.h"
37 namespace xla {
38
39 namespace {
40 // Replace `narrow_comp` with a new computation with `wide_shape` as input.
WidenComputation(HloComputation * narrow_comp,const Shape & wide_shape)41 StatusOr<HloComputation*> WidenComputation(HloComputation* narrow_comp,
42 const Shape& wide_shape) {
43 TF_RET_CHECK(wide_shape.IsTuple());
44 const Shape& narrow_shape = narrow_comp->parameter_instruction(0)->shape();
45 if (Shape::Equal()(wide_shape, narrow_shape)) {
46 // No need to widen the computation.
47 return narrow_comp;
48 }
49 HloComputation* wide_comp = [&]() {
50 HloComputation::Builder builder(absl::StrCat("wide.", narrow_comp->name()));
51 builder.AddInstruction(
52 HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
53 return narrow_comp->parent()->AddEmbeddedComputation(builder.Build());
54 }();
55
56 HloInstruction* wide_parameter = wide_comp->parameter_instruction(0);
57 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
58 wide_parameter, narrow_shape.tuple_shapes_size());
59 HloInstruction* call_narrow_comp = wide_comp->AddInstruction(
60 HloInstruction::CreateCall(narrow_comp->root_instruction()->shape(),
61 {truncated_parameter}, narrow_comp));
62 wide_comp->set_root_instruction(call_narrow_comp,
63 /*accept_different_shape=*/true);
64 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_comp).status());
65 return wide_comp;
66 }
67 } // namespace
68
69 class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
70 public:
DynamicDimensionInferenceVisitor(const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)71 explicit DynamicDimensionInferenceVisitor(
72 const DynamicParameterBinding& param_bindings,
73 DynamicDimensionInference* parent,
74 DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler)
75 : param_bindings_(param_bindings),
76 parent_(parent),
77 custom_call_handler_(std::move(custom_call_handler)) {}
78
79 Status DefaultAction(HloInstruction* hlo) override;
80
Run(HloComputation * computation,const DynamicParameterBinding & param_bindings,DynamicDimensionInference * parent,DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler=nullptr)81 static Status Run(HloComputation* computation,
82 const DynamicParameterBinding& param_bindings,
83 DynamicDimensionInference* parent,
84 DynamicDimensionInference::CustomCallInferenceHandler
85 custom_call_handler = nullptr) {
86 DynamicDimensionInferenceVisitor visitor(param_bindings, parent,
87 std::move(custom_call_handler));
88 return computation->Accept(&visitor);
89 }
90
91 Status HandleParameter(HloInstruction* hlo) override;
92
93 Status HandleReduce(HloInstruction* hlo) override;
94
95 Status HandleDot(HloInstruction* hlo) override;
96
97 Status HandleTuple(HloInstruction* hlo) override;
98
99 Status HandleTranspose(HloInstruction* hlo) override;
100
101 Status HandleDynamicReshape(HloInstruction* hlo) override;
102
103 Status HandleReshape(HloInstruction* hlo) override;
104
105 Status HandleSort(HloInstruction* hlo) override;
106
107 Status HandlePad(HloInstruction* hlo) override;
108
109 Status HandleCustomCall(HloInstruction* hlo) override;
110
111 Status HandleBroadcast(HloInstruction* hlo) override;
112
113 Status HandleGetDimensionSize(HloInstruction* hlo) override;
114
115 Status HandleSetDimensionSize(HloInstruction* hlo) override;
116
117 Status HandleSelect(HloInstruction* hlo) override;
118
119 Status HandleConvolution(HloInstruction* hlo) override;
120
121 Status HandleConcatenate(HloInstruction* hlo) override;
122
123 Status HandleReduceWindow(HloInstruction* hlo) override;
124
125 Status HandleReverse(HloInstruction* hlo) override;
126
127 Status HandleSelectAndScatter(HloInstruction* hlo) override;
128
129 Status HandleGetTupleElement(HloInstruction* hlo) override;
130
131 Status HandleElementwiseUnary(HloInstruction* hlo) override;
132
133 Status HandleElementwiseBinary(HloInstruction* hlo) override;
134
135 Status HandleClamp(HloInstruction* hlo) override;
136
137 Status HandleConditional(HloInstruction* hlo) override;
138
139 Status HandleWhile(HloInstruction* hlo) override;
140
141 Status HandleSlice(HloInstruction* hlo) override;
142
143 Status HandleDynamicSlice(HloInstruction* hlo) override;
144
145 Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
146
147 Status HandleGather(HloInstruction* hlo) override;
148
149 Status HandleScatter(HloInstruction* hlo) override;
150
151 Status HandleDomain(HloInstruction* hlo) override;
152
153 private:
154 using OperandDynamicDimensionFn = std::function<Status(
155 HloInstruction* operand, ShapeIndex index, int64 dimension,
156 int64 operand_index, HloInstruction* dynamic_size)>;
157
158 using DynamicDimensionFn = std::function<Status(
159 ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>;
160
161 Status HandleDynamicConvolutionForward(HloInstruction* hlo,
162 int64 operand_index, int64 dimension,
163 HloInstruction* dynamic_size);
164
165 Status HandleDynamicConvolutionKernelGrad(HloInstruction* hlo,
166 int64 operand_index,
167 int64 dimension);
168
169 Status HandleDynamicConvolutionInputGrad(HloInstruction* hlo,
170 int64 operand_index,
171 int64 dimension);
172
173 Status HandleDynamicWindowSamePadding(HloInstruction* hlo,
174 HloInstruction* dynamic_size,
175 int64 operand_index, int64 dimension);
176
177 Status ForEachOperandDynamicDimension(HloInstruction* inst,
178 const OperandDynamicDimensionFn&);
179 Status ForEachDynamicDimensionInOperand(HloInstruction* inst,
180 int64 operand_index,
181 const OperandDynamicDimensionFn&);
182 Status ForEachDynamicDimension(HloInstruction* inst,
183 const DynamicDimensionFn& fn);
184
185 // Pass through a dynamic dimension from the input to the output with the
186 // same value and index in the shape. This is a helper function to handle
187 // trivial instructions like elementwise operations.
188 Status PassThroughDynamicDimension(HloInstruction*);
189
190 // The dynamic parameter bindings of this computation.
191 const DynamicParameterBinding& param_bindings_;
192
193 // A pointer to DynamicDimensionInference, used to update the dynamic mapping.
194 DynamicDimensionInference* parent_;
195
196 // A handler for custom calls.
197 DynamicDimensionInference::CustomCallInferenceHandler custom_call_handler_;
198 };
199
DefaultAction(HloInstruction * hlo)200 Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) {
201 return ForEachOperandDynamicDimension(
202 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
203 int64 operand_index, HloInstruction* dynamic_size) {
204 return UnimplementedStrCat(
205 "Asked to propagate a dynamic dimension from hlo ", operand->name(),
206 "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(),
207 ", which is not implemented.");
208 });
209 }
210
HandleGetTupleElement(HloInstruction * hlo)211 Status DynamicDimensionInferenceVisitor::HandleGetTupleElement(
212 HloInstruction* hlo) {
213 return ForEachOperandDynamicDimension(
214 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
215 int64 operand_index, HloInstruction* dynamic_size) {
216 if (hlo->tuple_index() == index[0]) {
217 ShapeIndex new_index =
218 ShapeIndexView(index).ConsumeFront().ToShapeIndex();
219 parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size);
220 }
221 return Status::OK();
222 });
223 }
224
HandleTuple(HloInstruction * hlo)225 Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) {
226 return ForEachOperandDynamicDimension(
227 hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension,
228 int64 operand_index, HloInstruction* dynamic_size) {
229 index.push_front(operand_index);
230 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
231 return Status::OK();
232 });
233 }
234
HandleBroadcast(HloInstruction * hlo)235 Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
236 return ForEachOperandDynamicDimension(
237 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
238 int64 operand_index, HloInstruction* dynamic_size) {
239 int64 broadcast_dim = hlo->dimensions(dimension);
240 parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size);
241 return Status::OK();
242 });
243 }
244
HandleCustomCall(HloInstruction * hlo)245 Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) {
246 if (hlo->custom_call_target() == "PadToStatic") {
247 for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
248 if (hlo->operand(0)->shape().is_dynamic_dimension(i)) {
249 HloInstruction* dynamic_size =
250 hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
251 ShapeUtil::MakeScalarShape(S32), hlo, i + 1));
252 // PadToStatic converts a dynamic dimension to static dimension. It then
253 // returns the padded data output and the dynamic sizes of input
254 // dimensions.
255 ShapeIndex data_output = {0};
256 parent_->SetDynamicSize(hlo, data_output, i, dynamic_size);
257 }
258 }
259 return Status::OK();
260 }
261 if (custom_call_handler_) {
262 return custom_call_handler_(hlo, parent_);
263 }
264
265 if (hlo->custom_call_target() == "DynamicConvolutionForward") {
266 // If input feature is dynamic and kernel feature is static, we can infer
267 // that input feature is also static.
268 // E.g.,:
269 // lhs = [B, X, Y, ?]
270 // rhs = [X, Y, I, O]
271 // dim_labels = b01f_01io
272 // We can infer that the dynamic dimension in rhs is static I.
273 const ConvolutionDimensionNumbers& dnums =
274 hlo->convolution_dimension_numbers();
275 HloInstruction* input_feature = parent_->GetDynamicSize(
276 hlo->mutable_operand(0), {}, dnums.input_feature_dimension());
277 HloInstruction* kernel_feature = parent_->GetDynamicSize(
278 hlo->mutable_operand(1), {}, dnums.kernel_input_feature_dimension());
279
280 if (input_feature != nullptr && kernel_feature == nullptr) {
281 if (hlo->mutable_operand(0)->shape().dimensions(
282 dnums.input_feature_dimension()) ==
283 hlo->mutable_operand(1)->shape().dimensions(
284 dnums.kernel_input_feature_dimension()))
285 parent_->SetDynamicSize(hlo->mutable_operand(0), {},
286 dnums.input_feature_dimension(), nullptr);
287 }
288 }
289 return ForEachOperandDynamicDimension(
290 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
291 int64 operand_index, HloInstruction* dynamic_size) {
292 // Resize custom call should propagate dynamic batch (0) and channel (3)
293 // dimensions.
294 if (hlo->custom_call_target() == "SliceToDynamic" ||
295 hlo->custom_call_target() == "Sharding" ||
296 (absl::StartsWith(hlo->custom_call_target(), "Resize") &&
297 (dimension == 0 || dimension == 3))) {
298 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
299 return Status::OK();
300 }
301 if (hlo->custom_call_target() == "DynamicReduceWindowSamePadding") {
302 if (hlo->operand_count() > 2) {
303 return Unimplemented(
304 "DynamicReduceWindowSamePadding doesn't support variadic "
305 "reduce window %s",
306 hlo->ToString());
307 }
308 return HandleDynamicWindowSamePadding(hlo, dynamic_size,
309 operand_index, dimension);
310 }
311
312 if (hlo->custom_call_target() == "DynamicSelectAndScatterSamePadding") {
313 if (operand_index == 1) {
314 // Operand 0 (input) determines dynamic output size. We ignore the
315 // dynamic size in the operand 1 (output gradient).
316 return Status::OK();
317 }
318 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
319 return Status::OK();
320 }
321
322 if (hlo->custom_call_target() == "DynamicConvolutionInputGrad") {
323 return HandleDynamicConvolutionInputGrad(hlo, operand_index,
324 dimension);
325 }
326
327 if (hlo->custom_call_target() == "DynamicConvolutionKernelGrad") {
328 return HandleDynamicConvolutionKernelGrad(hlo, operand_index,
329 dimension);
330 }
331
332 if (hlo->custom_call_target() == "DynamicConvolutionForward") {
333 return HandleDynamicConvolutionForward(hlo, operand_index, dimension,
334 dynamic_size);
335 }
336 return Unimplemented(
337 "CustomCall \"%s\" is not supported to have a dynamic dimension",
338 hlo->custom_call_target());
339 });
340 }
341
HandleSort(HloInstruction * hlo)342 Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) {
343 return ForEachOperandDynamicDimension(
344 hlo,
345 [&](HloInstruction* operand, ShapeIndex index, int64 dynamic_dimension,
346 int64 operand_index, HloInstruction* dynamic_size) {
347 HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
348 if (sort->values_count() == 0) {
349 parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size);
350 } else {
351 parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension,
352 dynamic_size);
353 }
354
355 return Status::OK();
356 });
357 }
358
HandlePad(HloInstruction * hlo)359 Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) {
360 return ForEachOperandDynamicDimension(
361 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
362 int64 operand_index, HloInstruction* dynamic_size) {
363 if (operand_index != 0) {
364 return Unimplemented(
365 "Dynamic dimension on padding value is not supported");
366 }
367 const PaddingConfig_PaddingConfigDimension& padding_config =
368 hlo->padding_config().dimensions(dimension);
369
370 HloInstruction* dynamic_size_adjusted = dynamic_size;
371 if (padding_config.interior_padding() != 0) {
372 // Adjust for interior padding :
373 // Size' = max((Size - 1), 0) * interior_padding + Size
374 HloInstruction* one = hlo->parent()->AddInstruction(
375 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
376 HloInstruction* zero = hlo->parent()->AddInstruction(
377 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
378 HloInstruction* interior_padding = hlo->parent()->AddInstruction(
379 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
380 padding_config.interior_padding())));
381 dynamic_size_adjusted =
382 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
383 dynamic_size_adjusted->shape(), HloOpcode::kSubtract,
384 dynamic_size_adjusted, one));
385 dynamic_size_adjusted =
386 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
387 dynamic_size_adjusted->shape(), HloOpcode::kMaximum,
388 dynamic_size_adjusted, zero));
389 dynamic_size_adjusted =
390 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
391 dynamic_size_adjusted->shape(), HloOpcode::kMultiply,
392 dynamic_size_adjusted, interior_padding));
393 dynamic_size_adjusted =
394 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
395 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
396 dynamic_size_adjusted, dynamic_size));
397 }
398 HloInstruction* adjustment = hlo->parent()->AddInstruction(
399 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
400 padding_config.edge_padding_low() +
401 padding_config.edge_padding_high())));
402 dynamic_size_adjusted =
403 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
404 dynamic_size_adjusted->shape(), HloOpcode::kAdd,
405 dynamic_size_adjusted, adjustment));
406 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted);
407 return Status::OK();
408 });
409 }
410
HandleReduce(HloInstruction * hlo)411 Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) {
412 return ForEachOperandDynamicDimension(
413 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
414 int64 operand_index, HloInstruction* dynamic_size) {
415 HloInstruction* reduce = hlo;
416 int64 operand_count = reduce->operand_count();
417 bool is_variadic_reduce = operand_count > 2;
418 CHECK_EQ(operand_count % 2, 0);
419 if (operand_index >= operand_count / 2) {
420 // Init values doesn't have dynamic size.
421 return Status::OK();
422 }
423 if ((absl::c_count(reduce->dimensions(), dimension) != 0)) {
424 // Dimension is to be reduced, stop tracing.
425 return Status::OK();
426 }
427
428 // Find out the new dynamic dimension after reduce.
429 int64 dimensions_not_reduced_count = 0;
430 for (int i = 0; i < operand->shape().rank(); ++i) {
431 if (dimension == i) {
432 ShapeIndex result_index = {};
433
434 if (is_variadic_reduce) {
435 // The dimensions of all data operands of a variadic reduce have
436 // to be the same. This means that if one operand of variadic
437 // reduce has a dynamic dimension, we set all outputs to use the
438 // same dynamic size in corresponding dimensions.
439 for (int64 i = 0; i < operand_count / 2; ++i) {
440 parent_->SetDynamicSize(
441 reduce, {i}, dimensions_not_reduced_count, dynamic_size);
442 }
443 } else {
444 parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count,
445 dynamic_size);
446 }
447
448 return Status::OK();
449 }
450 if (absl::c_count(reduce->dimensions(), i) == 0) {
451 dimensions_not_reduced_count++;
452 }
453 }
454
455 return Status::OK();
456 });
457 }
458
HandleDot(HloInstruction * hlo)459 Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) {
460 return ForEachOperandDynamicDimension(
461 hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index,
462 int64 operand_dimension, int64 operand_index,
463 HloInstruction* dynamic_size) {
464 // There are three types of dimensions in a dot:
465 // A. batch dims
466 // B. contracting dims
467 // C. non-batch non-contracting dims.
468 // The output dimensions of a dot has three parts with the following
469 // order:
470 // [(type A), (lhs type C), (rhs type C)]
471 //
472 // Note that both lhs and rhs have the same dimension sizes for batch,
473 // but the dimension index could be different.
474 //
475 // Given one dynamic input dimension, either lhs or rhs, we use a
476 // mapping to find the corresponding output dimension.
477 HloInstruction* dot = hlo;
478 const DotDimensionNumbers& dimension_numbers =
479 dot->dot_dimension_numbers();
480 // A map from the operand dimensions to result dimension.
481 absl::flat_hash_map<int64, int64> result_dim_mapping;
482 int64 current_result_dims = 0;
483
484 bool lhs = operand_index == 0;
485
486 // The first loop keep tracks of batch dimension. RHS and LHS could have
487 // different batch dimension numbers.
488 if (lhs) {
489 for (int64 i : dimension_numbers.lhs_batch_dimensions()) {
490 result_dim_mapping[i] = current_result_dims++;
491 }
492 } else {
493 for (int64 i : dimension_numbers.rhs_batch_dimensions()) {
494 result_dim_mapping[i] = current_result_dims++;
495 }
496 }
497
498 // Handle dimensions in the lhs.
499 for (int64 i = 0; i < dot->operand(0)->shape().rank(); i++) {
500 // Look for non-contracting and non-batching dimension.
501 if (absl::c_linear_search(
502 dimension_numbers.lhs_contracting_dimensions(), i)) {
503 continue;
504 }
505 if (absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
506 i)) {
507 continue;
508 }
509 if (lhs) {
510 result_dim_mapping[i] = current_result_dims;
511 }
512 current_result_dims++;
513 }
514
515 // Handle dimensions in the rhs.
516 for (int64 i = 0; i < dot->operand(1)->shape().rank(); i++) {
517 // Look for non-contracting and non-batching dimension.
518 if (absl::c_linear_search(
519 dimension_numbers.rhs_contracting_dimensions(), i)) {
520 continue;
521 }
522 if (absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
523 i)) {
524 continue;
525 }
526 if (!lhs) {
527 result_dim_mapping[i] = current_result_dims;
528 }
529 current_result_dims++;
530 }
531
532 // Check if the operand dim is in the result shape. If so, add another
533 // work item to trace that dimension.
534 auto iter = result_dim_mapping.find(operand_dimension);
535 if (iter != result_dim_mapping.end()) {
536 parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size);
537 }
538
539 return Status::OK();
540 });
541 }
542
HandleTranspose(HloInstruction * hlo)543 Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) {
544 return ForEachOperandDynamicDimension(
545 hlo,
546 [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
547 int64 operand_index, HloInstruction* dynamic_size) -> Status {
548 int64 permuted_dim = -1;
549 for (int64 i = 0; i < hlo->dimensions().size(); ++i) {
550 if (hlo->dimensions()[i] == dimension) {
551 TF_RET_CHECK(permuted_dim == -1);
552 permuted_dim = i;
553 }
554 }
555 parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size);
556 return Status::OK();
557 });
558 }
559
HandleConvolution(HloInstruction * hlo)560 Status DynamicDimensionInferenceVisitor::HandleConvolution(
561 HloInstruction* hlo) {
562 return ForEachOperandDynamicDimension(
563 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
564 int64 operand_index, HloInstruction* dynamic_size) {
565 HloInstruction* conv = hlo;
566 const ConvolutionDimensionNumbers& dimension_numbers =
567 conv->convolution_dimension_numbers();
568 if (operand_index == 0) {
569 if (dimension == dimension_numbers.input_batch_dimension()) {
570 parent_->SetDynamicSize(conv, {},
571 dimension_numbers.output_batch_dimension(),
572 dynamic_size);
573 return Status::OK();
574 }
575
576 if (dimension == dimension_numbers.input_feature_dimension()) {
577 return Status::OK();
578 }
579 } else {
580 if (dimension == dimension_numbers.kernel_input_feature_dimension()) {
581 return Status::OK();
582 }
583 }
584
585 return Unimplemented("Dynamic Spatial Convolution is not supported: %s",
586 conv->ToString());
587 });
588 }
589
HandleConcatenate(HloInstruction * hlo)590 Status DynamicDimensionInferenceVisitor::HandleConcatenate(
591 HloInstruction* hlo) {
592 // First handle concatenate dimensions. We do this by iterating through all
593 // operands while tracking both dynamic and static dimensions.
594
595 // static_size is used to keep track of the concated size of static
596 // dimensions.
597 int64 static_size = 0;
598 std::vector<HloInstruction*> dynamic_concat_dims;
599 for (int64 i = 0; i < hlo->operand_count(); ++i) {
600 HloInstruction* dynamic_size = parent_->GetDynamicSize(
601 hlo->mutable_operand(i), {}, hlo->concatenate_dimension());
602 if (dynamic_size == nullptr) {
603 // This is a static dimension.
604 static_size +=
605 hlo->operand(i)->shape().dimensions(hlo->concatenate_dimension());
606 } else {
607 dynamic_concat_dims.push_back(dynamic_size);
608 }
609 }
610 // If concat dimension is dynamic, calculate its size by summing up static
611 // dims and dynamic dims together.
612 if (!dynamic_concat_dims.empty()) {
613 HloInstruction* dim_size_total =
614 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
615 LiteralUtil::CreateR0<int32>(static_size)));
616 for (HloInstruction* dynamic_dim : dynamic_concat_dims) {
617 dim_size_total = hlo->parent()->AddInstruction(
618 HloInstruction::CreateBinary(dim_size_total->shape(), HloOpcode::kAdd,
619 dim_size_total, dynamic_dim));
620 }
621 parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(),
622 dim_size_total);
623 }
624
625 // Simply pass through non-concat dynamic dimensions.
626 return ForEachOperandDynamicDimension(
627 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
628 int64 operand_index, HloInstruction* dynamic_size) {
629 int64 concatenate_dimension = hlo->concatenate_dimension();
630 if (concatenate_dimension == dimension) {
631 return Status::OK();
632 }
633 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
634 return Status::OK();
635 });
636 }
637
HandleGetDimensionSize(HloInstruction *)638 Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize(
639 HloInstruction*) {
640 // Dynamic dimension doesn't propagate through GetDimensionSize:
641 //
642 // Input: F32[x, y, z]
643 // |
644 // GetDimensionSize(1): S32[]
645 //
646 // The returned value is a scalar, which doesn't have any dynamic dimension in
647 // the shape (although the value contains the real size of the dynamic
648 // dimension of the input).
649 return Status::OK();
650 }
651
HandleSetDimensionSize(HloInstruction * hlo)652 Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize(
653 HloInstruction* hlo) {
654 bool dimension_is_static = false;
655 const HloInstruction* size = hlo->operand(1);
656 if (size->opcode() == HloOpcode::kConstant) {
657 // Check if we are setting a dimension size to its static size. If so,
658 // removes the dynamic dimension.
659 //
660 // size = s32[] constant(5)
661 // s32[2, 5] = set-dimension-size(s32[2,<=5]{1,0} %param, s32[] %size),
662 // dimensions={1}
663 // The result shape has no dynamic dimension.
664 TF_RET_CHECK(size->shape().rank() == 0);
665 if (size->literal().Get<int32>({}) ==
666 hlo->shape().dimensions(hlo->dimension())) {
667 dimension_is_static = true;
668 }
669 }
670
671 if (!dimension_is_static) {
672 // Propagate dynamic dimension indicated by this set dimension size
673 // instruction.
674 parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1));
675 }
676
677 // Also Propagate dynamic dimension already set by operands.
678 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
679 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
680 int64 operand_index, HloInstruction* dynamic_size) {
681 if (dimension != hlo->dimension()) {
682 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
683 }
684 return Status::OK();
685 }));
686
687 return Status::OK();
688 }
689
HandleDynamicConvolutionForward(HloInstruction * hlo,int64 operand_index,int64 dimension,HloInstruction * dynamic_size)690 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionForward(
691 HloInstruction* hlo, int64 operand_index, int64 dimension,
692 HloInstruction* dynamic_size) {
693 TF_RET_CHECK(operand_index == 0);
694 const ConvolutionDimensionNumbers& dimension_numbers =
695 hlo->convolution_dimension_numbers();
696
697 if (dimension == dimension_numbers.input_batch_dimension()) {
698 // Batch dimension is propagated without any changes.
699 parent_->SetDynamicSize(hlo, {}, dimension_numbers.output_batch_dimension(),
700 dynamic_size);
701 return Status::OK();
702 }
703
704 for (int64 spatial_dim_index = 0;
705 spatial_dim_index < dimension_numbers.input_spatial_dimensions_size();
706 ++spatial_dim_index) {
707 int64 input_spatial_dim =
708 dimension_numbers.input_spatial_dimensions(spatial_dim_index);
709 int64 output_spatial_dim =
710 dimension_numbers.output_spatial_dimensions(spatial_dim_index);
711 if (dimension == input_spatial_dim) {
712 // This is a dynamic spatial dimension. Calculate the output size.
713 WindowDimension window_dim = hlo->window().dimensions(spatial_dim_index);
714 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
715 dynamic_size, window_dim.size(), window_dim.window_dilation(),
716 window_dim.stride(), hlo->padding_type());
717 TF_RET_CHECK(window_dim.base_dilation() == 1);
718 parent_->SetDynamicSize(hlo, {}, output_spatial_dim,
719 dynamic_window_dims.output_size);
720 return Status::OK();
721 }
722 }
723 // Input Feature dim disappears after convolution.
724 return Status::OK();
725 }
726
HandleDynamicWindowSamePadding(HloInstruction * hlo,HloInstruction * dynamic_size,int64 operand_index,int64 dimension)727 Status DynamicDimensionInferenceVisitor::HandleDynamicWindowSamePadding(
728 HloInstruction* hlo, HloInstruction* dynamic_size, int64 operand_index,
729 int64 dimension) {
730 const Window& window = hlo->window();
731 const WindowDimension& window_dim = window.dimensions(dimension);
732 if (!window_util::IsTrivialWindowDimension(window_dim)) {
733 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
734 dynamic_size, window_dim.size(), window_dim.window_dilation(),
735 window_dim.stride(), PaddingType::PADDING_SAME);
736 parent_->SetDynamicSize(hlo, {}, dimension,
737 dynamic_window_dims.output_size);
738 return Status::OK();
739 }
740
741 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
742
743 return Status::OK();
744 }
745
HandleDynamicConvolutionInputGrad(HloInstruction * hlo,int64 operand_index,int64 dimension)746 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionInputGrad(
747 HloInstruction* hlo, int64 operand_index, int64 dimension) {
748 // The output size of convolution input grad is corresponding input size.
749 HloInstruction* input_sizes = hlo->mutable_operand(0);
750 HloComputation* comp = hlo->parent();
751 TF_RET_CHECK(input_sizes->shape().rank() == 1) << hlo->ToString();
752 TF_RET_CHECK(input_sizes->shape().element_type() == S32) << hlo->ToString();
753 TF_RET_CHECK(input_sizes->shape().dimensions(0) ==
754 hlo->shape().dimensions_size())
755 << hlo->ToString();
756 // Slice to get corresponding input size.
757 HloInstruction* slice = comp->AddInstruction(
758 HloInstruction::CreateSlice(ShapeUtil::MakeShape(S32, {1}), input_sizes,
759 {dimension}, {dimension + 1}, {1}));
760 HloInstruction* reshape = comp->AddInstruction(
761 HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
762 parent_->SetDynamicSize(hlo, {}, dimension, reshape);
763 return Status::OK();
764 }
765
HandleDynamicConvolutionKernelGrad(HloInstruction * hlo,int64 operand_index,int64 dimension)766 Status DynamicDimensionInferenceVisitor::HandleDynamicConvolutionKernelGrad(
767 HloInstruction* hlo, int64 operand_index, int64 dimension) {
768 // Dynamic convolution kernel grad produces static shape outputs.
769 return Status::OK();
770 }
771
PassThroughDynamicDimension(HloInstruction * hlo)772 Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension(
773 HloInstruction* hlo) {
774 return ForEachOperandDynamicDimension(
775 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
776 int64 operand_index, HloInstruction* dynamic_size) {
777 parent_->SetDynamicSize(hlo, index, dimension, dynamic_size);
778 return Status::OK();
779 });
780 }
781
HandleDomain(HloInstruction * hlo)782 Status DynamicDimensionInferenceVisitor::HandleDomain(HloInstruction* hlo) {
783 return PassThroughDynamicDimension(hlo);
784 }
785
HandleElementwiseUnary(HloInstruction * hlo)786 Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary(
787 HloInstruction* hlo) {
788 return PassThroughDynamicDimension(hlo);
789 }
790
HandleSelect(HloInstruction * hlo)791 Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) {
792 return PassThroughDynamicDimension(hlo);
793 }
794
HandleElementwiseBinary(HloInstruction * hlo)795 Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary(
796 HloInstruction* hlo) {
797 return PassThroughDynamicDimension(hlo);
798 }
799
HandleClamp(HloInstruction * hlo)800 Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) {
801 return PassThroughDynamicDimension(hlo);
802 }
803
HandleDynamicReshape(HloInstruction * hlo)804 Status DynamicDimensionInferenceVisitor::HandleDynamicReshape(
805 HloInstruction* hlo) {
806 HloDynamicReshapeInstruction* dynamic_reshape =
807 Cast<HloDynamicReshapeInstruction>(hlo);
808 for (int64 i = 0; i < hlo->shape().rank(); ++i) {
809 if (hlo->shape().is_dynamic_dimension(i)) {
810 parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i));
811 }
812 }
813 return Status::OK();
814 }
815
HandleReshape(HloInstruction * hlo)816 Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
817 return ForEachOperandDynamicDimension(
818 hlo,
819 [&](HloInstruction* operand, ShapeIndex index,
820 int64 input_dynamic_dimension, int64 operand_index,
821 HloInstruction* operand_dynamic_size) -> Status {
822 HloInstruction* reshape = hlo;
823 if (reshape->shape().rank() == 0) {
824 VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has "
825 "undefined behavior when input size is 0. The offending "
826 "instruction is: "
827 << reshape->ToString();
828 return Status::OK();
829 }
830 auto common_factors = CommonFactors(operand->shape().dimensions(),
831 reshape->shape().dimensions());
832 int64 input_dim_start = -1;
833 int64 input_dim_end = -1;
834 int64 output_dim_start = -1;
835 int64 output_dim_end = -1;
836 // Find common_factors that the input belongs to.
837 for (int64 i = 0; i < common_factors.size() - 1; ++i) {
838 auto start = common_factors[i];
839 auto end = common_factors[i + 1];
840 if (input_dynamic_dimension >= start.first &&
841 input_dynamic_dimension < end.first) {
842 // Found the common_factor group that the input_dim belongs to.
843 input_dim_start = start.first;
844 input_dim_end = end.first;
845 output_dim_start = start.second;
846 output_dim_end = end.second;
847 }
848 }
849
850 VLOG(2) << "Input dim start: " << input_dim_start
851 << " Input dim end: " << input_dim_end
852 << " output dim start: " << output_dim_start
853 << " output dim end: " << output_dim_end;
854
855 if ((input_dim_end - input_dim_start) > 1 &&
856 (output_dim_end - output_dim_start) > 1) {
857 // We don't support the case when a dynamic dimension is both combined
858 // with and splitted into other dimensions:
859 //
860 // [x, yz]
861 // | Reshape
862 // [xy, z]
863 //
864 // TODO(yunxing): This can be supported by canonicalizing
865 // the offending reshape into two reshapes:
866 //
867 // [x,yz]
868 // | Reshape
869 // [x, y, z]
870 // | Reshape
871 // [xy, z]
872 //
873 return Unimplemented(
874 "Dynamic input dimension to reshape that is both splitted and "
875 "combined is not supported %s",
876 hlo->ToString());
877 }
878
879 for (auto common_factor : common_factors) {
880 // Expand common factor to include degenerated output dimensions.
881 if (common_factor.first == input_dim_start) {
882 output_dim_start = std::min(output_dim_start, common_factor.second);
883 }
884 if (common_factor.first == input_dim_end) {
885 output_dim_end = std::max(output_dim_end, common_factor.second);
886 }
887 }
888
889 int64 output_dynamic_dimension = -1;
890
891 if (operand->shape().dimensions(input_dynamic_dimension) == 1) {
892 // If dynamic dimension is 1, it can only be most-major or
893 // most-minor.
894 if (input_dynamic_dimension == 0) {
895 output_dynamic_dimension = 0;
896 }
897 if (input_dynamic_dimension == operand->shape().rank() - 1) {
898 output_dynamic_dimension = reshape->shape().rank() - 1;
899 }
900
901 if (output_dynamic_dimension == -1) {
902 return Unimplemented(
903 "Dynamic degenerated dimension that's not most-minor nor "
904 "most-major is not supported %s",
905 reshape->ToString());
906 }
907 }
908
909 if (output_dynamic_dimension == -1 &&
910 output_dim_end - output_dim_start == 1) {
911 // Only one possible output dimension.
912 output_dynamic_dimension = output_dim_start;
913 }
914
915 if (output_dynamic_dimension == -1 &&
916 output_dim_end - output_dim_start > 1) {
917 // One input dimension is splitted into multiple output dimensions.
918 // Output dimension is decomposed from input most major dimension.
919 // In this case, we don't know which one is dynamic, e.g., when we
920 // have:
921 //
922 // [<=a/c, c, b]
923 // | Reshape
924 // [<=a, b] // a is dynamic, has to be multiple of c.
925 // | Reshape
926 // [1, 1, ... , a/c, c, b]
927 //
928 // Any dimension from the first '1' to 'a/c' can be dynamic.
929 //
930 // We use the following logics to disambiguate:
931 // 1. If the user sets "inferred_dimension", then use that as
932 // dynamic dimension.
933 // 2. If the one dimension in the reshape is dynamic, use that as
934 // dynamic dimension.
935 // E.g.:
936 // [<=4]
937 // |
938 // reshape
939 // |
940 // [1, <=2, 2]
941 // We use second dim as dynamic dimension.
942 //
943 // 3. If all logics above cannot disambiguate, e.g.,:
944 //
945 // [<=1]
946 // |
947 // reshape
948 // |
949 // [1, 1, 1]
950 //
951 // We bail out and return an error.
952 // TODO(yunxing): Further simplify this, remove 1. and fully rely
953 // on 2.
954 output_dynamic_dimension = reshape->inferred_dimension();
955 if (output_dynamic_dimension == -1) {
956 // Try find dynamic dimension from the result shape.
957 for (int64 i = output_dim_start; i < output_dim_end; ++i) {
958 if (reshape->shape().is_dynamic_dimension(i)) {
959 output_dynamic_dimension = i;
960 }
961 }
962 }
963
964 if (output_dynamic_dimension == -1) {
965 std::vector<int64> output_non_degenerated;
966 for (int64 i = output_dim_start; i < output_dim_end; ++i) {
967 if (reshape->shape().dimensions(i) != 1) {
968 output_non_degenerated.push_back(i);
969 }
970 }
971 if (output_non_degenerated.size() == 1) {
972 output_dynamic_dimension = output_non_degenerated[0];
973 }
974 }
975
976 if (output_dynamic_dimension == -1) {
977 return InvalidArgument(
978 "Reshape's input dynamic dimension is decomposed into "
979 "multiple output dynamic dimensions, but the constraint is "
980 "ambiguous and XLA can't infer the output dimension %s. ",
981 hlo->ToString());
982 }
983 }
984
985 CHECK_NE(output_dynamic_dimension, -1);
986 const int64 input_dim_size =
987 operand->shape().dimensions(input_dynamic_dimension);
988 const int64 output_dim_size =
989 reshape->shape().dimensions(output_dynamic_dimension);
990 VLOG(2) << "input_dim_size: " << input_dim_size
991 << " output_dim_size: " << output_dim_size;
992
993 if (input_dim_size == output_dim_size) {
994 // Simply forward dynamic dimension.
995 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
996 operand_dynamic_size);
997 }
998
999 if (input_dim_size > output_dim_size) {
1000 TF_RET_CHECK(input_dim_size % output_dim_size == 0)
1001 << reshape->ToString();
1002 const int64 divisor = input_dim_size / output_dim_size;
1003 HloInstruction* divisor_hlo =
1004 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1005 LiteralUtil::CreateR0<int32>(divisor)));
1006
1007 HloInstruction* new_dynamic_size =
1008 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1009 operand_dynamic_size->shape(), HloOpcode::kDivide,
1010 operand_dynamic_size, divisor_hlo));
1011
1012 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1013 new_dynamic_size);
1014 }
1015
1016 if (input_dim_size < output_dim_size) {
1017 // Input dimension is combined with other input dimensions.
1018 //
1019 // Adjust the output size by the ratio of dynamic_input_dim /
1020 // static_input_dim.
1021 //
1022 // For example if we have [<=3, 3] -> [9], if the dynamic size is 2,
1023 // the new output dynamic isze is 9 / 3 * 2 = 6.
1024 //
1025 // If it turns out the second dimension is also dynamic:
1026 // [<=3, <=3] -> [9], and the dynamic size is also 2, the new output
1027 // dynamic size is 6 / 3 * 2 = 4.
1028 //
1029 //
1030 HloInstruction* output_dynamic_size =
1031 parent_->GetDynamicSize(reshape, {}, output_dynamic_dimension);
1032 if (output_dynamic_size == nullptr) {
1033 output_dynamic_size =
1034 hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
1035 LiteralUtil::CreateR0<int32>(output_dim_size)));
1036 }
1037 HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(
1038 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1039 operand->shape().dimensions(input_dynamic_dimension))));
1040
1041 HloInstruction* new_dynamic_size =
1042 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1043 output_dynamic_size->shape(), HloOpcode::kDivide,
1044 output_dynamic_size, divisor_hlo));
1045
1046 new_dynamic_size =
1047 hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
1048 output_dynamic_size->shape(), HloOpcode::kMultiply,
1049 new_dynamic_size, operand_dynamic_size));
1050 parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension,
1051 new_dynamic_size);
1052 }
1053
1054 return Status::OK();
1055 });
1056 }
1057
HandleReduceWindow(HloInstruction * hlo)1058 Status DynamicDimensionInferenceVisitor::HandleReduceWindow(
1059 HloInstruction* hlo) {
1060 return ForEachOperandDynamicDimension(
1061 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
1062 int64 operand_index, HloInstruction* dynamic_size) {
1063 HloInstruction* reduce_window = hlo;
1064 const WindowDimension& window_dim =
1065 reduce_window->window().dimensions(dimension);
1066
1067 if (!window_util::IsTrivialWindowDimension(window_dim)) {
1068 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1069 dynamic_size, window_dim.size(), window_dim.window_dilation(),
1070 window_dim.stride(), PaddingType::PADDING_VALID);
1071 parent_->SetDynamicSize(hlo, {}, dimension,
1072 dynamic_window_dims.output_size);
1073 return Status::OK();
1074 }
1075
1076 parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size);
1077
1078 return Status::OK();
1079 });
1080 }
1081
HandleSelectAndScatter(HloInstruction * hlo)1082 Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter(
1083 HloInstruction* hlo) {
1084 return ForEachOperandDynamicDimension(
1085 hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension,
1086 int64 operand_index, HloInstruction* dynamic_size) {
1087 if (operand_index == 1) {
1088 // Operand 0 (input) determines dynamic output size. We ignore the
1089 // dynamic size in the operand 1 (output gradient).
1090 return Status::OK();
1091 }
1092 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1093
1094 return Status::OK();
1095 });
1096 }
1097
HandleSlice(HloInstruction * hlo)1098 Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) {
1099 return ForEachOperandDynamicDimension(
1100 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension,
1101 int64 /*operand_index*/, HloInstruction* dynamic_size) {
1102 if (hlo->slice_starts(dimension) != 0 ||
1103 hlo->slice_strides(dimension) != 1 ||
1104 hlo->slice_limits(dimension) !=
1105 operand->shape().dimensions(dimension)) {
1106 // Slicing a partial element out eliminates the dynamic dimension.
1107 return Status::OK();
1108 }
1109
1110 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1111
1112 return Status::OK();
1113 });
1114 }
1115
HandleDynamicSlice(HloInstruction * hlo)1116 Status DynamicDimensionInferenceVisitor::HandleDynamicSlice(
1117 HloInstruction* hlo) {
1118 return ForEachOperandDynamicDimension(
1119 hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64 dimension,
1120 int64 /*operand_index*/, HloInstruction* dynamic_size) {
1121 if (hlo->shape().dimensions(dimension) !=
1122 hlo->operand(0)->shape().dimensions(dimension)) {
1123 // Slicing a single element out kills the dynamic dimension.
1124 if (hlo->shape().dimensions(dimension) == 1) {
1125 return Status::OK();
1126 }
1127 return Unimplemented(
1128 "Dynamic dimension propagation on DynamicSlice where a partial "
1129 "dimension is selected %s",
1130 hlo->ToString());
1131 }
1132
1133 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1134
1135 return Status::OK();
1136 });
1137 }
1138
HandleDynamicUpdateSlice(HloInstruction * hlo)1139 Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice(
1140 HloInstruction* hlo) {
1141 return ForEachOperandDynamicDimension(
1142 hlo,
1143 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
1144 int64 operand_index, HloInstruction* dynamic_size) {
1145 if (hlo->shape().dimensions(dimension) !=
1146 hlo->operand(0)->shape().dimensions(dimension)) {
1147 return Unimplemented(
1148 "Dynamic dimension propagation on DynamicUpdateSlice where a "
1149 "partial dimension is selected %s",
1150 hlo->ToString());
1151 }
1152
1153 if (operand_index == 1 &&
1154 hlo->operand(1)->shape().dimensions(dimension) <
1155 hlo->operand(0)->shape().dimensions(dimension)) {
1156 // DUS(input=[A], update=[<=B])
1157 //
1158 // If update dim is smaller than input dim (B < A) , then we are doing
1159 // a partial update, no need to set the output dynamic dimension.
1160 //
1161 // The dynamic shape in `update` doesn't change output dynamic shape.
1162 return Status::OK();
1163 }
1164
1165 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1166
1167 return Status::OK();
1168 });
1169 }
1170
HandleReverse(HloInstruction * hlo)1171 Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) {
1172 return ForEachOperandDynamicDimension(
1173 hlo,
1174 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
1175 int64 /*operand_index*/, HloInstruction* dynamic_size) {
1176 if (absl::c_linear_search(hlo->dimensions(), dimension)) {
1177 return Unimplemented(
1178 "Dynamic dimension propagation on reversed dimension is not "
1179 "supported %s",
1180 hlo->ToString());
1181 }
1182 parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size);
1183
1184 return Status::OK();
1185 });
1186 }
1187
HandleGather(HloInstruction * hlo)1188 Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) {
1189 return ForEachOperandDynamicDimension(
1190 hlo, [&](HloInstruction* operand, ShapeIndex /*index*/,
1191 int64 input_dynamic_dimension, int64 operand_index,
1192 HloInstruction* dynamic_size) {
1193 const GatherDimensionNumbers& gather_dims =
1194 hlo->gather_dimension_numbers();
1195 if (operand_index != 1) {
1196 if (hlo->gather_slice_sizes()[input_dynamic_dimension] == 1) {
1197 // Gathering a size 1 dimension out of a dynamic dimension removes
1198 // the dynamicity.
1199 return Status::OK();
1200 }
1201 if (hlo->gather_slice_sizes()[input_dynamic_dimension] ==
1202 operand->shape().dimensions(input_dynamic_dimension)) {
1203 // Gathering a full-sized dimension out of a dynamic dimension
1204 // propagates the dynamicity to output.
1205 int64 output_dimension = input_dynamic_dimension;
1206 for (int64 collapsed_dim : gather_dims.collapsed_slice_dims()) {
1207 if (collapsed_dim < input_dynamic_dimension) {
1208 // This output dimension is collapsed.
1209 output_dimension--;
1210 }
1211 }
1212 parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size);
1213 return Status::OK();
1214 }
1215 return Unimplemented(
1216 "Detects a dynamic dimension on the data input of gather, which "
1217 "is not supported: %s, %lld",
1218 hlo->ToString(), input_dynamic_dimension);
1219 }
1220 // A mapping from output to input batch dim number. -1 means not a batch
1221 // dimension.
1222 int64 indices_rank = hlo->operand(1)->shape().rank();
1223 int64 output_rank = hlo->shape().rank();
1224
1225 // indices_dim is an iterator over indices dimensions.
1226 int64 indices_dim = 0;
1227 // Find the corresponding batch dimension in the output.
1228 for (int64 output_dim = 0; output_dim < output_rank; ++output_dim) {
1229 if (!absl::c_linear_search(gather_dims.offset_dims(), output_dim)) {
1230 // Skips index vector dimension.
1231 if (indices_dim == gather_dims.index_vector_dim()) {
1232 indices_dim++;
1233 }
1234 if (indices_dim++ == input_dynamic_dimension) {
1235 parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size);
1236 return Status::OK();
1237 }
1238 }
1239 }
1240 CHECK(indices_dim == indices_rank);
1241
1242 return Unimplemented(
1243 "Detects a non-batch dynamic dimension of gather, "
1244 "which is not supported: %s",
1245 hlo->ToString());
1246 });
1247 }
1248
HandleConditional(HloInstruction * hlo)1249 Status DynamicDimensionInferenceVisitor::HandleConditional(
1250 HloInstruction* hlo) {
1251 // Conditionals are handled by producing additional inputs and outputs of
1252 // the conditional instruction.
1253 std::vector<HloComputation*> new_branch_computations;
1254 std::vector<HloInstruction*> new_operands;
1255 // If the output of the conditional contains dynamic dimension. We send
1256 // dynamic dimension size out by adding additional root element. A mapping
1257 // from the root instruction's dynamic dimension index (represented by a shape
1258 // index as output index and a int64 dimension number) to output index
1259 // (represented by an int64) is tracked for the conditional intsruction (all
1260 // branches should have the same mapping).
1261 ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1262 hlo->shape());
1263
1264 bool need_rewrite = false;
1265
1266 for (int64 branch_index = 0; branch_index < hlo->branch_count();
1267 ++branch_index) {
1268 std::vector<HloInstruction*> operands_to_add;
1269
1270 absl::flat_hash_map<HloInstruction*, int64>
1271 dynamic_size_to_operand_id_index_map;
1272 // Only look at branch_index + 1, the correct operand index for a
1273 // given branch.
1274 const int64 operand_index = branch_index + 1;
1275
1276 int64 operand_count =
1277 hlo->operand(operand_index)->shape().tuple_shapes_size();
1278 // Prepare to pass dynamic dimension into the new computation and add
1279 // dynamic dimension sizes as parameters to the new tuple.
1280 TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1281 hlo, operand_index,
1282 [&](HloInstruction*, ShapeIndex, int64, int64,
1283 HloInstruction* dynamic_size) -> Status {
1284 TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple())
1285 << "Only tuple typed inputs can have dynamic dimension. Please "
1286 "file a bug against XLA team.";
1287 const HloInstruction* tuple_operand = hlo->operand(operand_index);
1288 for (int64 i = 0; i < tuple_operand->operand_count(); ++i) {
1289 // If the dynamic size is already an operand to the computation,
1290 // skip adding it to the computation input again.
1291 if (dynamic_size == tuple_operand->operand(i)) {
1292 dynamic_size_to_operand_id_index_map[dynamic_size] = i;
1293 return Status::OK();
1294 }
1295 }
1296 auto iter = dynamic_size_to_operand_id_index_map.find(dynamic_size);
1297 if (iter == dynamic_size_to_operand_id_index_map.end()) {
1298 operands_to_add.push_back(dynamic_size);
1299 dynamic_size_to_operand_id_index_map[dynamic_size] =
1300 operand_count++;
1301 }
1302 return Status::OK();
1303 }));
1304
1305 HloInstruction* original_input = hlo->mutable_operand(operand_index);
1306 HloComputation* branch_computation = hlo->branch_computation(branch_index);
1307
1308 HloComputation* new_computation = branch_computation;
1309 HloInstruction* new_operand = hlo->mutable_operand(operand_index);
1310 if (!operands_to_add.empty()) {
1311 TF_RET_CHECK(original_input->shape().IsTuple());
1312 need_rewrite = true;
1313 new_operand = TupleUtil::AppendSuffix(original_input, operands_to_add);
1314 TF_ASSIGN_OR_RETURN(
1315 new_computation,
1316 WidenComputation(branch_computation, new_operand->shape()));
1317 }
1318 // Set the dynamic dimensions for the newly created branch computation's
1319 // parameters so that the hlos inside the computation can see dynamic
1320 // dimensions.
1321 DynamicParameterBinding dynamic_parameter_binding;
1322 TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand(
1323 hlo, operand_index,
1324 [&](HloInstruction*, ShapeIndex index, int64 dimension,
1325 int64 operand_index, HloInstruction* dynamic_size) {
1326 DynamicParameterBinding::DynamicParameter dynamic_parameter{
1327 0, {dynamic_size_to_operand_id_index_map[dynamic_size]}};
1328 DynamicParameterBinding::DynamicDimension dynamic_dimension{
1329 0, {index}, dimension};
1330 TF_RETURN_IF_ERROR(dynamic_parameter_binding.Bind(dynamic_parameter,
1331 dynamic_dimension));
1332
1333 return Status::OK();
1334 }));
1335 VLOG(2) << "dynamic_parameter_binding for conditional branch"
1336 << dynamic_parameter_binding;
1337 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1338 new_computation, dynamic_parameter_binding, parent_));
1339 std::vector<HloInstruction*> hlos_to_add_in_root;
1340 int64 original_tuple_count = hlo->shape().tuple_shapes_size();
1341 // There may be some dynamic dimensions coming out of the computation, wire
1342 // that into the root instruction as additional tuple elements.
1343 TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1344 new_computation->root_instruction(),
1345 [&](ShapeIndex index, int64 dim,
1346 HloInstruction* dynamic_size) -> Status {
1347 TF_RET_CHECK(hlo->shape().IsTuple())
1348 << "Only tuple typed conditionals can have dynamic dimension. "
1349 "Please file a bug against XLA team.";
1350 dynamic_output_mapping.mutable_element(index)->emplace(
1351 dim, original_tuple_count++);
1352 hlos_to_add_in_root.push_back(dynamic_size);
1353 return Status::OK();
1354 }));
1355
1356 VLOG(2) << "hlos_to_add_in_root:" << hlos_to_add_in_root.size();
1357 if (!hlos_to_add_in_root.empty()) {
1358 need_rewrite = true;
1359 HloInstruction* new_branch_root = TupleUtil::AppendSuffix(
1360 new_computation->root_instruction(), hlos_to_add_in_root);
1361 new_computation->set_root_instruction(new_branch_root,
1362 /*accept_different_shape=*/true);
1363 }
1364
1365 new_branch_computations.push_back(new_computation);
1366 new_operands.push_back(new_operand);
1367 }
1368 if (!need_rewrite) {
1369 return Status::OK();
1370 }
1371 // Create a new conditional with the new operations and computations.
1372 HloInstruction* new_conditional =
1373 hlo->parent()->AddInstruction(HloInstruction::CreateConditional(
1374 new_branch_computations[0]->root_instruction()->shape(),
1375 hlo->mutable_operand(0), new_branch_computations, new_operands));
1376
1377 HloInstruction* new_conditional_extracted = TupleUtil::ExtractPrefix(
1378 new_conditional, hlo->shape().tuple_shapes_size());
1379 // Now set the dynamic dimensions of the newly created conditional.
1380 dynamic_output_mapping.ForEachElement(
1381 [&](const ShapeIndex& index,
1382 const absl::flat_hash_map<int64, int64>& dim_to_output) {
1383 for (auto iter : dim_to_output) {
1384 int64 dim = iter.first;
1385 int64 output_index = iter.second;
1386 HloInstruction* dynamic_size = hlo->parent()->AddInstruction(
1387 HloInstruction::CreateGetTupleElement(
1388 ShapeUtil::MakeScalarShape(S32), new_conditional,
1389 output_index));
1390 parent_->SetDynamicSize(new_conditional, index, dim, dynamic_size);
1391 parent_->SetDynamicSize(new_conditional_extracted, index, dim,
1392 dynamic_size);
1393 }
1394 });
1395
1396 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_conditional_extracted));
1397 // Remove the original instruction even if has side-effects.
1398 TF_RETURN_IF_ERROR(hlo->parent()->RemoveInstruction(hlo));
1399 SetVisited(*new_conditional);
1400 SetVisited(*new_conditional_extracted);
1401 return Status::OK();
1402 }
1403
HandleScatter(HloInstruction * hlo)1404 Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) {
1405 return ForEachOperandDynamicDimension(
1406 hlo,
1407 [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension,
1408 int64 operand_index, HloInstruction* operand_dynamic_size) {
1409 if (operand_index == 0) {
1410 parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size);
1411 return Status::OK();
1412 }
1413
1414 const ScatterDimensionNumbers& scatter_dims =
1415 hlo->scatter_dimension_numbers();
1416 if (operand_index == 2 &&
1417 absl::c_linear_search(scatter_dims.update_window_dims(),
1418 dimension)) {
1419 return Unimplemented(
1420 "Dynamic dimension of update window dims is not supported: %s",
1421 hlo->ToString());
1422 }
1423 // The dynamic dimension is collapsed and won't show up in the output.
1424 // Do nothing here.
1425 return Status::OK();
1426 });
1427 }
1428
HandleWhile(HloInstruction * hlo)1429 Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
1430 // If the output of the kWhile contains dynamic dimension, we send
1431 // dynamic dimension size into the while body by adding additional root/body
1432 // element. A mapping from the root instruction's dynamic dimension index
1433 // (represented by a shape index as output index and an int64 dimension
1434 // number) to output index (represented by an int64) is tracked for the
1435 // conditional instruction.
1436 ShapeTree<absl::flat_hash_map<int64, int64>> dynamic_output_mapping(
1437 hlo->shape());
1438 std::vector<HloInstruction*> operands_to_add;
1439 const int64 original_tuple_count = hlo->shape().tuple_shapes_size();
1440 int64 operand_count = original_tuple_count;
1441 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1442 hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64,
1443 HloInstruction* dynamic_size) {
1444 operands_to_add.push_back(dynamic_size);
1445 dynamic_output_mapping.mutable_element(index)->emplace(dim,
1446 operand_count++);
1447 return Status::OK();
1448 }));
1449
1450 DynamicParameterBinding binding_for_while;
1451 if (!operands_to_add.empty()) {
1452 // Only replace the while loop if there are new parameters to add.
1453 HloInstruction* old_tuple_operand = hlo->mutable_operand(0);
1454 TF_ASSIGN_OR_RETURN(
1455 WhileUtil::MakeInstructionsLiveInResult result,
1456 WhileUtil::MakeInstructionsLiveIn(hlo, operands_to_add));
1457 // WhileUtil creates a new while hlo and tuple. Update the dynamic size
1458 // mapping for the newly created tuple.
1459 HloInstruction* new_tuple_operand =
1460 result.new_while_instr->mutable_operand(0);
1461 parent_->CopyMapping(/*from=*/old_tuple_operand,
1462 /*to=*/new_tuple_operand);
1463 hlo = result.new_while_instr;
1464 // We have replaced the while loop, now set the dynamic dimensions for the
1465 // newly created while loop so that the hlos that consumes the while loop
1466 // can see the dynamic dimensions. Also sets the dynamic parameter binding
1467 // for running inference in the while loop.
1468 TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension(
1469 hlo,
1470 [&](HloInstruction*, ShapeIndex index, int64 dimension,
1471 int64 operand_index, HloInstruction* dynamic_size) -> Status {
1472 TF_RET_CHECK(!operands_to_add.empty());
1473 const int64 output_dynamic_size_index =
1474 dynamic_output_mapping.element(index).at(dimension);
1475 DynamicParameterBinding::DynamicParameter dynamic_parameter{
1476 operand_index, {output_dynamic_size_index}};
1477 DynamicParameterBinding::DynamicDimension dynamic_dimension{
1478 operand_index, index, dimension};
1479 TF_RETURN_IF_ERROR(
1480 binding_for_while.Bind(dynamic_parameter, dynamic_dimension));
1481 // This is the updated output dynamic size coming out of hlo while
1482 // loop.
1483 HloInstruction* output_dynamic_size = hlo->parent()->AddInstruction(
1484 HloInstruction::CreateGetTupleElement(
1485 ShapeUtil::MakeScalarShape(S32), hlo,
1486 output_dynamic_size_index));
1487 parent_->SetDynamicSize(result.replacement_instr, index, dimension,
1488 output_dynamic_size);
1489 return Status::OK();
1490 }));
1491 // Set the replacement instruction as visited to avoid visiting it again.
1492 SetVisited(*result.replacement_instr);
1493 }
1494
1495 // Run inference in while body and condition.
1496 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1497 hlo->while_body(), binding_for_while, parent_));
1498 TF_RETURN_IF_ERROR(DynamicDimensionInferenceVisitor::Run(
1499 hlo->while_condition(), binding_for_while, parent_));
1500
1501 if (operands_to_add.empty()) {
1502 // No dynamic dimension in the inputs and outputs.
1503 return Status::OK();
1504 }
1505
1506 // The dynamic dimension size could have been changed in the loop body (e.g, A
1507 // loop that inserts items in a stack, the stack size increases with each
1508 // iteration). Rewrite the dynamic dimension size at the root.
1509 HloInstruction* body_root = hlo->while_body()->root_instruction();
1510 std::vector<HloInstruction*> new_root_operands(body_root->operand_count(),
1511 nullptr);
1512
1513 // Original non-dynamic-dim operands of root are pass-through.
1514 for (int64 i = 0; i < original_tuple_count; ++i) {
1515 new_root_operands[i] =
1516 hlo->while_body()->AddInstruction(HloInstruction::CreateGetTupleElement(
1517 body_root->shape().tuple_shapes(i), body_root, i));
1518 }
1519 // Add dynamic dimension size as new parameters.
1520 TF_RETURN_IF_ERROR(ForEachDynamicDimension(
1521 hlo->while_body()->root_instruction(),
1522 [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size) -> Status {
1523 const int64 output_index =
1524 dynamic_output_mapping.element(index).at(dim);
1525 new_root_operands[output_index] = dynamic_size;
1526 return Status::OK();
1527 }));
1528 for (auto operand : new_root_operands) {
1529 TF_RET_CHECK(operand != nullptr);
1530 }
1531 HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
1532 HloInstruction::CreateTuple(new_root_operands));
1533 hlo->while_body()->set_root_instruction(new_body_root);
1534 return Status::OK();
1535 }
1536
HandleParameter(HloInstruction * hlo)1537 Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) {
1538 return param_bindings_.ForEachBinding(
1539 [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1540 const DynamicParameterBinding::DynamicDimension& dynamic_dimension) {
1541 if (dynamic_dimension.parameter_num != hlo->parameter_number()) {
1542 return Status::OK();
1543 }
1544 HloComputation* computation = hlo->parent();
1545 HloInstruction* target_parameter =
1546 computation->parameter_instruction(dynamic_dimension.parameter_num);
1547
1548 HloInstruction* dynamic_size =
1549 computation->parameter_instruction(dynamic_parameter.parameter_num);
1550 for (int64 i : dynamic_parameter.parameter_index) {
1551 dynamic_size =
1552 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
1553 ShapeUtil::GetSubshape(dynamic_size->shape(), {i}),
1554 dynamic_size, i));
1555 }
1556
1557 parent_->SetDynamicSize(target_parameter,
1558 dynamic_dimension.parameter_index,
1559 dynamic_dimension.dimension, dynamic_size);
1560 return Status::OK();
1561 });
1562 }
1563
ForEachDynamicDimension(HloInstruction * inst,const DynamicDimensionFn & fn)1564 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension(
1565 HloInstruction* inst, const DynamicDimensionFn& fn) {
1566 auto iter = parent_->per_hlo_dynamic_dimensions_.find(inst);
1567 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1568 for (auto& dynamic_dimension : iter->second) {
1569 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1570 dynamic_dimension.inst, dynamic_dimension.index,
1571 dynamic_dimension.dim);
1572 TF_RETURN_IF_ERROR(
1573 fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size));
1574 }
1575 }
1576 return Status::OK();
1577 }
1578
ForEachDynamicDimensionInOperand(HloInstruction * inst,int64 operand_index,const OperandDynamicDimensionFn & fn)1579 Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand(
1580 HloInstruction* inst, int64 operand_index,
1581 const OperandDynamicDimensionFn& fn) {
1582 auto iter =
1583 parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index));
1584 if (iter != parent_->per_hlo_dynamic_dimensions_.end()) {
1585 for (auto& dynamic_dimension : iter->second) {
1586 HloInstruction* dynamic_size = parent_->GetDynamicSize(
1587 dynamic_dimension.inst, dynamic_dimension.index,
1588 dynamic_dimension.dim);
1589 TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index,
1590 dynamic_dimension.dim, operand_index,
1591 dynamic_size));
1592 }
1593 }
1594 return Status::OK();
1595 }
1596
ForEachOperandDynamicDimension(HloInstruction * inst,const OperandDynamicDimensionFn & fn)1597 Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension(
1598 HloInstruction* inst, const OperandDynamicDimensionFn& fn) {
1599 for (int64 operand_index = 0; operand_index < inst->operand_count();
1600 ++operand_index) {
1601 TF_RETURN_IF_ERROR(
1602 ForEachDynamicDimensionInOperand(inst, operand_index, fn));
1603 }
1604 return Status::OK();
1605 }
1606
SetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim,HloInstruction * size)1607 void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst,
1608 const ShapeIndex& index,
1609 int64 dim,
1610 HloInstruction* size) {
1611 VLOG(1) << "Set dimension inst " << inst->ToString() << " index "
1612 << index.ToString() << "@" << dim << " to " << size->ToShortString();
1613 Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index);
1614 CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension";
1615 CHECK(dim < subshape.rank() && dim >= 0)
1616 << "Asked to set invalid dynamic dimension. Shape: "
1617 << subshape.ToString() << ", Dimension: " << dim;
1618 DynamicDimension dynamic_dimension{inst, index, dim};
1619 // Updating a dynamic dimension twice overwrites the previous one.
1620 dynamic_mapping_[dynamic_dimension] = size;
1621 auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst);
1622 iter.first->second.emplace(dynamic_dimension);
1623 }
1624
CopyMapping(HloInstruction * from,HloInstruction * to)1625 void DynamicDimensionInference::CopyMapping(HloInstruction* from,
1626 HloInstruction* to) {
1627 auto iter = per_hlo_dynamic_dimensions_.find(from);
1628 if (iter != per_hlo_dynamic_dimensions_.end()) {
1629 for (auto& dynamic_dimension : iter->second) {
1630 HloInstruction* dynamic_size =
1631 GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index,
1632 dynamic_dimension.dim);
1633 SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim,
1634 dynamic_size);
1635 }
1636 }
1637 }
1638
1639 /* static */
Run(HloModule * module,CustomCallInferenceHandler custom_call_handler)1640 StatusOr<DynamicDimensionInference> DynamicDimensionInference::Run(
1641 HloModule* module, CustomCallInferenceHandler custom_call_handler) {
1642 VLOG(2) << "Param Config " << module->dynamic_parameter_binding().ToString();
1643 DynamicDimensionInference inference(module, std::move(custom_call_handler));
1644 TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions());
1645 return inference;
1646 }
1647
ToString() const1648 string DynamicDimensionInference::ToString() const {
1649 std::vector<string> pieces;
1650 pieces.push_back("DynamicDimensionInference: ");
1651 for (const auto& mapping : dynamic_mapping_) {
1652 const DynamicDimension& dynamic_dimension = mapping.first;
1653 pieces.push_back(absl::StrFormat(
1654 " -- instruction %s at %s has dim %lld as dynamic"
1655 " dimension, which is represented by instruction %s",
1656 dynamic_dimension.inst->ToString(), dynamic_dimension.index.ToString(),
1657 dynamic_dimension.dim, mapping.second->ToString()));
1658 }
1659 return absl::StrJoin(pieces, "\n");
1660 }
1661
DynamicDimensionInference(HloModule * module,CustomCallInferenceHandler custom_call_handler)1662 DynamicDimensionInference::DynamicDimensionInference(
1663 HloModule* module, CustomCallInferenceHandler custom_call_handler)
1664 : module_(module), custom_call_handler_(std::move(custom_call_handler)) {}
1665
AnalyzeDynamicDimensions()1666 Status DynamicDimensionInference::AnalyzeDynamicDimensions() {
1667 return DynamicDimensionInferenceVisitor::Run(
1668 module_->entry_computation(), module_->dynamic_parameter_binding(), this,
1669 custom_call_handler_);
1670 }
1671
ReplaceAllDynamicDimensionUsesWith(HloInstruction * replace,HloInstruction * with)1672 void DynamicDimensionInference::ReplaceAllDynamicDimensionUsesWith(
1673 HloInstruction* replace, HloInstruction* with) {
1674 CHECK(Shape::Equal().IgnoreLayout()(replace->shape(),
1675 ShapeUtil::MakeScalarShape(S32)));
1676 CHECK(Shape::Equal().IgnoreLayout()(with->shape(),
1677 ShapeUtil::MakeScalarShape(S32)));
1678 for (auto& kv : dynamic_mapping_) {
1679 if (kv.second == replace) {
1680 kv.second = with;
1681 }
1682 }
1683 }
1684
ForwardDynamicSize(HloInstruction * inst,HloInstruction * new_inst,const ShapeIndex & index)1685 Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst,
1686 HloInstruction* new_inst,
1687 const ShapeIndex& index) {
1688 CHECK(Shape::Equal()(inst->shape(), new_inst->shape()));
1689
1690 for (int64 dim = 0; dim < inst->shape().rank(); ++dim) {
1691 DynamicDimension dynamic_dimension_new{new_inst, index, dim};
1692 DynamicDimension dynamic_dimension{inst, index, dim};
1693 auto iter = dynamic_mapping_.find(dynamic_dimension);
1694 if (iter != dynamic_mapping_.end()) {
1695 dynamic_mapping_.insert({dynamic_dimension_new, iter->second});
1696 auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst);
1697 iter.first->second.emplace(dynamic_dimension_new);
1698 }
1699 }
1700
1701 return Status::OK();
1702 }
1703
HasDynamicDimension(HloInstruction * inst) const1704 bool DynamicDimensionInference::HasDynamicDimension(
1705 HloInstruction* inst) const {
1706 bool has_dynamic_dim = false;
1707 ShapeUtil::ForEachSubshape(
1708 inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
1709 if (subshape.IsTuple()) {
1710 return;
1711 }
1712 for (int64 i = 0; i < subshape.dimensions_size(); ++i) {
1713 HloInstruction* operand_dynamic_size = GetDynamicSize(inst, index, i);
1714 if (operand_dynamic_size != nullptr) {
1715 has_dynamic_dim = true;
1716 }
1717 }
1718 });
1719 return has_dynamic_dim;
1720 }
1721
GetDynamicSize(HloInstruction * inst,const ShapeIndex & index,int64 dim) const1722 HloInstruction* DynamicDimensionInference::GetDynamicSize(
1723 HloInstruction* inst, const ShapeIndex& index, int64 dim) const {
1724 auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim});
1725 if (iter != dynamic_mapping_.end()) {
1726 return iter->second;
1727 }
1728 return nullptr;
1729 }
1730
GetDynamicSizes(HloInstruction * inst,const ShapeIndex & index) const1731 std::vector<HloInstruction*> DynamicDimensionInference::GetDynamicSizes(
1732 HloInstruction* inst, const ShapeIndex& index) const {
1733 CHECK(ShapeUtil::IndexIsValid(inst->shape(), index));
1734 const int64 rank = ShapeUtil::GetSubshape(inst->shape(), index).rank();
1735 std::vector<HloInstruction*> result(rank, nullptr);
1736 for (int64 i = 0; i < rank; ++i) {
1737 result[i] = GetDynamicSize(inst, {}, i);
1738 }
1739 return result;
1740 }
1741
1742 } // namespace xla
1743