1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ 16 #define TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ 17 18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 19 20 #include <vector> 21 22 // TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when 23 // setting EIGEN_USE_THREADS. But when defining EIGEN_USE_THREADS here, 24 // incAtomic and other CUDA specific symbols are no longer recognized. 25 #ifndef gpu_assert 26 #define gpu_assert(x) 27 #endif 28 29 #include "absl/container/flat_hash_map.h" 30 #if GOOGLE_CUDA 31 #include "third_party/nccl/nccl.h" 32 #elif TENSORFLOW_USE_ROCM 33 #include "rocm/include/rccl/rccl.h" 34 #include "tensorflow/core/common_runtime/gpu_device_context.h" 35 #endif 36 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 37 #include "tensorflow/core/framework/device_base.h" 38 #include "tensorflow/core/framework/tensor.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/stream_executor.h" 41 42 namespace tensorflow { 43 44 // NCCL manager is used to make the asynchronous communicator calls and to 45 // manage the per-device streams used for communication. 46 // 47 // See nccl_ops.cc for example usage, including description of memory 48 // management and stream synchronization. 49 class NcclManager { 50 public: 51 typedef std::function<void(Status)> DoneCallback; 52 NcclManager(); 53 ~NcclManager(); 54 55 static NcclManager* instance(); 56 57 #if TENSORFLOW_USE_ROCM 58 static int instance_count; 59 #endif 60 61 // Calls `ncclGetUniqueId` and returns the id as a string. The returned value 62 // may be shared with other participants on different nodes and passed in to 63 // multi-node collective invocations. 64 string GenerateCommunicatorKey(); 65 66 // A participant in a Collective. 67 struct Participant { ParticipantParticipant68 Participant(se::StreamExecutor* executor, se::Stream* tensor_stream, 69 const DeviceBase::GpuDeviceInfo* info, const Tensor* input, 70 Tensor* output, int global_rank, DoneCallback done_callback) 71 : executor(executor), 72 tensor_stream(tensor_stream), 73 event_mgr(info->event_mgr), 74 gpu_device_id(info->gpu_id), 75 #if TENSORFLOW_USE_ROCM 76 context(static_cast<GPUDeviceContext*>(info->default_context)), 77 #endif 78 input(input), 79 output(output), 80 global_rank(global_rank), 81 done_callback(std::move(done_callback)), 82 root(false) { 83 DCHECK(executor != nullptr); 84 DCHECK(event_mgr != nullptr); 85 DCHECK(tensor_stream != nullptr); 86 } 87 88 // StreamExecutor for the device. Expected to be live for process lifetime. 89 se::StreamExecutor* const executor = nullptr; 90 91 // `tensor_stream` is the stream that should be waited on to ensure 92 // `input`'s data is available on the GPU for the communication stream to 93 // access. It is also the stream that will use the produced data; 94 // `done_callback` is not called until the next kernel launched on `stream` 95 // would see the data. Owned by the caller, who must keep it live until 96 // `done_callback` is called. 97 se::Stream* const tensor_stream; 98 99 // EventMgr which polls on executor. 100 // Owned by the caller, who must keep it live until `done_callback` is 101 // called. 102 EventMgr* const event_mgr; 103 104 const int gpu_device_id; 105 106 #if TENSORFLOW_USE_ROCM 107 GPUDeviceContext* const context; 108 #endif 109 110 // Owned by the caller, who must keep it live until `done_callback` is 111 // called. Is NULL for participants that only receive data. 112 const Tensor* input; 113 114 // Owned by the caller, who must keep it live until `done_callback` is 115 // called. Is NULL for participants that only send data. 116 Tensor* output; 117 118 // Rank across all devices and all nodes. 119 // `global_rank` is not required for single-node collectives. 120 const int global_rank; 121 122 // The callback which is called at the completion of the NCCL operation. 123 // When called, `output` has been set to the result of the operation. (note: 124 // the stream may not yet have been synced) 125 DoneCallback done_callback; 126 127 // True if this is the root of the collective, e.g. source of broadcast. 128 bool root; 129 }; 130 131 // Data that provides context for the collective operation, including the 132 // operation key, number of participants, and communicator key. 133 struct Context { ContextContext134 Context(const string& collective_key, int num_local_devices, 135 int num_global_devices, const string& communicator_key, 136 int source_rank) 137 : collective_key(collective_key), 138 num_local_devices(num_local_devices), 139 num_global_devices(num_global_devices), 140 communicator_key(communicator_key), 141 source_rank(source_rank) {} 142 143 // Unique key for this collective instance 144 const string& collective_key; 145 146 // Devices local to this node 147 int num_local_devices; 148 149 // Devices across all nodes 150 int num_global_devices; 151 152 // In order to use NCCL across nodes, the callee first has to generate a 153 // `communicator_key` via `GenerateCommunicatorKey()` function and share 154 // this with all the other nodes. Each node should pass in this 155 // `communicator_key` to the `NcclManager` functions. 156 // `communicator_key` is not required for single-node collectives and can be 157 // empty. 158 const string& communicator_key; 159 160 // Rank of broadcast source. 161 int source_rank; 162 }; 163 164 // Adds one participant to an all-reduce. 165 void AddToAllReduce(std::unique_ptr<Participant> participant, 166 const Context& context, ncclRedOp_t reduction_op); 167 168 // Adds one participant to an all-gather. 169 void AddToAllGather(std::unique_ptr<Participant> participant, 170 const Context& context); 171 172 // AddBroadcastSend and AddBroadcastRecv combine to send data from one sender 173 // to all receivers. 174 void AddBroadcastSend(std::unique_ptr<Participant> participant, 175 const Context& context); 176 void AddBroadcastRecv(std::unique_ptr<Participant> participant, 177 const Context& context); 178 179 // AddReduceSend and AddReduceRecv combine to send data from all senders 180 // to one receiver. 181 void AddReduceSend(std::unique_ptr<Participant> participant, 182 const Context& context, ncclRedOp_t reduction_op); 183 void AddReduceRecv(std::unique_ptr<Participant> participant, 184 const Context& context, ncclRedOp_t reduction_op); 185 186 // Signals that the `Collective` corresponding to `key` is ready to launch 187 // across all nodes participating in this multi-node collective operation. 188 // 189 // This should only be called for multi-node collectives; single-node 190 // collectives are implicitly ready when all participants have called Add* 191 // function. 192 void SignalMultiNodeReady(const string& collective_key); 193 194 // Aborts all collectives. After abortion, no further collectives can be 195 // launched with this NcclManager. 196 void StartAbort(const Status& s); 197 198 // Resets a previously aborted NcclManager, making it available for future 199 // collectives. 200 void Reset(); 201 202 private: 203 enum CollectiveType { 204 kAllReduce = 1, 205 kBroadcast = 2, 206 kReduce = 3, 207 kAllGather = 4, 208 }; 209 struct Collective; 210 struct Communicator; 211 struct CommunicatorMember; 212 struct NcclStream; 213 214 // Gets the `Communicator` object that will be used to enqueue NCCL kernels 215 // for `collective`, and returns it via `communicator`. 216 // 217 // This may involve creating CUDA streams and NCCL initialization. If a NCCL 218 // or CUDA error occurs in the process, this returns an INTERNAL error with 219 // the corresponding NCCL/CUDA error string. 220 Status GetCommunicator(Collective* collective, Communicator** communicator); 221 222 // Adds a participant device to the local `Collective` instance corresponding 223 // to `collective_key`. Launches the `Collective` if it is ready, which it 224 // checks by calling `CheckReady()`. Also performs consistency and sanity 225 // checks before launching. 226 void AddParticipant(std::unique_ptr<Participant> participant, 227 const Context& context, CollectiveType collective_type, 228 ncclRedOp_t reduction_op); 229 230 // If `collective` is ready to run, removes it from the `collectives_` map and 231 // returns true. Otherwise returns false. 232 // Assumes `collective_key` corresponds to `collective`. 233 // 234 // A collective is ready to run when all local participants have called Add* 235 // function, and the collective is signalled globally ready via 236 // `SetMultiNodeReady`. 237 bool CheckReady(const string& collective_key, Collective* collective) 238 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 239 240 // Run <collective>. This calls takes ownership of <collective>. 241 void RunCollective(Collective* collective); 242 void LoopKernelLaunches(NcclStream* stream); 243 244 mutex mu_; 245 246 // Maps key to collectives currently being assembled or run. 247 absl::flat_hash_map<string, Collective*> collectives_ TF_GUARDED_BY(mu_); 248 249 // Maps a device to the communication streams that make up its collective. 250 // This is used to share the stream across different communicators that 251 // include the same device. 252 absl::flat_hash_map<se::StreamExecutor*, std::vector<NcclStream*>> 253 device_to_comm_streams_ TF_GUARDED_BY(mu_); 254 255 std::vector<std::unique_ptr<Communicator>> communicators_ TF_GUARDED_BY(mu_); 256 257 Status status_ TF_GUARDED_BY(mu_); 258 259 TF_DISALLOW_COPY_AND_ASSIGN(NcclManager); 260 }; 261 262 } // namespace tensorflow 263 264 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 265 266 #endif // TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ 267