1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/nccl/nccl_manager.h"
16
17 #include <utility>
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #include "absl/base/call_once.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 #include "tensorflow/core/lib/core/threadpool.h"
25 #include "tensorflow/core/platform/blocking_counter.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/unbounded_work_queue.h"
28 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
29 #include "tensorflow/core/profiler/lib/connected_traceme.h"
30 #include "tensorflow/core/profiler/lib/traceme.h"
31 #if GOOGLE_CUDA
32 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
33 #elif TENSORFLOW_USE_ROCM
34 #include "tensorflow/core/platform/rocm.h"
35 #endif
36
37 namespace tensorflow {
38
39 #if GOOGLE_CUDA
40 using se::cuda::ScopedActivateExecutorContext;
41 #elif TENSORFLOW_USE_ROCM
42 using se::rocm::ScopedActivateExecutorContext;
43 // Local hipify of cuda symbols
44 #define cudaError_t hipError_t
45 #define cudaStream_t hipStream_t
46 #define cudaGetErrorString hipGetErrorString
47 #define cudaGetDevice hipGetDevice
48 #define cudaSetDevice hipSetDevice
49 #define cudaSuccess hipSuccess
50 int NcclManager::instance_count = 0;
51 #endif
52
53 #define NCCL_RETURN_IF_ERROR(...) \
54 do { \
55 ncclResult_t nccl_status = (__VA_ARGS__); \
56 if (nccl_status != ncclSuccess) { \
57 return errors::Internal("NCCL: ", ncclGetErrorString(nccl_status), \
58 ". Set NCCL_DEBUG=WARN for detail."); \
59 } \
60 } while (0)
61
62 #define CUDA_RETURN_IF_ERROR(...) \
63 do { \
64 cudaError_t cuda_status = (__VA_ARGS__); \
65 if (cuda_status != cudaSuccess) { \
66 return errors::Internal("CUDA: ", cudaGetErrorString(cuda_status)); \
67 } \
68 } while (0)
69
70 // Contains data for a single stream used for nccl communication; this includes
71 // a background thread that calls NcclManager::LoopKernelLaunches.
72 struct NcclManager::NcclStream : public core::RefCounted {
73 public:
74 NcclStream() = default;
75 ~NcclStream() = default;
76
77 se::StreamExecutor* executor = nullptr;
78
79 // The stream on which to run the nccl collective.
80 // This is a different stream than the tensorflow compute stream.
81 #if TENSORFLOW_USE_ROCM
82 // On ROCm, we borrow the nccl stream from the device context.
83 se::Stream* stream = nullptr;
84 #else
85 std::unique_ptr<se::Stream> stream;
86 #endif
87
88 // `mu` protects access to `pending_launches_`, which is the list of
89 // collectives ready but whose kernels are yet to be launched. When the
90 // NcclManager object that owns this NcclStream object is destroyed, it
91 // signals `cv` to unblock the thread waiting on more collectives.
92 mutex mu;
93 condition_variable cv;
94 // Has (collective, participant_idx) pairs.
95 std::deque<std::pair<Collective*, int>> pending_launches_ TF_GUARDED_BY(mu);
96 bool shutdown_requested TF_GUARDED_BY(mu) = false;
97 };
98
99 struct NcclManager::CommunicatorMember {
100 public:
CommunicatorMembertensorflow::NcclManager::CommunicatorMember101 CommunicatorMember() {}
~CommunicatorMembertensorflow::NcclManager::CommunicatorMember102 ~CommunicatorMember() {
103 if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
104 }
105
106 ncclComm_t nccl_comm = nullptr;
107 // Owned by NcclManager::device_to_comm_streams_ and LoopKernelLaunches.
108 NcclStream* nccl_stream = nullptr;
109 };
110
111 struct NcclManager::Communicator {
112 public:
Communicatortensorflow::NcclManager::Communicator113 explicit Communicator(std::vector<CommunicatorMember> members,
114 const string& key)
115 : num_devices(members.size()), members(std::move(members)), key(key) {}
116
117 const int num_devices;
118 std::vector<CommunicatorMember> members;
119 const string key;
120 };
121
122 namespace {
123
124 static constexpr DataTypeSet kValidDataTypes =
125 ToSet(DT_HALF) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) |
126 ToSet(DT_INT64);
127
ToNcclType(DataType t)128 ncclDataType_t ToNcclType(DataType t) {
129 switch (t) {
130 case DT_HALF:
131 return ncclHalf;
132 case DT_FLOAT:
133 return ncclFloat;
134 case DT_DOUBLE:
135 return ncclDouble;
136 case DT_INT32:
137 return ncclInt;
138 case DT_INT64:
139 return ncclInt64;
140 default:
141 return ncclFloat;
142 }
143 }
144
StringToNcclUniqueId(const string & str_id,ncclUniqueId * nccl_id)145 void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) {
146 if (str_id.size() == NCCL_UNIQUE_ID_BYTES) {
147 memcpy(nccl_id->internal, str_id.data(), NCCL_UNIQUE_ID_BYTES);
148 }
149 }
150
151 } // namespace
152
153 // A `Collective` encapsulates state for a collective instance at one node.
154 // Typically, an instance in TensorFlow context would be defined by a collective
155 // group and the (step, frame iteration) for that execution.
156 //
157 // For each collective instance there will be one `Collective` object per node.
158 // For example, a NCCL collective that runs on a single node with 4 GPUs would
159 // have a single `Collective` per step. However, a collective that executes on
160 // 3 nodes with 4 GPUs each would have a `Collective` per node, each of which is
161 // tracking the 4 GPUs local to that node.
162 struct NcclManager::Collective : public core::RefCounted {
Collectivetensorflow::NcclManager::Collective163 Collective(const string& collective_key_in, DataType data_type_in,
164 CollectiveType type_in, ncclRedOp_t reduction_op_in,
165 int num_local_devices_in, int num_global_devices_in,
166 const string& communicator_key_in)
167 : collective_key(collective_key_in),
168 data_type(data_type_in),
169 type(type_in),
170 reduction_op(reduction_op_in),
171 num_local_devices(num_local_devices_in),
172 num_global_devices(num_global_devices_in),
173 single_node(num_local_devices_in == num_global_devices_in),
174 communicator_key(communicator_key_in) {
175 participants.reserve(num_local_devices_in);
176 #if TENSORFLOW_USE_ROCM
177 // On ROCm platform, this allows caller to either use the singleton instance
178 // or to manage one non-singleton NcclManager instance.
179 // For example, the nccl_manager_test will use both paradigms in the same
180 // executable, but not running concurrently (which would hang otherwise).
181 if (NcclManager::instance_count > 1) {
182 status = errors::Internal(
183 "ROCm cannot use multi-node NCCL collectives on a single node");
184 }
185 #endif
186 }
187
188 const string collective_key; // A unique key for debugging.
189 const DataType data_type;
190 const CollectiveType type;
191 const ncclRedOp_t reduction_op; // applies when <type> is a reduction.
192 const int num_local_devices; // devices local to this node
193 const int num_global_devices; // devices across all nodes
194 const bool single_node; // true if all devices are at one node
195 const string communicator_key;
196
197 Communicator* communicator = nullptr;
198
199 // All collective participants.
200 //
201 // Adding values in this vector is guarded by the mutex of the containing
202 // NcclManager.
203 std::vector<std::unique_ptr<Participant>> participants;
204
205 // For collective types that have a root (e.g. the root of broadcast is the
206 // sender), this is the rank of the root.
207 int root_rank = -1;
208
209 // How many participants have been registered so far. The Collective is
210 // eligible for running with <available_participants> == num_local_devices.
211 //
212 // If this is a multi-node collective, we additionally have to synchronize
213 // across nodes. The caller would need to signal multi node readiness by
214 // calling NcclManager::SignalMultiNodeReady, which sets `multi_node_ready` to
215 // true.
216 //
217 // Guarded by the mutex of the containing Communicator.
218 int available_participants = 0;
219 bool multi_node_ready = false;
220 // trace_context is used by tracing system to associate collective
221 // scheduling and execution (cooperative kernel launch), which happen
222 // on different threads.
223 uint64 trace_context = 0;
224
225 Status status;
226 };
227
NcclManager()228 NcclManager::NcclManager() {
229 VLOG(2) << "New NcclManager " << this;
230 #if TENSORFLOW_USE_ROCM
231 ++instance_count;
232 #endif
233 }
~NcclManager()234 NcclManager::~NcclManager() {
235 VLOG(2) << "~NcclManager " << this;
236 #if TENSORFLOW_USE_ROCM
237 --instance_count;
238 #endif
239 for (auto& it : device_to_comm_streams_) {
240 for (NcclStream* nccl_stream : it.second) {
241 {
242 mutex_lock l(nccl_stream->mu);
243 nccl_stream->shutdown_requested = true;
244 nccl_stream->cv.notify_all();
245 }
246 nccl_stream->Unref();
247 }
248 }
249 }
instance()250 NcclManager* NcclManager::instance() {
251 static NcclManager* instance = new NcclManager();
252 #if TENSORFLOW_USE_ROCM
253 // singleton does not count against total instances
254 // see comment above in Collective constructor concerning ROCm platform
255 static absl::once_flag once;
256 absl::call_once(once, [] { --NcclManager::instance_count; });
257 #endif
258 return instance;
259 }
260
GenerateCommunicatorKey()261 string NcclManager::GenerateCommunicatorKey() {
262 ncclUniqueId nccl_id;
263 ncclGetUniqueId(&nccl_id);
264 return string(nccl_id.internal, NCCL_UNIQUE_ID_BYTES);
265 }
266
GetCommunicator(NcclManager::Collective * collective,NcclManager::Communicator ** communicator)267 Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
268 NcclManager::Communicator** communicator) {
269 // Sort by device ID, executor, and global rank to make ordering of
270 // participants deterministic.
271 std::sort(collective->participants.begin(), collective->participants.end(),
272 [](const std::unique_ptr<Participant>& a,
273 const std::unique_ptr<Participant>& b) {
274 if (a->gpu_device_id != b->gpu_device_id) {
275 return a->gpu_device_id < b->gpu_device_id;
276 }
277 if (a->executor != b->executor) {
278 return a->executor < b->executor;
279 }
280 return a->global_rank < b->global_rank;
281 });
282
283 mutex_lock l(mu_);
284 if (!status_.ok()) {
285 return status_;
286 }
287
288 if (collective->communicator_key.empty()) {
289 // For single-node collectives, when the caller does not specify a
290 // `communicator_key`, we identify a communicator uniquely by the set of
291 // devices participating in the collective. For example, if a collective is
292 // for GPUs 0, 1, and 2 then this will scan to find the communicator for
293 // GPUs 0, 1, and 2.
294 //
295 // Note that each executor identifies a context on one device, so this is
296 // the same as getting the communicator connecting the devices in the
297 // collective. A device can be in different communicators as well - for
298 // example, a communicator for GPUs 0 and 1 is separate from one for GPUs 0,
299 // 1, and 2.
300 //
301 // Since it's expected that a small number of distinct communicators will
302 // be needed, communicators_ is not garbage collected currently.
303 //
304 // Launching of kernels must be serialized so that, given collectives A and
305 // B, and an order of them (e.g., A before B), then for each comm_stream
306 // involved, the kernel for A is launched before the kernel for B. This is
307 // guaranteed currently by a global mutex controlling additions of the
308 // kernels to per-stream launch queues. The launch queues are processed by
309 // LoopKernelLaunches.
310 for (auto& comm : communicators_) {
311 if (comm->num_devices == collective->num_global_devices) {
312 int i;
313 for (i = 0; i < collective->num_local_devices; ++i) {
314 if (comm->members[i].nccl_stream->executor !=
315 collective->participants[i]->executor) {
316 break;
317 }
318 }
319 if (i == collective->num_local_devices) {
320 *communicator = comm.get();
321 return Status::OK();
322 }
323 }
324 }
325 } else {
326 #if NCCL_MAJOR < 2
327 return errors::Internal(
328 "Cannot use multi-node NCCL collectives with NCCL 1.x");
329 #endif
330 if (collective->communicator_key.size() != NCCL_UNIQUE_ID_BYTES) {
331 return errors::Internal("Expected communicator_key of size ",
332 NCCL_UNIQUE_ID_BYTES, " but found size ",
333 collective->communicator_key.size());
334 }
335 // This is an instance of multi-node collective. We have previously
336 // created a NCCL unique id and shared with all workers. Now we find the
337 // `Communicator` corresponding to this id.
338 for (auto& comm : communicators_) {
339 if (comm->key == collective->communicator_key) {
340 *communicator = comm.get();
341 return Status::OK();
342 }
343 }
344 }
345
346 auto* env = Env::Default();
347 std::set<NcclStream*> used_streams;
348
349 // Create and initialize a new communicator.
350 // Note that this is done under the lock; performance is not expected to
351 // matter as this happens a very small number of times.
352 std::vector<CommunicatorMember> members(collective->num_local_devices);
353 std::vector<int> devices(collective->num_local_devices);
354 for (int i = 0; i < collective->num_local_devices; ++i) {
355 auto* executor = collective->participants[i]->executor;
356
357 // Find a communication stream to use for the device.
358 auto& streams = device_to_comm_streams_[executor];
359 NcclStream* nccl_stream = nullptr;
360 for (const auto& s : streams) {
361 if (used_streams.insert(s).second) {
362 nccl_stream = s;
363 break;
364 }
365 }
366 if (nccl_stream == nullptr) {
367 nccl_stream = new NcclStream();
368 nccl_stream->executor = executor;
369 #if TENSORFLOW_USE_ROCM
370 nccl_stream->stream = collective->participants[i]->context->nccl_stream();
371 #else
372 nccl_stream->stream.reset(new se::Stream(executor));
373 nccl_stream->stream->Init();
374 #endif
375
376 streams.emplace_back(nccl_stream);
377 used_streams.insert(nccl_stream);
378
379 nccl_stream->Ref();
380 env->SchedClosure([this, nccl_stream]() {
381 LoopKernelLaunches(nccl_stream);
382 nccl_stream->Unref();
383 });
384 }
385
386 members[i].nccl_stream = nccl_stream;
387 devices[i] = collective->participants[i]->gpu_device_id;
388 }
389
390 std::vector<ncclComm_t> nccl_comms(collective->num_local_devices);
391 #if NCCL_MAJOR >= 2
392 // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL
393 // group primitives.
394 ncclUniqueId nccl_id;
395 if (collective->single_node) {
396 NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
397 } else {
398 StringToNcclUniqueId(collective->communicator_key, &nccl_id);
399 }
400 int saved_device = 0;
401 CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device));
402 NCCL_RETURN_IF_ERROR(ncclGroupStart());
403 for (int i = 0; i < collective->num_local_devices; ++i) {
404 // Set rank to `participant->global_rank` if provided, else `i`.
405 const int rank = collective->participants[i]->global_rank >= 0
406 ? collective->participants[i]->global_rank
407 : i;
408 CUDA_RETURN_IF_ERROR(cudaSetDevice(devices[i]));
409 NCCL_RETURN_IF_ERROR(ncclCommInitRank(
410 nccl_comms.data() + i, collective->num_global_devices, nccl_id, rank));
411 }
412 NCCL_RETURN_IF_ERROR(ncclGroupEnd());
413 CUDA_RETURN_IF_ERROR(cudaSetDevice(saved_device));
414 #else
415 // Since NCCL 1 is single node only, we use ncclCommInitAll. We could have
416 // used ncclCommInitRank with NCCL 1 as well, but then we would have to
417 // issue each init call from a different thread
418 // (https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/nccl1.html).
419 NCCL_RETURN_IF_ERROR(ncclCommInitAll(
420 nccl_comms.data(), collective->num_local_devices, devices.data()));
421 #endif
422
423 for (int i = 0; i < collective->num_local_devices; ++i) {
424 members[i].nccl_comm = nccl_comms[i];
425 }
426 communicators_.emplace_back(
427 new Communicator(std::move(members), collective->communicator_key));
428 *communicator = communicators_.back().get();
429 return Status::OK();
430 }
431
AddToAllReduce(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)432 void NcclManager::AddToAllReduce(std::unique_ptr<Participant> participant,
433 const Context& context,
434 ncclRedOp_t reduction_op) {
435 AddParticipant(std::move(participant), context, kAllReduce, reduction_op);
436 }
437
AddToAllGather(std::unique_ptr<Participant> participant,const Context & context)438 void NcclManager::AddToAllGather(std::unique_ptr<Participant> participant,
439 const Context& context) {
440 AddParticipant(std::move(participant), context, kAllGather,
441 ncclSum /* unused */);
442 }
443
AddBroadcastSend(std::unique_ptr<Participant> participant,const Context & context)444 void NcclManager::AddBroadcastSend(std::unique_ptr<Participant> participant,
445 const Context& context) {
446 participant->root = true;
447 AddParticipant(std::move(participant), context, kBroadcast,
448 ncclSum /* unused */);
449 }
450
AddBroadcastRecv(std::unique_ptr<Participant> participant,const Context & context)451 void NcclManager::AddBroadcastRecv(std::unique_ptr<Participant> participant,
452 const Context& context) {
453 AddParticipant(std::move(participant), context, kBroadcast,
454 ncclSum /* unused */);
455 }
456
AddReduceSend(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)457 void NcclManager::AddReduceSend(std::unique_ptr<Participant> participant,
458 const Context& context,
459 ncclRedOp_t reduction_op) {
460 AddParticipant(std::move(participant), context, kReduce, reduction_op);
461 }
462
AddReduceRecv(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)463 void NcclManager::AddReduceRecv(std::unique_ptr<Participant> participant,
464 const Context& context,
465 ncclRedOp_t reduction_op) {
466 participant->root = true;
467 AddParticipant(std::move(participant), context, kReduce, reduction_op);
468 }
469
SignalMultiNodeReady(const string & collective_key)470 void NcclManager::SignalMultiNodeReady(const string& collective_key) {
471 Collective* to_run = nullptr;
472 {
473 mutex_lock l(mu_);
474 auto collective_it = collectives_.find(collective_key);
475 if (collective_it != collectives_.end()) {
476 Collective* collective = collective_it->second;
477 collective->multi_node_ready = true;
478 if (CheckReady(collective_key, collective)) {
479 to_run = collective;
480 }
481 VLOG(2) << "SignalMultiNodeReady collective " << collective_key
482 << " to_run " << to_run;
483 }
484 }
485
486 if (to_run != nullptr) RunCollective(to_run);
487 }
488
AddParticipant(std::unique_ptr<Participant> participant,const Context & context,CollectiveType collective_type,ncclRedOp_t reduction_op)489 void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
490 const Context& context,
491 CollectiveType collective_type,
492 ncclRedOp_t reduction_op) {
493 Collective* to_run = nullptr;
494 DataType data_type;
495 Status nccl_manager_status;
496 if (participant->input != nullptr) {
497 data_type = participant->input->dtype();
498 } else {
499 data_type = participant->output->dtype();
500 }
501 {
502 mutex_lock l(mu_);
503 nccl_manager_status = status_;
504 if (nccl_manager_status.ok()) {
505 auto collective_it = collectives_.find(context.collective_key);
506 Collective* collective = nullptr;
507 if (collective_it == collectives_.end()) {
508 collective = new Collective(
509 context.collective_key, data_type, collective_type, reduction_op,
510 context.num_local_devices, context.num_global_devices,
511 context.communicator_key);
512 collectives_.emplace(context.collective_key, collective);
513 } else {
514 collective = collective_it->second;
515 }
516
517 // Check `collective` is correct and consistent.
518 if (collective->status.ok() && !collective->single_node &&
519 collective->communicator_key.empty()) {
520 collective->status = errors::Internal(
521 "Collective ", reduction_op,
522 " is multi node with num_local_devices=",
523 collective->num_local_devices,
524 " and num_global_devices=", collective->num_global_devices,
525 " but has an empty communicator_key");
526 }
527 if (collective->status.ok() && collective->communicator_key.size() !=
528 context.communicator_key.size()) {
529 collective->status =
530 errors::Internal("Collective ", reduction_op,
531 " mismatch in member communicator_key with size ",
532 collective->communicator_key.size(),
533 " and arg communicator_key with size ",
534 context.communicator_key.size());
535 }
536 if (collective->status.ok() && collective->type != collective_type) {
537 collective->status = errors::Internal(
538 "Collective ", reduction_op, " previously initialized with type ",
539 collective->type, " but now got type ", collective_type);
540 }
541 if (collective->status.ok() &&
542 collective->num_global_devices != context.num_global_devices) {
543 collective->status =
544 errors::Internal("Collective ", reduction_op,
545 " previously initialized with num_global_devices ",
546 collective->num_global_devices, " but now got ",
547 context.num_global_devices);
548 }
549 if (collective->status.ok() &&
550 collective->num_local_devices != context.num_local_devices) {
551 collective->status =
552 errors::Internal("Collective ", reduction_op,
553 "previously initialized with num_local_devices ",
554 collective->num_local_devices, " but now got ",
555 context.num_local_devices);
556 }
557 if (collective->status.ok() &&
558 collective->participants.size() >= collective->num_local_devices) {
559 collective->status = errors::Internal(
560 "Collective ", reduction_op, " expected ",
561 collective->num_local_devices, " participants but now has ",
562 collective->participants.size(),
563 " with one more participant being added");
564 }
565 if (collective->status.ok() && collective->root_rank >= 0 &&
566 context.source_rank >= 0 &&
567 collective->root_rank != context.source_rank) {
568 collective->status = errors::Internal(
569 "Collective ", collective->collective_key,
570 " already has root_rank ", collective->root_rank,
571 " but new participant has root_rank ", context.source_rank);
572 }
573 if (collective->status.ok() &&
574 !kValidDataTypes.Contains(collective->data_type)) {
575 collective->status = errors::Internal(
576 "Collective ", collective->collective_key,
577 " expected data types compatible with NCCL but instead got ",
578 DataTypeString(collective->data_type));
579 }
580
581 if (context.source_rank >= 0) {
582 collective->root_rank = context.source_rank;
583 }
584
585 collective->participants.emplace_back(std::move(participant));
586 ++collective->available_participants;
587
588 if (CheckReady(context.collective_key, collective)) {
589 to_run = collective;
590 }
591 }
592 }
593 if (!nccl_manager_status.ok()) {
594 participant->done_callback(nccl_manager_status);
595 return;
596 }
597 if (to_run != nullptr) RunCollective(to_run);
598 }
599
CheckReady(const string & collective_key,Collective * collective)600 bool NcclManager::CheckReady(const string& collective_key,
601 Collective* collective) {
602 if (collective->available_participants == collective->num_local_devices) {
603 if (collective->num_global_devices == collective->num_local_devices ||
604 collective->multi_node_ready) {
605 // Ownership transferred to callee.
606 collectives_.erase(collective_key);
607 return true;
608 }
609 }
610 return false;
611 }
612
RunCollective(Collective * collective)613 void NcclManager::RunCollective(Collective* collective) {
614 // For TraceMeConsumer in Connection::RPCDone().
615 tensorflow::profiler::TraceMeProducer traceme("Schedule Collective");
616 collective->trace_context = traceme.GetContextId();
617
618 static mutex collective_mu(LINKER_INITIALIZED);
619
620 Status status = collective->status;
621 if (status.ok()) {
622 status = GetCommunicator(collective, &collective->communicator);
623 }
624
625 for (int i = 0; status.ok() && i < collective->num_local_devices; ++i) {
626 Participant* p = collective->participants[i].get();
627 NcclStream* nccl_stream = collective->communicator->members[i].nccl_stream;
628 CHECK(nccl_stream != nullptr);
629 const int rank = p->global_rank >= 0 ? p->global_rank : i;
630
631 if (p->input != nullptr) {
632 // Wait to ensure that the kernel that produces the data in the input
633 // tensor has finished running before the nccl kernel runs on the
634 // communication stream.
635 nccl_stream->stream->ThenWaitFor(p->tensor_stream);
636 }
637 if (p->root) {
638 if (collective->root_rank == -1) {
639 collective->root_rank = rank;
640 } else if (collective->root_rank != rank) {
641 status = errors::Internal(
642 "Inconsistent root rank ", collective->root_rank, " and GPU id ",
643 p->gpu_device_id, " rank ", rank, " also marked as root.");
644 }
645 }
646 VLOG(2) << "RunCollective rank " << rank << " global_rank "
647 << p->global_rank << " root_rank " << collective->root_rank;
648 }
649
650 if (status.ok() && collective->type == kBroadcast &&
651 collective->root_rank < 0) {
652 status = errors::Internal("Root rank not indicated for collective ",
653 collective->collective_key);
654 }
655
656 if (!status.ok()) {
657 for (int i = 0; i < collective->num_local_devices; ++i) {
658 collective->participants[i]->done_callback(status);
659 }
660 collective->Unref();
661 return;
662 }
663
664 {
665 // Allow only one collective at a time to queue kernels for launching. This
666 // is to prevent collectives from deadlocking each other.
667 // Note that it would be possible to run multiple collectives at once, if
668 // they have non-intersecting sets of devices.
669 mutex_lock l(collective_mu);
670 for (int i = 0; i < collective->num_local_devices; ++i) {
671 NcclStream* nccl_stream =
672 collective->communicator->members[i].nccl_stream;
673 mutex_lock l(nccl_stream->mu);
674 nccl_stream->pending_launches_.push_front(std::make_pair(collective, i));
675 // Ownership is shared between LoopKernelLaunches for each stream in this
676 // collective.
677 collective->Ref();
678 nccl_stream->cv.notify_all();
679 }
680 }
681 collective->Unref();
682 }
683
684 namespace {
685 // For tracing purpose.
ComputeBufferSize(const NcclManager::Participant * p,DataType data_type)686 size_t ComputeBufferSize(const NcclManager::Participant* p,
687 DataType data_type) {
688 size_t num_elements = 0;
689 if (p->output) {
690 num_elements += p->output->NumElements();
691 } else if (p->input) {
692 num_elements += p->input->NumElements();
693 }
694 return num_elements * DataTypeSize(data_type);
695 }
696 } // namespace
697
LoopKernelLaunches(NcclStream * nccl_stream)698 void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
699 #if TENSORFLOW_USE_ROCM
700 se::Stream* comm_stream = nccl_stream->stream;
701 #else
702 se::Stream* comm_stream = nccl_stream->stream.get();
703 #endif
704 ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
705 const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
706 comm_stream->implementation()->GpuStreamMemberHack());
707
708 while (true) {
709 // Find collective to run.
710 std::pair<Collective*, int> next_launch;
711 {
712 VLOG(3) << "Locking mutex nccl_stream " << nccl_stream;
713 mutex_lock l(nccl_stream->mu);
714 while (nccl_stream->pending_launches_.empty()) {
715 if (nccl_stream->shutdown_requested) {
716 // No work and shutdown requested, exit.
717 return;
718 }
719 nccl_stream->cv.wait(l);
720 }
721 next_launch = nccl_stream->pending_launches_.back();
722 nccl_stream->pending_launches_.pop_back();
723 }
724
725 // Launch the nccl kernel.
726 Collective* collective = next_launch.first;
727 tensorflow::profiler::TraceMeConsumer traceme("Run Collective",
728 collective->trace_context);
729
730 ncclDataType_t data_type = ToNcclType(collective->data_type);
731 int p_idx = next_launch.second;
732 Participant* p = collective->participants[p_idx].get();
733 auto nccl_comm = collective->communicator->members[p_idx].nccl_comm;
734 ncclResult_t nccl_result = ncclSuccess;
735 switch (collective->type) {
736 case kAllReduce: {
737 const void* sendbuff = p->input->tensor_data().data();
738 void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
739
740 VLOG(2) << "call NcclAllReduce collective_key "
741 << collective->collective_key << " participant " << p_idx
742 << " num_participants " << collective->participants.size()
743 << " sendbuff " << sendbuff << " recvbuff " << recvbuff
744 << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream
745 << " cuda_stream " << cu_stream;
746 profiler::AnnotatedTraceMe traceme([&] {
747 return profiler::TraceMeEncode(
748 "ncclAllReduce",
749 {{"buffer_size", ComputeBufferSize(p, collective->data_type)},
750 {"collective_type", "all_reduce"}});
751 });
752 nccl_result = ncclAllReduce(sendbuff, recvbuff, p->input->NumElements(),
753 data_type, collective->reduction_op,
754 nccl_comm, *cu_stream);
755 break;
756 }
757 case kBroadcast: {
758 const void* sendbuff = nullptr;
759 void* recvbuff = nullptr;
760 int num_elements = -1;
761 if (p->input) {
762 sendbuff = p->input->tensor_data().data();
763 num_elements = p->input->NumElements();
764 }
765 if (p->output) {
766 recvbuff = const_cast<char*>(p->output->tensor_data().data());
767 num_elements = p->output->NumElements();
768 } else {
769 // Operate in-place if no output (for the src node).
770 recvbuff = const_cast<void*>(sendbuff);
771 }
772 if (num_elements < 0) {
773 p->done_callback(errors::Internal(
774 "Both input and output are null in ncclBroadcast"));
775 collective->Unref();
776 continue;
777 }
778 VLOG(2) << "call NcclBroadcast collective_key "
779 << collective->collective_key << " participant " << p_idx
780 << " sendbuff " << sendbuff << " recvbuff " << recvbuff
781 << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream
782 << " cuda_stream " << cu_stream;
783 profiler::AnnotatedTraceMe traceme([&] {
784 return profiler::TraceMeEncode(
785 "ncclBroadcast",
786 {{"buffer_size", ComputeBufferSize(p, collective->data_type)},
787 {"collective_type", "broadcast"}});
788 });
789 nccl_result =
790 ncclBroadcast(sendbuff, recvbuff, num_elements, data_type,
791 collective->root_rank, nccl_comm, *cu_stream);
792 break;
793 }
794 case kReduce: {
795 const void* sendbuff = p->input->tensor_data().data();
796 void* recvbuff =
797 p->output ? const_cast<char*>(p->output->tensor_data().data())
798 : nullptr;
799 profiler::AnnotatedTraceMe traceme([&] {
800 return profiler::TraceMeEncode(
801 "buffer_size",
802 {{"output_size", ComputeBufferSize(p, collective->data_type)},
803 {"collective_type", "reduce"}});
804 });
805 nccl_result = ncclReduce(sendbuff, recvbuff, p->input->NumElements(),
806 data_type, collective->reduction_op,
807 collective->root_rank, nccl_comm, *cu_stream);
808 break;
809 }
810 case kAllGather: {
811 const void* sendbuff = p->input->tensor_data().data();
812 void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
813
814 VLOG(2) << "call NcclAllGather collective_key "
815 << collective->collective_key << " participant " << p_idx
816 << " sendbuff " << sendbuff << " sendcount "
817 << p->input->NumElements() << " recvbuff " << recvbuff
818 << " recvcount " << p->output->NumElements() << " nccl_comm "
819 << nccl_comm << " comm_stream " << comm_stream
820 << " cuda_stream " << cu_stream;
821 profiler::AnnotatedTraceMe traceme([&] {
822 return profiler::TraceMeEncode(
823 "ncclAllGather",
824 {{"buffer_size", ComputeBufferSize(p, collective->data_type)},
825 {"collective_type", "all_gather"}});
826 });
827 nccl_result = ncclAllGather(sendbuff, recvbuff, p->input->NumElements(),
828 data_type, nccl_comm, *cu_stream);
829 break;
830 }
831 }
832
833 // Run the done_callback when the nccl kernel finishes running.
834 auto done_callback = [collective, p_idx, nccl_result]() {
835 VLOG(2) << "done Nccl kernel collective_key "
836 << collective->collective_key << " participant " << p_idx
837 << " ncclResult " << nccl_result;
838 if (nccl_result == ncclSuccess) {
839 collective->participants[p_idx]->done_callback(Status::OK());
840 } else {
841 // Propagate the error, but note that if other members of the collective
842 // did launch their kernels, then they are hanging.
843 collective->participants[p_idx]->done_callback(errors::Unknown(
844 "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
845 }
846 collective->Unref();
847 };
848 p->event_mgr->ThenExecute(comm_stream, done_callback);
849 }
850 }
851
StartAbort(const Status & s)852 void NcclManager::StartAbort(const Status& s) {
853 absl::flat_hash_map<string, Collective*> collectives;
854 std::vector<std::unique_ptr<Communicator>> communicators;
855 {
856 mutex_lock l(mu_);
857 if (!status_.ok()) {
858 LOG(WARNING)
859 << "NcclManager already aborted, ignoring subsequent StartAbort with "
860 << s;
861 return;
862 }
863 status_ = s;
864 collectives.swap(collectives_);
865 communicators.swap(communicators_);
866 }
867 VLOG(2) << "Aborted NcclManager " << this << " with " << collectives.size()
868 << " collectives and " << communicators.size()
869 << " comms with status " << s;
870 // collectives_ contains pending launches that haven't been dispatched to
871 // kernel launch threads, so we can simply invoke the done callbacks of them.
872 for (const auto& item : collectives) {
873 for (const std::unique_ptr<Participant>& p : item.second->participants) {
874 p->done_callback(s);
875 }
876 item.second->Unref();
877 }
878 // Abort ncclComm. Note that there could be multiple ncclComm per device,
879 // and ncclCommAbort contains cuda calls that requires device
880 // synchronization. That is a collective on nccl_comm_0 can block
881 // ncclCommAbort(nccl_comm_1), so we need to abort all ncclComm in a
882 // concurrent fashion. This assumes that there's only one active NcclManager
883 // at a time.
884 UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
885 int num_comms = 0;
886 for (std::unique_ptr<Communicator>& communicator : communicators) {
887 num_comms += communicator->members.size();
888 }
889 BlockingCounter pending(num_comms);
890 for (std::unique_ptr<Communicator>& communicator : communicators) {
891 for (CommunicatorMember& member : communicator->members) {
892 queue.Schedule([&member, &pending]() {
893 ncclCommAbort(member.nccl_comm);
894 member.nccl_comm = nullptr;
895 pending.DecrementCount();
896 });
897 }
898 }
899 pending.Wait();
900 }
901
Reset()902 void NcclManager::Reset() {
903 mutex_lock l(mu_);
904 status_ = Status();
905 VLOG(2) << "Reset NcclManager " << this;
906 }
907
908 } // namespace tensorflow
909
910 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
911