• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
21 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
23 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 struct NcclCollectivePermuteConfig : public NcclCollectiveConfig {
32  public:
33   // During a collective permute, every node optionally sends its data to
34   // another node (including possibly itself) and received data from another
35   // node. For each node, remember who it receives data from (source) and who
36   // it send data to (target). Either are optional.
37   struct SourceTargetMapEntry {
38     absl::optional<int64> source;
39     absl::optional<int64> target;
40   };
41 
42   absl::flat_hash_map<int64, SourceTargetMapEntry> id_to_source_target;
43 
44   // Returns the source and target ID corresponding to the given ID (these IDs
45   // are replica_ids for cross replica permute or partition_ids for cross
46   // partition permute). The source ID is the id which will send data to this
47   // ID and the target ID is the id to which this ID will send its data. Either
48   // can be optional.
GetSourceTargetNcclCollectivePermuteConfig49   SourceTargetMapEntry GetSourceTarget(int64_t id) const {
50     auto it = id_to_source_target.find(id);
51     if (it != id_to_source_target.end()) return it->second;
52     return SourceTargetMapEntry{};
53   }
54 };
55 
56 // Thunk that performs a NCCL-based collective permute.
57 class NcclCollectivePermuteThunk : public NcclCollectiveThunk {
58  public:
59   NcclCollectivePermuteThunk(ThunkInfo thunk_info,
60                              mlir::lmhlo::CollectivePermuteOp op,
61                              int64_t replica_count, int64_t partition_count,
62                              const Buffer& buffer);
63 
64   // Returns whether the given instruction can be lowered to a nccl collective
65   // permute thunk.
66   static bool CanImplement(mlir::lmhlo::CollectivePermuteOp op);
67 
GetName()68   static const char* GetName() { return "CollectivePermute"; }
69   static bool IsDegenerate(mlir::lmhlo::CollectivePermuteOp op,
70                            int64_t replica_count, int64_t partition_count);
GetGroupMode(mlir::lmhlo::CollectivePermuteOp op)71   static CollectiveOpGroupMode GetGroupMode(
72       mlir::lmhlo::CollectivePermuteOp op) {
73     return GetCollectiveOpGroupMode(op.channel_id().hasValue(), absl::nullopt)
74         .ValueOrDie();
75   }
76 
77  protected:
78   Status RunNcclCollective(const ExecuteParams& params,
79                            ncclComm_t comm) override;
80 
config()81   const NcclCollectiveConfig& config() const override { return config_; }
82 
83  private:
84   static NcclCollectivePermuteConfig GetNcclCollectivePermuteConfig(
85       mlir::lmhlo::CollectivePermuteOp op, int64_t replica_count,
86       int64_t partition_count);
87 
88   const NcclCollectivePermuteConfig config_;
89   const Buffer buffer_;
90 };
91 
92 }  // namespace gpu
93 }  // namespace xla
94 
95 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_PERMUTE_THUNK_H_
96