• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/nccl/nccl_manager.h"
16 
17 #include <utility>
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #include "absl/base/call_once.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 #include "tensorflow/core/lib/core/threadpool.h"
25 #include "tensorflow/core/platform/blocking_counter.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/unbounded_work_queue.h"
28 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
29 #include "tensorflow/core/profiler/lib/connected_traceme.h"
30 #include "tensorflow/core/profiler/lib/traceme.h"
31 #if GOOGLE_CUDA
32 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
33 #elif TENSORFLOW_USE_ROCM
34 #include "tensorflow/core/platform/rocm.h"
35 #endif
36 
37 namespace tensorflow {
38 
39 #if GOOGLE_CUDA
40 using se::cuda::ScopedActivateExecutorContext;
41 #elif TENSORFLOW_USE_ROCM
42 using se::rocm::ScopedActivateExecutorContext;
43 // Local hipify of cuda symbols
44 #define cudaError_t hipError_t
45 #define cudaStream_t hipStream_t
46 #define cudaGetErrorString hipGetErrorString
47 #define cudaGetDevice hipGetDevice
48 #define cudaSetDevice hipSetDevice
49 #define cudaSuccess hipSuccess
50 int NcclManager::instance_count = 0;
51 #endif
52 
53 #define NCCL_RETURN_IF_ERROR(...)                                        \
54   do {                                                                   \
55     ncclResult_t nccl_status = (__VA_ARGS__);                            \
56     if (nccl_status != ncclSuccess) {                                    \
57       return errors::Internal("NCCL: ", ncclGetErrorString(nccl_status), \
58                               ". Set NCCL_DEBUG=WARN for detail.");      \
59     }                                                                    \
60   } while (0)
61 
62 #define CUDA_RETURN_IF_ERROR(...)                                         \
63   do {                                                                    \
64     cudaError_t cuda_status = (__VA_ARGS__);                              \
65     if (cuda_status != cudaSuccess) {                                     \
66       return errors::Internal("CUDA: ", cudaGetErrorString(cuda_status)); \
67     }                                                                     \
68   } while (0)
69 
70 // Contains data for a single stream used for nccl communication; this includes
71 // a background thread that calls NcclManager::LoopKernelLaunches.
72 struct NcclManager::NcclStream : public core::RefCounted {
73  public:
74   NcclStream() = default;
75   ~NcclStream() = default;
76 
77   se::StreamExecutor* executor = nullptr;
78 
79   // The stream on which to run the nccl collective.
80   // This is a different stream than the tensorflow compute stream.
81 #if TENSORFLOW_USE_ROCM
82   // On ROCm, we borrow the nccl stream from the device context.
83   se::Stream* stream = nullptr;
84 #else
85   std::unique_ptr<se::Stream> stream;
86 #endif
87 
88   // `mu` protects access to `pending_launches_`, which is the list of
89   // collectives ready but whose kernels are yet to be launched.  When the
90   // NcclManager object that owns this NcclStream object is destroyed, it
91   // signals `cv` to unblock the thread waiting on more collectives.
92   mutex mu;
93   condition_variable cv;
94   // Has (collective, participant_idx) pairs.
95   std::deque<std::pair<Collective*, int>> pending_launches_ TF_GUARDED_BY(mu);
96   bool shutdown_requested TF_GUARDED_BY(mu) = false;
97 };
98 
99 struct NcclManager::CommunicatorMember {
100  public:
CommunicatorMembertensorflow::NcclManager::CommunicatorMember101   CommunicatorMember() {}
~CommunicatorMembertensorflow::NcclManager::CommunicatorMember102   ~CommunicatorMember() {
103     if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
104   }
105 
106   ncclComm_t nccl_comm = nullptr;
107   // Owned by NcclManager::device_to_comm_streams_ and LoopKernelLaunches.
108   NcclStream* nccl_stream = nullptr;
109 };
110 
111 struct NcclManager::Communicator {
112  public:
Communicatortensorflow::NcclManager::Communicator113   explicit Communicator(std::vector<CommunicatorMember> members,
114                         const string& key)
115       : num_devices(members.size()), members(std::move(members)), key(key) {}
116 
117   const int num_devices;
118   std::vector<CommunicatorMember> members;
119   const string key;
120 };
121 
122 namespace {
123 
124 static constexpr DataTypeSet kValidDataTypes =
125     ToSet(DT_HALF) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) |
126     ToSet(DT_INT64);
127 
ToNcclType(DataType t)128 ncclDataType_t ToNcclType(DataType t) {
129   switch (t) {
130     case DT_HALF:
131       return ncclHalf;
132     case DT_FLOAT:
133       return ncclFloat;
134     case DT_DOUBLE:
135       return ncclDouble;
136     case DT_INT32:
137       return ncclInt;
138     case DT_INT64:
139       return ncclInt64;
140     default:
141       return ncclFloat;
142   }
143 }
144 
StringToNcclUniqueId(const string & str_id,ncclUniqueId * nccl_id)145 void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) {
146   if (str_id.size() == NCCL_UNIQUE_ID_BYTES) {
147     memcpy(nccl_id->internal, str_id.data(), NCCL_UNIQUE_ID_BYTES);
148   }
149 }
150 
151 }  // namespace
152 
153 // A `Collective` encapsulates state for a collective instance at one node.
154 // Typically, an instance in TensorFlow context would be defined by a collective
155 // group and the (step, frame iteration) for that execution.
156 //
157 // For each collective instance there will be one `Collective` object per node.
158 // For example,  a NCCL collective that runs on a single node with 4 GPUs would
159 // have a single `Collective` per step.  However, a collective that executes on
160 // 3 nodes with 4 GPUs each would have a `Collective` per node, each of which is
161 // tracking the 4 GPUs local to that node.
162 struct NcclManager::Collective : public core::RefCounted {
Collectivetensorflow::NcclManager::Collective163   Collective(const string& collective_key_in, DataType data_type_in,
164              CollectiveType type_in, ncclRedOp_t reduction_op_in,
165              int num_local_devices_in, int num_global_devices_in,
166              const string& communicator_key_in)
167       : collective_key(collective_key_in),
168         data_type(data_type_in),
169         type(type_in),
170         reduction_op(reduction_op_in),
171         num_local_devices(num_local_devices_in),
172         num_global_devices(num_global_devices_in),
173         single_node(num_local_devices_in == num_global_devices_in),
174         communicator_key(communicator_key_in) {
175     participants.reserve(num_local_devices_in);
176 #if TENSORFLOW_USE_ROCM
177     // On ROCm platform, this allows caller to either use the singleton instance
178     // or to manage one non-singleton NcclManager instance.
179     // For example, the nccl_manager_test will use both paradigms in the same
180     // executable, but not running concurrently (which would hang otherwise).
181     if (NcclManager::instance_count > 1) {
182       status = errors::Internal(
183           "ROCm cannot use multi-node NCCL collectives on a single node");
184     }
185 #endif
186   }
187 
188   const string collective_key;  // A unique key for debugging.
189   const DataType data_type;
190   const CollectiveType type;
191   const ncclRedOp_t reduction_op;  // applies when <type> is a reduction.
192   const int num_local_devices;     // devices local to this node
193   const int num_global_devices;    // devices across all nodes
194   const bool single_node;          // true if all devices are at one node
195   const string communicator_key;
196 
197   Communicator* communicator = nullptr;
198 
199   // All collective participants.
200   //
201   // Adding values in this vector is guarded by the mutex of the containing
202   // NcclManager.
203   std::vector<std::unique_ptr<Participant>> participants;
204 
205   // For collective types that have a root (e.g. the root of broadcast is the
206   // sender), this is the rank of the root.
207   int root_rank = -1;
208 
209   // How many participants have been registered so far. The Collective is
210   // eligible for running with <available_participants> == num_local_devices.
211   //
212   // If this is a multi-node collective, we additionally have to synchronize
213   // across nodes.  The caller would need to signal multi node readiness by
214   // calling NcclManager::SignalMultiNodeReady, which sets `multi_node_ready` to
215   // true.
216   //
217   // Guarded by the mutex of the containing Communicator.
218   int available_participants = 0;
219   bool multi_node_ready = false;
220   // trace_context is used by tracing system to associate collective
221   // scheduling and execution (cooperative kernel launch), which happen
222   // on different threads.
223   uint64 trace_context = 0;
224 
225   Status status;
226 };
227 
NcclManager()228 NcclManager::NcclManager() {
229   VLOG(2) << "New NcclManager " << this;
230 #if TENSORFLOW_USE_ROCM
231   ++instance_count;
232 #endif
233 }
~NcclManager()234 NcclManager::~NcclManager() {
235   VLOG(2) << "~NcclManager " << this;
236 #if TENSORFLOW_USE_ROCM
237   --instance_count;
238 #endif
239   for (auto& it : device_to_comm_streams_) {
240     for (NcclStream* nccl_stream : it.second) {
241       {
242         mutex_lock l(nccl_stream->mu);
243         nccl_stream->shutdown_requested = true;
244         nccl_stream->cv.notify_all();
245       }
246       nccl_stream->Unref();
247     }
248   }
249 }
instance()250 NcclManager* NcclManager::instance() {
251   static NcclManager* instance = new NcclManager();
252 #if TENSORFLOW_USE_ROCM
253   // singleton does not count against total instances
254   // see comment above in Collective constructor concerning ROCm platform
255   static absl::once_flag once;
256   absl::call_once(once, [] { --NcclManager::instance_count; });
257 #endif
258   return instance;
259 }
260 
GenerateCommunicatorKey()261 string NcclManager::GenerateCommunicatorKey() {
262   ncclUniqueId nccl_id;
263   ncclGetUniqueId(&nccl_id);
264   return string(nccl_id.internal, NCCL_UNIQUE_ID_BYTES);
265 }
266 
GetCommunicator(NcclManager::Collective * collective,NcclManager::Communicator ** communicator)267 Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
268                                     NcclManager::Communicator** communicator) {
269   // Sort by device ID, executor, and global rank to make ordering of
270   // participants deterministic.
271   std::sort(collective->participants.begin(), collective->participants.end(),
272             [](const std::unique_ptr<Participant>& a,
273                const std::unique_ptr<Participant>& b) {
274               if (a->gpu_device_id != b->gpu_device_id) {
275                 return a->gpu_device_id < b->gpu_device_id;
276               }
277               if (a->executor != b->executor) {
278                 return a->executor < b->executor;
279               }
280               return a->global_rank < b->global_rank;
281             });
282 
283   mutex_lock l(mu_);
284   if (!status_.ok()) {
285     return status_;
286   }
287 
288   if (collective->communicator_key.empty()) {
289     // For single-node collectives, when the caller does not specify a
290     // `communicator_key`, we identify a communicator uniquely by the set of
291     // devices participating in the collective.  For example, if a collective is
292     // for GPUs 0, 1, and 2 then this will scan to find the communicator for
293     // GPUs 0, 1, and 2.
294     //
295     // Note that each executor identifies a context on one device, so this is
296     // the same as getting the communicator connecting the devices in the
297     // collective. A device can be in different communicators as well - for
298     // example, a communicator for GPUs 0 and 1 is separate from one for GPUs 0,
299     // 1, and 2.
300     //
301     // Since it's expected that a small number of distinct communicators will
302     // be needed, communicators_ is not garbage collected currently.
303     //
304     // Launching of kernels must be serialized so that, given collectives A and
305     // B, and an order of them (e.g., A before B), then for each comm_stream
306     // involved, the kernel for A is launched before the kernel for B. This is
307     // guaranteed currently by a global mutex controlling additions of the
308     // kernels to per-stream launch queues.  The launch queues are processed by
309     // LoopKernelLaunches.
310     for (auto& comm : communicators_) {
311       if (comm->num_devices == collective->num_global_devices) {
312         int i;
313         for (i = 0; i < collective->num_local_devices; ++i) {
314           if (comm->members[i].nccl_stream->executor !=
315               collective->participants[i]->executor) {
316             break;
317           }
318         }
319         if (i == collective->num_local_devices) {
320           *communicator = comm.get();
321           return Status::OK();
322         }
323       }
324     }
325   } else {
326 #if NCCL_MAJOR < 2
327     return errors::Internal(
328         "Cannot use multi-node NCCL collectives with NCCL 1.x");
329 #endif
330     if (collective->communicator_key.size() != NCCL_UNIQUE_ID_BYTES) {
331       return errors::Internal("Expected communicator_key of size ",
332                               NCCL_UNIQUE_ID_BYTES, " but found size ",
333                               collective->communicator_key.size());
334     }
335     // This is an instance of multi-node collective.  We have previously
336     // created a NCCL unique id and shared with all workers.  Now we find the
337     // `Communicator` corresponding to this id.
338     for (auto& comm : communicators_) {
339       if (comm->key == collective->communicator_key) {
340         *communicator = comm.get();
341         return Status::OK();
342       }
343     }
344   }
345 
346   auto* env = Env::Default();
347   std::set<NcclStream*> used_streams;
348 
349   // Create and initialize a new communicator.
350   // Note that this is done under the lock; performance is not expected to
351   // matter as this happens a very small number of times.
352   std::vector<CommunicatorMember> members(collective->num_local_devices);
353   std::vector<int> devices(collective->num_local_devices);
354   for (int i = 0; i < collective->num_local_devices; ++i) {
355     auto* executor = collective->participants[i]->executor;
356 
357     // Find a communication stream to use for the device.
358     auto& streams = device_to_comm_streams_[executor];
359     NcclStream* nccl_stream = nullptr;
360     for (const auto& s : streams) {
361       if (used_streams.insert(s).second) {
362         nccl_stream = s;
363         break;
364       }
365     }
366     if (nccl_stream == nullptr) {
367       nccl_stream = new NcclStream();
368       nccl_stream->executor = executor;
369 #if TENSORFLOW_USE_ROCM
370       nccl_stream->stream = collective->participants[i]->context->nccl_stream();
371 #else
372       nccl_stream->stream.reset(new se::Stream(executor));
373       nccl_stream->stream->Init();
374 #endif
375 
376       streams.emplace_back(nccl_stream);
377       used_streams.insert(nccl_stream);
378 
379       nccl_stream->Ref();
380       env->SchedClosure([this, nccl_stream]() {
381         LoopKernelLaunches(nccl_stream);
382         nccl_stream->Unref();
383       });
384     }
385 
386     members[i].nccl_stream = nccl_stream;
387     devices[i] = collective->participants[i]->gpu_device_id;
388   }
389 
390   std::vector<ncclComm_t> nccl_comms(collective->num_local_devices);
391 #if NCCL_MAJOR >= 2
392   // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL
393   // group primitives.
394   ncclUniqueId nccl_id;
395   if (collective->single_node) {
396     NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
397   } else {
398     StringToNcclUniqueId(collective->communicator_key, &nccl_id);
399   }
400   int saved_device = 0;
401   CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device));
402   NCCL_RETURN_IF_ERROR(ncclGroupStart());
403   for (int i = 0; i < collective->num_local_devices; ++i) {
404     // Set rank to `participant->global_rank` if provided, else `i`.
405     const int rank = collective->participants[i]->global_rank >= 0
406                          ? collective->participants[i]->global_rank
407                          : i;
408     CUDA_RETURN_IF_ERROR(cudaSetDevice(devices[i]));
409     NCCL_RETURN_IF_ERROR(ncclCommInitRank(
410         nccl_comms.data() + i, collective->num_global_devices, nccl_id, rank));
411   }
412   NCCL_RETURN_IF_ERROR(ncclGroupEnd());
413   CUDA_RETURN_IF_ERROR(cudaSetDevice(saved_device));
414 #else
415   // Since NCCL 1 is single node only, we use ncclCommInitAll.  We could have
416   // used ncclCommInitRank with NCCL 1 as well, but then we would have to
417   // issue each init call from a different thread
418   // (https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/nccl1.html).
419   NCCL_RETURN_IF_ERROR(ncclCommInitAll(
420       nccl_comms.data(), collective->num_local_devices, devices.data()));
421 #endif
422 
423   for (int i = 0; i < collective->num_local_devices; ++i) {
424     members[i].nccl_comm = nccl_comms[i];
425   }
426   communicators_.emplace_back(
427       new Communicator(std::move(members), collective->communicator_key));
428   *communicator = communicators_.back().get();
429   return Status::OK();
430 }
431 
AddToAllReduce(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)432 void NcclManager::AddToAllReduce(std::unique_ptr<Participant> participant,
433                                  const Context& context,
434                                  ncclRedOp_t reduction_op) {
435   AddParticipant(std::move(participant), context, kAllReduce, reduction_op);
436 }
437 
AddToAllGather(std::unique_ptr<Participant> participant,const Context & context)438 void NcclManager::AddToAllGather(std::unique_ptr<Participant> participant,
439                                  const Context& context) {
440   AddParticipant(std::move(participant), context, kAllGather,
441                  ncclSum /* unused */);
442 }
443 
AddBroadcastSend(std::unique_ptr<Participant> participant,const Context & context)444 void NcclManager::AddBroadcastSend(std::unique_ptr<Participant> participant,
445                                    const Context& context) {
446   participant->root = true;
447   AddParticipant(std::move(participant), context, kBroadcast,
448                  ncclSum /* unused */);
449 }
450 
AddBroadcastRecv(std::unique_ptr<Participant> participant,const Context & context)451 void NcclManager::AddBroadcastRecv(std::unique_ptr<Participant> participant,
452                                    const Context& context) {
453   AddParticipant(std::move(participant), context, kBroadcast,
454                  ncclSum /* unused */);
455 }
456 
AddReduceSend(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)457 void NcclManager::AddReduceSend(std::unique_ptr<Participant> participant,
458                                 const Context& context,
459                                 ncclRedOp_t reduction_op) {
460   AddParticipant(std::move(participant), context, kReduce, reduction_op);
461 }
462 
AddReduceRecv(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)463 void NcclManager::AddReduceRecv(std::unique_ptr<Participant> participant,
464                                 const Context& context,
465                                 ncclRedOp_t reduction_op) {
466   participant->root = true;
467   AddParticipant(std::move(participant), context, kReduce, reduction_op);
468 }
469 
SignalMultiNodeReady(const string & collective_key)470 void NcclManager::SignalMultiNodeReady(const string& collective_key) {
471   Collective* to_run = nullptr;
472   {
473     mutex_lock l(mu_);
474     auto collective_it = collectives_.find(collective_key);
475     if (collective_it != collectives_.end()) {
476       Collective* collective = collective_it->second;
477       collective->multi_node_ready = true;
478       if (CheckReady(collective_key, collective)) {
479         to_run = collective;
480       }
481       VLOG(2) << "SignalMultiNodeReady collective " << collective_key
482               << " to_run " << to_run;
483     }
484   }
485 
486   if (to_run != nullptr) RunCollective(to_run);
487 }
488 
AddParticipant(std::unique_ptr<Participant> participant,const Context & context,CollectiveType collective_type,ncclRedOp_t reduction_op)489 void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
490                                  const Context& context,
491                                  CollectiveType collective_type,
492                                  ncclRedOp_t reduction_op) {
493   Collective* to_run = nullptr;
494   DataType data_type;
495   Status nccl_manager_status;
496   if (participant->input != nullptr) {
497     data_type = participant->input->dtype();
498   } else {
499     data_type = participant->output->dtype();
500   }
501   {
502     mutex_lock l(mu_);
503     nccl_manager_status = status_;
504     if (nccl_manager_status.ok()) {
505       auto collective_it = collectives_.find(context.collective_key);
506       Collective* collective = nullptr;
507       if (collective_it == collectives_.end()) {
508         collective = new Collective(
509             context.collective_key, data_type, collective_type, reduction_op,
510             context.num_local_devices, context.num_global_devices,
511             context.communicator_key);
512         collectives_.emplace(context.collective_key, collective);
513       } else {
514         collective = collective_it->second;
515       }
516 
517       // Check `collective` is correct and consistent.
518       if (collective->status.ok() && !collective->single_node &&
519           collective->communicator_key.empty()) {
520         collective->status = errors::Internal(
521             "Collective ", reduction_op,
522             " is multi node with num_local_devices=",
523             collective->num_local_devices,
524             " and num_global_devices=", collective->num_global_devices,
525             " but has an empty communicator_key");
526       }
527       if (collective->status.ok() && collective->communicator_key.size() !=
528                                          context.communicator_key.size()) {
529         collective->status =
530             errors::Internal("Collective ", reduction_op,
531                              " mismatch in member communicator_key with size ",
532                              collective->communicator_key.size(),
533                              " and arg communicator_key with size ",
534                              context.communicator_key.size());
535       }
536       if (collective->status.ok() && collective->type != collective_type) {
537         collective->status = errors::Internal(
538             "Collective ", reduction_op, " previously initialized with type ",
539             collective->type, " but now got type ", collective_type);
540       }
541       if (collective->status.ok() &&
542           collective->num_global_devices != context.num_global_devices) {
543         collective->status =
544             errors::Internal("Collective ", reduction_op,
545                              " previously initialized with num_global_devices ",
546                              collective->num_global_devices, " but now got ",
547                              context.num_global_devices);
548       }
549       if (collective->status.ok() &&
550           collective->num_local_devices != context.num_local_devices) {
551         collective->status =
552             errors::Internal("Collective ", reduction_op,
553                              "previously initialized with num_local_devices ",
554                              collective->num_local_devices, " but now got ",
555                              context.num_local_devices);
556       }
557       if (collective->status.ok() &&
558           collective->participants.size() >= collective->num_local_devices) {
559         collective->status = errors::Internal(
560             "Collective ", reduction_op, " expected ",
561             collective->num_local_devices, " participants but now has ",
562             collective->participants.size(),
563             " with one more participant being added");
564       }
565       if (collective->status.ok() && collective->root_rank >= 0 &&
566           context.source_rank >= 0 &&
567           collective->root_rank != context.source_rank) {
568         collective->status = errors::Internal(
569             "Collective ", collective->collective_key,
570             " already has root_rank ", collective->root_rank,
571             " but new participant has root_rank ", context.source_rank);
572       }
573       if (collective->status.ok() &&
574           !kValidDataTypes.Contains(collective->data_type)) {
575         collective->status = errors::Internal(
576             "Collective ", collective->collective_key,
577             " expected data types compatible with NCCL but instead got ",
578             DataTypeString(collective->data_type));
579       }
580 
581       if (context.source_rank >= 0) {
582         collective->root_rank = context.source_rank;
583       }
584 
585       collective->participants.emplace_back(std::move(participant));
586       ++collective->available_participants;
587 
588       if (CheckReady(context.collective_key, collective)) {
589         to_run = collective;
590       }
591     }
592   }
593   if (!nccl_manager_status.ok()) {
594     participant->done_callback(nccl_manager_status);
595     return;
596   }
597   if (to_run != nullptr) RunCollective(to_run);
598 }
599 
CheckReady(const string & collective_key,Collective * collective)600 bool NcclManager::CheckReady(const string& collective_key,
601                              Collective* collective) {
602   if (collective->available_participants == collective->num_local_devices) {
603     if (collective->num_global_devices == collective->num_local_devices ||
604         collective->multi_node_ready) {
605       // Ownership transferred to callee.
606       collectives_.erase(collective_key);
607       return true;
608     }
609   }
610   return false;
611 }
612 
RunCollective(Collective * collective)613 void NcclManager::RunCollective(Collective* collective) {
614   // For TraceMeConsumer in Connection::RPCDone().
615   tensorflow::profiler::TraceMeProducer traceme("Schedule Collective");
616   collective->trace_context = traceme.GetContextId();
617 
618   static mutex collective_mu(LINKER_INITIALIZED);
619 
620   Status status = collective->status;
621   if (status.ok()) {
622     status = GetCommunicator(collective, &collective->communicator);
623   }
624 
625   for (int i = 0; status.ok() && i < collective->num_local_devices; ++i) {
626     Participant* p = collective->participants[i].get();
627     NcclStream* nccl_stream = collective->communicator->members[i].nccl_stream;
628     CHECK(nccl_stream != nullptr);
629     const int rank = p->global_rank >= 0 ? p->global_rank : i;
630 
631     if (p->input != nullptr) {
632       // Wait to ensure that the kernel that produces the data in the input
633       // tensor has finished running before the nccl kernel runs on the
634       // communication stream.
635       nccl_stream->stream->ThenWaitFor(p->tensor_stream);
636     }
637     if (p->root) {
638       if (collective->root_rank == -1) {
639         collective->root_rank = rank;
640       } else if (collective->root_rank != rank) {
641         status = errors::Internal(
642             "Inconsistent root rank ", collective->root_rank, " and GPU id ",
643             p->gpu_device_id, " rank ", rank, " also marked as root.");
644       }
645     }
646     VLOG(2) << "RunCollective rank " << rank << " global_rank "
647             << p->global_rank << " root_rank " << collective->root_rank;
648   }
649 
650   if (status.ok() && collective->type == kBroadcast &&
651       collective->root_rank < 0) {
652     status = errors::Internal("Root rank not indicated for collective ",
653                               collective->collective_key);
654   }
655 
656   if (!status.ok()) {
657     for (int i = 0; i < collective->num_local_devices; ++i) {
658       collective->participants[i]->done_callback(status);
659     }
660     collective->Unref();
661     return;
662   }
663 
664   {
665     // Allow only one collective at a time to queue kernels for launching. This
666     // is to prevent collectives from deadlocking each other.
667     // Note that it would be possible to run multiple collectives at once, if
668     // they have non-intersecting sets of devices.
669     mutex_lock l(collective_mu);
670     for (int i = 0; i < collective->num_local_devices; ++i) {
671       NcclStream* nccl_stream =
672           collective->communicator->members[i].nccl_stream;
673       mutex_lock l(nccl_stream->mu);
674       nccl_stream->pending_launches_.push_front(std::make_pair(collective, i));
675       // Ownership is shared between LoopKernelLaunches for each stream in this
676       // collective.
677       collective->Ref();
678       nccl_stream->cv.notify_all();
679     }
680   }
681   collective->Unref();
682 }
683 
684 namespace {
685 // For tracing purpose.
ComputeBufferSize(const NcclManager::Participant * p,DataType data_type)686 size_t ComputeBufferSize(const NcclManager::Participant* p,
687                          DataType data_type) {
688   size_t num_elements = 0;
689   if (p->output) {
690     num_elements += p->output->NumElements();
691   } else if (p->input) {
692     num_elements += p->input->NumElements();
693   }
694   return num_elements * DataTypeSize(data_type);
695 }
696 }  // namespace
697 
LoopKernelLaunches(NcclStream * nccl_stream)698 void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
699 #if TENSORFLOW_USE_ROCM
700   se::Stream* comm_stream = nccl_stream->stream;
701 #else
702   se::Stream* comm_stream = nccl_stream->stream.get();
703 #endif
704   ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
705   const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
706       comm_stream->implementation()->GpuStreamMemberHack());
707 
708   while (true) {
709     // Find collective to run.
710     std::pair<Collective*, int> next_launch;
711     {
712       VLOG(3) << "Locking mutex nccl_stream " << nccl_stream;
713       mutex_lock l(nccl_stream->mu);
714       while (nccl_stream->pending_launches_.empty()) {
715         if (nccl_stream->shutdown_requested) {
716           // No work and shutdown requested, exit.
717           return;
718         }
719         nccl_stream->cv.wait(l);
720       }
721       next_launch = nccl_stream->pending_launches_.back();
722       nccl_stream->pending_launches_.pop_back();
723     }
724 
725     // Launch the nccl kernel.
726     Collective* collective = next_launch.first;
727     tensorflow::profiler::TraceMeConsumer traceme("Run Collective",
728                                                   collective->trace_context);
729 
730     ncclDataType_t data_type = ToNcclType(collective->data_type);
731     int p_idx = next_launch.second;
732     Participant* p = collective->participants[p_idx].get();
733     auto nccl_comm = collective->communicator->members[p_idx].nccl_comm;
734     ncclResult_t nccl_result = ncclSuccess;
735     switch (collective->type) {
736       case kAllReduce: {
737         const void* sendbuff = p->input->tensor_data().data();
738         void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
739 
740         VLOG(2) << "call NcclAllReduce collective_key "
741                 << collective->collective_key << " participant " << p_idx
742                 << " num_participants " << collective->participants.size()
743                 << " sendbuff " << sendbuff << " recvbuff " << recvbuff
744                 << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream
745                 << " cuda_stream " << cu_stream;
746         profiler::AnnotatedTraceMe traceme([&] {
747           return profiler::TraceMeEncode(
748               "ncclAllReduce",
749               {{"buffer_size", ComputeBufferSize(p, collective->data_type)},
750                {"collective_type", "all_reduce"}});
751         });
752         nccl_result = ncclAllReduce(sendbuff, recvbuff, p->input->NumElements(),
753                                     data_type, collective->reduction_op,
754                                     nccl_comm, *cu_stream);
755         break;
756       }
757       case kBroadcast: {
758         const void* sendbuff = nullptr;
759         void* recvbuff = nullptr;
760         int num_elements = -1;
761         if (p->input) {
762           sendbuff = p->input->tensor_data().data();
763           num_elements = p->input->NumElements();
764         }
765         if (p->output) {
766           recvbuff = const_cast<char*>(p->output->tensor_data().data());
767           num_elements = p->output->NumElements();
768         } else {
769           // Operate in-place if no output (for the src node).
770           recvbuff = const_cast<void*>(sendbuff);
771         }
772         if (num_elements < 0) {
773           p->done_callback(errors::Internal(
774               "Both input and output are null in ncclBroadcast"));
775           collective->Unref();
776           continue;
777         }
778         VLOG(2) << "call NcclBroadcast collective_key "
779                 << collective->collective_key << " participant " << p_idx
780                 << " sendbuff " << sendbuff << " recvbuff " << recvbuff
781                 << " nccl_comm " << nccl_comm << " comm_stream " << comm_stream
782                 << " cuda_stream " << cu_stream;
783         profiler::AnnotatedTraceMe traceme([&] {
784           return profiler::TraceMeEncode(
785               "ncclBroadcast",
786               {{"buffer_size", ComputeBufferSize(p, collective->data_type)},
787                {"collective_type", "broadcast"}});
788         });
789         nccl_result =
790             ncclBroadcast(sendbuff, recvbuff, num_elements, data_type,
791                           collective->root_rank, nccl_comm, *cu_stream);
792         break;
793       }
794       case kReduce: {
795         const void* sendbuff = p->input->tensor_data().data();
796         void* recvbuff =
797             p->output ? const_cast<char*>(p->output->tensor_data().data())
798                       : nullptr;
799         profiler::AnnotatedTraceMe traceme([&] {
800           return profiler::TraceMeEncode(
801               "buffer_size",
802               {{"output_size", ComputeBufferSize(p, collective->data_type)},
803                {"collective_type", "reduce"}});
804         });
805         nccl_result = ncclReduce(sendbuff, recvbuff, p->input->NumElements(),
806                                  data_type, collective->reduction_op,
807                                  collective->root_rank, nccl_comm, *cu_stream);
808         break;
809       }
810       case kAllGather: {
811         const void* sendbuff = p->input->tensor_data().data();
812         void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
813 
814         VLOG(2) << "call NcclAllGather collective_key "
815                 << collective->collective_key << " participant " << p_idx
816                 << " sendbuff " << sendbuff << " sendcount "
817                 << p->input->NumElements() << " recvbuff " << recvbuff
818                 << " recvcount " << p->output->NumElements() << " nccl_comm "
819                 << nccl_comm << " comm_stream " << comm_stream
820                 << " cuda_stream " << cu_stream;
821         profiler::AnnotatedTraceMe traceme([&] {
822           return profiler::TraceMeEncode(
823               "ncclAllGather",
824               {{"buffer_size", ComputeBufferSize(p, collective->data_type)},
825                {"collective_type", "all_gather"}});
826         });
827         nccl_result = ncclAllGather(sendbuff, recvbuff, p->input->NumElements(),
828                                     data_type, nccl_comm, *cu_stream);
829         break;
830       }
831     }
832 
833     // Run the done_callback when the nccl kernel finishes running.
834     auto done_callback = [collective, p_idx, nccl_result]() {
835       VLOG(2) << "done Nccl kernel collective_key "
836               << collective->collective_key << " participant " << p_idx
837               << " ncclResult " << nccl_result;
838       if (nccl_result == ncclSuccess) {
839         collective->participants[p_idx]->done_callback(Status::OK());
840       } else {
841         // Propagate the error, but note that if other members of the collective
842         // did launch their kernels, then they are hanging.
843         collective->participants[p_idx]->done_callback(errors::Unknown(
844             "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
845       }
846       collective->Unref();
847     };
848     p->event_mgr->ThenExecute(comm_stream, done_callback);
849   }
850 }
851 
StartAbort(const Status & s)852 void NcclManager::StartAbort(const Status& s) {
853   absl::flat_hash_map<string, Collective*> collectives;
854   std::vector<std::unique_ptr<Communicator>> communicators;
855   {
856     mutex_lock l(mu_);
857     if (!status_.ok()) {
858       LOG(WARNING)
859           << "NcclManager already aborted, ignoring subsequent StartAbort with "
860           << s;
861       return;
862     }
863     status_ = s;
864     collectives.swap(collectives_);
865     communicators.swap(communicators_);
866   }
867   VLOG(2) << "Aborted NcclManager " << this << " with " << collectives.size()
868           << " collectives and " << communicators.size()
869           << " comms with status " << s;
870   // collectives_ contains pending launches that haven't been dispatched to
871   // kernel launch threads, so we can simply invoke the done callbacks of them.
872   for (const auto& item : collectives) {
873     for (const std::unique_ptr<Participant>& p : item.second->participants) {
874       p->done_callback(s);
875     }
876     item.second->Unref();
877   }
878   // Abort ncclComm. Note that there could be multiple ncclComm per device,
879   // and ncclCommAbort contains cuda calls that requires device
880   // synchronization. That is a collective on nccl_comm_0 can block
881   // ncclCommAbort(nccl_comm_1), so we need to abort all ncclComm in a
882   // concurrent fashion. This assumes that there's only one active NcclManager
883   // at a time.
884   UnboundedWorkQueue queue(Env::Default(), "nccl_abort");
885   int num_comms = 0;
886   for (std::unique_ptr<Communicator>& communicator : communicators) {
887     num_comms += communicator->members.size();
888   }
889   BlockingCounter pending(num_comms);
890   for (std::unique_ptr<Communicator>& communicator : communicators) {
891     for (CommunicatorMember& member : communicator->members) {
892       queue.Schedule([&member, &pending]() {
893         ncclCommAbort(member.nccl_comm);
894         member.nccl_comm = nullptr;
895         pending.DecrementCount();
896       });
897     }
898   }
899   pending.Wait();
900 }
901 
Reset()902 void NcclManager::Reset() {
903   mutex_lock l(mu_);
904   status_ = Status();
905   VLOG(2) << "Reset NcclManager " << this;
906 }
907 
908 }  // namespace tensorflow
909 
910 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
911