1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/types/optional.h" 26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 32 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 33 34 namespace xla { 35 namespace spmd { 36 37 struct SpmdPartitionerOptions { 38 // Always exchange halo on LHS for all convolutions. If false, backprop filter 39 // convolution exchanges halo on RHS. 40 bool conv_halo_exchange_always_on_lhs = true; 41 42 // The number of instructions to be reported for the highest memory profile 43 // instructions. 44 int64 report_instruction_count = 5; 45 46 // The minimum size in MiB of an einsum operand to be considered using 47 // windowed implementation in an HLO loop. 48 int64 threshold_for_windowed_einsum_mib = 256; 49 50 // Whether unroll windowed einsum loop by degree of two. 51 bool unroll_windowed_einsum = false; 52 53 // Whether doing bidirectional collective permute in windowed einsum loop. 54 bool bidirectional_windowed_einsum = false; 55 56 // Whether the entry computations' signature could change after partitioning. 57 bool allow_module_signature_change = false; 58 59 // Whether to use cached all-gather to avoid repeatedly replicate a tiled 60 // tensor. If it is set to false, the result tends to be more 61 // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE 62 // pass to CSE some all-gathers which are relatively close to each other. 63 bool cache_all_gather = true; 64 65 // When making a compromise between windowed einsum speed and memory usage 66 // prefer the former if true. 67 bool choose_faster_windowed_einsum_over_mem = false; 68 69 // Whether doing bidirectional communication when decomposing independent 70 // all-gathers. 71 bool bidirectional_decomposed_all_gather = false; 72 }; 73 74 // Class to wrap the computation builder to capture information during SPMD 75 // transformation. 76 class SpmdBuilder : public HloComputation::Builder { 77 public: SpmdBuilder(const std::string & name,HloInstruction * hlo)78 SpmdBuilder(const std::string& name, HloInstruction* hlo) 79 : HloComputation::Builder(name) { 80 visiting_hlo_ = hlo; 81 } 82 83 HloInstruction* AddInstruction( 84 std::unique_ptr<HloInstruction> instruction) override; 85 derived_instructions(HloInstruction * hlo)86 const std::vector<HloInstruction*>& derived_instructions( 87 HloInstruction* hlo) { 88 return instructions_.at(hlo); 89 } 90 set_visiting_hlo(HloInstruction * hlo)91 void set_visiting_hlo(HloInstruction* hlo) { 92 visiting_hlo_ = hlo; 93 instructions_[hlo]; 94 } 95 visiting_hlo()96 HloInstruction* visiting_hlo() const { return visiting_hlo_; } 97 98 // Wrapper of queries to broadcast_dims_. BroadcastDimsForCreatedHlo(const HloInstruction * hlo)99 absl::optional<const absl::flat_hash_set<int64>*> BroadcastDimsForCreatedHlo( 100 const HloInstruction* hlo) { 101 auto it = broadcast_dims_.find(hlo); 102 if (it == broadcast_dims_.end()) { 103 return absl::nullopt; 104 } 105 return &it->second; 106 } 107 108 private: 109 // Currently visiting instruction. 110 HloInstruction* visiting_hlo_; 111 112 // Map from the currently visiting (old) instruction to new instructions 113 // created during SPMD partitioning. 114 HloInstructionMap<std::vector<HloInstruction*>> instructions_; 115 116 // Maps from each created instruction to a set of dimensions that are from 117 // broadcasts or elementwise ops over broadcasts. This means elements along 118 // these dimensions have the same value. 119 absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<int64>> 120 broadcast_dims_; 121 }; 122 123 // A set of functions that create the cross-partition collective ops. 124 struct SPMDCollectiveOpsCreator { 125 // Function used to create a partition ID HLO. 126 std::function<HloInstruction*(SpmdBuilder*)> create_partition_id; 127 128 // Function used to create a cross-partition all-reduce HLO. 129 std::function<HloInstruction*( 130 SpmdBuilder*, HloInstruction* operand, HloComputation* reduction, 131 const std::vector<std::vector<int64>>& partition_subgroups, 132 int64_t channel_id)> 133 create_cross_partition_all_reduce; 134 135 // Function used to create a cross-partition collective-permute HLO. 136 std::function<HloInstruction*( 137 SpmdBuilder*, HloInstruction* operand, 138 std::vector<std::pair<int64, int64>>& src_dst_pairs, 139 int64_t next_channel_id)> 140 create_cross_partition_collective_permute; 141 142 // Function used to create a cross-partition all-to-all HLO. 143 std::function<HloInstruction*( 144 SpmdBuilder*, absl::Span<HloInstruction* const> operands, 145 const std::vector<std::vector<int64>>& partition_subgroups, 146 int64_t channel_id, absl::optional<int64> split_dimension)> 147 create_cross_partition_all_to_all; 148 149 // Function used to create a cross-partition all-gather HLO. This is optional: 150 // if it is nullptr, the partitioner will use all-reduce instead. 151 std::function<HloInstruction*( 152 SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape, 153 const std::vector<std::vector<int64>>& partition_subgroups, 154 int64_t channel_id, int64_t all_gather_dimension)> 155 create_cross_partition_all_gather; 156 }; 157 158 // Create a default SPMDCollectiveOpsCreator. 159 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, 160 int64_t num_replicas); 161 162 // Logger to report memory usage during SPMD partitioning. 163 class SpmdLogger { 164 public: SpmdLogger(int64_t report_instruction_count)165 explicit SpmdLogger(int64_t report_instruction_count) 166 : report_instruction_count_(report_instruction_count) {} 167 static std::string ReportBeforePartition(const HloModule& module, 168 int64_t report_instruction_count); 169 static std::string ReportAfterPartition(const HloModule& module, 170 int64_t report_instruction_count); 171 172 // Registers the logging for the groups of instructions created to transform 173 // the given hlo. 174 void RegisterLogEntry(HloInstruction* hlo, 175 const std::vector<HloInstruction*>& group); 176 177 std::string MakeReport(); 178 179 private: 180 template <typename F> 181 static std::string ReportMemoryUsage(const HloModule& module, const F& filter, 182 int64_t report_instruction_count); 183 184 // A vector of logging messages (one for each original HLO instruction), where 185 // the first integer of the pair represents the size of the HBM used. 186 std::vector<std::pair<int64, std::string>> entries_; 187 188 int64 report_instruction_count_; 189 }; 190 191 class SpmdPartitioningVisitor; 192 193 class SpmdPartitioner : public HloModulePass { 194 public: 195 SpmdPartitioner(int64_t num_partitions, int64_t num_replicas, 196 SpmdPartitionerOptions options); SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options,SPMDCollectiveOpsCreator collective_ops_creator)197 SpmdPartitioner(int64_t num_partitions, int64_t num_replicas, 198 SpmdPartitionerOptions options, 199 SPMDCollectiveOpsCreator collective_ops_creator) 200 : num_partitions_(num_partitions), 201 num_replicas_(num_replicas), 202 options_(std::move(options)), 203 collective_ops_creator_(std::move(collective_ops_creator)) {} name()204 absl::string_view name() const override { return "spmd-partitioning"; } 205 StatusOr<bool> Run(HloModule* module) override; 206 207 // Transforms the given computation with SPMD instructions, replacing it with 208 // a new computation. 209 StatusOr<bool> PartitionComputation(HloComputation* computation, 210 const HloSharding& root_sharding, 211 int64* next_channel_id, 212 SpmdLogger* logger); 213 214 // Creates all-gather(s) based on HloSharding. Can be overridden to customize. 215 // The default uses a single all-gather even if there are multiple sharded 216 // dimensions, and adds potential reshapes and transposes to achieve that. 217 // If it returns false, the partitioner will fall back to all-reduce. 218 // `selected_dims` specifies the dimensions along which the all-gather happens 219 // in the tiled sharding, which allows potentially creating a subgroup 220 // all-gather. 221 virtual HloInstruction* AllGatherShards( 222 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 223 int64* next_channel_id, absl::Span<const int64> selected_dims, 224 const SPMDCollectiveOpsCreator& collectives_creator); 225 226 // Creates all-reduce(s) across devices along selected_dims in sharding. Can 227 // be overridden to customize. 228 virtual HloInstruction* AllReduceAlongShardingDims( 229 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 230 int64* next_channel_id, absl::Span<const int64> selected_dims, 231 const SPMDCollectiveOpsCreator& collectives_creator, 232 HloComputation* reduction); 233 options()234 const SpmdPartitionerOptions& options() { return options_; } 235 236 protected: 237 virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor( 238 HloComputation* computation, int64_t num_partitions, int64_t num_replicas, 239 const SPMDCollectiveOpsCreator& collective_ops_creator, 240 int64* next_channel_id, SpmdLogger* logger, 241 SpmdPartitionerOptions options); 242 243 HloInstruction* AllGatherShardsInternal( 244 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 245 int64* next_channel_id, absl::Span<const int64> selected_dims, 246 const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag); 247 HloInstruction* AllReduceAlongShardingDimsInternal( 248 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 249 int64* next_channel_id, absl::Span<const int64> selected_dims, 250 const SPMDCollectiveOpsCreator& collectives_creator, 251 HloComputation* reduction, bool per_dim_ar); 252 253 // Verifies that the sharding of instructions in the module are valid, and 254 // also fill in missing sharding information. 255 virtual Status PreprocessSharding(HloModule* module); 256 257 // Returns if the given side-effecting instruction is allowed to have 258 // replicated sharding. CanSideEffectingHaveReplicatedSharding(const HloInstruction * hlo)259 virtual bool CanSideEffectingHaveReplicatedSharding( 260 const HloInstruction* hlo) { 261 return hlo->opcode() == HloOpcode::kInfeed || 262 hlo->opcode() == HloOpcode::kOutfeed; 263 } 264 265 // Preprocesses the graph to simplify some communication patterns. E.g., merge 266 // pad->slice into a single pad with potentially negative padding to avoid 267 // multiple halo exchanges. 268 Status PreprocessHlos(HloModule* module); 269 270 const int64 num_partitions_; 271 const int64 num_replicas_; 272 273 SpmdPartitionerOptions options_; 274 SPMDCollectiveOpsCreator collective_ops_creator_; 275 std::vector<std::vector<int64>> device_groups_; 276 }; 277 278 // Class describes partition state of the data represented by an HLO created 279 // during SPMD partitioning pass. 280 // 281 // Data on some devices may include padding region, if the base (full) shape 282 // could not be evenly partitioned. 283 class PartitionedHlo { 284 public: 285 // Return value for ReshardAsWindowedInput which describes the resharded HLO, 286 // the window for the user on the shard, and if necessary, the dynamic slice 287 // offsets to be applied to the output of the op being sharded. 288 struct WindowedInputShardReturnValue { 289 HloInstruction* sharded_input; 290 Window shard_window; 291 absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output; 292 }; 293 // A cache for resharding each partitioned HLO. 294 struct ReshardCache { 295 struct PerHloCache { 296 std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache; 297 std::vector< 298 std::tuple<HloSharding, Window, WindowedInputShardReturnValue>> 299 window_reshard_cache; 300 }; 301 // Use std::unordered_map for pointer stability. 302 std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache; 303 // Caches for nested partitioning of grouped sharding. Each string key 304 // represents a unique way of grouping devices. 305 absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>> 306 groupd_caches; 307 }; 308 struct PartitioningState { 309 SpmdBuilder* b; 310 HloModule* module; 311 int64 num_replicas; 312 HloInstruction* partition_id; 313 SPMDCollectiveOpsCreator collective_ops_creator; 314 int64* next_channel_id; 315 ReshardCache* reshard_cache; 316 SpmdPartitioner* partitioner; 317 }; PartitionedHlo(HloInstruction * hlo,Shape base_shape,PartitioningState state)318 PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) 319 : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { 320 CHECK(hlo->has_sharding()) 321 << "PartitionedHlo is missing sharding:" << hlo->ToString(); 322 // If the tuple shape instruction does not have a tuple sharding, reassign 323 // to use the tuple sharding. Reshard() implementation assumes this. 324 if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { 325 hlo_->set_sharding( 326 hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); 327 } 328 } 329 330 // Reshards the current SPMD instruction to a new sharding. Could only modify 331 // the reshard cache. 332 PartitionedHlo Reshard(const HloSharding& target); 333 334 // Pads the garbage area of the output with the provided value. Normally, 335 // unevenly partitioned dimensions are padded on the right, but this function 336 // allows specifying left-padded dimensions, which can be used during the 337 // handling of kReverse, etc. 338 PartitionedHlo PadWithValue(HloInstruction* pad_value, 339 absl::Span<const int64> left_padded_dims = {}, 340 absl::Span<const int64> skipped_dims = {}) const; 341 342 // Returns the SPMD instruction. hlo()343 HloInstruction* hlo() const { return hlo_; } 344 345 // Returns the sharding of the SPMD instruction. sharding()346 const HloSharding& sharding() const { return hlo_->sharding(); } 347 348 // Original full shape of the data. base_shape()349 const Shape& base_shape() const { return base_shape_; } 350 NewChannel()351 int64 NewChannel() const { return (*state_.next_channel_id)++; } 352 353 // Reshards the HLO to a usable partitioned input for a windowed user. Could 354 // only modify the reshard cache. 355 absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput( 356 const Window& window, const HloSharding& target, 357 HloInstruction* pad_value, bool mask_invalid_region = true); 358 state()359 const PartitioningState& state() const { return state_; } 360 361 // Helper function to replicate the data on all devices. Could only modify 362 // the reshard cache. 363 PartitionedHlo Replicate(); 364 365 // Helper function to replicate the data for partitions along the given dims. 366 HloInstruction* ReplicatePartial(absl::Span<const int64> dims); 367 368 // Set state of the partitoned HLO. set_state(PartitioningState state)369 void set_state(PartitioningState state) { state_ = std::move(state); } 370 371 private: 372 // Same as Reshard except that it does not explicitly modify the reshard 373 // cache, although it would indirectly modify by calling Replicate(). 374 PartitionedHlo ReshardNoCache(const HloSharding& target); 375 376 // Helper function to broadcast data from a single device to all devices. 377 PartitionedHlo Broadcast() const; 378 379 // Helper function to reshard the tensor using AllToAll (instead of the 380 // default of Replicate followed by Slice). 381 PartitionedHlo ReshardWithAllToAll( 382 const HloSharding& target, 383 absl::Span<const std::pair<int64, int64>> source_target_dims) const; 384 385 // Helper function to reshard the tensor using CollectivePermute. 386 PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; 387 388 // Helper function to reshard to partial replicate using AllGather. 389 absl::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather( 390 const HloSharding& target); 391 392 // Helper function to reshard from partial replicate using DynamicSlice. 393 absl::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice( 394 const HloSharding& target); 395 396 // Helper function to reshard from partial replicate using AllToAll. 397 absl::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll( 398 const HloSharding& target); 399 400 // SPMD instruction. 401 HloInstruction* hlo_; 402 403 // The original shape of the data before SPMD transformation is applied. 404 Shape base_shape_; 405 406 PartitioningState state_; 407 }; 408 409 struct DotConvDimsMapping { 410 // The dimension numbers for the operands and output corresponding to a 411 // logical dimension (e.g., batch, contracting, non-contracting). If an 412 // operand or the output doesn't have the logical dimension, it is set to 413 // -1. 414 struct DimsMapping { 415 int64 lhs; 416 int64 rhs; 417 int64 output; 418 // input mapped to index in input_spatial_dimensions(). 419 int64 spatial; 420 }; 421 std::vector<DimsMapping> batch_dims; 422 std::vector<DimsMapping> contracting_dims; 423 std::vector<DimsMapping> lhs_non_contracting_dims; 424 std::vector<DimsMapping> rhs_non_contracting_dims; 425 std::vector<DimsMapping> conv_spatial_dims; 426 }; 427 428 class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { 429 public: 430 SpmdPartitioningVisitor( 431 HloComputation* computation, int64_t num_partitions, int64_t num_replicas, 432 const SPMDCollectiveOpsCreator& collective_ops_creator, 433 int64* next_channel_id, SpmdLogger* logger, 434 SpmdPartitionerOptions options, SpmdPartitioner* partitioner); 435 436 Status DefaultAction(HloInstruction* hlo) override; 437 Status HandleAllReduce(HloInstruction* hlo) override; 438 Status HandleBroadcast(HloInstruction* hlo) override; 439 Status HandleConstant(HloInstruction* hlo) override; 440 Status HandleCustomCall(HloInstruction* hlo) override; 441 Status HandleDot(HloInstruction* hlo) override; 442 Status HandleDynamicSlice(HloInstruction* hlo) override; 443 Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; 444 Status HandleFft(HloInstruction* hlo) override; 445 Status HandleGather(HloInstruction* hlo) override; 446 Status HandleGetTupleElement(HloInstruction* hlo) override; 447 Status HandleInfeed(HloInstruction* hlo) override; 448 Status HandleOutfeed(HloInstruction* hlo) override; 449 Status HandlePad(HloInstruction* hlo) override; 450 Status HandleParameter(HloInstruction* hlo) override; 451 Status HandleReduce(HloInstruction* hlo) override; 452 Status HandleReverse(HloInstruction* hlo) override; 453 Status HandleWhile(HloInstruction* hlo) override; 454 Status HandleConditional(HloInstruction* hlo) override; 455 Status HandleReduceWindow(HloInstruction* hlo) override; 456 Status HandleSelectAndScatter(HloInstruction* hlo) override; 457 Status HandleTuple(HloInstruction* hlo) override; 458 Status HandleRng(HloInstruction* hlo) override; 459 Status HandleConvolution(HloInstruction* hlo) override; 460 Status HandleConcatenate(HloInstruction* hlo) override; 461 Status HandleScatter(HloInstruction* hlo) override; 462 Status HandleSlice(HloInstruction* hlo) override; 463 Status HandleSort(HloInstruction* hlo) override; 464 Status HandleTranspose(HloInstruction* hlo) override; 465 Status HandleReshape(HloInstruction* hlo) override; 466 Status HandleIota(HloInstruction* hlo) override; 467 Status HandlePartitionId(HloInstruction* hlo) override; 468 469 // Implementation of dot partitioning given DotGeneralDimsMapping. 470 Status HandleDotHelper(HloInstruction* hlo, 471 const DotConvDimsMapping& dims_mapping, 472 const std::function<StatusOr<HloInstruction*>( 473 HloInstruction*, HloInstruction*, SpmdBuilder*, 474 const Window& conv_window)>& create_sharded_dot); 475 476 // Common handle for elementwise HLOs. 477 Status HandleElementwise(HloInstruction* hlo); 478 479 // Common handle for HLOs that runs on a single device. 480 Status HandleSingleDevice(const HloInstruction* hlo); 481 482 // CustomCall handlers per call target. 483 Status HandleCustomCallTopK(HloInstruction* hlo); 484 // Convenient custom ops defined by the partitioner itself. 485 Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo); 486 487 // Returns the PartitionedHlo that corresponds to the original hlo. GetPartitionedHlo(const HloInstruction * hlo)488 PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { 489 CHECK_EQ(partitioned_instructions_.count(hlo), 1); 490 return partitioned_instructions_.find(hlo)->second; 491 } 492 493 // Sets the PartitionedHlo for the original hlo. SetPartitionedHlo(const HloInstruction * hlo,const PartitionedHlo & partitioned_hlo)494 void SetPartitionedHlo(const HloInstruction* hlo, 495 const PartitionedHlo& partitioned_hlo) { 496 CHECK_EQ(partitioned_instructions_.count(hlo), 0); 497 partitioned_instructions_.emplace(hlo, partitioned_hlo); 498 changed_ = true; 499 } 500 501 // Convenient wrapper that creates PartitionedHlo from the result of the func 502 // and maps it to the given original hlo. SetPartitionedHlo(const HloInstruction * hlo,const std::function<HloInstruction * ()> & func)503 void SetPartitionedHlo(const HloInstruction* hlo, 504 const std::function<HloInstruction*()>& func) { 505 HloInstruction* new_hlo = func(); 506 new_hlo->set_sharding(hlo->sharding()); 507 SetPartitionedHlo( 508 hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); 509 changed_ = true; 510 } 511 NewChannel()512 int64 NewChannel() { return (*next_channel_id_)++; } 513 514 PartitionedHlo::PartitioningState MakePartitioningState(); 515 builder()516 SpmdBuilder* builder() { return &b_; } 517 518 virtual StatusOr<bool> DoPartition(HloComputation* computation, 519 const HloSharding& root_sharding, 520 const SpmdPartitionerOptions& options); 521 522 // Information about a loop created for windowed dot-general. Used when 523 // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor 524 // finishes traversing the graph. 525 struct WindowedDotGeneralLoop { 526 HloInstruction* while_loop; 527 int64 windowed_operand; 528 bool windowed_in_contracting_dims; 529 bool windowed_in_batch_dims; 530 bool operands_sharded_at_contracting_dims; 531 int64 num_partitions; 532 std::vector<ReplicaGroup> loop_replica_groups; 533 }; 534 535 protected: 536 Status Preprocess(HloInstruction* hlo) override; 537 Status Postprocess(HloInstruction* hlo) override; 538 539 // Performs code motion for windowed dot-general loops in 540 // windowed_dot_general_loops_. Invoked after the visitor finishes traversing 541 // the graph. 542 Status DoCodeMotionForWindowedDotGeneralLoops( 543 HloComputation* computation, const SpmdPartitionerOptions& options); 544 545 bool changed_; 546 HloModule* module_; 547 int64 num_partitions_; 548 int64 num_replicas_; 549 550 SPMDCollectiveOpsCreator collective_ops_creator_; 551 552 // Tracks the next channel id to use for cross-partition all-reduce. 553 int64* next_channel_id_; 554 SpmdBuilder b_; 555 556 std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_; 557 558 HloInstruction* partition_id_; 559 560 private: 561 PartitionedHlo::ReshardCache reshard_cache_; 562 563 // Mapping from the instruction in the original computation to the new SPMD 564 // partitioned instruction. 565 ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_; 566 567 HloInstruction* visiting_hlo_; 568 SpmdLogger* logger_; 569 const SpmdPartitionerOptions options_; 570 SpmdPartitioner* partitioner_; 571 std::vector<HloSharding> visiting_hlo_operand_shardings_; 572 absl::optional<HloSharding> visiting_hlo_sharding_; 573 absl::optional<int64> visiting_num_partitions_; 574 std::vector<PartitionedHlo::PartitioningState> visiting_state_; 575 std::vector<std::vector<int64>> device_groups_; 576 }; 577 578 } // namespace spmd 579 } // namespace xla 580 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 581