1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" 21 #include "tensorflow/compiler/xla/service/collective_ops_utils.h" 22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 23 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/xla_data.pb.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace xla { 29 namespace gpu { 30 31 struct NcclCollectivePermuteConfig : public NcclCollectiveConfig { 32 public: 33 // During a collective permute, every node optionally sends its data to 34 // another node (including possibly itself) and received data from another 35 // node. For each node, remember who it receives data from (source) and who 36 // it send data to (target). Either are optional. 37 struct SourceTargetMapEntry { 38 absl::optional<int64> source; 39 absl::optional<int64> target; 40 }; 41 42 absl::flat_hash_map<int64, SourceTargetMapEntry> id_to_source_target; 43 44 // Returns the source and target ID corresponding to the given ID (these IDs 45 // are replica_ids for cross replica permute or partition_ids for cross 46 // partition permute). The source ID is the id which will send data to this 47 // ID and the target ID is the id to which this ID will send its data. Either 48 // can be optional. GetSourceTargetNcclCollectivePermuteConfig49 SourceTargetMapEntry GetSourceTarget(int64_t id) const { 50 auto it = id_to_source_target.find(id); 51 if (it != id_to_source_target.end()) return it->second; 52 return SourceTargetMapEntry{}; 53 } 54 }; 55 56 // Thunk that performs a NCCL-based collective permute. 57 class NcclCollectivePermuteThunk : public NcclCollectiveThunk { 58 public: 59 NcclCollectivePermuteThunk(ThunkInfo thunk_info, 60 mlir::lmhlo::CollectivePermuteOp op, 61 int64_t replica_count, int64_t partition_count, 62 const Buffer& buffer); 63 64 // Returns whether the given instruction can be lowered to a nccl collective 65 // permute thunk. 66 static bool CanImplement(mlir::lmhlo::CollectivePermuteOp op); 67 GetName()68 static const char* GetName() { return "CollectivePermute"; } 69 static bool IsDegenerate(mlir::lmhlo::CollectivePermuteOp op, 70 int64_t replica_count, int64_t partition_count); GetGroupMode(mlir::lmhlo::CollectivePermuteOp op)71 static CollectiveOpGroupMode GetGroupMode( 72 mlir::lmhlo::CollectivePermuteOp op) { 73 return GetCollectiveOpGroupMode(op.channel_id().hasValue(), absl::nullopt) 74 .ValueOrDie(); 75 } 76 77 protected: 78 Status RunNcclCollective(const ExecuteParams& params, 79 ncclComm_t comm) override; 80 config()81 const NcclCollectiveConfig& config() const override { return config_; } 82 83 private: 84 static NcclCollectivePermuteConfig GetNcclCollectivePermuteConfig( 85 mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count, 86 int64_t partition_count); 87 88 const NcclCollectivePermuteConfig config_; 89 const Buffer buffer_; 90 }; 91 92 } // namespace gpu 93 } // namespace xla 94 95 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_ 96