• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
33 
34 namespace xla {
35 namespace spmd {
36 
37 struct SpmdPartitionerOptions {
38   // Always exchange halo on LHS for all convolutions. If false, backprop filter
39   // convolution exchanges halo on RHS.
40   bool conv_halo_exchange_always_on_lhs = true;
41 
42   // The number of instructions to be reported for the highest memory profile
43   // instructions.
44   int64 report_instruction_count = 5;
45 
46   // The minimum size in MiB of an einsum operand to be considered using
47   // windowed implementation in an HLO loop.
48   int64 threshold_for_windowed_einsum_mib = 256;
49 
50   // Whether unroll windowed einsum loop by degree of two.
51   bool unroll_windowed_einsum = false;
52 
53   // Whether doing bidirectional collective permute in windowed einsum loop.
54   bool bidirectional_windowed_einsum = false;
55 
56   // Whether the entry computations' signature could change after partitioning.
57   bool allow_module_signature_change = false;
58 
59   // Whether to use cached all-gather to avoid repeatedly replicate a tiled
60   // tensor. If it is set to false, the result tends to be more
61   // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE
62   // pass to CSE some all-gathers which are relatively close to each other.
63   bool cache_all_gather = true;
64 
65   // When making a compromise between windowed einsum speed and memory usage
66   // prefer the former if true.
67   bool choose_faster_windowed_einsum_over_mem = false;
68 
69   // Whether doing bidirectional communication when decomposing independent
70   // all-gathers.
71   bool bidirectional_decomposed_all_gather = false;
72 };
73 
74 // Class to wrap the computation builder to capture information during SPMD
75 // transformation.
76 class SpmdBuilder : public HloComputation::Builder {
77  public:
SpmdBuilder(const std::string & name,HloInstruction * hlo)78   SpmdBuilder(const std::string& name, HloInstruction* hlo)
79       : HloComputation::Builder(name) {
80     visiting_hlo_ = hlo;
81   }
82 
83   HloInstruction* AddInstruction(
84       std::unique_ptr<HloInstruction> instruction) override;
85 
derived_instructions(HloInstruction * hlo)86   const std::vector<HloInstruction*>& derived_instructions(
87       HloInstruction* hlo) {
88     return instructions_.at(hlo);
89   }
90 
set_visiting_hlo(HloInstruction * hlo)91   void set_visiting_hlo(HloInstruction* hlo) {
92     visiting_hlo_ = hlo;
93     instructions_[hlo];
94   }
95 
visiting_hlo()96   HloInstruction* visiting_hlo() const { return visiting_hlo_; }
97 
98   // Wrapper of queries to broadcast_dims_.
BroadcastDimsForCreatedHlo(const HloInstruction * hlo)99   absl::optional<const absl::flat_hash_set<int64>*> BroadcastDimsForCreatedHlo(
100       const HloInstruction* hlo) {
101     auto it = broadcast_dims_.find(hlo);
102     if (it == broadcast_dims_.end()) {
103       return absl::nullopt;
104     }
105     return &it->second;
106   }
107 
108  private:
109   // Currently visiting instruction.
110   HloInstruction* visiting_hlo_;
111 
112   // Map from the currently visiting (old) instruction to new instructions
113   // created during SPMD partitioning.
114   HloInstructionMap<std::vector<HloInstruction*>> instructions_;
115 
116   // Maps from each created instruction to a set of dimensions that are from
117   // broadcasts or elementwise ops over broadcasts. This means elements along
118   // these dimensions have the same value.
119   absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<int64>>
120       broadcast_dims_;
121 };
122 
123 // A set of functions that create the cross-partition collective ops.
124 struct SPMDCollectiveOpsCreator {
125   // Function used to create a partition ID HLO.
126   std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
127 
128   // Function used to create a cross-partition all-reduce HLO.
129   std::function<HloInstruction*(
130       SpmdBuilder*, HloInstruction* operand, HloComputation* reduction,
131       const std::vector<std::vector<int64>>& partition_subgroups,
132       int64_t channel_id)>
133       create_cross_partition_all_reduce;
134 
135   // Function used to create a cross-partition collective-permute HLO.
136   std::function<HloInstruction*(
137       SpmdBuilder*, HloInstruction* operand,
138       std::vector<std::pair<int64, int64>>& src_dst_pairs,
139       int64_t next_channel_id)>
140       create_cross_partition_collective_permute;
141 
142   // Function used to create a cross-partition all-to-all HLO.
143   std::function<HloInstruction*(
144       SpmdBuilder*, absl::Span<HloInstruction* const> operands,
145       const std::vector<std::vector<int64>>& partition_subgroups,
146       int64_t channel_id, absl::optional<int64> split_dimension)>
147       create_cross_partition_all_to_all;
148 
149   // Function used to create a cross-partition all-gather HLO. This is optional:
150   // if it is nullptr, the partitioner will use all-reduce instead.
151   std::function<HloInstruction*(
152       SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
153       const std::vector<std::vector<int64>>& partition_subgroups,
154       int64_t channel_id, int64_t all_gather_dimension)>
155       create_cross_partition_all_gather;
156 };
157 
158 // Create a default SPMDCollectiveOpsCreator.
159 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
160                                                         int64_t num_replicas);
161 
162 // Logger to report memory usage during SPMD partitioning.
163 class SpmdLogger {
164  public:
SpmdLogger(int64_t report_instruction_count)165   explicit SpmdLogger(int64_t report_instruction_count)
166       : report_instruction_count_(report_instruction_count) {}
167   static std::string ReportBeforePartition(const HloModule& module,
168                                            int64_t report_instruction_count);
169   static std::string ReportAfterPartition(const HloModule& module,
170                                           int64_t report_instruction_count);
171 
172   // Registers the logging for the groups of instructions created to transform
173   // the given hlo.
174   void RegisterLogEntry(HloInstruction* hlo,
175                         const std::vector<HloInstruction*>& group);
176 
177   std::string MakeReport();
178 
179  private:
180   template <typename F>
181   static std::string ReportMemoryUsage(const HloModule& module, const F& filter,
182                                        int64_t report_instruction_count);
183 
184   // A vector of logging messages (one for each original HLO instruction), where
185   // the first integer of the pair represents the size of the HBM used.
186   std::vector<std::pair<int64, std::string>> entries_;
187 
188   int64 report_instruction_count_;
189 };
190 
191 class SpmdPartitioningVisitor;
192 
193 class SpmdPartitioner : public HloModulePass {
194  public:
195   SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
196                   SpmdPartitionerOptions options);
SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options,SPMDCollectiveOpsCreator collective_ops_creator)197   SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
198                   SpmdPartitionerOptions options,
199                   SPMDCollectiveOpsCreator collective_ops_creator)
200       : num_partitions_(num_partitions),
201         num_replicas_(num_replicas),
202         options_(std::move(options)),
203         collective_ops_creator_(std::move(collective_ops_creator)) {}
name()204   absl::string_view name() const override { return "spmd-partitioning"; }
205   StatusOr<bool> Run(HloModule* module) override;
206 
207   // Transforms the given computation with SPMD instructions, replacing it with
208   // a new computation.
209   StatusOr<bool> PartitionComputation(HloComputation* computation,
210                                       const HloSharding& root_sharding,
211                                       int64* next_channel_id,
212                                       SpmdLogger* logger);
213 
214   // Creates all-gather(s) based on HloSharding. Can be overridden to customize.
215   // The default uses a single all-gather even if there are multiple sharded
216   // dimensions, and adds potential reshapes and transposes to achieve that.
217   // If it returns false, the partitioner will fall back to all-reduce.
218   // `selected_dims` specifies the dimensions along which the all-gather happens
219   // in the tiled sharding, which allows potentially creating a subgroup
220   // all-gather.
221   virtual HloInstruction* AllGatherShards(
222       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
223       int64* next_channel_id, absl::Span<const int64> selected_dims,
224       const SPMDCollectiveOpsCreator& collectives_creator);
225 
226   // Creates all-reduce(s) across devices along selected_dims in sharding. Can
227   // be overridden to customize.
228   virtual HloInstruction* AllReduceAlongShardingDims(
229       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
230       int64* next_channel_id, absl::Span<const int64> selected_dims,
231       const SPMDCollectiveOpsCreator& collectives_creator,
232       HloComputation* reduction);
233 
options()234   const SpmdPartitionerOptions& options() { return options_; }
235 
236  protected:
237   virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
238       HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
239       const SPMDCollectiveOpsCreator& collective_ops_creator,
240       int64* next_channel_id, SpmdLogger* logger,
241       SpmdPartitionerOptions options);
242 
243   HloInstruction* AllGatherShardsInternal(
244       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
245       int64* next_channel_id, absl::Span<const int64> selected_dims,
246       const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag);
247   HloInstruction* AllReduceAlongShardingDimsInternal(
248       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
249       int64* next_channel_id, absl::Span<const int64> selected_dims,
250       const SPMDCollectiveOpsCreator& collectives_creator,
251       HloComputation* reduction, bool per_dim_ar);
252 
253   // Verifies that the sharding of instructions in the module are valid, and
254   // also fill in missing sharding information.
255   virtual Status PreprocessSharding(HloModule* module);
256 
257   // Returns if the given side-effecting instruction is allowed to have
258   // replicated sharding.
CanSideEffectingHaveReplicatedSharding(const HloInstruction * hlo)259   virtual bool CanSideEffectingHaveReplicatedSharding(
260       const HloInstruction* hlo) {
261     return hlo->opcode() == HloOpcode::kInfeed ||
262            hlo->opcode() == HloOpcode::kOutfeed;
263   }
264 
265   // Preprocesses the graph to simplify some communication patterns. E.g., merge
266   // pad->slice into a single pad with potentially negative padding to avoid
267   // multiple halo exchanges.
268   Status PreprocessHlos(HloModule* module);
269 
270   const int64 num_partitions_;
271   const int64 num_replicas_;
272 
273   SpmdPartitionerOptions options_;
274   SPMDCollectiveOpsCreator collective_ops_creator_;
275   std::vector<std::vector<int64>> device_groups_;
276 };
277 
278 // Class describes partition state of the data represented by an HLO created
279 // during SPMD partitioning pass.
280 //
281 // Data on some devices may include padding region, if the base (full) shape
282 // could not be evenly partitioned.
283 class PartitionedHlo {
284  public:
285   // Return value for ReshardAsWindowedInput which describes the resharded HLO,
286   // the window for the user on the shard, and if necessary, the dynamic slice
287   // offsets to be applied to the output of the op being sharded.
288   struct WindowedInputShardReturnValue {
289     HloInstruction* sharded_input;
290     Window shard_window;
291     absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output;
292   };
293   // A cache for resharding each partitioned HLO.
294   struct ReshardCache {
295     struct PerHloCache {
296       std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
297       std::vector<
298           std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
299           window_reshard_cache;
300     };
301     // Use std::unordered_map for pointer stability.
302     std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
303     // Caches for nested partitioning of grouped sharding. Each string key
304     // represents a unique way of grouping devices.
305     absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>>
306         groupd_caches;
307   };
308   struct PartitioningState {
309     SpmdBuilder* b;
310     HloModule* module;
311     int64 num_replicas;
312     HloInstruction* partition_id;
313     SPMDCollectiveOpsCreator collective_ops_creator;
314     int64* next_channel_id;
315     ReshardCache* reshard_cache;
316     SpmdPartitioner* partitioner;
317   };
PartitionedHlo(HloInstruction * hlo,Shape base_shape,PartitioningState state)318   PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
319       : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
320     CHECK(hlo->has_sharding())
321         << "PartitionedHlo is missing sharding:" << hlo->ToString();
322     // If the tuple shape instruction does not have a tuple sharding, reassign
323     // to use the tuple sharding. Reshard() implementation assumes this.
324     if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) {
325       hlo_->set_sharding(
326           hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie());
327     }
328   }
329 
330   // Reshards the current SPMD instruction to a new sharding. Could only modify
331   // the reshard cache.
332   PartitionedHlo Reshard(const HloSharding& target);
333 
334   // Pads the garbage area of the output with the provided value. Normally,
335   // unevenly partitioned dimensions are padded on the right, but this function
336   // allows specifying left-padded dimensions, which can be used during the
337   // handling of kReverse, etc.
338   PartitionedHlo PadWithValue(HloInstruction* pad_value,
339                               absl::Span<const int64> left_padded_dims = {},
340                               absl::Span<const int64> skipped_dims = {}) const;
341 
342   // Returns the SPMD instruction.
hlo()343   HloInstruction* hlo() const { return hlo_; }
344 
345   // Returns the sharding of the SPMD instruction.
sharding()346   const HloSharding& sharding() const { return hlo_->sharding(); }
347 
348   // Original full shape of the data.
base_shape()349   const Shape& base_shape() const { return base_shape_; }
350 
NewChannel()351   int64 NewChannel() const { return (*state_.next_channel_id)++; }
352 
353   // Reshards the HLO to a usable partitioned input for a windowed user. Could
354   // only modify the reshard cache.
355   absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput(
356       const Window& window, const HloSharding& target,
357       HloInstruction* pad_value, bool mask_invalid_region = true);
358 
state()359   const PartitioningState& state() const { return state_; }
360 
361   // Helper function to replicate the data on all devices. Could only modify
362   // the reshard cache.
363   PartitionedHlo Replicate();
364 
365   // Helper function to replicate the data for partitions along the given dims.
366   HloInstruction* ReplicatePartial(absl::Span<const int64> dims);
367 
368   // Set state of the partitoned HLO.
set_state(PartitioningState state)369   void set_state(PartitioningState state) { state_ = std::move(state); }
370 
371  private:
372   // Same as Reshard except that it does not explicitly modify the reshard
373   // cache, although it would indirectly modify by calling Replicate().
374   PartitionedHlo ReshardNoCache(const HloSharding& target);
375 
376   // Helper function to broadcast data from a single device to all devices.
377   PartitionedHlo Broadcast() const;
378 
379   // Helper function to reshard the tensor using AllToAll (instead of the
380   // default of Replicate followed by Slice).
381   PartitionedHlo ReshardWithAllToAll(
382       const HloSharding& target,
383       absl::Span<const std::pair<int64, int64>> source_target_dims) const;
384 
385   // Helper function to reshard the tensor using CollectivePermute.
386   PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
387 
388   // Helper function to reshard to partial replicate using AllGather.
389   absl::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather(
390       const HloSharding& target);
391 
392   // Helper function to reshard from partial replicate using DynamicSlice.
393   absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
394       const HloSharding& target);
395 
396   // Helper function to reshard from partial replicate using AllToAll.
397   absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
398       const HloSharding& target);
399 
400   // SPMD instruction.
401   HloInstruction* hlo_;
402 
403   // The original shape of the data before SPMD transformation is applied.
404   Shape base_shape_;
405 
406   PartitioningState state_;
407 };
408 
409 struct DotConvDimsMapping {
410   // The dimension numbers for the operands and output corresponding to a
411   // logical dimension (e.g., batch, contracting, non-contracting). If an
412   // operand or the output doesn't have the logical dimension, it is set to
413   // -1.
414   struct DimsMapping {
415     int64 lhs;
416     int64 rhs;
417     int64 output;
418     // input mapped to index in input_spatial_dimensions().
419     int64 spatial;
420   };
421   std::vector<DimsMapping> batch_dims;
422   std::vector<DimsMapping> contracting_dims;
423   std::vector<DimsMapping> lhs_non_contracting_dims;
424   std::vector<DimsMapping> rhs_non_contracting_dims;
425   std::vector<DimsMapping> conv_spatial_dims;
426 };
427 
428 class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
429  public:
430   SpmdPartitioningVisitor(
431       HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
432       const SPMDCollectiveOpsCreator& collective_ops_creator,
433       int64* next_channel_id, SpmdLogger* logger,
434       SpmdPartitionerOptions options, SpmdPartitioner* partitioner);
435 
436   Status DefaultAction(HloInstruction* hlo) override;
437   Status HandleAllReduce(HloInstruction* hlo) override;
438   Status HandleBroadcast(HloInstruction* hlo) override;
439   Status HandleConstant(HloInstruction* hlo) override;
440   Status HandleCustomCall(HloInstruction* hlo) override;
441   Status HandleDot(HloInstruction* hlo) override;
442   Status HandleDynamicSlice(HloInstruction* hlo) override;
443   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
444   Status HandleFft(HloInstruction* hlo) override;
445   Status HandleGather(HloInstruction* hlo) override;
446   Status HandleGetTupleElement(HloInstruction* hlo) override;
447   Status HandleInfeed(HloInstruction* hlo) override;
448   Status HandleOutfeed(HloInstruction* hlo) override;
449   Status HandlePad(HloInstruction* hlo) override;
450   Status HandleParameter(HloInstruction* hlo) override;
451   Status HandleReduce(HloInstruction* hlo) override;
452   Status HandleReverse(HloInstruction* hlo) override;
453   Status HandleWhile(HloInstruction* hlo) override;
454   Status HandleConditional(HloInstruction* hlo) override;
455   Status HandleReduceWindow(HloInstruction* hlo) override;
456   Status HandleSelectAndScatter(HloInstruction* hlo) override;
457   Status HandleTuple(HloInstruction* hlo) override;
458   Status HandleRng(HloInstruction* hlo) override;
459   Status HandleConvolution(HloInstruction* hlo) override;
460   Status HandleConcatenate(HloInstruction* hlo) override;
461   Status HandleScatter(HloInstruction* hlo) override;
462   Status HandleSlice(HloInstruction* hlo) override;
463   Status HandleSort(HloInstruction* hlo) override;
464   Status HandleTranspose(HloInstruction* hlo) override;
465   Status HandleReshape(HloInstruction* hlo) override;
466   Status HandleIota(HloInstruction* hlo) override;
467   Status HandlePartitionId(HloInstruction* hlo) override;
468 
469   // Implementation of dot partitioning given DotGeneralDimsMapping.
470   Status HandleDotHelper(HloInstruction* hlo,
471                          const DotConvDimsMapping& dims_mapping,
472                          const std::function<StatusOr<HloInstruction*>(
473                              HloInstruction*, HloInstruction*, SpmdBuilder*,
474                              const Window& conv_window)>& create_sharded_dot);
475 
476   // Common handle for elementwise HLOs.
477   Status HandleElementwise(HloInstruction* hlo);
478 
479   // Common handle for HLOs that runs on a single device.
480   Status HandleSingleDevice(const HloInstruction* hlo);
481 
482   // CustomCall handlers per call target.
483   Status HandleCustomCallTopK(HloInstruction* hlo);
484   // Convenient custom ops defined by the partitioner itself.
485   Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
486 
487   // Returns the PartitionedHlo that corresponds to the original hlo.
GetPartitionedHlo(const HloInstruction * hlo)488   PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) {
489     CHECK_EQ(partitioned_instructions_.count(hlo), 1);
490     return partitioned_instructions_.find(hlo)->second;
491   }
492 
493   // Sets the PartitionedHlo for the original hlo.
SetPartitionedHlo(const HloInstruction * hlo,const PartitionedHlo & partitioned_hlo)494   void SetPartitionedHlo(const HloInstruction* hlo,
495                          const PartitionedHlo& partitioned_hlo) {
496     CHECK_EQ(partitioned_instructions_.count(hlo), 0);
497     partitioned_instructions_.emplace(hlo, partitioned_hlo);
498     changed_ = true;
499   }
500 
501   // Convenient wrapper that creates PartitionedHlo from the result of the func
502   // and maps it to the given original hlo.
SetPartitionedHlo(const HloInstruction * hlo,const std::function<HloInstruction * ()> & func)503   void SetPartitionedHlo(const HloInstruction* hlo,
504                          const std::function<HloInstruction*()>& func) {
505     HloInstruction* new_hlo = func();
506     new_hlo->set_sharding(hlo->sharding());
507     SetPartitionedHlo(
508         hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
509     changed_ = true;
510   }
511 
NewChannel()512   int64 NewChannel() { return (*next_channel_id_)++; }
513 
514   PartitionedHlo::PartitioningState MakePartitioningState();
515 
builder()516   SpmdBuilder* builder() { return &b_; }
517 
518   virtual StatusOr<bool> DoPartition(HloComputation* computation,
519                                      const HloSharding& root_sharding,
520                                      const SpmdPartitionerOptions& options);
521 
522   // Information about a loop created for windowed dot-general. Used when
523   // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
524   // finishes traversing the graph.
525   struct WindowedDotGeneralLoop {
526     HloInstruction* while_loop;
527     int64 windowed_operand;
528     bool windowed_in_contracting_dims;
529     bool windowed_in_batch_dims;
530     bool operands_sharded_at_contracting_dims;
531     int64 num_partitions;
532     std::vector<ReplicaGroup> loop_replica_groups;
533   };
534 
535  protected:
536   Status Preprocess(HloInstruction* hlo) override;
537   Status Postprocess(HloInstruction* hlo) override;
538 
539   // Performs code motion for windowed dot-general loops in
540   // windowed_dot_general_loops_. Invoked after the visitor finishes traversing
541   // the graph.
542   Status DoCodeMotionForWindowedDotGeneralLoops(
543       HloComputation* computation, const SpmdPartitionerOptions& options);
544 
545   bool changed_;
546   HloModule* module_;
547   int64 num_partitions_;
548   int64 num_replicas_;
549 
550   SPMDCollectiveOpsCreator collective_ops_creator_;
551 
552   // Tracks the next channel id to use for cross-partition all-reduce.
553   int64* next_channel_id_;
554   SpmdBuilder b_;
555 
556   std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
557 
558   HloInstruction* partition_id_;
559 
560  private:
561   PartitionedHlo::ReshardCache reshard_cache_;
562 
563   // Mapping from the instruction in the original computation to the new SPMD
564   // partitioned instruction.
565   ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
566 
567   HloInstruction* visiting_hlo_;
568   SpmdLogger* logger_;
569   const SpmdPartitionerOptions options_;
570   SpmdPartitioner* partitioner_;
571   std::vector<HloSharding> visiting_hlo_operand_shardings_;
572   absl::optional<HloSharding> visiting_hlo_sharding_;
573   absl::optional<int64> visiting_num_partitions_;
574   std::vector<PartitionedHlo::PartitioningState> visiting_state_;
575   std::vector<std::vector<int64>> device_groups_;
576 };
577 
578 }  // namespace spmd
579 }  // namespace xla
580 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
581