• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/global_device_id.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/platform/errors.h"
32 
33 namespace xla {
34 namespace gpu {
35 
ToNcclReduction(ReductionKind kind)36 ncclRedOp_t ToNcclReduction(ReductionKind kind) {
37   switch (kind) {
38     case ReductionKind::SUM:
39       return ncclSum;
40     case ReductionKind::PRODUCT:
41       return ncclProd;
42     case ReductionKind::MIN:
43       return ncclMin;
44     case ReductionKind::MAX:
45       return ncclMax;
46   }
47 }
48 
49 namespace {
50 
ToNcclDataType(PrimitiveType element_type)51 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
52   switch (element_type) {
53     case S8:
54       return ncclInt8;
55     case PRED:
56     case U8:
57       return ncclUint8;
58     case S32:
59       return ncclInt32;
60     case U32:
61       return ncclUint32;
62     case S64:
63       return ncclInt64;
64     case U64:
65       return ncclUint64;
66     case F16:
67       return ncclFloat16;
68     case F32:
69     case C64:
70       return ncclFloat32;
71     case F64:
72     case C128:
73       return ncclFloat64;
74 #if defined(__CUDA_BF16_TYPES_EXIST__)
75     case BF16:
76       return ncclBfloat16;
77 #endif
78     default:
79       return tensorflow::errors::InvalidArgument(absl::StrFormat(
80           "Unsupported data type: %s", PrimitiveType_Name(element_type)));
81   }
82 }
83 
84 }  // namespace
85 
ToNcclDataTypeAndCountMultiplier(PrimitiveType element_type)86 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier(
87     PrimitiveType element_type) {
88   TF_ASSIGN_OR_RETURN(ncclDataType_t dtype, ToNcclDataType(element_type));
89   bool is_complex = primitive_util::IsComplexType(element_type);
90   return std::make_pair(dtype, is_complex ? 2 : 1);
91 }
92 
IsGlobalNcclConfig()93 bool IsGlobalNcclConfig() {
94   static const bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
95   return global_nccl_config;
96 }
97 
IsNcclLaunchModeParallel()98 bool IsNcclLaunchModeParallel() {
99   static const bool is_launch_mode_parallel =
100       absl::string_view(std::getenv("NCCL_LAUNCH_MODE")) == "PARALLEL";
101   return is_launch_mode_parallel;
102 }
103 
ToStatus(ncclResult_t s,const char * file,int64_t line,const char * expr)104 Status ToStatus(ncclResult_t s, const char* file, int64_t line,
105                 const char* expr) {
106   if (s == ncclSuccess) {
107     return Status::OK();
108   }
109   return tensorflow::errors::Internal(
110       absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
111                       ncclGetErrorString(s)));
112 }
113 
ToStatus(cudaError_t s,const char * file,int64_t line,const char * expr)114 Status ToStatus(cudaError_t s, const char* file, int64_t line,
115                 const char* expr) {
116   if (s == cudaSuccess) {
117     return Status::OK();
118   }
119   return tensorflow::errors::Internal(
120       absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
121                       cudaGetErrorString(s)));
122 }
123 
NcclClique(absl::flat_hash_map<int,NcclComm> comms_by_device_ordinal)124 NcclClique::NcclClique(
125     absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
126     : comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
127 
GetCommForDeviceOrdinal(int device_ordinal) const128 ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
129   return comms_by_device_ordinal_.at(device_ordinal).get();
130 }
131 
NcclCliqueCache()132 NcclCliqueMap& NcclCliqueCache() {
133   // Global cache of NCCL cliques.  An entry in this map is always kept alive.
134   //
135   // A consequence of the fact that this is process-global is that we'll only
136   // ever have one clique alive for a given set of GPUs.  This means that a
137   // process will never do two collective operations concurrently on the same
138   // set of GPUs.
139   static auto& cache = *new NcclCliqueMap();
140   return cache;
141 }
142 
143 namespace {
144 
DestroyNcclComm(ncclComm_t comm)145 void DestroyNcclComm(ncclComm_t comm) {
146   VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
147   XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
148 }
149 
ToNcclUniqueId(const std::string & str_id,ncclUniqueId * nccl_id)150 Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
151   if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
152     return InvalidArgument(
153         "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
154         NCCL_UNIQUE_ID_BYTES);
155   }
156   // NcclUniqueId is internally just a char[].
157   static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
158                 "NCCL_UNIQUE_ID_BYTES");
159   std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
160   return Status::OK();
161 }
162 
LocalParticipantsToString(const std::vector<LocalParticipant> & local_participants)163 std::string LocalParticipantsToString(
164     const std::vector<LocalParticipant>& local_participants) {
165   std::vector<std::string> parts;
166   for (const LocalParticipant& local_participant : local_participants) {
167     parts.push_back(absl::StrFormat("%d/rank=%d",
168                                     local_participant.device_ordinal,
169                                     local_participant.rank));
170   }
171   return absl::StrJoin(parts, ",");
172 }
173 
CreateNcclClique(const NcclCliqueKey & key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)174 StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
175     const NcclCliqueKey& key,
176     const std::vector<LocalParticipant>& local_participants,
177     const NcclUniqueIdCallback* callback) {
178   int num_participants = key.devices().size();
179   ncclUniqueId unique_id;
180   bool only_local_participants = num_participants == local_participants.size();
181   if (callback && !only_local_participants) {  // Multi-host collective.
182     TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
183     TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
184   } else {
185     TF_RET_CHECK(only_local_participants || IsGlobalNcclConfig())
186         << "If non-local devices are taking part of a collective API on GPU, "
187            "the nccl_unique_id_callback must be provided by the client.";
188     XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
189   }
190 
191   VLOG(3) << "Initializing nccl comms for local participants: "
192           << LocalParticipantsToString(local_participants);
193 
194   // Restore CUDA device after running this.  XLA shouldn't care, but maybe
195   // another consumer does.
196   int initial_cuda_device;
197   XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
198   auto cuda_device_restorer = MakeCleanup(
199       [&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
200 
201   // When using ncclGroupStart/End it seems that the ncclComm_t's are not
202   // populated until the End() call.
203   std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
204   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
205   Status status = [&] {
206     for (int i = 0; i < local_participants.size(); ++i) {
207       XLA_CUDA_RETURN_IF_ERROR(
208           cudaSetDevice(local_participants[i].device_ordinal));
209       XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
210                                                 unique_id,
211                                                 local_participants[i].rank));
212     }
213     return Status::OK();
214   }();
215   // Always call ncclGroupEnd().
216   status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
217 
218   // Always copy raw comms to RAII type, so they are cleaned up properly.
219   absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
220   for (int i = 0; i < raw_comms.size(); ++i) {
221     int device_ordinal = local_participants[i].device_ordinal;
222     VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
223                                   device_ordinal, raw_comms[i]);
224     CHECK(raw_comms[i] != nullptr || !status.ok());
225     comms_by_device_ordinal.emplace(device_ordinal,
226                                     NcclComm(raw_comms[i], &DestroyNcclComm));
227   }
228 
229   // Now we can check if there was an error creating the communicators.
230   TF_RETURN_IF_ERROR(status);
231   return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
232 }
233 
234 class NcclCliqueRendezvous
235     : public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
236  public:
NcclCliqueRendezvous(const RendezvousKey & rendezvous_key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)237   NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
238                        const std::vector<LocalParticipant>& local_participants,
239                        const NcclUniqueIdCallback* callback)
240       : Rendezvous(rendezvous_key),
241         key_(std::move(rendezvous_key.global_devices)),
242         local_participants_(local_participants),
243         callback_(callback),
244         counter_(nullptr) {}
245 
RunCollectiveOp(const NcclCliqueParticipantData &)246   StatusOr<LockedNcclClique> RunCollectiveOp(
247       const NcclCliqueParticipantData&) override {
248     tensorflow::mutex_lock lock(mu_);
249     bool primary = !initialized_;
250     if (primary) {
251       maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
252           key_, [&](const NcclCliqueKey& key) {
253             return CreateNcclClique(key, local_participants_, callback_);
254           });
255       initialized_ = true;
256     }
257     TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
258     std::unique_ptr<absl::MutexLock> clique_lock;
259     if (primary) {
260       clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
261       counter_ = new absl::BlockingCounter(local_participants_.size());
262     }
263     return LockedNcclClique(*clique, std::move(clique_lock), counter_);
264   }
265 
266  private:
267   NcclCliqueKey key_;
268   const std::vector<LocalParticipant>& local_participants_;
269   const NcclUniqueIdCallback* callback_;
270 
271   StatusOr<NcclClique*> maybe_clique_;
272   absl::BlockingCounter* counter_;
273 };
274 
275 }  // namespace
276 
GetLocalParticipants(const std::vector<GlobalDeviceId> & participants,const std::vector<GlobalDeviceId> * local_devices)277 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
278     const std::vector<GlobalDeviceId>& participants,
279     const std::vector<GlobalDeviceId>* local_devices) {
280   std::vector<LocalParticipant> local_participants;
281   if (local_devices) {
282     absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
283     for (int rank = 0; rank < participants.size(); ++rank) {
284       auto result = device_ranks.emplace(participants[rank], rank);
285       TF_RET_CHECK(result.second) << "Duplicate device found";
286     }
287 
288     local_participants.reserve(local_devices->size());
289     for (int device_ordinal = 0; device_ordinal < local_devices->size();
290          ++device_ordinal) {
291       auto it = device_ranks.find((*local_devices)[device_ordinal]);
292       if (it != device_ranks.end()) {
293         local_participants.push_back({device_ordinal, /*rank=*/it->second});
294       }
295     }
296   } else {  // Single host, so use identity mapping (device ordinal == id).
297     local_participants.reserve(participants.size());
298     for (int rank = 0; rank < participants.size(); ++rank) {
299       int device_ordinal = participants[rank].value();
300       local_participants.push_back({device_ordinal, rank});
301     }
302   }
303 
304   return local_participants;
305 }
306 
LockedNcclClique(NcclClique & clique,std::unique_ptr<absl::MutexLock> lock,absl::BlockingCounter * counter)307 LockedNcclClique::LockedNcclClique(NcclClique& clique,
308                                    std::unique_ptr<absl::MutexLock> lock,
309                                    absl::BlockingCounter* counter)
310     : clique(clique), lock_(std::move(lock)), counter_(counter) {}
311 
LockedNcclClique(LockedNcclClique && other)312 LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
313     : clique(other.clique),
314       lock_(std::move(other.lock_)),
315       counter_(std::exchange(other.counter_, nullptr)) {}
316 
~LockedNcclClique()317 LockedNcclClique::~LockedNcclClique() {
318   if (counter_) {
319     counter_->DecrementCount();
320     if (lock_) {
321       counter_->Wait();  // Don't release lock until all threads are finished.
322       delete counter_;
323     }
324   }
325 }
326 
GetOrTryCreateIfAbsent(const NcclCliqueKey & key,const std::function<StatusOr<std::unique_ptr<NcclClique>> (const NcclCliqueKey &)> & value_factory)327 StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
328     const NcclCliqueKey& key,
329     const std::function<StatusOr<std::unique_ptr<NcclClique>>(
330         const NcclCliqueKey&)>& value_factory) {
331   {
332     absl::MutexLock lock(&mu_);
333     auto it = map_.find(key);
334     if (it != map_.end()) {
335       return it->second.get();
336     }
337   }
338   // We release the lock to allow different cliques to be created in parallel
339   // (avoiding a potential deadlock in multi-host settings). This is safe
340   // provided that there aren't two threads trying to create cliques with the
341   // same key - which we know will not happen as this method is only called by
342   // the primary thread from the clique rendezvous. If this assumption is not
343   // valid, the method will return an error.
344   TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
345   absl::MutexLock lock(&mu_);
346   auto result = map_.emplace(key, std::move(value));
347   TF_RET_CHECK(result.second) << "Clique already in cache.";
348   return result.first->second.get();
349 }
350 
ForEach(const std::function<void (const NcclCliqueKey &,const NcclClique &)> & fn)351 void NcclCliqueMap::ForEach(
352     const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
353   absl::MutexLock lock(&mu_);
354   for (const auto& kv : map_) {
355     fn(kv.first, *kv.second);
356   }
357 }
358 
AcquireNcclClique(const NcclCliqueParticipantData & participant,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)359 StatusOr<LockedNcclClique> AcquireNcclClique(
360     const NcclCliqueParticipantData& participant,
361     const std::vector<LocalParticipant>& local_participants,
362     const NcclUniqueIdCallback* callback) {
363   VLOG(2) << "Rendezvous key: " << participant.rendezvous_key.ToString()
364           << ", local participants: "
365           << LocalParticipantsToString(local_participants);
366 
367   static auto& rendezvous_map =
368       *new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
369 
370   return NcclCliqueRendezvous::SubmitParticipant(
371       /*rendezvous_getter=*/
372       [&, participant] {
373         return rendezvous_map.GetOrCreateIfAbsent(
374             participant.rendezvous_key,
375             [&](const RendezvousKey& rendezvous_key) {
376               return std::make_unique<NcclCliqueRendezvous>(
377                   participant.rendezvous_key, local_participants, callback);
378             });
379       },
380       participant);
381 }
382 
383 }  // namespace gpu
384 }  // namespace xla
385