• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
18 
19 #include "absl/synchronization/mutex.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
23 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
24 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
26 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 // Common place for all collective thunks to source nccl/rccl headers.
32 // Also, all the RunNcclCollective() functions for various thunks should
33 // use XLA_ENABLE_XCCL to guard use NCCL/RCCL usage (and not use GOOGLE_XCCL).
34 #if GOOGLE_XCCL
35 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
36 #define XLA_ENABLE_XCCL 1
37 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
38 #endif  // GOOGLE_XCCL
39 
40 #if XLA_ENABLE_XCCL
41 #if GOOGLE_CUDA
42 #include "third_party/nccl/nccl.h"
43 #elif TENSORFLOW_USE_ROCM
44 #include "rocm/include/rccl/rccl.h"
45 #else
46 #error "Neither CUDA nor ROCm enabled but NCCL/RCCL enabled"
47 #endif
48 
49 // Also include this file required by all collective thunks.
50 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
51 
52 #endif  // XLA_ENABLE_XCCL
53 
54 struct ncclComm;
55 using ncclComm_t = ncclComm*;
56 
57 namespace xla {
58 namespace gpu {
59 
60 class NcclClique;
61 
62 struct NcclCollectiveConfig {
63   NcclCollectiveConfig();
64   NcclCollectiveConfig(NcclCollectiveConfig&&);
65   ~NcclCollectiveConfig();
66 
67   NcclCollectiveConfig& operator=(NcclCollectiveConfig&&);
68 
69   int64 operand_count;
70   std::vector<PrimitiveType> operand_element_type;
71   std::vector<ReplicaGroup> replica_groups;
72   RendezvousKey::CollectiveOpKind collective_op_kind;
73   int64 op_id;
74   CollectiveOpGroupMode group_mode;
75 
76   template <typename OpT>
77   void SetCollectiveOpKindAndID(OpT op);
78   bool IsDegenerate(int64_t replica_count, int64_t partition_count) const;
79 };
80 
81 template <typename OpT>
SetCollectiveOpKindAndID(OpT op)82 void NcclCollectiveConfig::SetCollectiveOpKindAndID(OpT op) {
83   if (op.channel_id()) {
84     collective_op_kind = RendezvousKey::kCrossModule;
85     op_id = static_cast<int64>(op.channel_id()->handle().getInt());
86   } else {
87     collective_op_kind = RendezvousKey::kCrossReplica;
88     mlir::ModuleOp parent = op->template getParentOfType<mlir::ModuleOp>();
89     mlir::IntegerAttr unique_id =
90         parent->getAttrOfType<mlir::IntegerAttr>("hlo.unique_id");
91     op_id = static_cast<int64>(unique_id.getInt());
92   }
93 }
94 
95 template <typename OpT>
GetNcclCollectiveConfigForMlir(OpT op,absl::optional<bool> use_global_device_ids)96 NcclCollectiveConfig GetNcclCollectiveConfigForMlir(
97     OpT op, absl::optional<bool> use_global_device_ids) {
98   NcclCollectiveConfig config;
99   config.operand_count = op.operands().size();
100   config.operand_element_type.reserve(config.operand_count);
101   for (int i = 0; i < config.operand_count; i++) {
102     const Shape shape = GetShape(op.operands()[i]);
103     config.operand_element_type.push_back(shape.element_type());
104   }
105   config.replica_groups =
106       ConvertReplicaGroups(op.replica_groups()).ValueOrDie();
107   config.SetCollectiveOpKindAndID(op);
108   config.group_mode = GetCollectiveOpGroupMode(op.channel_id().hasValue(),
109                                                use_global_device_ids)
110                           .ValueOrDie();
111   return config;
112 }
113 
114 // Thunk base class for NCCL collective operations.
115 class NcclCollectiveThunk : public Thunk {
116  public:
117   using Thunk::Thunk;
118 
119   struct Buffer {
120     int64 element_count;
121     BufferAllocation::Slice source_buffer;
122     BufferAllocation::Slice destination_buffer;
123   };
124 
125   // Returns whether NCCL operations appear possible to perform; e.g. if we
126   // haven't done a build with the CUDA compiler enabled, we can't compile the
127   // NCCL header, and thus this will be false.
128   //
129   // When this is false, the ExecuteOnStream() call will simply return a status
130   // error.
131   static bool NcclIsEnabled();
132 
133   Status ExecuteOnStream(const ExecuteParams& params) override;
134 
135  protected:
136   virtual Status RunNcclCollective(const ExecuteParams& params,
137                                    ncclComm_t comm) = 0;
138   virtual const NcclCollectiveConfig& config() const = 0;
139 
140   // Logging support.
141   std::string GetDeviceString(const ExecuteParams& params) const;
142 };
143 
144 // Returns if the given data type is supported by NCCL.
145 // Note: Keep this in sync with ToNcclDataType().
146 bool IsTypeSupportedByNccl(PrimitiveType element_type);
147 
148 }  // namespace gpu
149 }  // namespace xla
150 
151 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
152