1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/nccl/nccl_manager.h"
16
17 #include <utility>
18
19 #ifdef GOOGLE_CUDA
20
21 #include "tensorflow/core/lib/core/threadpool.h"
22 #include "tensorflow/core/platform/cuda.h"
23 #include "tensorflow/core/platform/env.h"
24
25 namespace tensorflow {
26
27 #define NCCL_RETURN_IF_ERROR(...) \
28 do { \
29 ncclResult_t nccl_status = (__VA_ARGS__); \
30 if (nccl_status != ncclSuccess) { \
31 return errors::Internal(ncclGetErrorString(nccl_status)); \
32 } \
33 } while (0)
34
35 #define CUDA_RETURN_IF_ERROR(...) \
36 do { \
37 cudaError_t cuda_status = (__VA_ARGS__); \
38 if (cuda_status != cudaSuccess) { \
39 return errors::Internal(cudaGetErrorString(cuda_status)); \
40 } \
41 } while (0)
42
43 using se::cuda::ScopedActivateExecutorContext;
44
45 // Contains data for a single stream used for nccl communication; this includes
46 // a background thread that calls NcclManager::LoopKernelLaunches.
47 struct NcclManager::NcclStream {
48 public:
NcclStreamtensorflow::NcclManager::NcclStream49 NcclStream() {}
~NcclStreamtensorflow::NcclManager::NcclStream50 ~NcclStream() {
51 mutex_lock l(mu);
52 shutdown_requested = true;
53 cv.notify_all();
54 }
55
56 se::StreamExecutor* executor = nullptr;
57
58 // The stream on which to run the nccl collective.
59 // This is a different stream than the tensorflow compute stream.
60 std::unique_ptr<se::Stream> stream;
61
62 // See NcclManager::LoopKernelLaunches for information on these.
63 std::unique_ptr<Thread> thread;
64 mutex mu;
65 condition_variable cv;
66 // Has collective,participant_idx pairs.
67 std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu);
68 bool shutdown_requested GUARDED_BY(mu) = false;
69 };
70
71 struct NcclManager::CommunicatorMember {
72 public:
CommunicatorMembertensorflow::NcclManager::CommunicatorMember73 CommunicatorMember() {}
~CommunicatorMembertensorflow::NcclManager::CommunicatorMember74 ~CommunicatorMember() {
75 if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
76 }
77 ncclComm_t nccl_comm;
78
79 // Owned by NcclManager::device_to_comm_streams_.
80 NcclStream* nccl_stream = nullptr;
81 };
82
83 struct NcclManager::Communicator {
84 public:
Communicatortensorflow::NcclManager::Communicator85 explicit Communicator(std::vector<CommunicatorMember> members,
86 const string& key)
87 : num_devices(members.size()), members(std::move(members)), key(key) {}
88
89 const int num_devices;
90 const std::vector<CommunicatorMember> members;
91 const string key;
92 };
93
94 namespace {
95
ToNcclType(DataType t)96 ncclDataType_t ToNcclType(DataType t) {
97 switch (t) {
98 case DT_HALF:
99 return ncclHalf;
100 case DT_FLOAT:
101 return ncclFloat;
102 case DT_DOUBLE:
103 return ncclDouble;
104 case DT_INT32:
105 return ncclInt;
106 case DT_INT64:
107 return ncclInt64;
108 default:
109 return ncclFloat;
110 }
111 }
112
StringToNcclUniqueId(const string & str_id,ncclUniqueId * nccl_id)113 void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) {
114 if (str_id.size() == NCCL_UNIQUE_ID_BYTES) {
115 memcpy(nccl_id->internal, str_id.data(), NCCL_UNIQUE_ID_BYTES);
116 }
117 }
118
119 } // namespace
120
121 // A `Collective` encapsulates state for a collective instance at one node.
122 // Typically, an instance in TensorFlow context would be defined by a collective
123 // group and the (step, frame iteration) for that execution.
124 //
125 // For each collective instance there will be one `Collective` object per node.
126 // For example, a NCCL collective that runs on a single node with 4 GPUs would
127 // have a single `Collective` per step. However, a collective that executes on
128 // 3 nodes with 4 GPUs each would have a `Collective` per node, each of which is
129 // tracking the 4 GPUs local to that node.
130 struct NcclManager::Collective {
Collectivetensorflow::NcclManager::Collective131 Collective(DataType data_type_in, CollectiveType type_in,
132 ncclRedOp_t reduction_op_in, int num_local_devices_in,
133 int num_global_devices_in, const string& communicator_key_in)
134 : data_type(data_type_in),
135 type(type_in),
136 reduction_op(reduction_op_in),
137 num_local_devices(num_local_devices_in),
138 num_global_devices(num_global_devices_in),
139 single_node(num_local_devices_in == num_global_devices_in),
140 communicator_key(communicator_key_in),
141 remaining_participants(num_local_devices_in) {
142 participants.reserve(num_local_devices_in);
143 }
144
145 const DataType data_type;
146 const CollectiveType type;
147 const ncclRedOp_t reduction_op; // applies when <type> is a reduction.
148 const int num_local_devices; // devices local to this node
149 const int num_global_devices; // devices across all nodes
150 const bool single_node; // true if all devices are at one node
151 const string communicator_key;
152
153 Communicator* communicator = nullptr;
154
155 // All collective participants.
156 //
157 // Adding values in this vector is guarded by the mutex of the containing
158 // NcclManager.
159 std::vector<std::unique_ptr<Participant>> participants;
160
161 // For collective types that have a root (e.g. the root of broadcast is the
162 // sender), this is the rank of the root.
163 int root_rank = -1;
164
165 // How many participants have been registered so far. The Collective is
166 // eligible for running with <available_participants> == num_local_devices.
167 //
168 // If this is a multi-node collective, we additionally have to synchronize
169 // across nodes. The caller would need to signal multi node readiness by
170 // calling NcclManager::SignalMultiNodeReady, which sets `multi_node_ready` to
171 // true.
172 //
173 // Guarded by the mutex of the containing Communicator.
174 int available_participants = 0;
175 bool multi_node_ready = false;
176
177 mutable std::atomic_int_fast32_t remaining_participants;
178
179 Status status;
180 };
181
NcclManager()182 NcclManager::NcclManager() {}
~NcclManager()183 NcclManager::~NcclManager() {}
instance()184 NcclManager* NcclManager::instance() {
185 static NcclManager* instance = new NcclManager();
186 return instance;
187 }
188
GenerateCommunicatorKey()189 string NcclManager::GenerateCommunicatorKey() {
190 ncclUniqueId nccl_id;
191 ncclGetUniqueId(&nccl_id);
192 return string(nccl_id.internal, NCCL_UNIQUE_ID_BYTES);
193 }
194
GetCommunicator(NcclManager::Collective * collective,NcclManager::Communicator ** communicator)195 Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
196 NcclManager::Communicator** communicator) {
197 // Sort by executor to make ordering of executors deterministic.
198 std::sort(collective->participants.begin(), collective->participants.end(),
199 [](const std::unique_ptr<Participant>& a,
200 const std::unique_ptr<Participant>& b) {
201 return a->executor < b->executor;
202 });
203
204 mutex_lock l(mu_);
205
206 if (collective->single_node) {
207 // For single-node collectives, we identify a communicator uniquely by the
208 // set of devices participating in the collective. For example, if a
209 // collective is for GPUs 0, 1, and 2 then this will scan to find the
210 // communicator for GPUs 0, 1, and 2.
211 //
212 // Note that each executor identifies a context on one device, so this is
213 // the same as getting the communicator connecting the devices in the
214 // collective. A device can be in different communicators as well - for
215 // example, a communicator for GPUs 0 and 1 is separate from one for GPUs 0,
216 // 1, and 2.
217 //
218 // Since it's expected that a small number of distinct communicators will
219 // be needed, communicators_ is not garbage collected currently.
220 //
221 // Launching of kernels must be serialized so that, given collectives A and
222 // B, and an order of them (e.g., A before B), then for each comm_stream
223 // involved, the kernel for A is launched before the kernel for B. This is
224 // guaranteed currently be a global mutex controlling additions of the
225 // kernels to per-stream launch queues. The launch queues are processed by
226 // LoopKernelLaunches.
227 for (auto& comm : communicators_) {
228 if (comm->num_devices == collective->num_global_devices) {
229 int i;
230 for (i = 0; i < collective->num_local_devices; ++i) {
231 if (comm->members[i].nccl_stream->executor !=
232 collective->participants[i]->executor) {
233 break;
234 }
235 }
236 if (i == collective->num_local_devices) {
237 *communicator = comm.get();
238 return Status::OK();
239 }
240 }
241 }
242 } else {
243 #if NCCL_MAJOR < 2
244 return errors::Internal(
245 "Cannot use multi-node NCCL collectives with NCCL 1.x");
246 #endif
247 if (collective->communicator_key.size() != NCCL_UNIQUE_ID_BYTES) {
248 return errors::Internal("Expected communicator_key of size ",
249 NCCL_UNIQUE_ID_BYTES, " but found size ",
250 collective->communicator_key.size());
251 }
252 // This is an instance of multi-node collective. We have previously
253 // created a NCCL unique id and shared with all workers. Now we find the
254 // `Communicator` corresponding to this id.
255 for (auto& comm : communicators_) {
256 if (comm->key == collective->communicator_key) {
257 *communicator = comm.get();
258 return Status::OK();
259 }
260 }
261 }
262
263 auto* env = Env::Default();
264 std::set<NcclStream*> used_streams;
265
266 // Create and initialize a new communicator.
267 // Note that this is done under the lock; performance is not expected to
268 // matter as this happens a very small number of times.
269 std::vector<CommunicatorMember> members(collective->num_local_devices);
270 std::vector<int> devices(collective->num_local_devices);
271 for (int i = 0; i < collective->num_local_devices; ++i) {
272 auto* executor = collective->participants[i]->executor;
273
274 // Find a communication stream to use for the device.
275 auto& streams = device_to_comm_streams_[executor];
276 NcclStream* nccl_stream = nullptr;
277 for (const auto& s : streams) {
278 if (used_streams.insert(s.get()).second) {
279 nccl_stream = s.get();
280 break;
281 }
282 }
283 if (nccl_stream == nullptr) {
284 nccl_stream = new NcclStream();
285 nccl_stream->executor = executor;
286 nccl_stream->stream.reset(new se::Stream(executor));
287 nccl_stream->stream->Init();
288
289 streams.emplace_back(nccl_stream);
290 used_streams.insert(nccl_stream);
291
292 nccl_stream->thread.reset(env->StartThread(
293 ThreadOptions(), "nccl_kernel_launch",
294 [this, nccl_stream] { LoopKernelLaunches(nccl_stream); }));
295 }
296
297 members[i].nccl_stream = nccl_stream;
298 devices[i] = collective->participants[i]->gpu_device_id;
299 }
300
301 std::vector<ncclComm_t> nccl_comms(collective->num_local_devices);
302 #if NCCL_MAJOR >= 2
303 // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL
304 // group primitives.
305 ncclUniqueId nccl_id;
306 if (collective->single_node) {
307 NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
308 } else {
309 StringToNcclUniqueId(collective->communicator_key, &nccl_id);
310 }
311 int saved_device = 0;
312 CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device));
313 NCCL_RETURN_IF_ERROR(ncclGroupStart());
314 for (int i = 0; i < collective->num_local_devices; ++i) {
315 // Set rank to `participant->global_rank` if provided, else `i`.
316 const int rank = collective->participants[i]->global_rank >= 0
317 ? collective->participants[i]->global_rank
318 : i;
319 CUDA_RETURN_IF_ERROR(cudaSetDevice(devices[i]));
320 NCCL_RETURN_IF_ERROR(ncclCommInitRank(
321 nccl_comms.data() + i, collective->num_global_devices, nccl_id, rank));
322 }
323 NCCL_RETURN_IF_ERROR(ncclGroupEnd());
324 CUDA_RETURN_IF_ERROR(cudaSetDevice(saved_device));
325 #else
326 // Since NCCL 1 is single node only, we use ncclCommInitAll. We could have
327 // used ncclCommInitRank with NCCL 1 as well, but then we would have to
328 // issue each init call from a different thread
329 // (https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/nccl1.html).
330 NCCL_RETURN_IF_ERROR(ncclCommInitAll(
331 nccl_comms.data(), collective->num_local_devices, devices.data()));
332 #endif
333
334 for (int i = 0; i < collective->num_local_devices; ++i) {
335 members[i].nccl_comm = nccl_comms[i];
336 }
337 communicators_.emplace_back(
338 new Communicator(std::move(members), collective->communicator_key));
339 *communicator = communicators_.back().get();
340 return Status::OK();
341 }
342
AddToAllReduce(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)343 void NcclManager::AddToAllReduce(std::unique_ptr<Participant> participant,
344 const Context& context,
345 ncclRedOp_t reduction_op) {
346 AddParticipant(std::move(participant), context, kAllReduce, reduction_op);
347 }
348
AddToAllGather(std::unique_ptr<Participant> participant,const Context & context)349 void NcclManager::AddToAllGather(std::unique_ptr<Participant> participant,
350 const Context& context) {
351 AddParticipant(std::move(participant), context, kAllGather,
352 ncclSum /* unused */);
353 }
354
AddBroadcastSend(std::unique_ptr<Participant> participant,const Context & context)355 void NcclManager::AddBroadcastSend(std::unique_ptr<Participant> participant,
356 const Context& context) {
357 participant->root = true;
358 AddParticipant(std::move(participant), context, kBroadcast,
359 ncclSum /* unused */);
360 }
361
AddBroadcastRecv(std::unique_ptr<Participant> participant,const Context & context)362 void NcclManager::AddBroadcastRecv(std::unique_ptr<Participant> participant,
363 const Context& context) {
364 AddParticipant(std::move(participant), context, kBroadcast,
365 ncclSum /* unused */);
366 }
367
AddReduceSend(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)368 void NcclManager::AddReduceSend(std::unique_ptr<Participant> participant,
369 const Context& context,
370 ncclRedOp_t reduction_op) {
371 AddParticipant(std::move(participant), context, kReduce, reduction_op);
372 }
373
AddReduceRecv(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)374 void NcclManager::AddReduceRecv(std::unique_ptr<Participant> participant,
375 const Context& context,
376 ncclRedOp_t reduction_op) {
377 AddParticipant(std::move(participant), context, kReduce, reduction_op);
378 }
379
SignalMultiNodeReady(const string & collective_key)380 void NcclManager::SignalMultiNodeReady(const string& collective_key) {
381 Collective* to_run = nullptr;
382 {
383 mutex_lock l(mu_);
384 auto collective_it = collectives_.find(collective_key);
385 if (collective_it != collectives_.end()) {
386 Collective* collective = collective_it->second.get();
387 collective->multi_node_ready = true;
388 to_run = CheckReady(collective_key, collective);
389 }
390 }
391
392 if (to_run != nullptr) RunCollective(to_run);
393 }
394
AddParticipant(std::unique_ptr<Participant> participant,const Context & context,CollectiveType collective_type,ncclRedOp_t reduction_op)395 void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
396 const Context& context,
397 CollectiveType collective_type,
398 ncclRedOp_t reduction_op) {
399 Collective* to_run = nullptr;
400 const DataType data_type = participant->input->dtype();
401 {
402 mutex_lock l(mu_);
403 auto collective_it = collectives_.find(context.collective_key);
404 Collective* collective = nullptr;
405 if (collective_it == collectives_.end()) {
406 auto collective_unique_ptr = absl::make_unique<Collective>(
407 data_type, collective_type, reduction_op, context.num_local_devices,
408 context.num_global_devices, context.communicator_key);
409 collective = collective_unique_ptr.get();
410 collectives_.emplace(context.collective_key,
411 std::move(collective_unique_ptr));
412 } else {
413 collective = collective_it->second.get();
414 }
415
416 // Check `collective` is correct and consistent.
417 if (collective->status.ok() && collective->single_node &&
418 !collective->communicator_key.empty()) {
419 collective->status =
420 errors::Internal("Collective ", reduction_op,
421 " is single node but has communicator_key of size ",
422 collective->communicator_key.size());
423 }
424 if (collective->status.ok() && collective->communicator_key.size() !=
425 context.communicator_key.size()) {
426 collective->status =
427 errors::Internal("Collective ", reduction_op,
428 " mismatch in member communicator_key with size ",
429 collective->communicator_key.size(),
430 " and arg communicator_key with size ",
431 context.communicator_key.size());
432 }
433 if (collective->status.ok() && collective->type != collective_type) {
434 collective->status = errors::Internal(
435 "Collective ", reduction_op, " previously initialized with type ",
436 collective->type, " but now got type ", collective_type);
437 }
438 if (collective->status.ok() &&
439 collective->num_global_devices != context.num_global_devices) {
440 collective->status =
441 errors::Internal("Collective ", reduction_op,
442 " previously initialized with num_global_devices ",
443 collective->num_global_devices, " but now got ",
444 context.num_global_devices);
445 }
446 if (collective->status.ok() &&
447 collective->num_local_devices != context.num_local_devices) {
448 collective->status =
449 errors::Internal("Collective ", reduction_op,
450 "previously initialized with num_local_devices ",
451 collective->num_local_devices, " but now got ",
452 context.num_local_devices);
453 }
454 if (collective->status.ok() &&
455 collective->participants.size() >= collective->num_local_devices) {
456 collective->status = errors::Internal(
457 "Collective ", reduction_op, " expected ",
458 collective->num_local_devices, " participants but now has ",
459 collective->participants.size(),
460 " with one more participant being added");
461 }
462
463 collective->participants.emplace_back(std::move(participant));
464 ++collective->available_participants;
465
466 to_run = CheckReady(context.collective_key, collective);
467 }
468
469 if (to_run != nullptr) RunCollective(to_run);
470 }
471
CheckReady(const string & collective_key,Collective * collective)472 NcclManager::Collective* NcclManager::CheckReady(const string& collective_key,
473 Collective* collective) {
474 Collective* to_run = nullptr;
475 if (collective->available_participants == collective->num_local_devices) {
476 if (collective->num_global_devices == collective->num_local_devices ||
477 collective->multi_node_ready) {
478 // Ownership transferred to callee.
479 to_run = collective;
480 auto collectives_it = collectives_.find(collective_key);
481 collectives_it->second.release();
482 collectives_.erase(collectives_it);
483 }
484 }
485 return to_run;
486 }
487
RunCollective(Collective * collective)488 void NcclManager::RunCollective(Collective* collective) {
489 static mutex collective_mu(LINKER_INITIALIZED);
490
491 Status s = collective->status;
492 if (s.ok()) {
493 s = GetCommunicator(collective, &collective->communicator);
494 }
495 if (!s.ok()) {
496 for (int i = 0; i < collective->num_local_devices; ++i) {
497 collective->participants[i]->done_callback(s);
498 }
499 delete collective;
500 return;
501 }
502
503 for (int i = 0; i < collective->num_local_devices; ++i) {
504 Participant* p = collective->participants[i].get();
505 NcclStream* nccl_stream = collective->communicator->members[i].nccl_stream;
506 CHECK(nccl_stream != nullptr);
507 const int rank = p->global_rank >= 0 ? p->global_rank : i;
508
509 if (p->input != nullptr) {
510 // Wait to ensure that the kernel that produces the data in the input
511 // tensor has finished running before the nccl kernel runs on the
512 // communication stream.
513 nccl_stream->stream->ThenWaitFor(p->tensor_stream);
514 }
515 if (p->root) {
516 CHECK_EQ(collective->root_rank, -1);
517 collective->root_rank = rank;
518 }
519 }
520
521 if (collective->type == kBroadcast) {
522 CHECK_NE(collective->root_rank, -1);
523 }
524
525 {
526 // Allow only one collective at a time to queue kernels for launching. This
527 // is to prevent collectives from deadlocking each other.
528 // Note that it would be possible to run multiple collectives at once, if
529 // they have non-intersecting sets of devices.
530 mutex_lock l(collective_mu);
531 for (int i = 0; i < collective->num_local_devices; ++i) {
532 NcclStream* nccl_stream =
533 collective->communicator->members[i].nccl_stream;
534 mutex_lock l(nccl_stream->mu);
535 nccl_stream->pending_launches_.push_front(std::make_pair(collective, i));
536 nccl_stream->cv.notify_all();
537 }
538 }
539 }
540
LoopKernelLaunches(NcclStream * nccl_stream)541 void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
542 se::Stream* comm_stream = nccl_stream->stream.get();
543 ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
544 const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
545 comm_stream->implementation()->GpuStreamMemberHack());
546
547 while (true) {
548 // Find collective to run.
549 std::pair<Collective*, int> next_launch;
550 {
551 mutex_lock l(nccl_stream->mu);
552 while (nccl_stream->pending_launches_.empty()) {
553 if (nccl_stream->shutdown_requested) {
554 // No work and shutdown requested, exit.
555 return;
556 }
557 nccl_stream->cv.wait(l);
558 }
559 next_launch = nccl_stream->pending_launches_.back();
560 nccl_stream->pending_launches_.pop_back();
561 }
562
563 // Launch the nccl kernel.
564 Collective* collective = next_launch.first;
565 ncclDataType_t data_type = ToNcclType(collective->data_type);
566 int p_idx = next_launch.second;
567 Participant* p = collective->participants[p_idx].get();
568 auto nccl_comm = collective->communicator->members[p_idx].nccl_comm;
569 ncclResult_t nccl_result = ncclSuccess;
570 switch (collective->type) {
571 case kAllReduce: {
572 const void* sendbuff = p->input->tensor_data().data();
573 void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
574
575 VLOG(2) << "call NcclAllReduce participant " << p_idx << " sendbuff "
576 << sendbuff << " recvbuff " << recvbuff << " nccl_comm "
577 << nccl_comm << " comm_stream " << comm_stream
578 << " cuda_stream " << cu_stream;
579 nccl_result = ncclAllReduce(sendbuff, recvbuff, p->input->NumElements(),
580 data_type, collective->reduction_op,
581 nccl_comm, *cu_stream);
582 break;
583 }
584 case kBroadcast: {
585 const Tensor* buf_t = p->input ? p->input : p->output;
586 void* buf = const_cast<char*>(buf_t->tensor_data().data());
587 nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type,
588 collective->root_rank, nccl_comm, *cu_stream);
589 break;
590 }
591 case kReduce: {
592 const void* sendbuff = p->input->tensor_data().data();
593 void* recvbuff =
594 p->output ? const_cast<char*>(p->output->tensor_data().data())
595 : nullptr;
596 nccl_result = ncclReduce(sendbuff, recvbuff, p->input->NumElements(),
597 data_type, collective->reduction_op,
598 collective->root_rank, nccl_comm, *cu_stream);
599 break;
600 }
601 case kAllGather: {
602 const void* sendbuff = p->input->tensor_data().data();
603 void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
604
605 VLOG(2) << "call NcclAllGather participant " << p_idx << " sendbuff "
606 << sendbuff << " sendcount " << p->input->NumElements()
607 << " recvbuff " << recvbuff << " recvcount "
608 << p->output->NumElements() << " nccl_comm " << nccl_comm
609 << " comm_stream " << comm_stream << " cuda_stream "
610 << cu_stream;
611 nccl_result = ncclAllGather(sendbuff, recvbuff, p->input->NumElements(),
612 data_type, nccl_comm, *cu_stream);
613 break;
614 }
615 }
616
617 // Run the done_callback when the nccl kernel finishes running.
618 auto done_callback = [collective, p_idx, nccl_result]() {
619 if (nccl_result == ncclSuccess) {
620 collective->participants[p_idx]->done_callback(Status::OK());
621 } else {
622 // Propagate the error, but note that if other members of the collective
623 // did launch their kernels, then they are hanging.
624 collective->participants[p_idx]->done_callback(errors::Unknown(
625 "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
626 }
627
628 // TODO(cwhipkey): use RefCounted after figuring out how to use in a
629 // custom op library.
630 // See tensorflow/core/lib/core/refcount.h for details on this locking.
631 if (collective->remaining_participants.load(std::memory_order_acquire) ==
632 1 ||
633 collective->remaining_participants.fetch_sub(1) == 1) {
634 delete collective;
635 }
636 };
637 p->event_mgr->ThenExecute(comm_stream, done_callback);
638 }
639 }
640
641 } // namespace tensorflow
642
643 #endif // GOOGLE_CUDA
644