• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/space_to_batch_converter.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <iterator>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <tuple>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/algorithm.h"
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/debug_options_flags.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
42 #include "tensorflow/compiler/xla/service/shape_inference.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/statusor.h"
46 #include "tensorflow/compiler/xla/types.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/bitmap.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/math/math_util.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/stream_executor/lib/statusor.h"
55 
56 namespace xla {
57 
58 namespace {
59 
60 namespace m = match;
61 
62 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
63 // operations with small batch counts into convolutions with larger batch
64 // counts by moving space to batch.
65 class ConvolutionVisitor {
66  public:
67   // Top-level function to begin space-to-batch conversion.
68   Status PerformSpaceToBatchOnConvolution(HloInstruction* convolution);
69 
70   // Struct containing details about a convolution.
71   struct ConvDetails {
72     std::vector<int64_t> spatial_dimensions_to_split;
73     int64_t inherent_low_padding, inherent_high_padding, stride, spatial_size,
74         base_dilation_factor, halo_size, high_padding_for_conv,
75         low_padding_for_conv, kernel_spatial_dim_size, input_dim_size;
76   };
77 
78   // Return a struct containing various necessary information pieces for
79   // performing space-to-batch on a convolution.
80   ConvDetails GetConvolutionDetails(HloInstruction* convolution,
81                                     ConvolutionDimensionNumbers& dim_numbers);
82 
83   // Returns the set of old and new spatial dimensions respectively.
84   std::pair<std::vector<int64_t>, std::vector<int64_t>> GetSpatialDimsToSplit(
85       HloInstruction* old_operand);
86 
87   // Returns if the convolution is a forward window dilated convolution.
88   bool IsForwardWindowDilatedConv(HloInstruction* convolution,
89                                   ConvolutionDimensionNumbers& dim_numbers);
90 
91   // Function that determines if space-to-batch can be propagated into the
92   // consumer. Such propagation is only possible when all required operands are
93   // space-to-batch'ed.
94   bool CanPropagate(HloInstruction* consumer, HloInstruction* producer);
95 
96   // Returns true if the op has all its direct and indirect operands being
97   // created via broadcasts. Consumer uses op, and is space-to-batched.
98   // instructions_to_transform returns the reverse post order instruction graph.
99   bool IsBroadcastTree(HloInstruction* op, HloInstruction* consumer,
100                        std::vector<HloInstruction*>& instructions_to_transform);
101 
102   // Replicates the broadcast tree with space-to-batched instructions.
103   void RewriteBroadcastTree(
104       HloInstruction* producer,
105       std::vector<HloInstruction*>& instructions_to_transform);
106 
107   // Propagate space-to-batch on a broadcast instruction.
108   void PropagateOnBroadcast(HloInstruction* consumer, HloInstruction* producer);
109 
110   // Returns false if the opcode should definitely not be propagated upon.
111   bool IsOpcodeNonPropagatable(HloInstruction* consumer);
112 
113   // This function checks if the HLO instruction supports propagation.
114   bool SupportedOpForPropagation(HloInstruction* consumer,
115                                  HloInstruction* producer);
116 
117   // Method that checks validity of Broadcast propagation.
118   bool IsBroadcastPropagatable(HloInstruction* broadcast,
119                                HloInstruction* old_other_op);
120 
121   // Propagates space-to-batch on the op, and returns a bool that indicates if
122   // the users of the op need to be propagated through.
123   StatusOr<bool> Propagate(HloInstruction* consumer, HloInstruction* producer);
124 
125   // Splits the given spatial dimension on the activations and returns the
126   // new instructions, and the dimension permutation of the new shape.
127   StatusOr<std::pair<HloInstruction*, std::vector<int64_t>>> SplitSpace(
128       HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
129       int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding,
130       int64_t spatial_split_size, int64_t num_splits,
131       std::vector<int64_t>* spatial_dimensions_to_split,
132       bool is_backprop = false, bool is_rhs = false);
133 
134   // Performs the actual dimension splitting.
135   StatusOr<HloInstruction*> PerformSplitSpace(
136       HloInstruction* activations,
137       absl::Span<const int64_t> spatial_dimensions_to_split,
138       int64_t activations_batch_dim, int64_t spatial_split_size,
139       int64_t num_splits);
140 
141   // Helper function that puts individually split dimensions together, and
142   // merges the batch(es).
143   // The input activations dimensions are ... B, B0, S0, B1, S1, ... Bn, Sn, ...
144   // The output dimensions will be ..., B, S0, S1,.. Sn, ...
145   StatusOr<HloInstruction*> TransposeAndMergeBatch(
146       HloInstruction* activations,
147       absl::Span<const int64_t> final_split_spatial_dim_positioning,
148       int64_t activations_batch_dim, int64_t old_batch_size);
149 
150   // Helper function for the SplitSpace function above. Handles padding and
151   // reshaping to generate space-to-batched shape.
152   StatusOr<HloInstruction*> PadAndSplitSpace(
153       HloInstruction* activations,
154       absl::Span<const int64_t> spatial_dimensions_to_split,
155       int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding,
156       int64_t spatial_split_size, int64_t num_splits);
157 
158   // Perform space-to-batch propagation on constants.
159   StatusOr<HloInstruction*> PropagateOnConstant(HloInstruction* consumer,
160                                                 HloInstruction* producer);
161 
162   // Perform space-to-batch propagation on the convolution. Assumes the
163   // activations were already space-to-batched.
164   Status PropagateOnConv(HloInstruction* convolution);
165 
166   // Perform space-to-batch propagation on concatenate.
167   Status PropagateOnConcat(HloInstruction* concat);
168 
169   // Perform space-to-batch propagation on reverse.
170   Status PropagateOnReverse(HloInstruction* reverse);
171 
172   // Perform space-to-batch propagation on reverse.
173   Status PropagateOnPad(HloInstruction* pad);
174 
175   // Perform space-to-batch propagation on the backprop filter convolution.
176   // Assumes the activations and kernel were already space-to-batched.
177   Status PropagateOnBackpropFilterConv(HloInstruction* convolution);
178 
179   // Method that checks validity of space-to-batch on a given convolution.
180   bool IsConvSuitableForSpaceToBatch(HloInstruction* convolution);
181 
182   // Method that returns true if this is a backprop filter convolution.
183   bool IsThisBackPropFilterConv(HloInstruction* convolution);
184 
185   // Once a convolution has been space-to-batch'ed, this function will
186   // transitively propagate the space-to-batch-ness on rest of the graph.
187   Status PropagateOnUsers(HloInstruction* old_conv);
188 
189   // Generates masked output with valid data. This is useful when larger shapes
190   // are generated due to space-to-batch.
191   StatusOr<HloInstruction*> SelectValidPortion(
192       HloInstruction* new_instr, HloInstruction* old_instr,
193       HloInstruction* select_val, int64_t new_batch_dim,
194       absl::Span<const int64_t> new_space_dims, int64_t old_batch_dim,
195       absl::Span<const int64_t> old_space_dims);
196 
197   struct SpaceNextToBatchDetails {
198     HloInstruction* instr;
199     std::vector<int64_t> transpose_dims;
200   };
201 
202   // Performs tranposition so that space dimension follows the batch dimension.
203   StatusOr<SpaceNextToBatchDetails> BringSpaceNextToBatch(
204       HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
205       int64_t& activations_batch_dim,
206       std::vector<int64_t>* spatial_dimensions_to_split,
207       bool is_backprop = false, bool is_rhs = false);
208 
209   // Decreases the spatial dimension size in an already space-to-batched shape
210   // so that the new size is new_spatial_dim_size.
211   StatusOr<HloInstruction*> ChangeSpatialSizeOnSpaceToBatchedShape(
212       HloInstruction* activations, int64_t batch_dimension,
213       int64_t old_batch_size,
214       absl::Span<const int64_t> spatial_dimensions_to_split,
215       int64_t new_spatial_dim_size, bool increase_spatial_size = false);
216 
217   // Turns B, S0, S1, ..., Sn into B, B0, S0, B1, S1,... Bn, Sn.
218   StatusOr<HloInstruction*> SplitAndTransposeMergedBatch(
219       HloInstruction* activations, int64_t batch_dimension,
220       int64_t old_batch_size, absl::Span<const int64_t> spatial_dimensions);
221 
222   // Function that converts spaced-to-batch shape back to the original.
223   StatusOr<HloInstruction*> BatchToSpace(HloInstruction* old_instr);
224 
225   // Duplicates elements at boundaries.
226   StatusOr<HloInstruction*> HaloDuplicateWithSlice(
227       HloInstruction* activations,
228       absl::Span<const int64_t> spatial_dimensions_to_split,
229       int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size,
230       HloInstruction* pad_val = nullptr);
231 
232   // Runs the visitor on a computation.
233   StatusOr<bool> Run();
234 
235   // Returns whether any convolution ops were rewritten.
changed() const236   const bool changed() const { return changed_; }
237 
238   ~ConvolutionVisitor() = default;
239 
240   explicit ConvolutionVisitor(SpaceToBatchController ctrl,
241                               HloComputation* computation);
242 
GetFirstChosenSpatialDim(HloInstruction * convolution)243   int64_t GetFirstChosenSpatialDim(HloInstruction* convolution) {
244     const int64_t dim_count = ctrl_.count_of_dimensions_to_convert;
245     const int64_t end_point = convolution->convolution_dimension_numbers()
246                                   .input_spatial_dimensions_size() -
247                               ctrl_.dimension_from_end_to_convert;
248     return end_point - dim_count + 1;
249   }
250 
GetChosenSpatialDims(HloInstruction * convolution)251   std::vector<int64_t> GetChosenSpatialDims(HloInstruction* convolution) {
252     const int64_t dim_count = ctrl_.count_of_dimensions_to_convert;
253     const int64_t first_dim = GetFirstChosenSpatialDim(convolution);
254     std::vector<int64_t> dims(dim_count);
255     for (int i = 0; i < dim_count; ++i) {
256       dims[i] =
257           convolution->convolution_dimension_numbers().input_spatial_dimensions(
258               first_dim + i);
259     }
260     return dims;
261   }
262 
DimLookUp(absl::Span<const int64_t> permute_dims,int64_t id)263   int64_t DimLookUp(absl::Span<const int64_t> permute_dims, int64_t id) {
264     return permute_dims[id];
265   }
266 
DimMapper(SpaceToBatchDimMap s)267   int DimMapper(SpaceToBatchDimMap s) { return static_cast<int>(s); }
268 
ReverseDimLookUp(absl::Span<const int64_t> permute_dims,int64_t id)269   int64_t ReverseDimLookUp(absl::Span<const int64_t> permute_dims, int64_t id) {
270     return std::distance(permute_dims.begin(), absl::c_find(permute_dims, id));
271   }
272 
273   HloInstruction* DoesConvolutionFeedReduceWindowOrSelectAndScatter(
274       HloInstruction* instr, int64_t depth);
275 
276   // Returns true if instr feeds an unpropagatable op before it feeds 'depth'
277   // number of convolutions.
278   bool DoesConvolutionFeedUnpropagatableOp(
279       HloInstruction* instr, int64_t depth = kUnpropagatableOpSearchDepth);
280 
281   // Checks that the space-to-batched shape has not rendered the new spatial
282   // dimension to be smaller than the window's size.
283   bool IsSpaceToBatchedSpaceSizeSuitable(HloInstruction* instr);
284 
285  private:
286   // Current HloComputation instance the ConvolutionVisitor is traversing.
287   HloComputation* computation_;
288 
289   absl::flat_hash_set<HloInstruction*> convs_to_visit_;
290   std::vector<HloInstruction*> conv_visitor_list_;
291   HloInstructionSet non_propagatable_instrs_;
292   // Map from a given spaced-to-batch instruction to its batched-to-space
293   // version.
294   absl::flat_hash_map<HloInstruction*, HloInstruction*> batch_to_space_map_;
295 
296   // Map from old (non space-to-batch) instructions to space-to-batch'ed
297   // instructions.
298   absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new_instrs_;
299 
300   // Map from instruction to dimensions of the shape. This is with respect to
301   // the old instruction.
302   absl::flat_hash_map<HloInstruction*, std::vector<int64_t>> instr_to_dim_map_;
303 
304   // Map from space-to-batch'ed instruction to its permute dims.
305   absl::flat_hash_map<HloInstruction*, std::vector<int64_t>>
306       instr_to_dim_permute_map_;
307 
308   // Map maintaining previously space-to-batched broadcasts.
309   absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>>
310       broadcast_map_;
311 
312   // Whether rewrite has occurred.
313   bool changed_ = false;
314 
315   // Depth for searching reduce window
316   static constexpr int64_t kReduceWindowSearchDepth = 10;
317 
318   // Depth for searching unpropagatable op.
319   static constexpr int64_t kUnpropagatableOpSearchDepth = 3;
320 
321   // Penalty on size for base dilated convs
322   static constexpr int64_t kMultiplierOnSpaceForBaseDilation = 3;
323 
324   // Cache for <instruction, depth> ==> unpropagatablilty decision.
325   absl::flat_hash_map<std::pair<HloInstruction*, int64_t>, bool>
326       unpropagatability_cache_;
327 
328   // Controller for various knobs.
329   SpaceToBatchController ctrl_;
330 };
331 
ConvolutionVisitor(SpaceToBatchController ctrl,HloComputation * computation)332 ConvolutionVisitor::ConvolutionVisitor(SpaceToBatchController ctrl,
333                                        HloComputation* computation) {
334   ctrl_ = ctrl;
335   computation_ = computation;
336   for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
337     if (inst->opcode() != HloOpcode::kConvolution) {
338       continue;
339     }
340 
341     auto convolution = inst;
342     // Perform legality checks.
343     if (!IsConvSuitableForSpaceToBatch(convolution)) {
344       VLOG(1) << "Conv not suitable for space-to-batch "
345               << convolution->ToString();
346       continue;
347     }
348     VLOG(1) << "Conv added to space-to-batch worklist "
349             << convolution->ToString();
350     convs_to_visit_.insert(convolution);
351     conv_visitor_list_.push_back(convolution);
352   }
353 }
354 
355 std::pair<std::vector<int64_t>, std::vector<int64_t>>
GetSpatialDimsToSplit(HloInstruction * old_operand)356 ConvolutionVisitor::GetSpatialDimsToSplit(HloInstruction* old_operand) {
357   auto new_operand = old_to_new_instrs_[old_operand];
358   auto dim_map_val = instr_to_dim_map_[old_operand];
359   auto permute_dims = instr_to_dim_permute_map_[new_operand];
360   std::vector<int64_t> old_dims(ctrl_.count_of_dimensions_to_convert),
361       new_dims(ctrl_.count_of_dimensions_to_convert);
362 
363   old_dims[0] = dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
364   new_dims[0] = DimLookUp(permute_dims, old_dims[0]);
365   for (int i = 1; i < ctrl_.count_of_dimensions_to_convert; ++i) {
366     old_dims[i] = old_dims[0] + i;
367     new_dims[i] = new_dims[0] + i;
368   }
369   return std::make_pair(old_dims, new_dims);
370 }
371 
IsForwardWindowDilatedConv(HloInstruction * convolution,ConvolutionDimensionNumbers & dim_numbers)372 bool ConvolutionVisitor::IsForwardWindowDilatedConv(
373     HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
374   const int64_t window_dilation_factor =
375       convolution->window()
376           .dimensions(GetFirstChosenSpatialDim(convolution))
377           .window_dilation();
378 
379   if (window_dilation_factor == 1) {
380     return false;
381   }
382 
383   const int64_t output_spatial_dim = dim_numbers.output_spatial_dimensions(
384       GetFirstChosenSpatialDim(convolution));
385   const int64_t kernel_spatial_dim = dim_numbers.kernel_spatial_dimensions(
386       GetFirstChosenSpatialDim(convolution));
387 
388   // If convolution's spatial dim size is larger than that of RHS, this is a
389   // forward RHS dilated convolution.
390   return convolution->operand(1)->shape().dimensions(kernel_spatial_dim) <
391          convolution->shape().dimensions(output_spatial_dim);
392 }
393 
IsConvSuitableForSpaceToBatch(HloInstruction * convolution)394 bool ConvolutionVisitor::IsConvSuitableForSpaceToBatch(
395     HloInstruction* convolution) {
396   ConvolutionDimensionNumbers dim_numbers =
397       convolution->convolution_dimension_numbers();
398 
399   // If there are no specified spatial dims, we return.
400   if (GetFirstChosenSpatialDim(convolution) < 0) {
401     return false;
402   }
403 
404   // Batch in batch_group_count has different semantics (it isn't true batch).
405   // Consider supporting this case in future if needed.
406   if (convolution->batch_group_count() != 1) {
407     return false;
408   }
409 
410   if (convolution->window()
411           .dimensions(GetFirstChosenSpatialDim(convolution))
412           .window_dilation() != 1) {
413     if (!IsForwardWindowDilatedConv(convolution, dim_numbers)) {
414       return false;
415     }
416   }
417 
418   const ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
419 
420   const int64_t low_pad = convolution->window()
421                               .dimensions(GetFirstChosenSpatialDim(convolution))
422                               .padding_low();
423 
424   // TODO(b/168316428): Support base dilations more generically.
425   if (c.base_dilation_factor != 1) {
426     if (!ctrl_.enable_propagations_on_base_dilations) {
427       return false;
428     }
429     if (c.stride != 1) {
430       return false;
431     }
432     // For low pad of 0, only support a pointwise kernel.
433     if (low_pad == 0) {
434       if (c.kernel_spatial_dim_size != 1) {
435         return false;
436       }
437     } else if (low_pad != c.base_dilation_factor - 1 &&
438                low_pad != c.base_dilation_factor) {
439       // Only support dilations such that base dilation factor and low pad are
440       // compatible with kernel_spatial_dim_size to be compatible with
441       // HaloDuplicateWithSlice.
442       return false;
443     }
444   }
445 
446   int64_t activations_batch_dim = dim_numbers.input_batch_dimension();
447 
448   const int64_t old_batch_size =
449       convolution->operand(0)->shape().dimensions(activations_batch_dim);
450 
451   if (old_batch_size > ctrl_.limit_on_batch_size) {
452     return false;
453   }
454 
455   VLOG(1) << "spatial size " << c.spatial_size << " halo size " << c.halo_size;
456 
457   // If the ratio is not within the 2X range, we can't Halo Pad from the next
458   // split.
459   if (c.halo_size > CeilOfRatio(c.spatial_size, ctrl_.number_of_splits)) {
460     return false;
461   }
462 
463   // TODO(b/201444224): The following cost model is needed to escape slowing
464   // down ssd batch 4.
465   if (c.base_dilation_factor > 1 &&
466       c.inherent_low_padding == c.base_dilation_factor) {
467     if (c.spatial_size <
468         kMultiplierOnSpaceForBaseDilation * ctrl_.number_of_splits) {
469       return false;
470     }
471   }
472 
473   VLOG(1) << "Legal space-to-batch convolution " << convolution->ToString();
474   return true;
475 }
476 
IsThisBackPropFilterConv(HloInstruction * convolution)477 bool ConvolutionVisitor::IsThisBackPropFilterConv(HloInstruction* convolution) {
478   auto activations = convolution->mutable_operand(0);
479   auto kernel = convolution->mutable_operand(1);
480   auto dim_numbers = convolution->convolution_dimension_numbers();
481 
482   if (!old_to_new_instrs_.contains(kernel) &&
483       !old_to_new_instrs_.contains(activations)) {
484     return false;
485   }
486 
487   if (old_to_new_instrs_.contains(kernel)) {
488     auto dim_map_val_op_0 = instr_to_dim_map_[kernel];
489     const int64_t old_batch_dim =
490         dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
491     if (convolution->convolution_dimension_numbers()
492             .kernel_input_feature_dimension() != old_batch_dim) {
493       return false;
494     }
495   }
496 
497   if (old_to_new_instrs_.contains(activations)) {
498     auto dim_map_val_op_0 = instr_to_dim_map_[activations];
499     const int64_t old_batch_dim =
500         dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
501     if (dim_numbers.input_feature_dimension() != old_batch_dim) {
502       return false;
503     }
504   }
505 
506   return true;
507 }
508 
HaloDuplicateWithSlice(HloInstruction * activations,absl::Span<const int64_t> spatial_dimensions_to_split,int64_t activations_batch_dim,int64_t low_padding,int64_t halo_size,HloInstruction * pad_val)509 StatusOr<HloInstruction*> ConvolutionVisitor::HaloDuplicateWithSlice(
510     HloInstruction* activations,
511     absl::Span<const int64_t> spatial_dimensions_to_split,
512     int64_t activations_batch_dim, int64_t low_padding, int64_t halo_size,
513     HloInstruction* pad_val) {
514   const int64_t spatial_dim_count = spatial_dimensions_to_split.size();
515   const int64_t additional_batch_size =
516       IPow<int64_t>(ctrl_.number_of_splits, spatial_dim_count);
517   const int64_t original_batch_size =
518       activations->shape().dimensions(activations_batch_dim) /
519       additional_batch_size;
520 
521   const int64_t spatial_split_size =
522       activations->shape().dimensions(spatial_dimensions_to_split[0]);
523   const int64_t batch_size = ctrl_.number_of_splits;
524 
525   TF_ASSIGN_OR_RETURN(
526       activations, SplitAndTransposeMergedBatch(
527                        activations, activations_batch_dim, original_batch_size,
528                        spatial_dimensions_to_split));
529 
530   const int64_t rank = activations->shape().rank();
531 
532   VLOG(1) << "In HaloDuplicateWithSlice with activations "
533           << activations->ToString() << " batch_size " << batch_size
534           << " spatial_split_size " << spatial_split_size << " low_padding "
535           << low_padding << " halo size " << halo_size;
536 
537   CHECK_LE(std::abs(halo_size - low_padding), spatial_split_size);
538 
539   for (int64_t i = 0; i < spatial_dimensions_to_split.size(); ++i) {
540     int64_t spatial_dimension_to_split = activations_batch_dim + 2 * (i + 1);
541     int64_t remapped_batch_dimension = spatial_dimension_to_split - 1;
542     HloInstruction* first_slice = nullptr;
543 
544     std::vector<int64_t> strides(rank, 1);
545     HloInstruction* padding =
546         pad_val == nullptr
547             ? computation_->AddInstruction(HloInstruction::CreateConstant(
548                   LiteralUtil::Zero(activations->shape().element_type())))
549             : pad_val;
550 
551     if (low_padding > 0) {
552       std::vector<int64_t> start_indices(rank, 0),
553           end_indices(activations->shape().dimensions().begin(),
554                       activations->shape().dimensions().end());
555       start_indices[spatial_dimension_to_split] =
556           spatial_split_size - low_padding;
557       end_indices[remapped_batch_dimension] = batch_size - 1;
558       end_indices[spatial_dimension_to_split] = spatial_split_size;
559 
560       TF_ASSIGN_OR_RETURN(first_slice, MakeSliceHlo(activations, start_indices,
561                                                     end_indices, strides));
562       VLOG(1) << "first slice " << first_slice->ToString();
563 
564       PaddingConfig padding_config =
565           MakeNoPaddingConfig(first_slice->shape().dimensions_size());
566       padding_config.mutable_dimensions(remapped_batch_dimension)
567           ->set_edge_padding_low(1);
568 
569       TF_ASSIGN_OR_RETURN(first_slice,
570                           MakePadHlo(first_slice, padding, padding_config));
571     }
572 
573     HloInstruction* halo_region = nullptr;
574     if (halo_size - low_padding > 0) {
575       std::vector<int64_t> start_indices_halo(rank, 0),
576           end_indices_halo(activations->shape().dimensions().begin(),
577                            activations->shape().dimensions().end());
578 
579       start_indices_halo[remapped_batch_dimension] = 1;
580       end_indices_halo[spatial_dimension_to_split] = halo_size - low_padding;
581 
582       TF_ASSIGN_OR_RETURN(halo_region,
583                           MakeSliceHlo(activations, start_indices_halo,
584                                        end_indices_halo, strides));
585       VLOG(1) << "halo_region " << halo_region->ToString();
586       PaddingConfig padding_config_halo =
587           MakeNoPaddingConfig(halo_region->shape().dimensions_size());
588       padding_config_halo.mutable_dimensions(remapped_batch_dimension)
589           ->set_edge_padding_high(1);
590       TF_ASSIGN_OR_RETURN(
591           halo_region, MakePadHlo(halo_region, padding, padding_config_halo));
592     }
593 
594     if (halo_size == 0 && low_padding != 0) {
595       std::vector<int64_t> start_indices_activations_cut(rank, 0),
596           end_indices_activations_cut(activations->shape().dimensions().begin(),
597                                       activations->shape().dimensions().end());
598       // When no halo is needed, we must slice out activations.
599       if (low_padding > 0) {
600         end_indices_activations_cut[spatial_dimension_to_split] =
601             spatial_split_size - low_padding;
602       } else {
603         start_indices_activations_cut[spatial_dimension_to_split] =
604             0 - low_padding;
605         end_indices_activations_cut[spatial_dimension_to_split] =
606             spatial_split_size;
607       }
608 
609       TF_ASSIGN_OR_RETURN(
610           activations, MakeSliceHlo(activations, start_indices_activations_cut,
611                                     end_indices_activations_cut, strides));
612     }
613 
614     if (first_slice != nullptr) {
615       TF_ASSIGN_OR_RETURN(activations,
616                           MakeConcatHlo({first_slice, activations},
617                                         spatial_dimension_to_split));
618     }
619 
620     if (halo_region != nullptr) {
621       TF_ASSIGN_OR_RETURN(activations,
622                           MakeConcatHlo({activations, halo_region},
623                                         spatial_dimension_to_split));
624     }
625   }
626 
627   TF_ASSIGN_OR_RETURN(
628       activations,
629       TransposeAndMergeBatch(
630           activations,
631           /*final_split_spatial_dim_positioning=*/spatial_dimensions_to_split,
632           activations_batch_dim, original_batch_size));
633 
634   VLOG(1) << "HaloDuplicated activations " << activations->ToString();
635   return activations;
636 }
637 
638 StatusOr<ConvolutionVisitor::SpaceNextToBatchDetails>
BringSpaceNextToBatch(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64_t & activations_batch_dim,std::vector<int64_t> * spatial_dimensions_to_split,bool is_backprop,bool is_rhs)639 ConvolutionVisitor::BringSpaceNextToBatch(
640     HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
641     int64_t& activations_batch_dim,
642     std::vector<int64_t>* spatial_dimensions_to_split, bool is_backprop,
643     bool is_rhs) {
644   for (int64_t i = 1; i < spatial_dimensions_to_split->size(); ++i) {
645     CHECK_EQ(spatial_dimensions_to_split->at(i),
646              spatial_dimensions_to_split->at(i - 1) + 1)
647         << "Spatial dimensions are not contiguous";
648   }
649 
650   int64_t spatial_dimension_to_split = spatial_dimensions_to_split->at(0);
651 
652   std::vector<int64_t> transpose_dims(activations->shape().rank());
653   if (spatial_dimension_to_split == activations_batch_dim + 1) {
654     absl::c_iota(transpose_dims, 0);
655   } else {
656     ConvolutionDimensionNumbers new_dim_numbers = dim_numbers;
657     int64_t pushed_counter = 0;
658     int64_t new_batch_dim, new_spatial_dim;
659     int64_t dim_counter = 0;
660     if (is_rhs) {
661       CHECK(is_backprop);
662       for (int i = 0; i < activations->shape().rank(); ++i) {
663         if (i == activations_batch_dim) {
664           continue;
665         }
666         if (i == spatial_dimension_to_split) {
667           transpose_dims[dim_counter++] = activations_batch_dim;
668           new_batch_dim = pushed_counter;
669           pushed_counter++;
670           new_spatial_dim = pushed_counter;
671         }
672 
673         if (i == dim_numbers.kernel_output_feature_dimension()) {
674           new_dim_numbers.set_kernel_output_feature_dimension(pushed_counter);
675         } else {
676           auto it = absl::c_find(dim_numbers.kernel_spatial_dimensions(), i);
677           if (it != dim_numbers.kernel_spatial_dimensions().end()) {
678             int64_t j = it - dim_numbers.kernel_spatial_dimensions().begin();
679             new_dim_numbers.set_kernel_spatial_dimensions(j, pushed_counter);
680           }
681         }
682         transpose_dims[dim_counter++] = i;
683         pushed_counter++;
684       }
685 
686       activations_batch_dim = new_batch_dim;
687       spatial_dimension_to_split = new_spatial_dim;
688       TF_ASSIGN_OR_RETURN(activations,
689                           MakeTransposeHlo(activations, transpose_dims));
690 
691       new_dim_numbers.set_kernel_input_feature_dimension(activations_batch_dim);
692 
693     } else {
694       for (int i = 0; i < activations->shape().rank(); ++i) {
695         if (i == activations_batch_dim) {
696           continue;
697         }
698         if (i == spatial_dimension_to_split) {
699           transpose_dims[dim_counter++] = activations_batch_dim;
700           new_batch_dim = pushed_counter;
701           pushed_counter++;
702           new_spatial_dim = pushed_counter;
703         }
704 
705         if (is_backprop && i == dim_numbers.input_batch_dimension()) {
706           new_dim_numbers.set_input_batch_dimension(pushed_counter);
707         } else if (i == dim_numbers.input_feature_dimension()) {
708           new_dim_numbers.set_input_feature_dimension(pushed_counter);
709         } else {
710           auto it = absl::c_find(dim_numbers.input_spatial_dimensions(), i);
711           if (it != dim_numbers.input_spatial_dimensions().end()) {
712             int64_t j = it - dim_numbers.input_spatial_dimensions().begin();
713             new_dim_numbers.set_input_spatial_dimensions(j, pushed_counter);
714           }
715         }
716         transpose_dims[dim_counter++] = i;
717         pushed_counter++;
718       }
719 
720       activations_batch_dim = new_batch_dim;
721       spatial_dimension_to_split = new_spatial_dim;
722       TF_ASSIGN_OR_RETURN(activations,
723                           MakeTransposeHlo(activations, transpose_dims));
724 
725       if (is_backprop) {
726         new_dim_numbers.set_input_feature_dimension(activations_batch_dim);
727       } else {
728         new_dim_numbers.set_input_batch_dimension(activations_batch_dim);
729       }
730     }
731 
732     dim_numbers = new_dim_numbers;
733   }
734 
735   // Note that the spatial dimensions are in a sequential increasing order.
736   for (int64_t i = 0; i < spatial_dimensions_to_split->size(); ++i) {
737     (*spatial_dimensions_to_split)[i] = spatial_dimension_to_split + i;
738   }
739 
740   return SpaceNextToBatchDetails{activations, transpose_dims};
741 }
742 
SplitAndTransposeMergedBatch(HloInstruction * activations,int64_t batch_dimension,int64_t old_batch_size,absl::Span<const int64_t> spatial_dimensions)743 StatusOr<HloInstruction*> ConvolutionVisitor::SplitAndTransposeMergedBatch(
744     HloInstruction* activations, int64_t batch_dimension,
745     int64_t old_batch_size, absl::Span<const int64_t> spatial_dimensions) {
746   CHECK_EQ(batch_dimension + 1, spatial_dimensions[0]);
747   std::vector<int64_t> new_dimensions(activations->shape().dimensions().begin(),
748                                       activations->shape().dimensions().end());
749 
750   const int64_t new_batch_size =
751       activations->shape().dimensions(batch_dimension);
752 
753   VLOG(3) << "Decreasing the spatial size while propagating new_batch_size "
754           << new_batch_size << " old_batch_size " << old_batch_size;
755 
756   new_dimensions[batch_dimension] = old_batch_size;
757 
758   const int64_t spatial_dim_count = spatial_dimensions.size();
759   // Create additional batch dimensions.
760   for (int64_t i = 0; i < spatial_dim_count; ++i) {
761     new_dimensions.insert(new_dimensions.begin() + spatial_dimensions[0],
762                           ctrl_.number_of_splits);
763   }
764 
765   // Reshape the output of the new conv into the old convolutions shape.
766   TF_ASSIGN_OR_RETURN(HloInstruction * batch_split_activations,
767                       MakeReshapeHlo(new_dimensions, activations));
768 
769   if (spatial_dim_count > 1) {
770     // Transpose such that we get // B, B0, S0, B1, S1,...
771     std::vector<int64_t> transpose_dims(new_dimensions.size());
772     absl::c_iota(transpose_dims, 0);
773     // Transpose such that we get B, B0, S0, B1, S1,...
774     std::vector<int64_t> trans_dims(new_dimensions.size());
775     absl::c_iota(trans_dims, 0);
776 
777     int64_t start_batch_dim_position = batch_dimension + 1;
778     int64_t start_space_dim_position = batch_dimension + 2;
779 
780     for (int i = 0; i < spatial_dim_count; ++i) {
781       transpose_dims[start_batch_dim_position + 2 * i] =
782           batch_dimension + spatial_dim_count - i;
783       transpose_dims[start_space_dim_position + 2 * i] =
784           batch_dimension + spatial_dim_count + 1 + i;
785     }
786 
787     TF_ASSIGN_OR_RETURN(
788         batch_split_activations,
789         MakeTransposeHlo(batch_split_activations, transpose_dims));
790   }
791   return batch_split_activations;
792 }
793 
794 StatusOr<HloInstruction*>
ChangeSpatialSizeOnSpaceToBatchedShape(HloInstruction * activations,int64_t batch_dimension,int64_t old_batch_size,absl::Span<const int64_t> spatial_dimensions,int64_t new_spatial_dim_size,bool increase_spatial_size)795 ConvolutionVisitor::ChangeSpatialSizeOnSpaceToBatchedShape(
796     HloInstruction* activations, int64_t batch_dimension,
797     int64_t old_batch_size, absl::Span<const int64_t> spatial_dimensions,
798     int64_t new_spatial_dim_size, bool increase_spatial_size) {
799   CHECK_EQ(batch_dimension + 1, spatial_dimensions[0]);
800   std::vector<int64_t> new_dimensions(activations->shape().dimensions().begin(),
801                                       activations->shape().dimensions().end());
802 
803   const int64_t spatial_dim_count = spatial_dimensions.size();
804   const int64_t spatial_dim_size =
805       activations->shape().dimensions(spatial_dimensions[0]);
806   const int64_t reshaped_space_size = spatial_dim_size * ctrl_.number_of_splits;
807 
808   // Reshape the output of the new conv into the old convolutions shape.
809   TF_ASSIGN_OR_RETURN(
810       HloInstruction * batch_split_activations,
811       SplitAndTransposeMergedBatch(activations, batch_dimension, old_batch_size,
812                                    spatial_dimensions));
813 
814   // Now merge the individual (split) batch and space dimensions.
815   std::vector<int64_t> batch_space_collapse_reshape_dims(
816       batch_split_activations->shape().dimensions().begin(),
817       batch_split_activations->shape().dimensions().end());
818 
819   batch_space_collapse_reshape_dims.erase(
820       batch_space_collapse_reshape_dims.begin() + spatial_dimensions[0],
821       batch_space_collapse_reshape_dims.begin() + spatial_dimensions[0] +
822           spatial_dim_count);
823 
824   for (auto spatial_dimension : spatial_dimensions) {
825     batch_space_collapse_reshape_dims[spatial_dimension] = reshaped_space_size;
826   }
827 
828   TF_ASSIGN_OR_RETURN(HloInstruction * batch_space_collapsed_reshape,
829                       MakeReshapeHlo(batch_space_collapse_reshape_dims,
830                                      batch_split_activations));
831 
832   VLOG(3) << "First reshape done";
833 
834   const int64_t rank = activations->shape().rank();
835 
836   // If spatial size is increased, we add padding. If it has shrunk, we slice
837   // out the padding that was added before.
838   if (increase_spatial_size) {
839     PaddingConfig padding_config = MakeNoPaddingConfig(
840         batch_space_collapsed_reshape->shape().dimensions_size());
841     for (auto spatial_dimension : spatial_dimensions) {
842       padding_config.mutable_dimensions(spatial_dimension)
843           ->set_edge_padding_high(new_spatial_dim_size *
844                                       ctrl_.number_of_splits -
845                                   reshaped_space_size);
846       padding_config.mutable_dimensions(spatial_dimension)
847           ->set_edge_padding_low(0);
848     }
849     HloInstruction* padding = computation_->AddInstruction(
850         HloInstruction::CreateConstant(LiteralUtil::Zero(
851             batch_space_collapsed_reshape->shape().element_type())));
852 
853     TF_ASSIGN_OR_RETURN(
854         batch_space_collapsed_reshape,
855         MakePadHlo(batch_space_collapsed_reshape, padding, padding_config));
856   } else {
857     std::vector<int64_t> start_indices(rank, 0),
858         end_indices(batch_space_collapsed_reshape->shape().dimensions().begin(),
859                     batch_space_collapsed_reshape->shape().dimensions().end()),
860         strides(rank, 1);
861     for (auto spatial_dimension : spatial_dimensions) {
862       end_indices[spatial_dimension] =
863           new_spatial_dim_size * ctrl_.number_of_splits;
864     }
865 
866     // This is the slice from halo padding.
867     TF_ASSIGN_OR_RETURN(batch_space_collapsed_reshape,
868                         MakeSliceHlo(batch_space_collapsed_reshape,
869                                      start_indices, end_indices, strides));
870   }
871 
872   TF_ASSIGN_OR_RETURN(
873       HloInstruction * activations_new,
874       PerformSplitSpace(batch_space_collapsed_reshape, spatial_dimensions,
875                         batch_dimension, new_spatial_dim_size,
876                         ctrl_.number_of_splits));
877 
878   VLOG(3) << "Size decreased activations " << activations_new->ToString();
879 
880   return activations_new;
881 }
882 
Run()883 StatusOr<bool> ConvolutionVisitor::Run() {
884   for (auto conv : conv_visitor_list_) {
885     // If we expect to see an unpropagatable op, space-to-batch may not be
886     // beneficial.
887     if (ctrl_.disable_starting_on_small_chains &&
888         DoesConvolutionFeedUnpropagatableOp(conv)) {
889       VLOG(1) << "Giving up on conv " << conv->ToString()
890               << " because it feeds an unpropagatable op";
891       convs_to_visit_.erase(conv);
892     }
893     if (convs_to_visit_.count(conv) > 0) {
894       TF_CHECK_OK(PerformSpaceToBatchOnConvolution(conv));
895     }
896   }
897   conv_visitor_list_.clear();
898   convs_to_visit_.clear();
899   // Iterate through all instructions that we could not propagate through, and
900   // turn their operands from batch-to-space as needed.
901   for (auto instr : non_propagatable_instrs_) {
902     if (instr->opcode() == HloOpcode::kConvolution) {
903       VLOG(1) << "Instr " << instr->ToString();
904     }
905     // Try to propagate on backprop filters
906     if (instr->opcode() == HloOpcode::kConvolution &&
907         !IsConvSuitableForSpaceToBatch(instr)) {
908       HloInstruction* producer = nullptr;
909       if (old_to_new_instrs_.contains(instr->mutable_operand(0))) {
910         producer = instr->mutable_operand(0);
911       } else if (old_to_new_instrs_.contains(instr->mutable_operand(1))) {
912         producer = instr->mutable_operand(1);
913       }
914       if (producer) {
915         if (CanPropagate(instr, producer)) {
916           bool needs_further_propagation;
917           TF_ASSIGN_OR_RETURN(needs_further_propagation,
918                               Propagate(instr, producer));
919           TF_CHECK_OK(computation_->ReplaceInstruction(
920               instr, old_to_new_instrs_[instr]));
921           continue;
922         }
923       }
924     }
925     VLOG(1) << "Could not eventually propagate through " << instr->ToString();
926     absl::flat_hash_map<int64_t, HloInstruction*> operand_map;
927     for (int64_t i = 0; i < instr->operand_count(); ++i) {
928       if (old_to_new_instrs_.count(instr->mutable_operand(i))) {
929         TF_ASSIGN_OR_RETURN(operand_map[i],
930                             BatchToSpace(instr->mutable_operand(i)));
931       }
932     }
933     for (auto entry : operand_map) {
934       TF_CHECK_OK(instr->ReplaceOperandWith(entry.first, entry.second));
935     }
936   }
937   non_propagatable_instrs_.clear();
938   return changed_;
939 }
940 
IsTrivialElementwise(HloInstruction * hlo)941 bool IsTrivialElementwise(HloInstruction* hlo) {
942   if (hlo->opcode() == HloOpcode::kFusion || hlo->opcode() == HloOpcode::kRng ||
943       hlo->opcode() == HloOpcode::kCopy ||
944       hlo->opcode() == HloOpcode::kConstant ||
945       hlo->opcode() == HloOpcode::kIota || hlo->opcode() == HloOpcode::kMap) {
946     return false;
947   }
948   return hlo->IsElementwise();
949 }
950 
CanPropagate(HloInstruction * consumer,HloInstruction * producer)951 bool ConvolutionVisitor::CanPropagate(HloInstruction* consumer,
952                                       HloInstruction* producer) {
953   if (IsTrivialElementwise(consumer)) {
954     VLOG(2) << "Doing propagation check on elementwise op: "
955             << consumer->ToString();
956 
957     HloInstruction* pivot_operand = nullptr;
958     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
959       auto old_producer = consumer->mutable_operand(i);
960       std::vector<HloInstruction*> to_transform;
961       const bool broadcast_or_constant =
962           (old_producer->opcode() == HloOpcode::kConstant) ||
963           (old_producer->opcode() == HloOpcode::kBroadcast &&
964            IsBroadcastPropagatable(old_producer, producer)) ||
965           (consumer->IsElementwiseBinary() &&
966            old_producer->opcode() == HloOpcode::kBroadcast &&
967            IsBroadcastTree(old_producer, producer, to_transform));
968 
969       if (!old_to_new_instrs_.contains(old_producer) &&
970           !broadcast_or_constant) {
971         VLOG(1) << "Cannot propagate on elementwise op " << consumer->ToString()
972                 << " because operand " << old_producer->ToString()
973                 << " isn't ready ";
974         return false;
975       } else {
976         if (broadcast_or_constant) {
977           VLOG(2) << "Skipping on " << old_producer->ToString();
978           continue;
979         }
980 
981         CHECK(old_to_new_instrs_.contains(old_producer));
982 
983         CHECK(instr_to_dim_map_.contains(old_producer));
984         if (pivot_operand == nullptr) {
985           pivot_operand = old_producer;
986           VLOG(2) << "Elementwise op: pivot " << old_producer->ToString();
987         } else {
988           if (instr_to_dim_map_[pivot_operand]
989                                [DimMapper(SpaceToBatchDimMap::kBatch)] !=
990                   instr_to_dim_map_[old_producer]
991                                    [DimMapper(SpaceToBatchDimMap::kBatch)] ||
992               instr_to_dim_map_[pivot_operand]
993                                [DimMapper(SpaceToBatchDimMap::kSpace0)] !=
994                   instr_to_dim_map_[old_producer]
995                                    [DimMapper(SpaceToBatchDimMap::kSpace0)]) {
996             VLOG(2) << "Elementwise op: checking for shape equivalence "
997                     << consumer->ToString()
998                     << " failed due to changed batch space ordering ";
999             return false;
1000           }
1001           auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
1002           auto pivot_permute_dims = instr_to_dim_permute_map_[pivot_new_instr];
1003           auto new_instr = old_to_new_instrs_[old_producer];
1004           auto permute_dims = instr_to_dim_permute_map_[new_instr];
1005           for (int j = 0; j < pivot_permute_dims.size(); ++j) {
1006             // Ensure the dimension mapping is the same.
1007             if (pivot_permute_dims[j] != permute_dims[j]) {
1008               VLOG(2) << "Elementwise op: checking for shape equivalence "
1009                       << consumer->ToString()
1010                       << " failed due to permuted dimensions ";
1011               return false;
1012             }
1013 
1014             // Make sure all other dimensions are of the same size.
1015             if (pivot_new_instr->shape().dimensions(j) !=
1016                 new_instr->shape().dimensions(j)) {
1017               if (!((consumer->IsElementwiseBinary() ||
1018                      consumer->opcode() == HloOpcode::kSelect) &&
1019                     j == instr_to_dim_map_[pivot_operand][DimMapper(
1020                              SpaceToBatchDimMap::kSpace0)])) {
1021                 VLOG(2) << "Elementwise op: checking for shape equivalence "
1022                         << consumer->ToString()
1023                         << " failed due to changed shape sizes ";
1024                 return false;
1025               }
1026             }
1027           }
1028         }
1029       }
1030     }
1031   }
1032 
1033   if (consumer->opcode() == HloOpcode::kConcatenate) {
1034     // Make sure all operands have been space-to-batched.
1035     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1036       if (!instr_to_dim_map_.contains(consumer->mutable_operand(i))) {
1037         return false;
1038       }
1039     }
1040     auto pivot_operand = consumer->mutable_operand(0);
1041     auto pivot_new_instr = old_to_new_instrs_[pivot_operand];
1042     auto pivot_permute_dims = instr_to_dim_permute_map_[pivot_new_instr];
1043     for (int64_t i = 1; i < consumer->operand_count(); ++i) {
1044       auto new_instr = old_to_new_instrs_[consumer->mutable_operand(i)];
1045       auto permute_dims = instr_to_dim_permute_map_[new_instr];
1046 
1047       for (int j = 0; j < pivot_permute_dims.size(); ++j) {
1048         // Ensure the dimension mapping is the same.
1049         if (pivot_permute_dims[j] != permute_dims[j]) {
1050           VLOG(2) << "Concat op: checking for shape equivalence "
1051                   << consumer->ToString()
1052                   << " failed due to permuted dimensions ";
1053           return false;
1054         }
1055         // Make sure all other dimensions are of the same size.
1056         if (pivot_new_instr->shape().dimensions(j) !=
1057             new_instr->shape().dimensions(j)) {
1058           VLOG(2) << "Concat op: checking for shape equivalence "
1059                   << consumer->ToString()
1060                   << " failed due to changed shape sizes ";
1061           return false;
1062         }
1063       }
1064     }
1065     return true;
1066   }
1067 
1068   if (consumer->opcode() == HloOpcode::kConvolution) {
1069     if (!ConsumeFuel("space-to-batch-converter", [&] {
1070           return "Skipping space-to-batch propagation because fuel over\n";
1071         })) {
1072       return false;
1073     }
1074     // Lambda that checks basic sanity of dimension propagation on convolutions.
1075     // This includes: the split dimension from the previous convolution should
1076     // remain the same. No feature/batch dimension should be turned into a
1077     // spatial dimension.
1078     auto are_conv_dims_compatible =
1079         [&](const ConvolutionDimensionNumbers dim_numbers,
1080             std::vector<int64_t>& dim_map, bool check_lhs) {
1081           if (check_lhs) {
1082             if (dim_numbers.input_spatial_dimensions(
1083                     GetFirstChosenSpatialDim(consumer)) !=
1084                 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)]) {
1085               return false;
1086             }
1087             for (int i = 0; i < dim_numbers.input_spatial_dimensions().size();
1088                  ++i) {
1089               if (dim_numbers.input_spatial_dimensions(i) ==
1090                       dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] ||
1091                   dim_numbers.input_spatial_dimensions(i) ==
1092                       dim_map[DimMapper(SpaceToBatchDimMap::kFeature)]) {
1093                 return false;
1094               }
1095             }
1096           } else {
1097             if (dim_numbers.kernel_spatial_dimensions(
1098                     GetFirstChosenSpatialDim(consumer)) !=
1099                 dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)]) {
1100               return false;
1101             }
1102             for (int i = 0; i < dim_numbers.kernel_spatial_dimensions().size();
1103                  ++i) {
1104               if (dim_numbers.kernel_spatial_dimensions(i) ==
1105                       dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] ||
1106                   dim_numbers.kernel_spatial_dimensions(i) ==
1107                       dim_map[DimMapper(SpaceToBatchDimMap::kFeature)]) {
1108                 return false;
1109               }
1110             }
1111           }
1112           return true;
1113         };
1114 
1115     VLOG(1) << "Checking if conv is supported for propagation "
1116             << consumer->ToString();
1117     bool found_good_non_window_dilated_conv = true;
1118     if (IsConvSuitableForSpaceToBatch(consumer)) {
1119       // Activations must have been space-to-batched to enable propagation.
1120       if (!old_to_new_instrs_.contains(consumer->mutable_operand(0))) {
1121         found_good_non_window_dilated_conv = false;
1122       }
1123       auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
1124 
1125       if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1126                                     dim_map_val_op_0, /*check_lhs*/ true)) {
1127         found_good_non_window_dilated_conv = false;
1128       }
1129       // Make sure that the batch dimension is the same across the producer
1130       // and consumer.
1131       if (consumer->convolution_dimension_numbers().input_batch_dimension() !=
1132           dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)]) {
1133         found_good_non_window_dilated_conv = false;
1134       }
1135 
1136       if (found_good_non_window_dilated_conv) {
1137         return true;
1138       }
1139     }
1140 
1141     if (!ctrl_.enable_propagations_on_window_dilations) {
1142       return false;
1143     }
1144 
1145     if (!IsThisBackPropFilterConv(consumer)) {
1146       return false;
1147     }
1148     // Check for space-to-depth readiness here. Note this is not done in
1149     // SupportedOpForPropagation because the readiness is dependent upon
1150     // space-to-batchedness of the operands.
1151 
1152     // If there are no specified spatial dims, we return.
1153     if (GetFirstChosenSpatialDim(consumer) < 0) {
1154       return false;
1155     }
1156 
1157     // We currently only support stride of 1.
1158     if (consumer->window()
1159             .dimensions(GetFirstChosenSpatialDim(consumer))
1160             .stride() != 1) {
1161       return false;
1162     }
1163 
1164     // Same reason why we give up on batch group counts applies to features in
1165     // backprop.
1166     if (consumer->feature_group_count() != 1) {
1167       return false;
1168     }
1169 
1170     VLOG(2) << "Checking for backprop filter conv propagatability";
1171     CHECK_EQ(consumer->operand_count(), 2);
1172 
1173     auto activations = consumer->mutable_operand(0);
1174     auto kernel = consumer->mutable_operand(1);
1175 
1176     auto win_dims =
1177         consumer->window().dimensions(GetFirstChosenSpatialDim(consumer));
1178     const int64_t rhs_dilation = win_dims.window_dilation();
1179     const int64_t lhs_dilation = win_dims.base_dilation();
1180 
1181     // LHS dilations are supported by PropagateOnConv, and not by
1182     // PropagateOnBackpropFilterConv.
1183     if (lhs_dilation != 1) {
1184       return false;
1185     }
1186     // If the rhs_dilation is absent, we want both LHS and RHS to be space-to-
1187     // batched for propagating on backprop convolutions.
1188 
1189     if (rhs_dilation == 1 &&
1190         !ctrl_.enable_propagations_on_trivial_window_dilations) {
1191       if (!old_to_new_instrs_.contains(kernel) ||
1192           !old_to_new_instrs_.contains(activations)) {
1193         return false;
1194       }
1195     }
1196 
1197     if (!old_to_new_instrs_.contains(kernel)) {
1198       const int64_t rhs_batch =
1199           kernel->shape().dimensions(consumer->convolution_dimension_numbers()
1200                                          .kernel_input_feature_dimension());
1201       auto dim_map_val_op_0 = instr_to_dim_map_[activations];
1202       const int64_t old_batch_dim =
1203           dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
1204       const int64_t old_space_dim =
1205           dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kSpace0)];
1206       auto first_operand = old_to_new_instrs_[activations];
1207       auto permute_dims_first_operand =
1208           instr_to_dim_permute_map_[first_operand];
1209       const int64_t new_batch_dim =
1210           DimLookUp(permute_dims_first_operand, old_batch_dim);
1211       const int64_t new_space_dim =
1212           DimLookUp(permute_dims_first_operand, old_space_dim);
1213       const int64_t lhs_batch =
1214           first_operand->shape().dimensions(new_batch_dim);
1215 
1216       if (first_operand->shape().dimensions(new_space_dim) % rhs_dilation !=
1217           0) {
1218         return false;
1219       }
1220       // Because we want to convert activations into a space-to-batched version
1221       // only for backprop filter convolutions, we want to make sure that the
1222       // batch dimensions (feature dimensions, technically) are same sized.
1223       // Since LHS is already space-to-batched, we need to account for it too.
1224       if (rhs_batch * ctrl_.number_of_splits != lhs_batch) {
1225         return false;
1226       }
1227 
1228       if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1229                                     dim_map_val_op_0, /*check_lhs*/ true)) {
1230         return false;
1231       }
1232 
1233       // If kernel have not been propagated through, we can do
1234       // space-to-batch on them provided kernel has been propagated.
1235       VLOG(2)
1236           << "Backprop filter conv ready for propagation: activations ready, "
1237              " kernel will be space-to-batched";
1238       return true;
1239     }
1240 
1241     if (!old_to_new_instrs_.contains(activations)) {
1242       const int64_t lhs_batch = activations->shape().dimensions(
1243           consumer->convolution_dimension_numbers().input_feature_dimension());
1244       auto dim_map_val_op_1 = instr_to_dim_map_[consumer->mutable_operand(1)];
1245       const int64_t old_batch_dim =
1246           dim_map_val_op_1[DimMapper(SpaceToBatchDimMap::kBatch)];
1247       auto second_operand = old_to_new_instrs_[kernel];
1248       auto permute_dims_second_operand =
1249           instr_to_dim_permute_map_[second_operand];
1250       const int64_t new_batch_dim =
1251           DimLookUp(permute_dims_second_operand, old_batch_dim);
1252       const int64_t rhs_batch =
1253           second_operand->shape().dimensions(new_batch_dim);
1254 
1255       // Because we want to convert activations into a space-to-batched version
1256       // only for backprop filter convolutions, we want to make sure that the
1257       // batch dimensions (feature dimensions, technically) are same sized.
1258       // Since RHS is already space-to-batched, we need to account for it too.
1259       if (rhs_batch != ctrl_.number_of_splits * lhs_batch) {
1260         return false;
1261       }
1262 
1263       if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1264                                     dim_map_val_op_1, /*check_lhs*/ false)) {
1265         return false;
1266       }
1267 
1268       // If activations have not been propagated through, we can do
1269       // space-to-batch on them provided kernel has been propagated.
1270       VLOG(2) << "Backprop filter conv ready for propagation: kernel ready, "
1271                  " activations will be space-to-batched";
1272       return true;
1273     }
1274 
1275     auto first_operand = old_to_new_instrs_[activations];
1276     auto dim_map_val_op_0 = instr_to_dim_map_[activations];
1277 
1278     auto second_operand = old_to_new_instrs_[kernel];
1279     auto dim_map_val_op_1 = instr_to_dim_map_[kernel];
1280 
1281     auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
1282     auto permute_dims_second_operand =
1283         instr_to_dim_permute_map_[second_operand];
1284 
1285     const int64_t new_batch_dim_operand_0 =
1286         DimLookUp(permute_dims_first_operand,
1287                   dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)]);
1288 
1289     const int64_t new_space_dim_operand_0 =
1290         DimLookUp(permute_dims_first_operand,
1291                   dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kSpace0)]);
1292 
1293     const int64_t new_batch_dim_operand_1 =
1294         DimLookUp(permute_dims_second_operand,
1295                   dim_map_val_op_1[DimMapper(SpaceToBatchDimMap::kBatch)]);
1296     const int64_t new_space_dim_operand_1 =
1297         DimLookUp(permute_dims_second_operand,
1298                   dim_map_val_op_1[DimMapper(SpaceToBatchDimMap::kSpace0)]);
1299 
1300     if (first_operand->shape().dimensions(new_batch_dim_operand_0) !=
1301         second_operand->shape().dimensions(new_batch_dim_operand_1)) {
1302       VLOG(2) << "Backprop filter conv not ready for propagation because batch "
1303                  "dimensions don't line up";
1304       return false;
1305     }
1306 
1307     if (first_operand->shape().dimensions(new_space_dim_operand_0) >
1308         rhs_dilation *
1309             second_operand->shape().dimensions(new_space_dim_operand_1)) {
1310       VLOG(2) << "Backprop filter conv not ready for propagation because of "
1311                  "dilation factor mismatch";
1312       return false;
1313     }
1314 
1315     if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1316                                   dim_map_val_op_0, /*check_lhs*/ true)) {
1317       return false;
1318     }
1319 
1320     if (!are_conv_dims_compatible(consumer->convolution_dimension_numbers(),
1321                                   dim_map_val_op_1, /*check_lhs*/ false)) {
1322       return false;
1323     }
1324 
1325     VLOG(2) << "Backprop filter conv ready for propagation";
1326 
1327     return true;
1328   }
1329 
1330   if (consumer->opcode() == HloOpcode::kReduceWindow ||
1331       consumer->opcode() == HloOpcode::kReduce) {
1332     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1333       auto old_producer = consumer->mutable_operand(i);
1334       if (i == 0 && !old_to_new_instrs_.contains(old_producer)) {
1335         return false;
1336       }
1337     }
1338 
1339     // Make sure the post space-to-batch dim size is larger than window size.
1340     if (consumer->opcode() == HloOpcode::kReduceWindow) {
1341       return IsSpaceToBatchedSpaceSizeSuitable(consumer);
1342     }
1343   }
1344 
1345   if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
1346     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1347       auto old_producer = consumer->mutable_operand(i);
1348       if (i < 2 && !old_to_new_instrs_.contains(old_producer)) {
1349         return false;
1350       }
1351     }
1352 
1353     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1354     auto dim_map_val_op_0 = instr_to_dim_map_[consumer->mutable_operand(0)];
1355     auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
1356 
1357     auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
1358     auto permute_dims_second_operand =
1359         instr_to_dim_permute_map_[second_operand];
1360 
1361     // The permuting must match.
1362     if (permute_dims_first_operand != permute_dims_second_operand) {
1363       VLOG(2) << "Can't propagate through select and scatter due to "
1364                  "permutation mismatch";
1365       return false;
1366     }
1367 
1368     const int64_t old_batch_dim =
1369         dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kBatch)];
1370     const int64_t old_space_dim =
1371         dim_map_val_op_0[DimMapper(SpaceToBatchDimMap::kSpace0)];
1372 
1373     const int64_t new_batch_dim =
1374         DimLookUp(permute_dims_first_operand, old_batch_dim);
1375     const int64_t new_space_dim =
1376         DimLookUp(permute_dims_first_operand, old_space_dim);
1377 
1378     if (first_operand->shape().dimensions(new_batch_dim) !=
1379         second_operand->shape().dimensions(new_batch_dim)) {
1380       VLOG(2)
1381           << "Can't propagate through select and scatter due to dim mismatch";
1382       return false;
1383     }
1384 
1385     const int64_t stride =
1386         consumer->window().dimensions(old_space_dim).stride();
1387     const int64_t pad_high =
1388         consumer->window().dimensions(old_space_dim).padding_high();
1389     const int64_t pad_low =
1390         consumer->window().dimensions(old_space_dim).padding_low();
1391 
1392     if ((first_operand->shape().dimensions(new_space_dim) + pad_high +
1393          pad_low) /
1394             stride !=
1395         second_operand->shape().dimensions(new_space_dim)) {
1396       VLOG(2) << "Can't propagate through select and scatter due to stride "
1397                  "mismatch";
1398       return false;
1399     }
1400 
1401     return IsSpaceToBatchedSpaceSizeSuitable(consumer);
1402   }
1403   return true;
1404 }
1405 
PropagateOnBroadcast(HloInstruction * consumer,HloInstruction * producer)1406 void ConvolutionVisitor::PropagateOnBroadcast(HloInstruction* consumer,
1407                                               HloInstruction* producer) {
1408   auto new_producer = old_to_new_instrs_[producer];
1409   auto permute_dims = instr_to_dim_permute_map_[new_producer];
1410   auto dim_map_val = instr_to_dim_map_[producer];
1411 
1412   const int64_t old_batch_dim =
1413       dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
1414   const int64_t old_space_dim =
1415       dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
1416 
1417   auto orig_broadcast_dims = consumer->dimensions();
1418 
1419   bool batch_is_broadcasted =
1420       absl::c_linear_search(orig_broadcast_dims, old_batch_dim);
1421   const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1422   const int64_t new_space_dim = DimLookUp(permute_dims, old_space_dim);
1423 
1424   bool map_found = broadcast_map_.contains(consumer);
1425   if (map_found) {
1426     // Check if we previously had created the same broadcast.
1427     for (auto previous_broadcast : broadcast_map_[consumer]) {
1428       if (ShapeUtil::CompatibleIgnoringElementType(previous_broadcast->shape(),
1429                                                    new_producer->shape())) {
1430         return;
1431       }
1432     }
1433   }
1434 
1435   std::vector<int64_t> final_shape_dims(
1436       new_producer->shape().dimensions().begin(),
1437       new_producer->shape().dimensions().end());
1438   if (batch_is_broadcasted) {
1439     final_shape_dims[new_batch_dim] =
1440         producer->shape().dimensions(old_batch_dim);
1441     final_shape_dims[new_space_dim] *= ctrl_.number_of_splits;
1442   }
1443 
1444   std::vector<int64_t> broadcast_dims;
1445   const auto& dimensions = consumer->dimensions();
1446   broadcast_dims.reserve(dimensions.size());
1447   for (auto j : dimensions) {
1448     broadcast_dims.push_back(DimLookUp(permute_dims, j));
1449   }
1450   auto new_broadcast = MakeBroadcastHlo(consumer->mutable_operand(0),
1451                                         broadcast_dims, final_shape_dims);
1452   VLOG(1) << "Created broadcast " << new_broadcast->ToString();
1453 
1454   if (batch_is_broadcasted) {
1455     new_broadcast =
1456         MakeReshapeHlo(new_producer->shape().dimensions(), new_broadcast)
1457             .ValueOrDie();
1458     VLOG(2) << "Created reshape of broadcast " << new_broadcast->ToString();
1459   }
1460 
1461   if (!map_found) {
1462     absl::flat_hash_set<HloInstruction*> set_of_broadcasts;
1463     broadcast_map_[consumer] = set_of_broadcasts;
1464   }
1465   broadcast_map_[consumer].insert(new_broadcast);
1466 }
1467 
RewriteBroadcastTree(HloInstruction * producer,std::vector<HloInstruction * > & instructions_to_transform)1468 void ConvolutionVisitor::RewriteBroadcastTree(
1469     HloInstruction* producer,
1470     std::vector<HloInstruction*>& instructions_to_transform) {
1471   CHECK(old_to_new_instrs_.contains(producer));
1472   for (auto instr : instructions_to_transform) {
1473     if (instr->opcode() == HloOpcode::kBroadcast) {
1474       PropagateOnBroadcast(instr, producer);
1475     } else if (IsTrivialElementwise(instr)) {
1476       Propagate(instr, /*producer=*/instr->mutable_operand(0)).ValueOrDie();
1477     } else {
1478       LOG(FATAL) << "Unsupported opcode in RewriteBroadcastTree";
1479     }
1480   }
1481 }
1482 
IsBroadcastTree(HloInstruction * op,HloInstruction * consumer,std::vector<HloInstruction * > & instructions_to_transform)1483 bool ConvolutionVisitor::IsBroadcastTree(
1484     HloInstruction* op, HloInstruction* consumer,
1485     std::vector<HloInstruction*>& instructions_to_transform) {
1486   if (op->opcode() == HloOpcode::kBroadcast) {
1487     // We want to ensure that the broadcast did not happen on the space and
1488     // batch dimensions.
1489     if (IsBroadcastPropagatable(op, consumer)) {
1490       instructions_to_transform.push_back(op);
1491       return true;
1492     } else {
1493       return false;
1494     }
1495   }
1496   if (Match(op, m::ConstantScalar())) {
1497     return true;
1498   }
1499   if (!IsTrivialElementwise(op)) {
1500     return false;
1501   }
1502   for (int64_t i = 0; i < op->operand_count(); ++i) {
1503     if (!IsBroadcastTree(op->mutable_operand(i), consumer,
1504                          instructions_to_transform)) {
1505       return false;
1506     }
1507   }
1508   instructions_to_transform.push_back(op);
1509   return true;
1510 }
1511 
IsBroadcastPropagatable(HloInstruction * broadcast,HloInstruction * old_other_op)1512 bool ConvolutionVisitor::IsBroadcastPropagatable(HloInstruction* broadcast,
1513                                                  HloInstruction* old_other_op) {
1514   CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast);
1515   CHECK(instr_to_dim_map_.contains(old_other_op));
1516 
1517   auto result = instr_to_dim_map_[old_other_op];
1518   const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1519   auto broadcast_dims = broadcast->dimensions();
1520   return !absl::c_linear_search(broadcast_dims, space_dim);
1521 }
1522 
IsOpcodeNonPropagatable(HloInstruction * consumer)1523 bool ConvolutionVisitor::IsOpcodeNonPropagatable(HloInstruction* consumer) {
1524   // We can add more non-propagatable opcodes as needed.
1525   switch (consumer->opcode()) {
1526     case HloOpcode::kCustomCall:
1527       return true;
1528     default:
1529       return false;
1530   }
1531 }
1532 
SupportedOpForPropagation(HloInstruction * consumer,HloInstruction * producer)1533 bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer,
1534                                                    HloInstruction* producer) {
1535   if (IsOpcodeNonPropagatable(consumer)) {
1536     return false;
1537   }
1538 
1539   if (IsTrivialElementwise(consumer)) {
1540     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1541       if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1542         if (!IsBroadcastPropagatable(consumer->mutable_operand(i), producer)) {
1543           VLOG(2) << "Could not propagate through broadcast";
1544           return false;
1545         }
1546       }
1547     }
1548     return true;
1549   }
1550 
1551   if (consumer->opcode() == HloOpcode::kConvolution) {
1552     return true;
1553   }
1554 
1555   if (consumer->opcode() == HloOpcode::kConcatenate) {
1556     HloInstruction* pivot_operand = nullptr;
1557     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1558       if (instr_to_dim_map_.contains(consumer->mutable_operand(i))) {
1559         pivot_operand = consumer->mutable_operand(i);
1560         break;
1561       }
1562     }
1563     if (pivot_operand == nullptr) {
1564       VLOG(1) << "Concat: Dim map not found on any operand";
1565       return false;
1566     }
1567     // Disallow concating on the batch and space dims
1568     auto result = instr_to_dim_map_[pivot_operand];
1569     const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1570     const int64_t old_space_dim =
1571         result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1572     if (consumer->concatenate_dimension() == old_batch_dim ||
1573         consumer->concatenate_dimension() == old_space_dim) {
1574       return false;
1575     }
1576     return true;
1577   }
1578 
1579   if (consumer->opcode() == HloOpcode::kReverse) {
1580     auto operand_0 = consumer->mutable_operand(0);
1581     if (!instr_to_dim_map_.contains(operand_0)) {
1582       return false;
1583     }
1584     // Disallow reversing on the batch and space dims
1585     auto result = instr_to_dim_map_[operand_0];
1586     const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1587     const int64_t old_space_dim =
1588         result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1589 
1590     for (auto dim : consumer->dimensions()) {
1591       if (dim == old_batch_dim || dim == old_space_dim) {
1592         return false;
1593       }
1594     }
1595     return true;
1596   }
1597 
1598   if (consumer->opcode() == HloOpcode::kTranspose) {
1599     return true;
1600   }
1601 
1602   if (consumer->opcode() == HloOpcode::kPad) {
1603     auto operand_0 = consumer->mutable_operand(0);
1604     if (!instr_to_dim_map_.contains(operand_0)) {
1605       return false;
1606     }
1607     // Disallow reversing on the batch and space dims
1608     auto result = instr_to_dim_map_[operand_0];
1609     const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1610     const int64_t old_space_dim =
1611         result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1612 
1613     auto does_dim_have_padding = [](PaddingConfig padding_config, int64_t dim) {
1614       return padding_config.dimensions(dim).edge_padding_low() != 0 ||
1615              padding_config.dimensions(dim).edge_padding_high() != 0 ||
1616              padding_config.dimensions(dim).interior_padding() != 0;
1617     };
1618     // Batch and space dims should not have padding.
1619     if (does_dim_have_padding(consumer->padding_config(), old_batch_dim) ||
1620         does_dim_have_padding(consumer->padding_config(), old_space_dim)) {
1621       return false;
1622     }
1623     return true;
1624   }
1625 
1626   if (consumer->opcode() == HloOpcode::kReduce) {
1627     // Support only the trivial case where both batch and split spatial dim are
1628     // being reduced
1629 
1630     auto reduce_dims = consumer->dimensions();
1631     auto result = instr_to_dim_map_[consumer->mutable_operand(0)];
1632     const int64_t batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1633     const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1634     VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim
1635             << "  space_dim " << space_dim << " reduce "
1636             << consumer->ToString();
1637     return absl::c_linear_search(reduce_dims, batch_dim) &&
1638            absl::c_linear_search(reduce_dims, space_dim);
1639   }
1640 
1641   if (consumer->opcode() == HloOpcode::kReduceWindow &&
1642       consumer->shape().IsTuple()) {
1643     // TODO (b/73062247) variadic reduce window is not yet supported.
1644     return false;
1645   }
1646   if (consumer->opcode() == HloOpcode::kReduceWindow ||
1647       consumer->opcode() == HloOpcode::kSelectAndScatter) {
1648     auto first_operand = consumer->mutable_operand(0);
1649     auto window = consumer->window();
1650     if (instr_to_dim_map_.count(first_operand) <= 0) {
1651       VLOG(1) << "Dim map not found on windowed operand. Window dim count "
1652               << window.dimensions().size();
1653       return false;
1654     }
1655     // Disallow windowing on the batch dim
1656     auto result = instr_to_dim_map_[first_operand];
1657     const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
1658     const int64_t old_space_dim =
1659         result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1660     if (window.dimensions(old_batch_dim).size() != 1) {
1661       return false;
1662     }
1663 
1664     // Only allow no-low-padding cases.
1665     if (window.dimensions(old_space_dim).padding_low() != 0) {
1666       return false;
1667     }
1668 
1669     // No base/window dilations allowed on space and batch dimensions.
1670     if (window.dimensions(old_space_dim).base_dilation() != 1 ||
1671         window.dimensions(old_space_dim).window_dilation() != 1) {
1672       return false;
1673     }
1674     // No base/window dilations allowed on space and batch dimensions.
1675     if (window.dimensions(old_batch_dim).base_dilation() != 1 ||
1676         window.dimensions(old_batch_dim).window_dilation() != 1) {
1677       return false;
1678     }
1679 
1680     // Only allow small high pads.
1681     if (window.dimensions(old_space_dim).padding_high() >
1682         window.dimensions(old_space_dim).size()) {
1683       return false;
1684     }
1685 
1686     // Operand 0 must have been propagated through
1687     if (old_to_new_instrs_.count(first_operand) <= 0) {
1688       return false;
1689     }
1690 
1691     auto new_operand = old_to_new_instrs_[first_operand];
1692     auto permute_dims = instr_to_dim_permute_map_[new_operand];
1693 
1694     // Select-and-scatter specific checks.
1695     if (consumer->opcode() == HloOpcode::kSelectAndScatter) {
1696       const int64_t new_space_dim = DimLookUp(permute_dims, old_space_dim);
1697       // Make sure that the stride lines up.
1698       if (new_operand->shape().dimensions(new_space_dim) %
1699               window.dimensions(old_space_dim).stride() !=
1700           0) {
1701         return false;
1702       }
1703 
1704       // Only support floating point datatypes.
1705       if (!ShapeUtil::ElementIsFloating(consumer->shape())) {
1706         return false;
1707       }
1708       // We currently only support adds in the scatter.
1709       auto scatter_comp = consumer->scatter();
1710       if (!Match(scatter_comp->root_instruction(),
1711                  m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
1712         return false;
1713       }
1714       // Select should just be a single comparison with GE as the direction.
1715       auto select_comp = consumer->select();
1716       if (!Match(select_comp->root_instruction(),
1717                  m::Compare(m::Parameter(0), m::Parameter(1))
1718                      .WithComparisonDirection(ComparisonDirection::kGe)) &&
1719           !Match(select_comp->root_instruction(),
1720                  m::Compare(m::Parameter(1), m::Parameter(0))
1721                      .WithComparisonDirection(ComparisonDirection::kGe))) {
1722         return false;
1723       }
1724       // We do not support low padding on select-and-scatter.
1725       if (consumer->window().dimensions(old_space_dim).padding_low() != 0) {
1726         return false;
1727       }
1728     }
1729 
1730     return true;
1731   }
1732 
1733   return false;
1734 }
1735 
Propagate(HloInstruction * consumer,HloInstruction * producer)1736 StatusOr<bool> ConvolutionVisitor::Propagate(HloInstruction* consumer,
1737                                              HloInstruction* producer) {
1738   auto computation = consumer->parent();
1739   if (IsTrivialElementwise(consumer)) {
1740     auto dim_map_val = instr_to_dim_map_[producer];
1741     auto new_consumer = computation->AddInstruction(consumer->Clone());
1742 
1743     bool is_pivot_producer_modified = false;
1744     // For elementwise binary ops, both of whose operands have been space-to-
1745     // batched, if their new spatial sizes don't match, choose the bigger one
1746     // as the producer.
1747     if (consumer->IsElementwiseBinary() ||
1748         consumer->opcode() == HloOpcode::kSelect) {
1749       int64_t pivot_operand_number = -1;
1750       HloInstruction* pivot_operand = nullptr;
1751       for (int i = 0; i < consumer->operand_count(); ++i) {
1752         if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1753           continue;
1754         }
1755         auto operand = consumer->mutable_operand(i);
1756         if (old_to_new_instrs_.contains(operand)) {
1757           if (pivot_operand_number == -1 ||
1758               old_to_new_instrs_[pivot_operand]->shape().dimensions() <
1759                   old_to_new_instrs_[operand]->shape().dimensions()) {
1760             is_pivot_producer_modified = true;
1761             pivot_operand_number = i;
1762             pivot_operand = consumer->mutable_operand(pivot_operand_number);
1763           }
1764         }
1765       }
1766       if (pivot_operand_number != -1) {
1767         producer = pivot_operand;
1768       }
1769     }
1770 
1771     for (int64_t i = 0; i < consumer->operand_count(); ++i) {
1772       std::vector<HloInstruction*> instructions_to_transform;
1773 
1774       if (consumer->operand(i)->opcode() == HloOpcode::kBroadcast) {
1775         auto broadcast = consumer->mutable_operand(i);
1776         PropagateOnBroadcast(broadcast, producer);
1777         HloInstruction* new_broadcast = nullptr;
1778         auto new_producer = old_to_new_instrs_[producer];
1779         for (auto previous_broadcast : broadcast_map_[broadcast]) {
1780           if (ShapeUtil::CompatibleIgnoringElementType(
1781                   previous_broadcast->shape(), new_producer->shape())) {
1782             new_broadcast = previous_broadcast;
1783             break;
1784           }
1785         }
1786         CHECK_NE(new_broadcast, nullptr);
1787         TF_CHECK_OK(
1788             new_consumer->ReplaceOperandWithDifferentShape(i, new_broadcast));
1789       } else if (old_to_new_instrs_.contains(consumer->mutable_operand(i))) {
1790         HloInstruction* operand_to_use = nullptr;
1791         auto result = instr_to_dim_map_[producer];
1792         const int64_t old_batch_dim =
1793             result[DimMapper(SpaceToBatchDimMap::kBatch)];
1794         const int64_t old_space_dim =
1795             result[DimMapper(SpaceToBatchDimMap::kSpace0)];
1796         const int64_t old_batch_size =
1797             producer->shape().dimensions(old_batch_dim);
1798         HloInstruction* new_instr =
1799             old_to_new_instrs_[consumer->mutable_operand(i)];
1800         HloInstruction* pivot_new_instr = old_to_new_instrs_[producer];
1801 
1802         auto permute_dims = instr_to_dim_permute_map_[new_instr];
1803         const int64_t batch_dim = DimLookUp(permute_dims, old_batch_dim);
1804         const int64_t space_dim = DimLookUp(permute_dims, old_space_dim);
1805         const int64_t batch_size = new_instr->shape().dimensions(batch_dim);
1806 
1807         if (new_instr->shape().dimensions(space_dim) !=
1808             pivot_new_instr->shape().dimensions(space_dim)) {
1809           // Because we do not propagate through transposes, the batch should
1810           // always be followed by the split space dimension.
1811           CHECK_EQ(batch_dim + 1, space_dim);
1812 
1813           // Reshape to 1D, pad to the producer's size, reshape back to 2D.
1814           std::vector<int64_t> new_dimensions(
1815               new_instr->shape().dimensions().begin(),
1816               new_instr->shape().dimensions().end());
1817           new_dimensions[space_dim] *= (batch_size / old_batch_size);
1818           new_dimensions[batch_dim] = old_batch_size;
1819 
1820           TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
1821                               MakeReshapeHlo(new_dimensions, new_instr));
1822 
1823           const int64_t pivot_space_size =
1824               pivot_new_instr->shape().dimensions(space_dim) * batch_size /
1825               old_batch_size;
1826 
1827           CHECK(pivot_space_size > new_dimensions[space_dim] ||
1828                 !is_pivot_producer_modified);
1829 
1830           PaddingConfig padding_config =
1831               MakeNoPaddingConfig(reshape->shape().dimensions_size());
1832           padding_config.mutable_dimensions(space_dim)->set_edge_padding_high(
1833               pivot_space_size - new_dimensions[space_dim]);
1834           padding_config.mutable_dimensions(space_dim)->set_edge_padding_low(0);
1835           HloInstruction* padding =
1836               computation_->AddInstruction(HloInstruction::CreateConstant(
1837                   LiteralUtil::Zero(reshape->shape().element_type())));
1838 
1839           TF_ASSIGN_OR_RETURN(HloInstruction * padded_operand,
1840                               MakePadHlo(reshape, padding, padding_config));
1841 
1842           TF_ASSIGN_OR_RETURN(
1843               operand_to_use,
1844               MakeReshapeHlo(pivot_new_instr->shape().dimensions(),
1845                              padded_operand));
1846 
1847         } else {
1848           operand_to_use = old_to_new_instrs_[consumer->mutable_operand(i)];
1849         }
1850         TF_CHECK_OK(
1851             new_consumer->ReplaceOperandWithDifferentShape(i, operand_to_use));
1852       } else if (consumer->IsElementwiseBinary() &&
1853                  consumer->mutable_operand(i)->opcode() ==
1854                      HloOpcode::kBroadcast &&
1855                  IsBroadcastTree(consumer->mutable_operand(i), producer,
1856                                  instructions_to_transform)) {
1857         RewriteBroadcastTree(producer, instructions_to_transform);
1858         TF_CHECK_OK(new_consumer->ReplaceOperandWithDifferentShape(
1859             i, old_to_new_instrs_[consumer->mutable_operand(i)]));
1860       } else if (consumer->operand(i)->opcode() == HloOpcode::kConstant) {
1861         TF_ASSIGN_OR_RETURN(
1862             auto new_constant,
1863             PropagateOnConstant(consumer->mutable_operand(i), producer));
1864         TF_CHECK_OK(
1865             new_consumer->ReplaceOperandWithDifferentShape(i, new_constant));
1866       }
1867     }
1868     auto old_type = new_consumer->mutable_shape()->element_type();
1869     *(new_consumer->mutable_shape()) = old_to_new_instrs_[producer]->shape();
1870 
1871     // The element type needs to be retained.
1872     new_consumer->mutable_shape()->set_element_type(old_type);
1873 
1874     old_to_new_instrs_[consumer] = new_consumer;
1875     instr_to_dim_map_[consumer] = std::vector<int64_t>(dim_map_val);
1876     CHECK(instr_to_dim_permute_map_.contains(old_to_new_instrs_[producer]));
1877     instr_to_dim_permute_map_[new_consumer] = std::vector<int64_t>(
1878         instr_to_dim_permute_map_[old_to_new_instrs_[producer]]);
1879 
1880     VLOG(2) << " new_consumer " << new_consumer->ToString()
1881             << " old_to_new_instrs_[producer] "
1882             << old_to_new_instrs_[producer]->ToString() << " permute dims "
1883             << instr_to_dim_permute_map_.count(new_consumer);
1884 
1885     return true;
1886   }
1887 
1888   if (consumer->opcode() == HloOpcode::kConvolution) {
1889     if (IsConvSuitableForSpaceToBatch(consumer)) {
1890       TF_CHECK_OK(PropagateOnConv(consumer));
1891       return true;
1892     } else {
1893       TF_CHECK_OK(PropagateOnBackpropFilterConv(consumer));
1894       return false;
1895     }
1896   }
1897 
1898   if (consumer->opcode() == HloOpcode::kConcatenate) {
1899     TF_CHECK_OK(PropagateOnConcat(consumer));
1900     return true;
1901   }
1902 
1903   if (consumer->opcode() == HloOpcode::kReverse) {
1904     TF_CHECK_OK(PropagateOnReverse(consumer));
1905     return true;
1906   }
1907 
1908   // TODO(b/189500737) : Consider a common way of propagation for
1909   // slice/pad/reduce-window.
1910   if (consumer->opcode() == HloOpcode::kPad) {
1911     TF_CHECK_OK(PropagateOnPad(consumer));
1912     return true;
1913   }
1914 
1915   if (consumer->opcode() == HloOpcode::kReduce) {
1916     auto new_consumer = computation->AddInstruction(consumer->Clone());
1917     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1918 
1919     auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1920     const int64_t old_batch_dim =
1921         dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
1922 
1923     auto permute_dims = instr_to_dim_permute_map_[first_operand];
1924     const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
1925 
1926     auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0));
1927     std::vector<int64_t> old_spatial_dims = retval.first;
1928     std::vector<int64_t> new_spatial_dims = retval.second;
1929 
1930     TF_ASSIGN_OR_RETURN(
1931         first_operand,
1932         SelectValidPortion(first_operand, consumer->mutable_operand(0),
1933                            consumer->mutable_operand(1), new_batch_dim,
1934                            new_spatial_dims, old_batch_dim, old_spatial_dims));
1935 
1936     std::vector<int64_t> changed_dims(new_consumer->dimensions().size());
1937     for (int64_t i = 0; i < new_consumer->dimensions().size(); ++i) {
1938       changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i));
1939     }
1940     *(new_consumer->mutable_dimensions()) = changed_dims;
1941     // Replace operand 0.
1942     TF_CHECK_OK(
1943         new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
1944     // We do not set instr_to_dim_permute_map_ here because no further
1945     // propagation is needed here.
1946     old_to_new_instrs_[consumer] = new_consumer;
1947     instr_to_dim_map_[consumer] = std::vector<int64_t>(dim_map_val);
1948 
1949     // Since the resultant ordering of dimension is the same as before, no
1950     // further propagation is needed.
1951     return false;
1952   }
1953 
1954   if (consumer->opcode() == HloOpcode::kTranspose) {
1955     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
1956     // Pass the first operand forward, with the map of the permuted dims.
1957     auto new_consumer = computation->AddInstruction(first_operand->Clone());
1958     old_to_new_instrs_[consumer] = new_consumer;
1959     auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
1960     const int64_t old_batch_dim =
1961         dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
1962     const int64_t old_space_dim =
1963         dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)];
1964     const int64_t old_feature_dim =
1965         dim_map_val[DimMapper(SpaceToBatchDimMap::kFeature)];
1966 
1967     int64_t new_batch_dim, new_space_dim, new_feature_dim;
1968     std::vector<int64_t> new_dimensions(consumer->dimensions().size());
1969     for (int64_t ctr = 0; ctr < consumer->dimensions().size(); ++ctr) {
1970       int64_t dim = consumer->dimensions(ctr);
1971       if (dim == old_batch_dim) {
1972         new_batch_dim = ctr;
1973       }
1974       if (dim == old_space_dim) {
1975         new_space_dim = ctr;
1976       }
1977       if (dim == old_feature_dim) {
1978         new_feature_dim = ctr;
1979       }
1980     }
1981 
1982     std::vector<int64_t> dim_map(NumMappedDims());
1983     dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = new_batch_dim;
1984     dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = new_feature_dim;
1985     dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = new_space_dim;
1986     instr_to_dim_map_[consumer] = dim_map;
1987 
1988     std::vector<int64_t> new_permute_dims(consumer->dimensions().size());
1989     auto permute_dims = instr_to_dim_permute_map_[first_operand];
1990     for (int64_t i = 0; i < consumer->dimensions().size(); ++i) {
1991       new_permute_dims[i] = DimLookUp(permute_dims, consumer->dimensions(i));
1992     }
1993 
1994     instr_to_dim_permute_map_[new_consumer] = new_permute_dims;
1995 
1996     return true;
1997   }
1998 
1999   if (consumer->opcode() == HloOpcode::kReduceWindow ||
2000       consumer->opcode() == HloOpcode::kSelectAndScatter) {
2001     bool is_select_and_scatter =
2002         consumer->opcode() == HloOpcode::kSelectAndScatter;
2003     auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)];
2004 
2005     auto init_val = is_select_and_scatter ? consumer->mutable_operand(2)
2006                                           : consumer->mutable_operand(1);
2007     auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)];
2008 
2009     auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0));
2010     std::vector<int64_t> old_spatial_dims = retval.first;
2011     std::vector<int64_t> new_spatial_dims = retval.second;
2012 
2013     const int64_t old_batch_dim =
2014         dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)];
2015     const int64_t old_space_dim = old_spatial_dims[0];
2016     auto permute_dims = instr_to_dim_permute_map_[first_operand];
2017     const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim);
2018     const int64_t new_space_dim = new_spatial_dims[0];
2019 
2020     // Calculate the required halo size
2021     auto new_shape = first_operand->shape();
2022     auto old_shape = consumer->mutable_operand(0)->shape();
2023 
2024     const int64_t new_space_size = new_shape.dimensions(new_space_dim);
2025     const int64_t stride =
2026         consumer->window().dimensions(old_space_dim).stride();
2027 
2028     auto pad_val =
2029         is_select_and_scatter
2030             ? computation_->AddInstruction(
2031                   HloInstruction::CreateConstant(LiteralUtil::MinValue(
2032                       consumer->operand(2)->shape().element_type())))
2033             : init_val;
2034     TF_ASSIGN_OR_RETURN(
2035         first_operand,
2036         SelectValidPortion(first_operand, consumer->mutable_operand(0), pad_val,
2037                            new_batch_dim, new_spatial_dims, old_batch_dim,
2038                            old_spatial_dims));
2039 
2040     const int64_t extra_space = new_space_size % stride;
2041     if (extra_space) {
2042       CHECK_EQ(consumer->opcode(), HloOpcode::kReduceWindow);
2043       const int64_t old_batch_size = old_shape.dimensions(old_batch_dim);
2044       const int64_t old_space_size = old_shape.dimensions(old_space_dim);
2045       // If the shrunk space is still larger/equal than the original space, we
2046       // reduce the space.
2047       if ((new_space_size - extra_space) * old_batch_size *
2048               ctrl_.number_of_splits >=
2049           old_batch_size * old_space_size) {
2050         TF_ASSIGN_OR_RETURN(
2051             first_operand, ChangeSpatialSizeOnSpaceToBatchedShape(
2052                                first_operand, new_batch_dim, old_batch_size,
2053                                new_spatial_dims, new_space_size - extra_space));
2054       } else {
2055         TF_ASSIGN_OR_RETURN(
2056             first_operand,
2057             ChangeSpatialSizeOnSpaceToBatchedShape(
2058                 first_operand, new_batch_dim, old_batch_size, new_spatial_dims,
2059                 new_space_size + stride - extra_space,
2060                 /*increase_spatial_size*/ true));
2061       }
2062     }
2063     const int64_t window_size =
2064         consumer->window().dimensions(old_space_dim).size();
2065     const int64_t last_overlap_point = ((new_space_size - 1) / stride) * stride;
2066     VLOG(1) << "last_overlap_point " << last_overlap_point << " window_size "
2067             << window_size << " new_space_size " << new_space_size;
2068 
2069     const int64_t halo_size = last_overlap_point + window_size - new_space_size;
2070     if (halo_size > 0) {
2071       TF_ASSIGN_OR_RETURN(
2072           first_operand,
2073           HaloDuplicateWithSlice(first_operand, new_spatial_dims, new_batch_dim,
2074                                  /*low_padding=*/0, halo_size, init_val));
2075     }
2076 
2077     Window new_win;
2078     for (int64_t i = 0; i < consumer->window().dimensions().size(); ++i) {
2079       auto dim = ReverseDimLookUp(permute_dims, i);
2080       new_win.add_dimensions();
2081       new_win.mutable_dimensions(i)->set_stride(
2082           consumer->window().dimensions(dim).stride());
2083       new_win.mutable_dimensions(i)->set_size(
2084           consumer->window().dimensions(dim).size());
2085       if (i == old_space_dim) {
2086         new_win.mutable_dimensions(i)->set_padding_high(0);
2087         new_win.mutable_dimensions(i)->set_padding_low(0);
2088       } else {
2089         new_win.mutable_dimensions(i)->set_padding_high(
2090             consumer->window().dimensions(dim).padding_high());
2091         new_win.mutable_dimensions(i)->set_padding_low(
2092             consumer->window().dimensions(dim).padding_low());
2093       }
2094       new_win.mutable_dimensions(i)->set_window_dilation(
2095           consumer->window().dimensions(dim).window_dilation());
2096       new_win.mutable_dimensions(i)->set_base_dilation(
2097           consumer->window().dimensions(dim).base_dilation());
2098       new_win.mutable_dimensions(i)->set_window_reversal(
2099           consumer->window().dimensions(dim).window_reversal());
2100     }
2101 
2102     new_shape = first_operand->shape();
2103 
2104     HloInstruction* new_consumer = nullptr;
2105     if (is_select_and_scatter) {
2106       auto second_operand = old_to_new_instrs_[consumer->mutable_operand(1)];
2107 
2108       auto select_comp = consumer->select();
2109 
2110       auto scatter_comp = consumer->scatter();
2111       TF_ASSIGN_OR_RETURN(
2112           auto new_select_and_scatter_shape,
2113           ShapeInference::InferSelectAndScatterShape(
2114               new_shape, select_comp->ComputeProgramShape(), new_win,
2115               second_operand->shape(), init_val->shape(),
2116               scatter_comp->ComputeProgramShape()));
2117       new_consumer =
2118           computation_->AddInstruction(HloInstruction::CreateSelectAndScatter(
2119               new_select_and_scatter_shape, first_operand, select_comp, new_win,
2120               second_operand, init_val, scatter_comp));
2121       // Replace operand 0.
2122       TF_CHECK_OK(
2123           new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
2124       // Replace operand 1.
2125       TF_CHECK_OK(
2126           new_consumer->ReplaceOperandWithDifferentShape(1, second_operand));
2127       VLOG(2) << "New select and scatter " << new_consumer->ToString();
2128 
2129       // If the window size was larger than the stride, there could be overlaps.
2130       // Such cases require updates from both overlaps to be applied.
2131       if (halo_size > 0) {
2132         const int64_t rank = new_consumer->shape().rank();
2133 
2134         const int64_t batch_size =
2135             new_consumer->shape().dimensions(new_batch_dim);
2136 
2137         std::vector<int64_t> start_indices(rank, 0),
2138             end_indices(new_consumer->shape().dimensions().begin(),
2139                         new_consumer->shape().dimensions().end()),
2140             strides(rank, 1);
2141         start_indices[new_space_dim] = new_space_size;
2142         end_indices[new_space_dim] = new_space_size + halo_size;
2143         end_indices[new_batch_dim] = batch_size - 1;
2144 
2145         // This is the slice from halo padding.
2146         TF_ASSIGN_OR_RETURN(
2147             HloInstruction * bottom,
2148             MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
2149 
2150         std::vector<int64_t> start_indices_top(rank, 0),
2151             end_indices_top(new_consumer->shape().dimensions().begin(),
2152                             new_consumer->shape().dimensions().end());
2153         end_indices_top[new_space_dim] = halo_size;
2154         // The first batch has correct data.
2155         start_indices_top[new_batch_dim] = 1;
2156 
2157         // This is the original area from where halo pad was extracted.
2158         TF_ASSIGN_OR_RETURN(HloInstruction * top,
2159                             MakeSliceHlo(new_consumer, start_indices_top,
2160                                          end_indices_top, strides));
2161 
2162         HloInstruction* default_fill =
2163             MakeBroadcastHlo(init_val, {}, top->shape().dimensions());
2164 
2165         // Compare to see if the bottom area was changed.
2166         TF_ASSIGN_OR_RETURN(
2167             HloInstruction * bottom_compare,
2168             MakeCompareHlo(ComparisonDirection::kNe, bottom, default_fill));
2169 
2170         // Take out only the changed values.
2171         TF_ASSIGN_OR_RETURN(
2172             HloInstruction * bottom_taken,
2173             MakeSelectHlo(bottom_compare, bottom, default_fill));
2174 
2175         // Compare to see if the top area was changed.
2176         TF_ASSIGN_OR_RETURN(
2177             HloInstruction * top_compare,
2178             MakeCompareHlo(ComparisonDirection::kNe, top, default_fill));
2179 
2180         // Take out only the changed values.
2181         TF_ASSIGN_OR_RETURN(HloInstruction * top_taken,
2182                             MakeSelectHlo(top_compare, top, bottom_taken));
2183 
2184         // This makes checks if the area was updated by both overlaps.
2185         TF_ASSIGN_OR_RETURN(
2186             HloInstruction * both_compare,
2187             MakeBinaryHlo(HloOpcode::kAnd, top_compare, bottom_compare));
2188 
2189         // If it was, add them up.
2190         TF_ASSIGN_OR_RETURN(HloInstruction * both_added,
2191                             MakeBinaryHlo(HloOpcode::kAdd, top, bottom));
2192 
2193         // Pad the final result to the original shape.
2194         TF_ASSIGN_OR_RETURN(HloInstruction * final_selection,
2195                             MakeSelectHlo(both_compare, both_added, top_taken));
2196 
2197         PaddingConfig padding_config =
2198             MakeNoPaddingConfig(final_selection->shape().dimensions_size());
2199         padding_config.mutable_dimensions(new_batch_dim)
2200             ->set_edge_padding_low(1);
2201         padding_config.mutable_dimensions(new_space_dim)
2202             ->set_edge_padding_high(new_space_size);
2203         HloInstruction* padding =
2204             computation_->AddInstruction(HloInstruction::CreateConstant(
2205                 LiteralUtil::Zero(final_selection->shape().element_type())));
2206 
2207         TF_ASSIGN_OR_RETURN(
2208             final_selection,
2209             MakePadHlo(final_selection, padding, padding_config));
2210 
2211         tensorflow::core::Bitmap b(batch_size * (new_space_size + halo_size));
2212         for (int k = 0; k < batch_size * (new_space_size + halo_size); ++k) {
2213           const int64_t space_index = k % (new_space_size + halo_size);
2214           const int64_t batch_index = (k / (new_space_size + halo_size));
2215           if (batch_index < 1 || space_index >= halo_size) {
2216             b.set(k);
2217           } else {
2218             b.clear(k);
2219           }
2220         }
2221 
2222         auto arg_literal = LiteralUtil::CreateR1(b);
2223         VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
2224         HloInstruction* slice_mask = computation_->AddInstruction(
2225             HloInstruction::CreateConstant(std::move(arg_literal)));
2226 
2227         std::vector<int64_t> slice_mask_reshape_dims(2);
2228         slice_mask_reshape_dims[0] = batch_size;
2229         slice_mask_reshape_dims[1] = (new_space_size + halo_size);
2230 
2231         TF_ASSIGN_OR_RETURN(
2232             HloInstruction * slice_mask_reshaped,
2233             MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
2234 
2235         // Broadcast the mask in all dimensions.
2236         HloInstruction* shape_mask = MakeBroadcastHlo(
2237             slice_mask_reshaped, {new_batch_dim, new_space_dim},
2238             final_selection->shape().dimensions());
2239 
2240         TF_ASSIGN_OR_RETURN(
2241             new_consumer,
2242             MakeSelectHlo(shape_mask, new_consumer, final_selection));
2243       }
2244 
2245       auto previous_shape =
2246           old_to_new_instrs_[consumer->mutable_operand(0)]->shape();
2247       std::vector<int64_t> start_indices(previous_shape.rank(), 0),
2248           end_indices(previous_shape.dimensions().begin(),
2249                       previous_shape.dimensions().end()),
2250           strides(previous_shape.rank(), 1);
2251 
2252       TF_ASSIGN_OR_RETURN(
2253           new_consumer,
2254           MakeSliceHlo(new_consumer, start_indices, end_indices, strides));
2255 
2256     } else {
2257       auto reduce_comp = consumer->to_apply();
2258       TF_ASSIGN_OR_RETURN(auto new_reduce_window_shape,
2259                           ShapeInference::InferReduceWindowShape(
2260                               new_shape, init_val->shape(), new_win));
2261       new_consumer =
2262           computation_->AddInstruction(HloInstruction::CreateReduceWindow(
2263               new_reduce_window_shape, first_operand, init_val, new_win,
2264               reduce_comp));
2265       // Replace operand 0.
2266       TF_CHECK_OK(
2267           new_consumer->ReplaceOperandWithDifferentShape(0, first_operand));
2268       VLOG(1) << "New reduce window " << new_consumer->ToString();
2269     }
2270 
2271     old_to_new_instrs_[consumer] = new_consumer;
2272     instr_to_dim_map_[consumer] = std::vector<int64_t>(dim_map_val);
2273 
2274     instr_to_dim_permute_map_[new_consumer] = std::vector<int64_t>(
2275         instr_to_dim_permute_map_[old_to_new_instrs_[consumer->mutable_operand(
2276             0)]]);
2277 
2278     return true;
2279   }
2280 
2281   LOG(FATAL) << "Trying to propagate through an unsupported instruction "
2282              << consumer->ToString();
2283   return true;
2284 }
2285 
SelectValidPortion(HloInstruction * new_instr,HloInstruction * old_instr,HloInstruction * select_val,int64_t new_batch_dim,absl::Span<const int64_t> new_space_dims,int64_t old_batch_dim,absl::Span<const int64_t> old_space_dims)2286 StatusOr<HloInstruction*> ConvolutionVisitor::SelectValidPortion(
2287     HloInstruction* new_instr, HloInstruction* old_instr,
2288     HloInstruction* select_val, int64_t new_batch_dim,
2289     absl::Span<const int64_t> new_space_dims, int64_t old_batch_dim,
2290     absl::Span<const int64_t> old_space_dims) {
2291   auto new_shape = new_instr->shape();
2292   auto old_shape = old_instr->shape();
2293   VLOG(1) << "In SelectValidPortion new_batch_dim " << new_batch_dim
2294           << " new_space_dim " << new_space_dims[0] << " old_batch_dim "
2295           << old_batch_dim << " old_space_dim " << old_space_dims[0];
2296   const int64_t new_batch_size = new_shape.dimensions(new_batch_dim);
2297   const int64_t new_space_size = new_shape.dimensions(new_space_dims[0]);
2298   const int64_t old_batch_size = old_shape.dimensions(old_batch_dim);
2299   const int64_t old_space_size = old_shape.dimensions(old_space_dims[0]);
2300   CHECK_EQ(new_batch_size % old_batch_size, 0)
2301       << " New batch size " << new_batch_size << " old batch size "
2302       << old_batch_size;
2303   const int64_t num_splits = ctrl_.number_of_splits;
2304   const int64_t spatial_dim_count = new_space_dims.size();
2305 
2306   // The dimension ordering found bounds is Old Batch, BN, BN -1 .., B0, S0, S1
2307   // ..., SN
2308   std::vector<int64_t> bounds(2 + spatial_dim_count, new_space_size);
2309   bounds[0] = old_batch_size;
2310   bounds[1] = IPow<int64_t>(num_splits, spatial_dim_count);
2311 
2312   const int64_t total_new_space =
2313       IPow<int64_t>(new_space_size, spatial_dim_count);
2314 
2315   // Build a constant PRED to decide which elements in the split dimension
2316   // are from halo.
2317   tensorflow::core::Bitmap b(new_batch_size * total_new_space);
2318   for (int k = 0; k < new_batch_size * total_new_space; ++k) {
2319     auto radix = ToMixedRadix(k, bounds);
2320 
2321     bool out_of_bounds = false;
2322     int64_t batch_residue = 1;
2323     for (int i = 0; i < spatial_dim_count; ++i) {
2324       const int64_t space_index = radix[2 + i];
2325       const int64_t batch_index = (radix[1] / batch_residue) % num_splits;
2326       batch_residue *= num_splits;
2327       if (batch_index * new_space_size + space_index >= old_space_size) {
2328         out_of_bounds = true;
2329       }
2330     }
2331 
2332     if (!out_of_bounds) {
2333       b.set(k);
2334     } else {
2335       b.clear(k);
2336     }
2337   }
2338 
2339   auto arg_literal = LiteralUtil::CreateR1(b);
2340   VLOG(4) << "Slice mask created: arg literal " << arg_literal.ToString();
2341   HloInstruction* slice_mask = computation_->AddInstruction(
2342       HloInstruction::CreateConstant(std::move(arg_literal)));
2343 
2344   std::vector<int64_t> slice_mask_reshape_dims(1 + spatial_dim_count,
2345                                                new_space_size);
2346   slice_mask_reshape_dims[0] = new_batch_size;
2347 
2348   TF_ASSIGN_OR_RETURN(HloInstruction * slice_mask_reshaped,
2349                       MakeReshapeHlo(slice_mask_reshape_dims, slice_mask));
2350 
2351   std::vector<int64_t> broadcast_dims(new_space_dims.begin(),
2352                                       new_space_dims.end());
2353   broadcast_dims.insert(broadcast_dims.begin(), new_batch_dim);
2354   // Broadcast the mask in all dimensions of the activations.
2355   HloInstruction* shape_mask = MakeBroadcastHlo(
2356       slice_mask_reshaped, broadcast_dims, new_instr->shape().dimensions());
2357 
2358   VLOG(1) << "Shape mask made " << shape_mask->ToString();
2359 
2360   HloInstruction* zeroes =
2361       MakeBroadcastHlo(select_val, {}, new_instr->shape().dimensions());
2362 
2363   TF_ASSIGN_OR_RETURN(new_instr, MakeSelectHlo(shape_mask, new_instr, zeroes));
2364 
2365   return new_instr;
2366 }
2367 
BatchToSpace(HloInstruction * old_instr)2368 StatusOr<HloInstruction*> ConvolutionVisitor::BatchToSpace(
2369     HloInstruction* old_instr) {
2370   if (batch_to_space_map_.count(old_instr)) {
2371     CHECK_NE(batch_to_space_map_[old_instr], nullptr);
2372     return batch_to_space_map_[old_instr];
2373   }
2374 
2375   auto result = instr_to_dim_map_[old_instr];
2376   const int64_t old_batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)];
2377   const int64_t old_space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)];
2378 
2379   const int64_t old_batch_size = old_instr->shape().dimensions(old_batch_dim);
2380   CHECK(old_to_new_instrs_.contains(old_instr));
2381   auto new_instr = old_to_new_instrs_[old_instr];
2382   VLOG(2) << "old_batch_dim " << old_batch_dim << " old_space_dim "
2383           << old_space_dim << " old_instr " << old_instr->ToString()
2384           << "\n new_instr " << new_instr->ToString() << " permute dims "
2385           << instr_to_dim_permute_map_.count(new_instr) << " old_batch_size "
2386           << old_batch_size;
2387   CHECK(instr_to_dim_permute_map_.contains(new_instr));
2388   auto permute_dims = instr_to_dim_permute_map_[new_instr];
2389   const int64_t batch_dim = DimLookUp(permute_dims, old_batch_dim);
2390   const int64_t space_dim = DimLookUp(permute_dims, old_space_dim);
2391 
2392   const int64_t spatial_dim_size = new_instr->shape().dimensions(space_dim);
2393 
2394   std::vector<int64_t> split_spatial_dimensions(
2395       ctrl_.count_of_dimensions_to_convert);
2396   absl::c_iota(split_spatial_dimensions, space_dim);
2397 
2398   TF_ASSIGN_OR_RETURN(new_instr, SplitAndTransposeMergedBatch(
2399                                      new_instr, batch_dim, old_batch_size,
2400                                      split_spatial_dimensions));
2401 
2402   std::vector<int64_t> new_dimensions(new_instr->shape().dimensions().begin(),
2403                                       new_instr->shape().dimensions().end());
2404 
2405   new_dimensions.erase(new_dimensions.begin() + split_spatial_dimensions[0],
2406                        new_dimensions.begin() + split_spatial_dimensions[0] +
2407                            ctrl_.count_of_dimensions_to_convert);
2408 
2409   for (auto spatial_dimension : split_spatial_dimensions) {
2410     new_dimensions[spatial_dimension] =
2411         spatial_dim_size * ctrl_.number_of_splits;
2412   }
2413 
2414   // Reshape the output of the new conv into the old convolutions shape.
2415   TF_ASSIGN_OR_RETURN(HloInstruction * reshape,
2416                       MakeReshapeHlo(new_dimensions, new_instr));
2417 
2418   VLOG(1) << "Batch to space reshape " << reshape->ToString();
2419   const int64_t rank = old_instr->shape().rank();
2420   std::vector<int64_t> start_indices(rank, 0),
2421       end_indices(new_dimensions.begin(), new_dimensions.end()),
2422       strides(rank, 1);
2423 
2424   for (auto spatial_dimension : split_spatial_dimensions) {
2425     end_indices[spatial_dimension] =
2426         old_instr->shape().dimensions(old_space_dim);
2427   }
2428 
2429   // This slicing is getting rid of the padding we added to evenly divide space.
2430   TF_ASSIGN_OR_RETURN(
2431       HloInstruction * output_slice,
2432       MakeSliceHlo(reshape, start_indices, end_indices, strides));
2433   VLOG(1) << "Batch to space slice " << output_slice->ToString();
2434   std::vector<int64_t> transpose_dims(permute_dims);
2435   TF_ASSIGN_OR_RETURN(HloInstruction * output_transpose,
2436                       MakeTransposeHlo(output_slice, transpose_dims));
2437   old_instr->SetupDerivedInstruction(output_transpose);
2438 
2439   batch_to_space_map_[old_instr] = output_transpose;
2440   return output_transpose;
2441 }
2442 
PropagateOnUsers(HloInstruction * old_conv)2443 Status ConvolutionVisitor::PropagateOnUsers(HloInstruction* old_conv) {
2444   std::queue<std::pair<HloInstruction*, HloInstruction*>> propagation_worklist;
2445 
2446   if (old_conv->user_count() == 0) {
2447     TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
2448                         BatchToSpace(old_conv));
2449     VLOG(1) << "Replacing the root instruction to "
2450             << batch_to_space->ToString();
2451     TF_CHECK_OK(computation_->ReplaceInstruction(old_conv, batch_to_space));
2452     VLOG(1) << "Replacement successful";
2453     return OkStatus();
2454   }
2455 
2456   int64_t iteration_count = 0;
2457   propagation_worklist.push(
2458       std::make_pair(old_conv, old_conv->mutable_operand(0)));
2459 
2460   while (!propagation_worklist.empty()) {
2461     auto top = propagation_worklist.front();
2462     auto node = top.first;
2463     auto parent = top.second;
2464     VLOG(1) << "Traversing for propagation operating on " << node->ToString();
2465     propagation_worklist.pop();
2466 
2467     // Don't work on the same node again.
2468     if (old_to_new_instrs_.count(node) > 0 && iteration_count != 0) {
2469       continue;
2470     }
2471 
2472     bool needs_further_propagation = true;
2473     if (iteration_count != 0) {
2474       // Do the space-to-batch propagation on this node.
2475       TF_ASSIGN_OR_RETURN(needs_further_propagation, Propagate(node, parent));
2476     }
2477     iteration_count++;
2478     // If this is the root, no room for further propagation.
2479     if (node->parent()->root_instruction() == node) {
2480       // The below case does not need going back to space.
2481       if (!needs_further_propagation) {
2482         VLOG(1) << "Replacing the root instruction to "
2483                 << old_to_new_instrs_[node]->ToString();
2484         TF_CHECK_OK(
2485             computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
2486         continue;
2487       }
2488 
2489       TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space, BatchToSpace(node));
2490       VLOG(1) << "Replacing the root instruction to "
2491               << batch_to_space->ToString();
2492       TF_CHECK_OK(computation_->ReplaceInstruction(node, batch_to_space));
2493     } else {
2494       if (!needs_further_propagation) {
2495         TF_CHECK_OK(
2496             computation_->ReplaceInstruction(node, old_to_new_instrs_[node]));
2497         continue;
2498       }
2499 
2500       HloInstructionSet unsupported_users;
2501       // Insert all users into the queue, as long as the ops are supported and
2502       // the op is ready for propagation. If the op is unsupported, do
2503       // batch-to-space. If not ready, mark as non-propagatable.
2504       for (auto user : node->users()) {
2505         if (!SupportedOpForPropagation(user, node)) {
2506           VLOG(1) << "Unsupported op found " << user->ToString();
2507           unsupported_users.insert(user);
2508           continue;
2509         }
2510         // If the instruction is ready for propagation, add it to the queue.
2511         if (CanPropagate(user, node)) {
2512           non_propagatable_instrs_.erase(user);
2513           propagation_worklist.push(std::make_pair(user, node));
2514         } else {
2515           // Mark it as non-propagatable for now, for later revisiting.
2516           non_propagatable_instrs_.insert(user);
2517         }
2518       }
2519 
2520       if (!unsupported_users.empty()) {
2521         TF_ASSIGN_OR_RETURN(HloInstruction * batch_to_space,
2522                             BatchToSpace(node));
2523         for (auto user : unsupported_users) {
2524           for (int64_t i = 0; i < user->operand_count(); ++i) {
2525             if (user->operand(i) == node) {
2526               TF_CHECK_OK(user->ReplaceOperandWith(i, batch_to_space));
2527             }
2528           }
2529         }
2530       }
2531     }
2532   }
2533   return OkStatus();
2534 }
2535 
PropagateOnConv(HloInstruction * convolution)2536 Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
2537   auto activations_old = convolution->mutable_operand(0);
2538 
2539   CHECK(old_to_new_instrs_.contains(activations_old));
2540   auto activations_new = old_to_new_instrs_[activations_old];
2541   auto permute_dims = instr_to_dim_permute_map_[activations_new];
2542 
2543   auto original_conv_dims = convolution->convolution_dimension_numbers();
2544 
2545   auto old_new_dims = GetSpatialDimsToSplit(activations_old);
2546   std::vector<int64_t> old_spatial_dims = old_new_dims.first;
2547   std::vector<int64_t> new_spatial_dims = old_new_dims.second;
2548 
2549   auto permuted_conv_dims_numbers = original_conv_dims;
2550 
2551   int64_t activations_batch_dim =
2552       DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
2553   int64_t activations_feature_dim =
2554       DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
2555   permuted_conv_dims_numbers.set_input_batch_dimension(activations_batch_dim);
2556   permuted_conv_dims_numbers.set_input_feature_dimension(
2557       activations_feature_dim);
2558 
2559   for (int64_t i = 0; i < original_conv_dims.input_spatial_dimensions_size();
2560        ++i) {
2561     permuted_conv_dims_numbers.set_input_spatial_dimensions(
2562         i, DimLookUp(permute_dims,
2563                      original_conv_dims.input_spatial_dimensions(i)));
2564   }
2565 
2566   const int64_t old_batch_dim = original_conv_dims.input_batch_dimension();
2567   const int64_t old_batch_size =
2568       activations_old->shape().dimensions(old_batch_dim);
2569 
2570   ConvDetails c =
2571       GetConvolutionDetails(convolution, permuted_conv_dims_numbers);
2572 
2573   VLOG(1) << "Propagating on conv activations_batch_dim "
2574           << activations_batch_dim << " spatial_dimension_to_split "
2575           << c.spatial_dimensions_to_split[0] << " old_batch_size "
2576           << old_batch_size;
2577 
2578   TF_ASSIGN_OR_RETURN(
2579       auto retval,
2580       BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
2581                             activations_batch_dim, &new_spatial_dims));
2582   activations_new = retval.instr;
2583   std::vector<int64_t> trans_dims = retval.transpose_dims;
2584   CHECK(!trans_dims.empty());
2585   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
2586       LiteralUtil::Zero(activations_new->shape().element_type())));
2587 
2588   TF_ASSIGN_OR_RETURN(
2589       activations_new,
2590       SelectValidPortion(activations_new, activations_old, select_val,
2591                          activations_batch_dim, new_spatial_dims, old_batch_dim,
2592                          old_spatial_dims));
2593   // Create the new convolution dim numbers.
2594   auto new_dim_numbers = permuted_conv_dims_numbers;
2595 
2596   const int64_t num_splits = ctrl_.number_of_splits;
2597   const int64_t output_offsets = convolution->shape().dimensions(
2598       permuted_conv_dims_numbers.output_spatial_dimensions(
2599           GetFirstChosenSpatialDim(convolution)));
2600   const int64_t output_offsets_per_split =
2601       CeilOfRatio(output_offsets, num_splits);
2602 
2603   int64_t spatial_split_size =
2604       CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
2605 
2606   VLOG(1) << "spatial size " << c.spatial_size << " halo size " << c.halo_size
2607           << " spatial_split_size " << spatial_split_size;
2608 
2609   // Keep increasing the split size so that overall size isn't smaller than the
2610   // original spatial dimension. Unlike for the first space-to-batch'ed
2611   // convolution, while propagating, we can use the last halo_size as available
2612   // spatial size.
2613   // If the spatial size is less than the halo size required, we need to
2614   // increase the spatial size.
2615   while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0 ||
2616          spatial_split_size < c.halo_size) {
2617     spatial_split_size += c.stride;
2618   }
2619 
2620   VLOG(1) << "Modified spatial_split_size " << spatial_split_size;
2621   const int64_t new_space_size =
2622       activations_new->shape().dimensions(new_spatial_dims[0]);
2623 
2624   int64_t slice_size = spatial_split_size + c.halo_size;
2625   // In the below case, we cannot use the activations directly for Halo
2626   // Duplication. We must reshape them.
2627   if (spatial_split_size > new_space_size) {
2628     TF_ASSIGN_OR_RETURN(
2629         activations_new,
2630         ChangeSpatialSizeOnSpaceToBatchedShape(
2631             activations_new, activations_batch_dim, old_batch_size,
2632             new_spatial_dims, spatial_split_size,
2633             /*increase_spatial_size*/ true));
2634 
2635   } else {
2636     // If the ideal spatial_split_size was smaller than the incoming spatial
2637     // dimension size, we don't need reshaping. Instead, we determine the
2638     // additional space available, and adjust the required slice size (and
2639     // thereby the halo size).
2640     VLOG(3)
2641         << "Decreasing the spatial size while propagating spatial_split_size "
2642         << spatial_split_size << " new_space_size " << new_space_size;
2643     if (spatial_split_size < new_space_size) {
2644       // If there's a stride mismatch, we change the new_space_size be
2645       // smaller (equal to spatial_split_size).
2646       if (new_space_size % c.stride != 0 || c.base_dilation_factor != 1) {
2647         TF_ASSIGN_OR_RETURN(
2648             activations_new,
2649             ChangeSpatialSizeOnSpaceToBatchedShape(
2650                 activations_new, activations_batch_dim, old_batch_size,
2651                 new_spatial_dims, spatial_split_size));
2652       } else {
2653         const int64_t additional_space_present = spatial_split_size % c.stride;
2654         spatial_split_size = new_space_size;
2655         slice_size =
2656             spatial_split_size + std::max(c.kernel_spatial_dim_size - c.stride -
2657                                               additional_space_present,
2658                                           static_cast<int64_t>(0));
2659       }
2660     }
2661   }
2662 
2663   // For space-to-batch supported base-dilated convolutions, the low padding is
2664   // passed on to the new convolutions. Halo does not have to account for it.
2665   TF_ASSIGN_OR_RETURN(
2666       activations_new,
2667       HaloDuplicateWithSlice(
2668           activations_new, new_spatial_dims, activations_batch_dim,
2669           /*low_padding=*/c.base_dilation_factor != 1 &&
2670                   c.inherent_low_padding != 0
2671               ? (c.inherent_low_padding == c.base_dilation_factor ? 1 : 0)
2672               : c.inherent_low_padding,
2673           slice_size - spatial_split_size));
2674 
2675   // We will generate output such that batch is followed by the split spatial
2676   // dimension.
2677   const int64_t rank = (convolution->shape().rank());
2678   std::vector<int64_t> transpose_dims(rank);
2679   int dim_count = 0;
2680   std::map<int64_t, int64_t> dim_translator;
2681 
2682   for (int j = 0;
2683        j < permuted_conv_dims_numbers.output_spatial_dimensions_size(); ++j) {
2684     if (j == GetFirstChosenSpatialDim(convolution)) {
2685       dim_translator[permuted_conv_dims_numbers.output_batch_dimension()] =
2686           dim_count;
2687       new_dim_numbers.set_output_batch_dimension(dim_count++);
2688     }
2689     dim_translator[permuted_conv_dims_numbers.output_spatial_dimensions(j)] =
2690         dim_count;
2691     new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
2692     dim_count++;
2693   }
2694 
2695   dim_translator[permuted_conv_dims_numbers.output_feature_dimension()] =
2696       dim_count;
2697   new_dim_numbers.set_output_feature_dimension(dim_count);
2698 
2699   int p = 0;
2700   for (const auto& entry : dim_translator) {
2701     transpose_dims[p] = entry.second;
2702     p++;
2703   }
2704 
2705   auto new_window = convolution->window();
2706   const int64_t first_dim = GetFirstChosenSpatialDim(convolution);
2707   for (int i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
2708     new_window.mutable_dimensions(first_dim + i)
2709         ->set_padding_high(c.high_padding_for_conv);
2710     new_window.mutable_dimensions(first_dim + i)
2711         ->set_padding_low(c.low_padding_for_conv);
2712   }
2713   TF_ASSIGN_OR_RETURN(
2714       HloInstruction * new_conv,
2715       MakeConvolveHlo(
2716           activations_new, /*rhs=*/convolution->mutable_operand(1),
2717           convolution->feature_group_count(), convolution->batch_group_count(),
2718           new_window, new_dim_numbers, convolution->precision_config(),
2719           /*preferred_element_type=*/convolution->shape().element_type()));
2720   convolution->SetupDerivedInstruction(new_conv);
2721 
2722   old_to_new_instrs_[convolution] = new_conv;
2723   VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
2724 
2725   std::vector<int64_t> dim_map(NumMappedDims());
2726   dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] =
2727       original_conv_dims.output_batch_dimension();
2728   dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
2729       original_conv_dims.output_feature_dimension();
2730   dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] =
2731       original_conv_dims.output_spatial_dimensions(
2732           GetFirstChosenSpatialDim(convolution));
2733   instr_to_dim_map_[convolution] = dim_map;
2734 
2735   instr_to_dim_permute_map_[new_conv] = std::vector<int64_t>(transpose_dims);
2736 
2737   convs_to_visit_.erase(convolution);
2738   return OkStatus();
2739 }
2740 
PropagateOnConcat(HloInstruction * concat)2741 Status ConvolutionVisitor::PropagateOnConcat(HloInstruction* concat) {
2742   auto first_operand = old_to_new_instrs_[concat->mutable_operand(0)];
2743   auto permute_dims = instr_to_dim_permute_map_[first_operand];
2744   const int64_t new_concat_dim =
2745       DimLookUp(permute_dims, concat->concatenate_dimension());
2746   std::vector<HloInstruction*> new_operands(concat->operand_count());
2747   for (int64_t i = 0; i < concat->operand_count(); ++i) {
2748     new_operands[i] = old_to_new_instrs_[concat->mutable_operand(i)];
2749   }
2750   TF_ASSIGN_OR_RETURN(HloInstruction * new_concat,
2751                       MakeConcatHlo(new_operands, new_concat_dim));
2752   old_to_new_instrs_[concat] = new_concat;
2753   // Set mappings from operand 0.
2754   instr_to_dim_map_[concat] =
2755       std::vector<int64_t>(instr_to_dim_map_[concat->mutable_operand(0)]);
2756   instr_to_dim_permute_map_[new_concat] =
2757       std::vector<int64_t>(instr_to_dim_permute_map_[first_operand]);
2758 
2759   return OkStatus();
2760 }
2761 
PropagateOnReverse(HloInstruction * reverse)2762 Status ConvolutionVisitor::PropagateOnReverse(HloInstruction* reverse) {
2763   auto first_operand = old_to_new_instrs_[reverse->mutable_operand(0)];
2764   auto permute_dims = instr_to_dim_permute_map_[first_operand];
2765 
2766   std::vector<int64_t> new_reverse_dimensions(reverse->dimensions().size());
2767   int dim_count = 0;
2768   for (auto dim : reverse->dimensions()) {
2769     new_reverse_dimensions[dim_count++] = DimLookUp(permute_dims, dim);
2770   }
2771   TF_ASSIGN_OR_RETURN(HloInstruction * new_reverse,
2772                       MakeReverseHlo(first_operand, new_reverse_dimensions));
2773   old_to_new_instrs_[reverse] = new_reverse;
2774   // Set mappings from operand 0.
2775   instr_to_dim_map_[reverse] =
2776       std::vector<int64_t>(instr_to_dim_map_[reverse->mutable_operand(0)]);
2777   instr_to_dim_permute_map_[new_reverse] =
2778       std::vector<int64_t>(instr_to_dim_permute_map_[first_operand]);
2779 
2780   return OkStatus();
2781 }
2782 
PropagateOnPad(HloInstruction * pad)2783 Status ConvolutionVisitor::PropagateOnPad(HloInstruction* pad) {
2784   auto first_operand = old_to_new_instrs_[pad->mutable_operand(0)];
2785   auto permute_dims = instr_to_dim_permute_map_[first_operand];
2786 
2787   PaddingConfig padding_config;
2788   for (int i = 0; i < pad->shape().rank(); ++i) {
2789     auto dimension = padding_config.add_dimensions();
2790     const int64_t old_dim = ReverseDimLookUp(permute_dims, i);
2791     auto old_padding = pad->padding_config().dimensions(old_dim);
2792     dimension->set_edge_padding_low(old_padding.edge_padding_low());
2793     dimension->set_edge_padding_high(old_padding.edge_padding_high());
2794     dimension->set_interior_padding(old_padding.interior_padding());
2795   }
2796 
2797   HloInstruction* padding = pad->mutable_operand(1);
2798 
2799   TF_ASSIGN_OR_RETURN(auto new_pad,
2800                       MakePadHlo(first_operand, padding, padding_config));
2801 
2802   old_to_new_instrs_[pad] = new_pad;
2803   // Set mappings from operand 0.
2804   instr_to_dim_map_[pad] =
2805       std::vector<int64_t>(instr_to_dim_map_[pad->mutable_operand(0)]);
2806   instr_to_dim_permute_map_[new_pad] =
2807       std::vector<int64_t>(instr_to_dim_permute_map_[first_operand]);
2808 
2809   return OkStatus();
2810 }
2811 
TransposeAndMergeBatch(HloInstruction * activations,absl::Span<const int64_t> final_split_spatial_dim_positioning,int64_t activations_batch_dim,int64_t old_batch_size)2812 StatusOr<HloInstruction*> ConvolutionVisitor::TransposeAndMergeBatch(
2813     HloInstruction* activations,
2814     absl::Span<const int64_t> final_split_spatial_dim_positioning,
2815     int64_t activations_batch_dim, int64_t old_batch_size) {
2816   const int64_t spatial_dim_count = final_split_spatial_dim_positioning.size();
2817 
2818   if (final_split_spatial_dim_positioning.size() > 1) {
2819     int64_t start_batch_dim_position = activations_batch_dim + 1;
2820     int64_t start_space_dim_position =
2821         start_batch_dim_position + spatial_dim_count;
2822 
2823     std::vector<int64_t> trans_dims(activations->shape().dimensions_size());
2824     absl::c_iota(trans_dims, 0);
2825 
2826     for (int i = 0; i < spatial_dim_count; ++i) {
2827       trans_dims[start_batch_dim_position + i] =
2828           start_batch_dim_position + (spatial_dim_count - 1 - i) * 2;
2829       trans_dims[start_space_dim_position + i] =
2830           start_batch_dim_position + i * 2 + 1;
2831     }
2832 
2833     TF_ASSIGN_OR_RETURN(activations, MakeTransposeHlo(activations, trans_dims));
2834   }
2835 
2836   std::vector<int64_t> batch_collapse_reshape_dims(
2837       activations->shape().dimensions().begin(),
2838       activations->shape().dimensions().end());
2839 
2840   const int64_t collapsed_batch_size =
2841       old_batch_size * IPow<int64_t>(ctrl_.number_of_splits, spatial_dim_count);
2842 
2843   batch_collapse_reshape_dims.erase(
2844       batch_collapse_reshape_dims.begin() + activations_batch_dim,
2845       batch_collapse_reshape_dims.begin() + activations_batch_dim +
2846           spatial_dim_count);
2847   batch_collapse_reshape_dims[activations_batch_dim] = collapsed_batch_size;
2848 
2849   TF_ASSIGN_OR_RETURN(HloInstruction * batch_collapsed_reshape,
2850                       MakeReshapeHlo(batch_collapse_reshape_dims, activations));
2851   return batch_collapsed_reshape;
2852 }
2853 
PerformSplitSpace(HloInstruction * activations,absl::Span<const int64_t> spatial_dimensions_to_split,int64_t activations_batch_dim,int64_t spatial_split_size,int64_t num_splits)2854 StatusOr<HloInstruction*> ConvolutionVisitor::PerformSplitSpace(
2855     HloInstruction* activations,
2856     absl::Span<const int64_t> spatial_dimensions_to_split,
2857     int64_t activations_batch_dim, int64_t spatial_split_size,
2858     int64_t num_splits) {
2859   const int64_t old_batch_size =
2860       activations->shape().dimensions(activations_batch_dim);
2861 
2862   // Now we reorganize the activations. E.g. if the shape [B, SPACE] was [1, 16]
2863   // and 4 splits were needed, we first create [4, 4]. Next, to deal with halo
2864   // in the spatial dimension, we generate a gather. E.g. if halo size was 2,
2865   // we'd create a shape of [24] using the gather, and reshape it into [6, 4]
2866   // (4 being the batch).
2867 
2868   // The benefit of the above mentioned scheme is that it allows for batch
2869   // growth. Here are some examples of the size increases it causes for a 3x3
2870   // kernel.
2871   // with batch=1, [1,16] -> [4,4] ->   [4,6] ->   [1,24] growth of 8.
2872   // with batch=2, [2,16] -> [8,4] ->   [8,6] ->   [1,48] growth of 16.
2873   // with batch=3, [3,16] -> [12,4] -> [12,6] -> [1,72] growth of 24.
2874 
2875   std::vector<int64_t> reshape_dimensions(
2876       activations->shape().dimensions().begin(),
2877       activations->shape().dimensions().end());
2878 
2879   for (auto spatial_dimension_to_split : spatial_dimensions_to_split) {
2880     reshape_dimensions[spatial_dimension_to_split] = spatial_split_size;
2881   }
2882 
2883   int counter = 0;
2884   for (auto spatial_dimension_to_split : spatial_dimensions_to_split) {
2885     reshape_dimensions.insert(
2886         reshape_dimensions.begin() + (spatial_dimension_to_split + counter),
2887         num_splits);
2888     counter++;
2889   }
2890 
2891   TF_ASSIGN_OR_RETURN(HloInstruction * batch_increased_reshape,
2892                       MakeReshapeHlo(reshape_dimensions, activations));
2893 
2894   return TransposeAndMergeBatch(
2895       batch_increased_reshape,
2896       /*final_split_spatial_dim_positioning=*/spatial_dimensions_to_split,
2897       activations_batch_dim, old_batch_size);
2898 }
2899 
PadAndSplitSpace(HloInstruction * activations,absl::Span<const int64_t> spatial_dimensions_to_split,int64_t activations_batch_dim,int64_t high_padding,int64_t low_padding,int64_t spatial_split_size,int64_t num_splits)2900 StatusOr<HloInstruction*> ConvolutionVisitor::PadAndSplitSpace(
2901     HloInstruction* activations,
2902     absl::Span<const int64_t> spatial_dimensions_to_split,
2903     int64_t activations_batch_dim, int64_t high_padding, int64_t low_padding,
2904     int64_t spatial_split_size, int64_t num_splits) {
2905   const int64_t old_batch_size =
2906       activations->shape().dimensions(activations_batch_dim);
2907 
2908   // Because we are splitting the spatial dimension, if convolution needed
2909   // padding in the spatial dimension, we materialize it.
2910   if (high_padding || low_padding) {
2911     PaddingConfig padding_config =
2912         MakeNoPaddingConfig(activations->shape().dimensions_size());
2913     for (auto spatial_dimension_to_split : spatial_dimensions_to_split) {
2914       padding_config.mutable_dimensions(spatial_dimension_to_split)
2915           ->set_edge_padding_high(high_padding);
2916       padding_config.mutable_dimensions(spatial_dimension_to_split)
2917           ->set_edge_padding_low(low_padding);
2918     }
2919     HloInstruction* padding =
2920         computation_->AddInstruction(HloInstruction::CreateConstant(
2921             LiteralUtil::Zero(activations->shape().element_type())));
2922     TF_ASSIGN_OR_RETURN(activations,
2923                         MakePadHlo(activations, padding, padding_config));
2924   }
2925   VLOG(1) << "Initial padded activations shape "
2926           << activations->shape().ToString() << " old_batch_size "
2927           << old_batch_size << " activations_batch_dim "
2928           << activations_batch_dim;
2929   return PerformSplitSpace(activations, spatial_dimensions_to_split,
2930                            activations_batch_dim, spatial_split_size,
2931                            num_splits);
2932 }
2933 
2934 StatusOr<std::pair<HloInstruction*, std::vector<int64_t>>>
SplitSpace(HloInstruction * activations,ConvolutionDimensionNumbers & dim_numbers,int64_t & activations_batch_dim,int64_t high_padding,int64_t low_padding,int64_t spatial_split_size,int64_t num_splits,std::vector<int64_t> * spatial_dimensions_to_split,bool is_backprop,bool is_rhs)2935 ConvolutionVisitor::SplitSpace(
2936     HloInstruction* activations, ConvolutionDimensionNumbers& dim_numbers,
2937     int64_t& activations_batch_dim, int64_t high_padding, int64_t low_padding,
2938     int64_t spatial_split_size, int64_t num_splits,
2939     std::vector<int64_t>* spatial_dimensions_to_split, bool is_backprop,
2940     bool is_rhs) {
2941   TF_ASSIGN_OR_RETURN(
2942       auto retval,
2943       BringSpaceNextToBatch(activations, dim_numbers, activations_batch_dim,
2944                             spatial_dimensions_to_split, is_backprop, is_rhs));
2945 
2946   activations = retval.instr;
2947   std::vector<int64_t> transpose_dims = retval.transpose_dims;
2948   TF_ASSIGN_OR_RETURN(
2949       auto new_activations,
2950       PadAndSplitSpace(activations, *spatial_dimensions_to_split,
2951                        activations_batch_dim, high_padding, low_padding,
2952                        spatial_split_size, num_splits));
2953   return std::make_pair(new_activations, transpose_dims);
2954 }
2955 
PropagateOnConstant(HloInstruction * consumer,HloInstruction * producer)2956 StatusOr<HloInstruction*> ConvolutionVisitor::PropagateOnConstant(
2957     HloInstruction* consumer, HloInstruction* producer) {
2958   CHECK(old_to_new_instrs_.contains(producer));
2959   HloInstruction* new_producer = old_to_new_instrs_[producer];
2960   auto prod_transpose_dims = instr_to_dim_permute_map_[new_producer];
2961   std::vector<int64_t> reversed_transpose_dims(prod_transpose_dims.size());
2962   for (int64_t i = 0; i < prod_transpose_dims.size(); ++i) {
2963     reversed_transpose_dims[i] = ReverseDimLookUp(prod_transpose_dims, i);
2964   }
2965   // Bring space next to batch.
2966   TF_ASSIGN_OR_RETURN(consumer,
2967                       MakeTransposeHlo(consumer, reversed_transpose_dims));
2968 
2969   auto retval = GetSpatialDimsToSplit(producer);
2970   std::vector<int64_t> old_spatial_dims = retval.first;
2971   std::vector<int64_t> new_spatial_dims = retval.second;
2972 
2973   auto dim_map = instr_to_dim_map_[producer];
2974   const int64_t old_batch_dim = dim_map[DimMapper(SpaceToBatchDimMap::kBatch)];
2975   const int64_t old_space_dim = old_spatial_dims[0];
2976   const int64_t new_batch_dim = DimLookUp(prod_transpose_dims, old_batch_dim);
2977   const int64_t new_space_dim = new_spatial_dims[0];
2978 
2979   const int64_t old_batch_size = producer->shape().dimensions(old_batch_dim);
2980   const int64_t new_batch_size = old_batch_size * ctrl_.number_of_splits;
2981   const int64_t high_padding =
2982       (new_batch_size * new_producer->shape().dimensions(new_space_dim) -
2983        old_batch_size * producer->shape().dimensions(old_space_dim)) /
2984       old_batch_size;
2985 
2986   auto new_consumer = PadAndSplitSpace(
2987       consumer, new_spatial_dims, new_batch_dim, high_padding,
2988       /*low_padding=*/0, new_producer->shape().dimensions(new_space_dim),
2989       ctrl_.number_of_splits);
2990 
2991   return new_consumer;
2992 }
2993 
PropagateOnBackpropFilterConv(HloInstruction * convolution)2994 Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
2995     HloInstruction* convolution) {
2996   auto activations_old = convolution->mutable_operand(0);
2997 
2998   const int64_t rhs_dilation =
2999       convolution->window()
3000           .dimensions(GetFirstChosenSpatialDim(convolution))
3001           .window_dilation();
3002 
3003   auto original_conv_dims = convolution->convolution_dimension_numbers();
3004 
3005   std::vector<int64_t> old_split_spatial_dims(
3006       ctrl_.dimension_from_end_to_convert),
3007       old_split_kernel_spatial_dims(ctrl_.dimension_from_end_to_convert);
3008   for (int i = 0; i < ctrl_.dimension_from_end_to_convert; ++i) {
3009     old_split_spatial_dims[i] = original_conv_dims.input_spatial_dimensions(
3010         GetFirstChosenSpatialDim(convolution) + i);
3011     old_split_kernel_spatial_dims[i] =
3012         original_conv_dims.kernel_spatial_dimensions(
3013             GetFirstChosenSpatialDim(convolution) + i);
3014   }
3015 
3016   auto kernel_old = convolution->mutable_operand(1);
3017   const int64_t old_kernel_split_dim_size =
3018       kernel_old->shape().dimensions(old_split_kernel_spatial_dims[0]);
3019 
3020   int64_t old_split_dim_size =
3021       activations_old->shape().dimensions(old_split_spatial_dims[0]);
3022 
3023   int64_t old_batch_dim = original_conv_dims.input_feature_dimension();
3024   int64_t kernel_old_batch_dim =
3025       original_conv_dims.kernel_input_feature_dimension();
3026   const int64_t old_batch_size =
3027       activations_old->shape().dimensions(old_batch_dim);
3028 
3029   CHECK(old_to_new_instrs_.contains(kernel_old) ||
3030         old_to_new_instrs_.contains(activations_old));
3031 
3032   HloInstruction* activations_new = nullptr;
3033   HloInstruction* kernel_new = nullptr;
3034   bool activations_locally_space_to_batched = false;
3035   bool kernel_locally_space_to_batched = false;
3036   std::vector<int64_t> permute_dims_kernel, permute_dims;
3037 
3038   if (old_to_new_instrs_.contains(activations_old)) {
3039     activations_new = old_to_new_instrs_[activations_old];
3040     permute_dims = instr_to_dim_permute_map_[activations_new];
3041   }
3042 
3043   if (old_to_new_instrs_.contains(kernel_old)) {
3044     kernel_new = old_to_new_instrs_[kernel_old];
3045     permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
3046   }
3047 
3048   // If activations were not space-to-batched, we space-to-batch them below.
3049   if (!old_to_new_instrs_.contains(activations_old)) {
3050     kernel_new = old_to_new_instrs_[kernel_old];
3051     permute_dims_kernel = instr_to_dim_permute_map_[kernel_new];
3052 
3053     VLOG(1) << "Space-to-batching activations to enable space-to-depth";
3054 
3055     const int64_t new_kernel_space_dim =
3056         DimLookUp(permute_dims_kernel, old_split_kernel_spatial_dims[0]);
3057 
3058     const int64_t new_kernel_split_dim_size =
3059         kernel_new->shape().dimensions(new_kernel_space_dim);
3060     const int64_t needed_spatial_size =
3061         rhs_dilation * new_kernel_split_dim_size;
3062     const int64_t pad_size =
3063         needed_spatial_size * ctrl_.number_of_splits - old_split_dim_size;
3064     ConvolutionDimensionNumbers tmp_dim_numbers;
3065     tmp_dim_numbers = original_conv_dims;
3066     TF_ASSIGN_OR_RETURN(
3067         auto retval, SplitSpace(activations_old, tmp_dim_numbers, old_batch_dim,
3068                                 /*high_padding=*/pad_size, /*low_padding=*/0,
3069                                 needed_spatial_size, ctrl_.number_of_splits,
3070                                 &old_split_spatial_dims,
3071                                 /*is_backprop=*/true));
3072 
3073     activations_new = retval.first;
3074 
3075     std::vector<int64_t> reversed_transpose_dims(retval.second.size());
3076     for (int64_t i = 0; i < retval.second.size(); ++i) {
3077       reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
3078     }
3079     permute_dims = reversed_transpose_dims;
3080 
3081     VLOG(3) << "New Activations " << retval.first->ToString();
3082 
3083     activations_locally_space_to_batched = true;
3084   } else if (!old_to_new_instrs_.contains(kernel_old)) {
3085     activations_new = old_to_new_instrs_[activations_old];
3086     permute_dims = instr_to_dim_permute_map_[activations_new];
3087 
3088     VLOG(1) << "Space-to-batching kernel to enable space-to-depth";
3089 
3090     const int64_t new_space_dim =
3091         DimLookUp(permute_dims, old_split_spatial_dims[0]);
3092     const int64_t new_split_dim_size =
3093         activations_new->shape().dimensions(new_space_dim);
3094     const int64_t needed_spatial_size =
3095         CeilOfRatio(new_split_dim_size, rhs_dilation);
3096     int64_t old_kernel_split_dim_size =
3097         kernel_old->shape().dimensions(old_split_kernel_spatial_dims[0]);
3098     const int64_t pad_size = needed_spatial_size * ctrl_.number_of_splits -
3099                              old_kernel_split_dim_size;
3100 
3101     ConvolutionDimensionNumbers tmp_dim_numbers;
3102     tmp_dim_numbers = original_conv_dims;
3103     TF_ASSIGN_OR_RETURN(
3104         auto retval,
3105         SplitSpace(kernel_old, tmp_dim_numbers, kernel_old_batch_dim,
3106                    /*high_padding=*/pad_size, /*low_padding=*/0,
3107                    needed_spatial_size, ctrl_.number_of_splits,
3108                    &old_split_kernel_spatial_dims,
3109                    /*is_backprop=*/true, /*is_rhs=*/true));
3110 
3111     kernel_new = retval.first;
3112 
3113     std::vector<int64_t> reversed_transpose_dims(retval.second.size());
3114     for (int64_t i = 0; i < retval.second.size(); ++i) {
3115       reversed_transpose_dims[i] = ReverseDimLookUp(retval.second, i);
3116     }
3117     permute_dims_kernel = reversed_transpose_dims;
3118 
3119     VLOG(3) << "New kernel " << retval.first->ToString();
3120 
3121     kernel_locally_space_to_batched = true;
3122   }
3123 
3124   CHECK_NE(activations_new, nullptr);
3125   CHECK_NE(kernel_new, nullptr);
3126 
3127   // TODO(b/189500737): For multi-dimensional space-to-batch, we'd need to add
3128   // an auxiliary dim per converted dimension.
3129   const int64_t new_spatial_dimension =
3130       activations_new->shape().dimensions_size();
3131 
3132   auto permuted_conv_dims_numbers = original_conv_dims;
3133 
3134   // Note the inversion here : batch and feature are inverted in backprop
3135   // filters.
3136   int64_t activations_batch_dim =
3137       DimLookUp(permute_dims, original_conv_dims.input_feature_dimension());
3138   int64_t activations_feature_dim =
3139       DimLookUp(permute_dims, original_conv_dims.input_batch_dimension());
3140 
3141   const int64_t previous_spatial_dim_count =
3142       original_conv_dims.input_spatial_dimensions_size();
3143   for (int64_t i = 0; i < previous_spatial_dim_count; ++i) {
3144     permuted_conv_dims_numbers.set_input_spatial_dimensions(
3145         i, DimLookUp(permute_dims,
3146                      original_conv_dims.input_spatial_dimensions(i)));
3147     permuted_conv_dims_numbers.set_kernel_spatial_dimensions(
3148         i, DimLookUp(permute_dims_kernel,
3149                      original_conv_dims.kernel_spatial_dimensions(i)));
3150   }
3151 
3152   permuted_conv_dims_numbers.add_input_spatial_dimensions(
3153       new_spatial_dimension);
3154   permuted_conv_dims_numbers.add_kernel_spatial_dimensions(
3155       new_spatial_dimension);
3156   permuted_conv_dims_numbers.add_output_spatial_dimensions(
3157       new_spatial_dimension);
3158 
3159   // For the output, make the last dimension size 1.
3160   const int64_t previous_chosen_spatial_dim_in_output =
3161       permuted_conv_dims_numbers.output_spatial_dimensions(
3162           GetFirstChosenSpatialDim(convolution));
3163   permuted_conv_dims_numbers.set_output_spatial_dimensions(
3164       GetFirstChosenSpatialDim(convolution), new_spatial_dimension);
3165   permuted_conv_dims_numbers.set_output_spatial_dimensions(
3166       previous_spatial_dim_count, previous_chosen_spatial_dim_in_output);
3167 
3168   const int64_t kernel_input_feature_dim = DimLookUp(
3169       permute_dims_kernel, original_conv_dims.kernel_input_feature_dimension());
3170 
3171   const int64_t kernel_output_feature_dim =
3172       DimLookUp(permute_dims_kernel,
3173                 original_conv_dims.kernel_output_feature_dimension());
3174 
3175   permuted_conv_dims_numbers.set_kernel_input_feature_dimension(
3176       kernel_input_feature_dim);
3177   permuted_conv_dims_numbers.set_kernel_output_feature_dimension(
3178       kernel_output_feature_dim);
3179 
3180   std::vector<int64_t> spatial_dimensions_to_split(
3181       ctrl_.count_of_dimensions_to_convert);
3182   const int64_t first_dim_to_split = GetFirstChosenSpatialDim(convolution);
3183   for (int64_t i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
3184     spatial_dimensions_to_split[i] =
3185         permuted_conv_dims_numbers.input_spatial_dimensions(first_dim_to_split +
3186                                                             i);
3187   }
3188 
3189   const int64_t kernel_spatial_dimension_to_split =
3190       permuted_conv_dims_numbers.kernel_spatial_dimensions(
3191           GetFirstChosenSpatialDim(convolution));
3192 
3193   int64_t new_split_dim_size =
3194       activations_new->shape().dimensions(spatial_dimensions_to_split[0]);
3195 
3196   const int64_t kernel_new_split_dim_size =
3197       kernel_new->shape().dimensions(kernel_spatial_dimension_to_split);
3198 
3199   permuted_conv_dims_numbers.set_input_batch_dimension(activations_feature_dim);
3200   permuted_conv_dims_numbers.set_input_feature_dimension(activations_batch_dim);
3201 
3202   VLOG(1) << "Propagating on conv activations_batch_dim "
3203           << activations_batch_dim << " spatial_dimension_to_split "
3204           << spatial_dimensions_to_split[0] << " old_batch_size "
3205           << old_batch_size << " new_split_dim_size " << new_split_dim_size;
3206 
3207   TF_ASSIGN_OR_RETURN(
3208       auto retval,
3209       BringSpaceNextToBatch(activations_new, permuted_conv_dims_numbers,
3210                             activations_batch_dim, &spatial_dimensions_to_split,
3211                             /*is_backprop=*/true));
3212 
3213   int64_t spatial_dimension_to_split = spatial_dimensions_to_split[0];
3214 
3215   std::vector<int64_t> transpose_dims = retval.transpose_dims;
3216   CHECK(!transpose_dims.empty());
3217   activations_new = retval.instr;
3218 
3219   VLOG(1) << "Activations_new post BringSpaceNextToBatch "
3220           << activations_new->ToString();
3221   VLOG(1) << "activations_batch_dim " << activations_batch_dim
3222           << " activations_feature_dim " << activations_feature_dim;
3223   const int64_t expected_split_dim_size =
3224       rhs_dilation * kernel_new_split_dim_size;
3225   if (new_split_dim_size != expected_split_dim_size) {
3226     CHECK_LT(new_split_dim_size, expected_split_dim_size);
3227     new_split_dim_size = expected_split_dim_size;
3228     TF_ASSIGN_OR_RETURN(
3229         activations_new,
3230         ChangeSpatialSizeOnSpaceToBatchedShape(
3231             activations_new, activations_batch_dim, old_batch_size,
3232             spatial_dimensions_to_split, new_split_dim_size, true));
3233   }
3234 
3235   spatial_dimension_to_split = spatial_dimensions_to_split[0];
3236 
3237   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
3238       LiteralUtil::Zero(activations_new->shape().element_type())));
3239 
3240   if (!activations_locally_space_to_batched) {
3241     // Select activations correctly by masking additional space.
3242     TF_ASSIGN_OR_RETURN(
3243         activations_new,
3244         SelectValidPortion(activations_new, activations_old, select_val,
3245                            activations_batch_dim, spatial_dimensions_to_split,
3246                            old_batch_dim, old_split_spatial_dims));
3247   }
3248   if (!kernel_locally_space_to_batched) {
3249     VLOG(3) << "Selecting the valid kernel area";
3250     // Select kernel correctly by masking additional space.
3251 
3252     std::vector<int64_t> new_kernel_split_spatial_dims(
3253         ctrl_.dimension_from_end_to_convert);
3254 
3255     // TODO(b/189500737) : Extend this once
3256     // IncreaseSpatialSizeOnSpaceToBatchedShape returns all dimensions.
3257     new_kernel_split_spatial_dims[0] = kernel_spatial_dimension_to_split;
3258 
3259     TF_ASSIGN_OR_RETURN(
3260         kernel_new,
3261         SelectValidPortion(kernel_new, kernel_old, select_val,
3262                            /*new_batch_dim=*/kernel_input_feature_dim,
3263                            new_kernel_split_spatial_dims,
3264                            /*old_batch_dim=*/
3265                            original_conv_dims.kernel_input_feature_dimension(),
3266                            old_split_kernel_spatial_dims));
3267   }
3268 
3269   // Create the new convolution dim numbers.
3270   auto new_dim_numbers = permuted_conv_dims_numbers;
3271 
3272   VLOG(2) << "New dim numbers " << new_dim_numbers.DebugString();
3273 
3274   const int64_t inherent_low_padding =
3275       convolution->window()
3276           .dimensions(GetFirstChosenSpatialDim(convolution))
3277           .padding_low();
3278 
3279   const int64_t inherent_high_padding =
3280       convolution->window()
3281           .dimensions(GetFirstChosenSpatialDim(convolution))
3282           .padding_high();
3283 
3284   std::vector<HloInstruction*> activations_chunks;
3285 
3286   // Insert slices for low padding.
3287   for (int64_t i = 0; i < inherent_low_padding; ++i) {
3288     HloInstruction* activations_to_use = nullptr;
3289     if (i == 0) {
3290       activations_to_use = activations_new;
3291     } else {
3292       activations_to_use = activations_chunks.back();
3293     }
3294     TF_ASSIGN_OR_RETURN(
3295         HloInstruction * activations_slice,
3296         HaloDuplicateWithSlice(activations_to_use, spatial_dimensions_to_split,
3297                                activations_batch_dim, /*low_padding=*/1,
3298                                /*halo_size=*/0));
3299     activations_chunks.push_back(activations_slice);
3300   }
3301   // Reverse the low padding slices because we created them in the opposite
3302   // order above.
3303   absl::c_reverse(activations_chunks);
3304 
3305   const int64_t expanded_kernel =
3306       old_kernel_split_dim_size * rhs_dilation - (rhs_dilation - 1);
3307   const int64_t overlap_count =
3308       old_split_dim_size - expanded_kernel + 1 +
3309       (inherent_low_padding < 0 ? inherent_low_padding : 0) +
3310       (inherent_high_padding < 0 ? inherent_high_padding : 0);
3311   VLOG(1) << "overlap_count " << overlap_count << " inherent_low_padding "
3312           << inherent_low_padding << " inherent_high_padding "
3313           << inherent_high_padding;
3314 
3315   const int64_t total_overlap_count =
3316       overlap_count + (inherent_low_padding > 0 ? inherent_low_padding : 0) +
3317       (inherent_high_padding > 0 ? inherent_high_padding : 0);
3318 
3319   // Insert original activations.
3320   for (int64_t i = 0; i < overlap_count; ++i) {
3321     HloInstruction* activations_to_use = nullptr;
3322     HloInstruction* activations_slice = nullptr;
3323     if (i == 0) {
3324       activations_to_use = activations_new;
3325       if (inherent_low_padding < 0) {
3326         TF_ASSIGN_OR_RETURN(
3327             activations_slice,
3328             HaloDuplicateWithSlice(
3329                 activations_to_use, spatial_dimensions_to_split,
3330                 activations_batch_dim,
3331                 /*low_padding=*/inherent_low_padding, /*halo_size=*/0));
3332       } else {
3333         activations_slice = activations_to_use;
3334       }
3335     } else {
3336       activations_to_use = activations_chunks.back();
3337 
3338       TF_ASSIGN_OR_RETURN(activations_slice,
3339                           HaloDuplicateWithSlice(
3340                               activations_to_use, spatial_dimensions_to_split,
3341                               activations_batch_dim, /*low_padding=*/-1,
3342                               /*halo_size=*/0));
3343     }
3344 
3345     activations_chunks.push_back(activations_slice);
3346   }
3347 
3348   int64_t high_padding_to_materialize = 0;
3349 
3350   if (inherent_high_padding > 0) {
3351     high_padding_to_materialize =
3352         std::max(total_overlap_count -
3353                      (std::max(overlap_count, static_cast<int64_t>(0)) +
3354                       std::max(inherent_low_padding, static_cast<int64_t>(0))),
3355                  static_cast<int64_t>(0));
3356   }
3357 
3358   // Insert slices for high padding.
3359   for (int64_t i = 0; i < high_padding_to_materialize; ++i) {
3360     HloInstruction* activations_to_use = nullptr;
3361     activations_to_use = activations_chunks.back();
3362 
3363     TF_ASSIGN_OR_RETURN(
3364         HloInstruction * activations_slice,
3365         HaloDuplicateWithSlice(activations_to_use, spatial_dimensions_to_split,
3366                                activations_batch_dim,
3367                                /*low_padding=*/-1, /*halo_size=*/0));
3368     activations_chunks.push_back(activations_slice);
3369   }
3370 
3371   for (int64_t i = 0; i < activations_chunks.size(); ++i) {
3372     std::vector<int64_t> input_sizes(
3373         activations_chunks[i]->shape().dimensions().begin(),
3374         activations_chunks[i]->shape().dimensions().end());
3375     // Insert 1-sized dimension at the end
3376     input_sizes.push_back(1);
3377     TF_ASSIGN_OR_RETURN(activations_chunks[i],
3378                         MakeReshapeHlo(input_sizes, activations_chunks[i]));
3379     VLOG(1) << "new_spatial_dimension " << new_spatial_dimension << " slice "
3380             << activations_chunks[i]->ToString();
3381   }
3382 
3383   TF_ASSIGN_OR_RETURN(
3384       activations_new,
3385       MakeConcatHlo(absl::MakeSpan(activations_chunks), new_spatial_dimension));
3386 
3387   // Reshape the kernel with additional spatial dim.
3388   std::vector<int64_t> kernel_sizes(kernel_new->shape().dimensions().begin(),
3389                                     kernel_new->shape().dimensions().end());
3390   // Insert 1-sized dimension at the end
3391   kernel_sizes.push_back(1);
3392   TF_ASSIGN_OR_RETURN(kernel_new, MakeReshapeHlo(kernel_sizes, kernel_new));
3393 
3394   auto new_window = convolution->window();
3395   new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution))
3396       ->set_padding_high(-(rhs_dilation - 1));
3397   new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution))
3398       ->set_padding_low(0);
3399   new_window.mutable_dimensions(GetFirstChosenSpatialDim(convolution))
3400       ->set_size(CeilOfRatio(new_split_dim_size, rhs_dilation));
3401 
3402   // Set the window for the additional spatial dim. This is a vanilla window.
3403   auto window_dim = new_window.add_dimensions();
3404   window_dim->set_base_dilation(1);
3405   window_dim->set_size(1);
3406   int64_t stride = 1;
3407   // This condition means there's only a single overlap possible (as the shapes
3408   // were grown due to padding). In this case, we increase the stride.
3409   if (inherent_low_padding > total_overlap_count) {
3410     stride = activations_chunks.size();
3411   }
3412   window_dim->set_stride(stride);
3413   window_dim->set_padding_low(0);
3414   window_dim->set_padding_high(0);
3415   window_dim->set_window_reversal(false);
3416   window_dim->set_window_dilation(1);
3417 
3418   TF_ASSIGN_OR_RETURN(
3419       HloInstruction * new_conv,
3420       MakeConvolveHlo(
3421           activations_new, kernel_new, convolution->feature_group_count(),
3422           convolution->batch_group_count(), new_window, new_dim_numbers,
3423           convolution->precision_config(),
3424           /*preferred_element_type=*/convolution->shape().element_type()));
3425   convolution->SetupDerivedInstruction(new_conv);
3426 
3427   VLOG(2) << "New backprop filter convolution " << new_conv->ToString();
3428 
3429   std::vector<int64_t> output_sizes(new_conv->shape().dimensions().begin(),
3430                                     new_conv->shape().dimensions().end());
3431 
3432   output_sizes.erase(output_sizes.begin() +
3433                      new_dim_numbers.output_spatial_dimensions(
3434                          GetFirstChosenSpatialDim(convolution)));
3435 
3436   TF_ASSIGN_OR_RETURN(new_conv, MakeReshapeHlo(output_sizes, new_conv));
3437 
3438   old_to_new_instrs_[convolution] = new_conv;
3439   VLOG(1) << "Space-to-featured convolution " << new_conv->ToString();
3440 
3441   std::vector<int64_t> dim_map(NumMappedDims());
3442   dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] =
3443       original_conv_dims.output_batch_dimension();
3444   dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
3445       original_conv_dims.output_feature_dimension();
3446   dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] =
3447       original_conv_dims.output_spatial_dimensions(
3448           GetFirstChosenSpatialDim(convolution));
3449   instr_to_dim_map_[convolution] = dim_map;
3450 
3451   std::vector<int64_t> trans_dims(convolution->shape().dimensions_size());
3452   absl::c_iota(trans_dims, 0);
3453   instr_to_dim_permute_map_[new_conv] = trans_dims;
3454 
3455   return OkStatus();
3456 }
3457 
3458 HloInstruction*
DoesConvolutionFeedReduceWindowOrSelectAndScatter(HloInstruction * instr,int64_t depth=kReduceWindowSearchDepth)3459 ConvolutionVisitor::DoesConvolutionFeedReduceWindowOrSelectAndScatter(
3460     HloInstruction* instr, int64_t depth = kReduceWindowSearchDepth) {
3461   if (depth == 0) {
3462     return nullptr;
3463   }
3464 
3465   for (auto user : instr->users()) {
3466     if (user->opcode() == HloOpcode::kReduceWindow ||
3467         user->opcode() == HloOpcode::kSelectAndScatter) {
3468       return user;
3469     }
3470     // Stop the search if these ops are encountered.
3471     if (user->opcode() == HloOpcode::kConvolution ||
3472         user->opcode() == HloOpcode::kPad ||
3473         user->opcode() == HloOpcode::kTranspose) {
3474       continue;
3475     }
3476     auto ret =
3477         DoesConvolutionFeedReduceWindowOrSelectAndScatter(user, depth - 1);
3478     if (ret != nullptr) {
3479       return ret;
3480     }
3481   }
3482   return nullptr;
3483 }
3484 
DoesConvolutionFeedUnpropagatableOp(HloInstruction * instr,int64_t depth)3485 bool ConvolutionVisitor::DoesConvolutionFeedUnpropagatableOp(
3486     HloInstruction* instr, int64_t depth) {
3487   auto key = std::make_pair(instr, depth);
3488   if (unpropagatability_cache_.contains(key)) {
3489     return unpropagatability_cache_[key];
3490   }
3491 
3492   if (depth == 0 || instr->user_count() == 0) {
3493     unpropagatability_cache_[key] = false;
3494     return false;
3495   }
3496 
3497   for (auto user : instr->users()) {
3498     if (IsOpcodeNonPropagatable(user)) {
3499       unpropagatability_cache_[key] = true;
3500       return true;
3501     }
3502 
3503     int64_t depth_to_use = depth;
3504     // When we see a convolution, we reduce the depth to look further for.
3505     if (user->opcode() == HloOpcode::kConvolution) {
3506       depth_to_use--;
3507     }
3508 
3509     if (DoesConvolutionFeedUnpropagatableOp(user, depth_to_use)) {
3510       unpropagatability_cache_[key] = true;
3511       return true;
3512     }
3513   }
3514 
3515   unpropagatability_cache_[key] = false;
3516   return false;
3517 }
3518 
IsSpaceToBatchedSpaceSizeSuitable(HloInstruction * instr)3519 bool ConvolutionVisitor::IsSpaceToBatchedSpaceSizeSuitable(
3520     HloInstruction* instr) {
3521   CHECK(instr->opcode() == HloOpcode::kSelectAndScatter ||
3522         instr->opcode() == HloOpcode::kReduceWindow);
3523   auto old_producer = instr->mutable_operand(0);
3524 
3525   auto dim_map_val_op = instr_to_dim_map_[old_producer];
3526   const int64_t old_space_dim =
3527       dim_map_val_op[DimMapper(SpaceToBatchDimMap::kSpace0)];
3528   auto first_operand = old_to_new_instrs_[old_producer];
3529   auto permute_dims_first_operand = instr_to_dim_permute_map_[first_operand];
3530   const int64_t new_space_dim =
3531       DimLookUp(permute_dims_first_operand, old_space_dim);
3532 
3533   const int64_t window_size = instr->window().dimensions(old_space_dim).size();
3534 
3535   if (first_operand->shape().dimensions(new_space_dim) < window_size) {
3536     return false;
3537   }
3538 
3539   return true;
3540 }
3541 
GetConvolutionDetails(HloInstruction * convolution,ConvolutionDimensionNumbers & dim_numbers)3542 ConvolutionVisitor::ConvDetails ConvolutionVisitor::GetConvolutionDetails(
3543     HloInstruction* convolution, ConvolutionDimensionNumbers& dim_numbers) {
3544   auto activations = convolution->mutable_operand(0);
3545 
3546   auto kernel = convolution->mutable_operand(1);
3547   const auto& kernel_shape = kernel->shape();
3548   const int64_t kernel_spatial_dim = dim_numbers.kernel_spatial_dimensions(
3549       GetFirstChosenSpatialDim(convolution));
3550   int64_t kernel_spatial_dim_size = kernel_shape.dimensions(kernel_spatial_dim);
3551 
3552   if (IsForwardWindowDilatedConv(convolution, dim_numbers)) {
3553     const int64_t window_dilation_factor =
3554         convolution->window()
3555             .dimensions(GetFirstChosenSpatialDim(convolution))
3556             .window_dilation();
3557     kernel_spatial_dim_size =
3558         (kernel_spatial_dim_size - 1) * (window_dilation_factor - 1) +
3559         kernel_spatial_dim_size;
3560   }
3561 
3562   std::vector<int64_t> spatial_dimensions_to_split =
3563       GetChosenSpatialDims(convolution);
3564   const int64_t spatial_dimension_to_split = spatial_dimensions_to_split[0];
3565 
3566   const int64_t input_dim_size =
3567       activations->shape().dimensions(spatial_dimension_to_split);
3568 
3569   const int64_t inherent_low_padding =
3570       convolution->window()
3571           .dimensions(GetFirstChosenSpatialDim(convolution))
3572           .padding_low();
3573   const int64_t inherent_high_padding =
3574       convolution->window()
3575           .dimensions(GetFirstChosenSpatialDim(convolution))
3576           .padding_high();
3577 
3578   const int64_t stride = convolution->window()
3579                              .dimensions(GetFirstChosenSpatialDim(convolution))
3580                              .stride();
3581 
3582   const int64_t base_dilation_factor =
3583       convolution->window()
3584           .dimensions(GetFirstChosenSpatialDim(convolution))
3585           .base_dilation();
3586 
3587   bool is_base_dilated = base_dilation_factor > 1;
3588 
3589   const int64_t spatial_size = input_dim_size +
3590                                (is_base_dilated ? 0 : inherent_low_padding) +
3591                                inherent_high_padding;
3592 
3593   const int64_t last_overlap = base_dilation_factor == inherent_low_padding
3594                                    ? kernel_spatial_dim_size
3595                                    : kernel_spatial_dim_size - 1;
3596   const int64_t halo_size = is_base_dilated
3597                                 ? last_overlap / base_dilation_factor
3598                                 : kernel_spatial_dim_size - 1;
3599 
3600   const int64_t high_padding_for_base_dilation =
3601       inherent_low_padding == 0 ? base_dilation_factor - 1
3602                                 : last_overlap % base_dilation_factor;
3603 
3604   const int64_t high_padding_for_conv =
3605       is_base_dilated ? high_padding_for_base_dilation : 0;
3606 
3607   const int64_t low_padding_for_conv =
3608       is_base_dilated && (base_dilation_factor != inherent_low_padding)
3609           ? inherent_low_padding
3610           : 0;
3611 
3612   return ConvDetails{spatial_dimensions_to_split,
3613                      inherent_low_padding,
3614                      inherent_high_padding,
3615                      stride,
3616                      spatial_size,
3617                      base_dilation_factor,
3618                      halo_size,
3619                      high_padding_for_conv,
3620                      low_padding_for_conv,
3621                      kernel_spatial_dim_size,
3622                      input_dim_size};
3623 }
3624 
PerformSpaceToBatchOnConvolution(HloInstruction * convolution)3625 Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
3626     HloInstruction* convolution) {
3627   if (!ConsumeFuel("space-to-batch-converter", [&] {
3628         return "Skipping space-to-batch propagation because fuel over\n";
3629       })) {
3630     return OkStatus();
3631   }
3632   VLOG(1) << "Handling conv " << convolution->ToString();
3633 
3634   changed_ = false;
3635 
3636   ConvolutionDimensionNumbers dim_numbers =
3637       convolution->convolution_dimension_numbers();
3638 
3639   ConvDetails c = GetConvolutionDetails(convolution, dim_numbers);
3640 
3641   int64_t activations_batch_dim = dim_numbers.input_batch_dimension();
3642 
3643   auto activations = convolution->mutable_operand(0);
3644 
3645   VLOG(1) << "spatial size " << c.spatial_size;
3646 
3647   // A very primitive cost model to thwart propagations on tiny shapes.
3648   if (c.spatial_size < 2 * ctrl_.number_of_splits) {
3649     return OkStatus();
3650   }
3651 
3652   auto original_conv = convolution;
3653 
3654   const int64_t output_spatial_dim = dim_numbers.output_spatial_dimensions(
3655       GetFirstChosenSpatialDim(convolution));
3656   const int64_t output_offsets =
3657       convolution->shape().dimensions(output_spatial_dim);
3658   const int64_t output_offsets_per_split =
3659       CeilOfRatio(output_offsets, ctrl_.number_of_splits);
3660 
3661   int64_t spatial_split_size =
3662       CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
3663   // Keep increasing the split size so that overall size isn't smaller than the
3664   // original spatial dimension.
3665   while (spatial_split_size * ctrl_.number_of_splits - c.spatial_size < 0) {
3666     spatial_split_size += c.stride;
3667   }
3668 
3669   auto reduce_window_or_select_and_scatter =
3670       DoesConvolutionFeedReduceWindowOrSelectAndScatter(convolution);
3671 
3672   if (reduce_window_or_select_and_scatter != nullptr &&
3673       reduce_window_or_select_and_scatter->shape().IsArray() &&
3674       reduce_window_or_select_and_scatter->shape().rank() ==
3675           convolution->shape().rank()) {
3676     VLOG(2)
3677         << "DoesConvolutionFeedReduceWindowOrSelectAndScatter returned true";
3678     // Take into account the stride of the reduce window while choosing the
3679     // spatial_split_size. This will guarantee propagation through reduce
3680     // windows.
3681     const int64_t win_stride =
3682         std::max(reduce_window_or_select_and_scatter->window()
3683                      .dimensions(output_spatial_dim)
3684                      .stride(),
3685                  static_cast<int64_t>(1));
3686     CHECK_NE(win_stride, 0)
3687         << "Bad op " << reduce_window_or_select_and_scatter->ToString();
3688     CHECK_NE(c.stride, 0) << "Bad op " << convolution->ToString();
3689     while ((spatial_split_size / c.stride) % win_stride != 0) {
3690       spatial_split_size += c.stride;
3691     }
3692   }
3693 
3694   const int64_t slice_size = spatial_split_size + c.halo_size;
3695 
3696   const int64_t low_pad_to_handle_base_dilation =
3697       (c.base_dilation_factor > 1 &&
3698        c.base_dilation_factor == c.inherent_low_padding)
3699           ? 1
3700           : 0;
3701 
3702   // Pad spatial dim.
3703   int64_t pad_size =
3704       spatial_split_size * ctrl_.number_of_splits - c.spatial_size;
3705 
3706   bool handle_low_pad_in_first_reshape = false;
3707   if (pad_size > low_pad_to_handle_base_dilation) {
3708     pad_size -= low_pad_to_handle_base_dilation;
3709     handle_low_pad_in_first_reshape = true;
3710   }
3711 
3712   VLOG(1) << "spatial_split_size " << spatial_split_size << " stride "
3713           << c.stride << " slice_size " << slice_size;
3714   VLOG(1) << "spatial_dimension_to_split " << c.spatial_dimensions_to_split[0]
3715           << " num_splits " << ctrl_.number_of_splits
3716           << " kernel_spatial_dim_size " << c.kernel_spatial_dim_size;
3717   std::vector<int64_t> spatial_dimensions_to_split =
3718       c.spatial_dimensions_to_split;
3719   TF_ASSIGN_OR_RETURN(
3720       auto retval,
3721       SplitSpace(
3722           activations, dim_numbers, activations_batch_dim,
3723           /*high_padding=*/c.inherent_high_padding + pad_size,
3724           /*low_padding=*/c.base_dilation_factor == 1 ? c.inherent_low_padding
3725           : handle_low_pad_in_first_reshape ? low_pad_to_handle_base_dilation
3726                                             : 0,
3727           spatial_split_size, ctrl_.number_of_splits,
3728           &spatial_dimensions_to_split));
3729   HloInstruction* batch_increased_reshape = retval.first;
3730   convolution->SetupDerivedInstruction(batch_increased_reshape);
3731 
3732   VLOG(1) << "First reshape done " << batch_increased_reshape->ToString();
3733 
3734   TF_ASSIGN_OR_RETURN(
3735       activations,
3736       HaloDuplicateWithSlice(
3737           batch_increased_reshape, spatial_dimensions_to_split,
3738           activations_batch_dim,
3739           /*low_padding=*/
3740           handle_low_pad_in_first_reshape ? 0 : low_pad_to_handle_base_dilation,
3741           c.halo_size));
3742 
3743   VLOG(1) << "Batch merge done " << activations->ToString();
3744 
3745   // Now, we rewrite the convolution with a larger batch.
3746 
3747   // Create the new convolution dim numbers.
3748   auto new_dim_numbers = dim_numbers;
3749 
3750   // We will generate output such that batch is followed by the split spatial
3751   // dimension.
3752   const int64_t rank = convolution->shape().rank();
3753   std::vector<int64_t> transpose_dims(rank);
3754   int dim_count = 0;
3755   std::map<int64_t, int64_t> dim_translator;
3756 
3757   for (int j = 0; j < dim_numbers.output_spatial_dimensions_size(); ++j) {
3758     if (j == GetFirstChosenSpatialDim(convolution)) {
3759       dim_translator[dim_numbers.output_batch_dimension()] = dim_count;
3760       new_dim_numbers.set_output_batch_dimension(dim_count++);
3761     }
3762     dim_translator[dim_numbers.output_spatial_dimensions(j)] = dim_count;
3763     new_dim_numbers.set_output_spatial_dimensions(j, dim_count);
3764     dim_count++;
3765   }
3766 
3767   dim_translator[dim_numbers.output_feature_dimension()] = dim_count;
3768   new_dim_numbers.set_output_feature_dimension(dim_count);
3769 
3770   int p = 0;
3771   for (const auto& entry : dim_translator) {
3772     transpose_dims[p] = entry.second;
3773     p++;
3774   }
3775   VLOG(1) << "New dim numbers " << new_dim_numbers.DebugString()
3776           << " batch dim " << new_dim_numbers.input_batch_dimension();
3777   auto new_window = convolution->window();
3778   const int64_t first_dim = GetFirstChosenSpatialDim(convolution);
3779   for (int i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
3780     new_window.mutable_dimensions(first_dim + i)
3781         ->set_padding_high(c.high_padding_for_conv);
3782     new_window.mutable_dimensions(first_dim + i)
3783         ->set_padding_low(c.low_padding_for_conv);
3784   }
3785   TF_ASSIGN_OR_RETURN(
3786       HloInstruction * new_conv,
3787       MakeConvolveHlo(
3788           activations, /*rhs=*/convolution->mutable_operand(1),
3789           convolution->feature_group_count(), convolution->batch_group_count(),
3790           new_window, new_dim_numbers, convolution->precision_config(),
3791           /*preferred_element_type=*/convolution->shape().element_type()));
3792   convolution->SetupDerivedInstruction(new_conv);
3793 
3794   // If the activations were to be batch-to-spaced again, simply use the
3795   // original value.
3796   batch_to_space_map_[convolution->mutable_operand(0)] =
3797       convolution->mutable_operand(0);
3798 
3799   VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
3800 
3801   std::vector<int64_t> new_output_split_spatial_dims(
3802       ctrl_.count_of_dimensions_to_convert),
3803       old_output_split_spatial_dims(ctrl_.count_of_dimensions_to_convert);
3804   for (int i = 0; i < ctrl_.count_of_dimensions_to_convert; ++i) {
3805     old_output_split_spatial_dims[i] =
3806         dim_numbers.output_spatial_dimensions(first_dim + i);
3807     new_output_split_spatial_dims[i] =
3808         new_dim_numbers.output_spatial_dimensions(first_dim + i);
3809   }
3810 
3811   const int64_t output_batch_dim = new_dim_numbers.output_batch_dimension();
3812 
3813   auto select_val = computation_->AddInstruction(HloInstruction::CreateConstant(
3814       LiteralUtil::Zero(new_conv->shape().element_type())));
3815 
3816   TF_ASSIGN_OR_RETURN(
3817       new_conv,
3818       SelectValidPortion(new_conv, original_conv, select_val, output_batch_dim,
3819                          new_output_split_spatial_dims,
3820                          dim_numbers.output_batch_dimension(),
3821                          old_output_split_spatial_dims));
3822   old_to_new_instrs_[original_conv] = new_conv;
3823 
3824   std::vector<int64_t> dim_map(NumMappedDims());
3825   dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] =
3826       dim_numbers.output_batch_dimension();
3827   dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] =
3828       dim_numbers.output_feature_dimension();
3829   dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] =
3830       dim_numbers.output_spatial_dimensions(
3831           GetFirstChosenSpatialDim(convolution));
3832   instr_to_dim_map_[original_conv] = dim_map;
3833 
3834   instr_to_dim_permute_map_[new_conv] = std::vector<int64_t>(transpose_dims);
3835   if (non_propagatable_instrs_.count(convolution) > 0) {
3836     non_propagatable_instrs_.erase(convolution);
3837   }
3838   TF_CHECK_OK(PropagateOnUsers(original_conv));
3839 
3840   changed_ = true;
3841 
3842   return OkStatus();
3843 }
3844 
3845 }  // namespace
3846 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)3847 StatusOr<bool> SpaceToBatchConverter::Run(
3848     HloModule* module,
3849     const absl::flat_hash_set<absl::string_view>& execution_threads) {
3850   XLA_VLOG_LINES(
3851       2, "SpaceToBatchConverter::Run(), before:\n" + module->ToString());
3852   bool changed = false;
3853 
3854   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
3855     ConvolutionVisitor visitor(ctrl_, comp);
3856     if (visitor.Run().ValueOrDie()) {
3857       changed = true;
3858     }
3859     VLOG(1) << "Done operating on computation";
3860   }
3861   XLA_VLOG_LINES(2,
3862                  "SpaceToBatchConverter::Run(), after:\n" + module->ToString());
3863   return changed;
3864 }
3865 
3866 }  // namespace xla
3867