• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
18 
19 #include "absl/synchronization/mutex.h"
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
23 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
24 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
25 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 // Common place for all collective thunks to source nccl/rccl headers.
31 // Also, all the RunNcclCollective() functions for various thunks should
32 // use XLA_ENABLE_XCCL to guard use NCCL/RCCL usage (and not use GOOGLE_XCCL).
33 #if GOOGLE_XCCL
34 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35 #define XLA_ENABLE_XCCL 1
36 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 #endif  // GOOGLE_XCCL
38 
39 #if XLA_ENABLE_XCCL
40 #if GOOGLE_CUDA
41 #include "third_party/nccl/nccl.h"
42 #elif TENSORFLOW_USE_ROCM
43 #include "rocm/include/rccl/rccl.h"
44 #else
45 #error "Neither CUDA nor ROCm enabled but NCCL/RCCL enabled"
46 #endif
47 
48 // Also include this file required by all collective thunks.
49 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
50 
51 #endif  // XLA_ENABLE_XCCL
52 
53 struct ncclComm;
54 using ncclComm_t = ncclComm*;
55 
56 namespace xla {
57 namespace gpu {
58 
59 struct NcclClique;
60 
61 struct NcclCollectiveConfig {
62   NcclCollectiveConfig();
63   NcclCollectiveConfig(NcclCollectiveConfig&&);
64   ~NcclCollectiveConfig();
65 
66   NcclCollectiveConfig& operator=(NcclCollectiveConfig&&);
67 
68   int64 operand_count;
69   std::vector<PrimitiveType> operand_element_type;
70   int64 replica_count;
71   std::vector<ReplicaGroup> replica_groups;
72   RendezvousKey::CollectiveOpKind collective_op_kind;
73   int64 op_id;
74 };
75 
76 
77 template <typename OpT>
GetNcclCollectiveConfigForMlir(OpT op,int64 replica_count)78 NcclCollectiveConfig GetNcclCollectiveConfigForMlir(OpT op,
79                                                     int64 replica_count) {
80   NcclCollectiveConfig config;
81   config.operand_count = op.operands().size();
82   config.operand_element_type.reserve(config.operand_count);
83   for (int i = 0; i < config.operand_count; i++) {
84     const Shape shape = TypeToShape(op.operands()[i].getType());
85     config.operand_element_type.push_back(shape.element_type());
86   }
87   config.replica_count = replica_count;
88   config.replica_groups =
89       ConvertReplicaGroups(op.replica_groups()).ValueOrDie();
90 
91   if (!op.IsCrossReplica()) {
92     config.collective_op_kind = RendezvousKey::kCrossModule;
93     config.op_id = op.channel_id()->handle().getUInt();
94   } else {
95     config.collective_op_kind = RendezvousKey::kCrossReplica;
96     mlir::ModuleOp parent = op->template getParentOfType<mlir::ModuleOp>();
97     mlir::IntegerAttr unique_id =
98         parent->getAttrOfType<mlir::IntegerAttr>("hlo.unique_id");
99     config.op_id = static_cast<int64>(unique_id.getInt());
100   }
101   return config;
102 }
103 
104 // Thunk base class for NCCL collective operations.
105 class NcclCollectiveThunk : public Thunk {
106  public:
107   using Thunk::Thunk;
108 
109   struct Buffer {
110     int64 element_count;
111     BufferAllocation::Slice source_buffer;
112     BufferAllocation::Slice destination_buffer;
113   };
114 
115   // Returns whether NCCL operations appear possible to perform; e.g. if we
116   // haven't done a build with the CUDA compiler enabled, we can't compile the
117   // NCCL header, and thus this will be false.
118   //
119   // When this is false, the ExecuteOnStream() call will simply return a status
120   // error.
121   static bool NcclIsEnabled();
122 
123   Status ExecuteOnStream(const ExecuteParams& params) override;
124 
125  protected:
126   virtual Status RunNcclCollective(const ExecuteParams& params,
127                                    ncclComm_t comm) = 0;
128   virtual const NcclCollectiveConfig& config() const = 0;
129 };
130 
131 // Returns if the given data type is supported by NCCL.
132 // Note: Keep this in sync with ToNcclDataType().
133 bool IsTypeSupportedByNccl(PrimitiveType element_type);
134 
135 }  // namespace gpu
136 }  // namespace xla
137 
138 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_COLLECTIVE_THUNK_H_
139