1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
18
19 #include "absl/synchronization/mutex.h"
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
23 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
24 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/types.h"
29
30 // Common place for all collective thunks to source nccl/rccl headers.
31 // Also, all the RunNcclCollective() functions for various thunks should
32 // use XLA_ENABLE_XCCL to guard use NCCL/RCCL usage (and not use GOOGLE_XCCL).
33 #if GOOGLE_XCCL
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #define XLA_ENABLE_XCCL 1
36 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 #endif // GOOGLE_XCCL
38
39 #if XLA_ENABLE_XCCL
40 #if GOOGLE_CUDA
41 #include "third_party/nccl/nccl.h"
42 #elif TENSORFLOW_USE_ROCM
43 #include "rocm/include/rccl/rccl.h"
44 #else
45 #error "Neither CUDA nor ROCm enabled but NCCL/RCCL enabled"
46 #endif
47
48 // Also include this file required by all collective thunks.
49 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
50
51 #endif // XLA_ENABLE_XCCL
52
53 struct ncclComm;
54 using ncclComm_t = ncclComm*;
55
56 namespace xla {
57 namespace gpu {
58
59 struct NcclClique;
60
61 struct NcclCollectiveConfig {
62 NcclCollectiveConfig();
63 NcclCollectiveConfig(NcclCollectiveConfig&&);
64 ~NcclCollectiveConfig();
65
66 NcclCollectiveConfig& operator=(NcclCollectiveConfig&&);
67
68 int64 operand_count;
69 std::vector<PrimitiveType> operand_element_type;
70 int64 replica_count;
71 std::vector<ReplicaGroup> replica_groups;
72 RendezvousKey::CollectiveOpKind collective_op_kind;
73 int64 op_id;
74 };
75
76
77 template <typename OpT>
GetNcclCollectiveConfigForMlir(OpT op,int64 replica_count)78 NcclCollectiveConfig GetNcclCollectiveConfigForMlir(OpT op,
79 int64 replica_count) {
80 NcclCollectiveConfig config;
81 config.operand_count = op.operands().size();
82 config.operand_element_type.reserve(config.operand_count);
83 for (int i = 0; i < config.operand_count; i++) {
84 const Shape shape = TypeToShape(op.operands()[i].getType());
85 config.operand_element_type.push_back(shape.element_type());
86 }
87 config.replica_count = replica_count;
88 config.replica_groups =
89 ConvertReplicaGroups(op.replica_groups()).ValueOrDie();
90
91 if (!op.IsCrossReplica()) {
92 config.collective_op_kind = RendezvousKey::kCrossModule;
93 config.op_id = op.channel_id()->handle().getUInt();
94 } else {
95 config.collective_op_kind = RendezvousKey::kCrossReplica;
96 mlir::ModuleOp parent = op->template getParentOfType<mlir::ModuleOp>();
97 mlir::IntegerAttr unique_id =
98 parent->getAttrOfType<mlir::IntegerAttr>("hlo.unique_id");
99 config.op_id = static_cast<int64>(unique_id.getInt());
100 }
101 return config;
102 }
103
104 // Thunk base class for NCCL collective operations.
105 class NcclCollectiveThunk : public Thunk {
106 public:
107 using Thunk::Thunk;
108
109 struct Buffer {
110 int64 element_count;
111 BufferAllocation::Slice source_buffer;
112 BufferAllocation::Slice destination_buffer;
113 };
114
115 // Returns whether NCCL operations appear possible to perform; e.g. if we
116 // haven't done a build with the CUDA compiler enabled, we can't compile the
117 // NCCL header, and thus this will be false.
118 //
119 // When this is false, the ExecuteOnStream() call will simply return a status
120 // error.
121 static bool NcclIsEnabled();
122
123 Status ExecuteOnStream(const ExecuteParams& params) override;
124
125 protected:
126 virtual Status RunNcclCollective(const ExecuteParams& params,
127 ncclComm_t comm) = 0;
128 virtual const NcclCollectiveConfig& config() const = 0;
129 };
130
131 // Returns if the given data type is supported by NCCL.
132 // Note: Keep this in sync with ToNcclDataType().
133 bool IsTypeSupportedByNccl(PrimitiveType element_type);
134
135 } // namespace gpu
136 } // namespace xla
137
138 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
139