1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/comparison_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
30 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
31 #include "tensorflow/compiler/xla/service/dynamic_window_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
35 #include "tensorflow/compiler/xla/service/hlo_dce.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
38 #include "tensorflow/compiler/xla/service/hlo_module.h"
39 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
40 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
41 #include "tensorflow/compiler/xla/service/shape_inference.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/platform/errors.h"
49 #include "tensorflow/core/platform/statusor.h"
50
51 namespace xla {
52
53 namespace {
54
55 // ChooseIdentityValue looks at the instruction's operand, returns a
56 // identity value which, when padded, doesn't change the result of the
57 // instruction.
58 //
59 // nullopt is returned if padding doesn't need to be reset.
ChooseIdentityValue(HloInstruction * inst,int64_t operand_number)60 StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
61 int64_t operand_number) {
62 HloComputation* comp = inst->parent();
63 // Padding on elementwise operation doesn't affect the result of the effective
64 // data.
65 if (inst->IsElementwise()) {
66 return nullptr;
67 }
68 if (inst->opcode() == HloOpcode::kSelectAndScatter ||
69 inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
70 if (operand_number == 1) {
71 return inst->mutable_operand(2);
72 }
73 TF_RET_CHECK(operand_number == 0);
74 HloComputation* select = inst->called_computations()[0];
75
76 if (Match(select->root_instruction(),
77 match::Compare(match::Parameter(), match::Parameter())
78 .WithComparisonDirection(ComparisonDirection::kGe))) {
79 return comp->AddInstruction(HloInstruction::CreateConstant(
80 LiteralUtil::MinValue(inst->operand(0)->shape().element_type())));
81 } else {
82 return Unimplemented(
83 "Only select and scatter with `max` as select function is "
84 "supported, got %s",
85 select->ToString());
86 }
87 }
88 switch (inst->opcode()) {
89 case HloOpcode::kReduce: {
90 auto* reduce = Cast<HloReduceInstruction>(inst);
91 TF_RET_CHECK(operand_number < reduce->input_count())
92 << "Only data operand with dynamic dimension is valid.";
93 // Variadic reduce has different init value for different operand, given
94 // a data operand number, find the init value index.
95 int64_t init_value_index = reduce->input_count() + operand_number;
96 return inst->mutable_operand(init_value_index);
97 }
98 case HloOpcode::kReduceWindow: {
99 auto* reduce_window = Cast<HloReduceWindowInstruction>(inst);
100 TF_RET_CHECK(operand_number < reduce_window->input_count())
101 << "Only data operand with dynamic dimension is valid.";
102 // Variadic reduce has different init value for different operand, given
103 // a data operand number, find the init value index.
104 int64_t init_value_index = reduce_window->input_count() + operand_number;
105 return inst->mutable_operand(init_value_index);
106 }
107
108 case HloOpcode::kConvolution:
109 case HloOpcode::kDot: {
110 // Use 0 as padding value for convolution and dot.
111 PrimitiveType ptype = inst->shape().element_type();
112 return comp->AddInstruction(
113 HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
114 }
115
116 case HloOpcode::kPad: {
117 return inst->mutable_operand(1);
118 }
119 case HloOpcode::kScatter: {
120 if (operand_number != 1) {
121 return nullptr;
122 }
123 PrimitiveType indices_ptype =
124 inst->operand(operand_number)->shape().element_type();
125
126 return comp->AddInstruction(
127 HloInstruction::CreateConstant(LiteralUtil::MaxValue(indices_ptype)));
128 }
129 case HloOpcode::kParameter:
130 case HloOpcode::kGather:
131 case HloOpcode::kDynamicSlice:
132 case HloOpcode::kDynamicUpdateSlice:
133 case HloOpcode::kGetDimensionSize:
134 case HloOpcode::kSetDimensionSize:
135 case HloOpcode::kConcatenate:
136 case HloOpcode::kReshape:
137 case HloOpcode::kReverse:
138 case HloOpcode::kTuple:
139 case HloOpcode::kAllReduce:
140 case HloOpcode::kReduceScatter:
141 case HloOpcode::kBroadcast:
142 case HloOpcode::kTranspose:
143 case HloOpcode::kSort:
144 case HloOpcode::kSlice:
145 case HloOpcode::kDomain:
146 return nullptr;
147 case HloOpcode::kCustomCall:
148 // Assume that custom calls created by the client are valid with padded
149 // dynamic dimensions.
150 return nullptr;
151 default:
152 return UnimplementedStrCat("Unimplemented padding for instruction: ",
153 inst->ToString());
154 }
155 }
156
ReplaceGetSize(HloInstruction * instr,DynamicDimensionInference * dynamic_dimension_inference)157 StatusOr<bool> ReplaceGetSize(
158 HloInstruction* instr,
159 DynamicDimensionInference* dynamic_dimension_inference) {
160 if (instr->opcode() != HloOpcode::kGetDimensionSize) {
161 return false;
162 }
163 HloComputation* computation = instr->parent();
164
165 TF_ASSIGN_OR_RETURN(auto legal_shape,
166 ShapeInference::InferGetDimensionSizeShape(
167 instr->operand(0)->shape(), instr->dimension()));
168 TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
169 << "instr->shape() " << instr->shape().ToString() << " , "
170 << "legal_shape " << legal_shape.ToString();
171 TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
172 HloInstruction* operand = instr->mutable_operand(0);
173 int64_t dim = instr->dimension();
174 HloInstruction* dynamic_size =
175 dynamic_dimension_inference->GetDynamicSize(operand, {}, dim);
176 if (dynamic_size != nullptr) {
177 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
178 // The dependency between a instruction and its dynamic dimensions is not
179 // modeled in the IR. As instr is being replaced by dynamic_size, also tell
180 // dynamic dimension inference that the instruction is being replaced.
181 dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(
182 instr, dynamic_size);
183 } else {
184 int32_t size = instr->operand(0)->shape().dimensions(dim);
185 HloInstruction* new_instr = computation->AddInstruction(
186 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
187 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
188 dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr,
189 new_instr);
190 }
191 return true;
192 }
193
ReplaceSetSize(HloInstruction * instr)194 StatusOr<bool> ReplaceSetSize(HloInstruction* instr) {
195 if (instr->opcode() != HloOpcode::kSetDimensionSize) {
196 return false;
197 }
198
199 TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
200 instr->shape(), instr->operand(0)->shape()))
201 << "instr->shape() " << instr->shape().ToString() << " , "
202 << "instruction operand shape " << instr->operand(0)->shape();
203 HloInstruction* operand = instr->mutable_operand(0);
204
205 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
206 return true;
207 }
208
ReplaceSetBound(HloInstruction * instr)209 StatusOr<bool> ReplaceSetBound(HloInstruction* instr) {
210 if (instr->opcode() != HloOpcode::kCustomCall ||
211 instr->custom_call_target() != "SetBound") {
212 return false;
213 }
214
215 TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()(
216 instr->shape(), instr->operand(0)->shape()))
217 << "instr->shape() " << instr->shape().ToString() << " , "
218 << "instruction operand shape " << instr->operand(0)->shape();
219 HloInstruction* operand = instr->mutable_operand(0);
220
221 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand));
222 return true;
223 }
224
ShouldSkipPadOnOperand(const HloInstruction * inst,int64_t operand_num,int64_t dimension)225 bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64_t operand_num,
226 int64_t dimension) {
227 if (inst->opcode() == HloOpcode::kSelectAndScatter && operand_num == 0 &&
228 inst->window().dimensions(dimension).size() == 1) {
229 return true;
230 }
231
232 if (auto* reduce_window = DynCast<HloReduceWindowInstruction>(inst)) {
233 if (operand_num < reduce_window->input_count() &&
234 inst->window().dimensions(dimension).size() == 1) {
235 return true;
236 }
237 }
238
239 if (operand_num == 0 && inst->opcode() == HloOpcode::kConvolution &&
240 inst->convolution_dimension_numbers().input_batch_dimension() ==
241 dimension) {
242 return true;
243 }
244 return false;
245 }
246
247 // Generates a mask representing the effective area of data and padded area of
248 // data using iota and dynamic_size. For example, given a dimension of 7
249 // elements and 5 effective elements:
250 //
251 // iota = [0, 1, 2, 3, 4, 5, 6]
252 // broadcast_dynamic_size = [5, 5, 5, 5, 5, 5, 5]
253 // mask = lt(iota, broadcast_dynamic_size) = [t, t, t, t, t, f, f]
254 //
255 // Once the mask is generated, the input data is then padded using the
256 // mask and pad value.
257 //
PadWithScalar(HloInstruction * inst,int64_t dim,HloInstruction * dynamic_size,HloInstruction * padding_scalar)258 HloInstruction* PadWithScalar(HloInstruction* inst, int64_t dim,
259 HloInstruction* dynamic_size,
260 HloInstruction* padding_scalar) {
261 CHECK(inst != nullptr && dynamic_size != nullptr &&
262 padding_scalar != nullptr);
263 const Shape mask_shape =
264 ShapeUtil::ChangeElementType(inst->shape(), xla::S32);
265 const Shape pred_shape =
266 ShapeUtil::ChangeElementType(inst->shape(), xla::PRED);
267 HloComputation* computation = inst->parent();
268 HloInstruction* iota =
269 computation->AddInstruction(HloInstruction::CreateIota(mask_shape, dim));
270
271 HloInstruction* broadcasted_effective_size = computation->AddInstruction(
272 HloInstruction::CreateBroadcast(mask_shape, dynamic_size, {}));
273 HloInstruction* pred =
274 computation->AddInstruction(HloInstruction::CreateCompare(
275 pred_shape, iota, broadcasted_effective_size,
276 ComparisonDirection::kLt));
277
278 HloInstruction* broadcasted_identity_value = computation->AddInstruction(
279 HloInstruction::CreateBroadcast(inst->shape(), padding_scalar, {}));
280 HloInstruction* padded = computation->AddInstruction(
281 HloInstruction::CreateTernary(inst->shape(), HloOpcode::kSelect, pred,
282 inst, broadcasted_identity_value));
283 return padded;
284 }
285
286 // In a reshape if a dynamic dimension is splitted into multiple output
287 // dimensions, we need to rewrite the input of the reshape.
288 //
289 // The reason for this is that a continuous input may not be evenly reshaped
290 // into output. Image we have [<=6] where valid data has size 4 and padding (P)
291 // data has size 2: [a,b,c,d,P,P]
292 //
293 // And we have a reshape that produces dynamic output dimensions.
294 //
295 // [<=6]
296 // |
297 // Reshape
298 // |
299 // [2, <=3]
300 //
301 // This should produce the same result as if the data has no padding:
302 //
303 // [4] // [a, b, c, d]
304 // |
305 // Reshape
306 // |
307 // [2, 2] // [[a,b], [c,d]]
308 //
309 // Without reshape rewriting, the result looks like:
310 //
311 // [[a,b,c]
312 // [d,P,P]], which is incorrect.
313 //
314 // We need to rewrite the reshape such that it produces:
315 // [[a,b,P]
316 // [c,d,P]]
317 //
318 // The way we do this is by a 5-steps cumsum-gather algorithm:
319 //
320 // 1.First we use the output shape to generate a binary 0-1 masking, which masks
321 // out the padded area of the output:
322 // [[1,1,0]
323 // [1,1,0]]
324 //
325 // 2.Then we do an inverse reshape to reshape it from output shape back to input
326 // shape [2,3]->[6]:
327 // [1,1,0,1,1,0]
328 //
329 // 3.We then do a cumsum with the mask:
330 // [1,2,2,3,4,4] and subtract it with 1:
331 // [0,1,1,2,3,3]
332 //
333 // 4.Use the result of cumsum as gather indices to rearrange the original
334 // data. Feed the original input [a,b,c,d,P,P] and indices into gather.
335 //
336 // operand [a,b,c,d,P,P], indices [0,1,1,2,3,3]
337 // | |
338 // Gather-----------------+
339 // |
340 // v
341 // value[a,b,b,c,d,d], which is equivalent to [a,b,P,c,d,P] as padding value
342 // doesn't matter.
343 //
344 //
345 // 5.Feed the sorted input to original reshape[6]->[2,3], we can now get the
346 // correct result:
347 // [[a,b,P]
348 // [c,d,P]]
349 //
RewriteDynamicReshapeSplitInput(HloInstruction * reshape,int64_t input_dim,absl::Span<const int64> output_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)350 Status RewriteDynamicReshapeSplitInput(
351 HloInstruction* reshape, int64_t input_dim,
352 absl::Span<const int64> output_dims,
353 absl::Span<HloInstruction*> output_dynamic_dims,
354 DynamicDimensionInference* dynamic_dimension_inference) {
355 VLOG(2) << "Reshaping input dim " << input_dim << "to "
356 << VectorString(output_dims);
357 const Shape operand_shape = reshape->operand(0)->shape();
358 TF_RET_CHECK(output_dims.size() > 1);
359
360 HloComputation* comp = reshape->parent();
361 const Shape mask_input_shape =
362 ShapeUtil::MakeShape(xla::S32, {operand_shape.dimensions(input_dim)});
363
364 std::vector<int64> reshaped_dims;
365 for (int64_t output_dim : output_dims) {
366 reshaped_dims.push_back(reshape->shape().dimensions(output_dim));
367 }
368
369 const Shape mask_reshaped_shape =
370 ShapeUtil::MakeShape(xla::S32, reshaped_dims);
371
372 HloInstruction* zero = comp->AddInstruction(
373 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
374 HloInstruction* one = comp->AddInstruction(
375 HloInstruction::CreateConstant(LiteralUtil::One(S32)));
376 // Step 1 -- generate binary mask.
377 // Mask starts with all one, each dynamic dimension sets that dimension of the
378 // mask to partially zero in the end.
379 HloInstruction* binary_mask = comp->AddInstruction(
380 HloInstruction::CreateBroadcast(mask_reshaped_shape, one, {}));
381
382 bool need_rewrite = false;
383
384 // Pad the effective dimension with 1.
385 //
386 // Index starts from 1 since there is no need to rewrite a major output
387 // dimension.
388 for (int64_t i = 1; i < output_dims.size(); ++i) {
389 const int64_t output_dim = output_dims[i];
390 HloInstruction* dynamic_size = output_dynamic_dims[output_dim];
391 if (dynamic_size == nullptr) {
392 continue;
393 }
394 // If there is dynamic dimension in the output, need to rewrite the input.
395 need_rewrite = true;
396
397 binary_mask = PadWithScalar(binary_mask, i, dynamic_size, zero);
398 }
399 if (!need_rewrite) {
400 return Status::OK();
401 }
402 // Step 2.
403 // Do a reverse reshape to flatten the binary mask (with output shape) back to
404 // input shape.
405 HloInstruction* input_shape_binary_mask = comp->AddInstruction(
406 HloInstruction::CreateReshape(mask_input_shape, binary_mask));
407
408 // Step 3. Do a cumsum on the binary mask.
409 auto embedded_builder = HloComputation::Builder("add");
410 {
411 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
412 0, ShapeUtil::MakeShape(S32, {}), "lhs"));
413 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
414 1, ShapeUtil::MakeShape(S32, {}), "rhs"));
415 embedded_builder.AddInstruction(
416 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
417 }
418
419 HloComputation* add =
420 reshape->GetModule()->AddEmbeddedComputation(embedded_builder.Build());
421 Window cumsum_window;
422 // First dimension is unchanged.
423 WindowDimension* dim = cumsum_window.add_dimensions();
424 dim->set_size(operand_shape.dimensions(input_dim));
425 dim->set_stride(1);
426 dim->set_padding_low(operand_shape.dimensions(input_dim) - 1);
427 dim->set_padding_high(0);
428 dim->set_window_dilation(1);
429 dim->set_base_dilation(1);
430 HloInstruction* cumsum =
431 comp->AddInstruction(HloInstruction::CreateReduceWindow(
432 mask_input_shape, input_shape_binary_mask, zero, cumsum_window, add));
433
434 HloInstruction* broadcast_ones = comp->AddInstruction(
435 HloInstruction::CreateBroadcast(mask_input_shape, one, {}));
436 cumsum = comp->AddInstruction(HloInstruction::CreateBinary(
437 mask_input_shape, HloOpcode::kSubtract, cumsum, broadcast_ones));
438
439 GatherDimensionNumbers gather_dim_numbers;
440 // Use gather to rearrange the input dim dimension.
441 for (int64_t i = 0; i < operand_shape.dimensions_size(); ++i) {
442 // Offset dim is every dimension including newly added size 1 dim, except
443 // for input_dim, which acts as a batch_dim.
444 if (i != input_dim) {
445 gather_dim_numbers.add_offset_dims(i);
446 }
447 }
448 // The dimension to rewrite is the index dim.
449 gather_dim_numbers.add_start_index_map(input_dim);
450 gather_dim_numbers.set_index_vector_dim(1);
451 gather_dim_numbers.add_collapsed_slice_dims(input_dim);
452
453 // Step 4. Gather.
454
455 // Temporarily removes dynamic dimension before entering gather -- we want the
456 // gather to ignore dynamic dimension.
457 HloInstruction* operand_static_dim_size =
458 comp->AddInstruction(HloInstruction::CreateConstant(
459 LiteralUtil::CreateR0<int32>(operand_shape.dimensions(input_dim))));
460 HloInstruction* operand_static =
461 comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
462 operand_shape, reshape->mutable_operand(0), operand_static_dim_size,
463 input_dim));
464
465 std::vector<int64> slice_sizes(operand_shape.dimensions().begin(),
466 operand_shape.dimensions().end());
467 slice_sizes[input_dim] = 1;
468 HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather(
469 ShapeUtil::MakeShape(operand_shape.element_type(),
470 operand_shape.dimensions()),
471 operand_static, cumsum, gather_dim_numbers, slice_sizes, true));
472
473 // Step 6: Feed gather input to original reshape.
474
475 TF_RETURN_IF_ERROR(reshape->ReplaceOperandWith(0, gather));
476
477 HloInstruction* reshape_dynamic = reshape;
478
479 auto users = reshape->users();
480
481 // Forward the output dynamic dimension.
482 for (int64_t output_dim : output_dims) {
483 HloInstruction* output_dynamic_size =
484 dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
485 if (output_dynamic_size != nullptr) {
486 reshape_dynamic =
487 comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
488 reshape->shape(), reshape_dynamic, output_dynamic_size,
489 output_dim));
490 }
491 }
492
493 for (auto* user : users) {
494 TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, reshape_dynamic));
495 }
496 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
497 reshape, reshape_dynamic, {}));
498
499 return Status::OK();
500 }
501
502 // RewriteDynamicReshapeCombineInput is similar to
503 // RewriteDynamicReshapeSplitInput, in a reshape if multiple dimensions are
504 // combined into one dimension, we need to rewrite the output.
505 //
506 // The reason for this is that a continuous input may not be evenly reshaped
507 // into output. Image we have [2, <=3] where second dimension has size 2 and
508 // padding(P) data has size 1:
509 // [[a,b,P]
510 // [c,d,P]]
511 //
512 // And we have a reshape that combines this two input dimensions.
513 //
514 // [2, <=3]
515 // |
516 // Reshape
517 // |
518 // [6]
519 //
520 // This should produce the same result as if the data has no padding:
521 //
522 // [2, 2] // [[a, b], [c, d]]
523 // |
524 // Reshape
525 // |
526 // [4] // [a,b,c,d]
527 //
528 // Without rewriting, the result would be:
529 //
530 // [a,b,P,c,d,P], which is incorrect.
531 //
532 // We need to rewrite the reshape such that it produces:
533 // [a,b,c,d,P,P]
534 //
535 // The way we do this is by a 5-steps sort-gather algorithm:
536 //
537 // 1.First we use the input shape to generate a binary 0-1 masking, which masks
538 // out the padded area of the output:
539 // [[0,0,1]
540 // [0,0,1]]
541 //
542 // 2.Then we do an reshape to reshape the mask from input shape to output
543 // shape [2,3]->[6]:
544 // [0,0,1,0,0,1]
545 //
546 // 3.We then generate an iota mask using the output shape:
547 // [0,1,2,3,4,5]
548 //
549 // 4.Stable sort the iota mask using the binary mask as key:
550 // key [0,0,1,0,0,1]
551 // value[0,1,2,3,4,5]
552 // | Sort by key
553 // v
554 // key [0,0,0,0,1,1]
555 // value[0,1,3,4,2,5]
556 //
557 // 5.Gather the original output [a,b,P,c,d,P] using the sorted iota mask:
558 // original output gather indices
559 // [a,b,P,c,d,P] [0,1,3,4,2,5]
560 // | |
561 // Gather ----------------+
562 // |
563 // [a,b,c,d,P,P]
564 //
RewriteDynamicReshapeCombineInput(HloInstruction * reshape,absl::Span<const int64> input_dims,int64_t output_dim,absl::Span<HloInstruction * > input_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)565 Status RewriteDynamicReshapeCombineInput(
566 HloInstruction* reshape, absl::Span<const int64> input_dims,
567 int64_t output_dim, absl::Span<HloInstruction*> input_dynamic_dims,
568 DynamicDimensionInference* dynamic_dimension_inference) {
569 // Rewrite dynamic reshape into reshape followed by a sort, all padded
570 // data will be moved to the end.
571 HloComputation* comp = reshape->parent();
572 HloInstruction* zero = comp->AddInstruction(
573 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
574 HloInstruction* one = comp->AddInstruction(
575 HloInstruction::CreateConstant(LiteralUtil::One(S32)));
576 const Shape output_shape = reshape->shape();
577 const Shape input_shape = reshape->operand(0)->shape();
578 const Shape mask_output_shape =
579 ShapeUtil::MakeShape(xla::S32, {output_shape.dimensions(output_dim)});
580 std::vector<int64> input_dim_sizes;
581 for (int64_t input_dim : input_dims) {
582 input_dim_sizes.push_back(input_shape.dimensions(input_dim));
583 }
584
585 const Shape mask_input_shape =
586 ShapeUtil::MakeShape(xla::S32, input_dim_sizes);
587
588 // Step 1 -- generate binary mask.
589 // Mask starts with all zero, each dynamic dimension sets that dimension of
590 // the mask to partially ones in the end.
591 HloInstruction* binary_mask = comp->AddInstruction(
592 HloInstruction::CreateBroadcast(mask_input_shape, zero, {}));
593
594 bool need_rewrite = false;
595
596 // Pad the effective dimension with 1.
597 //
598 // Index starts from 1 since there is no need to rewrite a major output
599 // dimension.
600 for (int64_t i = 1; i < input_dims.size(); ++i) {
601 const int64_t input_dim = input_dims[i];
602 HloInstruction* dynamic_size = input_dynamic_dims[input_dim];
603 if (dynamic_size == nullptr) {
604 continue;
605 }
606 // If there is a dynamic dimension in the input, need to rewrite the output.
607 need_rewrite = true;
608
609 binary_mask = PadWithScalar(binary_mask, i, dynamic_size, one);
610 }
611 if (!need_rewrite) {
612 VLOG(2) << "No need to rewrite";
613 return Status::OK();
614 }
615
616 // Step 2.
617 // Do a reshape to flatten the binary mask into output_shape
618 HloInstruction* output_shape_binary_mask = comp->AddInstruction(
619 HloInstruction::CreateReshape(mask_output_shape, binary_mask));
620
621 // Step 3.
622 // Generate an iota with output shape.
623 HloInstruction* iota =
624 comp->AddInstruction(HloInstruction::CreateIota(mask_output_shape, 0));
625
626 // Step 4.
627 // Stable sort the iota mask using the binary mask as key and iota as value:
628
629 // Build computation for sort, key is the mask, value is the iota.
630 HloComputation::Builder comp_builder("compare");
631 HloInstruction* lhs_key =
632 comp_builder.AddInstruction(HloInstruction::CreateParameter(
633 0, ShapeUtil::MakeScalarShape(S32), "lhs_key"));
634 HloInstruction* rhs_key =
635 comp_builder.AddInstruction(HloInstruction::CreateParameter(
636 1, ShapeUtil::MakeScalarShape(S32), "rhs_key"));
637
638 // Values for lhs and rhs
639 comp_builder.AddInstruction(HloInstruction::CreateParameter(
640 2, ShapeUtil::MakeScalarShape(S32), "lhs_value"));
641 comp_builder.AddInstruction(HloInstruction::CreateParameter(
642 3, ShapeUtil::MakeScalarShape(S32), "rhs_value"));
643 comp_builder.AddInstruction(
644 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), lhs_key,
645 rhs_key, ComparisonDirection::kLt));
646 HloComputation* compare =
647 comp->parent()->AddEmbeddedComputation(comp_builder.Build());
648
649 // Use mask_reshaped as key, sort reshaped data as value.
650 HloInstruction* sort = comp->AddInstruction(HloInstruction::CreateSort(
651 ShapeUtil::MakeTupleShape({mask_output_shape, mask_output_shape}), 0,
652 {output_shape_binary_mask, iota}, compare,
653 /*is_stable=*/true));
654
655 HloInstruction* gather_indices = comp->AddInstruction(
656 HloInstruction::CreateGetTupleElement(mask_output_shape, sort, 1));
657
658 // Step 5.Gather the original output using the sorted iota mask:
659
660 GatherDimensionNumbers gather_dim_numbers;
661 // Use gather to rearrange the output dim dimension.
662 for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
663 // Offset dim is every dimension including newly added size 1 dim, except
664 // for input_dim, which acts as a batch_dim.
665 if (i != output_dim) {
666 gather_dim_numbers.add_offset_dims(i);
667 }
668 }
669 // The dimension to rewrite is the index dim.
670 gather_dim_numbers.add_start_index_map(output_dim);
671 gather_dim_numbers.set_index_vector_dim(1);
672 gather_dim_numbers.add_collapsed_slice_dims(output_dim);
673
674 HloInstruction* static_dim_size = comp->AddInstruction(
675 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
676 reshape->shape().dimensions(output_dim))));
677
678 // Temporarily removes dynamic dimension of the reshape before we send it to
679 // the sort -- we want padded area to also participate in the gather.
680 HloInstruction* reshape_static =
681 comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
682 reshape->shape(), reshape, static_dim_size, output_dim));
683 std::vector<int64> gather_slice_sizes(output_shape.dimensions().begin(),
684 output_shape.dimensions().end());
685 gather_slice_sizes[output_dim] = 1;
686 HloInstruction* gather = comp->AddInstruction(HloInstruction::CreateGather(
687 output_shape, reshape_static, gather_indices, gather_dim_numbers,
688 gather_slice_sizes, true));
689
690 // Forward dynamic size to the newly created gather.
691 HloInstruction* output_dynamic_size =
692 dynamic_dimension_inference->GetDynamicSize(reshape, {}, output_dim);
693 TF_RET_CHECK(output_dynamic_size != nullptr);
694 gather = comp->AddInstruction(HloInstruction::CreateSetDimensionSize(
695 gather->shape(), gather, output_dynamic_size, output_dim));
696 auto users = reshape->users();
697 for (auto* user : users) {
698 // Avoid cycles by not replacing the static reshape and get_dimension_size.
699 if (user != reshape_static && user != output_dynamic_size) {
700 TF_RETURN_IF_ERROR(reshape->ReplaceUseWith(user, gather));
701 }
702 }
703
704 if (reshape == comp->root_instruction()) {
705 comp->set_root_instruction(gather);
706 }
707
708 TF_RETURN_IF_ERROR(
709 dynamic_dimension_inference->ForwardDynamicSize(reshape, gather, {}));
710
711 return Status::OK();
712 }
713
RewriteDynamicReshapeSingleGroup(HloInstruction * reshape,absl::Span<const int64> input_dims,absl::Span<const int64> output_dims,absl::Span<HloInstruction * > input_dynamic_dims,absl::Span<HloInstruction * > output_dynamic_dims,DynamicDimensionInference * dynamic_dimension_inference)714 Status RewriteDynamicReshapeSingleGroup(
715 HloInstruction* reshape, absl::Span<const int64> input_dims,
716 absl::Span<const int64> output_dims,
717 absl::Span<HloInstruction*> input_dynamic_dims,
718 absl::Span<HloInstruction*> output_dynamic_dims,
719 DynamicDimensionInference* dynamic_dimension_inference) {
720 VLOG(2) << "Rewriting dynamic reshape " << reshape->ToString()
721 << " input dims: " << VectorString(input_dims)
722 << " output dims: " << VectorString(output_dims);
723
724 const Shape operand_shape = reshape->operand(0)->shape();
725 const Shape output_shape = reshape->shape();
726
727 if (input_dims.size() == 1) {
728 int64_t input_dim = input_dims[0];
729 // Size 1 dimension doesn't need a rewrite.
730 if (operand_shape.dimensions()[input_dim] == 1) {
731 return Status::OK();
732 }
733 // One input dimension is splitted into multiple output dimensions.
734 return RewriteDynamicReshapeSplitInput(reshape, input_dim, output_dims,
735 output_dynamic_dims,
736 dynamic_dimension_inference);
737 }
738
739 if (output_dims.size() == 1) {
740 int64_t output_dim = output_dims[0];
741 if (output_shape.dimensions()[output_dim] == 1) {
742 return Status::OK();
743 }
744 // One input dimension is splitted into multiple output dimensions.
745 return RewriteDynamicReshapeCombineInput(reshape, input_dims, output_dim,
746 input_dynamic_dims,
747 dynamic_dimension_inference);
748 }
749 // Shouldn't get here;
750 TF_RET_CHECK(false);
751 return Status::OK();
752 }
753
RewriteReverse(HloInstruction * reverse,DynamicDimensionInference * dynamic_dimension_inference)754 StatusOr<bool> RewriteReverse(
755 HloInstruction* reverse,
756 DynamicDimensionInference* dynamic_dimension_inference) {
757 // When we have [A, B, C, D, E] and reverse them, we get [E, D, C, B, A].
758 // However, if the dynamic size is 2, we expect B, A to be in front:
759 // [B, A, P, P, P].
760 //
761 // We do this by running a pad and dynamic slice on the result:
762 // [A, B, C, D, E]
763 // |
764 // reverse
765 // |
766 // [E, D, C, B, A]
767 // |
768 // pad # Use pad to double the size of the dimension to avoid OOB.
769 // |
770 // [E, D, C, B, A, P, P, P, P, P]
771 // |
772 // dynamic slice
773 // |
774 // [B, A, P, P, P]
775 auto reverse_dims = reverse->dimensions();
776 HloComputation* comp = reverse->parent();
777 const Shape& reverse_shape = reverse->shape();
778 std::set<int64> dynamic_reverse_dims;
779 for (int64_t reverse_dim : reverse_dims) {
780 HloInstruction* dynamic_size =
781 dynamic_dimension_inference->GetDynamicSize(reverse, {}, reverse_dim);
782 if (dynamic_size == nullptr) {
783 // Reverse dimension is not dynamic -- no rewrite needed.
784 continue;
785 }
786 dynamic_reverse_dims.insert(reverse_dim);
787 }
788
789 if (dynamic_reverse_dims.empty()) {
790 // We only need to rewrite dynamic dimensions that are also reverse
791 // dimensions.
792 return false;
793 }
794
795 PaddingConfig padding;
796 // Doubles dynamic dimension size using a pad.
797 Shape pad_shape = reverse_shape;
798 for (int i = 0; i < reverse_shape.rank(); ++i) {
799 auto dimension = padding.add_dimensions();
800 if (dynamic_reverse_dims.count(i) > 0) {
801 dimension->set_edge_padding_low(0);
802 dimension->set_edge_padding_high(reverse_shape.dimensions(i));
803 dimension->set_interior_padding(0);
804 pad_shape.set_dimensions(i, 2 * pad_shape.dimensions(i));
805 }
806 }
807 HloInstruction* cloned_reverse = comp->AddInstruction(reverse->Clone());
808 HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
809 LiteralUtil::Zero(pad_shape.element_type())));
810 HloInstruction* pad = comp->AddInstruction(
811 HloInstruction::CreatePad(pad_shape, cloned_reverse, zero, padding));
812 std::vector<HloInstruction*> start_indices;
813 start_indices.reserve(reverse_shape.rank());
814 for (int i = 0; i < reverse_shape.rank(); ++i) {
815 if (dynamic_reverse_dims.count(i) > 0) {
816 // Start at bound_size - dynamic_size.
817 HloInstruction* bound_size =
818 comp->AddInstruction(HloInstruction::CreateConstant(
819 LiteralUtil::CreateR0<int32>(reverse_shape.dimensions(i))));
820 HloInstruction* dynamic_size =
821 dynamic_dimension_inference->GetDynamicSize(reverse, {}, i);
822 HloInstruction* start_offset =
823 comp->AddInstruction(HloInstruction::CreateBinary(
824 ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract, bound_size,
825 dynamic_size));
826 start_indices.push_back(start_offset);
827 } else {
828 HloInstruction* zero = comp->AddInstruction(
829 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
830 start_indices.push_back(zero);
831 }
832 }
833 HloInstruction* dynamic_reverse =
834 comp->AddInstruction(HloInstruction::CreateDynamicSlice(
835 reverse_shape, pad, start_indices, reverse_shape.dimensions()));
836 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(reverse, dynamic_reverse));
837 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
838 reverse, dynamic_reverse, {}));
839 return true;
840 }
841
RewriteInputWithDynamicPadding(HloInstruction * conv,HloInstruction * input,HloInstruction * padding_value,absl::Span<HloInstruction * > padding_before,Window * input_window,std::function<int64 (int64_t)> window_dim_to_shape_dim)842 HloInstruction* RewriteInputWithDynamicPadding(
843 HloInstruction* conv, HloInstruction* input, HloInstruction* padding_value,
844 absl::Span<HloInstruction*> padding_before, Window* input_window,
845 std::function<int64(int64_t)> window_dim_to_shape_dim) {
846 HloComputation* comp = conv->parent();
847 HloInstruction* zero_s32 = comp->AddInstruction(
848 HloInstruction::CreateConstant(LiteralUtil::Zero(S32)));
849 // Padded shape represents the bounded shape after dynamic padding.
850 Shape padded_shape = input->shape();
851 PaddingConfig padding_configs;
852 for (int64_t i = 0; i < input->shape().rank(); ++i) {
853 PaddingConfig::PaddingConfigDimension padding_dim;
854 *padding_configs.add_dimensions() = padding_dim;
855 }
856 std::vector<HloInstruction*> start_indices(input->shape().rank(), zero_s32);
857 for (int64_t dim_index = 0; dim_index < input_window->dimensions_size();
858 ++dim_index) {
859 if (padding_before[dim_index] == nullptr) {
860 continue;
861 }
862 int64_t shape_dim = window_dim_to_shape_dim(dim_index);
863
864 WindowDimension* window_dim = input_window->mutable_dimensions(dim_index);
865 auto* padding_dim = padding_configs.mutable_dimensions(shape_dim);
866 const int64_t dilated_window_size = window_util::DilatedBound(
867 window_dim->size(), window_dim->window_dilation());
868 // Use dilated window size as low padding and static padding_high +
869 // padding_low as high padding to make sure the following dynamic slice is
870 // valid and doesn't go out of bound.
871 //
872 // See go/xla-dynamic-spatial-dim for more details.
873 padding_dim->set_edge_padding_low(dilated_window_size);
874 padding_dim->set_edge_padding_high(window_dim->padding_high() +
875 window_dim->padding_low());
876 padding_dim->set_interior_padding(window_dim->base_dilation() - 1);
877 HloInstruction* slicing_start =
878 comp->AddInstruction(HloInstruction::CreateBinary(
879 ShapeUtil::MakeScalarShape(S32), HloOpcode::kSubtract,
880 comp->AddInstruction(HloInstruction::CreateConstant(
881 LiteralUtil::CreateR0<int32>(padding_dim->edge_padding_low()))),
882 padding_before[dim_index]));
883 start_indices[shape_dim] = slicing_start;
884
885 padded_shape.mutable_dimensions()[shape_dim] =
886 window_dim->padding_low() +
887 window_util::DilatedBound(padded_shape.dimensions(shape_dim),
888 window_dim->base_dilation()) +
889 window_dim->padding_high();
890 window_dim->clear_padding_high();
891 window_dim->clear_padding_low();
892 window_dim->set_base_dilation(1);
893 input->mutable_shape()->set_dynamic_dimension(shape_dim, false);
894 }
895 // Reconstruct dynamic padding using pad and dynamic slice.
896
897 HloInstruction* pad =
898 MakePadHlo(input, padding_value, padding_configs).ValueOrDie();
899 input = comp->AddInstruction(HloInstruction::CreateDynamicSlice(
900 padded_shape, pad, start_indices, padded_shape.dimensions()));
901 return input;
902 }
903
RewriteDynamicConvolutionInputGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)904 StatusOr<bool> RewriteDynamicConvolutionInputGrad(
905 HloInstruction* custom_call_conv,
906 DynamicDimensionInference* dynamic_dimension_inference) {
907 HloInstruction* grad = custom_call_conv->mutable_operand(1);
908 HloInstruction* kernel = custom_call_conv->mutable_operand(2);
909 TF_RET_CHECK(kernel->shape().is_static());
910 auto dnums = custom_call_conv->convolution_dimension_numbers();
911 HloComputation* comp = custom_call_conv->parent();
912 Window window = custom_call_conv->window();
913 HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
914 LiteralUtil::Zero(custom_call_conv->shape().element_type())));
915 std::vector<HloInstruction*> padding_before(
916 dnums.input_spatial_dimensions_size(), nullptr);
917 for (int64_t spatial_dim_index = 0;
918 spatial_dim_index < dnums.input_spatial_dimensions_size();
919 ++spatial_dim_index) {
920 int64_t input_spatial_dim =
921 dnums.input_spatial_dimensions(spatial_dim_index);
922 HloInstruction* operand_dynamic_size =
923 dynamic_dimension_inference->GetDynamicSize(
924 custom_call_conv->mutable_operand(1), {}, input_spatial_dim);
925 if (operand_dynamic_size == nullptr) {
926 continue;
927 }
928 grad = PadWithScalar(grad, input_spatial_dim, operand_dynamic_size, zero);
929 HloInstruction* slice = comp->AddInstruction(HloInstruction::CreateSlice(
930 ShapeUtil::MakeShape(S32, {1}), custom_call_conv->mutable_operand(0),
931 {input_spatial_dim}, {input_spatial_dim + 1}, {1}));
932 HloInstruction* dynamic_input_size = comp->AddInstruction(
933 HloInstruction::CreateReshape(ShapeUtil::MakeScalarShape(S32), slice));
934 const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
935 // Window stride of forward prop is same as base dilation of backward prop.
936 DynamicWindowDims dynamic_window_dims = GetWindowedInputGradSize(
937 dynamic_input_size, /*window_size=*/window_dim.size(),
938 /*window_dilation=*/window_dim.window_dilation(),
939 /*window_stride=*/window_dim.base_dilation(),
940 custom_call_conv->padding_type());
941 padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
942 }
943
944 if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
945 grad = RewriteInputWithDynamicPadding(
946 custom_call_conv, grad, zero, absl::MakeSpan(padding_before), &window,
947 [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
948 }
949
950 PrecisionConfig precision_config;
951 if (custom_call_conv->precision_config().operand_precision_size() == 3) {
952 // We are not interested in the precision config of the first operand, which
953 // is the input_sizes.
954 *precision_config.mutable_operand_precision() = {
955 custom_call_conv->precision_config().operand_precision().begin() + 1,
956 custom_call_conv->precision_config().operand_precision().end()};
957 }
958 HloInstruction* static_conv = comp->AddInstruction(
959 HloInstruction::CreateConvolve(
960 custom_call_conv->shape(), grad, kernel,
961 custom_call_conv->feature_group_count(),
962 custom_call_conv->batch_group_count(), window,
963 custom_call_conv->convolution_dimension_numbers(),
964 custom_call_conv->precision_config()),
965 "ConvBackwardInput");
966 TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
967 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
968 custom_call_conv, static_conv, {}));
969 return true;
970 }
971
RewriteDynamicConvolutionForward(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)972 StatusOr<bool> RewriteDynamicConvolutionForward(
973 HloInstruction* custom_call_conv,
974 DynamicDimensionInference* dynamic_dimension_inference) {
975 HloInstruction* input = custom_call_conv->mutable_operand(0);
976 HloInstruction* kernel = custom_call_conv->mutable_operand(1);
977 TF_RET_CHECK(kernel->shape().is_static());
978 TF_RET_CHECK(input->shape().is_dynamic());
979 HloComputation* comp = custom_call_conv->parent();
980 Window window = custom_call_conv->window();
981 auto dnums = custom_call_conv->convolution_dimension_numbers();
982 HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
983 LiteralUtil::Zero(custom_call_conv->shape().element_type())));
984 std::vector<HloInstruction*> padding_before(
985 dnums.input_spatial_dimensions_size(), nullptr);
986 for (int64_t spatial_dim_index = 0;
987 spatial_dim_index < dnums.input_spatial_dimensions_size();
988 ++spatial_dim_index) {
989 int64_t input_spatial_dim =
990 dnums.input_spatial_dimensions(spatial_dim_index);
991 HloInstruction* operand_dynamic_size =
992 dynamic_dimension_inference->GetDynamicSize(
993 custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
994 if (operand_dynamic_size == nullptr) {
995 continue;
996 }
997
998 input = PadWithScalar(input, input_spatial_dim, operand_dynamic_size, zero);
999 const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1000 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1001 operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1002 window_dim.stride(), custom_call_conv->padding_type());
1003 padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1004 }
1005 // Input feature dim can be dynamic too, reset it to zero.
1006 const int64_t input_feature_dim = dnums.input_feature_dimension();
1007 if (HloInstruction* input_feature_dynamic_size =
1008 dynamic_dimension_inference->GetDynamicSize(
1009 custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1010 input = PadWithScalar(input, input_feature_dim, input_feature_dynamic_size,
1011 zero);
1012 }
1013
1014 if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1015 input = RewriteInputWithDynamicPadding(
1016 custom_call_conv, input, zero, absl::MakeSpan(padding_before), &window,
1017 [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1018 }
1019
1020 HloInstruction* static_conv = comp->AddInstruction(
1021 HloInstruction::CreateConvolve(
1022 custom_call_conv->shape(), input, kernel,
1023 custom_call_conv->feature_group_count(),
1024 custom_call_conv->batch_group_count(), window,
1025 custom_call_conv->convolution_dimension_numbers(),
1026 custom_call_conv->precision_config()),
1027 "ConvForward");
1028 TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1029 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1030 custom_call_conv, static_conv, {}));
1031 return true;
1032 }
1033
RewriteDynamicConvolutionKernelGrad(HloInstruction * custom_call_conv,DynamicDimensionInference * dynamic_dimension_inference)1034 StatusOr<bool> RewriteDynamicConvolutionKernelGrad(
1035 HloInstruction* custom_call_conv,
1036 DynamicDimensionInference* dynamic_dimension_inference) {
1037 HloInstruction* activations = custom_call_conv->mutable_operand(0);
1038 HloInstruction* gradients = custom_call_conv->mutable_operand(1);
1039 TF_RET_CHECK(activations->shape().is_dynamic());
1040 TF_RET_CHECK(gradients->shape().is_dynamic());
1041 HloComputation* comp = custom_call_conv->parent();
1042 Window window = custom_call_conv->window();
1043 auto dnums = custom_call_conv->convolution_dimension_numbers();
1044 HloInstruction* zero = comp->AddInstruction(HloInstruction::CreateConstant(
1045 LiteralUtil::Zero(custom_call_conv->shape().element_type())));
1046 std::vector<HloInstruction*> padding_before(
1047 dnums.input_spatial_dimensions_size(), nullptr);
1048 for (int64_t spatial_dim_index = 0;
1049 spatial_dim_index < dnums.input_spatial_dimensions_size();
1050 ++spatial_dim_index) {
1051 int64_t input_spatial_dim =
1052 dnums.input_spatial_dimensions(spatial_dim_index);
1053 int64_t kernel_spatial_dim =
1054 dnums.kernel_spatial_dimensions(spatial_dim_index);
1055 HloInstruction* activations_dynamic_size =
1056 dynamic_dimension_inference->GetDynamicSize(
1057 custom_call_conv->mutable_operand(0), {}, input_spatial_dim);
1058 if (activations_dynamic_size != nullptr) {
1059 activations = PadWithScalar(activations, input_spatial_dim,
1060 activations_dynamic_size, zero);
1061 }
1062
1063 HloInstruction* gradients_dynamic_size =
1064 dynamic_dimension_inference->GetDynamicSize(
1065 custom_call_conv->mutable_operand(1), {}, kernel_spatial_dim);
1066 if (gradients_dynamic_size != nullptr) {
1067 gradients = PadWithScalar(gradients, kernel_spatial_dim,
1068 gradients_dynamic_size, zero);
1069 }
1070 if (activations_dynamic_size == nullptr ||
1071 gradients_dynamic_size == nullptr) {
1072 TF_RET_CHECK(activations_dynamic_size == nullptr &&
1073 gradients_dynamic_size == nullptr);
1074 continue;
1075 }
1076 int64_t output_spatial_dim =
1077 dnums.output_spatial_dimensions(spatial_dim_index);
1078 const WindowDimension& window_dim = window.dimensions(spatial_dim_index);
1079 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1080 activations_dynamic_size, /*window_size=*/
1081 custom_call_conv->shape().dimensions(output_spatial_dim),
1082 /*window_dilation=*/window_dim.stride(),
1083 /*window_stride=*/window_dim.window_dilation(),
1084 custom_call_conv->padding_type());
1085 padding_before[spatial_dim_index] = dynamic_window_dims.padding_before;
1086 }
1087
1088 // We only need to pad input feature on lhs to 0 -- it's mathematically
1089 // equivalent to padding both lhs and rhs to 0.
1090 const int64_t input_feature_dim = dnums.input_feature_dimension();
1091 if (HloInstruction* input_feature_dynamic_size =
1092 dynamic_dimension_inference->GetDynamicSize(
1093 custom_call_conv->mutable_operand(0), {}, input_feature_dim)) {
1094 activations = PadWithScalar(activations, input_feature_dim,
1095 input_feature_dynamic_size, zero);
1096 }
1097
1098 if (custom_call_conv->padding_type() == PaddingType::PADDING_SAME) {
1099 activations = RewriteInputWithDynamicPadding(
1100 custom_call_conv, activations, zero, absl::MakeSpan(padding_before),
1101 &window,
1102 [&](int64_t dim) { return dnums.input_spatial_dimensions(dim); });
1103 }
1104
1105 HloInstruction* static_conv = comp->AddInstruction(
1106 HloInstruction::CreateConvolve(
1107 custom_call_conv->shape(), activations, gradients,
1108 custom_call_conv->feature_group_count(),
1109 custom_call_conv->batch_group_count(), window,
1110 custom_call_conv->convolution_dimension_numbers(),
1111 custom_call_conv->precision_config()),
1112 "ConvBackwardGrad");
1113 TF_RETURN_IF_ERROR(custom_call_conv->ReplaceAllUsesWith(static_conv));
1114 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1115 custom_call_conv, static_conv, {}));
1116 return true;
1117 }
1118
RewriteDynamicReduceWindowSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1119 StatusOr<bool> RewriteDynamicReduceWindowSamePadding(
1120 HloInstruction* hlo,
1121 DynamicDimensionInference* dynamic_dimension_inference) {
1122 if (hlo->shape().IsTuple()) {
1123 // TODO (b/73062247) variadic reduce window is not yet supported here.
1124 return Unimplemented("DynamicReduceWindowSamePadding not yet supported.");
1125 }
1126 HloInstruction* input = hlo->mutable_operand(0);
1127 HloInstruction* init = hlo->mutable_operand(1);
1128 HloComputation* comp = hlo->parent();
1129 int64_t rank = hlo->shape().rank();
1130 Window window = hlo->window();
1131 std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1132 for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1133 HloInstruction* operand_dynamic_size =
1134 dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1135 dim_index);
1136 if (operand_dynamic_size == nullptr) {
1137 continue;
1138 }
1139 const WindowDimension& window_dim = window.dimensions(dim_index);
1140 if (window_util::IsTrivialWindowDimension(window_dim)) {
1141 continue;
1142 }
1143 input = PadWithScalar(input, dim_index, operand_dynamic_size, init);
1144
1145 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1146 operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1147 window_dim.stride(), PaddingType::PADDING_SAME);
1148 padding_before[dim_index] = dynamic_window_dims.padding_before;
1149 }
1150
1151 input = RewriteInputWithDynamicPadding(
1152 hlo, input, init, absl::MakeSpan(padding_before), &window,
1153 [](int64_t dim) { return dim; });
1154
1155 HloInstruction* rewritten = comp->AddInstruction(
1156 HloInstruction::CreateReduceWindow(hlo->shape(), input, init, window,
1157 hlo->called_computations()[0]),
1158 "DynamicReduceWindow");
1159 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1160 TF_RETURN_IF_ERROR(
1161 dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1162 return true;
1163 }
1164
RewriteDynamicSelectAndScatterSamePadding(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1165 StatusOr<bool> RewriteDynamicSelectAndScatterSamePadding(
1166 HloInstruction* hlo,
1167 DynamicDimensionInference* dynamic_dimension_inference) {
1168 HloInstruction* input = hlo->mutable_operand(0);
1169 HloInstruction* source = hlo->mutable_operand(1);
1170 HloInstruction* init = hlo->mutable_operand(2);
1171 TF_ASSIGN_OR_RETURN(HloInstruction * input_padding_value,
1172 ChooseIdentityValue(hlo, /*operand_number=*/0));
1173 HloComputation* comp = hlo->parent();
1174 int64_t rank = hlo->shape().rank();
1175 Window window = hlo->window();
1176 std::vector<HloInstruction*> padding_before(hlo->shape().rank(), nullptr);
1177 for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1178 const WindowDimension& window_dim = window.dimensions(dim_index);
1179 if (window_util::IsTrivialWindowDimension(window_dim)) {
1180 continue;
1181 }
1182 HloInstruction* operand_dynamic_size =
1183 dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(0), {},
1184 dim_index);
1185 if (operand_dynamic_size == nullptr) {
1186 continue;
1187 }
1188
1189 input = PadWithScalar(input, dim_index, operand_dynamic_size,
1190 input_padding_value);
1191
1192 HloInstruction* source_dynamic_size =
1193 dynamic_dimension_inference->GetDynamicSize(hlo->mutable_operand(1), {},
1194 dim_index);
1195 if (source_dynamic_size == nullptr) {
1196 continue;
1197 }
1198 source = PadWithScalar(source, dim_index, source_dynamic_size, init);
1199
1200 DynamicWindowDims dynamic_window_dims = GetWindowedOutputSize(
1201 operand_dynamic_size, window_dim.size(), window_dim.window_dilation(),
1202 window_dim.stride(), PaddingType::PADDING_SAME);
1203 padding_before[dim_index] = dynamic_window_dims.padding_before;
1204 }
1205
1206 input = RewriteInputWithDynamicPadding(
1207 hlo, input, input_padding_value, absl::MakeSpan(padding_before), &window,
1208 [](int64_t dim) { return dim; });
1209
1210 // RewriteInputWithDynamicPadding adds padding to the input. However those
1211 // inputs should not be materialized in select and scatter's output and we
1212 // need to slice them out using dynamic slice. To prevent dynamic slicegoing
1213 // OOB, we first add some high-pad to the output to leave enough space.
1214 HloInstruction* rewritten = comp->AddInstruction(
1215 HloInstruction::CreateSelectAndScatter(
1216 input->shape(), input, hlo->called_computations()[0], window, source,
1217 init, hlo->called_computations()[1]),
1218 "DynamicReduceWindow");
1219 std::vector<HloInstruction*> start_indices(
1220 input->shape().rank(),
1221 comp->AddInstruction(
1222 HloInstruction::CreateConstant(LiteralUtil::Zero(S32))));
1223 PaddingConfig padding_configs;
1224 for (int64_t dim_index = 0; dim_index < rank; ++dim_index) {
1225 PaddingConfig::PaddingConfigDimension padding_dim;
1226 if (padding_before[dim_index] != nullptr) {
1227 const WindowDimension& window_dim = window.dimensions(dim_index);
1228 const int64_t dilated_window_size = window_util::DilatedBound(
1229 window_dim.size(), window_dim.window_dilation());
1230 padding_dim.set_edge_padding_high(dilated_window_size);
1231 start_indices[dim_index] = padding_before[dim_index];
1232 }
1233 *padding_configs.add_dimensions() = padding_dim;
1234 }
1235 HloInstruction* padded =
1236 MakePadHlo(rewritten, init, padding_configs).ValueOrDie();
1237 rewritten = comp->AddInstruction(HloInstruction::CreateDynamicSlice(
1238 hlo->shape(), padded, start_indices, hlo->shape().dimensions()));
1239 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(rewritten));
1240 TF_RETURN_IF_ERROR(
1241 dynamic_dimension_inference->ForwardDynamicSize(hlo, rewritten, {}));
1242 return true;
1243 }
1244
RewriteDynamicConcat(HloInstruction * concat,DynamicDimensionInference * dynamic_dimension_inference)1245 StatusOr<bool> RewriteDynamicConcat(
1246 HloInstruction* concat,
1247 DynamicDimensionInference* dynamic_dimension_inference) {
1248 const int64_t concat_dim = concat->concatenate_dimension();
1249 HloComputation* comp = concat->parent();
1250 if (dynamic_dimension_inference->GetDynamicSize(concat, {}, concat_dim) ==
1251 nullptr) {
1252 // Concat dimension is not dynamic -- no rewrite needed.
1253 return false;
1254 }
1255 std::vector<HloInstruction*> offsets;
1256 for (int64_t i = 0; i < concat->shape().dimensions_size(); ++i) {
1257 offsets.push_back(comp->AddInstruction(
1258 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))));
1259 }
1260 HloInstruction* rewritten_concat = concat;
1261 // Keep track of previous users before rewrite so that we can update their
1262 // operands later.
1263 auto prev_users = concat->users();
1264 for (int64_t i = 0; i < concat->operand_count(); ++i) {
1265 // Rewrite the concat by dynamic update slicing operand into the concat dim.
1266 HloInstruction* operand = concat->mutable_operand(i);
1267 rewritten_concat =
1268 comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1269 rewritten_concat->shape(), rewritten_concat, operand, offsets));
1270 // Update the offset of concat dimension by adding the size of the concat
1271 // dimension of the operand to it.
1272 HloInstruction* dynamic_size =
1273 dynamic_dimension_inference->GetDynamicSize(operand, {}, concat_dim);
1274 if (dynamic_size == nullptr) {
1275 HloInstruction* static_size = comp->AddInstruction(
1276 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(
1277 operand->shape().dimensions(concat_dim))));
1278 offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary(
1279 ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1280 static_size));
1281 } else {
1282 offsets[concat_dim] = comp->AddInstruction(HloInstruction::CreateBinary(
1283 ShapeUtil::MakeScalarShape(S32), HloOpcode::kAdd, offsets[concat_dim],
1284 dynamic_size));
1285 }
1286 }
1287 TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat));
1288 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1289 concat, rewritten_concat, {}));
1290 return true;
1291 }
1292
RewriteDynamicSort(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1293 StatusOr<bool> RewriteDynamicSort(
1294 HloInstruction* hlo,
1295 DynamicDimensionInference* dynamic_dimension_inference) {
1296 HloInstruction* dynamic_size = nullptr;
1297 HloSortInstruction* sort = Cast<HloSortInstruction>(hlo);
1298 HloComputation* comp = hlo->parent();
1299 int64_t sort_dim = sort->sort_dimension();
1300 // Find the dynamic dimension in the operand.
1301 for (auto* operand : sort->operands()) {
1302 if (dynamic_size == nullptr) {
1303 dynamic_size =
1304 dynamic_dimension_inference->GetDynamicSize(operand, {}, sort_dim);
1305 }
1306 }
1307
1308 if (dynamic_size == nullptr) {
1309 // Not a dynamic sort, ignore.
1310 return false;
1311 }
1312
1313 Shape operand_shape =
1314 ShapeUtil::ChangeElementType(sort->operand(0)->shape(), S32);
1315 HloInstruction* iota =
1316 comp->AddInstruction(HloInstruction::CreateIota(operand_shape, sort_dim));
1317 HloInstruction* dynamic_size_broadcasted = comp->AddInstruction(
1318 HloInstruction::CreateBroadcast(operand_shape, dynamic_size, {}));
1319 HloInstruction* lt = comp->AddInstruction(HloInstruction::CreateCompare(
1320 ShapeUtil::ChangeElementType(operand_shape, PRED), iota,
1321 dynamic_size_broadcasted, ComparisonDirection::kLt));
1322 sort->AppendOperand(lt);
1323
1324 const int64_t param_number_before_rewritten =
1325 sort->called_computations()[0]->num_parameters();
1326 auto new_param_0 = HloInstruction::CreateParameter(
1327 param_number_before_rewritten, ShapeUtil::MakeScalarShape(PRED),
1328 "inbound_lhs");
1329 auto new_param_1 = HloInstruction::CreateParameter(
1330 param_number_before_rewritten + 1, ShapeUtil::MakeScalarShape(PRED),
1331 "inbound_rhs");
1332 std::vector<const HloInstruction*> extra_parameters{new_param_0.get(),
1333 new_param_1.get()};
1334 HloComputation* sort_comp = sort->parent()->parent()->AddEmbeddedComputation(
1335 sort->called_computations()[0]->CloneWithReplacements(
1336 /*replacements=*/absl::flat_hash_map<
1337 const HloInstruction*, std::unique_ptr<HloInstruction>>(),
1338 extra_parameters));
1339 auto inbound_lhs =
1340 sort_comp->parameter_instruction(param_number_before_rewritten);
1341 auto inbound_rhs =
1342 sort_comp->parameter_instruction(param_number_before_rewritten + 1);
1343 sort->ReplaceCalledComputations(
1344 [&](HloComputation* comp) { return sort_comp; });
1345
1346 // inbound_lhs & (sort_comp | !in_bound_rhs)
1347 // Select the lhs if it is in bounds and the rhs is out of bounds or the
1348 // sort_comp returns true.
1349 auto out_of_bound_rhs = sort_comp->AddInstruction(HloInstruction::CreateUnary(
1350 ShapeUtil::MakeScalarShape(PRED), HloOpcode::kNot, inbound_rhs));
1351 auto sort_comp_or_out_of_bound_rhs =
1352 sort_comp->AddInstruction(HloInstruction::CreateBinary(
1353 ShapeUtil::MakeScalarShape(PRED), HloOpcode::kOr,
1354 sort_comp->root_instruction(), out_of_bound_rhs));
1355
1356 auto new_root = sort_comp->AddInstruction(HloInstruction::CreateBinary(
1357 ShapeUtil::MakeScalarShape(PRED), HloOpcode::kAnd, inbound_lhs,
1358 sort_comp_or_out_of_bound_rhs));
1359 sort_comp->set_root_instruction(new_root);
1360 Shape compare_shape =
1361 ShapeUtil::ChangeElementType(sort->operand(0)->shape(), PRED);
1362 if (sort->shape().IsTuple()) {
1363 // For sort that is already tuple, simply add another result to the tuple.
1364 *sort->mutable_shape()->add_tuple_shapes() =
1365 ShapeUtil::ChangeElementType(operand_shape, PRED);
1366 } else {
1367 auto sort_users = sort->users();
1368 auto sort_clone = comp->AddInstruction(sort->Clone());
1369 *sort_clone->mutable_shape() = ShapeUtil::MakeTupleShape(
1370 {sort->shape(), ShapeUtil::ChangeElementType(operand_shape, PRED)});
1371 auto rewritten_sort = comp->AddInstruction(
1372 HloInstruction::CreateGetTupleElement(sort->shape(), sort_clone, 0));
1373 for (HloInstruction* user : sort_users) {
1374 TF_RETURN_IF_ERROR(sort->ReplaceUseWith(user, rewritten_sort));
1375 }
1376 TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize(
1377 sort, rewritten_sort, {}));
1378 if (comp->root_instruction() == sort) {
1379 comp->set_root_instruction(rewritten_sort);
1380 }
1381 }
1382
1383 return true;
1384 }
1385
RewriteDynamicBinaryOp(HloInstruction * binary,DynamicDimensionInference * dynamic_dimension_inference)1386 StatusOr<bool> RewriteDynamicBinaryOp(
1387 HloInstruction* binary,
1388 DynamicDimensionInference* dynamic_dimension_inference) {
1389 HloInstruction* operand_0 = binary->mutable_operand(0);
1390 HloInstruction* operand_1 = binary->mutable_operand(1);
1391
1392 HloComputation* comp = binary->parent();
1393 TF_RET_CHECK(operand_0->shape().rank() == operand_1->shape().rank());
1394 auto dims_0 = dynamic_dimension_inference->GetDynamicSizes(operand_0, {});
1395 auto dims_1 = dynamic_dimension_inference->GetDynamicSizes(operand_1, {});
1396 bool changed = false;
1397 for (int64_t i = 0; i < dims_0.size(); ++i) {
1398 HloInstruction* dim_0 = dims_0[i];
1399 HloInstruction* dim_1 = dims_1[i];
1400
1401 if (dims_0[i] != dims_1[i] && dims_0[i] != nullptr &&
1402 dims_1[i] != nullptr) {
1403 changed = true;
1404 // It is possible that a dynamic dimension of one operand is size 1 while
1405 // the other is greater than one. According to implicit broadcast
1406 // semantics, we need to insert broadcast in this case to make the dynamic
1407 // shape match.
1408
1409 // An implicit broadcast is inserted by slicing the small shape into a
1410 // size 1 slice, reshape out the size 1 dimension then broadcast to the
1411 // full shape:
1412 //
1413 // Input [2, <=5, 3]
1414 // |
1415 // Slice [2, 1, 3]
1416 // |
1417 // Reshape [2, 3]
1418 // |
1419 // Broadcast [2, 5, 3]
1420 auto rewrite_operand = [&](HloInstruction* pred,
1421 HloInstruction* operand) -> HloInstruction* {
1422 Shape static_shape = operand->shape();
1423 static_shape.clear_dynamic_dimensions();
1424 pred = comp->AddInstruction(HloInstruction::CreateBroadcast(
1425 ShapeUtil::ChangeElementType(static_shape, PRED), pred, {}));
1426 Shape slice_shape = static_shape;
1427 slice_shape.set_dimensions(i, 1);
1428 std::vector<int64> start_indices(slice_shape.rank(), 0);
1429 std::vector<int64> strides(slice_shape.rank(), 1);
1430 HloInstruction* slice = comp->AddInstruction(
1431 HloInstruction::CreateSlice(slice_shape, operand, start_indices,
1432 slice_shape.dimensions(), strides));
1433 Shape reshape_shape = ShapeUtil::DeleteDimension(i, slice_shape);
1434 HloInstruction* reshape = comp->AddInstruction(
1435 HloInstruction::CreateReshape(reshape_shape, slice));
1436 std::vector<int64> broadcast_dims;
1437 broadcast_dims.reserve(static_shape.rank() - 1);
1438 // Broadcast to all dims execpt for i.
1439 for (int64_t j = 0; j < static_shape.rank(); ++j) {
1440 if (j != i) {
1441 broadcast_dims.push_back(j);
1442 }
1443 }
1444
1445 HloInstruction* broadcast =
1446 comp->AddInstruction(HloInstruction::CreateBroadcast(
1447 static_shape, reshape, broadcast_dims),
1448 "implicit_broadcast");
1449
1450 // Use a select instead of conditional as elementwise operations promote
1451 // more fusion.
1452 HloInstruction* select =
1453 comp->AddInstruction(HloInstruction::CreateTernary(
1454 static_shape, HloOpcode::kSelect, pred, broadcast, operand));
1455 return select;
1456 };
1457 auto operand_0_needs_broadcast = binary->parent()->AddInstruction(
1458 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_0,
1459 dim_1, ComparisonDirection::kLt),
1460 "lhs_needs_implicit_broadcast");
1461 operand_0 = rewrite_operand(operand_0_needs_broadcast, operand_0);
1462
1463 auto operand_1_needs_broadcast = binary->parent()->AddInstruction(
1464 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), dim_1,
1465 dim_0, ComparisonDirection::kLt),
1466 "rhs_needs_implicit_broadcast");
1467 operand_1 = rewrite_operand(operand_1_needs_broadcast, operand_1);
1468 }
1469 }
1470 if (changed) {
1471 TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(0, operand_0));
1472 TF_RETURN_IF_ERROR(binary->ReplaceOperandWith(1, operand_1));
1473 }
1474 return changed;
1475 }
1476
RewriteDynamicUpdateSlice(HloInstruction * hlo,DynamicDimensionInference * dynamic_dimension_inference)1477 StatusOr<bool> RewriteDynamicUpdateSlice(
1478 HloInstruction* hlo,
1479 DynamicDimensionInference* dynamic_dimension_inference) {
1480 HloDynamicUpdateSliceInstruction* dus =
1481 Cast<HloDynamicUpdateSliceInstruction>(hlo);
1482 HloComputation* comp = hlo->parent();
1483 // Suppose we have a base area that we want to update:
1484 // +------------------------+
1485 // | |
1486 // | base |
1487 // | |
1488 // +------------------------+
1489 //
1490 // A partial update with dynamic padding looks like this:
1491 //
1492 // +------+-------+
1493 // |update|padding|
1494 // +------+-------+
1495 //
1496 // We don't want the padding to overwrite the base area:
1497 //
1498 // +------------------------+
1499 // | +------+-------+
1500 // |<-begin->|update|padding| (what we want to avoid)
1501 // | +------+-------+
1502 // +------------------------+
1503 //
1504 // Instead we want to keep the base area untouched except for the update
1505 // region:
1506 //
1507 // +------------------------+
1508 // | +------+ |
1509 // |<-begin->|update| base | (what we want)
1510 // | +------+ |
1511 // +------------------------+
1512 //
1513 // We do this by dynamic slicing the base area out first with the same begin
1514 // index:
1515 //
1516 // +--------------+
1517 // <-begin-> | base |
1518 // +--------------+
1519 //
1520 // Then replace the update's padding part with base:
1521 //
1522 // +------+-------+
1523 // |update| base |
1524 // +------+-------+
1525 //
1526 // Then do the DUS.
1527
1528 HloInstruction* update = dus->mutable_operand(1);
1529 HloInstruction* base = dus->mutable_operand(0);
1530 std::vector<HloInstruction*> dynamic_dims_in_partial_update(
1531 update->shape().rank(), nullptr);
1532 bool needs_rewrite = false;
1533 for (int64_t i = 0; i < update->shape().rank(); ++i) {
1534 if (update->shape().dimensions(i) < base->shape().dimensions(i)) {
1535 HloInstruction* dynamic_dim =
1536 dynamic_dimension_inference->GetDynamicSize(update, {}, i);
1537
1538 if (dynamic_dim != nullptr) {
1539 dynamic_dims_in_partial_update[i] = dynamic_dim;
1540 needs_rewrite = true;
1541 }
1542 }
1543 }
1544
1545 if (!needs_rewrite) {
1546 return false;
1547 }
1548 std::vector<HloInstruction*> indices;
1549 indices.reserve(dus->operand_count() - 2);
1550 for (int64_t i = 2; i < dus->operand_count(); ++i) {
1551 indices.push_back(dus->mutable_operand(i));
1552 }
1553 HloInstruction* base_slice =
1554 comp->AddInstruction(HloInstruction::CreateDynamicSlice(
1555 update->shape(), base, indices, update->shape().dimensions()));
1556
1557 for (int64_t i = 0; i < dynamic_dims_in_partial_update.size(); ++i) {
1558 HloInstruction* dynamic_dim = dynamic_dims_in_partial_update[i];
1559 if (dynamic_dim != nullptr) {
1560 Shape mask_shape_int = ShapeUtil::ChangeElementType(update->shape(), S32);
1561 Shape mask_shape_pred =
1562 ShapeUtil::ChangeElementType(update->shape(), PRED);
1563 // Generate mask using iota and dynamic_dim.
1564 HloInstruction* iota =
1565 comp->AddInstruction(HloInstruction::CreateIota(mask_shape_int, i));
1566 HloInstruction* broadcast_dim = comp->AddInstruction(
1567 HloInstruction::CreateBroadcast(mask_shape_int, dynamic_dim, {}));
1568 HloInstruction* pred = comp->AddInstruction(HloInstruction::CreateCompare(
1569 mask_shape_pred, iota, broadcast_dim, ComparisonDirection::kLt));
1570 // Update `update` to include base.
1571 update = comp->AddInstruction(HloInstruction::CreateTernary(
1572 update->shape(), HloOpcode::kSelect, pred, update, base_slice));
1573 }
1574 }
1575 TF_RETURN_IF_ERROR(dus->ReplaceOperandWith(1, update));
1576
1577 return true;
1578 }
1579
RewriteDynamicReshape(HloInstruction * reshape,DynamicDimensionInference * dynamic_dimension_inference)1580 StatusOr<bool> RewriteDynamicReshape(
1581 HloInstruction* reshape,
1582 DynamicDimensionInference* dynamic_dimension_inference) {
1583 bool changed = false;
1584 HloInstruction* operand = reshape->mutable_operand(0);
1585 std::vector<HloInstruction*> input_dynamic_dims;
1586 for (int64_t dim = 0; dim < operand->shape().dimensions_size(); ++dim) {
1587 input_dynamic_dims.push_back(
1588 dynamic_dimension_inference->GetDynamicSize(operand, {}, dim));
1589 }
1590
1591 std::vector<HloInstruction*> output_dynamic_dims;
1592 for (int64_t dim = 0; dim < reshape->shape().dimensions_size(); ++dim) {
1593 output_dynamic_dims.push_back(
1594 dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim));
1595 }
1596
1597 auto common_factors = CommonFactors(operand->shape().dimensions(),
1598 reshape->shape().dimensions());
1599 // Find common_factors that the input belongs to.
1600 for (int64_t i = 0; i < common_factors.size() - 1; ++i) {
1601 auto start = common_factors[i];
1602 auto end = common_factors[i + 1];
1603 std::vector<int64> input_dims;
1604 std::vector<int64> output_dims;
1605 for (int64_t dim = start.first; dim < end.first; ++dim) {
1606 input_dims.push_back(dim);
1607 }
1608 for (int64_t dim = start.second; dim < end.second; ++dim) {
1609 output_dims.push_back(dim);
1610 }
1611
1612 VLOG(2) << "input_dims: " << VectorString(input_dims);
1613 VLOG(2) << "output_dims: " << VectorString(output_dims);
1614
1615 if (input_dims.empty() || output_dims.empty()) {
1616 continue;
1617 }
1618 bool has_dynamic_dimension = absl::c_any_of(output_dims, [&](int64_t dim) {
1619 HloInstruction* operand_dynamic_size =
1620 dynamic_dimension_inference->GetDynamicSize(reshape, {}, dim);
1621
1622 return operand_dynamic_size != nullptr ||
1623 reshape->shape().is_dynamic_dimension(dim);
1624 });
1625
1626 if (!has_dynamic_dimension) {
1627 // Don't need to rewrite any group without dynamic dimensions.
1628 VLOG(2) << "All dimensions are static in this common factor group";
1629 continue;
1630 }
1631
1632 if (input_dims.size() == 1 && output_dims.size() == 1) {
1633 // The dimension is unchanged. No rewrite needed.
1634 continue;
1635 }
1636 if (input_dims.size() > 1 && output_dims.size() > 1) {
1637 // We don't support the case when a dynamic dimension is both combined
1638 // with and splitted into other dimensions:
1639 //
1640 // [x, yz]
1641 // | Reshape
1642 // [xy, z]
1643 //
1644 // TODO(yunxing): This can be supported by canonicalizing
1645 // the offending reshape into two reshapes:
1646 //
1647 // [x,yz]
1648 // | Reshape
1649 // [x, y, z]
1650 // | Reshape
1651 // [xy, z]
1652 //
1653 return Unimplemented(
1654 "Dynamic input dimension to reshape that is both splitted and "
1655 "combined is not supported %s",
1656 reshape->ToString());
1657 }
1658
1659 TF_RETURN_IF_ERROR(RewriteDynamicReshapeSingleGroup(
1660 reshape, input_dims, output_dims, absl::MakeSpan(input_dynamic_dims),
1661 absl::MakeSpan(output_dynamic_dims), dynamic_dimension_inference));
1662 }
1663
1664 return changed;
1665 }
1666
1667 // Insert pad-to-static after `inst` if `inst` has dynamic dimensions in it.
1668 // Recurse into tuple instructions.
InsertPadToStaticOnInstruction(HloInstruction * inst)1669 StatusOr<HloInstruction*> InsertPadToStaticOnInstruction(HloInstruction* inst) {
1670 if (inst->shape().is_static()) {
1671 return inst;
1672 }
1673 HloComputation* comp = inst->parent();
1674 if (!inst->shape().IsTuple()) {
1675 // The output shape of pad static is a tuple. The 0th element is the data
1676 // output, which is the same as input shape, but without dynamic dimensions;
1677 // i-th element is the dynamic dimension size for i-1th input dimension.
1678 Shape data_output_shape = inst->shape(); // 0th element.
1679 data_output_shape.clear_dynamic_dimensions();
1680 Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
1681 for (int64_t i = 0; i < inst->shape().rank(); ++i) {
1682 ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
1683 &output_shape);
1684 }
1685 HloInstruction* pad_to_static =
1686 comp->AddInstruction(HloInstruction::CreateCustomCall(
1687 output_shape, {inst}, "PadToStatic", ""));
1688 HloInstruction* data_output =
1689 comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1690 data_output_shape, pad_to_static, 0));
1691 return data_output;
1692 }
1693
1694 TF_RET_CHECK(inst->shape().IsTuple());
1695 std::vector<HloInstruction*> static_tuple_elements;
1696 for (int64_t i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
1697 // For each tuple element, if it is static, pass it through. If it is
1698 // dynamic, recursively call this function again.
1699 HloInstruction* gte =
1700 comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1701 inst->shape().tuple_shapes(i), inst, i));
1702
1703 if (gte->shape().is_static()) {
1704 static_tuple_elements.push_back(gte);
1705 } else {
1706 TF_ASSIGN_OR_RETURN(HloInstruction * static_gte,
1707 InsertPadToStaticOnInstruction(gte));
1708 static_tuple_elements.push_back(static_gte);
1709 }
1710 }
1711
1712 return comp->AddInstruction(
1713 HloInstruction::CreateTuple(static_tuple_elements));
1714 }
1715
InsertPadToStaticAfterModuleInputs(HloModule * module)1716 Status InsertPadToStaticAfterModuleInputs(HloModule* module) {
1717 std::vector<HloInstruction*> params;
1718 HloComputation* entry = module->entry_computation();
1719 for (int64_t i = 0; i < entry->num_parameters(); ++i) {
1720 HloInstruction* param =
1721 module->entry_computation()->parameter_instruction(i);
1722 auto users = param->users();
1723 TF_ASSIGN_OR_RETURN(HloInstruction * static_param,
1724 InsertPadToStaticOnInstruction(param));
1725 for (auto* user : users) {
1726 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, static_param));
1727 }
1728 if (param == entry->root_instruction()) {
1729 module->entry_computation()->set_root_instruction(static_param);
1730 }
1731 }
1732 return Status::OK();
1733 }
1734
1735 // Remove all dynamic shapes between pad-to-static and slice-to-dynamic.
1736 //
1737 // After this visitor the entry computation then looks like:
1738 // Param(dynamic)
1739 // |
1740 // GTE (dynamic)
1741 // |
1742 // PadToStatic(static)
1743 // |
1744 // .... regular computation with static shapes.
1745 // |
1746 // SliceToDynamic(dynamic)
1747 // |
1748 // ROOT tuple (dynamic)
1749 class DynamicShapeRemovingVisitor : public DfsHloVisitorWithDefault {
1750 public:
DynamicShapeRemovingVisitor(const DynamicPadder::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_dimension_inference)1751 explicit DynamicShapeRemovingVisitor(
1752 const DynamicPadder::OpSupportsDynamismHandler&
1753 op_supports_dynamism_handler,
1754 DynamicDimensionInference* dynamic_dimension_inference)
1755 : op_supports_dynamism_handler_(op_supports_dynamism_handler),
1756 dynamic_dimension_inference_(dynamic_dimension_inference) {}
1757
1758 Status DefaultAction(HloInstruction* hlo) override;
1759
1760 Status HandleCustomCall(HloInstruction* hlo) override;
1761
1762 Status HandleTuple(HloInstruction* hlo) override;
1763 Status HandleGetTupleElement(HloInstruction* hlo) override;
1764
1765 Status HandleParameter(HloInstruction* hlo) override;
1766
Run(HloComputation * computation,const DynamicPadder::OpSupportsDynamismHandler & op_supports_dynamism_handler,DynamicDimensionInference * dynamic_shape_inference,bool require_dynamic_output)1767 static Status Run(HloComputation* computation,
1768 const DynamicPadder::OpSupportsDynamismHandler&
1769 op_supports_dynamism_handler,
1770 DynamicDimensionInference* dynamic_shape_inference,
1771 bool require_dynamic_output) {
1772 DynamicShapeRemovingVisitor visitor(op_supports_dynamism_handler,
1773 dynamic_shape_inference);
1774 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
1775 // If the outputs is required to be dynamic form, insert static to dynamic
1776 // conversion as root.
1777 if (require_dynamic_output) {
1778 HloInstruction* root = computation->root_instruction();
1779 if (dynamic_shape_inference->HasDynamicDimension(root)) {
1780 TF_ASSIGN_OR_RETURN(HloInstruction * new_root,
1781 visitor.ConvertToDynamic(root));
1782 computation->set_root_instruction(new_root);
1783 }
1784 }
1785 return Status::OK();
1786 }
1787
1788 private:
1789 // If a tensor produced by `inst` is in dynamic form, convert it to static and
1790 // returns the new instruction.
1791 StatusOr<HloInstruction*> ConvertToStatic(HloInstruction* inst);
1792
1793 // If a tensor produced by `inst` is in static form, convert it to dynamic and
1794 // returns the new instruction.
1795 StatusOr<HloInstruction*> ConvertToDynamic(HloInstruction* inst);
1796
1797 const DynamicPadder::OpSupportsDynamismHandler& op_supports_dynamism_handler_;
1798
1799 DynamicDimensionInference* dynamic_dimension_inference_;
1800 };
1801
ConvertToDynamic(HloInstruction * inst)1802 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToDynamic(
1803 HloInstruction* inst) {
1804 auto* comp = inst->parent();
1805 const Shape& shape = inst->shape();
1806 if (shape.IsTuple()) {
1807 std::vector<HloInstruction*> dynamic_operands;
1808 for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1809 auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1810 shape.tuple_shapes(i), inst, i));
1811 if (dynamic_dimension_inference_->HasDynamicDimension(inst, {i})) {
1812 TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
1813 TF_ASSIGN_OR_RETURN(auto dynamic, ConvertToDynamic(gte));
1814 dynamic_operands.push_back(dynamic);
1815 } else {
1816 dynamic_operands.push_back(gte);
1817 }
1818 }
1819 return comp->AddInstruction(HloInstruction::CreateTuple(dynamic_operands));
1820 } else {
1821 // Collect the data input, as well as dimension sizes, and feed them to
1822 // slice to dynamic to create a dynamic tensor.
1823 Shape output_shape = shape; // 0th element.
1824 CHECK(output_shape.is_static());
1825 std::vector<HloInstruction*> slice_operand;
1826 slice_operand.push_back(inst);
1827 for (int64_t i = 0; i < output_shape.dimensions_size(); ++i) {
1828 auto dimension_size =
1829 dynamic_dimension_inference_->GetDynamicSize(inst, {}, i);
1830 if (dimension_size == nullptr) {
1831 dimension_size = comp->AddInstruction(HloInstruction::CreateConstant(
1832 LiteralUtil::CreateR0<int32>(output_shape.dimensions(i))));
1833 } else {
1834 output_shape.set_dynamic_dimension(i, true);
1835 }
1836 slice_operand.push_back(dimension_size);
1837 }
1838 return comp->AddInstruction(HloInstruction::CreateCustomCall(
1839 output_shape, slice_operand, "SliceToDynamic"));
1840 }
1841 }
1842
ConvertToStatic(HloInstruction * inst)1843 StatusOr<HloInstruction*> DynamicShapeRemovingVisitor::ConvertToStatic(
1844 HloInstruction* inst) {
1845 auto* comp = inst->parent();
1846 const Shape& shape = inst->shape();
1847 CHECK(shape.is_dynamic());
1848 if (shape.IsTuple()) {
1849 std::vector<HloInstruction*> static_operands;
1850 for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
1851 auto gte = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1852 shape.tuple_shapes(i), inst, i));
1853 TF_RETURN_IF_ERROR(dynamic_dimension_inference_->Update(gte));
1854 auto operand = inst->mutable_operand(i);
1855 if (shape.tuple_shapes(i).is_dynamic()) {
1856 TF_ASSIGN_OR_RETURN(auto static_inst, ConvertToStatic(gte));
1857 static_operands.push_back(static_inst);
1858 } else {
1859 static_operands.push_back(operand);
1860 }
1861 }
1862 return comp->AddInstruction(HloInstruction::CreateTuple(static_operands));
1863 } else {
1864 // The output shape of pad static is a tuple. The 0th element is the data
1865 // output, which is the same as input shape, but without dynamic dimensions.
1866 // i-th element is the dynamic dimension size for i-1th input dimension.
1867 Shape data_output_shape = shape; // 0th element.
1868 data_output_shape.clear_dynamic_dimensions();
1869 Shape output_shape = ShapeUtil::MakeTupleShape({data_output_shape});
1870 for (int64_t i = 0; i < shape.rank(); ++i) {
1871 ShapeUtil::AppendShapeToTuple(ShapeUtil::MakeScalarShape(S32),
1872 &output_shape);
1873 }
1874 HloInstruction* pad_to_static =
1875 comp->AddInstruction(HloInstruction::CreateCustomCall(
1876 output_shape, {inst}, "PadToStatic", ""));
1877 HloInstruction* data_output =
1878 comp->AddInstruction(HloInstruction::CreateGetTupleElement(
1879 data_output_shape, pad_to_static, 0));
1880 return data_output;
1881 }
1882 }
1883
DefaultAction(HloInstruction * hlo)1884 Status DynamicShapeRemovingVisitor::DefaultAction(HloInstruction* hlo) {
1885 const bool input_is_dynamic = absl::c_any_of(
1886 hlo->operands(),
1887 [](const HloInstruction* hlo) { return hlo->shape().is_dynamic(); });
1888
1889 // By default, ops don't support dynamic lowering.
1890 OpDynamismSupport op_support = OpDynamismSupport::kNoSupport;
1891 if (op_supports_dynamism_handler_) {
1892 op_support = op_supports_dynamism_handler_(hlo);
1893 }
1894 if (op_support == OpDynamismSupport::kNoSupport) {
1895 for (auto* sub_computation : hlo->called_computations()) {
1896 for (auto* param : sub_computation->parameter_instructions()) {
1897 param->mutable_shape()->clear_dynamic_dimensions();
1898 }
1899 }
1900 }
1901 // If the input to an op is static and the op doesn't support
1902 // dynamic output, remove dynamism in output -- dynamic_padder should have
1903 // rewritten it to support static shapes.
1904 if (!input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
1905 hlo->mutable_shape()->clear_dynamic_dimensions();
1906 return Status::OK();
1907 }
1908
1909 // Op doesn't support dynamic tensor: For each operand rewrite dynamic input
1910 // into static input using pad_to_static.
1911 if (input_is_dynamic && op_support == OpDynamismSupport::kNoSupport) {
1912 VLOG(1) << "op doesn't support dynamic tensor: " << hlo->ToString();
1913 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1914 if (hlo->operand(i)->shape().is_dynamic()) {
1915 TF_ASSIGN_OR_RETURN(auto static_operand,
1916 ConvertToStatic(hlo->mutable_operand(i)));
1917 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, static_operand));
1918 }
1919 }
1920 // This op doesn't support dynamic lowering so the op has to be static.
1921 hlo->mutable_shape()->clear_dynamic_dimensions();
1922 return Status::OK();
1923 }
1924
1925 // If the op requires dynamic tensor and input is static -- construct a
1926 // dynamic tensor from the static tensor to feed it.
1927 if (!input_is_dynamic && op_support == OpDynamismSupport::kRequired) {
1928 VLOG(1) << "op doesn't support static tensor: " << hlo->ToString();
1929 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1930 auto operand = hlo->mutable_operand(i);
1931 if (dynamic_dimension_inference_->HasDynamicDimension(operand)) {
1932 TF_ASSIGN_OR_RETURN(auto dynamic_operand,
1933 ConvertToDynamic(hlo->mutable_operand(i)));
1934 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, dynamic_operand));
1935 }
1936 }
1937 return Status::OK();
1938 }
1939
1940 return Status::OK();
1941 }
1942
HandleGetTupleElement(HloInstruction * hlo)1943 Status DynamicShapeRemovingVisitor::HandleGetTupleElement(HloInstruction* hlo) {
1944 *hlo->mutable_shape() =
1945 hlo->operand(0)->shape().tuple_shapes(hlo->tuple_index());
1946 return Status::OK();
1947 }
1948
HandleTuple(HloInstruction * hlo)1949 Status DynamicShapeRemovingVisitor::HandleTuple(HloInstruction* hlo) {
1950 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
1951 *hlo->mutable_shape()->mutable_tuple_shapes(i) = hlo->operand(i)->shape();
1952 }
1953 return Status::OK();
1954 }
1955
HandleParameter(HloInstruction * hlo)1956 Status DynamicShapeRemovingVisitor::HandleParameter(HloInstruction* hlo) {
1957 return Status::OK();
1958 }
1959
HandleCustomCall(HloInstruction * hlo)1960 Status DynamicShapeRemovingVisitor::HandleCustomCall(HloInstruction* hlo) {
1961 if (hlo->custom_call_target() == "SliceToDynamic" ||
1962 hlo->custom_call_target() == "PadToStatic") {
1963 // Those ops support are created to handle dynamic tensors so by their
1964 // nature they support dynamic lowering.
1965 return Status::OK();
1966 }
1967
1968 return DefaultAction(hlo);
1969 }
1970
1971 } // namespace
1972
Run(HloModule * module)1973 StatusOr<bool> DynamicPadder::Run(HloModule* module) {
1974 bool changed = false;
1975 VLOG(2) << "Pre DynamicPadder HLO:";
1976 XLA_VLOG_LINES(2, module->ToString());
1977 // Removes dynamic dimensions on parameters if there is already a binding for
1978 // it. We do this because we have two different APIs to express a dynamic
1979 // dimension:
1980 //
1981 // 1. Dynamic dimension as specified directly in the shape -- Needed for
1982 // PyTorch.
1983 //
1984 // 2. Dynamic dimension using dynamic parameter binding object. This
1985 // is needed for tensorflow.
1986 //
1987 // For case 1, we will insert "pad-to-static" instruction in the
1988 // beginning of xla execution, to make it into a static layout.
1989 //
1990 // For case 2, since it already has a static layout, we remove the
1991 // dynamic dimension.
1992 //
1993 // TODO(b/145140571): Convert all API invocations to case 1.
1994 //
1995 TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().ForEachBinding(
1996 [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter,
1997 const DynamicParameterBinding::DynamicDimension& dynamic_dimension)
1998 -> Status {
1999 HloInstruction* parameter =
2000 module->entry_computation()->parameter_instruction(
2001 dynamic_dimension.parameter_num);
2002 ShapeUtil::UpdateDynamicDimension(parameter->mutable_shape(),
2003 dynamic_dimension.parameter_index,
2004 dynamic_dimension.dimension, false);
2005 return Status::OK();
2006 }));
2007
2008 TF_RETURN_IF_ERROR(InsertPadToStaticAfterModuleInputs(module));
2009 TF_ASSIGN_OR_RETURN(
2010 DynamicDimensionInference dynamic_dimension_inference,
2011 DynamicDimensionInference::Run(module, custom_call_handler_));
2012
2013 for (HloComputation* computation : module->computations()) {
2014 for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
2015 OpDynamismSupport has_dynamism_support = OpDynamismSupport::kNoSupport;
2016 if (op_supports_dynamism_handler_ != nullptr) {
2017 has_dynamism_support = op_supports_dynamism_handler_(inst);
2018 }
2019 // This op support dynamic lowering, no padding is required.
2020 if (has_dynamism_support != OpDynamismSupport::kNoSupport) {
2021 continue;
2022 }
2023 if (inst->opcode() == HloOpcode::kConcatenate) {
2024 TF_ASSIGN_OR_RETURN(
2025 changed, RewriteDynamicConcat(inst, &dynamic_dimension_inference));
2026 continue;
2027 }
2028 if (inst->opcode() == HloOpcode::kReverse) {
2029 TF_ASSIGN_OR_RETURN(changed,
2030 RewriteReverse(inst, &dynamic_dimension_inference));
2031 continue;
2032 }
2033 if (inst->opcode() == HloOpcode::kSort) {
2034 TF_ASSIGN_OR_RETURN(
2035 changed, RewriteDynamicSort(inst, &dynamic_dimension_inference));
2036 continue;
2037 }
2038 if (inst->opcode() == HloOpcode::kReshape) {
2039 TF_ASSIGN_OR_RETURN(
2040 changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
2041 continue;
2042 }
2043
2044 // Elementwise binary with dynamic shapes have implicit broadcast
2045 // semantics.
2046 if (inst->IsElementwiseBinary()) {
2047 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicBinaryOp(
2048 inst, &dynamic_dimension_inference));
2049 continue;
2050 }
2051
2052 if (inst->opcode() == HloOpcode::kDynamicUpdateSlice) {
2053 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicUpdateSlice(
2054 inst, &dynamic_dimension_inference));
2055 continue;
2056 }
2057
2058 if (inst->opcode() == HloOpcode::kDynamicReshape) {
2059 TF_ASSIGN_OR_RETURN(
2060 changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference));
2061 auto* static_reshape =
2062 computation->AddInstruction(HloInstruction::CreateReshape(
2063 inst->shape(), inst->mutable_operand(0)));
2064 TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_reshape));
2065 TF_RETURN_IF_ERROR(dynamic_dimension_inference.ForwardDynamicSize(
2066 inst, static_reshape, {}));
2067 continue;
2068 }
2069 if (inst->IsCustomCall("DynamicConvolutionInputGrad")) {
2070 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionInputGrad(
2071 inst, &dynamic_dimension_inference));
2072 continue;
2073 }
2074
2075 if (inst->IsCustomCall("DynamicConvolutionForward")) {
2076 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionForward(
2077 inst, &dynamic_dimension_inference));
2078 continue;
2079 }
2080
2081 if (inst->IsCustomCall("DynamicConvolutionKernelGrad")) {
2082 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicConvolutionKernelGrad(
2083 inst, &dynamic_dimension_inference));
2084 continue;
2085 }
2086
2087 if (inst->IsCustomCall("DynamicReduceWindowSamePadding")) {
2088 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicReduceWindowSamePadding(
2089 inst, &dynamic_dimension_inference));
2090 continue;
2091 }
2092
2093 if (inst->IsCustomCall("DynamicSelectAndScatterSamePadding")) {
2094 TF_ASSIGN_OR_RETURN(changed, RewriteDynamicSelectAndScatterSamePadding(
2095 inst, &dynamic_dimension_inference));
2096 continue;
2097 }
2098
2099 for (int64_t operand_num = 0; operand_num < inst->operand_count();
2100 ++operand_num) {
2101 HloInstruction* original_operand = inst->mutable_operand(operand_num);
2102 HloInstruction* operand = original_operand;
2103 if (!operand->shape().IsArray()) {
2104 continue;
2105 }
2106
2107 for (int64_t input_dim = 0; input_dim < operand->shape().rank();
2108 ++input_dim) {
2109 HloInstruction* operand_dynamic_size =
2110 dynamic_dimension_inference.GetDynamicSize(original_operand, {},
2111 input_dim);
2112 if (operand_dynamic_size == nullptr) {
2113 continue;
2114 }
2115 VLOG(2) << "Has dynamic dimension of operand" << operand_num << " @"
2116 << input_dim;
2117
2118 if (ShouldSkipPadOnOperand(inst, operand_num, input_dim)) {
2119 continue;
2120 }
2121
2122 TF_ASSIGN_OR_RETURN(HloInstruction * identity_value,
2123 ChooseIdentityValue(inst, operand_num));
2124 if (identity_value == nullptr) {
2125 continue;
2126 }
2127
2128 HloInstruction* padded = PadWithScalar(
2129 operand, input_dim, operand_dynamic_size, identity_value);
2130 TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, padded));
2131 operand = inst->mutable_operand(operand_num);
2132 changed = true;
2133 }
2134 }
2135 }
2136 }
2137 if (changed == true) {
2138 module->set_is_dynamic(true);
2139 }
2140
2141 // There are ops that only support dynamic lowering and ops that only support
2142 // static lowering, add dynamic<->static tensor conversion around the boundary
2143 // between those ops, as well as the root instruction.
2144 auto computations = module->MakeComputationPostOrder();
2145 // Reverse postorder so that if caller doesn't support dynamic tensor (while,
2146 // etc), change their called computation to only take static tensors.
2147 for (auto it = computations.rbegin(); it != computations.rend(); ++it) {
2148 HloComputation* computation = *it;
2149 // if slice_dynamic_output_ is set and this is entry computation, we need
2150 // the output tensor to be in dynamic form.
2151 bool require_dynamic_output =
2152 slice_dynamic_output_ && computation == module->entry_computation();
2153 TF_RETURN_IF_ERROR(DynamicShapeRemovingVisitor::Run(
2154 computation, op_supports_dynamism_handler_,
2155 &dynamic_dimension_inference,
2156 /*require_dynamic_output=*/require_dynamic_output));
2157 }
2158
2159 for (auto* computation : module->computations()) {
2160 for (auto instruction : computation->MakeInstructionPostOrder()) {
2161 TF_ASSIGN_OR_RETURN(
2162 bool replaced_get_size,
2163 ReplaceGetSize(instruction, &dynamic_dimension_inference));
2164 changed = changed || replaced_get_size;
2165 }
2166 }
2167
2168 for (auto* computation : module->computations()) {
2169 for (auto instruction : computation->MakeInstructionPostOrder()) {
2170 TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction));
2171 TF_ASSIGN_OR_RETURN(bool replaced_set_bound,
2172 ReplaceSetBound(instruction));
2173 changed = changed || replaced_set_size;
2174 changed = changed || replaced_set_bound;
2175 }
2176 }
2177 HloDCE dce;
2178 TF_ASSIGN_OR_RETURN(changed, dce.Run(module));
2179
2180 VLOG(2) << "Post DynamicPadder HLO:";
2181 XLA_VLOG_LINES(2, module->ToString());
2182 return changed;
2183 }
2184
2185 } // namespace xla
2186