• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
30 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
31 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
35 #include "tensorflow/compiler/xla/service/hlo_dce.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
38 #include "tensorflow/compiler/xla/service/hlo_module.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
41 #include "tensorflow/compiler/xla/service/shape_inference.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/platform/errors.h"
49 #include "tensorflow/core/platform/statusor.h"
50 
51 namespace xla {
52 
53 namespace {
54 
55 // ChooseIdentityValue looks at the instruction's operand, returns a
56 // identity value which, when padded, doesn't change the result of the
57 // instruction.
58 //
59 // nullopt is returned if padding doesn't need to be reset.
ChooseIdentityValue(HloInstruction * inst,int64_t operand_number)60 StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
61                                               int64_t operand_number) {
62   HloComputation* comp = inst->parent();
63   // Padding on elementwise operation doesn't affect the result of the effective
64   // data.
65   if (inst->IsElementwise()) {
66     return nullptr;
67   }
68   if (inst->opcode() == HloOpcode::kSelectAndScatter ||
69       inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
70     if (operand_number == 1) {
71       return inst->mutable_operand(2);
72     }
73     TF_RET_CHECK(operand_number == 0);
74     HloComputation* select = inst->called_computations()[0];
75 
76     if (Match(select->root_instruction(),
77               match::Compare(match::Parameter(), match::Parameter())
78                   .WithComparisonDirection(ComparisonDirection::kGe))) {
79       return comp->AddInstruction(HloInstruction::CreateConstant(
80           LiteralUtil::MinValue(inst->operand(0)->shape().element_type())));
81     } else {
82       return Unimplemented(
83           "Only select and scatter with `max` as select function is "
84           "supported, got %s",
85           select->ToString());
86     }
87   }
88   switch (inst->opcode()) {
89     case HloOpcode::kReduce: {
90       auto* reduce = Cast<HloReduceInstruction>(inst);
91       TF_RET_CHECK(operand_number < reduce->input_count())
92           << "Only data operand with dynamic dimension is valid.";
93       // Variadic reduce has different init value for different operand, given
94       // a data operand number, find the init value index.
95       int64_t init_value_index = reduce->input_count() + operand_number;
96       return inst->mutable_operand(init_value_index);
97     }
98     case HloOpcode::kReduceWindow: {
99       auto* reduce_window = Cast<HloReduceWindowInstruction>(inst);
100       TF_RET_CHECK(operand_number < reduce_window->input_count())
101           << "Only data operand with dynamic dimension is valid.";
102       // Variadic reduce has different init value for different operand, given
103       // a data operand number, find the init value index.
104       int64_t init_value_index = reduce_window->input_count() + operand_number;
105       return inst->mutable_operand(init_value_index);
106     }
107 
108     case HloOpcode::kConvolution:
109     case HloOpcode::kDot: {
110       // Use 0 as padding value for convolution and dot.
111       PrimitiveType ptype = inst->shape().element_type();
112       return comp->AddInstruction(
113           HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
114     }
115 
116     case HloOpcode::kPad: {
117       return inst->mutable_operand(1);
118     }
119     case HloOpcode::kScatter: {
120       if (operand_number != 1) {
121         return nullptr;
122       }
123       PrimitiveType indices_ptype =
124           inst->operand(operand_number)->shape().element_type();
125 
126       return comp->AddInstruction(
127           HloInstruction::CreateConstant(LiteralUtil::MaxValue(indices_ptype)));
128     }
129     case HloOpcode::kParameter:
130     case HloOpcode::kGather:
131     case HloOpcode::kDynamicSlice:
132     case HloOpcode::kDynamicUpdateSlice:
133     case HloOpcode::kGetDimensionSize:
134     case HloOpcode::kSetDimensionSize:
135     case HloOpcode::kConcatenate:
136     case HloOpcode::kReshape:
137     case HloOpcode::kReverse:
138     case HloOpcode::kTuple:
139     case HloOpcode::kAllReduce:
140     case HloOpcode::kReduceScatter:
141     case HloOpcode::kBroadcast:
142     case HloOpcode::kTranspose:
143     case HloOpcode::kSort:
144     case HloOpcode::kSlice:
145     case HloOpcode::kDomain:
146       return nullptr;
147     case HloOpcode::kCustomCall:
148       // Assume that custom calls created by the client are valid with padded
149       // dynamic dimensions.
150       return nullptr;
151     default:
152       return UnimplementedStrCat("Unimplemented padding for instruction: ",
153                                  inst->ToString());
154   }
155 }
156 
ReplaceGetSize(HloInstruction * instr,DynamicDimensionInference * dynamic_dimension_inference)157 StatusOr<bool> ReplaceGetSize(
158     HloInstruction* instr,
159     DynamicDimensionInference* dynamic_dimension_inference) {
160   if (instr->opcode() != HloOpcode::kGetDimensionSize) {
161     return false;
162   }
163   HloComputation* computation = instr->parent();
164 
165   TF_ASSIGN_OR_RETURN(auto legal_shape,
166                       ShapeInference::InferGetDimensionSizeShape(
167                           instr->operand(0)->shape(), instr->dimension()));
168   TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
169       << "instr->shape() " << instr->shape().ToString() << " , "
170       << "legal_shape " << legal_shape.ToString();
171   TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
172   HloInstruction* operand = instr->mutable_operand(0);
173   int64_t dim = instr->dimension();
174   HloInstruction* dynamic_size =
175       dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
176   if (dynamic_size != nullptr) {
177     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
178     // The dependency between a instruction and its dynamic dimensions is not
179     // modeled in the IR. As instr is being replaced by dynamic_size, also tell
180     // dynamic dimension inference that the instruction is being replaced.
181     dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
182         instr, dynamic_size);
183   } else {
184     int32_t size = instr->operand(0)->shape().dimensions(dim);
185     HloInstruction* new_instr = computation->AddInstruction(
186         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
187     TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
188     dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
189                                                                     new_instr);
190   }
191   return true;
192 }
193 
ReplaceSetSize(HloInstruction * instr)194 StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
195   if (instr->opcode() != HloOpcode::kSetDimensionSize) {
196     return false;
197   }
198 
199   TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
200       instr->shape(), instr->operand(0)->shape()))
201       << "instr->shape() " << instr->shape().ToString() << " , "
202       << "instruction operand shape " << instr->operand(0)->shape();
203   HloInstruction* operand = instr->mutable_operand(0);
204 
205   TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
206   return true;
207 }
208 
ReplaceSetBound(HloInstruction * instr)209 StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
210   if (instr->opcode() != HloOpcode::kCustomCall ||
211       instr->custom_call_target() != "SetBound") {
212     return false;
213   }
214 
215   TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
216       instr->shape(), instr->operand(0)->shape()))
217       << "instr->shape() " << instr->shape().ToString() << " , "
218       << "instruction operand shape " << instr->operand(0)->shape();
219   HloInstruction* operand = instr->mutable_operand(0);
220 
221   TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
222   return true;
223 }
224 
ShouldSkipPadOnOperand(const HloInstruction * inst,int64_t operand_num,int64_t dimension)225 bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64_t operand_num,
226                             int64_t dimension) {
227   if (inst->opcode() == HloOpcode::kSelectAndScatter && operand_num == 0 &&
228       inst->window().dimensions(dimension).size() == 1) {
229     return true;
230   }
231 
232   if (auto* reduce_window = DynCast<HloReduceWindowInstruction>(inst)) {
233     if (operand_num < reduce_window->input_count() &&
234         inst->window().dimensions(dimension).size() == 1) {
235       return true;
236     }
237   }
238 
239   if (operand_num == 0 && inst->opcode() == HloOpcode::kConvolution &&
240       inst->convolution_dimension_numbers().input_batch_dimension() ==
241           dimension) {
242     return true;
243   }
244   return false;
245 }
246 
247 // Generates a mask representing the effective area of data and padded area of
248 // data using iota and dynamic_size. For example, given a dimension of 7
249 // elements and 5 effective elements:
250 //
251 // iota = [0, 1, 2, 3, 4, 5, 6]
252 // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5]
253 // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f]
254 //
255 // Once the mask is generated, the input data is then padded using the
256 // mask and pad value.
257 //
PadWithScalar(HloInstruction * inst,int64_t dim,HloInstruction * dynamic_size,HloInstruction * padding_scalar)258 HloInstruction* PadWithScalar(HloInstruction* inst, int64_t dim,
259                               HloInstruction* dynamic_size,
260                               HloInstruction* padding_scalar) {
261   CHECK(inst != nullptr && dynamic_size != nullptr &&
262         padding_scalar != nullptr);
263   const Shape mask_shape =
264       ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
265   const Shape pred_shape =
266       ShapeUtil::ChangeElementType(inst->shape(), xla::PRED);
267   HloComputation* computation = inst->parent();
268   HloInstruction* iota =
269       computation->AddInstruction(HloInstruction::CreateIota(mask_shape, dim));
270 
271   HloInstruction* broadcasted_effective_size = computation->AddInstruction(
272       HloInstruction::CreateBroadcast(mask_shape, dynamic_size, {}));
273   HloInstruction* pred =
274       computation->AddInstruction(HloInstruction::CreateCompare(
275           pred_shape, iota, broadcasted_effective_size,
276           ComparisonDirection::kLt));
277 
278   HloInstruction* broadcasted_identity_value = computation->AddInstruction(
279       HloInstruction::CreateBroadcast(inst->shape(), padding_scalar, {}));
280   HloInstruction* padded = computation->AddInstruction(
281       HloInstruction::CreateTernary(inst->shape(), HloOpcode::kSelect, pred,
282                                     inst, broadcasted_identity_value));
283   return padded;
284 }
285 
286 // In a reshape if a dynamic dimension is splitted into multiple output
287 // dimensions, we need to rewrite the input of the reshape.
288 //
289 // The reason for this is that a continuous input may not be evenly reshaped
290 // into output.  Image we have [<=6] where valid data has size 4 and padding (P)
291 // data has size 2: [a,b,c,d,P,P]
292 //
293 // And we have a reshape that produces dynamic output dimensions.
294 //
295 // [<=6]
296 //  |
297 // Reshape
298 //  |
299 // [2, <=3]
300 //
301 // This should produce the same result as if the data has no padding:
302 //
303 // [4]     // [a, b, c, d]
304 //  |
305 // Reshape
306 //  |
307 // [2, 2]  // [[a,b], [c,d]]
308 //
309 // Without reshape rewriting, the result looks like:
310 //
311 // [[a,b,c]
312 //  [d,P,P]], which is incorrect.
313 //
314 // We need to rewrite the reshape such that it produces:
315 // [[a,b,P]
316 //  [c,d,P]]
317 //
318 // The way we do this is by a 5-steps cumsum-gather algorithm:
319 //
320 // 1.First we use the output shape to generate a binary 0-1 masking, which masks
321 // out the padded area of the output:
322 // [[1,1,0]
323 //  [1,1,0]]
324 //
325 // 2.Then we do an inverse reshape to reshape it from output shape back to input
326 // shape [2,3]->[6]:
327 //  [1,1,0,1,1,0]
328 //
329 // 3.We then do a cumsum with the mask:
330 //  [1,2,2,3,4,4] and subtract it with 1:
331 //  [0,1,1,2,3,3]
332 //
333 // 4.Use the result of cumsum as gather indices to rearrange the original
334 // data. Feed the original input [a,b,c,d,P,P] and indices into gather.
335 //
336 //  operand [a,b,c,d,P,P], indices [0,1,1,2,3,3]
337 //     |                    |
338 //   Gather-----------------+
339 //     |
340 //     v
341 //  value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value
342 //  doesn't matter.
343 //
344 //
345 // 5.Feed the sorted input to original reshape[6]->[2,3], we can now get the
346 // correct result:
347 //  [[a,b,P]
348 //   [c,d,P]]
349 //
RewriteDynamicReshapeSplitInput(HloInstruction * reshape,int64_t input_dim,absl::Span<const int64> output_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)350 Status RewriteDynamicReshapeSplitInput(
351     HloInstruction* reshape, int64_t input_dim,
352     absl::Span<const int64> output_dims,
353     absl::Span<HloInstruction*> output_dynamic_dims,
354     DynamicDimensionInference* dynamic_dimension_inference) {
355   VLOG(2) << "Reshaping input dim " << input_dim << "to "
356           << VectorString(output_dims);
357   const Shape operand_shape = reshape->operand(0)->shape();
358   TF_RET_CHECK(output_dims.size() > 1);
359 
360   HloComputation* comp = reshape->parent();
361   const Shape mask_input_shape =
362       ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)});
363 
364   std::vector<int64> reshaped_dims;
365   for (int64_t output_dim : output_dims) {
366     reshaped_dims.push_back(reshape->shape().dimensions(output_dim));
367   }
368 
369   const Shape mask_reshaped_shape =
370       ShapeUtil::MakeShape(xla::S32, reshaped_dims);
371 
372   HloInstruction* zero = comp->AddInstruction(
373       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
374   HloInstruction* one = comp->AddInstruction(
375       HloInstruction::CreateConstant(LiteralUtil::One(S32)));
376   // Step 1 -- generate binary mask.
377   // Mask starts with all one, each dynamic dimension sets that dimension of the
378   // mask to partially zero in the end.
379   HloInstruction* binary_mask = comp->AddInstruction(
380       HloInstruction::CreateBroadcast(mask_reshaped_shape, one, {}));
381 
382   bool need_rewrite = false;
383 
384   // Pad the effective dimension with 1.
385   //
386   // Index starts from 1 since there is no need to rewrite a major output
387   // dimension.
388   for (int64_t i = 1; i < output_dims.size(); ++i) {
389     const int64_t output_dim = output_dims[i];
390     HloInstruction* dynamic_size = output_dynamic_dims[output_dim];
391     if (dynamic_size == nullptr) {
392       continue;
393     }
394     // If there is dynamic dimension in the output, need to rewrite the input.
395     need_rewrite = true;
396 
397     binary_mask = PadWithScalar(binary_mask, i, dynamic_size, zero);
398   }
399   if (!need_rewrite) {
400     return Status::OK();
401   }
402   // Step 2.
403   // Do a reverse reshape to flatten the binary mask (with output shape) back to
404   // input shape.
405   HloInstruction* input_shape_binary_mask = comp->AddInstruction(
406       HloInstruction::CreateReshape(mask_input_shape, binary_mask));
407 
408   // Step 3. Do a cumsum on the binary mask.
409   auto embedded_builder = HloComputation::Builder("add");
410   {
411     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
412         0, ShapeUtil::MakeShape(S32, {}), "lhs"));
413     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
414         1, ShapeUtil::MakeShape(S32, {}), "rhs"));
415     embedded_builder.AddInstruction(
416         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
417   }
418 
419   HloComputation* add =
420       reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build());
421   Window cumsum_window;
422   // First dimension is unchanged.
423   WindowDimension* dim = cumsum_window.add_dimensions();
424   dim->set_size(operand_shape.dimensions(input_dim));
425   dim->set_stride(1);
426   dim->set_padding_low(operand_shape.dimensions(input_dim) - 1);
427   dim->set_padding_high(0);
428   dim->set_window_dilation(1);
429   dim->set_base_dilation(1);
430   HloInstruction* cumsum =
431       comp->AddInstruction(HloInstruction::CreateReduceWindow(
432           mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add));
433 
434   HloInstruction* broadcast_ones = comp->AddInstruction(
435       HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
436   cumsum = comp->AddInstruction(HloInstruction::CreateBinary(
437       mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones));
438 
439   GatherDimensionNumbers gather_dim_numbers;
440   // Use gather to rearrange the input dim dimension.
441   for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
442     // Offset dim is every dimension including newly added size 1 dim, except
443     // for input_dim, which acts as a batch_dim.
444     if (i != input_dim) {
445       gather_dim_numbers.add_offset_dims(i);
446     }
447   }
448   // The dimension to rewrite is the index dim.
449   gather_dim_numbers.add_start_index_map(input_dim);
450   gather_dim_numbers.set_index_vector_dim(1);
451   gather_dim_numbers.add_collapsed_slice_dims(input_dim);
452 
453   // Step 4. Gather.
454 
455   // Temporarily removes dynamic dimension before entering gather -- we want the
456   // gather to ignore dynamic dimension.
457   HloInstruction* operand_static_dim_size =
458       comp->AddInstruction(HloInstruction::CreateConstant(
459           LiteralUtil::CreateR0<int32>(operand_shape.dimensions(input_dim))));
460   HloInstruction* operand_static =
461       comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
462           operand_shape, reshape->mutable_operand(0), operand_static_dim_size,
463           input_dim));
464 
465   std::vector<int64> slice_sizes(operand_shape.dimensions().begin(),
466                                  operand_shape.dimensions().end());
467   slice_sizes[input_dim] = 1;
468   HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather(
469       ShapeUtil::MakeShape(operand_shape.element_type(),
470                            operand_shape.dimensions()),
471       operand_static, cumsum, gather_dim_numbers, slice_sizes, true));
472 
473   // Step 6: Feed gather input to original reshape.
474 
475   TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather));
476 
477   HloInstruction* reshape_dynamic = reshape;
478 
479   auto users = reshape->users();
480 
481   // Forward the output dynamic dimension.
482   for (int64_t output_dim : output_dims) {
483     HloInstruction* output_dynamic_size =
484         dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
485     if (output_dynamic_size != nullptr) {
486       reshape_dynamic =
487           comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
488               reshape->shape(), reshape_dynamic, output_dynamic_size,
489               output_dim));
490     }
491   }
492 
493   for (auto* user : users) {
494     TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, reshape_dynamic));
495   }
496   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
497       reshape, reshape_dynamic, {}));
498 
499   return Status::OK();
500 }
501 
502 // RewriteDynamicReshapeCombineInput is similar to
503 // RewriteDynamicReshapeSplitInput, in a reshape if multiple dimensions are
504 // combined into one dimension, we need to rewrite the output.
505 //
506 // The reason for this is that a continuous input may not be evenly reshaped
507 // into output.  Image we have [2, <=3] where second dimension has size 2 and
508 // padding(P) data has size 1:
509 // [[a,b,P]
510 //  [c,d,P]]
511 //
512 // And we have a reshape that combines this two input dimensions.
513 //
514 // [2, <=3]
515 //  |
516 // Reshape
517 //  |
518 // [6]
519 //
520 // This should produce the same result as if the data has no padding:
521 //
522 // [2, 2]     // [[a, b], [c, d]]
523 //  |
524 // Reshape
525 //  |
526 // [4]  // [a,b,c,d]
527 //
528 // Without rewriting, the result would be:
529 //
530 // [a,b,P,c,d,P], which is incorrect.
531 //
532 // We need to rewrite the reshape such that it produces:
533 // [a,b,c,d,P,P]
534 //
535 // The way we do this is by a 5-steps sort-gather algorithm:
536 //
537 // 1.First we use the input shape to generate a binary 0-1 masking, which masks
538 // out the padded area of the output:
539 // [[0,0,1]
540 //  [0,0,1]]
541 //
542 // 2.Then we do an reshape to reshape the mask from input shape to output
543 // shape [2,3]->[6]:
544 //  [0,0,1,0,0,1]
545 //
546 // 3.We then generate an iota mask using the output shape:
547 //  [0,1,2,3,4,5]
548 //
549 // 4.Stable sort the iota mask using the binary mask as key:
550 //  key  [0,0,1,0,0,1]
551 //  value[0,1,2,3,4,5]
552 //     | Sort by key
553 //     v
554 //  key  [0,0,0,0,1,1]
555 //  value[0,1,3,4,2,5]
556 //
557 // 5.Gather the original output [a,b,P,c,d,P] using the sorted iota mask:
558 //      original output       gather indices
559 //       [a,b,P,c,d,P]         [0,1,3,4,2,5]
560 //            |                    |
561 //          Gather ----------------+
562 //            |
563 //       [a,b,c,d,P,P]
564 //
RewriteDynamicReshapeCombineInput(HloInstruction * reshape,absl::Span<const int64> input_dims,int64_t output_dim,absl::Span<HloInstruction * > input_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)565 Status RewriteDynamicReshapeCombineInput(
566     HloInstruction* reshape, absl::Span<const int64> input_dims,
567     int64_t output_dim, absl::Span<HloInstruction*> input_dynamic_dims,
568     DynamicDimensionInference* dynamic_dimension_inference) {
569   // Rewrite dynamic reshape into reshape followed by a sort, all padded
570   // data will be moved to the end.
571   HloComputation* comp = reshape->parent();
572   HloInstruction* zero = comp->AddInstruction(
573       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
574   HloInstruction* one = comp->AddInstruction(
575       HloInstruction::CreateConstant(LiteralUtil::One(S32)));
576   const Shape output_shape = reshape->shape();
577   const Shape input_shape = reshape->operand(0)->shape();
578   const Shape mask_output_shape =
579       ShapeUtil::MakeShape(xla::S32, {output_shape.dimensions(output_dim)});
580   std::vector<int64> input_dim_sizes;
581   for (int64_t input_dim : input_dims) {
582     input_dim_sizes.push_back(input_shape.dimensions(input_dim));
583   }
584 
585   const Shape mask_input_shape =
586       ShapeUtil::MakeShape(xla::S32, input_dim_sizes);
587 
588   // Step 1 -- generate binary mask.
589   // Mask starts with all zero, each dynamic dimension sets that dimension of
590   // the mask to partially ones in the end.
591   HloInstruction* binary_mask = comp->AddInstruction(
592       HloInstruction::CreateBroadcast(mask_input_shape, zero, {}));
593 
594   bool need_rewrite = false;
595 
596   // Pad the effective dimension with 1.
597   //
598   // Index starts from 1 since there is no need to rewrite a major output
599   // dimension.
600   for (int64_t i = 1; i < input_dims.size(); ++i) {
601     const int64_t input_dim = input_dims[i];
602     HloInstruction* dynamic_size = input_dynamic_dims[input_dim];
603     if (dynamic_size == nullptr) {
604       continue;
605     }
606     // If there is a dynamic dimension in the input, need to rewrite the output.
607     need_rewrite = true;
608 
609     binary_mask = PadWithScalar(binary_mask, i, dynamic_size, one);
610   }
611   if (!need_rewrite) {
612     VLOG(2) << "No need to rewrite";
613     return Status::OK();
614   }
615 
616   // Step 2.
617   // Do a reshape to flatten the binary mask into output_shape
618   HloInstruction* output_shape_binary_mask = comp->AddInstruction(
619       HloInstruction::CreateReshape(mask_output_shape, binary_mask));
620 
621   // Step 3.
622   // Generate an iota with output shape.
623   HloInstruction* iota =
624       comp->AddInstruction(HloInstruction::CreateIota(mask_output_shape, 0));
625 
626   // Step 4.
627   // Stable sort the iota mask using the binary mask as key and iota as value:
628 
629   // Build computation for sort, key is the mask, value is the iota.
630   HloComputation::Builder comp_builder("compare");
631   HloInstruction* lhs_key =
632       comp_builder.AddInstruction(HloInstruction::CreateParameter(
633           0, ShapeUtil::MakeScalarShape(S32), "lhs_key"));
634   HloInstruction* rhs_key =
635       comp_builder.AddInstruction(HloInstruction::CreateParameter(
636           1, ShapeUtil::MakeScalarShape(S32), "rhs_key"));
637 
638   // Values for lhs and rhs
639   comp_builder.AddInstruction(HloInstruction::CreateParameter(
640       2, ShapeUtil::MakeScalarShape(S32), "lhs_value"));
641   comp_builder.AddInstruction(HloInstruction::CreateParameter(
642       3, ShapeUtil::MakeScalarShape(S32), "rhs_value"));
643   comp_builder.AddInstruction(
644       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
645                                     rhs_key, ComparisonDirection::kLt));
646   HloComputation* compare =
647       comp->parent()->AddEmbeddedComputation(comp_builder.Build());
648 
649   // Use mask_reshaped as key, sort reshaped data as value.
650   HloInstruction* sort = comp->AddInstruction(HloInstruction::CreateSort(
651       ShapeUtil::MakeTupleShape({mask_output_shape, mask_output_shape}), 0,
652       {output_shape_binary_mask, iota}, compare,
653       /*is_stable=*/true));
654 
655   HloInstruction* gather_indices = comp->AddInstruction(
656       HloInstruction::CreateGetTupleElement(mask_output_shape, sort, 1));
657 
658   // Step 5.Gather the original output using the sorted iota mask:
659 
660   GatherDimensionNumbers gather_dim_numbers;
661   // Use gather to rearrange the output dim dimension.
662   for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
663     // Offset dim is every dimension including newly added size 1 dim, except
664     // for input_dim, which acts as a batch_dim.
665     if (i != output_dim) {
666       gather_dim_numbers.add_offset_dims(i);
667     }
668   }
669   // The dimension to rewrite is the index dim.
670   gather_dim_numbers.add_start_index_map(output_dim);
671   gather_dim_numbers.set_index_vector_dim(1);
672   gather_dim_numbers.add_collapsed_slice_dims(output_dim);
673 
674   HloInstruction* static_dim_size = comp->AddInstruction(
675       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
676           reshape->shape().dimensions(output_dim))));
677 
678   // Temporarily removes dynamic dimension of the reshape before we send it to
679   // the sort -- we want padded area to also participate in the gather.
680   HloInstruction* reshape_static =
681       comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
682           reshape->shape(), reshape, static_dim_size, output_dim));
683   std::vector<int64> gather_slice_sizes(output_shape.dimensions().begin(),
684                                         output_shape.dimensions().end());
685   gather_slice_sizes[output_dim] = 1;
686   HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather(
687       output_shape, reshape_static, gather_indices, gather_dim_numbers,
688       gather_slice_sizes, true));
689 
690   // Forward dynamic size to the newly created gather.
691   HloInstruction* output_dynamic_size =
692       dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
693   TF_RET_CHECK(output_dynamic_size != nullptr);
694   gather = comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
695       gather->shape(), gather, output_dynamic_size, output_dim));
696   auto users = reshape->users();
697   for (auto* user : users) {
698     // Avoid cycles by not replacing the static reshape and get_dimension_size.
699     if (user != reshape_static && user != output_dynamic_size) {
700       TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather));
701     }
702   }
703 
704   if (reshape == comp->root_instruction()) {
705     comp->set_root_instruction(gather);
706   }
707 
708   TF_RETURN_IF_ERROR(
709       dynamic_dimension_inference->ForwardDynamicSize(reshape, gather, {}));
710 
711   return Status::OK();
712 }
713 
RewriteDynamicReshapeSingleGroup(HloInstruction * reshape,absl::Span<const int64> input_dims,absl::Span<const int64> output_dims,absl::Span<HloInstruction * > input_dynamic_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)714 Status RewriteDynamicReshapeSingleGroup(
715     HloInstruction* reshape, absl::Span<const int64> input_dims,
716     absl::Span<const int64> output_dims,
717     absl::Span<HloInstruction*> input_dynamic_dims,
718     absl::Span<HloInstruction*> output_dynamic_dims,
719     DynamicDimensionInference* dynamic_dimension_inference) {
720   VLOG(2) << "Rewriting dynamic reshape " << reshape->ToString()
721           << " input dims: " << VectorString(input_dims)
722           << " output dims: " << VectorString(output_dims);
723 
724   const Shape operand_shape = reshape->operand(0)->shape();
725   const Shape output_shape = reshape->shape();
726 
727   if (input_dims.size() == 1) {
728     int64_t input_dim = input_dims[0];
729     // Size 1 dimension doesn't need a rewrite.
730     if (operand_shape.dimensions()[input_dim] == 1) {
731       return Status::OK();
732     }
733     // One input dimension is splitted into multiple output dimensions.
734     return RewriteDynamicReshapeSplitInput(reshape, input_dim, output_dims,
735                                            output_dynamic_dims,
736                                            dynamic_dimension_inference);
737   }
738 
739   if (output_dims.size() == 1) {
740     int64_t output_dim = output_dims[0];
741     if (output_shape.dimensions()[output_dim] == 1) {
742       return Status::OK();
743     }
744     // One input dimension is splitted into multiple output dimensions.
745     return RewriteDynamicReshapeCombineInput(reshape, input_dims, output_dim,
746                                              input_dynamic_dims,
747                                              dynamic_dimension_inference);
748   }
749   // Shouldn't get here;
750   TF_RET_CHECK(false);
751   return Status::OK();
752 }
753 
RewriteReverse(HloInstruction * reverse,DynamicDimensionInference * dynamic_dimension_inference)754 StatusOr<bool> RewriteReverse(
755     HloInstruction* reverse,
756     DynamicDimensionInference* dynamic_dimension_inference) {
757   // When we have [A, B, C, D, E] and reverse them, we get [E, D, C, B, A].
758   // However, if the dynamic size is 2, we expect B, A to be in front:
759   // [B, A, P, P, P].
760   //
761   // We do this by running a pad and dynamic slice on the result:
762   // [A, B, C, D, E]
763   //      |
764   //    reverse
765   //      |
766   // [E, D, C, B, A]
767   //      |
768   //     pad # Use pad to double the size of the dimension to avoid OOB.
769   //      |
770   // [E, D, C, B, A, P, P, P, P, P]
771   //      |
772   //  dynamic slice
773   //      |
774   // [B, A, P, P, P]
775   auto reverse_dims = reverse->dimensions();
776   HloComputation* comp = reverse->parent();
777   const Shape& reverse_shape = reverse->shape();
778   std::set<int64> dynamic_reverse_dims;
779   for (int64_t reverse_dim : reverse_dims) {
780     HloInstruction* dynamic_size =
781         dynamic_dimension_inference->GetDynamicSize(reverse, {}, reverse_dim);
782     if (dynamic_size == nullptr) {
783       // Reverse dimension is not dynamic -- no rewrite needed.
784       continue;
785     }
786     dynamic_reverse_dims.insert(reverse_dim);
787   }
788 
789   if (dynamic_reverse_dims.empty()) {
790     // We only need to rewrite dynamic dimensions that are also reverse
791     // dimensions.
792     return false;
793   }
794 
795   PaddingConfig padding;
796   // Doubles dynamic dimension size using a pad.
797   Shape pad_shape = reverse_shape;
798   for (int i = 0; i < reverse_shape.rank(); ++i) {
799     auto dimension = padding.add_dimensions();
800     if (dynamic_reverse_dims.count(i) > 0) {
801       dimension->set_edge_padding_low(0);
802       dimension->set_edge_padding_high(reverse_shape.dimensions(i));
803       dimension->set_interior_padding(0);
804       pad_shape.set_dimensions(i, 2 * pad_shape.dimensions(i));
805     }
806   }
807   HloInstruction* cloned_reverse = comp->AddInstruction(reverse->Clone());
808   HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
809       LiteralUtil::Zero(pad_shape.element_type())));
810   HloInstruction* pad = comp->AddInstruction(
811       HloInstruction::CreatePad(pad_shape, cloned_reverse, zero, padding));
812   std::vector<HloInstruction*> start_indices;
813   start_indices.reserve(reverse_shape.rank());
814   for (int i = 0; i < reverse_shape.rank(); ++i) {
815     if (dynamic_reverse_dims.count(i) > 0) {
816       // Start at bound_size - dynamic_size.
817       HloInstruction* bound_size =
818           comp->AddInstruction(HloInstruction::CreateConstant(
819               LiteralUtil::CreateR0<int32>(reverse_shape.dimensions(i))));
820       HloInstruction* dynamic_size =
821           dynamic_dimension_inference->GetDynamicSize(reverse, {}, i);
822       HloInstruction* start_offset =
823           comp->AddInstruction(HloInstruction::CreateBinary(
824               ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract, bound_size,
825               dynamic_size));
826       start_indices.push_back(start_offset);
827     } else {
828       HloInstruction* zero = comp->AddInstruction(
829           HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
830       start_indices.push_back(zero);
831     }
832   }
833   HloInstruction* dynamic_reverse =
834       comp->AddInstruction(HloInstruction::CreateDynamicSlice(
835           reverse_shape, pad, start_indices, reverse_shape.dimensions()));
836   TF_RETURN_IF_ERROR(comp->ReplaceInstruction(reverse, dynamic_reverse));
837   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
838       reverse, dynamic_reverse, {}));
839   return true;
840 }
841 
RewriteInputWithDynamicPadding(HloInstruction * conv,HloInstruction * input,HloInstruction * padding_value,absl::Span<HloInstruction * > padding_before,Window * input_window,std::function<int64 (int64_t)> window_dim_to_shape_dim)842 HloInstruction* RewriteInputWithDynamicPadding(
843     HloInstruction* conv, HloInstruction* input, HloInstruction* padding_value,
844     absl::Span<HloInstruction*> padding_before, Window* input_window,
845     std::function<int64(int64_t)> window_dim_to_shape_dim) {
846   HloComputation* comp = conv->parent();
847   HloInstruction* zero_s32 = comp->AddInstruction(
848       HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
849   // Padded shape represents the bounded shape after dynamic padding.
850   Shape padded_shape = input->shape();
851   PaddingConfig padding_configs;
852   for (int64_t i = 0; i < input->shape().rank(); ++i) {
853     PaddingConfig::PaddingConfigDimension padding_dim;
854     *padding_configs.add_dimensions() = padding_dim;
855   }
856   std::vector<HloInstruction*> start_indices(input->shape().rank(), zero_s32);
857   for (int64_t dim_index = 0; dim_index < input_window->dimensions_size();
858        ++dim_index) {
859     if (padding_before[dim_index] == nullptr) {
860       continue;
861     }
862     int64_t shape_dim = window_dim_to_shape_dim(dim_index);
863 
864     WindowDimension* window_dim = input_window->mutable_dimensions(dim_index);
865     auto* padding_dim = padding_configs.mutable_dimensions(shape_dim);
866     const int64_t dilated_window_size = window_util::DilatedBound(
867         window_dim->size(), window_dim->window_dilation());
868     // Use dilated window size as low padding and static padding_high +
869     // padding_low as high padding to make sure the following dynamic slice is
870     // valid and doesn't go out of bound.
871     //
872     // See go/xla-dynamic-spatial-dim for more details.
873     padding_dim->set_edge_padding_low(dilated_window_size);
874     padding_dim->set_edge_padding_high(window_dim->padding_high() +
875                                        window_dim->padding_low());
876     padding_dim->set_interior_padding(window_dim->base_dilation() - 1);
877     HloInstruction* slicing_start =
878         comp->AddInstruction(HloInstruction::CreateBinary(
879             ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract,
880             comp->AddInstruction(HloInstruction::CreateConstant(
881                 LiteralUtil::CreateR0<int32>(padding_dim->edge_padding_low()))),
882             padding_before[dim_index]));
883     start_indices[shape_dim] = slicing_start;
884 
885     padded_shape.mutable_dimensions()[shape_dim] =
886         window_dim->padding_low() +
887         window_util::DilatedBound(padded_shape.dimensions(shape_dim),
888                                   window_dim->base_dilation()) +
889         window_dim->padding_high();
890     window_dim->clear_padding_high();
891     window_dim->clear_padding_low();
892     window_dim->set_base_dilation(1);
893     input->mutable_shape()->set_dynamic_dimension(shape_dim, false);
894   }
895   // Reconstruct dynamic padding using pad and dynamic slice.
896 
897   HloInstruction* pad =
898       MakePadHlo(input, padding_value, padding_configs).ValueOrDie();
899   input = comp->AddInstruction(HloInstruction::CreateDynamicSlice(
900       padded_shape, pad, start_indices, padded_shape.dimensions()));
901   return input;
902 }
903 
RewriteDynamicConvolutionInputGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)904 StatusOr<bool> RewriteDynamicConvolutionInputGrad(
905     HloInstruction* custom_call_conv,
906     DynamicDimensionInference* dynamic_dimension_inference) {
907   HloInstruction* grad = custom_call_conv->mutable_operand(1);
908   HloInstruction* kernel = custom_call_conv->mutable_operand(2);
909   TF_RET_CHECK(kernel->shape().is_static());
910   auto dnums = custom_call_conv->convolution_dimension_numbers();
911   HloComputation* comp = custom_call_conv->parent();
912   Window window = custom_call_conv->window();
913   HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
914       LiteralUtil::Zero(custom_call_conv->shape().element_type())));
915   std::vector<HloInstruction*> padding_before(
916       dnums.input_spatial_dimensions_size(), nullptr);
917   for (int64_t spatial_dim_index = 0;
918        spatial_dim_index < dnums.input_spatial_dimensions_size();
919        ++spatial_dim_index) {
920     int64_t input_spatial_dim =
921         dnums.input_spatial_dimensions(spatial_dim_index);
922     HloInstruction* operand_dynamic_size =
923         dynamic_dimension_inference->GetDynamicSize(
924             custom_call_conv->mutable_operand(1), {}, input_spatial_dim);
925     if (operand_dynamic_size == nullptr) {
926       continue;
927     }
928     grad = PadWithScalar(grad, input_spatial_dim, operand_dynamic_size, zero);
929     HloInstruction* slice = comp->AddInstruction(HloInstruction::CreateSlice(
930         ShapeUtil::MakeShape(S32, {1}), custom_call_conv->mutable_operand(0),
931         {input_spatial_dim}, {input_spatial_dim + 1}, {1}));
932     HloInstruction* dynamic_input_size = comp->AddInstruction(
933         HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
934     const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
935     // Window stride of forward prop is same as base dilation of backward prop.
936     DynamicWindowDims dynamic_window_dims = GetWindowedInputGradSize(
937         dynamic_input_size, /*window_size=*/window_dim.size(),
938         /*window_dilation=*/window_dim.window_dilation(),
939         /*window_stride=*/window_dim.base_dilation(),
940         custom_call_conv->padding_type());
941     padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
942   }
943 
944   if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
945     grad = RewriteInputWithDynamicPadding(
946         custom_call_conv, grad, zero, absl::MakeSpan(padding_before), &window,
947         [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
948   }
949 
950   PrecisionConfig precision_config;
951   if (custom_call_conv->precision_config().operand_precision_size() == 3) {
952     // We are not interested in the precision config of the first operand, which
953     // is the input_sizes.
954     *precision_config.mutable_operand_precision() = {
955         custom_call_conv->precision_config().operand_precision().begin() + 1,
956         custom_call_conv->precision_config().operand_precision().end()};
957   }
958   HloInstruction* static_conv = comp->AddInstruction(
959       HloInstruction::CreateConvolve(
960           custom_call_conv->shape(), grad, kernel,
961           custom_call_conv->feature_group_count(),
962           custom_call_conv->batch_group_count(), window,
963           custom_call_conv->convolution_dimension_numbers(),
964           custom_call_conv->precision_config()),
965       "ConvBackwardInput");
966   TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
967   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
968       custom_call_conv, static_conv, {}));
969   return true;
970 }
971 
RewriteDynamicConvolutionForward(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)972 StatusOr<bool> RewriteDynamicConvolutionForward(
973     HloInstruction* custom_call_conv,
974     DynamicDimensionInference* dynamic_dimension_inference) {
975   HloInstruction* input = custom_call_conv->mutable_operand(0);
976   HloInstruction* kernel = custom_call_conv->mutable_operand(1);
977   TF_RET_CHECK(kernel->shape().is_static());
978   TF_RET_CHECK(input->shape().is_dynamic());
979   HloComputation* comp = custom_call_conv->parent();
980   Window window = custom_call_conv->window();
981   auto dnums = custom_call_conv->convolution_dimension_numbers();
982   HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
983       LiteralUtil::Zero(custom_call_conv->shape().element_type())));
984   std::vector<HloInstruction*> padding_before(
985       dnums.input_spatial_dimensions_size(), nullptr);
986   for (int64_t spatial_dim_index = 0;
987        spatial_dim_index < dnums.input_spatial_dimensions_size();
988        ++spatial_dim_index) {
989     int64_t input_spatial_dim =
990         dnums.input_spatial_dimensions(spatial_dim_index);
991     HloInstruction* operand_dynamic_size =
992         dynamic_dimension_inference->GetDynamicSize(
993             custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
994     if (operand_dynamic_size == nullptr) {
995       continue;
996     }
997 
998     input = PadWithScalar(input, input_spatial_dim, operand_dynamic_size, zero);
999     const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1000     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1001         operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1002         window_dim.stride(), custom_call_conv->padding_type());
1003     padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1004   }
1005   // Input feature dim can be dynamic too, reset it to zero.
1006   const int64_t input_feature_dim = dnums.input_feature_dimension();
1007   if (HloInstruction* input_feature_dynamic_size =
1008           dynamic_dimension_inference->GetDynamicSize(
1009               custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1010     input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size,
1011                           zero);
1012   }
1013 
1014   if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1015     input = RewriteInputWithDynamicPadding(
1016         custom_call_conv, input, zero, absl::MakeSpan(padding_before), &window,
1017         [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1018   }
1019 
1020   HloInstruction* static_conv = comp->AddInstruction(
1021       HloInstruction::CreateConvolve(
1022           custom_call_conv->shape(), input, kernel,
1023           custom_call_conv->feature_group_count(),
1024           custom_call_conv->batch_group_count(), window,
1025           custom_call_conv->convolution_dimension_numbers(),
1026           custom_call_conv->precision_config()),
1027       "ConvForward");
1028   TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1029   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1030       custom_call_conv, static_conv, {}));
1031   return true;
1032 }
1033 
RewriteDynamicConvolutionKernelGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)1034 StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
1035     HloInstruction* custom_call_conv,
1036     DynamicDimensionInference* dynamic_dimension_inference) {
1037   HloInstruction* activations = custom_call_conv->mutable_operand(0);
1038   HloInstruction* gradients = custom_call_conv->mutable_operand(1);
1039   TF_RET_CHECK(activations->shape().is_dynamic());
1040   TF_RET_CHECK(gradients->shape().is_dynamic());
1041   HloComputation* comp = custom_call_conv->parent();
1042   Window window = custom_call_conv->window();
1043   auto dnums = custom_call_conv->convolution_dimension_numbers();
1044   HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
1045       LiteralUtil::Zero(custom_call_conv->shape().element_type())));
1046   std::vector<HloInstruction*> padding_before(
1047       dnums.input_spatial_dimensions_size(), nullptr);
1048   for (int64_t spatial_dim_index = 0;
1049        spatial_dim_index < dnums.input_spatial_dimensions_size();
1050        ++spatial_dim_index) {
1051     int64_t input_spatial_dim =
1052         dnums.input_spatial_dimensions(spatial_dim_index);
1053     int64_t kernel_spatial_dim =
1054         dnums.kernel_spatial_dimensions(spatial_dim_index);
1055     HloInstruction* activations_dynamic_size =
1056         dynamic_dimension_inference->GetDynamicSize(
1057             custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
1058     if (activations_dynamic_size != nullptr) {
1059       activations = PadWithScalar(activations, input_spatial_dim,
1060                                   activations_dynamic_size, zero);
1061     }
1062 
1063     HloInstruction* gradients_dynamic_size =
1064         dynamic_dimension_inference->GetDynamicSize(
1065             custom_call_conv->mutable_operand(1), {}, kernel_spatial_dim);
1066     if (gradients_dynamic_size != nullptr) {
1067       gradients = PadWithScalar(gradients, kernel_spatial_dim,
1068                                 gradients_dynamic_size, zero);
1069     }
1070     if (activations_dynamic_size == nullptr ||
1071         gradients_dynamic_size == nullptr) {
1072       TF_RET_CHECK(activations_dynamic_size == nullptr &&
1073                    gradients_dynamic_size == nullptr);
1074       continue;
1075     }
1076     int64_t output_spatial_dim =
1077         dnums.output_spatial_dimensions(spatial_dim_index);
1078     const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1079     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1080         activations_dynamic_size, /*window_size=*/
1081         custom_call_conv->shape().dimensions(output_spatial_dim),
1082         /*window_dilation=*/window_dim.stride(),
1083         /*window_stride=*/window_dim.window_dilation(),
1084         custom_call_conv->padding_type());
1085     padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1086   }
1087 
1088   // We only need to pad input feature on lhs to 0 -- it's mathematically
1089   // equivalent to padding both lhs and rhs to 0.
1090   const int64_t input_feature_dim = dnums.input_feature_dimension();
1091   if (HloInstruction* input_feature_dynamic_size =
1092           dynamic_dimension_inference->GetDynamicSize(
1093               custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1094     activations = PadWithScalar(activations, input_feature_dim,
1095                                 input_feature_dynamic_size, zero);
1096   }
1097 
1098   if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1099     activations = RewriteInputWithDynamicPadding(
1100         custom_call_conv, activations, zero, absl::MakeSpan(padding_before),
1101         &window,
1102         [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1103   }
1104 
1105   HloInstruction* static_conv = comp->AddInstruction(
1106       HloInstruction::CreateConvolve(
1107           custom_call_conv->shape(), activations, gradients,
1108           custom_call_conv->feature_group_count(),
1109           custom_call_conv->batch_group_count(), window,
1110           custom_call_conv->convolution_dimension_numbers(),
1111           custom_call_conv->precision_config()),
1112       "ConvBackwardGrad");
1113   TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1114   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1115       custom_call_conv, static_conv, {}));
1116   return true;
1117 }
1118 
RewriteDynamicReduceWindowSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1119 StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
1120     HloInstruction* hlo,
1121     DynamicDimensionInference* dynamic_dimension_inference) {
1122   if (hlo->shape().IsTuple()) {
1123     // TODO (b/73062247) variadic reduce window is not yet supported here.
1124     return Unimplemented("DynamicReduceWindowSamePadding not yet supported.");
1125   }
1126   HloInstruction* input = hlo->mutable_operand(0);
1127   HloInstruction* init = hlo->mutable_operand(1);
1128   HloComputation* comp = hlo->parent();
1129   int64_t rank = hlo->shape().rank();
1130   Window window = hlo->window();
1131   std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1132   for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1133     HloInstruction* operand_dynamic_size =
1134         dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1135                                                     dim_index);
1136     if (operand_dynamic_size == nullptr) {
1137       continue;
1138     }
1139     const WindowDimension& window_dim = window.dimensions(dim_index);
1140     if (window_util::IsTrivialWindowDimension(window_dim)) {
1141       continue;
1142     }
1143     input = PadWithScalar(input, dim_index, operand_dynamic_size, init);
1144 
1145     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1146         operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1147         window_dim.stride(), PaddingType::PADDING_SAME);
1148     padding_before[dim_index] = dynamic_window_dims.padding_before;
1149   }
1150 
1151   input = RewriteInputWithDynamicPadding(
1152       hlo, input, init, absl::MakeSpan(padding_before), &window,
1153       [](int64_t dim) { return dim; });
1154 
1155   HloInstruction* rewritten = comp->AddInstruction(
1156       HloInstruction::CreateReduceWindow(hlo->shape(), input, init, window,
1157                                          hlo->called_computations()[0]),
1158       "DynamicReduceWindow");
1159   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1160   TF_RETURN_IF_ERROR(
1161       dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1162   return true;
1163 }
1164 
RewriteDynamicSelectAndScatterSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1165 StatusOr<bool> RewriteDynamicSelectAndScatterSamePadding(
1166     HloInstruction* hlo,
1167     DynamicDimensionInference* dynamic_dimension_inference) {
1168   HloInstruction* input = hlo->mutable_operand(0);
1169   HloInstruction* source = hlo->mutable_operand(1);
1170   HloInstruction* init = hlo->mutable_operand(2);
1171   TF_ASSIGN_OR_RETURN(HloInstruction * input_padding_value,
1172                       ChooseIdentityValue(hlo, /*operand_number=*/0));
1173   HloComputation* comp = hlo->parent();
1174   int64_t rank = hlo->shape().rank();
1175   Window window = hlo->window();
1176   std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1177   for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1178     const WindowDimension& window_dim = window.dimensions(dim_index);
1179     if (window_util::IsTrivialWindowDimension(window_dim)) {
1180       continue;
1181     }
1182     HloInstruction* operand_dynamic_size =
1183         dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1184                                                     dim_index);
1185     if (operand_dynamic_size == nullptr) {
1186       continue;
1187     }
1188 
1189     input = PadWithScalar(input, dim_index, operand_dynamic_size,
1190                           input_padding_value);
1191 
1192     HloInstruction* source_dynamic_size =
1193         dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(1), {},
1194                                                     dim_index);
1195     if (source_dynamic_size == nullptr) {
1196       continue;
1197     }
1198     source = PadWithScalar(source, dim_index, source_dynamic_size, init);
1199 
1200     DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1201         operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1202         window_dim.stride(), PaddingType::PADDING_SAME);
1203     padding_before[dim_index] = dynamic_window_dims.padding_before;
1204   }
1205 
1206   input = RewriteInputWithDynamicPadding(
1207       hlo, input, input_padding_value, absl::MakeSpan(padding_before), &window,
1208       [](int64_t dim) { return dim; });
1209 
1210   // RewriteInputWithDynamicPadding adds padding to the input. However those
1211   // inputs should not be materialized in select and scatter's output and we
1212   // need to slice them out using dynamic slice. To prevent dynamic slicegoing
1213   // OOB, we first add some high-pad to the output to leave enough space.
1214   HloInstruction* rewritten = comp->AddInstruction(
1215       HloInstruction::CreateSelectAndScatter(
1216           input->shape(), input, hlo->called_computations()[0], window, source,
1217           init, hlo->called_computations()[1]),
1218       "DynamicReduceWindow");
1219   std::vector<HloInstruction*> start_indices(
1220       input->shape().rank(),
1221       comp->AddInstruction(
1222           HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
1223   PaddingConfig padding_configs;
1224   for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1225     PaddingConfig::PaddingConfigDimension padding_dim;
1226     if (padding_before[dim_index] != nullptr) {
1227       const WindowDimension& window_dim = window.dimensions(dim_index);
1228       const int64_t dilated_window_size = window_util::DilatedBound(
1229           window_dim.size(), window_dim.window_dilation());
1230       padding_dim.set_edge_padding_high(dilated_window_size);
1231       start_indices[dim_index] = padding_before[dim_index];
1232     }
1233     *padding_configs.add_dimensions() = padding_dim;
1234   }
1235   HloInstruction* padded =
1236       MakePadHlo(rewritten, init, padding_configs).ValueOrDie();
1237   rewritten = comp->AddInstruction(HloInstruction::CreateDynamicSlice(
1238       hlo->shape(), padded, start_indices, hlo->shape().dimensions()));
1239   TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1240   TF_RETURN_IF_ERROR(
1241       dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1242   return true;
1243 }
1244 
RewriteDynamicConcat(HloInstruction * concat,DynamicDimensionInference * dynamic_dimension_inference)1245 StatusOr<bool> RewriteDynamicConcat(
1246     HloInstruction* concat,
1247     DynamicDimensionInference* dynamic_dimension_inference) {
1248   const int64_t concat_dim = concat->concatenate_dimension();
1249   HloComputation* comp = concat->parent();
1250   if (dynamic_dimension_inference->GetDynamicSize(concat, {}, concat_dim) ==
1251       nullptr) {
1252     // Concat dimension is not dynamic -- no rewrite needed.
1253     return false;
1254   }
1255   std::vector<HloInstruction*> offsets;
1256   for (int64_t i = 0; i < concat->shape().dimensions_size(); ++i) {
1257     offsets.push_back(comp->AddInstruction(
1258         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))));
1259   }
1260   HloInstruction* rewritten_concat = concat;
1261   // Keep track of previous users before rewrite so that we can update their
1262   // operands later.
1263   auto prev_users = concat->users();
1264   for (int64_t i = 0; i < concat->operand_count(); ++i) {
1265     // Rewrite the concat by dynamic update slicing operand into the concat dim.
1266     HloInstruction* operand = concat->mutable_operand(i);
1267     rewritten_concat =
1268         comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1269             rewritten_concat->shape(), rewritten_concat, operand, offsets));
1270     // Update the offset of concat dimension by adding the size of the concat
1271     // dimension of the operand to it.
1272     HloInstruction* dynamic_size =
1273         dynamic_dimension_inference->GetDynamicSize(operand, {}, concat_dim);
1274     if (dynamic_size == nullptr) {
1275       HloInstruction* static_size = comp->AddInstruction(
1276           HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1277               operand->shape().dimensions(concat_dim))));
1278       offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary(
1279           ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1280           static_size));
1281     } else {
1282       offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary(
1283           ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1284           dynamic_size));
1285     }
1286   }
1287   TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat));
1288   TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1289       concat, rewritten_concat, {}));
1290   return true;
1291 }
1292 
RewriteDynamicSort(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1293 StatusOr<bool> RewriteDynamicSort(
1294     HloInstruction* hlo,
1295     DynamicDimensionInference* dynamic_dimension_inference) {
1296   HloInstruction* dynamic_size = nullptr;
1297   HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
1298   HloComputation* comp = hlo->parent();
1299   int64_t sort_dim = sort->sort_dimension();
1300   // Find the dynamic dimension in the operand.
1301   for (auto* operand : sort->operands()) {
1302     if (dynamic_size == nullptr) {
1303       dynamic_size =
1304           dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim);
1305     }
1306   }
1307 
1308   if (dynamic_size == nullptr) {
1309     // Not a dynamic sort, ignore.
1310     return false;
1311   }
1312 
1313   Shape operand_shape =
1314       ShapeUtil::ChangeElementType(sort->operand(0)->shape(), S32);
1315   HloInstruction* iota =
1316       comp->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim));
1317   HloInstruction* dynamic_size_broadcasted = comp->AddInstruction(
1318       HloInstruction::CreateBroadcast(operand_shape, dynamic_size, {}));
1319   HloInstruction* lt = comp->AddInstruction(HloInstruction::CreateCompare(
1320       ShapeUtil::ChangeElementType(operand_shape, PRED), iota,
1321       dynamic_size_broadcasted, ComparisonDirection::kLt));
1322   sort->AppendOperand(lt);
1323 
1324   const int64_t param_number_before_rewritten =
1325       sort->called_computations()[0]->num_parameters();
1326   auto new_param_0 = HloInstruction::CreateParameter(
1327       param_number_before_rewritten, ShapeUtil::MakeScalarShape(PRED),
1328       "inbound_lhs");
1329   auto new_param_1 = HloInstruction::CreateParameter(
1330       param_number_before_rewritten + 1, ShapeUtil::MakeScalarShape(PRED),
1331       "inbound_rhs");
1332   std::vector<const HloInstruction*> extra_parameters{new_param_0.get(),
1333                                                       new_param_1.get()};
1334   HloComputation* sort_comp = sort->parent()->parent()->AddEmbeddedComputation(
1335       sort->called_computations()[0]->CloneWithReplacements(
1336           /*replacements=*/absl::flat_hash_map<
1337               const HloInstruction*, std::unique_ptr<HloInstruction>>(),
1338           extra_parameters));
1339   auto inbound_lhs =
1340       sort_comp->parameter_instruction(param_number_before_rewritten);
1341   auto inbound_rhs =
1342       sort_comp->parameter_instruction(param_number_before_rewritten + 1);
1343   sort->ReplaceCalledComputations(
1344       [&](HloComputation* comp) { return sort_comp; });
1345 
1346   // inbound_lhs & (sort_comp | !in_bound_rhs)
1347   // Select the lhs if it is in bounds and the rhs is out of bounds or the
1348   // sort_comp returns true.
1349   auto out_of_bound_rhs = sort_comp->AddInstruction(HloInstruction::CreateUnary(
1350       ShapeUtil::MakeScalarShape(PRED), HloOpcode::kNot, inbound_rhs));
1351   auto sort_comp_or_out_of_bound_rhs =
1352       sort_comp->AddInstruction(HloInstruction::CreateBinary(
1353           ShapeUtil::MakeScalarShape(PRED), HloOpcode::kOr,
1354           sort_comp->root_instruction(), out_of_bound_rhs));
1355 
1356   auto new_root = sort_comp->AddInstruction(HloInstruction::CreateBinary(
1357       ShapeUtil::MakeScalarShape(PRED), HloOpcode::kAnd, inbound_lhs,
1358       sort_comp_or_out_of_bound_rhs));
1359   sort_comp->set_root_instruction(new_root);
1360   Shape compare_shape =
1361       ShapeUtil::ChangeElementType(sort->operand(0)->shape(), PRED);
1362   if (sort->shape().IsTuple()) {
1363     // For sort that is already tuple, simply add another result to the tuple.
1364     *sort->mutable_shape()->add_tuple_shapes() =
1365         ShapeUtil::ChangeElementType(operand_shape, PRED);
1366   } else {
1367     auto sort_users = sort->users();
1368     auto sort_clone = comp->AddInstruction(sort->Clone());
1369     *sort_clone->mutable_shape() = ShapeUtil::MakeTupleShape(
1370         {sort->shape(), ShapeUtil::ChangeElementType(operand_shape, PRED)});
1371     auto rewritten_sort = comp->AddInstruction(
1372         HloInstruction::CreateGetTupleElement(sort->shape(), sort_clone, 0));
1373     for (HloInstruction* user : sort_users) {
1374       TF_RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort));
1375     }
1376     TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1377         sort, rewritten_sort, {}));
1378     if (comp->root_instruction() == sort) {
1379       comp->set_root_instruction(rewritten_sort);
1380     }
1381   }
1382 
1383   return true;
1384 }
1385 
RewriteDynamicBinaryOp(HloInstruction * binary,DynamicDimensionInference * dynamic_dimension_inference)1386 StatusOr<bool> RewriteDynamicBinaryOp(
1387     HloInstruction* binary,
1388     DynamicDimensionInference* dynamic_dimension_inference) {
1389   HloInstruction* operand_0 = binary->mutable_operand(0);
1390   HloInstruction* operand_1 = binary->mutable_operand(1);
1391 
1392   HloComputation* comp = binary->parent();
1393   TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
1394   auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
1395   auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
1396   bool changed = false;
1397   for (int64_t i = 0; i < dims_0.size(); ++i) {
1398     HloInstruction* dim_0 = dims_0[i];
1399     HloInstruction* dim_1 = dims_1[i];
1400 
1401     if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
1402         dims_1[i] != nullptr) {
1403       changed = true;
1404       // It is possible that a dynamic dimension of one operand is size 1 while
1405       // the other is greater than one. According to implicit broadcast
1406       // semantics, we need to insert broadcast in this case to make the dynamic
1407       // shape match.
1408 
1409       // An implicit broadcast is inserted by slicing the small shape into a
1410       // size 1 slice, reshape out the size 1 dimension then broadcast to the
1411       // full shape:
1412       //
1413       // Input [2, <=5, 3]
1414       //   |
1415       // Slice [2, 1, 3]
1416       //   |
1417       // Reshape [2, 3]
1418       //   |
1419       // Broadcast [2, 5, 3]
1420       auto rewrite_operand = [&](HloInstruction* pred,
1421                                  HloInstruction* operand) -> HloInstruction* {
1422         Shape static_shape = operand->shape();
1423         static_shape.clear_dynamic_dimensions();
1424         pred = comp->AddInstruction(HloInstruction::CreateBroadcast(
1425             ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
1426         Shape slice_shape = static_shape;
1427         slice_shape.set_dimensions(i, 1);
1428         std::vector<int64> start_indices(slice_shape.rank(), 0);
1429         std::vector<int64> strides(slice_shape.rank(), 1);
1430         HloInstruction* slice = comp->AddInstruction(
1431             HloInstruction::CreateSlice(slice_shape, operand, start_indices,
1432                                         slice_shape.dimensions(), strides));
1433         Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
1434         HloInstruction* reshape = comp->AddInstruction(
1435             HloInstruction::CreateReshape(reshape_shape, slice));
1436         std::vector<int64> broadcast_dims;
1437         broadcast_dims.reserve(static_shape.rank() - 1);
1438         // Broadcast to all dims execpt for i.
1439         for (int64_t j = 0; j < static_shape.rank(); ++j) {
1440           if (j != i) {
1441             broadcast_dims.push_back(j);
1442           }
1443         }
1444 
1445         HloInstruction* broadcast =
1446             comp->AddInstruction(HloInstruction::CreateBroadcast(
1447                                      static_shape, reshape, broadcast_dims),
1448                                  "implicit_broadcast");
1449 
1450         // Use a select instead of conditional as elementwise operations promote
1451         // more fusion.
1452         HloInstruction* select =
1453             comp->AddInstruction(HloInstruction::CreateTernary(
1454                 static_shape, HloOpcode::kSelect, pred, broadcast, operand));
1455         return select;
1456       };
1457       auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
1458           HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
1459                                         dim_1, ComparisonDirection::kLt),
1460           "lhs_needs_implicit_broadcast");
1461       operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
1462 
1463       auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
1464           HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
1465                                         dim_0, ComparisonDirection::kLt),
1466           "rhs_needs_implicit_broadcast");
1467       operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
1468     }
1469   }
1470   if (changed) {
1471     TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
1472     TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
1473   }
1474   return changed;
1475 }
1476 
RewriteDynamicUpdateSlice(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1477 StatusOr<bool> RewriteDynamicUpdateSlice(
1478     HloInstruction* hlo,
1479     DynamicDimensionInference* dynamic_dimension_inference) {
1480   HloDynamicUpdateSliceInstruction* dus =
1481       Cast<HloDynamicUpdateSliceInstruction>(hlo);
1482   HloComputation* comp = hlo->parent();
1483   // Suppose we have a base area that we want to update:
1484   // +------------------------+
1485   // |                        |
1486   // |                  base  |
1487   // |                        |
1488   // +------------------------+
1489   //
1490   // A partial update with dynamic padding looks like this:
1491   //
1492   //           +------+-------+
1493   //           |update|padding|
1494   //           +------+-------+
1495   //
1496   // We don't want the padding to overwrite the base area:
1497   //
1498   // +------------------------+
1499   // |         +------+-------+
1500   // |<-begin->|update|padding| (what we want to avoid)
1501   // |         +------+-------+
1502   // +------------------------+
1503   //
1504   // Instead we want to keep the base area untouched except for the update
1505   // region:
1506   //
1507   // +------------------------+
1508   // |         +------+       |
1509   // |<-begin->|update|  base | (what we want)
1510   // |         +------+       |
1511   // +------------------------+
1512   //
1513   // We do this by dynamic slicing the base area out first with the same begin
1514   // index:
1515   //
1516   //           +--------------+
1517   // <-begin-> |         base |
1518   //           +--------------+
1519   //
1520   // Then replace the update's padding part with base:
1521   //
1522   //           +------+-------+
1523   //           |update|  base |
1524   //           +------+-------+
1525   //
1526   // Then do the DUS.
1527 
1528   HloInstruction* update = dus->mutable_operand(1);
1529   HloInstruction* base = dus->mutable_operand(0);
1530   std::vector<HloInstruction*> dynamic_dims_in_partial_update(
1531       update->shape().rank(), nullptr);
1532   bool needs_rewrite = false;
1533   for (int64_t i = 0; i < update->shape().rank(); ++i) {
1534     if (update->shape().dimensions(i) < base->shape().dimensions(i)) {
1535       HloInstruction* dynamic_dim =
1536           dynamic_dimension_inference->GetDynamicSize(update, {}, i);
1537 
1538       if (dynamic_dim != nullptr) {
1539         dynamic_dims_in_partial_update[i] = dynamic_dim;
1540         needs_rewrite = true;
1541       }
1542     }
1543   }
1544 
1545   if (!needs_rewrite) {
1546     return false;
1547   }
1548   std::vector<HloInstruction*> indices;
1549   indices.reserve(dus->operand_count() - 2);
1550   for (int64_t i = 2; i < dus->operand_count(); ++i) {
1551     indices.push_back(dus->mutable_operand(i));
1552   }
1553   HloInstruction* base_slice =
1554       comp->AddInstruction(HloInstruction::CreateDynamicSlice(
1555           update->shape(), base, indices, update->shape().dimensions()));
1556 
1557   for (int64_t i = 0; i < dynamic_dims_in_partial_update.size(); ++i) {
1558     HloInstruction* dynamic_dim = dynamic_dims_in_partial_update[i];
1559     if (dynamic_dim != nullptr) {
1560       Shape mask_shape_int = ShapeUtil::ChangeElementType(update->shape(), S32);
1561       Shape mask_shape_pred =
1562           ShapeUtil::ChangeElementType(update->shape(), PRED);
1563       // Generate mask using iota and dynamic_dim.
1564       HloInstruction* iota =
1565           comp->AddInstruction(HloInstruction::CreateIota(mask_shape_int, i));
1566       HloInstruction* broadcast_dim = comp->AddInstruction(
1567           HloInstruction::CreateBroadcast(mask_shape_int, dynamic_dim, {}));
1568       HloInstruction* pred = comp->AddInstruction(HloInstruction::CreateCompare(
1569           mask_shape_pred, iota, broadcast_dim, ComparisonDirection::kLt));
1570       // Update `update` to include base.
1571       update = comp->AddInstruction(HloInstruction::CreateTernary(
1572           update->shape(), HloOpcode::kSelect, pred, update, base_slice));
1573     }
1574   }
1575   TF_RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update));
1576 
1577   return true;
1578 }
1579 
RewriteDynamicReshape(HloInstruction * reshape,DynamicDimensionInference * dynamic_dimension_inference)1580 StatusOr<bool> RewriteDynamicReshape(
1581     HloInstruction* reshape,
1582     DynamicDimensionInference* dynamic_dimension_inference) {
1583   bool changed = false;
1584   HloInstruction* operand = reshape->mutable_operand(0);
1585   std::vector<HloInstruction*> input_dynamic_dims;
1586   for (int64_t dim = 0; dim < operand->shape().dimensions_size(); ++dim) {
1587     input_dynamic_dims.push_back(
1588         dynamic_dimension_inference->GetDynamicSize(operand, {}, dim));
1589   }
1590 
1591   std::vector<HloInstruction*> output_dynamic_dims;
1592   for (int64_t dim = 0; dim < reshape->shape().dimensions_size(); ++dim) {
1593     output_dynamic_dims.push_back(
1594         dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim));
1595   }
1596 
1597   auto common_factors = CommonFactors(operand->shape().dimensions(),
1598                                       reshape->shape().dimensions());
1599   // Find common_factors that the input belongs to.
1600   for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1601     auto start = common_factors[i];
1602     auto end = common_factors[i + 1];
1603     std::vector<int64> input_dims;
1604     std::vector<int64> output_dims;
1605     for (int64_t dim = start.first; dim < end.first; ++dim) {
1606       input_dims.push_back(dim);
1607     }
1608     for (int64_t dim = start.second; dim < end.second; ++dim) {
1609       output_dims.push_back(dim);
1610     }
1611 
1612     VLOG(2) << "input_dims: " << VectorString(input_dims);
1613     VLOG(2) << "output_dims: " << VectorString(output_dims);
1614 
1615     if (input_dims.empty() || output_dims.empty()) {
1616       continue;
1617     }
1618     bool has_dynamic_dimension = absl::c_any_of(output_dims, [&](int64_t dim) {
1619       HloInstruction* operand_dynamic_size =
1620           dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim);
1621 
1622       return operand_dynamic_size != nullptr ||
1623              reshape->shape().is_dynamic_dimension(dim);
1624     });
1625 
1626     if (!has_dynamic_dimension) {
1627       // Don't need to rewrite any group without dynamic dimensions.
1628       VLOG(2) << "All dimensions are static in this common factor group";
1629       continue;
1630     }
1631 
1632     if (input_dims.size() == 1 && output_dims.size() == 1) {
1633       // The dimension is unchanged. No rewrite needed.
1634       continue;
1635     }
1636     if (input_dims.size() > 1 && output_dims.size() > 1) {
1637       // We don't support the case when a dynamic dimension is both combined
1638       // with and splitted into other dimensions:
1639       //
1640       //  [x, yz]
1641       //     | Reshape
1642       //  [xy, z]
1643       //
1644       // TODO(yunxing): This can be supported by canonicalizing
1645       // the offending reshape into two reshapes:
1646       //
1647       //  [x,yz]
1648       //     | Reshape
1649       //  [x, y, z]
1650       //     | Reshape
1651       //  [xy, z]
1652       //
1653       return Unimplemented(
1654           "Dynamic input dimension to reshape that is both splitted and "
1655           "combined is not supported %s",
1656           reshape->ToString());
1657     }
1658 
1659     TF_RETURN_IF_ERROR(RewriteDynamicReshapeSingleGroup(
1660         reshape, input_dims, output_dims, absl::MakeSpan(input_dynamic_dims),
1661         absl::MakeSpan(output_dynamic_dims), dynamic_dimension_inference));
1662   }
1663 
1664   return changed;
1665 }
1666 
1667 // Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
1668 // Recurse into tuple instructions.
InsertPadToStaticOnInstruction(HloInstruction * inst)1669 StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
1670   if (inst->shape().is_static()) {
1671     return inst;
1672   }
1673   HloComputation* comp = inst->parent();
1674   if (!inst->shape().IsTuple()) {
1675     // The output shape of pad static is a tuple. The 0th element is the data
1676     // output, which is the same as input shape, but without dynamic dimensions;
1677     // i-th element is the dynamic dimension size for i-1th input dimension.
1678     Shape data_output_shape = inst->shape();  // 0th element.
1679     data_output_shape.clear_dynamic_dimensions();
1680     Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
1681     for (int64_t i = 0; i < inst->shape().rank(); ++i) {
1682       ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
1683                                     &output_shape);
1684     }
1685     HloInstruction* pad_to_static =
1686         comp->AddInstruction(HloInstruction::CreateCustomCall(
1687             output_shape, {inst}, "PadToStatic", ""));
1688     HloInstruction* data_output =
1689         comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1690             data_output_shape, pad_to_static, 0));
1691     return data_output;
1692   }
1693 
1694   TF_RET_CHECK(inst->shape().IsTuple());
1695   std::vector<HloInstruction*> static_tuple_elements;
1696   for (int64_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
1697     // For each tuple element, if it is static, pass it through. If it is
1698     // dynamic, recursively call this function again.
1699     HloInstruction* gte =
1700         comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1701             inst->shape().tuple_shapes(i), inst, i));
1702 
1703     if (gte->shape().is_static()) {
1704       static_tuple_elements.push_back(gte);
1705     } else {
1706       TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
1707                           InsertPadToStaticOnInstruction(gte));
1708       static_tuple_elements.push_back(static_gte);
1709     }
1710   }
1711 
1712   return comp->AddInstruction(
1713       HloInstruction::CreateTuple(static_tuple_elements));
1714 }
1715 
InsertPadToStaticAfterModuleInputs(HloModule * module)1716 Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
1717   std::vector<HloInstruction*> params;
1718   HloComputation* entry = module->entry_computation();
1719   for (int64_t i = 0; i < entry->num_parameters(); ++i) {
1720     HloInstruction* param =
1721         module->entry_computation()->parameter_instruction(i);
1722     auto users = param->users();
1723     TF_ASSIGN_OR_RETURN(HloInstruction * static_param,
1724                         InsertPadToStaticOnInstruction(param));
1725     for (auto* user : users) {
1726       TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, static_param));
1727     }
1728     if (param == entry->root_instruction()) {
1729       module->entry_computation()->set_root_instruction(static_param);
1730     }
1731   }
1732   return Status::OK();
1733 }
1734 
1735 // Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
1736 //
1737 // After this visitor the entry computation then looks like:
1738 //  Param(dynamic)
1739 //    |
1740 //   GTE (dynamic)
1741 //    |
1742 //  PadToStatic(static)
1743 //    |
1744 //   .... regular computation with static shapes.
1745 //    |
1746 //  SliceToDynamic(dynamic)
1747 //    |
1748 // ROOT tuple (dynamic)
1749 class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
1750  public:
DynamicShapeRemovingVisitor(const DynamicPadder::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_dimension_inference)1751   explicit DynamicShapeRemovingVisitor(
1752       const DynamicPadder::OpSupportsDynamismHandler&
1753           op_supports_dynamism_handler,
1754       DynamicDimensionInference* dynamic_dimension_inference)
1755       : op_supports_dynamism_handler_(op_supports_dynamism_handler),
1756         dynamic_dimension_inference_(dynamic_dimension_inference) {}
1757 
1758   Status DefaultAction(HloInstruction* hlo) override;
1759 
1760   Status HandleCustomCall(HloInstruction* hlo) override;
1761 
1762   Status HandleTuple(HloInstruction* hlo) override;
1763   Status HandleGetTupleElement(HloInstruction* hlo) override;
1764 
1765   Status HandleParameter(HloInstruction* hlo) override;
1766 
Run(HloComputation * computation,const DynamicPadder::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_shape_inference,bool require_dynamic_output)1767   static Status Run(HloComputation* computation,
1768                     const DynamicPadder::OpSupportsDynamismHandler&
1769                         op_supports_dynamism_handler,
1770                     DynamicDimensionInference* dynamic_shape_inference,
1771                     bool require_dynamic_output) {
1772     DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler,
1773                                         dynamic_shape_inference);
1774     TF_RETURN_IF_ERROR(computation->Accept(&visitor));
1775     // If the outputs is required to be dynamic form, insert static to dynamic
1776     // conversion as root.
1777     if (require_dynamic_output) {
1778       HloInstruction* root = computation->root_instruction();
1779       if (dynamic_shape_inference->HasDynamicDimension(root)) {
1780         TF_ASSIGN_OR_RETURN(HloInstruction * new_root,
1781                             visitor.ConvertToDynamic(root));
1782         computation->set_root_instruction(new_root);
1783       }
1784     }
1785     return Status::OK();
1786   }
1787 
1788  private:
1789   // If a tensor produced by `inst` is in dynamic form, convert it to static and
1790   // returns the new instruction.
1791   StatusOr<HloInstruction*> ConvertToStatic(HloInstruction* inst);
1792 
1793   // If a tensor produced by `inst` is in static form, convert it to dynamic and
1794   // returns the new instruction.
1795   StatusOr<HloInstruction*> ConvertToDynamic(HloInstruction* inst);
1796 
1797   const DynamicPadder::OpSupportsDynamismHandler& op_supports_dynamism_handler_;
1798 
1799   DynamicDimensionInference* dynamic_dimension_inference_;
1800 };
1801 
ConvertToDynamic(HloInstruction * inst)1802 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToDynamic(
1803     HloInstruction* inst) {
1804   auto* comp = inst->parent();
1805   const Shape& shape = inst->shape();
1806   if (shape.IsTuple()) {
1807     std::vector<HloInstruction*> dynamic_operands;
1808     for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1809       auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1810           shape.tuple_shapes(i), inst, i));
1811       if (dynamic_dimension_inference_->HasDynamicDimension(inst, {i})) {
1812         TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
1813         TF_ASSIGN_OR_RETURN(auto dynamic, ConvertToDynamic(gte));
1814         dynamic_operands.push_back(dynamic);
1815       } else {
1816         dynamic_operands.push_back(gte);
1817       }
1818     }
1819     return comp->AddInstruction(HloInstruction::CreateTuple(dynamic_operands));
1820   } else {
1821     // Collect the data input, as well as dimension sizes, and feed them to
1822     // slice to dynamic to create a dynamic tensor.
1823     Shape output_shape = shape;  // 0th element.
1824     CHECK(output_shape.is_static());
1825     std::vector<HloInstruction*> slice_operand;
1826     slice_operand.push_back(inst);
1827     for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
1828       auto dimension_size =
1829           dynamic_dimension_inference_->GetDynamicSize(inst, {}, i);
1830       if (dimension_size == nullptr) {
1831         dimension_size = comp->AddInstruction(HloInstruction::CreateConstant(
1832             LiteralUtil::CreateR0<int32>(output_shape.dimensions(i))));
1833       } else {
1834         output_shape.set_dynamic_dimension(i, true);
1835       }
1836       slice_operand.push_back(dimension_size);
1837     }
1838     return comp->AddInstruction(HloInstruction::CreateCustomCall(
1839         output_shape, slice_operand, "SliceToDynamic"));
1840   }
1841 }
1842 
ConvertToStatic(HloInstruction * inst)1843 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToStatic(
1844     HloInstruction* inst) {
1845   auto* comp = inst->parent();
1846   const Shape& shape = inst->shape();
1847   CHECK(shape.is_dynamic());
1848   if (shape.IsTuple()) {
1849     std::vector<HloInstruction*> static_operands;
1850     for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1851       auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1852           shape.tuple_shapes(i), inst, i));
1853       TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
1854       auto operand = inst->mutable_operand(i);
1855       if (shape.tuple_shapes(i).is_dynamic()) {
1856         TF_ASSIGN_OR_RETURN(auto static_inst, ConvertToStatic(gte));
1857         static_operands.push_back(static_inst);
1858       } else {
1859         static_operands.push_back(operand);
1860       }
1861     }
1862     return comp->AddInstruction(HloInstruction::CreateTuple(static_operands));
1863   } else {
1864     // The output shape of pad static is a tuple. The 0th element is the data
1865     // output, which is the same as input shape, but without dynamic dimensions.
1866     // i-th element is the dynamic dimension size for i-1th input dimension.
1867     Shape data_output_shape = shape;  // 0th element.
1868     data_output_shape.clear_dynamic_dimensions();
1869     Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
1870     for (int64_t i = 0; i < shape.rank(); ++i) {
1871       ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
1872                                     &output_shape);
1873     }
1874     HloInstruction* pad_to_static =
1875         comp->AddInstruction(HloInstruction::CreateCustomCall(
1876             output_shape, {inst}, "PadToStatic", ""));
1877     HloInstruction* data_output =
1878         comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1879             data_output_shape, pad_to_static, 0));
1880     return data_output;
1881   }
1882 }
1883 
DefaultAction(HloInstruction * hlo)1884 Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
1885   const bool input_is_dynamic = absl::c_any_of(
1886       hlo->operands(),
1887       [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
1888 
1889   // By default, ops don't support dynamic lowering.
1890   OpDynamismSupport op_support = OpDynamismSupport::kNoSupport;
1891   if (op_supports_dynamism_handler_) {
1892     op_support = op_supports_dynamism_handler_(hlo);
1893   }
1894   if (op_support == OpDynamismSupport::kNoSupport) {
1895     for (auto* sub_computation : hlo->called_computations()) {
1896       for (auto* param : sub_computation->parameter_instructions()) {
1897         param->mutable_shape()->clear_dynamic_dimensions();
1898       }
1899     }
1900   }
1901   // If the input to an op is static and the op doesn't support
1902   // dynamic output, remove dynamism in output -- dynamic_padder should have
1903   // rewritten it to support static shapes.
1904   if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
1905     hlo->mutable_shape()->clear_dynamic_dimensions();
1906     return Status::OK();
1907   }
1908 
1909   // Op doesn't support dynamic tensor: For each operand rewrite dynamic input
1910   // into static input using pad_to_static.
1911   if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
1912     VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString();
1913     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1914       if (hlo->operand(i)->shape().is_dynamic()) {
1915         TF_ASSIGN_OR_RETURN(auto static_operand,
1916                             ConvertToStatic(hlo->mutable_operand(i)));
1917         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand));
1918       }
1919     }
1920     // This op doesn't support dynamic lowering so the op has to be static.
1921     hlo->mutable_shape()->clear_dynamic_dimensions();
1922     return Status::OK();
1923   }
1924 
1925   // If the op requires dynamic tensor and input is static -- construct a
1926   // dynamic tensor from the static tensor to feed it.
1927   if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) {
1928     VLOG(1) << "op doesn't support static tensor: " << hlo->ToString();
1929     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1930       auto operand = hlo->mutable_operand(i);
1931       if (dynamic_dimension_inference_->HasDynamicDimension(operand)) {
1932         TF_ASSIGN_OR_RETURN(auto dynamic_operand,
1933                             ConvertToDynamic(hlo->mutable_operand(i)));
1934         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand));
1935       }
1936     }
1937     return Status::OK();
1938   }
1939 
1940   return Status::OK();
1941 }
1942 
HandleGetTupleElement(HloInstruction * hlo)1943 Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) {
1944   *hlo->mutable_shape() =
1945       hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index());
1946   return Status::OK();
1947 }
1948 
HandleTuple(HloInstruction * hlo)1949 Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) {
1950   for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1951     *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape();
1952   }
1953   return Status::OK();
1954 }
1955 
HandleParameter(HloInstruction * hlo)1956 Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
1957   return Status::OK();
1958 }
1959 
HandleCustomCall(HloInstruction * hlo)1960 Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
1961   if (hlo->custom_call_target() == "SliceToDynamic" ||
1962       hlo->custom_call_target() == "PadToStatic") {
1963     // Those ops support are created to handle dynamic tensors so by their
1964     // nature they support dynamic lowering.
1965     return Status::OK();
1966   }
1967 
1968   return DefaultAction(hlo);
1969 }
1970 
1971 }  // namespace
1972 
Run(HloModule * module)1973 StatusOr<bool> DynamicPadder::Run(HloModule* module) {
1974   bool changed = false;
1975   VLOG(2) << "Pre DynamicPadder HLO:";
1976   XLA_VLOG_LINES(2, module->ToString());
1977   // Removes dynamic dimensions on parameters if there is already a binding for
1978   // it. We do this because we have two different APIs to express a dynamic
1979   // dimension:
1980   //
1981   // 1. Dynamic dimension as specified directly in the shape -- Needed for
1982   // PyTorch.
1983   //
1984   // 2. Dynamic dimension using dynamic parameter binding object. This
1985   // is needed for tensorflow.
1986   //
1987   // For case 1, we will insert "pad-to-static" instruction in the
1988   // beginning of xla execution, to make it into a static layout.
1989   //
1990   // For case 2, since it already has a static layout, we remove the
1991   // dynamic dimension.
1992   //
1993   // TODO(b/145140571): Convert all API invocations to case 1.
1994   //
1995   TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
1996       [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1997           const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
1998           -> Status {
1999         HloInstruction* parameter =
2000             module->entry_computation()->parameter_instruction(
2001                 dynamic_dimension.parameter_num);
2002         ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
2003                                           dynamic_dimension.parameter_index,
2004                                           dynamic_dimension.dimension, false);
2005         return Status::OK();
2006       }));
2007 
2008   TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
2009   TF_ASSIGN_OR_RETURN(
2010       DynamicDimensionInference dynamic_dimension_inference,
2011       DynamicDimensionInference::Run(module, custom_call_handler_));
2012 
2013   for (HloComputation* computation : module->computations()) {
2014     for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
2015       OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport;
2016       if (op_supports_dynamism_handler_ != nullptr) {
2017         has_dynamism_support = op_supports_dynamism_handler_(inst);
2018       }
2019       // This op support dynamic lowering, no padding is required.
2020       if (has_dynamism_support != OpDynamismSupport::kNoSupport) {
2021         continue;
2022       }
2023       if (inst->opcode() == HloOpcode::kConcatenate) {
2024         TF_ASSIGN_OR_RETURN(
2025             changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference));
2026         continue;
2027       }
2028       if (inst->opcode() == HloOpcode::kReverse) {
2029         TF_ASSIGN_OR_RETURN(changed,
2030                             RewriteReverse(inst, &dynamic_dimension_inference));
2031         continue;
2032       }
2033       if (inst->opcode() == HloOpcode::kSort) {
2034         TF_ASSIGN_OR_RETURN(
2035             changed, RewriteDynamicSort(inst, &dynamic_dimension_inference));
2036         continue;
2037       }
2038       if (inst->opcode() == HloOpcode::kReshape) {
2039         TF_ASSIGN_OR_RETURN(
2040             changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
2041         continue;
2042       }
2043 
2044       // Elementwise binary with dynamic shapes have implicit broadcast
2045       // semantics.
2046       if (inst->IsElementwiseBinary()) {
2047         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp(
2048                                          inst, &dynamic_dimension_inference));
2049         continue;
2050       }
2051 
2052       if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) {
2053         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicUpdateSlice(
2054                                          inst, &dynamic_dimension_inference));
2055         continue;
2056       }
2057 
2058       if (inst->opcode() == HloOpcode::kDynamicReshape) {
2059         TF_ASSIGN_OR_RETURN(
2060             changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
2061         auto* static_reshape =
2062             computation->AddInstruction(HloInstruction::CreateReshape(
2063                 inst->shape(), inst->mutable_operand(0)));
2064         TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_reshape));
2065         TF_RETURN_IF_ERROR(dynamic_dimension_inference.ForwardDynamicSize(
2066             inst, static_reshape, {}));
2067         continue;
2068       }
2069       if (inst->IsCustomCall("DynamicConvolutionInputGrad")) {
2070         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionInputGrad(
2071                                          inst, &dynamic_dimension_inference));
2072         continue;
2073       }
2074 
2075       if (inst->IsCustomCall("DynamicConvolutionForward")) {
2076         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionForward(
2077                                          inst, &dynamic_dimension_inference));
2078         continue;
2079       }
2080 
2081       if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) {
2082         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionKernelGrad(
2083                                          inst, &dynamic_dimension_inference));
2084         continue;
2085       }
2086 
2087       if (inst->IsCustomCall("DynamicReduceWindowSamePadding")) {
2088         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicReduceWindowSamePadding(
2089                                          inst, &dynamic_dimension_inference));
2090         continue;
2091       }
2092 
2093       if (inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
2094         TF_ASSIGN_OR_RETURN(changed, RewriteDynamicSelectAndScatterSamePadding(
2095                                          inst, &dynamic_dimension_inference));
2096         continue;
2097       }
2098 
2099       for (int64_t operand_num = 0; operand_num < inst->operand_count();
2100            ++operand_num) {
2101         HloInstruction* original_operand = inst->mutable_operand(operand_num);
2102         HloInstruction* operand = original_operand;
2103         if (!operand->shape().IsArray()) {
2104           continue;
2105         }
2106 
2107         for (int64_t input_dim = 0; input_dim < operand->shape().rank();
2108              ++input_dim) {
2109           HloInstruction* operand_dynamic_size =
2110               dynamic_dimension_inference.GetDynamicSize(original_operand, {},
2111                                                          input_dim);
2112           if (operand_dynamic_size == nullptr) {
2113             continue;
2114           }
2115           VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @"
2116                   << input_dim;
2117 
2118           if (ShouldSkipPadOnOperand(inst, operand_num, input_dim)) {
2119             continue;
2120           }
2121 
2122           TF_ASSIGN_OR_RETURN(HloInstruction * identity_value,
2123                               ChooseIdentityValue(inst, operand_num));
2124           if (identity_value == nullptr) {
2125             continue;
2126           }
2127 
2128           HloInstruction* padded = PadWithScalar(
2129               operand, input_dim, operand_dynamic_size, identity_value);
2130           TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded));
2131           operand = inst->mutable_operand(operand_num);
2132           changed = true;
2133         }
2134       }
2135     }
2136   }
2137   if (changed == true) {
2138     module->set_is_dynamic(true);
2139   }
2140 
2141   // There are ops that only support dynamic lowering and ops that only support
2142   // static lowering, add dynamic<->static tensor conversion around the boundary
2143   // between those ops, as well as the root instruction.
2144   auto computations = module->MakeComputationPostOrder();
2145   // Reverse postorder so that if caller doesn't support dynamic tensor (while,
2146   // etc), change their called computation to only take static tensors.
2147   for (auto it = computations.rbegin(); it != computations.rend(); ++it) {
2148     HloComputation* computation = *it;
2149     // if slice_dynamic_output_ is set and this is entry computation, we need
2150     // the output tensor to be in dynamic form.
2151     bool require_dynamic_output =
2152         slice_dynamic_output_ && computation == module->entry_computation();
2153     TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(
2154         computation, op_supports_dynamism_handler_,
2155         &dynamic_dimension_inference,
2156         /*require_dynamic_output=*/require_dynamic_output));
2157   }
2158 
2159   for (auto* computation : module->computations()) {
2160     for (auto instruction : computation->MakeInstructionPostOrder()) {
2161       TF_ASSIGN_OR_RETURN(
2162           bool replaced_get_size,
2163           ReplaceGetSize(instruction, &dynamic_dimension_inference));
2164       changed = changed || replaced_get_size;
2165     }
2166   }
2167 
2168   for (auto* computation : module->computations()) {
2169     for (auto instruction : computation->MakeInstructionPostOrder()) {
2170       TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
2171       TF_ASSIGN_OR_RETURN(bool replaced_set_bound,
2172                           ReplaceSetBound(instruction));
2173       changed = changed || replaced_set_size;
2174       changed = changed || replaced_set_bound;
2175     }
2176   }
2177   HloDCE dce;
2178   TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
2179 
2180   VLOG(2) << "Post DynamicPadder HLO:";
2181   XLA_VLOG_LINES(2, module->ToString());
2182   return changed;
2183 }
2184 
2185 }  // namespace xla
2186