• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/global_device_id.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/platform/errors.h"
32 
33 namespace xla {
34 namespace gpu {
35 
ToNcclReduction(ReductionKind kind)36 ncclRedOp_t ToNcclReduction(ReductionKind kind) {
37   switch (kind) {
38     case ReductionKind::SUM:
39       return ncclSum;
40     case ReductionKind::PRODUCT:
41       return ncclProd;
42     case ReductionKind::MIN:
43       return ncclMin;
44     case ReductionKind::MAX:
45       return ncclMax;
46   }
47 }
48 
ToNcclDataType(PrimitiveType element_type)49 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
50   switch (element_type) {
51     case S8:
52       return ncclInt8;
53     case PRED:
54     case U8:
55       return ncclUint8;
56     case S32:
57       return ncclInt32;
58     case U32:
59       return ncclUint32;
60     case S64:
61       return ncclInt64;
62     case U64:
63       return ncclUint64;
64     case F16:
65       return ncclFloat16;
66     case F32:
67       return ncclFloat32;
68     case F64:
69       return ncclFloat64;
70     default:
71       return tensorflow::errors::InvalidArgument(absl::StrFormat(
72           "Unsupported data type: %s", PrimitiveType_Name(element_type)));
73   }
74 }
75 
IsGlobalNcclConfig()76 bool IsGlobalNcclConfig() {
77   static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
78   return global_nccl_config;
79 }
80 
ToStatus(ncclResult_t s,const char * file,int64 line,const char * expr)81 Status ToStatus(ncclResult_t s, const char* file, int64 line,
82                 const char* expr) {
83   if (s == ncclSuccess) {
84     return Status::OK();
85   }
86   return tensorflow::errors::Internal(
87       absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
88                       ncclGetErrorString(s)));
89 }
90 
ToStatus(cudaError_t s,const char * file,int64 line,const char * expr)91 Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr) {
92   if (s == cudaSuccess) {
93     return Status::OK();
94   }
95   return tensorflow::errors::Internal(
96       absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
97                       cudaGetErrorString(s)));
98 }
99 
NcclClique(absl::flat_hash_map<int,NcclComm> comms_by_device_ordinal)100 NcclClique::NcclClique(
101     absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
102     : comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
103 
GetCommForDeviceOrdinal(int device_ordinal) const104 ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
105   return comms_by_device_ordinal_.at(device_ordinal).get();
106 }
107 
NcclCliqueCache()108 NcclCliqueMap& NcclCliqueCache() {
109   // Global cache of NCCL cliques.  An entry in this map is always kept alive.
110   //
111   // A consequence of the fact that this is process-global is that we'll only
112   // ever have one clique alive for a given set of GPUs.  This means that a
113   // process will never do two collective operations concurrently on the same
114   // set of GPUs.
115   static auto& cache = *new NcclCliqueMap();
116   return cache;
117 }
118 
119 namespace {
120 
DestroyNcclComm(ncclComm_t comm)121 void DestroyNcclComm(ncclComm_t comm) {
122   VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
123   XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
124 }
125 
ToNcclUniqueId(const std::string & str_id,ncclUniqueId * nccl_id)126 Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
127   if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
128     return InvalidArgument(
129         "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
130         NCCL_UNIQUE_ID_BYTES);
131   }
132   // NcclUniqueId is internally just a char[].
133   static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
134                 "NCCL_UNIQUE_ID_BYTES");
135   std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
136   return Status::OK();
137 }
138 
LocalParticipantsToString(const std::vector<LocalParticipant> & local_participants)139 std::string LocalParticipantsToString(
140     const std::vector<LocalParticipant>& local_participants) {
141   std::vector<std::string> parts;
142   for (const LocalParticipant& local_participant : local_participants) {
143     parts.push_back(absl::StrFormat("%d/rank=%d",
144                                     local_participant.device_ordinal,
145                                     local_participant.rank));
146   }
147   return absl::StrJoin(parts, ",");
148 }
149 
CreateNcclClique(const NcclCliqueKey & key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)150 StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
151     const NcclCliqueKey& key,
152     const std::vector<LocalParticipant>& local_participants,
153     const NcclUniqueIdCallback* callback) {
154   int num_participants = key.devices().size();
155   ncclUniqueId unique_id;
156   if (callback) {  // Multi-host collective.
157     TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
158     TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
159   } else {
160     TF_RET_CHECK((num_participants == local_participants.size()) ||
161                  IsGlobalNcclConfig())
162         << "If non-local devices are taking part of a collective API on GPU, "
163            "the nccl_unique_id_callback must be provided by the client.";
164     XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
165   }
166 
167   VLOG(3) << "Initializing nccl comms for local participants: "
168           << LocalParticipantsToString(local_participants);
169 
170   // Restore CUDA device after running this.  XLA shouldn't care, but maybe
171   // another consumer does.
172   int initial_cuda_device;
173   XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
174   auto cuda_device_restorer = MakeCleanup(
175       [&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
176 
177   // When using ncclGroupStart/End it seems that the ncclComm_t's are not
178   // populated until the End() call.
179   std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
180   XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
181   Status status = [&] {
182     for (int i = 0; i < local_participants.size(); ++i) {
183       XLA_CUDA_RETURN_IF_ERROR(
184           cudaSetDevice(local_participants[i].device_ordinal));
185       XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
186                                                 unique_id,
187                                                 local_participants[i].rank));
188     }
189     return Status::OK();
190   }();
191   // Always call ncclGroupEnd().
192   status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
193 
194   // Always copy raw comms to RAII type, so they are cleaned up properly.
195   absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
196   for (int i = 0; i < raw_comms.size(); ++i) {
197     int device_ordinal = local_participants[i].device_ordinal;
198     VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
199                                   device_ordinal, raw_comms[i]);
200     CHECK(raw_comms[i] != nullptr || !status.ok());
201     comms_by_device_ordinal.emplace(device_ordinal,
202                                     NcclComm(raw_comms[i], &DestroyNcclComm));
203   }
204 
205   // Now we can check if there was an error creating the communicators.
206   TF_RETURN_IF_ERROR(status);
207   return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
208 }
209 
210 struct NcclCliqueParticipantData : public ParticipantData {
211   using ParticipantData::ParticipantData;
ToStringxla::gpu::__anonf15227f90111::NcclCliqueParticipantData212   std::string ToString() const override { return ""; }
213 };
214 
215 class NcclCliqueRendezvous
216     : public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
217  public:
NcclCliqueRendezvous(const RendezvousKey & rendezvous_key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)218   NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
219                        const std::vector<LocalParticipant>& local_participants,
220                        const NcclUniqueIdCallback* callback)
221       : Rendezvous(rendezvous_key),
222         key_(std::move(rendezvous_key.global_devices)),
223         local_participants_(local_participants),
224         callback_(callback),
225         counter_(nullptr) {}
226 
RunCollectiveOp(const NcclCliqueParticipantData &)227   StatusOr<LockedNcclClique> RunCollectiveOp(
228       const NcclCliqueParticipantData&) override {
229     tensorflow::mutex_lock lock(mu_);
230     bool primary = !initialized_;
231     if (primary) {
232       maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
233           key_, [&](const NcclCliqueKey& key) {
234             return CreateNcclClique(key, local_participants_, callback_);
235           });
236       initialized_ = true;
237     }
238     TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
239     std::unique_ptr<absl::MutexLock> clique_lock;
240     if (primary) {
241       clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
242       counter_ = new absl::BlockingCounter(local_participants_.size());
243     }
244     return LockedNcclClique(*clique, std::move(clique_lock), counter_);
245   }
246 
247  private:
248   NcclCliqueKey key_;
249   const std::vector<LocalParticipant>& local_participants_;
250   const NcclUniqueIdCallback* callback_;
251 
252   StatusOr<NcclClique*> maybe_clique_;
253   absl::BlockingCounter* counter_;
254 };
255 
256 }  // namespace
257 
GetLocalParticipants(const std::vector<GlobalDeviceId> & participants,const std::vector<GlobalDeviceId> * local_devices)258 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
259     const std::vector<GlobalDeviceId>& participants,
260     const std::vector<GlobalDeviceId>* local_devices) {
261   std::vector<LocalParticipant> local_participants;
262   if (local_devices) {
263     absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
264     for (int rank = 0; rank < participants.size(); ++rank) {
265       auto result = device_ranks.emplace(participants[rank], rank);
266       TF_RET_CHECK(result.second) << "Duplicate device found";
267     }
268 
269     local_participants.reserve(local_devices->size());
270     for (int device_ordinal = 0; device_ordinal < local_devices->size();
271          ++device_ordinal) {
272       auto it = device_ranks.find((*local_devices)[device_ordinal]);
273       if (it != device_ranks.end()) {
274         local_participants.push_back({device_ordinal, /*rank=*/it->second});
275       }
276     }
277   } else {  // Single host, so use identity mapping (device ordinal == id).
278     local_participants.reserve(participants.size());
279     for (int rank = 0; rank < participants.size(); ++rank) {
280       int device_ordinal = participants[rank].value();
281       local_participants.push_back({device_ordinal, rank});
282     }
283   }
284 
285   return local_participants;
286 }
287 
LockedNcclClique(NcclClique & clique,std::unique_ptr<absl::MutexLock> lock,absl::BlockingCounter * counter)288 LockedNcclClique::LockedNcclClique(NcclClique& clique,
289                                    std::unique_ptr<absl::MutexLock> lock,
290                                    absl::BlockingCounter* counter)
291     : clique(clique), lock_(std::move(lock)), counter_(counter) {}
292 
LockedNcclClique(LockedNcclClique && other)293 LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
294     : clique(other.clique),
295       lock_(std::move(other.lock_)),
296       counter_(std::exchange(other.counter_, nullptr)) {}
297 
~LockedNcclClique()298 LockedNcclClique::~LockedNcclClique() {
299   if (counter_) {
300     counter_->DecrementCount();
301     if (lock_) {
302       counter_->Wait();  // Don't release lock until all threads are finished.
303       delete counter_;
304     }
305   }
306 }
307 
GetOrTryCreateIfAbsent(const NcclCliqueKey & key,const std::function<StatusOr<std::unique_ptr<NcclClique>> (const NcclCliqueKey &)> & value_factory)308 StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
309     const NcclCliqueKey& key,
310     const std::function<StatusOr<std::unique_ptr<NcclClique>>(
311         const NcclCliqueKey&)>& value_factory) {
312   absl::MutexLock lock(&mu_);
313   auto it = map_.find(key);
314   if (it == map_.end()) {
315     TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
316     it = map_.emplace(key, std::move(value)).first;
317   }
318   return it->second.get();
319 }
320 
ForEach(const std::function<void (const NcclCliqueKey &,const NcclClique &)> & fn)321 void NcclCliqueMap::ForEach(
322     const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
323   absl::MutexLock lock(&mu_);
324   for (const auto& kv : map_) {
325     fn(kv.first, *kv.second);
326   }
327 }
328 
AcquireNcclClique(const RendezvousKey & rendezvous_key,int local_device_ordinal,se::Stream * stream,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)329 StatusOr<LockedNcclClique> AcquireNcclClique(
330     const RendezvousKey& rendezvous_key, int local_device_ordinal,
331     se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
332     const NcclUniqueIdCallback* callback) {
333   VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
334           << ", local participants: "
335           << LocalParticipantsToString(local_participants);
336 
337   static auto& rendezvous_map =
338       *new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
339 
340   NcclCliqueParticipantData participant(rendezvous_key, local_device_ordinal,
341                                         stream);
342   return NcclCliqueRendezvous::SubmitParticipant(
343       /*rendezvous_getter=*/
344       [&] {
345         return rendezvous_map.GetOrCreateIfAbsent(
346             rendezvous_key, [&](const RendezvousKey& rendezvous_key) {
347               return std::make_unique<NcclCliqueRendezvous>(
348                   rendezvous_key, local_participants, callback);
349             });
350       },
351       participant);
352 }
353 
354 }  // namespace gpu
355 }  // namespace xla
356