1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/space_to_batch_converter.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <iterator>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/algorithm.h"
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/debug_options_flags.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
42 #include "tensorflow/compiler/xla/service/shape_inference.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/math/math_util.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/stream_executor/lib/statusor.h"
55
56 namespace xla {
57
58 namespace {
59
60 namespace m = match;
61
62 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
63 // operations with small batch counts into convolutions with larger batch
64 // counts by moving space to batch.
65 class ConvolutionVisitor {
66 public:
67 // Top-level function to begin space-to-batch conversion.
68 Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
69
70 // Struct containing details about a convolution.
71 struct ConvDetails {
72 std::vector<int64_t> spatial_dimensions_to_split;
73 int64_t inherent_low_padding, inherent_high_padding, stride, spatial_size,
74 base_dilation_factor, halo_size, high_padding_for_conv,
75 low_padding_for_conv, kernel_spatial_dim_size, input_dim_size;
76 };
77
78 // Return a struct containing various necessary information pieces for
79 // performing space-to-batch on a convolution.
80 ConvDetails GetConvolutionDetails(HloInstruction* convolution,
81 ConvolutionDimensionNumbers& dim_numbers);
82
83 // Returns the set of old and new spatial dimensions respectively.
84 std::pair<std::vector<int64_t>, std::vector<int64_t>> GetSpatialDimsToSplit(
85 HloInstruction* old_operand);
86
87 // Returns if the convolution is a forward window dilated convolution.
88 bool IsForwardWindowDilatedConv(HloInstruction* convolution,
89 ConvolutionDimensionNumbers& dim_numbers);
90
91 // Function that determines if space-to-batch can be propagated into the
92 // consumer. Such propagation is only possible when all required operands are
93 // space-to-batch'ed.
94 bool CanPropagate(HloInstruction* consumer, HloInstruction* producer);
95
96 // Returns true if the op has all its direct and indirect operands being
97 // created via broadcasts. Consumer uses op, and is space-to-batched.
98 // instructions_to_transform returns the reverse post order instruction graph.
99 bool IsBroadcastTree(HloInstruction* op, HloInstruction* consumer,
100 std::vector<HloInstruction*>& instructions_to_transform);
101
102 // Replicates the broadcast tree with space-to-batched instructions.
103 void RewriteBroadcastTree(
104 HloInstruction* producer,
105 std::vector<HloInstruction*>& instructions_to_transform);
106
107 // Propagate space-to-batch on a broadcast instruction.
108 void PropagateOnBroadcast(HloInstruction* consumer, HloInstruction* producer);
109
110 // Returns false if the opcode should definitely not be propagated upon.
111 bool IsOpcodeNonPropagatable(HloInstruction* consumer);
112
113 // This function checks if the HLO instruction supports propagation.
114 bool SupportedOpForPropagation(HloInstruction* consumer,
115 HloInstruction* producer);
116
117 // Method that checks validity of Broadcast propagation.
118 bool IsBroadcastPropagatable(HloInstruction* broadcast,
119 HloInstruction* old_other_op);
120
121 // Propagates space-to-batch on the op, and returns a bool that indicates if
122 // the users of the op need to be propagated through.
123 StatusOr<bool> Propagate(HloInstruction* consumer, HloInstruction* producer);
124
125 // Splits the given spatial dimension on the activations and returns the
126 // new instructions, and the dimension permutation of the new shape.
127 StatusOr<std::pair<HloInstruction*, std::vector<int64_t>>> SplitSpace(
128 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
129 int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding,
130 int64_t spatial_split_size, int64_t num_splits,
131 std::vector<int64_t>* spatial_dimensions_to_split,
132 bool is_backprop = false, bool is_rhs = false);
133
134 // Performs the actual dimension splitting.
135 StatusOr<HloInstruction*> PerformSplitSpace(
136 HloInstruction* activations,
137 absl::Span<const int64_t> spatial_dimensions_to_split,
138 int64_t activations_batch_dim, int64_t spatial_split_size,
139 int64_t num_splits);
140
141 // Helper function that puts individually split dimensions together, and
142 // merges the batch(es).
143 // The input activations dimensions are ... B, B0, S0, B1, S1, ... Bn, Sn, ...
144 // The output dimensions will be ..., B, S0, S1,.. Sn, ...
145 StatusOr<HloInstruction*> TransposeAndMergeBatch(
146 HloInstruction* activations,
147 absl::Span<const int64_t> final_split_spatial_dim_positioning,
148 int64_t activations_batch_dim, int64_t old_batch_size);
149
150 // Helper function for the SplitSpace function above. Handles padding and
151 // reshaping to generate space-to-batched shape.
152 StatusOr<HloInstruction*> PadAndSplitSpace(
153 HloInstruction* activations,
154 absl::Span<const int64_t> spatial_dimensions_to_split,
155 int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding,
156 int64_t spatial_split_size, int64_t num_splits);
157
158 // Perform space-to-batch propagation on constants.
159 StatusOr<HloInstruction*> PropagateOnConstant(HloInstruction* consumer,
160 HloInstruction* producer);
161
162 // Perform space-to-batch propagation on the convolution. Assumes the
163 // activations were already space-to-batched.
164 Status PropagateOnConv(HloInstruction* convolution);
165
166 // Perform space-to-batch propagation on concatenate.
167 Status PropagateOnConcat(HloInstruction* concat);
168
169 // Perform space-to-batch propagation on reverse.
170 Status PropagateOnReverse(HloInstruction* reverse);
171
172 // Perform space-to-batch propagation on reverse.
173 Status PropagateOnPad(HloInstruction* pad);
174
175 // Perform space-to-batch propagation on the backprop filter convolution.
176 // Assumes the activations and kernel were already space-to-batched.
177 Status PropagateOnBackpropFilterConv(HloInstruction* convolution);
178
179 // Method that checks validity of space-to-batch on a given convolution.
180 bool IsConvSuitableForSpaceToBatch(HloInstruction* convolution);
181
182 // Method that returns true if this is a backprop filter convolution.
183 bool IsThisBackPropFilterConv(HloInstruction* convolution);
184
185 // Once a convolution has been space-to-batch'ed, this function will
186 // transitively propagate the space-to-batch-ness on rest of the graph.
187 Status PropagateOnUsers(HloInstruction* old_conv);
188
189 // Generates masked output with valid data. This is useful when larger shapes
190 // are generated due to space-to-batch.
191 StatusOr<HloInstruction*> SelectValidPortion(
192 HloInstruction* new_instr, HloInstruction* old_instr,
193 HloInstruction* select_val, int64_t new_batch_dim,
194 absl::Span<const int64_t> new_space_dims, int64_t old_batch_dim,
195 absl::Span<const int64_t> old_space_dims);
196
197 struct SpaceNextToBatchDetails {
198 HloInstruction* instr;
199 std::vector<int64_t> transpose_dims;
200 };
201
202 // Performs tranposition so that space dimension follows the batch dimension.
203 StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
204 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
205 int64_t& activations_batch_dim,
206 std::vector<int64_t>* spatial_dimensions_to_split,
207 bool is_backprop = false, bool is_rhs = false);
208
209 // Decreases the spatial dimension size in an already space-to-batched shape
210 // so that the new size is new_spatial_dim_size.
211 StatusOr<HloInstruction*> ChangeSpatialSizeOnSpaceToBatchedShape(
212 HloInstruction* activations, int64_t batch_dimension,
213 int64_t old_batch_size,
214 absl::Span<const int64_t> spatial_dimensions_to_split,
215 int64_t new_spatial_dim_size, bool increase_spatial_size = false);
216
217 // Turns B, S0, S1, ..., Sn into B, B0, S0, B1, S1,... Bn, Sn.
218 StatusOr<HloInstruction*> SplitAndTransposeMergedBatch(
219 HloInstruction* activations, int64_t batch_dimension,
220 int64_t old_batch_size, absl::Span<const int64_t> spatial_dimensions);
221
222 // Function that converts spaced-to-batch shape back to the original.
223 StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
224
225 // Duplicates elements at boundaries.
226 StatusOr<HloInstruction*> HaloDuplicateWithSlice(
227 HloInstruction* activations,
228 absl::Span<const int64_t> spatial_dimensions_to_split,
229 int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size,
230 HloInstruction* pad_val = nullptr);
231
232 // Runs the visitor on a computation.
233 StatusOr<bool> Run();
234
235 // Returns whether any convolution ops were rewritten.
changed() const236 const bool changed() const { return changed_; }
237
238 ~ConvolutionVisitor() = default;
239
240 explicit ConvolutionVisitor(SpaceToBatchController ctrl,
241 HloComputation* computation);
242
GetFirstChosenSpatialDim(HloInstruction * convolution)243 int64_t GetFirstChosenSpatialDim(HloInstruction* convolution) {
244 const int64_t dim_count = ctrl_.count_of_dimensions_to_convert;
245 const int64_t end_point = convolution->convolution_dimension_numbers()
246 .input_spatial_dimensions_size() -
247 ctrl_.dimension_from_end_to_convert;
248 return end_point - dim_count + 1;
249 }
250
GetChosenSpatialDims(HloInstruction * convolution)251 std::vector<int64_t> GetChosenSpatialDims(HloInstruction* convolution) {
252 const int64_t dim_count = ctrl_.count_of_dimensions_to_convert;
253 const int64_t first_dim = GetFirstChosenSpatialDim(convolution);
254 std::vector<int64_t> dims(dim_count);
255 for (int i = 0; i < dim_count; ++i) {
256 dims[i] =
257 convolution->convolution_dimension_numbers().input_spatial_dimensions(
258 first_dim + i);
259 }
260 return dims;
261 }
262
DimLookUp(absl::Span<const int64_t> permute_dims,int64_t id)263 int64_t DimLookUp(absl::Span<const int64_t> permute_dims, int64_t id) {
264 return permute_dims[id];
265 }
266
DimMapper(SpaceToBatchDimMap s)267 int DimMapper(SpaceToBatchDimMap s) { return static_cast<int>(s); }
268
ReverseDimLookUp(absl::Span<const int64_t> permute_dims,int64_t id)269 int64_t ReverseDimLookUp(absl::Span<const int64_t> permute_dims, int64_t id) {
270 return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
271 }
272
273 HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
274 HloInstruction* instr, int64_t depth);
275
276 // Returns true if instr feeds an unpropagatable op before it feeds 'depth'
277 // number of convolutions.
278 bool DoesConvolutionFeedUnpropagatableOp(
279 HloInstruction* instr, int64_t depth = kUnpropagatableOpSearchDepth);
280
281 // Checks that the space-to-batched shape has not rendered the new spatial
282 // dimension to be smaller than the window's size.
283 bool IsSpaceToBatchedSpaceSizeSuitable(HloInstruction* instr);
284
285 private:
286 // Current HloComputation instance the ConvolutionVisitor is traversing.
287 HloComputation* computation_;
288
289 absl::flat_hash_set<HloInstruction*> convs_to_visit_;
290 std::vector<HloInstruction*> conv_visitor_list_;
291 HloInstructionSet non_propagatable_instrs_;
292 // Map from a given spaced-to-batch instruction to its batched-to-space
293 // version.
294 absl::flat_hash_map<HloInstruction*, HloInstruction*> batch_to_space_map_;
295
296 // Map from old (non space-to-batch) instructions to space-to-batch'ed
297 // instructions.
298 absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_instrs_;
299
300 // Map from instruction to dimensions of the shape. This is with respect to
301 // the old instruction.
302 absl::flat_hash_map<HloInstruction*, std::vector<int64_t>> instr_to_dim_map_;
303
304 // Map from space-to-batch'ed instruction to its permute dims.
305 absl::flat_hash_map<HloInstruction*, std::vector<int64_t>>
306 instr_to_dim_permute_map_;
307
308 // Map maintaining previously space-to-batched broadcasts.
309 absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
310 broadcast_map_;
311
312 // Whether rewrite has occurred.
313 bool changed_ = false;
314
315 // Depth for searching reduce window
316 static constexpr int64_t kReduceWindowSearchDepth = 10;
317
318 // Depth for searching unpropagatable op.
319 static constexpr int64_t kUnpropagatableOpSearchDepth = 3;
320
321 // Penalty on size for base dilated convs
322 static constexpr int64_t kMultiplierOnSpaceForBaseDilation = 3;
323
324 // Cache for <instruction, depth> ==> unpropagatablilty decision.
325 absl::flat_hash_map<std::pair<HloInstruction*, int64_t>, bool>
326 unpropagatability_cache_;
327
328 // Controller for various knobs.
329 SpaceToBatchController ctrl_;
330 };
331
ConvolutionVisitor(SpaceToBatchController ctrl,HloComputation * computation)332 ConvolutionVisitor::ConvolutionVisitor(SpaceToBatchController ctrl,
333 HloComputation* computation) {
334 ctrl_ = ctrl;
335 computation_ = computation;
336 for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
337 if (inst->opcode() != HloOpcode::kConvolution) {
338 continue;
339 }
340
341 auto convolution = inst;
342 // Perform legality checks.
343 if (!IsConvSuitableForSpaceToBatch(convolution)) {
344 VLOG(1) << "Conv not suitable for space-to-batch "
345 << convolution->ToString();
346 continue;
347 }
348 VLOG(1) << "Conv added to space-to-batch worklist "
349 << convolution->ToString();
350 convs_to_visit_.insert(convolution);
351 conv_visitor_list_.push_back(convolution);
352 }
353 }
354
355 std::pair<std::vector<int64_t>, std::vector<int64_t>>
GetSpatialDimsToSplit(HloInstruction * old_operand)356 ConvolutionVisitor::GetSpatialDimsToSplit(HloInstruction* old_operand) {
357 auto new_operand = old_to_new_instrs_[old_operand];
358 auto dim_map_val = instr_to_dim_map_[old_operand];
359 auto permute_dims = instr_to_dim_permute_map_[new_operand];
360 std::vector<int64_t> old_dims(ctrl_.count_of_dimensions_to_convert),
361 new_dims(ctrl_.count_of_dimensions_to_convert);
362
363 old_dims[0] = dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
364 new_dims[0] = DimLookUp(permute_dims, old_dims[0]);
365 for (int i = 1; i < ctrl_.count_of_dimensions_to_convert; ++i) {
366 old_dims[i] = old_dims[0] + i;
367 new_dims[i] = new_dims[0] + i;
368 }
369 return std::make_pair(old_dims, new_dims);
370 }
371
IsForwardWindowDilatedConv(HloInstruction * convolution,ConvolutionDimensionNumbers & dim_numbers)372 bool ConvolutionVisitor::IsForwardWindowDilatedConv(
373 HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
374 const int64_t window_dilation_factor =
375 convolution->window()
376 .dimensions(GetFirstChosenSpatialDim(convolution))
377 .window_dilation();
378
379 if (window_dilation_factor == 1) {
380 return false;
381 }
382
383 const int64_t output_spatial_dim = dim_numbers.output_spatial_dimensions(
384 GetFirstChosenSpatialDim(convolution));
385 const int64_t kernel_spatial_dim = dim_numbers.kernel_spatial_dimensions(
386 GetFirstChosenSpatialDim(convolution));
387
388 // If convolution's spatial dim size is larger than that of RHS, this is a
389 // forward RHS dilated convolution.
390 return convolution->operand(1)->shape().dimensions(kernel_spatial_dim) <
391 convolution->shape().dimensions(output_spatial_dim);
392 }
393
IsConvSuitableForSpaceToBatch(HloInstruction * convolution)394 bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
395 HloInstruction* convolution) {
396 ConvolutionDimensionNumbers dim_numbers =
397 convolution->convolution_dimension_numbers();
398
399 // If there are no specified spatial dims, we return.
400 if (GetFirstChosenSpatialDim(convolution) < 0) {
401 return false;
402 }
403
404 // Batch in batch_group_count has different semantics (it isn't true batch).
405 // Consider supporting this case in future if needed.
406 if (convolution->batch_group_count() != 1) {
407 return false;
408 }
409
410 if (convolution->window()
411 .dimensions(GetFirstChosenSpatialDim(convolution))
412 .window_dilation() != 1) {
413 if (!IsForwardWindowDilatedConv(convolution, dim_numbers)) {
414 return false;
415 }
416 }
417
418 const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
419
420 const int64_t low_pad = convolution->window()
421 .dimensions(GetFirstChosenSpatialDim(convolution))
422 .padding_low();
423
424 // TODO(b/168316428): Support base dilations more generically.
425 if (c.base_dilation_factor != 1) {
426 if (!ctrl_.enable_propagations_on_base_dilations) {
427 return false;
428 }
429 if (c.stride != 1) {
430 return false;
431 }
432 // For low pad of 0, only support a pointwise kernel.
433 if (low_pad == 0) {
434 if (c.kernel_spatial_dim_size != 1) {
435 return false;
436 }
437 } else if (low_pad != c.base_dilation_factor - 1 &&
438 low_pad != c.base_dilation_factor) {
439 // Only support dilations such that base dilation factor and low pad are
440 // compatible with kernel_spatial_dim_size to be compatible with
441 // HaloDuplicateWithSlice.
442 return false;
443 }
444 }
445
446 int64_t activations_batch_dim = dim_numbers.input_batch_dimension();
447
448 const int64_t old_batch_size =
449 convolution->operand(0)->shape().dimensions(activations_batch_dim);
450
451 if (old_batch_size > ctrl_.limit_on_batch_size) {
452 return false;
453 }
454
455 VLOG(1) << "spatial size " << c.spatial_size << " halo size " << c.halo_size;
456
457 // If the ratio is not within the 2X range, we can't Halo Pad from the next
458 // split.
459 if (c.halo_size > CeilOfRatio(c.spatial_size, ctrl_.number_of_splits)) {
460 return false;
461 }
462
463 // TODO(b/201444224): The following cost model is needed to escape slowing
464 // down ssd batch 4.
465 if (c.base_dilation_factor > 1 &&
466 c.inherent_low_padding == c.base_dilation_factor) {
467 if (c.spatial_size <
468 kMultiplierOnSpaceForBaseDilation * ctrl_.number_of_splits) {
469 return false;
470 }
471 }
472
473 VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
474 return true;
475 }
476
IsThisBackPropFilterConv(HloInstruction * convolution)477 bool ConvolutionVisitor::IsThisBackPropFilterConv(HloInstruction* convolution) {
478 auto activations = convolution->mutable_operand(0);
479 auto kernel = convolution->mutable_operand(1);
480 auto dim_numbers = convolution->convolution_dimension_numbers();
481
482 if (!old_to_new_instrs_.contains(kernel) &&
483 !old_to_new_instrs_.contains(activations)) {
484 return false;
485 }
486
487 if (old_to_new_instrs_.contains(kernel)) {
488 auto dim_map_val_op_0 = instr_to_dim_map_[kernel];
489 const int64_t old_batch_dim =
490 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
491 if (convolution->convolution_dimension_numbers()
492 .kernel_input_feature_dimension() != old_batch_dim) {
493 return false;
494 }
495 }
496
497 if (old_to_new_instrs_.contains(activations)) {
498 auto dim_map_val_op_0 = instr_to_dim_map_[activations];
499 const int64_t old_batch_dim =
500 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
501 if (dim_numbers.input_feature_dimension() != old_batch_dim) {
502 return false;
503 }
504 }
505
506 return true;
507 }
508
HaloDuplicateWithSlice(HloInstruction * activations,absl::Span<const int64_t> spatial_dimensions_to_split,int64_t activations_batch_dim,int64_t low_padding,int64_t halo_size,HloInstruction * pad_val)509 StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
510 HloInstruction* activations,
511 absl::Span<const int64_t> spatial_dimensions_to_split,
512 int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size,
513 HloInstruction* pad_val) {
514 const int64_t spatial_dim_count = spatial_dimensions_to_split.size();
515 const int64_t additional_batch_size =
516 IPow<int64_t>(ctrl_.number_of_splits, spatial_dim_count);
517 const int64_t original_batch_size =
518 activations->shape().dimensions(activations_batch_dim) /
519 additional_batch_size;
520
521 const int64_t spatial_split_size =
522 activations->shape().dimensions(spatial_dimensions_to_split[0]);
523 const int64_t batch_size = ctrl_.number_of_splits;
524
525 TF_ASSIGN_OR_RETURN(
526 activations, SplitAndTransposeMergedBatch(
527 activations, activations_batch_dim, original_batch_size,
528 spatial_dimensions_to_split));
529
530 const int64_t rank = activations->shape().rank();
531
532 VLOG(1) << "In HaloDuplicateWithSlice with activations "
533 << activations->ToString() << " batch_size " << batch_size
534 << " spatial_split_size " << spatial_split_size << " low_padding "
535 << low_padding << " halo size " << halo_size;
536
537 CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
538
539 for (int64_t i = 0; i < spatial_dimensions_to_split.size(); ++i) {
540 int64_t spatial_dimension_to_split = activations_batch_dim + 2 * (i + 1);
541 int64_t remapped_batch_dimension = spatial_dimension_to_split - 1;
542 HloInstruction* first_slice = nullptr;
543
544 std::vector<int64_t> strides(rank, 1);
545 HloInstruction* padding =
546 pad_val == nullptr
547 ? computation_->AddInstruction(HloInstruction::CreateConstant(
548 LiteralUtil::Zero(activations->shape().element_type())))
549 : pad_val;
550
551 if (low_padding > 0) {
552 std::vector<int64_t> start_indices(rank, 0),
553 end_indices(activations->shape().dimensions().begin(),
554 activations->shape().dimensions().end());
555 start_indices[spatial_dimension_to_split] =
556 spatial_split_size - low_padding;
557 end_indices[remapped_batch_dimension] = batch_size - 1;
558 end_indices[spatial_dimension_to_split] = spatial_split_size;
559
560 TF_ASSIGN_OR_RETURN(first_slice, MakeSliceHlo(activations, start_indices,
561 end_indices, strides));
562 VLOG(1) << "first slice " << first_slice->ToString();
563
564 PaddingConfig padding_config =
565 MakeNoPaddingConfig(first_slice->shape().dimensions_size());
566 padding_config.mutable_dimensions(remapped_batch_dimension)
567 ->set_edge_padding_low(1);
568
569 TF_ASSIGN_OR_RETURN(first_slice,
570 MakePadHlo(first_slice, padding, padding_config));
571 }
572
573 HloInstruction* halo_region = nullptr;
574 if (halo_size - low_padding > 0) {
575 std::vector<int64_t> start_indices_halo(rank, 0),
576 end_indices_halo(activations->shape().dimensions().begin(),
577 activations->shape().dimensions().end());
578
579 start_indices_halo[remapped_batch_dimension] = 1;
580 end_indices_halo[spatial_dimension_to_split] = halo_size - low_padding;
581
582 TF_ASSIGN_OR_RETURN(halo_region,
583 MakeSliceHlo(activations, start_indices_halo,
584 end_indices_halo, strides));
585 VLOG(1) << "halo_region " << halo_region->ToString();
586 PaddingConfig padding_config_halo =
587 MakeNoPaddingConfig(halo_region->shape().dimensions_size());
588 padding_config_halo.mutable_dimensions(remapped_batch_dimension)
589 ->set_edge_padding_high(1);
590 TF_ASSIGN_OR_RETURN(
591 halo_region, MakePadHlo(halo_region, padding, padding_config_halo));
592 }
593
594 if (halo_size == 0 && low_padding != 0) {
595 std::vector<int64_t> start_indices_activations_cut(rank, 0),
596 end_indices_activations_cut(activations->shape().dimensions().begin(),
597 activations->shape().dimensions().end());
598 // When no halo is needed, we must slice out activations.
599 if (low_padding > 0) {
600 end_indices_activations_cut[spatial_dimension_to_split] =
601 spatial_split_size - low_padding;
602 } else {
603 start_indices_activations_cut[spatial_dimension_to_split] =
604 0 - low_padding;
605 end_indices_activations_cut[spatial_dimension_to_split] =
606 spatial_split_size;
607 }
608
609 TF_ASSIGN_OR_RETURN(
610 activations, MakeSliceHlo(activations, start_indices_activations_cut,
611 end_indices_activations_cut, strides));
612 }
613
614 if (first_slice != nullptr) {
615 TF_ASSIGN_OR_RETURN(activations,
616 MakeConcatHlo({first_slice, activations},
617 spatial_dimension_to_split));
618 }
619
620 if (halo_region != nullptr) {
621 TF_ASSIGN_OR_RETURN(activations,
622 MakeConcatHlo({activations, halo_region},
623 spatial_dimension_to_split));
624 }
625 }
626
627 TF_ASSIGN_OR_RETURN(
628 activations,
629 TransposeAndMergeBatch(
630 activations,
631 /*final_split_spatial_dim_positioning=*/spatial_dimensions_to_split,
632 activations_batch_dim, original_batch_size));
633
634 VLOG(1) << "HaloDuplicated activations " << activations->ToString();
635 return activations;
636 }
637
638 StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
BringSpaceNextToBatch(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64_t & activations_batch_dim,std::vector<int64_t> * spatial_dimensions_to_split,bool is_backprop,bool is_rhs)639 ConvolutionVisitor::BringSpaceNextToBatch(
640 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
641 int64_t& activations_batch_dim,
642 std::vector<int64_t>* spatial_dimensions_to_split, bool is_backprop,
643 bool is_rhs) {
644 for (int64_t i = 1; i < spatial_dimensions_to_split->size(); ++i) {
645 CHECK_EQ(spatial_dimensions_to_split->at(i),
646 spatial_dimensions_to_split->at(i - 1) + 1)
647 << "Spatial dimensions are not contiguous";
648 }
649
650 int64_t spatial_dimension_to_split = spatial_dimensions_to_split->at(0);
651
652 std::vector<int64_t> transpose_dims(activations->shape().rank());
653 if (spatial_dimension_to_split == activations_batch_dim + 1) {
654 absl::c_iota(transpose_dims, 0);
655 } else {
656 ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
657 int64_t pushed_counter = 0;
658 int64_t new_batch_dim, new_spatial_dim;
659 int64_t dim_counter = 0;
660 if (is_rhs) {
661 CHECK(is_backprop);
662 for (int i = 0; i < activations->shape().rank(); ++i) {
663 if (i == activations_batch_dim) {
664 continue;
665 }
666 if (i == spatial_dimension_to_split) {
667 transpose_dims[dim_counter++] = activations_batch_dim;
668 new_batch_dim = pushed_counter;
669 pushed_counter++;
670 new_spatial_dim = pushed_counter;
671 }
672
673 if (i == dim_numbers.kernel_output_feature_dimension()) {
674 new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
675 } else {
676 auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
677 if (it != dim_numbers.kernel_spatial_dimensions().end()) {
678 int64_t j = it - dim_numbers.kernel_spatial_dimensions().begin();
679 new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
680 }
681 }
682 transpose_dims[dim_counter++] = i;
683 pushed_counter++;
684 }
685
686 activations_batch_dim = new_batch_dim;
687 spatial_dimension_to_split = new_spatial_dim;
688 TF_ASSIGN_OR_RETURN(activations,
689 MakeTransposeHlo(activations, transpose_dims));
690
691 new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
692
693 } else {
694 for (int i = 0; i < activations->shape().rank(); ++i) {
695 if (i == activations_batch_dim) {
696 continue;
697 }
698 if (i == spatial_dimension_to_split) {
699 transpose_dims[dim_counter++] = activations_batch_dim;
700 new_batch_dim = pushed_counter;
701 pushed_counter++;
702 new_spatial_dim = pushed_counter;
703 }
704
705 if (is_backprop && i == dim_numbers.input_batch_dimension()) {
706 new_dim_numbers.set_input_batch_dimension(pushed_counter);
707 } else if (i == dim_numbers.input_feature_dimension()) {
708 new_dim_numbers.set_input_feature_dimension(pushed_counter);
709 } else {
710 auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
711 if (it != dim_numbers.input_spatial_dimensions().end()) {
712 int64_t j = it - dim_numbers.input_spatial_dimensions().begin();
713 new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
714 }
715 }
716 transpose_dims[dim_counter++] = i;
717 pushed_counter++;
718 }
719
720 activations_batch_dim = new_batch_dim;
721 spatial_dimension_to_split = new_spatial_dim;
722 TF_ASSIGN_OR_RETURN(activations,
723 MakeTransposeHlo(activations, transpose_dims));
724
725 if (is_backprop) {
726 new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
727 } else {
728 new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
729 }
730 }
731
732 dim_numbers = new_dim_numbers;
733 }
734
735 // Note that the spatial dimensions are in a sequential increasing order.
736 for (int64_t i = 0; i < spatial_dimensions_to_split->size(); ++i) {
737 (*spatial_dimensions_to_split)[i] = spatial_dimension_to_split + i;
738 }
739
740 return SpaceNextToBatchDetails{activations, transpose_dims};
741 }
742
SplitAndTransposeMergedBatch(HloInstruction * activations,int64_t batch_dimension,int64_t old_batch_size,absl::Span<const int64_t> spatial_dimensions)743 StatusOr<HloInstruction*> ConvolutionVisitor::SplitAndTransposeMergedBatch(
744 HloInstruction* activations, int64_t batch_dimension,
745 int64_t old_batch_size, absl::Span<const int64_t> spatial_dimensions) {
746 CHECK_EQ(batch_dimension + 1, spatial_dimensions[0]);
747 std::vector<int64_t> new_dimensions(activations->shape().dimensions().begin(),
748 activations->shape().dimensions().end());
749
750 const int64_t new_batch_size =
751 activations->shape().dimensions(batch_dimension);
752
753 VLOG(3) << "Decreasing the spatial size while propagating new_batch_size "
754 << new_batch_size << " old_batch_size " << old_batch_size;
755
756 new_dimensions[batch_dimension] = old_batch_size;
757
758 const int64_t spatial_dim_count = spatial_dimensions.size();
759 // Create additional batch dimensions.
760 for (int64_t i = 0; i < spatial_dim_count; ++i) {
761 new_dimensions.insert(new_dimensions.begin() + spatial_dimensions[0],
762 ctrl_.number_of_splits);
763 }
764
765 // Reshape the output of the new conv into the old convolutions shape.
766 TF_ASSIGN_OR_RETURN(HloInstruction * batch_split_activations,
767 MakeReshapeHlo(new_dimensions, activations));
768
769 if (spatial_dim_count > 1) {
770 // Transpose such that we get // B, B0, S0, B1, S1,...
771 std::vector<int64_t> transpose_dims(new_dimensions.size());
772 absl::c_iota(transpose_dims, 0);
773 // Transpose such that we get B, B0, S0, B1, S1,...
774 std::vector<int64_t> trans_dims(new_dimensions.size());
775 absl::c_iota(trans_dims, 0);
776
777 int64_t start_batch_dim_position = batch_dimension + 1;
778 int64_t start_space_dim_position = batch_dimension + 2;
779
780 for (int i = 0; i < spatial_dim_count; ++i) {
781 transpose_dims[start_batch_dim_position + 2 * i] =
782 batch_dimension + spatial_dim_count - i;
783 transpose_dims[start_space_dim_position + 2 * i] =
784 batch_dimension + spatial_dim_count + 1 + i;
785 }
786
787 TF_ASSIGN_OR_RETURN(
788 batch_split_activations,
789 MakeTransposeHlo(batch_split_activations, transpose_dims));
790 }
791 return batch_split_activations;
792 }
793
794 StatusOr<HloInstruction*>
ChangeSpatialSizeOnSpaceToBatchedShape(HloInstruction * activations,int64_t batch_dimension,int64_t old_batch_size,absl::Span<const int64_t> spatial_dimensions,int64_t new_spatial_dim_size,bool increase_spatial_size)795 ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape(
796 HloInstruction* activations, int64_t batch_dimension,
797 int64_t old_batch_size, absl::Span<const int64_t> spatial_dimensions,
798 int64_t new_spatial_dim_size, bool increase_spatial_size) {
799 CHECK_EQ(batch_dimension + 1, spatial_dimensions[0]);
800 std::vector<int64_t> new_dimensions(activations->shape().dimensions().begin(),
801 activations->shape().dimensions().end());
802
803 const int64_t spatial_dim_count = spatial_dimensions.size();
804 const int64_t spatial_dim_size =
805 activations->shape().dimensions(spatial_dimensions[0]);
806 const int64_t reshaped_space_size = spatial_dim_size * ctrl_.number_of_splits;
807
808 // Reshape the output of the new conv into the old convolutions shape.
809 TF_ASSIGN_OR_RETURN(
810 HloInstruction * batch_split_activations,
811 SplitAndTransposeMergedBatch(activations, batch_dimension, old_batch_size,
812 spatial_dimensions));
813
814 // Now merge the individual (split) batch and space dimensions.
815 std::vector<int64_t> batch_space_collapse_reshape_dims(
816 batch_split_activations->shape().dimensions().begin(),
817 batch_split_activations->shape().dimensions().end());
818
819 batch_space_collapse_reshape_dims.erase(
820 batch_space_collapse_reshape_dims.begin() + spatial_dimensions[0],
821 batch_space_collapse_reshape_dims.begin() + spatial_dimensions[0] +
822 spatial_dim_count);
823
824 for (auto spatial_dimension : spatial_dimensions) {
825 batch_space_collapse_reshape_dims[spatial_dimension] = reshaped_space_size;
826 }
827
828 TF_ASSIGN_OR_RETURN(HloInstruction * batch_space_collapsed_reshape,
829 MakeReshapeHlo(batch_space_collapse_reshape_dims,
830 batch_split_activations));
831
832 VLOG(3) << "First reshape done";
833
834 const int64_t rank = activations->shape().rank();
835
836 // If spatial size is increased, we add padding. If it has shrunk, we slice
837 // out the padding that was added before.
838 if (increase_spatial_size) {
839 PaddingConfig padding_config = MakeNoPaddingConfig(
840 batch_space_collapsed_reshape->shape().dimensions_size());
841 for (auto spatial_dimension : spatial_dimensions) {
842 padding_config.mutable_dimensions(spatial_dimension)
843 ->set_edge_padding_high(new_spatial_dim_size *
844 ctrl_.number_of_splits -
845 reshaped_space_size);
846 padding_config.mutable_dimensions(spatial_dimension)
847 ->set_edge_padding_low(0);
848 }
849 HloInstruction* padding = computation_->AddInstruction(
850 HloInstruction::CreateConstant(LiteralUtil::Zero(
851 batch_space_collapsed_reshape->shape().element_type())));
852
853 TF_ASSIGN_OR_RETURN(
854 batch_space_collapsed_reshape,
855 MakePadHlo(batch_space_collapsed_reshape, padding, padding_config));
856 } else {
857 std::vector<int64_t> start_indices(rank, 0),
858 end_indices(batch_space_collapsed_reshape->shape().dimensions().begin(),
859 batch_space_collapsed_reshape->shape().dimensions().end()),
860 strides(rank, 1);
861 for (auto spatial_dimension : spatial_dimensions) {
862 end_indices[spatial_dimension] =
863 new_spatial_dim_size * ctrl_.number_of_splits;
864 }
865
866 // This is the slice from halo padding.
867 TF_ASSIGN_OR_RETURN(batch_space_collapsed_reshape,
868 MakeSliceHlo(batch_space_collapsed_reshape,
869 start_indices, end_indices, strides));
870 }
871
872 TF_ASSIGN_OR_RETURN(
873 HloInstruction * activations_new,
874 PerformSplitSpace(batch_space_collapsed_reshape, spatial_dimensions,
875 batch_dimension, new_spatial_dim_size,
876 ctrl_.number_of_splits));
877
878 VLOG(3) << "Size decreased activations " << activations_new->ToString();
879
880 return activations_new;
881 }
882
Run()883 StatusOr<bool> ConvolutionVisitor::Run() {
884 for (auto conv : conv_visitor_list_) {
885 // If we expect to see an unpropagatable op, space-to-batch may not be
886 // beneficial.
887 if (ctrl_.disable_starting_on_small_chains &&
888 DoesConvolutionFeedUnpropagatableOp(conv)) {
889 VLOG(1) << "Giving up on conv " << conv->ToString()
890 << " because it feeds an unpropagatable op";
891 convs_to_visit_.erase(conv);
892 }
893 if (convs_to_visit_.count(conv) > 0) {
894 TF_CHECK_OK(PerformSpaceToBatchOnConvolution(conv));
895 }
896 }
897 conv_visitor_list_.clear();
898 convs_to_visit_.clear();
899 // Iterate through all instructions that we could not propagate through, and
900 // turn their operands from batch-to-space as needed.
901 for (auto instr : non_propagatable_instrs_) {
902 if (instr->opcode() == HloOpcode::kConvolution) {
903 VLOG(1) << "Instr " << instr->ToString();
904 }
905 // Try to propagate on backprop filters
906 if (instr->opcode() == HloOpcode::kConvolution &&
907 !IsConvSuitableForSpaceToBatch(instr)) {
908 HloInstruction* producer = nullptr;
909 if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
910 producer = instr->mutable_operand(0);
911 } else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
912 producer = instr->mutable_operand(1);
913 }
914 if (producer) {
915 if (CanPropagate(instr, producer)) {
916 bool needs_further_propagation;
917 TF_ASSIGN_OR_RETURN(needs_further_propagation,
918 Propagate(instr, producer));
919 TF_CHECK_OK(computation_->ReplaceInstruction(
920 instr, old_to_new_instrs_[instr]));
921 continue;
922 }
923 }
924 }
925 VLOG(1) << "Could not eventually propagate through " << instr->ToString();
926 absl::flat_hash_map<int64_t, HloInstruction*> operand_map;
927 for (int64_t i = 0; i < instr->operand_count(); ++i) {
928 if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
929 TF_ASSIGN_OR_RETURN(operand_map[i],
930 BatchToSpace(instr->mutable_operand(i)));
931 }
932 }
933 for (auto entry : operand_map) {
934 TF_CHECK_OK(instr->ReplaceOperandWith(entry.first, entry.second));
935 }
936 }
937 non_propagatable_instrs_.clear();
938 return changed_;
939 }
940
IsTrivialElementwise(HloInstruction * hlo)941 bool IsTrivialElementwise(HloInstruction* hlo) {
942 if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
943 hlo->opcode() == HloOpcode::kCopy ||
944 hlo->opcode() == HloOpcode::kConstant ||
945 hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
946 return false;
947 }
948 return hlo->IsElementwise();
949 }
950
CanPropagate(HloInstruction * consumer,HloInstruction * producer)951 bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
952 HloInstruction* producer) {
953 if (IsTrivialElementwise(consumer)) {
954 VLOG(2) << "Doing propagation check on elementwise op: "
955 << consumer->ToString();
956
957 HloInstruction* pivot_operand = nullptr;
958 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
959 auto old_producer = consumer->mutable_operand(i);
960 std::vector<HloInstruction*> to_transform;
961 const bool broadcast_or_constant =
962 (old_producer->opcode() == HloOpcode::kConstant) ||
963 (old_producer->opcode() == HloOpcode::kBroadcast &&
964 IsBroadcastPropagatable(old_producer, producer)) ||
965 (consumer->IsElementwiseBinary() &&
966 old_producer->opcode() == HloOpcode::kBroadcast &&
967 IsBroadcastTree(old_producer, producer, to_transform));
968
969 if (!old_to_new_instrs_.contains(old_producer) &&
970 !broadcast_or_constant) {
971 VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
972 << " because operand " << old_producer->ToString()
973 << " isn't ready ";
974 return false;
975 } else {
976 if (broadcast_or_constant) {
977 VLOG(2) << "Skipping on " << old_producer->ToString();
978 continue;
979 }
980
981 CHECK(old_to_new_instrs_.contains(old_producer));
982
983 CHECK(instr_to_dim_map_.contains(old_producer));
984 if (pivot_operand == nullptr) {
985 pivot_operand = old_producer;
986 VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
987 } else {
988 if (instr_to_dim_map_[pivot_operand]
989 [DimMapper(SpaceToBatchDimMap::kBatch)] !=
990 instr_to_dim_map_[old_producer]
991 [DimMapper(SpaceToBatchDimMap::kBatch)] ||
992 instr_to_dim_map_[pivot_operand]
993 [DimMapper(SpaceToBatchDimMap::kSpace0)] !=
994 instr_to_dim_map_[old_producer]
995 [DimMapper(SpaceToBatchDimMap::kSpace0)]) {
996 VLOG(2) << "Elementwise op: checking for shape equivalence "
997 << consumer->ToString()
998 << " failed due to changed batch space ordering ";
999 return false;
1000 }
1001 auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
1002 auto pivot_permute_dims = instr_to_dim_permute_map_[pivot_new_instr];
1003 auto new_instr = old_to_new_instrs_[old_producer];
1004 auto permute_dims = instr_to_dim_permute_map_[new_instr];
1005 for (int j = 0; j < pivot_permute_dims.size(); ++j) {
1006 // Ensure the dimension mapping is the same.
1007 if (pivot_permute_dims[j] != permute_dims[j]) {
1008 VLOG(2) << "Elementwise op: checking for shape equivalence "
1009 << consumer->ToString()
1010 << " failed due to permuted dimensions ";
1011 return false;
1012 }
1013
1014 // Make sure all other dimensions are of the same size.
1015 if (pivot_new_instr->shape().dimensions(j) !=
1016 new_instr->shape().dimensions(j)) {
1017 if (!((consumer->IsElementwiseBinary() ||
1018 consumer->opcode() == HloOpcode::kSelect) &&
1019 j == instr_to_dim_map_[pivot_operand][DimMapper(
1020 SpaceToBatchDimMap::kSpace0)])) {
1021 VLOG(2) << "Elementwise op: checking for shape equivalence "
1022 << consumer->ToString()
1023 << " failed due to changed shape sizes ";
1024 return false;
1025 }
1026 }
1027 }
1028 }
1029 }
1030 }
1031 }
1032
1033 if (consumer->opcode() == HloOpcode::kConcatenate) {
1034 // Make sure all operands have been space-to-batched.
1035 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1036 if (!instr_to_dim_map_.contains(consumer->mutable_operand(i))) {
1037 return false;
1038 }
1039 }
1040 auto pivot_operand = consumer->mutable_operand(0);
1041 auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
1042 auto pivot_permute_dims = instr_to_dim_permute_map_[pivot_new_instr];
1043 for (int64_t i = 1; i < consumer->operand_count(); ++i) {
1044 auto new_instr = old_to_new_instrs_[consumer->mutable_operand(i)];
1045 auto permute_dims = instr_to_dim_permute_map_[new_instr];
1046
1047 for (int j = 0; j < pivot_permute_dims.size(); ++j) {
1048 // Ensure the dimension mapping is the same.
1049 if (pivot_permute_dims[j] != permute_dims[j]) {
1050 VLOG(2) << "Concat op: checking for shape equivalence "
1051 << consumer->ToString()
1052 << " failed due to permuted dimensions ";
1053 return false;
1054 }
1055 // Make sure all other dimensions are of the same size.
1056 if (pivot_new_instr->shape().dimensions(j) !=
1057 new_instr->shape().dimensions(j)) {
1058 VLOG(2) << "Concat op: checking for shape equivalence "
1059 << consumer->ToString()
1060 << " failed due to changed shape sizes ";
1061 return false;
1062 }
1063 }
1064 }
1065 return true;
1066 }
1067
1068 if (consumer->opcode() == HloOpcode::kConvolution) {
1069 if (!ConsumeFuel("space-to-batch-converter", [&] {
1070 return "Skipping space-to-batch propagation because fuel over\n";
1071 })) {
1072 return false;
1073 }
1074 // Lambda that checks basic sanity of dimension propagation on convolutions.
1075 // This includes: the split dimension from the previous convolution should
1076 // remain the same. No feature/batch dimension should be turned into a
1077 // spatial dimension.
1078 auto are_conv_dims_compatible =
1079 [&](const ConvolutionDimensionNumbers dim_numbers,
1080 std::vector<int64_t>& dim_map, bool check_lhs) {
1081 if (check_lhs) {
1082 if (dim_numbers.input_spatial_dimensions(
1083 GetFirstChosenSpatialDim(consumer)) !=
1084 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)]) {
1085 return false;
1086 }
1087 for (int i = 0; i < dim_numbers.input_spatial_dimensions().size();
1088 ++i) {
1089 if (dim_numbers.input_spatial_dimensions(i) ==
1090 dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] ||
1091 dim_numbers.input_spatial_dimensions(i) ==
1092 dim_map[DimMapper(SpaceToBatchDimMap::kFeature)]) {
1093 return false;
1094 }
1095 }
1096 } else {
1097 if (dim_numbers.kernel_spatial_dimensions(
1098 GetFirstChosenSpatialDim(consumer)) !=
1099 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)]) {
1100 return false;
1101 }
1102 for (int i = 0; i < dim_numbers.kernel_spatial_dimensions().size();
1103 ++i) {
1104 if (dim_numbers.kernel_spatial_dimensions(i) ==
1105 dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] ||
1106 dim_numbers.kernel_spatial_dimensions(i) ==
1107 dim_map[DimMapper(SpaceToBatchDimMap::kFeature)]) {
1108 return false;
1109 }
1110 }
1111 }
1112 return true;
1113 };
1114
1115 VLOG(1) << "Checking if conv is supported for propagation "
1116 << consumer->ToString();
1117 bool found_good_non_window_dilated_conv = true;
1118 if (IsConvSuitableForSpaceToBatch(consumer)) {
1119 // Activations must have been space-to-batched to enable propagation.
1120 if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
1121 found_good_non_window_dilated_conv = false;
1122 }
1123 auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
1124
1125 if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1126 dim_map_val_op_0, /*check_lhs*/ true)) {
1127 found_good_non_window_dilated_conv = false;
1128 }
1129 // Make sure that the batch dimension is the same across the producer
1130 // and consumer.
1131 if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
1132 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)]) {
1133 found_good_non_window_dilated_conv = false;
1134 }
1135
1136 if (found_good_non_window_dilated_conv) {
1137 return true;
1138 }
1139 }
1140
1141 if (!ctrl_.enable_propagations_on_window_dilations) {
1142 return false;
1143 }
1144
1145 if (!IsThisBackPropFilterConv(consumer)) {
1146 return false;
1147 }
1148 // Check for space-to-depth readiness here. Note this is not done in
1149 // SupportedOpForPropagation because the readiness is dependent upon
1150 // space-to-batchedness of the operands.
1151
1152 // If there are no specified spatial dims, we return.
1153 if (GetFirstChosenSpatialDim(consumer) < 0) {
1154 return false;
1155 }
1156
1157 // We currently only support stride of 1.
1158 if (consumer->window()
1159 .dimensions(GetFirstChosenSpatialDim(consumer))
1160 .stride() != 1) {
1161 return false;
1162 }
1163
1164 // Same reason why we give up on batch group counts applies to features in
1165 // backprop.
1166 if (consumer->feature_group_count() != 1) {
1167 return false;
1168 }
1169
1170 VLOG(2) << "Checking for backprop filter conv propagatability";
1171 CHECK_EQ(consumer->operand_count(), 2);
1172
1173 auto activations = consumer->mutable_operand(0);
1174 auto kernel = consumer->mutable_operand(1);
1175
1176 auto win_dims =
1177 consumer->window().dimensions(GetFirstChosenSpatialDim(consumer));
1178 const int64_t rhs_dilation = win_dims.window_dilation();
1179 const int64_t lhs_dilation = win_dims.base_dilation();
1180
1181 // LHS dilations are supported by PropagateOnConv, and not by
1182 // PropagateOnBackpropFilterConv.
1183 if (lhs_dilation != 1) {
1184 return false;
1185 }
1186 // If the rhs_dilation is absent, we want both LHS and RHS to be space-to-
1187 // batched for propagating on backprop convolutions.
1188
1189 if (rhs_dilation == 1 &&
1190 !ctrl_.enable_propagations_on_trivial_window_dilations) {
1191 if (!old_to_new_instrs_.contains(kernel) ||
1192 !old_to_new_instrs_.contains(activations)) {
1193 return false;
1194 }
1195 }
1196
1197 if (!old_to_new_instrs_.contains(kernel)) {
1198 const int64_t rhs_batch =
1199 kernel->shape().dimensions(consumer->convolution_dimension_numbers()
1200 .kernel_input_feature_dimension());
1201 auto dim_map_val_op_0 = instr_to_dim_map_[activations];
1202 const int64_t old_batch_dim =
1203 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
1204 const int64_t old_space_dim =
1205 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kSpace0)];
1206 auto first_operand = old_to_new_instrs_[activations];
1207 auto permute_dims_first_operand =
1208 instr_to_dim_permute_map_[first_operand];
1209 const int64_t new_batch_dim =
1210 DimLookUp(permute_dims_first_operand, old_batch_dim);
1211 const int64_t new_space_dim =
1212 DimLookUp(permute_dims_first_operand, old_space_dim);
1213 const int64_t lhs_batch =
1214 first_operand->shape().dimensions(new_batch_dim);
1215
1216 if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
1217 0) {
1218 return false;
1219 }
1220 // Because we want to convert activations into a space-to-batched version
1221 // only for backprop filter convolutions, we want to make sure that the
1222 // batch dimensions (feature dimensions, technically) are same sized.
1223 // Since LHS is already space-to-batched, we need to account for it too.
1224 if (rhs_batch * ctrl_.number_of_splits != lhs_batch) {
1225 return false;
1226 }
1227
1228 if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1229 dim_map_val_op_0, /*check_lhs*/ true)) {
1230 return false;
1231 }
1232
1233 // If kernel have not been propagated through, we can do
1234 // space-to-batch on them provided kernel has been propagated.
1235 VLOG(2)
1236 << "Backprop filter conv ready for propagation: activations ready, "
1237 " kernel will be space-to-batched";
1238 return true;
1239 }
1240
1241 if (!old_to_new_instrs_.contains(activations)) {
1242 const int64_t lhs_batch = activations->shape().dimensions(
1243 consumer->convolution_dimension_numbers().input_feature_dimension());
1244 auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
1245 const int64_t old_batch_dim =
1246 dim_map_val_op_1[DimMapper(SpaceToBatchDimMap::kBatch)];
1247 auto second_operand = old_to_new_instrs_[kernel];
1248 auto permute_dims_second_operand =
1249 instr_to_dim_permute_map_[second_operand];
1250 const int64_t new_batch_dim =
1251 DimLookUp(permute_dims_second_operand, old_batch_dim);
1252 const int64_t rhs_batch =
1253 second_operand->shape().dimensions(new_batch_dim);
1254
1255 // Because we want to convert activations into a space-to-batched version
1256 // only for backprop filter convolutions, we want to make sure that the
1257 // batch dimensions (feature dimensions, technically) are same sized.
1258 // Since RHS is already space-to-batched, we need to account for it too.
1259 if (rhs_batch != ctrl_.number_of_splits * lhs_batch) {
1260 return false;
1261 }
1262
1263 if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1264 dim_map_val_op_1, /*check_lhs*/ false)) {
1265 return false;
1266 }
1267
1268 // If activations have not been propagated through, we can do
1269 // space-to-batch on them provided kernel has been propagated.
1270 VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
1271 " activations will be space-to-batched";
1272 return true;
1273 }
1274
1275 auto first_operand = old_to_new_instrs_[activations];
1276 auto dim_map_val_op_0 = instr_to_dim_map_[activations];
1277
1278 auto second_operand = old_to_new_instrs_[kernel];
1279 auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
1280
1281 auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
1282 auto permute_dims_second_operand =
1283 instr_to_dim_permute_map_[second_operand];
1284
1285 const int64_t new_batch_dim_operand_0 =
1286 DimLookUp(permute_dims_first_operand,
1287 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)]);
1288
1289 const int64_t new_space_dim_operand_0 =
1290 DimLookUp(permute_dims_first_operand,
1291 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kSpace0)]);
1292
1293 const int64_t new_batch_dim_operand_1 =
1294 DimLookUp(permute_dims_second_operand,
1295 dim_map_val_op_1[DimMapper(SpaceToBatchDimMap::kBatch)]);
1296 const int64_t new_space_dim_operand_1 =
1297 DimLookUp(permute_dims_second_operand,
1298 dim_map_val_op_1[DimMapper(SpaceToBatchDimMap::kSpace0)]);
1299
1300 if (first_operand->shape().dimensions(new_batch_dim_operand_0) !=
1301 second_operand->shape().dimensions(new_batch_dim_operand_1)) {
1302 VLOG(2) << "Backprop filter conv not ready for propagation because batch "
1303 "dimensions don't line up";
1304 return false;
1305 }
1306
1307 if (first_operand->shape().dimensions(new_space_dim_operand_0) >
1308 rhs_dilation *
1309 second_operand->shape().dimensions(new_space_dim_operand_1)) {
1310 VLOG(2) << "Backprop filter conv not ready for propagation because of "
1311 "dilation factor mismatch";
1312 return false;
1313 }
1314
1315 if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1316 dim_map_val_op_0, /*check_lhs*/ true)) {
1317 return false;
1318 }
1319
1320 if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1321 dim_map_val_op_1, /*check_lhs*/ false)) {
1322 return false;
1323 }
1324
1325 VLOG(2) << "Backprop filter conv ready for propagation";
1326
1327 return true;
1328 }
1329
1330 if (consumer->opcode() == HloOpcode::kReduceWindow ||
1331 consumer->opcode() == HloOpcode::kReduce) {
1332 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1333 auto old_producer = consumer->mutable_operand(i);
1334 if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
1335 return false;
1336 }
1337 }
1338
1339 // Make sure the post space-to-batch dim size is larger than window size.
1340 if (consumer->opcode() == HloOpcode::kReduceWindow) {
1341 return IsSpaceToBatchedSpaceSizeSuitable(consumer);
1342 }
1343 }
1344
1345 if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
1346 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1347 auto old_producer = consumer->mutable_operand(i);
1348 if (i < 2 && !old_to_new_instrs_.contains(old_producer)) {
1349 return false;
1350 }
1351 }
1352
1353 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1354 auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
1355 auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
1356
1357 auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
1358 auto permute_dims_second_operand =
1359 instr_to_dim_permute_map_[second_operand];
1360
1361 // The permuting must match.
1362 if (permute_dims_first_operand != permute_dims_second_operand) {
1363 VLOG(2) << "Can't propagate through select and scatter due to "
1364 "permutation mismatch";
1365 return false;
1366 }
1367
1368 const int64_t old_batch_dim =
1369 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
1370 const int64_t old_space_dim =
1371 dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kSpace0)];
1372
1373 const int64_t new_batch_dim =
1374 DimLookUp(permute_dims_first_operand, old_batch_dim);
1375 const int64_t new_space_dim =
1376 DimLookUp(permute_dims_first_operand, old_space_dim);
1377
1378 if (first_operand->shape().dimensions(new_batch_dim) !=
1379 second_operand->shape().dimensions(new_batch_dim)) {
1380 VLOG(2)
1381 << "Can't propagate through select and scatter due to dim mismatch";
1382 return false;
1383 }
1384
1385 const int64_t stride =
1386 consumer->window().dimensions(old_space_dim).stride();
1387 const int64_t pad_high =
1388 consumer->window().dimensions(old_space_dim).padding_high();
1389 const int64_t pad_low =
1390 consumer->window().dimensions(old_space_dim).padding_low();
1391
1392 if ((first_operand->shape().dimensions(new_space_dim) + pad_high +
1393 pad_low) /
1394 stride !=
1395 second_operand->shape().dimensions(new_space_dim)) {
1396 VLOG(2) << "Can't propagate through select and scatter due to stride "
1397 "mismatch";
1398 return false;
1399 }
1400
1401 return IsSpaceToBatchedSpaceSizeSuitable(consumer);
1402 }
1403 return true;
1404 }
1405
PropagateOnBroadcast(HloInstruction * consumer,HloInstruction * producer)1406 void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
1407 HloInstruction* producer) {
1408 auto new_producer = old_to_new_instrs_[producer];
1409 auto permute_dims = instr_to_dim_permute_map_[new_producer];
1410 auto dim_map_val = instr_to_dim_map_[producer];
1411
1412 const int64_t old_batch_dim =
1413 dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
1414 const int64_t old_space_dim =
1415 dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
1416
1417 auto orig_broadcast_dims = consumer->dimensions();
1418
1419 bool batch_is_broadcasted =
1420 absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
1421 const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1422 const int64_t new_space_dim = DimLookUp(permute_dims, old_space_dim);
1423
1424 bool map_found = broadcast_map_.contains(consumer);
1425 if (map_found) {
1426 // Check if we previously had created the same broadcast.
1427 for (auto previous_broadcast : broadcast_map_[consumer]) {
1428 if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
1429 new_producer->shape())) {
1430 return;
1431 }
1432 }
1433 }
1434
1435 std::vector<int64_t> final_shape_dims(
1436 new_producer->shape().dimensions().begin(),
1437 new_producer->shape().dimensions().end());
1438 if (batch_is_broadcasted) {
1439 final_shape_dims[new_batch_dim] =
1440 producer->shape().dimensions(old_batch_dim);
1441 final_shape_dims[new_space_dim] *= ctrl_.number_of_splits;
1442 }
1443
1444 std::vector<int64_t> broadcast_dims;
1445 const auto& dimensions = consumer->dimensions();
1446 broadcast_dims.reserve(dimensions.size());
1447 for (auto j : dimensions) {
1448 broadcast_dims.push_back(DimLookUp(permute_dims, j));
1449 }
1450 auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
1451 broadcast_dims, final_shape_dims);
1452 VLOG(1) << "Created broadcast " << new_broadcast->ToString();
1453
1454 if (batch_is_broadcasted) {
1455 new_broadcast =
1456 MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
1457 .ValueOrDie();
1458 VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
1459 }
1460
1461 if (!map_found) {
1462 absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
1463 broadcast_map_[consumer] = set_of_broadcasts;
1464 }
1465 broadcast_map_[consumer].insert(new_broadcast);
1466 }
1467
RewriteBroadcastTree(HloInstruction * producer,std::vector<HloInstruction * > & instructions_to_transform)1468 void ConvolutionVisitor::RewriteBroadcastTree(
1469 HloInstruction* producer,
1470 std::vector<HloInstruction*>& instructions_to_transform) {
1471 CHECK(old_to_new_instrs_.contains(producer));
1472 for (auto instr : instructions_to_transform) {
1473 if (instr->opcode() == HloOpcode::kBroadcast) {
1474 PropagateOnBroadcast(instr, producer);
1475 } else if (IsTrivialElementwise(instr)) {
1476 Propagate(instr, /*producer=*/instr->mutable_operand(0)).ValueOrDie();
1477 } else {
1478 LOG(FATAL) << "Unsupported opcode in RewriteBroadcastTree";
1479 }
1480 }
1481 }
1482
IsBroadcastTree(HloInstruction * op,HloInstruction * consumer,std::vector<HloInstruction * > & instructions_to_transform)1483 bool ConvolutionVisitor::IsBroadcastTree(
1484 HloInstruction* op, HloInstruction* consumer,
1485 std::vector<HloInstruction*>& instructions_to_transform) {
1486 if (op->opcode() == HloOpcode::kBroadcast) {
1487 // We want to ensure that the broadcast did not happen on the space and
1488 // batch dimensions.
1489 if (IsBroadcastPropagatable(op, consumer)) {
1490 instructions_to_transform.push_back(op);
1491 return true;
1492 } else {
1493 return false;
1494 }
1495 }
1496 if (Match(op, m::ConstantScalar())) {
1497 return true;
1498 }
1499 if (!IsTrivialElementwise(op)) {
1500 return false;
1501 }
1502 for (int64_t i = 0; i < op->operand_count(); ++i) {
1503 if (!IsBroadcastTree(op->mutable_operand(i), consumer,
1504 instructions_to_transform)) {
1505 return false;
1506 }
1507 }
1508 instructions_to_transform.push_back(op);
1509 return true;
1510 }
1511
IsBroadcastPropagatable(HloInstruction * broadcast,HloInstruction * old_other_op)1512 bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
1513 HloInstruction* old_other_op) {
1514 CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast);
1515 CHECK(instr_to_dim_map_.contains(old_other_op));
1516
1517 auto result = instr_to_dim_map_[old_other_op];
1518 const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1519 auto broadcast_dims = broadcast->dimensions();
1520 return !absl::c_linear_search(broadcast_dims, space_dim);
1521 }
1522
IsOpcodeNonPropagatable(HloInstruction * consumer)1523 bool ConvolutionVisitor::IsOpcodeNonPropagatable(HloInstruction* consumer) {
1524 // We can add more non-propagatable opcodes as needed.
1525 switch (consumer->opcode()) {
1526 case HloOpcode::kCustomCall:
1527 return true;
1528 default:
1529 return false;
1530 }
1531 }
1532
SupportedOpForPropagation(HloInstruction * consumer,HloInstruction * producer)1533 bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
1534 HloInstruction* producer) {
1535 if (IsOpcodeNonPropagatable(consumer)) {
1536 return false;
1537 }
1538
1539 if (IsTrivialElementwise(consumer)) {
1540 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1541 if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1542 if (!IsBroadcastPropagatable(consumer->mutable_operand(i), producer)) {
1543 VLOG(2) << "Could not propagate through broadcast";
1544 return false;
1545 }
1546 }
1547 }
1548 return true;
1549 }
1550
1551 if (consumer->opcode() == HloOpcode::kConvolution) {
1552 return true;
1553 }
1554
1555 if (consumer->opcode() == HloOpcode::kConcatenate) {
1556 HloInstruction* pivot_operand = nullptr;
1557 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1558 if (instr_to_dim_map_.contains(consumer->mutable_operand(i))) {
1559 pivot_operand = consumer->mutable_operand(i);
1560 break;
1561 }
1562 }
1563 if (pivot_operand == nullptr) {
1564 VLOG(1) << "Concat: Dim map not found on any operand";
1565 return false;
1566 }
1567 // Disallow concating on the batch and space dims
1568 auto result = instr_to_dim_map_[pivot_operand];
1569 const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1570 const int64_t old_space_dim =
1571 result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1572 if (consumer->concatenate_dimension() == old_batch_dim ||
1573 consumer->concatenate_dimension() == old_space_dim) {
1574 return false;
1575 }
1576 return true;
1577 }
1578
1579 if (consumer->opcode() == HloOpcode::kReverse) {
1580 auto operand_0 = consumer->mutable_operand(0);
1581 if (!instr_to_dim_map_.contains(operand_0)) {
1582 return false;
1583 }
1584 // Disallow reversing on the batch and space dims
1585 auto result = instr_to_dim_map_[operand_0];
1586 const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1587 const int64_t old_space_dim =
1588 result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1589
1590 for (auto dim : consumer->dimensions()) {
1591 if (dim == old_batch_dim || dim == old_space_dim) {
1592 return false;
1593 }
1594 }
1595 return true;
1596 }
1597
1598 if (consumer->opcode() == HloOpcode::kTranspose) {
1599 return true;
1600 }
1601
1602 if (consumer->opcode() == HloOpcode::kPad) {
1603 auto operand_0 = consumer->mutable_operand(0);
1604 if (!instr_to_dim_map_.contains(operand_0)) {
1605 return false;
1606 }
1607 // Disallow reversing on the batch and space dims
1608 auto result = instr_to_dim_map_[operand_0];
1609 const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1610 const int64_t old_space_dim =
1611 result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1612
1613 auto does_dim_have_padding = [](PaddingConfig padding_config, int64_t dim) {
1614 return padding_config.dimensions(dim).edge_padding_low() != 0 ||
1615 padding_config.dimensions(dim).edge_padding_high() != 0 ||
1616 padding_config.dimensions(dim).interior_padding() != 0;
1617 };
1618 // Batch and space dims should not have padding.
1619 if (does_dim_have_padding(consumer->padding_config(), old_batch_dim) ||
1620 does_dim_have_padding(consumer->padding_config(), old_space_dim)) {
1621 return false;
1622 }
1623 return true;
1624 }
1625
1626 if (consumer->opcode() == HloOpcode::kReduce) {
1627 // Support only the trivial case where both batch and split spatial dim are
1628 // being reduced
1629
1630 auto reduce_dims = consumer->dimensions();
1631 auto result = instr_to_dim_map_[consumer->mutable_operand(0)];
1632 const int64_t batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1633 const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1634 VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim
1635 << " space_dim " << space_dim << " reduce "
1636 << consumer->ToString();
1637 return absl::c_linear_search(reduce_dims, batch_dim) &&
1638 absl::c_linear_search(reduce_dims, space_dim);
1639 }
1640
1641 if (consumer->opcode() == HloOpcode::kReduceWindow &&
1642 consumer->shape().IsTuple()) {
1643 // TODO (b/73062247) variadic reduce window is not yet supported.
1644 return false;
1645 }
1646 if (consumer->opcode() == HloOpcode::kReduceWindow ||
1647 consumer->opcode() == HloOpcode::kSelectAndScatter) {
1648 auto first_operand = consumer->mutable_operand(0);
1649 auto window = consumer->window();
1650 if (instr_to_dim_map_.count(first_operand) <= 0) {
1651 VLOG(1) << "Dim map not found on windowed operand. Window dim count "
1652 << window.dimensions().size();
1653 return false;
1654 }
1655 // Disallow windowing on the batch dim
1656 auto result = instr_to_dim_map_[first_operand];
1657 const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1658 const int64_t old_space_dim =
1659 result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1660 if (window.dimensions(old_batch_dim).size() != 1) {
1661 return false;
1662 }
1663
1664 // Only allow no-low-padding cases.
1665 if (window.dimensions(old_space_dim).padding_low() != 0) {
1666 return false;
1667 }
1668
1669 // No base/window dilations allowed on space and batch dimensions.
1670 if (window.dimensions(old_space_dim).base_dilation() != 1 ||
1671 window.dimensions(old_space_dim).window_dilation() != 1) {
1672 return false;
1673 }
1674 // No base/window dilations allowed on space and batch dimensions.
1675 if (window.dimensions(old_batch_dim).base_dilation() != 1 ||
1676 window.dimensions(old_batch_dim).window_dilation() != 1) {
1677 return false;
1678 }
1679
1680 // Only allow small high pads.
1681 if (window.dimensions(old_space_dim).padding_high() >
1682 window.dimensions(old_space_dim).size()) {
1683 return false;
1684 }
1685
1686 // Operand 0 must have been propagated through
1687 if (old_to_new_instrs_.count(first_operand) <= 0) {
1688 return false;
1689 }
1690
1691 auto new_operand = old_to_new_instrs_[first_operand];
1692 auto permute_dims = instr_to_dim_permute_map_[new_operand];
1693
1694 // Select-and-scatter specific checks.
1695 if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
1696 const int64_t new_space_dim = DimLookUp(permute_dims, old_space_dim);
1697 // Make sure that the stride lines up.
1698 if (new_operand->shape().dimensions(new_space_dim) %
1699 window.dimensions(old_space_dim).stride() !=
1700 0) {
1701 return false;
1702 }
1703
1704 // Only support floating point datatypes.
1705 if (!ShapeUtil::ElementIsFloating(consumer->shape())) {
1706 return false;
1707 }
1708 // We currently only support adds in the scatter.
1709 auto scatter_comp = consumer->scatter();
1710 if (!Match(scatter_comp->root_instruction(),
1711 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
1712 return false;
1713 }
1714 // Select should just be a single comparison with GE as the direction.
1715 auto select_comp = consumer->select();
1716 if (!Match(select_comp->root_instruction(),
1717 m::Compare(m::Parameter(0), m::Parameter(1))
1718 .WithComparisonDirection(ComparisonDirection::kGe)) &&
1719 !Match(select_comp->root_instruction(),
1720 m::Compare(m::Parameter(1), m::Parameter(0))
1721 .WithComparisonDirection(ComparisonDirection::kGe))) {
1722 return false;
1723 }
1724 // We do not support low padding on select-and-scatter.
1725 if (consumer->window().dimensions(old_space_dim).padding_low() != 0) {
1726 return false;
1727 }
1728 }
1729
1730 return true;
1731 }
1732
1733 return false;
1734 }
1735
Propagate(HloInstruction * consumer,HloInstruction * producer)1736 StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
1737 HloInstruction* producer) {
1738 auto computation = consumer->parent();
1739 if (IsTrivialElementwise(consumer)) {
1740 auto dim_map_val = instr_to_dim_map_[producer];
1741 auto new_consumer = computation->AddInstruction(consumer->Clone());
1742
1743 bool is_pivot_producer_modified = false;
1744 // For elementwise binary ops, both of whose operands have been space-to-
1745 // batched, if their new spatial sizes don't match, choose the bigger one
1746 // as the producer.
1747 if (consumer->IsElementwiseBinary() ||
1748 consumer->opcode() == HloOpcode::kSelect) {
1749 int64_t pivot_operand_number = -1;
1750 HloInstruction* pivot_operand = nullptr;
1751 for (int i = 0; i < consumer->operand_count(); ++i) {
1752 if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1753 continue;
1754 }
1755 auto operand = consumer->mutable_operand(i);
1756 if (old_to_new_instrs_.contains(operand)) {
1757 if (pivot_operand_number == -1 ||
1758 old_to_new_instrs_[pivot_operand]->shape().dimensions() <
1759 old_to_new_instrs_[operand]->shape().dimensions()) {
1760 is_pivot_producer_modified = true;
1761 pivot_operand_number = i;
1762 pivot_operand = consumer->mutable_operand(pivot_operand_number);
1763 }
1764 }
1765 }
1766 if (pivot_operand_number != -1) {
1767 producer = pivot_operand;
1768 }
1769 }
1770
1771 for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1772 std::vector<HloInstruction*> instructions_to_transform;
1773
1774 if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1775 auto broadcast = consumer->mutable_operand(i);
1776 PropagateOnBroadcast(broadcast, producer);
1777 HloInstruction* new_broadcast = nullptr;
1778 auto new_producer = old_to_new_instrs_[producer];
1779 for (auto previous_broadcast : broadcast_map_[broadcast]) {
1780 if (ShapeUtil::CompatibleIgnoringElementType(
1781 previous_broadcast->shape(), new_producer->shape())) {
1782 new_broadcast = previous_broadcast;
1783 break;
1784 }
1785 }
1786 CHECK_NE(new_broadcast, nullptr);
1787 TF_CHECK_OK(
1788 new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
1789 } else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
1790 HloInstruction* operand_to_use = nullptr;
1791 auto result = instr_to_dim_map_[producer];
1792 const int64_t old_batch_dim =
1793 result[DimMapper(SpaceToBatchDimMap::kBatch)];
1794 const int64_t old_space_dim =
1795 result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1796 const int64_t old_batch_size =
1797 producer->shape().dimensions(old_batch_dim);
1798 HloInstruction* new_instr =
1799 old_to_new_instrs_[consumer->mutable_operand(i)];
1800 HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
1801
1802 auto permute_dims = instr_to_dim_permute_map_[new_instr];
1803 const int64_t batch_dim = DimLookUp(permute_dims, old_batch_dim);
1804 const int64_t space_dim = DimLookUp(permute_dims, old_space_dim);
1805 const int64_t batch_size = new_instr->shape().dimensions(batch_dim);
1806
1807 if (new_instr->shape().dimensions(space_dim) !=
1808 pivot_new_instr->shape().dimensions(space_dim)) {
1809 // Because we do not propagate through transposes, the batch should
1810 // always be followed by the split space dimension.
1811 CHECK_EQ(batch_dim + 1, space_dim);
1812
1813 // Reshape to 1D, pad to the producer's size, reshape back to 2D.
1814 std::vector<int64_t> new_dimensions(
1815 new_instr->shape().dimensions().begin(),
1816 new_instr->shape().dimensions().end());
1817 new_dimensions[space_dim] *= (batch_size / old_batch_size);
1818 new_dimensions[batch_dim] = old_batch_size;
1819
1820 TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
1821 MakeReshapeHlo(new_dimensions, new_instr));
1822
1823 const int64_t pivot_space_size =
1824 pivot_new_instr->shape().dimensions(space_dim) * batch_size /
1825 old_batch_size;
1826
1827 CHECK(pivot_space_size > new_dimensions[space_dim] ||
1828 !is_pivot_producer_modified);
1829
1830 PaddingConfig padding_config =
1831 MakeNoPaddingConfig(reshape->shape().dimensions_size());
1832 padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
1833 pivot_space_size - new_dimensions[space_dim]);
1834 padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
1835 HloInstruction* padding =
1836 computation_->AddInstruction(HloInstruction::CreateConstant(
1837 LiteralUtil::Zero(reshape->shape().element_type())));
1838
1839 TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
1840 MakePadHlo(reshape, padding, padding_config));
1841
1842 TF_ASSIGN_OR_RETURN(
1843 operand_to_use,
1844 MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
1845 padded_operand));
1846
1847 } else {
1848 operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
1849 }
1850 TF_CHECK_OK(
1851 new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
1852 } else if (consumer->IsElementwiseBinary() &&
1853 consumer->mutable_operand(i)->opcode() ==
1854 HloOpcode::kBroadcast &&
1855 IsBroadcastTree(consumer->mutable_operand(i), producer,
1856 instructions_to_transform)) {
1857 RewriteBroadcastTree(producer, instructions_to_transform);
1858 TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
1859 i, old_to_new_instrs_[consumer->mutable_operand(i)]));
1860 } else if (consumer->operand(i)->opcode() == HloOpcode::kConstant) {
1861 TF_ASSIGN_OR_RETURN(
1862 auto new_constant,
1863 PropagateOnConstant(consumer->mutable_operand(i), producer));
1864 TF_CHECK_OK(
1865 new_consumer->ReplaceOperandWithDifferentShape(i, new_constant));
1866 }
1867 }
1868 auto old_type = new_consumer->mutable_shape()->element_type();
1869 *(new_consumer->mutable_shape()) = old_to_new_instrs_[producer]->shape();
1870
1871 // The element type needs to be retained.
1872 new_consumer->mutable_shape()->set_element_type(old_type);
1873
1874 old_to_new_instrs_[consumer] = new_consumer;
1875 instr_to_dim_map_[consumer] = std::vector<int64_t>(dim_map_val);
1876 CHECK(instr_to_dim_permute_map_.contains(old_to_new_instrs_[producer]));
1877 instr_to_dim_permute_map_[new_consumer] = std::vector<int64_t>(
1878 instr_to_dim_permute_map_[old_to_new_instrs_[producer]]);
1879
1880 VLOG(2) << " new_consumer " << new_consumer->ToString()
1881 << " old_to_new_instrs_[producer] "
1882 << old_to_new_instrs_[producer]->ToString() << " permute dims "
1883 << instr_to_dim_permute_map_.count(new_consumer);
1884
1885 return true;
1886 }
1887
1888 if (consumer->opcode() == HloOpcode::kConvolution) {
1889 if (IsConvSuitableForSpaceToBatch(consumer)) {
1890 TF_CHECK_OK(PropagateOnConv(consumer));
1891 return true;
1892 } else {
1893 TF_CHECK_OK(PropagateOnBackpropFilterConv(consumer));
1894 return false;
1895 }
1896 }
1897
1898 if (consumer->opcode() == HloOpcode::kConcatenate) {
1899 TF_CHECK_OK(PropagateOnConcat(consumer));
1900 return true;
1901 }
1902
1903 if (consumer->opcode() == HloOpcode::kReverse) {
1904 TF_CHECK_OK(PropagateOnReverse(consumer));
1905 return true;
1906 }
1907
1908 // TODO(b/189500737) : Consider a common way of propagation for
1909 // slice/pad/reduce-window.
1910 if (consumer->opcode() == HloOpcode::kPad) {
1911 TF_CHECK_OK(PropagateOnPad(consumer));
1912 return true;
1913 }
1914
1915 if (consumer->opcode() == HloOpcode::kReduce) {
1916 auto new_consumer = computation->AddInstruction(consumer->Clone());
1917 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1918
1919 auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1920 const int64_t old_batch_dim =
1921 dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
1922
1923 auto permute_dims = instr_to_dim_permute_map_[first_operand];
1924 const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1925
1926 auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0));
1927 std::vector<int64_t> old_spatial_dims = retval.first;
1928 std::vector<int64_t> new_spatial_dims = retval.second;
1929
1930 TF_ASSIGN_OR_RETURN(
1931 first_operand,
1932 SelectValidPortion(first_operand, consumer->mutable_operand(0),
1933 consumer->mutable_operand(1), new_batch_dim,
1934 new_spatial_dims, old_batch_dim, old_spatial_dims));
1935
1936 std::vector<int64_t> changed_dims(new_consumer->dimensions().size());
1937 for (int64_t i = 0; i < new_consumer->dimensions().size(); ++i) {
1938 changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i));
1939 }
1940 *(new_consumer->mutable_dimensions()) = changed_dims;
1941 // Replace operand 0.
1942 TF_CHECK_OK(
1943 new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1944 // We do not set instr_to_dim_permute_map_ here because no further
1945 // propagation is needed here.
1946 old_to_new_instrs_[consumer] = new_consumer;
1947 instr_to_dim_map_[consumer] = std::vector<int64_t>(dim_map_val);
1948
1949 // Since the resultant ordering of dimension is the same as before, no
1950 // further propagation is needed.
1951 return false;
1952 }
1953
1954 if (consumer->opcode() == HloOpcode::kTranspose) {
1955 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1956 // Pass the first operand forward, with the map of the permuted dims.
1957 auto new_consumer = computation->AddInstruction(first_operand->Clone());
1958 old_to_new_instrs_[consumer] = new_consumer;
1959 auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1960 const int64_t old_batch_dim =
1961 dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
1962 const int64_t old_space_dim =
1963 dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
1964 const int64_t old_feature_dim =
1965 dim_map_val[DimMapper(SpaceToBatchDimMap::kFeature)];
1966
1967 int64_t new_batch_dim, new_space_dim, new_feature_dim;
1968 std::vector<int64_t> new_dimensions(consumer->dimensions().size());
1969 for (int64_t ctr = 0; ctr < consumer->dimensions().size(); ++ctr) {
1970 int64_t dim = consumer->dimensions(ctr);
1971 if (dim == old_batch_dim) {
1972 new_batch_dim = ctr;
1973 }
1974 if (dim == old_space_dim) {
1975 new_space_dim = ctr;
1976 }
1977 if (dim == old_feature_dim) {
1978 new_feature_dim = ctr;
1979 }
1980 }
1981
1982 std::vector<int64_t> dim_map(NumMappedDims());
1983 dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = new_batch_dim;
1984 dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = new_feature_dim;
1985 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = new_space_dim;
1986 instr_to_dim_map_[consumer] = dim_map;
1987
1988 std::vector<int64_t> new_permute_dims(consumer->dimensions().size());
1989 auto permute_dims = instr_to_dim_permute_map_[first_operand];
1990 for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
1991 new_permute_dims[i] = DimLookUp(permute_dims, consumer->dimensions(i));
1992 }
1993
1994 instr_to_dim_permute_map_[new_consumer] = new_permute_dims;
1995
1996 return true;
1997 }
1998
1999 if (consumer->opcode() == HloOpcode::kReduceWindow ||
2000 consumer->opcode() == HloOpcode::kSelectAndScatter) {
2001 bool is_select_and_scatter =
2002 consumer->opcode() == HloOpcode::kSelectAndScatter;
2003 auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
2004
2005 auto init_val = is_select_and_scatter ? consumer->mutable_operand(2)
2006 : consumer->mutable_operand(1);
2007 auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
2008
2009 auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0));
2010 std::vector<int64_t> old_spatial_dims = retval.first;
2011 std::vector<int64_t> new_spatial_dims = retval.second;
2012
2013 const int64_t old_batch_dim =
2014 dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
2015 const int64_t old_space_dim = old_spatial_dims[0];
2016 auto permute_dims = instr_to_dim_permute_map_[first_operand];
2017 const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
2018 const int64_t new_space_dim = new_spatial_dims[0];
2019
2020 // Calculate the required halo size
2021 auto new_shape = first_operand->shape();
2022 auto old_shape = consumer->mutable_operand(0)->shape();
2023
2024 const int64_t new_space_size = new_shape.dimensions(new_space_dim);
2025 const int64_t stride =
2026 consumer->window().dimensions(old_space_dim).stride();
2027
2028 auto pad_val =
2029 is_select_and_scatter
2030 ? computation_->AddInstruction(
2031 HloInstruction::CreateConstant(LiteralUtil::MinValue(
2032 consumer->operand(2)->shape().element_type())))
2033 : init_val;
2034 TF_ASSIGN_OR_RETURN(
2035 first_operand,
2036 SelectValidPortion(first_operand, consumer->mutable_operand(0), pad_val,
2037 new_batch_dim, new_spatial_dims, old_batch_dim,
2038 old_spatial_dims));
2039
2040 const int64_t extra_space = new_space_size % stride;
2041 if (extra_space) {
2042 CHECK_EQ(consumer->opcode(), HloOpcode::kReduceWindow);
2043 const int64_t old_batch_size = old_shape.dimensions(old_batch_dim);
2044 const int64_t old_space_size = old_shape.dimensions(old_space_dim);
2045 // If the shrunk space is still larger/equal than the original space, we
2046 // reduce the space.
2047 if ((new_space_size - extra_space) * old_batch_size *
2048 ctrl_.number_of_splits >=
2049 old_batch_size * old_space_size) {
2050 TF_ASSIGN_OR_RETURN(
2051 first_operand, ChangeSpatialSizeOnSpaceToBatchedShape(
2052 first_operand, new_batch_dim, old_batch_size,
2053 new_spatial_dims, new_space_size - extra_space));
2054 } else {
2055 TF_ASSIGN_OR_RETURN(
2056 first_operand,
2057 ChangeSpatialSizeOnSpaceToBatchedShape(
2058 first_operand, new_batch_dim, old_batch_size, new_spatial_dims,
2059 new_space_size + stride - extra_space,
2060 /*increase_spatial_size*/ true));
2061 }
2062 }
2063 const int64_t window_size =
2064 consumer->window().dimensions(old_space_dim).size();
2065 const int64_t last_overlap_point = ((new_space_size - 1) / stride) * stride;
2066 VLOG(1) << "last_overlap_point " << last_overlap_point << " window_size "
2067 << window_size << " new_space_size " << new_space_size;
2068
2069 const int64_t halo_size = last_overlap_point + window_size - new_space_size;
2070 if (halo_size > 0) {
2071 TF_ASSIGN_OR_RETURN(
2072 first_operand,
2073 HaloDuplicateWithSlice(first_operand, new_spatial_dims, new_batch_dim,
2074 /*low_padding=*/0, halo_size, init_val));
2075 }
2076
2077 Window new_win;
2078 for (int64_t i = 0; i < consumer->window().dimensions().size(); ++i) {
2079 auto dim = ReverseDimLookUp(permute_dims, i);
2080 new_win.add_dimensions();
2081 new_win.mutable_dimensions(i)->set_stride(
2082 consumer->window().dimensions(dim).stride());
2083 new_win.mutable_dimensions(i)->set_size(
2084 consumer->window().dimensions(dim).size());
2085 if (i == old_space_dim) {
2086 new_win.mutable_dimensions(i)->set_padding_high(0);
2087 new_win.mutable_dimensions(i)->set_padding_low(0);
2088 } else {
2089 new_win.mutable_dimensions(i)->set_padding_high(
2090 consumer->window().dimensions(dim).padding_high());
2091 new_win.mutable_dimensions(i)->set_padding_low(
2092 consumer->window().dimensions(dim).padding_low());
2093 }
2094 new_win.mutable_dimensions(i)->set_window_dilation(
2095 consumer->window().dimensions(dim).window_dilation());
2096 new_win.mutable_dimensions(i)->set_base_dilation(
2097 consumer->window().dimensions(dim).base_dilation());
2098 new_win.mutable_dimensions(i)->set_window_reversal(
2099 consumer->window().dimensions(dim).window_reversal());
2100 }
2101
2102 new_shape = first_operand->shape();
2103
2104 HloInstruction* new_consumer = nullptr;
2105 if (is_select_and_scatter) {
2106 auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
2107
2108 auto select_comp = consumer->select();
2109
2110 auto scatter_comp = consumer->scatter();
2111 TF_ASSIGN_OR_RETURN(
2112 auto new_select_and_scatter_shape,
2113 ShapeInference::InferSelectAndScatterShape(
2114 new_shape, select_comp->ComputeProgramShape(), new_win,
2115 second_operand->shape(), init_val->shape(),
2116 scatter_comp->ComputeProgramShape()));
2117 new_consumer =
2118 computation_->AddInstruction(HloInstruction::CreateSelectAndScatter(
2119 new_select_and_scatter_shape, first_operand, select_comp, new_win,
2120 second_operand, init_val, scatter_comp));
2121 // Replace operand 0.
2122 TF_CHECK_OK(
2123 new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
2124 // Replace operand 1.
2125 TF_CHECK_OK(
2126 new_consumer->ReplaceOperandWithDifferentShape(1, second_operand));
2127 VLOG(2) << "New select and scatter " << new_consumer->ToString();
2128
2129 // If the window size was larger than the stride, there could be overlaps.
2130 // Such cases require updates from both overlaps to be applied.
2131 if (halo_size > 0) {
2132 const int64_t rank = new_consumer->shape().rank();
2133
2134 const int64_t batch_size =
2135 new_consumer->shape().dimensions(new_batch_dim);
2136
2137 std::vector<int64_t> start_indices(rank, 0),
2138 end_indices(new_consumer->shape().dimensions().begin(),
2139 new_consumer->shape().dimensions().end()),
2140 strides(rank, 1);
2141 start_indices[new_space_dim] = new_space_size;
2142 end_indices[new_space_dim] = new_space_size + halo_size;
2143 end_indices[new_batch_dim] = batch_size - 1;
2144
2145 // This is the slice from halo padding.
2146 TF_ASSIGN_OR_RETURN(
2147 HloInstruction * bottom,
2148 MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
2149
2150 std::vector<int64_t> start_indices_top(rank, 0),
2151 end_indices_top(new_consumer->shape().dimensions().begin(),
2152 new_consumer->shape().dimensions().end());
2153 end_indices_top[new_space_dim] = halo_size;
2154 // The first batch has correct data.
2155 start_indices_top[new_batch_dim] = 1;
2156
2157 // This is the original area from where halo pad was extracted.
2158 TF_ASSIGN_OR_RETURN(HloInstruction * top,
2159 MakeSliceHlo(new_consumer, start_indices_top,
2160 end_indices_top, strides));
2161
2162 HloInstruction* default_fill =
2163 MakeBroadcastHlo(init_val, {}, top->shape().dimensions());
2164
2165 // Compare to see if the bottom area was changed.
2166 TF_ASSIGN_OR_RETURN(
2167 HloInstruction * bottom_compare,
2168 MakeCompareHlo(ComparisonDirection::kNe, bottom, default_fill));
2169
2170 // Take out only the changed values.
2171 TF_ASSIGN_OR_RETURN(
2172 HloInstruction * bottom_taken,
2173 MakeSelectHlo(bottom_compare, bottom, default_fill));
2174
2175 // Compare to see if the top area was changed.
2176 TF_ASSIGN_OR_RETURN(
2177 HloInstruction * top_compare,
2178 MakeCompareHlo(ComparisonDirection::kNe, top, default_fill));
2179
2180 // Take out only the changed values.
2181 TF_ASSIGN_OR_RETURN(HloInstruction * top_taken,
2182 MakeSelectHlo(top_compare, top, bottom_taken));
2183
2184 // This makes checks if the area was updated by both overlaps.
2185 TF_ASSIGN_OR_RETURN(
2186 HloInstruction * both_compare,
2187 MakeBinaryHlo(HloOpcode::kAnd, top_compare, bottom_compare));
2188
2189 // If it was, add them up.
2190 TF_ASSIGN_OR_RETURN(HloInstruction * both_added,
2191 MakeBinaryHlo(HloOpcode::kAdd, top, bottom));
2192
2193 // Pad the final result to the original shape.
2194 TF_ASSIGN_OR_RETURN(HloInstruction * final_selection,
2195 MakeSelectHlo(both_compare, both_added, top_taken));
2196
2197 PaddingConfig padding_config =
2198 MakeNoPaddingConfig(final_selection->shape().dimensions_size());
2199 padding_config.mutable_dimensions(new_batch_dim)
2200 ->set_edge_padding_low(1);
2201 padding_config.mutable_dimensions(new_space_dim)
2202 ->set_edge_padding_high(new_space_size);
2203 HloInstruction* padding =
2204 computation_->AddInstruction(HloInstruction::CreateConstant(
2205 LiteralUtil::Zero(final_selection->shape().element_type())));
2206
2207 TF_ASSIGN_OR_RETURN(
2208 final_selection,
2209 MakePadHlo(final_selection, padding, padding_config));
2210
2211 tensorflow::core::Bitmap b(batch_size * (new_space_size + halo_size));
2212 for (int k = 0; k < batch_size * (new_space_size + halo_size); ++k) {
2213 const int64_t space_index = k % (new_space_size + halo_size);
2214 const int64_t batch_index = (k / (new_space_size + halo_size));
2215 if (batch_index < 1 || space_index >= halo_size) {
2216 b.set(k);
2217 } else {
2218 b.clear(k);
2219 }
2220 }
2221
2222 auto arg_literal = LiteralUtil::CreateR1(b);
2223 VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
2224 HloInstruction* slice_mask = computation_->AddInstruction(
2225 HloInstruction::CreateConstant(std::move(arg_literal)));
2226
2227 std::vector<int64_t> slice_mask_reshape_dims(2);
2228 slice_mask_reshape_dims[0] = batch_size;
2229 slice_mask_reshape_dims[1] = (new_space_size + halo_size);
2230
2231 TF_ASSIGN_OR_RETURN(
2232 HloInstruction * slice_mask_reshaped,
2233 MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
2234
2235 // Broadcast the mask in all dimensions.
2236 HloInstruction* shape_mask = MakeBroadcastHlo(
2237 slice_mask_reshaped, {new_batch_dim, new_space_dim},
2238 final_selection->shape().dimensions());
2239
2240 TF_ASSIGN_OR_RETURN(
2241 new_consumer,
2242 MakeSelectHlo(shape_mask, new_consumer, final_selection));
2243 }
2244
2245 auto previous_shape =
2246 old_to_new_instrs_[consumer->mutable_operand(0)]->shape();
2247 std::vector<int64_t> start_indices(previous_shape.rank(), 0),
2248 end_indices(previous_shape.dimensions().begin(),
2249 previous_shape.dimensions().end()),
2250 strides(previous_shape.rank(), 1);
2251
2252 TF_ASSIGN_OR_RETURN(
2253 new_consumer,
2254 MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
2255
2256 } else {
2257 auto reduce_comp = consumer->to_apply();
2258 TF_ASSIGN_OR_RETURN(auto new_reduce_window_shape,
2259 ShapeInference::InferReduceWindowShape(
2260 new_shape, init_val->shape(), new_win));
2261 new_consumer =
2262 computation_->AddInstruction(HloInstruction::CreateReduceWindow(
2263 new_reduce_window_shape, first_operand, init_val, new_win,
2264 reduce_comp));
2265 // Replace operand 0.
2266 TF_CHECK_OK(
2267 new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
2268 VLOG(1) << "New reduce window " << new_consumer->ToString();
2269 }
2270
2271 old_to_new_instrs_[consumer] = new_consumer;
2272 instr_to_dim_map_[consumer] = std::vector<int64_t>(dim_map_val);
2273
2274 instr_to_dim_permute_map_[new_consumer] = std::vector<int64_t>(
2275 instr_to_dim_permute_map_[old_to_new_instrs_[consumer->mutable_operand(
2276 0)]]);
2277
2278 return true;
2279 }
2280
2281 LOG(FATAL) << "Trying to propagate through an unsupported instruction "
2282 << consumer->ToString();
2283 return true;
2284 }
2285
SelectValidPortion(HloInstruction * new_instr,HloInstruction * old_instr,HloInstruction * select_val,int64_t new_batch_dim,absl::Span<const int64_t> new_space_dims,int64_t old_batch_dim,absl::Span<const int64_t> old_space_dims)2286 StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
2287 HloInstruction* new_instr, HloInstruction* old_instr,
2288 HloInstruction* select_val, int64_t new_batch_dim,
2289 absl::Span<const int64_t> new_space_dims, int64_t old_batch_dim,
2290 absl::Span<const int64_t> old_space_dims) {
2291 auto new_shape = new_instr->shape();
2292 auto old_shape = old_instr->shape();
2293 VLOG(1) << "In SelectValidPortion new_batch_dim " << new_batch_dim
2294 << " new_space_dim " << new_space_dims[0] << " old_batch_dim "
2295 << old_batch_dim << " old_space_dim " << old_space_dims[0];
2296 const int64_t new_batch_size = new_shape.dimensions(new_batch_dim);
2297 const int64_t new_space_size = new_shape.dimensions(new_space_dims[0]);
2298 const int64_t old_batch_size = old_shape.dimensions(old_batch_dim);
2299 const int64_t old_space_size = old_shape.dimensions(old_space_dims[0]);
2300 CHECK_EQ(new_batch_size % old_batch_size, 0)
2301 << " New batch size " << new_batch_size << " old batch size "
2302 << old_batch_size;
2303 const int64_t num_splits = ctrl_.number_of_splits;
2304 const int64_t spatial_dim_count = new_space_dims.size();
2305
2306 // The dimension ordering found bounds is Old Batch, BN, BN -1 .., B0, S0, S1
2307 // ..., SN
2308 std::vector<int64_t> bounds(2 + spatial_dim_count, new_space_size);
2309 bounds[0] = old_batch_size;
2310 bounds[1] = IPow<int64_t>(num_splits, spatial_dim_count);
2311
2312 const int64_t total_new_space =
2313 IPow<int64_t>(new_space_size, spatial_dim_count);
2314
2315 // Build a constant PRED to decide which elements in the split dimension
2316 // are from halo.
2317 tensorflow::core::Bitmap b(new_batch_size * total_new_space);
2318 for (int k = 0; k < new_batch_size * total_new_space; ++k) {
2319 auto radix = ToMixedRadix(k, bounds);
2320
2321 bool out_of_bounds = false;
2322 int64_t batch_residue = 1;
2323 for (int i = 0; i < spatial_dim_count; ++i) {
2324 const int64_t space_index = radix[2 + i];
2325 const int64_t batch_index = (radix[1] / batch_residue) % num_splits;
2326 batch_residue *= num_splits;
2327 if (batch_index * new_space_size + space_index >= old_space_size) {
2328 out_of_bounds = true;
2329 }
2330 }
2331
2332 if (!out_of_bounds) {
2333 b.set(k);
2334 } else {
2335 b.clear(k);
2336 }
2337 }
2338
2339 auto arg_literal = LiteralUtil::CreateR1(b);
2340 VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
2341 HloInstruction* slice_mask = computation_->AddInstruction(
2342 HloInstruction::CreateConstant(std::move(arg_literal)));
2343
2344 std::vector<int64_t> slice_mask_reshape_dims(1 + spatial_dim_count,
2345 new_space_size);
2346 slice_mask_reshape_dims[0] = new_batch_size;
2347
2348 TF_ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped,
2349 MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
2350
2351 std::vector<int64_t> broadcast_dims(new_space_dims.begin(),
2352 new_space_dims.end());
2353 broadcast_dims.insert(broadcast_dims.begin(), new_batch_dim);
2354 // Broadcast the mask in all dimensions of the activations.
2355 HloInstruction* shape_mask = MakeBroadcastHlo(
2356 slice_mask_reshaped, broadcast_dims, new_instr->shape().dimensions());
2357
2358 VLOG(1) << "Shape mask made " << shape_mask->ToString();
2359
2360 HloInstruction* zeroes =
2361 MakeBroadcastHlo(select_val, {}, new_instr->shape().dimensions());
2362
2363 TF_ASSIGN_OR_RETURN(new_instr, MakeSelectHlo(shape_mask, new_instr, zeroes));
2364
2365 return new_instr;
2366 }
2367
BatchToSpace(HloInstruction * old_instr)2368 StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
2369 HloInstruction* old_instr) {
2370 if (batch_to_space_map_.count(old_instr)) {
2371 CHECK_NE(batch_to_space_map_[old_instr], nullptr);
2372 return batch_to_space_map_[old_instr];
2373 }
2374
2375 auto result = instr_to_dim_map_[old_instr];
2376 const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
2377 const int64_t old_space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
2378
2379 const int64_t old_batch_size = old_instr->shape().dimensions(old_batch_dim);
2380 CHECK(old_to_new_instrs_.contains(old_instr));
2381 auto new_instr = old_to_new_instrs_[old_instr];
2382 VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
2383 << old_space_dim << " old_instr " << old_instr->ToString()
2384 << "\n new_instr " << new_instr->ToString() << " permute dims "
2385 << instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
2386 << old_batch_size;
2387 CHECK(instr_to_dim_permute_map_.contains(new_instr));
2388 auto permute_dims = instr_to_dim_permute_map_[new_instr];
2389 const int64_t batch_dim = DimLookUp(permute_dims, old_batch_dim);
2390 const int64_t space_dim = DimLookUp(permute_dims, old_space_dim);
2391
2392 const int64_t spatial_dim_size = new_instr->shape().dimensions(space_dim);
2393
2394 std::vector<int64_t> split_spatial_dimensions(
2395 ctrl_.count_of_dimensions_to_convert);
2396 absl::c_iota(split_spatial_dimensions, space_dim);
2397
2398 TF_ASSIGN_OR_RETURN(new_instr, SplitAndTransposeMergedBatch(
2399 new_instr, batch_dim, old_batch_size,
2400 split_spatial_dimensions));
2401
2402 std::vector<int64_t> new_dimensions(new_instr->shape().dimensions().begin(),
2403 new_instr->shape().dimensions().end());
2404
2405 new_dimensions.erase(new_dimensions.begin() + split_spatial_dimensions[0],
2406 new_dimensions.begin() + split_spatial_dimensions[0] +
2407 ctrl_.count_of_dimensions_to_convert);
2408
2409 for (auto spatial_dimension : split_spatial_dimensions) {
2410 new_dimensions[spatial_dimension] =
2411 spatial_dim_size * ctrl_.number_of_splits;
2412 }
2413
2414 // Reshape the output of the new conv into the old convolutions shape.
2415 TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
2416 MakeReshapeHlo(new_dimensions, new_instr));
2417
2418 VLOG(1) << "Batch to space reshape " << reshape->ToString();
2419 const int64_t rank = old_instr->shape().rank();
2420 std::vector<int64_t> start_indices(rank, 0),
2421 end_indices(new_dimensions.begin(), new_dimensions.end()),
2422 strides(rank, 1);
2423
2424 for (auto spatial_dimension : split_spatial_dimensions) {
2425 end_indices[spatial_dimension] =
2426 old_instr->shape().dimensions(old_space_dim);
2427 }
2428
2429 // This slicing is getting rid of the padding we added to evenly divide space.
2430 TF_ASSIGN_OR_RETURN(
2431 HloInstruction * output_slice,
2432 MakeSliceHlo(reshape, start_indices, end_indices, strides));
2433 VLOG(1) << "Batch to space slice " << output_slice->ToString();
2434 std::vector<int64_t> transpose_dims(permute_dims);
2435 TF_ASSIGN_OR_RETURN(HloInstruction * output_transpose,
2436 MakeTransposeHlo(output_slice, transpose_dims));
2437 old_instr->SetupDerivedInstruction(output_transpose);
2438
2439 batch_to_space_map_[old_instr] = output_transpose;
2440 return output_transpose;
2441 }
2442
PropagateOnUsers(HloInstruction * old_conv)2443 Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) {
2444 std::queue<std::pair<HloInstruction*, HloInstruction*>> propagation_worklist;
2445
2446 if (old_conv->user_count() == 0) {
2447 TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
2448 BatchToSpace(old_conv));
2449 VLOG(1) << "Replacing the root instruction to "
2450 << batch_to_space->ToString();
2451 TF_CHECK_OK(computation_->ReplaceInstruction(old_conv, batch_to_space));
2452 VLOG(1) << "Replacement successful";
2453 return OkStatus();
2454 }
2455
2456 int64_t iteration_count = 0;
2457 propagation_worklist.push(
2458 std::make_pair(old_conv, old_conv->mutable_operand(0)));
2459
2460 while (!propagation_worklist.empty()) {
2461 auto top = propagation_worklist.front();
2462 auto node = top.first;
2463 auto parent = top.second;
2464 VLOG(1) << "Traversing for propagation operating on " << node->ToString();
2465 propagation_worklist.pop();
2466
2467 // Don't work on the same node again.
2468 if (old_to_new_instrs_.count(node) > 0 && iteration_count != 0) {
2469 continue;
2470 }
2471
2472 bool needs_further_propagation = true;
2473 if (iteration_count != 0) {
2474 // Do the space-to-batch propagation on this node.
2475 TF_ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent));
2476 }
2477 iteration_count++;
2478 // If this is the root, no room for further propagation.
2479 if (node->parent()->root_instruction() == node) {
2480 // The below case does not need going back to space.
2481 if (!needs_further_propagation) {
2482 VLOG(1) << "Replacing the root instruction to "
2483 << old_to_new_instrs_[node]->ToString();
2484 TF_CHECK_OK(
2485 computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
2486 continue;
2487 }
2488
2489 TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node));
2490 VLOG(1) << "Replacing the root instruction to "
2491 << batch_to_space->ToString();
2492 TF_CHECK_OK(computation_->ReplaceInstruction(node, batch_to_space));
2493 } else {
2494 if (!needs_further_propagation) {
2495 TF_CHECK_OK(
2496 computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
2497 continue;
2498 }
2499
2500 HloInstructionSet unsupported_users;
2501 // Insert all users into the queue, as long as the ops are supported and
2502 // the op is ready for propagation. If the op is unsupported, do
2503 // batch-to-space. If not ready, mark as non-propagatable.
2504 for (auto user : node->users()) {
2505 if (!SupportedOpForPropagation(user, node)) {
2506 VLOG(1) << "Unsupported op found " << user->ToString();
2507 unsupported_users.insert(user);
2508 continue;
2509 }
2510 // If the instruction is ready for propagation, add it to the queue.
2511 if (CanPropagate(user, node)) {
2512 non_propagatable_instrs_.erase(user);
2513 propagation_worklist.push(std::make_pair(user, node));
2514 } else {
2515 // Mark it as non-propagatable for now, for later revisiting.
2516 non_propagatable_instrs_.insert(user);
2517 }
2518 }
2519
2520 if (!unsupported_users.empty()) {
2521 TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
2522 BatchToSpace(node));
2523 for (auto user : unsupported_users) {
2524 for (int64_t i = 0; i < user->operand_count(); ++i) {
2525 if (user->operand(i) == node) {
2526 TF_CHECK_OK(user->ReplaceOperandWith(i, batch_to_space));
2527 }
2528 }
2529 }
2530 }
2531 }
2532 }
2533 return OkStatus();
2534 }
2535
PropagateOnConv(HloInstruction * convolution)2536 Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
2537 auto activations_old = convolution->mutable_operand(0);
2538
2539 CHECK(old_to_new_instrs_.contains(activations_old));
2540 auto activations_new = old_to_new_instrs_[activations_old];
2541 auto permute_dims = instr_to_dim_permute_map_[activations_new];
2542
2543 auto original_conv_dims = convolution->convolution_dimension_numbers();
2544
2545 auto old_new_dims = GetSpatialDimsToSplit(activations_old);
2546 std::vector<int64_t> old_spatial_dims = old_new_dims.first;
2547 std::vector<int64_t> new_spatial_dims = old_new_dims.second;
2548
2549 auto permuted_conv_dims_numbers = original_conv_dims;
2550
2551 int64_t activations_batch_dim =
2552 DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
2553 int64_t activations_feature_dim =
2554 DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
2555 permuted_conv_dims_numbers.set_input_batch_dimension(activations_batch_dim);
2556 permuted_conv_dims_numbers.set_input_feature_dimension(
2557 activations_feature_dim);
2558
2559 for (int64_t i = 0; i < original_conv_dims.input_spatial_dimensions_size();
2560 ++i) {
2561 permuted_conv_dims_numbers.set_input_spatial_dimensions(
2562 i, DimLookUp(permute_dims,
2563 original_conv_dims.input_spatial_dimensions(i)));
2564 }
2565
2566 const int64_t old_batch_dim = original_conv_dims.input_batch_dimension();
2567 const int64_t old_batch_size =
2568 activations_old->shape().dimensions(old_batch_dim);
2569
2570 ConvDetails c =
2571 GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
2572
2573 VLOG(1) << "Propagating on conv activations_batch_dim "
2574 << activations_batch_dim << " spatial_dimension_to_split "
2575 << c.spatial_dimensions_to_split[0] << " old_batch_size "
2576 << old_batch_size;
2577
2578 TF_ASSIGN_OR_RETURN(
2579 auto retval,
2580 BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
2581 activations_batch_dim, &new_spatial_dims));
2582 activations_new = retval.instr;
2583 std::vector<int64_t> trans_dims = retval.transpose_dims;
2584 CHECK(!trans_dims.empty());
2585 auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
2586 LiteralUtil::Zero(activations_new->shape().element_type())));
2587
2588 TF_ASSIGN_OR_RETURN(
2589 activations_new,
2590 SelectValidPortion(activations_new, activations_old, select_val,
2591 activations_batch_dim, new_spatial_dims, old_batch_dim,
2592 old_spatial_dims));
2593 // Create the new convolution dim numbers.
2594 auto new_dim_numbers = permuted_conv_dims_numbers;
2595
2596 const int64_t num_splits = ctrl_.number_of_splits;
2597 const int64_t output_offsets = convolution->shape().dimensions(
2598 permuted_conv_dims_numbers.output_spatial_dimensions(
2599 GetFirstChosenSpatialDim(convolution)));
2600 const int64_t output_offsets_per_split =
2601 CeilOfRatio(output_offsets, num_splits);
2602
2603 int64_t spatial_split_size =
2604 CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
2605
2606 VLOG(1) << "spatial size " << c.spatial_size << " halo size " << c.halo_size
2607 << " spatial_split_size " << spatial_split_size;
2608
2609 // Keep increasing the split size so that overall size isn't smaller than the
2610 // original spatial dimension. Unlike for the first space-to-batch'ed
2611 // convolution, while propagating, we can use the last halo_size as available
2612 // spatial size.
2613 // If the spatial size is less than the halo size required, we need to
2614 // increase the spatial size.
2615 while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0 ||
2616 spatial_split_size < c.halo_size) {
2617 spatial_split_size += c.stride;
2618 }
2619
2620 VLOG(1) << "Modified spatial_split_size " << spatial_split_size;
2621 const int64_t new_space_size =
2622 activations_new->shape().dimensions(new_spatial_dims[0]);
2623
2624 int64_t slice_size = spatial_split_size + c.halo_size;
2625 // In the below case, we cannot use the activations directly for Halo
2626 // Duplication. We must reshape them.
2627 if (spatial_split_size > new_space_size) {
2628 TF_ASSIGN_OR_RETURN(
2629 activations_new,
2630 ChangeSpatialSizeOnSpaceToBatchedShape(
2631 activations_new, activations_batch_dim, old_batch_size,
2632 new_spatial_dims, spatial_split_size,
2633 /*increase_spatial_size*/ true));
2634
2635 } else {
2636 // If the ideal spatial_split_size was smaller than the incoming spatial
2637 // dimension size, we don't need reshaping. Instead, we determine the
2638 // additional space available, and adjust the required slice size (and
2639 // thereby the halo size).
2640 VLOG(3)
2641 << "Decreasing the spatial size while propagating spatial_split_size "
2642 << spatial_split_size << " new_space_size " << new_space_size;
2643 if (spatial_split_size < new_space_size) {
2644 // If there's a stride mismatch, we change the new_space_size be
2645 // smaller (equal to spatial_split_size).
2646 if (new_space_size % c.stride != 0 || c.base_dilation_factor != 1) {
2647 TF_ASSIGN_OR_RETURN(
2648 activations_new,
2649 ChangeSpatialSizeOnSpaceToBatchedShape(
2650 activations_new, activations_batch_dim, old_batch_size,
2651 new_spatial_dims, spatial_split_size));
2652 } else {
2653 const int64_t additional_space_present = spatial_split_size % c.stride;
2654 spatial_split_size = new_space_size;
2655 slice_size =
2656 spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
2657 additional_space_present,
2658 static_cast<int64_t>(0));
2659 }
2660 }
2661 }
2662
2663 // For space-to-batch supported base-dilated convolutions, the low padding is
2664 // passed on to the new convolutions. Halo does not have to account for it.
2665 TF_ASSIGN_OR_RETURN(
2666 activations_new,
2667 HaloDuplicateWithSlice(
2668 activations_new, new_spatial_dims, activations_batch_dim,
2669 /*low_padding=*/c.base_dilation_factor != 1 &&
2670 c.inherent_low_padding != 0
2671 ? (c.inherent_low_padding == c.base_dilation_factor ? 1 : 0)
2672 : c.inherent_low_padding,
2673 slice_size - spatial_split_size));
2674
2675 // We will generate output such that batch is followed by the split spatial
2676 // dimension.
2677 const int64_t rank = (convolution->shape().rank());
2678 std::vector<int64_t> transpose_dims(rank);
2679 int dim_count = 0;
2680 std::map<int64_t, int64_t> dim_translator;
2681
2682 for (int j = 0;
2683 j < permuted_conv_dims_numbers.output_spatial_dimensions_size(); ++j) {
2684 if (j == GetFirstChosenSpatialDim(convolution)) {
2685 dim_translator[permuted_conv_dims_numbers.output_batch_dimension()] =
2686 dim_count;
2687 new_dim_numbers.set_output_batch_dimension(dim_count++);
2688 }
2689 dim_translator[permuted_conv_dims_numbers.output_spatial_dimensions(j)] =
2690 dim_count;
2691 new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
2692 dim_count++;
2693 }
2694
2695 dim_translator[permuted_conv_dims_numbers.output_feature_dimension()] =
2696 dim_count;
2697 new_dim_numbers.set_output_feature_dimension(dim_count);
2698
2699 int p = 0;
2700 for (const auto& entry : dim_translator) {
2701 transpose_dims[p] = entry.second;
2702 p++;
2703 }
2704
2705 auto new_window = convolution->window();
2706 const int64_t first_dim = GetFirstChosenSpatialDim(convolution);
2707 for (int i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
2708 new_window.mutable_dimensions(first_dim + i)
2709 ->set_padding_high(c.high_padding_for_conv);
2710 new_window.mutable_dimensions(first_dim + i)
2711 ->set_padding_low(c.low_padding_for_conv);
2712 }
2713 TF_ASSIGN_OR_RETURN(
2714 HloInstruction * new_conv,
2715 MakeConvolveHlo(
2716 activations_new, /*rhs=*/convolution->mutable_operand(1),
2717 convolution->feature_group_count(), convolution->batch_group_count(),
2718 new_window, new_dim_numbers, convolution->precision_config(),
2719 /*preferred_element_type=*/convolution->shape().element_type()));
2720 convolution->SetupDerivedInstruction(new_conv);
2721
2722 old_to_new_instrs_[convolution] = new_conv;
2723 VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
2724
2725 std::vector<int64_t> dim_map(NumMappedDims());
2726 dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] =
2727 original_conv_dims.output_batch_dimension();
2728 dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
2729 original_conv_dims.output_feature_dimension();
2730 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] =
2731 original_conv_dims.output_spatial_dimensions(
2732 GetFirstChosenSpatialDim(convolution));
2733 instr_to_dim_map_[convolution] = dim_map;
2734
2735 instr_to_dim_permute_map_[new_conv] = std::vector<int64_t>(transpose_dims);
2736
2737 convs_to_visit_.erase(convolution);
2738 return OkStatus();
2739 }
2740
PropagateOnConcat(HloInstruction * concat)2741 Status ConvolutionVisitor::PropagateOnConcat(HloInstruction* concat) {
2742 auto first_operand = old_to_new_instrs_[concat->mutable_operand(0)];
2743 auto permute_dims = instr_to_dim_permute_map_[first_operand];
2744 const int64_t new_concat_dim =
2745 DimLookUp(permute_dims, concat->concatenate_dimension());
2746 std::vector<HloInstruction*> new_operands(concat->operand_count());
2747 for (int64_t i = 0; i < concat->operand_count(); ++i) {
2748 new_operands[i] = old_to_new_instrs_[concat->mutable_operand(i)];
2749 }
2750 TF_ASSIGN_OR_RETURN(HloInstruction * new_concat,
2751 MakeConcatHlo(new_operands, new_concat_dim));
2752 old_to_new_instrs_[concat] = new_concat;
2753 // Set mappings from operand 0.
2754 instr_to_dim_map_[concat] =
2755 std::vector<int64_t>(instr_to_dim_map_[concat->mutable_operand(0)]);
2756 instr_to_dim_permute_map_[new_concat] =
2757 std::vector<int64_t>(instr_to_dim_permute_map_[first_operand]);
2758
2759 return OkStatus();
2760 }
2761
PropagateOnReverse(HloInstruction * reverse)2762 Status ConvolutionVisitor::PropagateOnReverse(HloInstruction* reverse) {
2763 auto first_operand = old_to_new_instrs_[reverse->mutable_operand(0)];
2764 auto permute_dims = instr_to_dim_permute_map_[first_operand];
2765
2766 std::vector<int64_t> new_reverse_dimensions(reverse->dimensions().size());
2767 int dim_count = 0;
2768 for (auto dim : reverse->dimensions()) {
2769 new_reverse_dimensions[dim_count++] = DimLookUp(permute_dims, dim);
2770 }
2771 TF_ASSIGN_OR_RETURN(HloInstruction * new_reverse,
2772 MakeReverseHlo(first_operand, new_reverse_dimensions));
2773 old_to_new_instrs_[reverse] = new_reverse;
2774 // Set mappings from operand 0.
2775 instr_to_dim_map_[reverse] =
2776 std::vector<int64_t>(instr_to_dim_map_[reverse->mutable_operand(0)]);
2777 instr_to_dim_permute_map_[new_reverse] =
2778 std::vector<int64_t>(instr_to_dim_permute_map_[first_operand]);
2779
2780 return OkStatus();
2781 }
2782
PropagateOnPad(HloInstruction * pad)2783 Status ConvolutionVisitor::PropagateOnPad(HloInstruction* pad) {
2784 auto first_operand = old_to_new_instrs_[pad->mutable_operand(0)];
2785 auto permute_dims = instr_to_dim_permute_map_[first_operand];
2786
2787 PaddingConfig padding_config;
2788 for (int i = 0; i < pad->shape().rank(); ++i) {
2789 auto dimension = padding_config.add_dimensions();
2790 const int64_t old_dim = ReverseDimLookUp(permute_dims, i);
2791 auto old_padding = pad->padding_config().dimensions(old_dim);
2792 dimension->set_edge_padding_low(old_padding.edge_padding_low());
2793 dimension->set_edge_padding_high(old_padding.edge_padding_high());
2794 dimension->set_interior_padding(old_padding.interior_padding());
2795 }
2796
2797 HloInstruction* padding = pad->mutable_operand(1);
2798
2799 TF_ASSIGN_OR_RETURN(auto new_pad,
2800 MakePadHlo(first_operand, padding, padding_config));
2801
2802 old_to_new_instrs_[pad] = new_pad;
2803 // Set mappings from operand 0.
2804 instr_to_dim_map_[pad] =
2805 std::vector<int64_t>(instr_to_dim_map_[pad->mutable_operand(0)]);
2806 instr_to_dim_permute_map_[new_pad] =
2807 std::vector<int64_t>(instr_to_dim_permute_map_[first_operand]);
2808
2809 return OkStatus();
2810 }
2811
TransposeAndMergeBatch(HloInstruction * activations,absl::Span<const int64_t> final_split_spatial_dim_positioning,int64_t activations_batch_dim,int64_t old_batch_size)2812 StatusOr<HloInstruction*> ConvolutionVisitor::TransposeAndMergeBatch(
2813 HloInstruction* activations,
2814 absl::Span<const int64_t> final_split_spatial_dim_positioning,
2815 int64_t activations_batch_dim, int64_t old_batch_size) {
2816 const int64_t spatial_dim_count = final_split_spatial_dim_positioning.size();
2817
2818 if (final_split_spatial_dim_positioning.size() > 1) {
2819 int64_t start_batch_dim_position = activations_batch_dim + 1;
2820 int64_t start_space_dim_position =
2821 start_batch_dim_position + spatial_dim_count;
2822
2823 std::vector<int64_t> trans_dims(activations->shape().dimensions_size());
2824 absl::c_iota(trans_dims, 0);
2825
2826 for (int i = 0; i < spatial_dim_count; ++i) {
2827 trans_dims[start_batch_dim_position + i] =
2828 start_batch_dim_position + (spatial_dim_count - 1 - i) * 2;
2829 trans_dims[start_space_dim_position + i] =
2830 start_batch_dim_position + i * 2 + 1;
2831 }
2832
2833 TF_ASSIGN_OR_RETURN(activations, MakeTransposeHlo(activations, trans_dims));
2834 }
2835
2836 std::vector<int64_t> batch_collapse_reshape_dims(
2837 activations->shape().dimensions().begin(),
2838 activations->shape().dimensions().end());
2839
2840 const int64_t collapsed_batch_size =
2841 old_batch_size * IPow<int64_t>(ctrl_.number_of_splits, spatial_dim_count);
2842
2843 batch_collapse_reshape_dims.erase(
2844 batch_collapse_reshape_dims.begin() + activations_batch_dim,
2845 batch_collapse_reshape_dims.begin() + activations_batch_dim +
2846 spatial_dim_count);
2847 batch_collapse_reshape_dims[activations_batch_dim] = collapsed_batch_size;
2848
2849 TF_ASSIGN_OR_RETURN(HloInstruction * batch_collapsed_reshape,
2850 MakeReshapeHlo(batch_collapse_reshape_dims, activations));
2851 return batch_collapsed_reshape;
2852 }
2853
PerformSplitSpace(HloInstruction * activations,absl::Span<const int64_t> spatial_dimensions_to_split,int64_t activations_batch_dim,int64_t spatial_split_size,int64_t num_splits)2854 StatusOr<HloInstruction*> ConvolutionVisitor::PerformSplitSpace(
2855 HloInstruction* activations,
2856 absl::Span<const int64_t> spatial_dimensions_to_split,
2857 int64_t activations_batch_dim, int64_t spatial_split_size,
2858 int64_t num_splits) {
2859 const int64_t old_batch_size =
2860 activations->shape().dimensions(activations_batch_dim);
2861
2862 // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
2863 // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
2864 // in the spatial dimension, we generate a gather. E.g. if halo size was 2,
2865 // we'd create a shape of [24] using the gather, and reshape it into [6, 4]
2866 // (4 being the batch).
2867
2868 // The benefit of the above mentioned scheme is that it allows for batch
2869 // growth. Here are some examples of the size increases it causes for a 3x3
2870 // kernel.
2871 // with batch=1, [1,16] -> [4,4] -> [4,6] -> [1,24] growth of 8.
2872 // with batch=2, [2,16] -> [8,4] -> [8,6] -> [1,48] growth of 16.
2873 // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
2874
2875 std::vector<int64_t> reshape_dimensions(
2876 activations->shape().dimensions().begin(),
2877 activations->shape().dimensions().end());
2878
2879 for (auto spatial_dimension_to_split : spatial_dimensions_to_split) {
2880 reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
2881 }
2882
2883 int counter = 0;
2884 for (auto spatial_dimension_to_split : spatial_dimensions_to_split) {
2885 reshape_dimensions.insert(
2886 reshape_dimensions.begin() + (spatial_dimension_to_split + counter),
2887 num_splits);
2888 counter++;
2889 }
2890
2891 TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
2892 MakeReshapeHlo(reshape_dimensions, activations));
2893
2894 return TransposeAndMergeBatch(
2895 batch_increased_reshape,
2896 /*final_split_spatial_dim_positioning=*/spatial_dimensions_to_split,
2897 activations_batch_dim, old_batch_size);
2898 }
2899
PadAndSplitSpace(HloInstruction * activations,absl::Span<const int64_t> spatial_dimensions_to_split,int64_t activations_batch_dim,int64_t high_padding,int64_t low_padding,int64_t spatial_split_size,int64_t num_splits)2900 StatusOr<HloInstruction*> ConvolutionVisitor::PadAndSplitSpace(
2901 HloInstruction* activations,
2902 absl::Span<const int64_t> spatial_dimensions_to_split,
2903 int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding,
2904 int64_t spatial_split_size, int64_t num_splits) {
2905 const int64_t old_batch_size =
2906 activations->shape().dimensions(activations_batch_dim);
2907
2908 // Because we are splitting the spatial dimension, if convolution needed
2909 // padding in the spatial dimension, we materialize it.
2910 if (high_padding || low_padding) {
2911 PaddingConfig padding_config =
2912 MakeNoPaddingConfig(activations->shape().dimensions_size());
2913 for (auto spatial_dimension_to_split : spatial_dimensions_to_split) {
2914 padding_config.mutable_dimensions(spatial_dimension_to_split)
2915 ->set_edge_padding_high(high_padding);
2916 padding_config.mutable_dimensions(spatial_dimension_to_split)
2917 ->set_edge_padding_low(low_padding);
2918 }
2919 HloInstruction* padding =
2920 computation_->AddInstruction(HloInstruction::CreateConstant(
2921 LiteralUtil::Zero(activations->shape().element_type())));
2922 TF_ASSIGN_OR_RETURN(activations,
2923 MakePadHlo(activations, padding, padding_config));
2924 }
2925 VLOG(1) << "Initial padded activations shape "
2926 << activations->shape().ToString() << " old_batch_size "
2927 << old_batch_size << " activations_batch_dim "
2928 << activations_batch_dim;
2929 return PerformSplitSpace(activations, spatial_dimensions_to_split,
2930 activations_batch_dim, spatial_split_size,
2931 num_splits);
2932 }
2933
2934 StatusOr<std::pair<HloInstruction*, std::vector<int64_t>>>
SplitSpace(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64_t & activations_batch_dim,int64_t high_padding,int64_t low_padding,int64_t spatial_split_size,int64_t num_splits,std::vector<int64_t> * spatial_dimensions_to_split,bool is_backprop,bool is_rhs)2935 ConvolutionVisitor::SplitSpace(
2936 HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
2937 int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding,
2938 int64_t spatial_split_size, int64_t num_splits,
2939 std::vector<int64_t>* spatial_dimensions_to_split, bool is_backprop,
2940 bool is_rhs) {
2941 TF_ASSIGN_OR_RETURN(
2942 auto retval,
2943 BringSpaceNextToBatch(activations, dim_numbers, activations_batch_dim,
2944 spatial_dimensions_to_split, is_backprop, is_rhs));
2945
2946 activations = retval.instr;
2947 std::vector<int64_t> transpose_dims = retval.transpose_dims;
2948 TF_ASSIGN_OR_RETURN(
2949 auto new_activations,
2950 PadAndSplitSpace(activations, *spatial_dimensions_to_split,
2951 activations_batch_dim, high_padding, low_padding,
2952 spatial_split_size, num_splits));
2953 return std::make_pair(new_activations, transpose_dims);
2954 }
2955
PropagateOnConstant(HloInstruction * consumer,HloInstruction * producer)2956 StatusOr<HloInstruction*> ConvolutionVisitor::PropagateOnConstant(
2957 HloInstruction* consumer, HloInstruction* producer) {
2958 CHECK(old_to_new_instrs_.contains(producer));
2959 HloInstruction* new_producer = old_to_new_instrs_[producer];
2960 auto prod_transpose_dims = instr_to_dim_permute_map_[new_producer];
2961 std::vector<int64_t> reversed_transpose_dims(prod_transpose_dims.size());
2962 for (int64_t i = 0; i < prod_transpose_dims.size(); ++i) {
2963 reversed_transpose_dims[i] = ReverseDimLookUp(prod_transpose_dims, i);
2964 }
2965 // Bring space next to batch.
2966 TF_ASSIGN_OR_RETURN(consumer,
2967 MakeTransposeHlo(consumer, reversed_transpose_dims));
2968
2969 auto retval = GetSpatialDimsToSplit(producer);
2970 std::vector<int64_t> old_spatial_dims = retval.first;
2971 std::vector<int64_t> new_spatial_dims = retval.second;
2972
2973 auto dim_map = instr_to_dim_map_[producer];
2974 const int64_t old_batch_dim = dim_map[DimMapper(SpaceToBatchDimMap::kBatch)];
2975 const int64_t old_space_dim = old_spatial_dims[0];
2976 const int64_t new_batch_dim = DimLookUp(prod_transpose_dims, old_batch_dim);
2977 const int64_t new_space_dim = new_spatial_dims[0];
2978
2979 const int64_t old_batch_size = producer->shape().dimensions(old_batch_dim);
2980 const int64_t new_batch_size = old_batch_size * ctrl_.number_of_splits;
2981 const int64_t high_padding =
2982 (new_batch_size * new_producer->shape().dimensions(new_space_dim) -
2983 old_batch_size * producer->shape().dimensions(old_space_dim)) /
2984 old_batch_size;
2985
2986 auto new_consumer = PadAndSplitSpace(
2987 consumer, new_spatial_dims, new_batch_dim, high_padding,
2988 /*low_padding=*/0, new_producer->shape().dimensions(new_space_dim),
2989 ctrl_.number_of_splits);
2990
2991 return new_consumer;
2992 }
2993
PropagateOnBackpropFilterConv(HloInstruction * convolution)2994 Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
2995 HloInstruction* convolution) {
2996 auto activations_old = convolution->mutable_operand(0);
2997
2998 const int64_t rhs_dilation =
2999 convolution->window()
3000 .dimensions(GetFirstChosenSpatialDim(convolution))
3001 .window_dilation();
3002
3003 auto original_conv_dims = convolution->convolution_dimension_numbers();
3004
3005 std::vector<int64_t> old_split_spatial_dims(
3006 ctrl_.dimension_from_end_to_convert),
3007 old_split_kernel_spatial_dims(ctrl_.dimension_from_end_to_convert);
3008 for (int i = 0; i < ctrl_.dimension_from_end_to_convert; ++i) {
3009 old_split_spatial_dims[i] = original_conv_dims.input_spatial_dimensions(
3010 GetFirstChosenSpatialDim(convolution) + i);
3011 old_split_kernel_spatial_dims[i] =
3012 original_conv_dims.kernel_spatial_dimensions(
3013 GetFirstChosenSpatialDim(convolution) + i);
3014 }
3015
3016 auto kernel_old = convolution->mutable_operand(1);
3017 const int64_t old_kernel_split_dim_size =
3018 kernel_old->shape().dimensions(old_split_kernel_spatial_dims[0]);
3019
3020 int64_t old_split_dim_size =
3021 activations_old->shape().dimensions(old_split_spatial_dims[0]);
3022
3023 int64_t old_batch_dim = original_conv_dims.input_feature_dimension();
3024 int64_t kernel_old_batch_dim =
3025 original_conv_dims.kernel_input_feature_dimension();
3026 const int64_t old_batch_size =
3027 activations_old->shape().dimensions(old_batch_dim);
3028
3029 CHECK(old_to_new_instrs_.contains(kernel_old) ||
3030 old_to_new_instrs_.contains(activations_old));
3031
3032 HloInstruction* activations_new = nullptr;
3033 HloInstruction* kernel_new = nullptr;
3034 bool activations_locally_space_to_batched = false;
3035 bool kernel_locally_space_to_batched = false;
3036 std::vector<int64_t> permute_dims_kernel, permute_dims;
3037
3038 if (old_to_new_instrs_.contains(activations_old)) {
3039 activations_new = old_to_new_instrs_[activations_old];
3040 permute_dims = instr_to_dim_permute_map_[activations_new];
3041 }
3042
3043 if (old_to_new_instrs_.contains(kernel_old)) {
3044 kernel_new = old_to_new_instrs_[kernel_old];
3045 permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
3046 }
3047
3048 // If activations were not space-to-batched, we space-to-batch them below.
3049 if (!old_to_new_instrs_.contains(activations_old)) {
3050 kernel_new = old_to_new_instrs_[kernel_old];
3051 permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
3052
3053 VLOG(1) << "Space-to-batching activations to enable space-to-depth";
3054
3055 const int64_t new_kernel_space_dim =
3056 DimLookUp(permute_dims_kernel, old_split_kernel_spatial_dims[0]);
3057
3058 const int64_t new_kernel_split_dim_size =
3059 kernel_new->shape().dimensions(new_kernel_space_dim);
3060 const int64_t needed_spatial_size =
3061 rhs_dilation * new_kernel_split_dim_size;
3062 const int64_t pad_size =
3063 needed_spatial_size * ctrl_.number_of_splits - old_split_dim_size;
3064 ConvolutionDimensionNumbers tmp_dim_numbers;
3065 tmp_dim_numbers = original_conv_dims;
3066 TF_ASSIGN_OR_RETURN(
3067 auto retval, SplitSpace(activations_old, tmp_dim_numbers, old_batch_dim,
3068 /*high_padding=*/pad_size, /*low_padding=*/0,
3069 needed_spatial_size, ctrl_.number_of_splits,
3070 &old_split_spatial_dims,
3071 /*is_backprop=*/true));
3072
3073 activations_new = retval.first;
3074
3075 std::vector<int64_t> reversed_transpose_dims(retval.second.size());
3076 for (int64_t i = 0; i < retval.second.size(); ++i) {
3077 reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
3078 }
3079 permute_dims = reversed_transpose_dims;
3080
3081 VLOG(3) << "New Activations " << retval.first->ToString();
3082
3083 activations_locally_space_to_batched = true;
3084 } else if (!old_to_new_instrs_.contains(kernel_old)) {
3085 activations_new = old_to_new_instrs_[activations_old];
3086 permute_dims = instr_to_dim_permute_map_[activations_new];
3087
3088 VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
3089
3090 const int64_t new_space_dim =
3091 DimLookUp(permute_dims, old_split_spatial_dims[0]);
3092 const int64_t new_split_dim_size =
3093 activations_new->shape().dimensions(new_space_dim);
3094 const int64_t needed_spatial_size =
3095 CeilOfRatio(new_split_dim_size, rhs_dilation);
3096 int64_t old_kernel_split_dim_size =
3097 kernel_old->shape().dimensions(old_split_kernel_spatial_dims[0]);
3098 const int64_t pad_size = needed_spatial_size * ctrl_.number_of_splits -
3099 old_kernel_split_dim_size;
3100
3101 ConvolutionDimensionNumbers tmp_dim_numbers;
3102 tmp_dim_numbers = original_conv_dims;
3103 TF_ASSIGN_OR_RETURN(
3104 auto retval,
3105 SplitSpace(kernel_old, tmp_dim_numbers, kernel_old_batch_dim,
3106 /*high_padding=*/pad_size, /*low_padding=*/0,
3107 needed_spatial_size, ctrl_.number_of_splits,
3108 &old_split_kernel_spatial_dims,
3109 /*is_backprop=*/true, /*is_rhs=*/true));
3110
3111 kernel_new = retval.first;
3112
3113 std::vector<int64_t> reversed_transpose_dims(retval.second.size());
3114 for (int64_t i = 0; i < retval.second.size(); ++i) {
3115 reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
3116 }
3117 permute_dims_kernel = reversed_transpose_dims;
3118
3119 VLOG(3) << "New kernel " << retval.first->ToString();
3120
3121 kernel_locally_space_to_batched = true;
3122 }
3123
3124 CHECK_NE(activations_new, nullptr);
3125 CHECK_NE(kernel_new, nullptr);
3126
3127 // TODO(b/189500737): For multi-dimensional space-to-batch, we'd need to add
3128 // an auxiliary dim per converted dimension.
3129 const int64_t new_spatial_dimension =
3130 activations_new->shape().dimensions_size();
3131
3132 auto permuted_conv_dims_numbers = original_conv_dims;
3133
3134 // Note the inversion here : batch and feature are inverted in backprop
3135 // filters.
3136 int64_t activations_batch_dim =
3137 DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
3138 int64_t activations_feature_dim =
3139 DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
3140
3141 const int64_t previous_spatial_dim_count =
3142 original_conv_dims.input_spatial_dimensions_size();
3143 for (int64_t i = 0; i < previous_spatial_dim_count; ++i) {
3144 permuted_conv_dims_numbers.set_input_spatial_dimensions(
3145 i, DimLookUp(permute_dims,
3146 original_conv_dims.input_spatial_dimensions(i)));
3147 permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
3148 i, DimLookUp(permute_dims_kernel,
3149 original_conv_dims.kernel_spatial_dimensions(i)));
3150 }
3151
3152 permuted_conv_dims_numbers.add_input_spatial_dimensions(
3153 new_spatial_dimension);
3154 permuted_conv_dims_numbers.add_kernel_spatial_dimensions(
3155 new_spatial_dimension);
3156 permuted_conv_dims_numbers.add_output_spatial_dimensions(
3157 new_spatial_dimension);
3158
3159 // For the output, make the last dimension size 1.
3160 const int64_t previous_chosen_spatial_dim_in_output =
3161 permuted_conv_dims_numbers.output_spatial_dimensions(
3162 GetFirstChosenSpatialDim(convolution));
3163 permuted_conv_dims_numbers.set_output_spatial_dimensions(
3164 GetFirstChosenSpatialDim(convolution), new_spatial_dimension);
3165 permuted_conv_dims_numbers.set_output_spatial_dimensions(
3166 previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
3167
3168 const int64_t kernel_input_feature_dim = DimLookUp(
3169 permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
3170
3171 const int64_t kernel_output_feature_dim =
3172 DimLookUp(permute_dims_kernel,
3173 original_conv_dims.kernel_output_feature_dimension());
3174
3175 permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
3176 kernel_input_feature_dim);
3177 permuted_conv_dims_numbers.set_kernel_output_feature_dimension(
3178 kernel_output_feature_dim);
3179
3180 std::vector<int64_t> spatial_dimensions_to_split(
3181 ctrl_.count_of_dimensions_to_convert);
3182 const int64_t first_dim_to_split = GetFirstChosenSpatialDim(convolution);
3183 for (int64_t i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
3184 spatial_dimensions_to_split[i] =
3185 permuted_conv_dims_numbers.input_spatial_dimensions(first_dim_to_split +
3186 i);
3187 }
3188
3189 const int64_t kernel_spatial_dimension_to_split =
3190 permuted_conv_dims_numbers.kernel_spatial_dimensions(
3191 GetFirstChosenSpatialDim(convolution));
3192
3193 int64_t new_split_dim_size =
3194 activations_new->shape().dimensions(spatial_dimensions_to_split[0]);
3195
3196 const int64_t kernel_new_split_dim_size =
3197 kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
3198
3199 permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
3200 permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
3201
3202 VLOG(1) << "Propagating on conv activations_batch_dim "
3203 << activations_batch_dim << " spatial_dimension_to_split "
3204 << spatial_dimensions_to_split[0] << " old_batch_size "
3205 << old_batch_size << " new_split_dim_size " << new_split_dim_size;
3206
3207 TF_ASSIGN_OR_RETURN(
3208 auto retval,
3209 BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
3210 activations_batch_dim, &spatial_dimensions_to_split,
3211 /*is_backprop=*/true));
3212
3213 int64_t spatial_dimension_to_split = spatial_dimensions_to_split[0];
3214
3215 std::vector<int64_t> transpose_dims = retval.transpose_dims;
3216 CHECK(!transpose_dims.empty());
3217 activations_new = retval.instr;
3218
3219 VLOG(1) << "Activations_new post BringSpaceNextToBatch "
3220 << activations_new->ToString();
3221 VLOG(1) << "activations_batch_dim " << activations_batch_dim
3222 << " activations_feature_dim " << activations_feature_dim;
3223 const int64_t expected_split_dim_size =
3224 rhs_dilation * kernel_new_split_dim_size;
3225 if (new_split_dim_size != expected_split_dim_size) {
3226 CHECK_LT(new_split_dim_size, expected_split_dim_size);
3227 new_split_dim_size = expected_split_dim_size;
3228 TF_ASSIGN_OR_RETURN(
3229 activations_new,
3230 ChangeSpatialSizeOnSpaceToBatchedShape(
3231 activations_new, activations_batch_dim, old_batch_size,
3232 spatial_dimensions_to_split, new_split_dim_size, true));
3233 }
3234
3235 spatial_dimension_to_split = spatial_dimensions_to_split[0];
3236
3237 auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
3238 LiteralUtil::Zero(activations_new->shape().element_type())));
3239
3240 if (!activations_locally_space_to_batched) {
3241 // Select activations correctly by masking additional space.
3242 TF_ASSIGN_OR_RETURN(
3243 activations_new,
3244 SelectValidPortion(activations_new, activations_old, select_val,
3245 activations_batch_dim, spatial_dimensions_to_split,
3246 old_batch_dim, old_split_spatial_dims));
3247 }
3248 if (!kernel_locally_space_to_batched) {
3249 VLOG(3) << "Selecting the valid kernel area";
3250 // Select kernel correctly by masking additional space.
3251
3252 std::vector<int64_t> new_kernel_split_spatial_dims(
3253 ctrl_.dimension_from_end_to_convert);
3254
3255 // TODO(b/189500737) : Extend this once
3256 // IncreaseSpatialSizeOnSpaceToBatchedShape returns all dimensions.
3257 new_kernel_split_spatial_dims[0] = kernel_spatial_dimension_to_split;
3258
3259 TF_ASSIGN_OR_RETURN(
3260 kernel_new,
3261 SelectValidPortion(kernel_new, kernel_old, select_val,
3262 /*new_batch_dim=*/kernel_input_feature_dim,
3263 new_kernel_split_spatial_dims,
3264 /*old_batch_dim=*/
3265 original_conv_dims.kernel_input_feature_dimension(),
3266 old_split_kernel_spatial_dims));
3267 }
3268
3269 // Create the new convolution dim numbers.
3270 auto new_dim_numbers = permuted_conv_dims_numbers;
3271
3272 VLOG(2) << "New dim numbers " << new_dim_numbers.DebugString();
3273
3274 const int64_t inherent_low_padding =
3275 convolution->window()
3276 .dimensions(GetFirstChosenSpatialDim(convolution))
3277 .padding_low();
3278
3279 const int64_t inherent_high_padding =
3280 convolution->window()
3281 .dimensions(GetFirstChosenSpatialDim(convolution))
3282 .padding_high();
3283
3284 std::vector<HloInstruction*> activations_chunks;
3285
3286 // Insert slices for low padding.
3287 for (int64_t i = 0; i < inherent_low_padding; ++i) {
3288 HloInstruction* activations_to_use = nullptr;
3289 if (i == 0) {
3290 activations_to_use = activations_new;
3291 } else {
3292 activations_to_use = activations_chunks.back();
3293 }
3294 TF_ASSIGN_OR_RETURN(
3295 HloInstruction * activations_slice,
3296 HaloDuplicateWithSlice(activations_to_use, spatial_dimensions_to_split,
3297 activations_batch_dim, /*low_padding=*/1,
3298 /*halo_size=*/0));
3299 activations_chunks.push_back(activations_slice);
3300 }
3301 // Reverse the low padding slices because we created them in the opposite
3302 // order above.
3303 absl::c_reverse(activations_chunks);
3304
3305 const int64_t expanded_kernel =
3306 old_kernel_split_dim_size * rhs_dilation - (rhs_dilation - 1);
3307 const int64_t overlap_count =
3308 old_split_dim_size - expanded_kernel + 1 +
3309 (inherent_low_padding < 0 ? inherent_low_padding : 0) +
3310 (inherent_high_padding < 0 ? inherent_high_padding : 0);
3311 VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
3312 << inherent_low_padding << " inherent_high_padding "
3313 << inherent_high_padding;
3314
3315 const int64_t total_overlap_count =
3316 overlap_count + (inherent_low_padding > 0 ? inherent_low_padding : 0) +
3317 (inherent_high_padding > 0 ? inherent_high_padding : 0);
3318
3319 // Insert original activations.
3320 for (int64_t i = 0; i < overlap_count; ++i) {
3321 HloInstruction* activations_to_use = nullptr;
3322 HloInstruction* activations_slice = nullptr;
3323 if (i == 0) {
3324 activations_to_use = activations_new;
3325 if (inherent_low_padding < 0) {
3326 TF_ASSIGN_OR_RETURN(
3327 activations_slice,
3328 HaloDuplicateWithSlice(
3329 activations_to_use, spatial_dimensions_to_split,
3330 activations_batch_dim,
3331 /*low_padding=*/inherent_low_padding, /*halo_size=*/0));
3332 } else {
3333 activations_slice = activations_to_use;
3334 }
3335 } else {
3336 activations_to_use = activations_chunks.back();
3337
3338 TF_ASSIGN_OR_RETURN(activations_slice,
3339 HaloDuplicateWithSlice(
3340 activations_to_use, spatial_dimensions_to_split,
3341 activations_batch_dim, /*low_padding=*/-1,
3342 /*halo_size=*/0));
3343 }
3344
3345 activations_chunks.push_back(activations_slice);
3346 }
3347
3348 int64_t high_padding_to_materialize = 0;
3349
3350 if (inherent_high_padding > 0) {
3351 high_padding_to_materialize =
3352 std::max(total_overlap_count -
3353 (std::max(overlap_count, static_cast<int64_t>(0)) +
3354 std::max(inherent_low_padding, static_cast<int64_t>(0))),
3355 static_cast<int64_t>(0));
3356 }
3357
3358 // Insert slices for high padding.
3359 for (int64_t i = 0; i < high_padding_to_materialize; ++i) {
3360 HloInstruction* activations_to_use = nullptr;
3361 activations_to_use = activations_chunks.back();
3362
3363 TF_ASSIGN_OR_RETURN(
3364 HloInstruction * activations_slice,
3365 HaloDuplicateWithSlice(activations_to_use, spatial_dimensions_to_split,
3366 activations_batch_dim,
3367 /*low_padding=*/-1, /*halo_size=*/0));
3368 activations_chunks.push_back(activations_slice);
3369 }
3370
3371 for (int64_t i = 0; i < activations_chunks.size(); ++i) {
3372 std::vector<int64_t> input_sizes(
3373 activations_chunks[i]->shape().dimensions().begin(),
3374 activations_chunks[i]->shape().dimensions().end());
3375 // Insert 1-sized dimension at the end
3376 input_sizes.push_back(1);
3377 TF_ASSIGN_OR_RETURN(activations_chunks[i],
3378 MakeReshapeHlo(input_sizes, activations_chunks[i]));
3379 VLOG(1) << "new_spatial_dimension " << new_spatial_dimension << " slice "
3380 << activations_chunks[i]->ToString();
3381 }
3382
3383 TF_ASSIGN_OR_RETURN(
3384 activations_new,
3385 MakeConcatHlo(absl::MakeSpan(activations_chunks), new_spatial_dimension));
3386
3387 // Reshape the kernel with additional spatial dim.
3388 std::vector<int64_t> kernel_sizes(kernel_new->shape().dimensions().begin(),
3389 kernel_new->shape().dimensions().end());
3390 // Insert 1-sized dimension at the end
3391 kernel_sizes.push_back(1);
3392 TF_ASSIGN_OR_RETURN(kernel_new, MakeReshapeHlo(kernel_sizes, kernel_new));
3393
3394 auto new_window = convolution->window();
3395 new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution))
3396 ->set_padding_high(-(rhs_dilation - 1));
3397 new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution))
3398 ->set_padding_low(0);
3399 new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution))
3400 ->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
3401
3402 // Set the window for the additional spatial dim. This is a vanilla window.
3403 auto window_dim = new_window.add_dimensions();
3404 window_dim->set_base_dilation(1);
3405 window_dim->set_size(1);
3406 int64_t stride = 1;
3407 // This condition means there's only a single overlap possible (as the shapes
3408 // were grown due to padding). In this case, we increase the stride.
3409 if (inherent_low_padding > total_overlap_count) {
3410 stride = activations_chunks.size();
3411 }
3412 window_dim->set_stride(stride);
3413 window_dim->set_padding_low(0);
3414 window_dim->set_padding_high(0);
3415 window_dim->set_window_reversal(false);
3416 window_dim->set_window_dilation(1);
3417
3418 TF_ASSIGN_OR_RETURN(
3419 HloInstruction * new_conv,
3420 MakeConvolveHlo(
3421 activations_new, kernel_new, convolution->feature_group_count(),
3422 convolution->batch_group_count(), new_window, new_dim_numbers,
3423 convolution->precision_config(),
3424 /*preferred_element_type=*/convolution->shape().element_type()));
3425 convolution->SetupDerivedInstruction(new_conv);
3426
3427 VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
3428
3429 std::vector<int64_t> output_sizes(new_conv->shape().dimensions().begin(),
3430 new_conv->shape().dimensions().end());
3431
3432 output_sizes.erase(output_sizes.begin() +
3433 new_dim_numbers.output_spatial_dimensions(
3434 GetFirstChosenSpatialDim(convolution)));
3435
3436 TF_ASSIGN_OR_RETURN(new_conv, MakeReshapeHlo(output_sizes, new_conv));
3437
3438 old_to_new_instrs_[convolution] = new_conv;
3439 VLOG(1) << "Space-to-featured convolution " << new_conv->ToString();
3440
3441 std::vector<int64_t> dim_map(NumMappedDims());
3442 dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] =
3443 original_conv_dims.output_batch_dimension();
3444 dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
3445 original_conv_dims.output_feature_dimension();
3446 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] =
3447 original_conv_dims.output_spatial_dimensions(
3448 GetFirstChosenSpatialDim(convolution));
3449 instr_to_dim_map_[convolution] = dim_map;
3450
3451 std::vector<int64_t> trans_dims(convolution->shape().dimensions_size());
3452 absl::c_iota(trans_dims, 0);
3453 instr_to_dim_permute_map_[new_conv] = trans_dims;
3454
3455 return OkStatus();
3456 }
3457
3458 HloInstruction*
DoesConvolutionFeedReduceWindowOrSelectAndScatter(HloInstruction * instr,int64_t depth=kReduceWindowSearchDepth)3459 ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
3460 HloInstruction* instr, int64_t depth = kReduceWindowSearchDepth) {
3461 if (depth == 0) {
3462 return nullptr;
3463 }
3464
3465 for (auto user : instr->users()) {
3466 if (user->opcode() == HloOpcode::kReduceWindow ||
3467 user->opcode() == HloOpcode::kSelectAndScatter) {
3468 return user;
3469 }
3470 // Stop the search if these ops are encountered.
3471 if (user->opcode() == HloOpcode::kConvolution ||
3472 user->opcode() == HloOpcode::kPad ||
3473 user->opcode() == HloOpcode::kTranspose) {
3474 continue;
3475 }
3476 auto ret =
3477 DoesConvolutionFeedReduceWindowOrSelectAndScatter(user, depth - 1);
3478 if (ret != nullptr) {
3479 return ret;
3480 }
3481 }
3482 return nullptr;
3483 }
3484
DoesConvolutionFeedUnpropagatableOp(HloInstruction * instr,int64_t depth)3485 bool ConvolutionVisitor::DoesConvolutionFeedUnpropagatableOp(
3486 HloInstruction* instr, int64_t depth) {
3487 auto key = std::make_pair(instr, depth);
3488 if (unpropagatability_cache_.contains(key)) {
3489 return unpropagatability_cache_[key];
3490 }
3491
3492 if (depth == 0 || instr->user_count() == 0) {
3493 unpropagatability_cache_[key] = false;
3494 return false;
3495 }
3496
3497 for (auto user : instr->users()) {
3498 if (IsOpcodeNonPropagatable(user)) {
3499 unpropagatability_cache_[key] = true;
3500 return true;
3501 }
3502
3503 int64_t depth_to_use = depth;
3504 // When we see a convolution, we reduce the depth to look further for.
3505 if (user->opcode() == HloOpcode::kConvolution) {
3506 depth_to_use--;
3507 }
3508
3509 if (DoesConvolutionFeedUnpropagatableOp(user, depth_to_use)) {
3510 unpropagatability_cache_[key] = true;
3511 return true;
3512 }
3513 }
3514
3515 unpropagatability_cache_[key] = false;
3516 return false;
3517 }
3518
IsSpaceToBatchedSpaceSizeSuitable(HloInstruction * instr)3519 bool ConvolutionVisitor::IsSpaceToBatchedSpaceSizeSuitable(
3520 HloInstruction* instr) {
3521 CHECK(instr->opcode() == HloOpcode::kSelectAndScatter ||
3522 instr->opcode() == HloOpcode::kReduceWindow);
3523 auto old_producer = instr->mutable_operand(0);
3524
3525 auto dim_map_val_op = instr_to_dim_map_[old_producer];
3526 const int64_t old_space_dim =
3527 dim_map_val_op[DimMapper(SpaceToBatchDimMap::kSpace0)];
3528 auto first_operand = old_to_new_instrs_[old_producer];
3529 auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
3530 const int64_t new_space_dim =
3531 DimLookUp(permute_dims_first_operand, old_space_dim);
3532
3533 const int64_t window_size = instr->window().dimensions(old_space_dim).size();
3534
3535 if (first_operand->shape().dimensions(new_space_dim) < window_size) {
3536 return false;
3537 }
3538
3539 return true;
3540 }
3541
GetConvolutionDetails(HloInstruction * convolution,ConvolutionDimensionNumbers & dim_numbers)3542 ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
3543 HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
3544 auto activations = convolution->mutable_operand(0);
3545
3546 auto kernel = convolution->mutable_operand(1);
3547 const auto& kernel_shape = kernel->shape();
3548 const int64_t kernel_spatial_dim = dim_numbers.kernel_spatial_dimensions(
3549 GetFirstChosenSpatialDim(convolution));
3550 int64_t kernel_spatial_dim_size = kernel_shape.dimensions(kernel_spatial_dim);
3551
3552 if (IsForwardWindowDilatedConv(convolution, dim_numbers)) {
3553 const int64_t window_dilation_factor =
3554 convolution->window()
3555 .dimensions(GetFirstChosenSpatialDim(convolution))
3556 .window_dilation();
3557 kernel_spatial_dim_size =
3558 (kernel_spatial_dim_size - 1) * (window_dilation_factor - 1) +
3559 kernel_spatial_dim_size;
3560 }
3561
3562 std::vector<int64_t> spatial_dimensions_to_split =
3563 GetChosenSpatialDims(convolution);
3564 const int64_t spatial_dimension_to_split = spatial_dimensions_to_split[0];
3565
3566 const int64_t input_dim_size =
3567 activations->shape().dimensions(spatial_dimension_to_split);
3568
3569 const int64_t inherent_low_padding =
3570 convolution->window()
3571 .dimensions(GetFirstChosenSpatialDim(convolution))
3572 .padding_low();
3573 const int64_t inherent_high_padding =
3574 convolution->window()
3575 .dimensions(GetFirstChosenSpatialDim(convolution))
3576 .padding_high();
3577
3578 const int64_t stride = convolution->window()
3579 .dimensions(GetFirstChosenSpatialDim(convolution))
3580 .stride();
3581
3582 const int64_t base_dilation_factor =
3583 convolution->window()
3584 .dimensions(GetFirstChosenSpatialDim(convolution))
3585 .base_dilation();
3586
3587 bool is_base_dilated = base_dilation_factor > 1;
3588
3589 const int64_t spatial_size = input_dim_size +
3590 (is_base_dilated ? 0 : inherent_low_padding) +
3591 inherent_high_padding;
3592
3593 const int64_t last_overlap = base_dilation_factor == inherent_low_padding
3594 ? kernel_spatial_dim_size
3595 : kernel_spatial_dim_size - 1;
3596 const int64_t halo_size = is_base_dilated
3597 ? last_overlap / base_dilation_factor
3598 : kernel_spatial_dim_size - 1;
3599
3600 const int64_t high_padding_for_base_dilation =
3601 inherent_low_padding == 0 ? base_dilation_factor - 1
3602 : last_overlap % base_dilation_factor;
3603
3604 const int64_t high_padding_for_conv =
3605 is_base_dilated ? high_padding_for_base_dilation : 0;
3606
3607 const int64_t low_padding_for_conv =
3608 is_base_dilated && (base_dilation_factor != inherent_low_padding)
3609 ? inherent_low_padding
3610 : 0;
3611
3612 return ConvDetails{spatial_dimensions_to_split,
3613 inherent_low_padding,
3614 inherent_high_padding,
3615 stride,
3616 spatial_size,
3617 base_dilation_factor,
3618 halo_size,
3619 high_padding_for_conv,
3620 low_padding_for_conv,
3621 kernel_spatial_dim_size,
3622 input_dim_size};
3623 }
3624
PerformSpaceToBatchOnConvolution(HloInstruction * convolution)3625 Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
3626 HloInstruction* convolution) {
3627 if (!ConsumeFuel("space-to-batch-converter", [&] {
3628 return "Skipping space-to-batch propagation because fuel over\n";
3629 })) {
3630 return OkStatus();
3631 }
3632 VLOG(1) << "Handling conv " << convolution->ToString();
3633
3634 changed_ = false;
3635
3636 ConvolutionDimensionNumbers dim_numbers =
3637 convolution->convolution_dimension_numbers();
3638
3639 ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
3640
3641 int64_t activations_batch_dim = dim_numbers.input_batch_dimension();
3642
3643 auto activations = convolution->mutable_operand(0);
3644
3645 VLOG(1) << "spatial size " << c.spatial_size;
3646
3647 // A very primitive cost model to thwart propagations on tiny shapes.
3648 if (c.spatial_size < 2 * ctrl_.number_of_splits) {
3649 return OkStatus();
3650 }
3651
3652 auto original_conv = convolution;
3653
3654 const int64_t output_spatial_dim = dim_numbers.output_spatial_dimensions(
3655 GetFirstChosenSpatialDim(convolution));
3656 const int64_t output_offsets =
3657 convolution->shape().dimensions(output_spatial_dim);
3658 const int64_t output_offsets_per_split =
3659 CeilOfRatio(output_offsets, ctrl_.number_of_splits);
3660
3661 int64_t spatial_split_size =
3662 CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
3663 // Keep increasing the split size so that overall size isn't smaller than the
3664 // original spatial dimension.
3665 while (spatial_split_size * ctrl_.number_of_splits - c.spatial_size < 0) {
3666 spatial_split_size += c.stride;
3667 }
3668
3669 auto reduce_window_or_select_and_scatter =
3670 DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
3671
3672 if (reduce_window_or_select_and_scatter != nullptr &&
3673 reduce_window_or_select_and_scatter->shape().IsArray() &&
3674 reduce_window_or_select_and_scatter->shape().rank() ==
3675 convolution->shape().rank()) {
3676 VLOG(2)
3677 << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
3678 // Take into account the stride of the reduce window while choosing the
3679 // spatial_split_size. This will guarantee propagation through reduce
3680 // windows.
3681 const int64_t win_stride =
3682 std::max(reduce_window_or_select_and_scatter->window()
3683 .dimensions(output_spatial_dim)
3684 .stride(),
3685 static_cast<int64_t>(1));
3686 CHECK_NE(win_stride, 0)
3687 << "Bad op " << reduce_window_or_select_and_scatter->ToString();
3688 CHECK_NE(c.stride, 0) << "Bad op " << convolution->ToString();
3689 while ((spatial_split_size / c.stride) % win_stride != 0) {
3690 spatial_split_size += c.stride;
3691 }
3692 }
3693
3694 const int64_t slice_size = spatial_split_size + c.halo_size;
3695
3696 const int64_t low_pad_to_handle_base_dilation =
3697 (c.base_dilation_factor > 1 &&
3698 c.base_dilation_factor == c.inherent_low_padding)
3699 ? 1
3700 : 0;
3701
3702 // Pad spatial dim.
3703 int64_t pad_size =
3704 spatial_split_size * ctrl_.number_of_splits - c.spatial_size;
3705
3706 bool handle_low_pad_in_first_reshape = false;
3707 if (pad_size > low_pad_to_handle_base_dilation) {
3708 pad_size -= low_pad_to_handle_base_dilation;
3709 handle_low_pad_in_first_reshape = true;
3710 }
3711
3712 VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
3713 << c.stride << " slice_size " << slice_size;
3714 VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimensions_to_split[0]
3715 << " num_splits " << ctrl_.number_of_splits
3716 << " kernel_spatial_dim_size " << c.kernel_spatial_dim_size;
3717 std::vector<int64_t> spatial_dimensions_to_split =
3718 c.spatial_dimensions_to_split;
3719 TF_ASSIGN_OR_RETURN(
3720 auto retval,
3721 SplitSpace(
3722 activations, dim_numbers, activations_batch_dim,
3723 /*high_padding=*/c.inherent_high_padding + pad_size,
3724 /*low_padding=*/c.base_dilation_factor == 1 ? c.inherent_low_padding
3725 : handle_low_pad_in_first_reshape ? low_pad_to_handle_base_dilation
3726 : 0,
3727 spatial_split_size, ctrl_.number_of_splits,
3728 &spatial_dimensions_to_split));
3729 HloInstruction* batch_increased_reshape = retval.first;
3730 convolution->SetupDerivedInstruction(batch_increased_reshape);
3731
3732 VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
3733
3734 TF_ASSIGN_OR_RETURN(
3735 activations,
3736 HaloDuplicateWithSlice(
3737 batch_increased_reshape, spatial_dimensions_to_split,
3738 activations_batch_dim,
3739 /*low_padding=*/
3740 handle_low_pad_in_first_reshape ? 0 : low_pad_to_handle_base_dilation,
3741 c.halo_size));
3742
3743 VLOG(1) << "Batch merge done " << activations->ToString();
3744
3745 // Now, we rewrite the convolution with a larger batch.
3746
3747 // Create the new convolution dim numbers.
3748 auto new_dim_numbers = dim_numbers;
3749
3750 // We will generate output such that batch is followed by the split spatial
3751 // dimension.
3752 const int64_t rank = convolution->shape().rank();
3753 std::vector<int64_t> transpose_dims(rank);
3754 int dim_count = 0;
3755 std::map<int64_t, int64_t> dim_translator;
3756
3757 for (int j = 0; j < dim_numbers.output_spatial_dimensions_size(); ++j) {
3758 if (j == GetFirstChosenSpatialDim(convolution)) {
3759 dim_translator[dim_numbers.output_batch_dimension()] = dim_count;
3760 new_dim_numbers.set_output_batch_dimension(dim_count++);
3761 }
3762 dim_translator[dim_numbers.output_spatial_dimensions(j)] = dim_count;
3763 new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
3764 dim_count++;
3765 }
3766
3767 dim_translator[dim_numbers.output_feature_dimension()] = dim_count;
3768 new_dim_numbers.set_output_feature_dimension(dim_count);
3769
3770 int p = 0;
3771 for (const auto& entry : dim_translator) {
3772 transpose_dims[p] = entry.second;
3773 p++;
3774 }
3775 VLOG(1) << "New dim numbers " << new_dim_numbers.DebugString()
3776 << " batch dim " << new_dim_numbers.input_batch_dimension();
3777 auto new_window = convolution->window();
3778 const int64_t first_dim = GetFirstChosenSpatialDim(convolution);
3779 for (int i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
3780 new_window.mutable_dimensions(first_dim + i)
3781 ->set_padding_high(c.high_padding_for_conv);
3782 new_window.mutable_dimensions(first_dim + i)
3783 ->set_padding_low(c.low_padding_for_conv);
3784 }
3785 TF_ASSIGN_OR_RETURN(
3786 HloInstruction * new_conv,
3787 MakeConvolveHlo(
3788 activations, /*rhs=*/convolution->mutable_operand(1),
3789 convolution->feature_group_count(), convolution->batch_group_count(),
3790 new_window, new_dim_numbers, convolution->precision_config(),
3791 /*preferred_element_type=*/convolution->shape().element_type()));
3792 convolution->SetupDerivedInstruction(new_conv);
3793
3794 // If the activations were to be batch-to-spaced again, simply use the
3795 // original value.
3796 batch_to_space_map_[convolution->mutable_operand(0)] =
3797 convolution->mutable_operand(0);
3798
3799 VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
3800
3801 std::vector<int64_t> new_output_split_spatial_dims(
3802 ctrl_.count_of_dimensions_to_convert),
3803 old_output_split_spatial_dims(ctrl_.count_of_dimensions_to_convert);
3804 for (int i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
3805 old_output_split_spatial_dims[i] =
3806 dim_numbers.output_spatial_dimensions(first_dim + i);
3807 new_output_split_spatial_dims[i] =
3808 new_dim_numbers.output_spatial_dimensions(first_dim + i);
3809 }
3810
3811 const int64_t output_batch_dim = new_dim_numbers.output_batch_dimension();
3812
3813 auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
3814 LiteralUtil::Zero(new_conv->shape().element_type())));
3815
3816 TF_ASSIGN_OR_RETURN(
3817 new_conv,
3818 SelectValidPortion(new_conv, original_conv, select_val, output_batch_dim,
3819 new_output_split_spatial_dims,
3820 dim_numbers.output_batch_dimension(),
3821 old_output_split_spatial_dims));
3822 old_to_new_instrs_[original_conv] = new_conv;
3823
3824 std::vector<int64_t> dim_map(NumMappedDims());
3825 dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] =
3826 dim_numbers.output_batch_dimension();
3827 dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
3828 dim_numbers.output_feature_dimension();
3829 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] =
3830 dim_numbers.output_spatial_dimensions(
3831 GetFirstChosenSpatialDim(convolution));
3832 instr_to_dim_map_[original_conv] = dim_map;
3833
3834 instr_to_dim_permute_map_[new_conv] = std::vector<int64_t>(transpose_dims);
3835 if (non_propagatable_instrs_.count(convolution) > 0) {
3836 non_propagatable_instrs_.erase(convolution);
3837 }
3838 TF_CHECK_OK(PropagateOnUsers(original_conv));
3839
3840 changed_ = true;
3841
3842 return OkStatus();
3843 }
3844
3845 } // namespace
3846
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)3847 StatusOr<bool> SpaceToBatchConverter::Run(
3848 HloModule* module,
3849 const absl::flat_hash_set<absl::string_view>& execution_threads) {
3850 XLA_VLOG_LINES(
3851 2, "SpaceToBatchConverter::Run(), before:\n" + module->ToString());
3852 bool changed = false;
3853
3854 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
3855 ConvolutionVisitor visitor(ctrl_, comp);
3856 if (visitor.Run().ValueOrDie()) {
3857 changed = true;
3858 }
3859 VLOG(1) << "Done operating on computation";
3860 }
3861 XLA_VLOG_LINES(2,
3862 "SpaceToBatchConverter::Run(), after:\n" + module->ToString());
3863 return changed;
3864 }
3865
3866 } // namespace xla
3867