1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/global_device_id.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/platform/errors.h"
32
33 namespace xla {
34 namespace gpu {
35
ToNcclReduction(ReductionKind kind)36 ncclRedOp_t ToNcclReduction(ReductionKind kind) {
37 switch (kind) {
38 case ReductionKind::SUM:
39 return ncclSum;
40 case ReductionKind::PRODUCT:
41 return ncclProd;
42 case ReductionKind::MIN:
43 return ncclMin;
44 case ReductionKind::MAX:
45 return ncclMax;
46 }
47 }
48
49 namespace {
50
ToNcclDataType(PrimitiveType element_type)51 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
52 switch (element_type) {
53 case S8:
54 return ncclInt8;
55 case PRED:
56 case U8:
57 return ncclUint8;
58 case S32:
59 return ncclInt32;
60 case U32:
61 return ncclUint32;
62 case S64:
63 return ncclInt64;
64 case U64:
65 return ncclUint64;
66 case F16:
67 return ncclFloat16;
68 case F32:
69 case C64:
70 return ncclFloat32;
71 case F64:
72 case C128:
73 return ncclFloat64;
74 #if defined(__CUDA_BF16_TYPES_EXIST__)
75 case BF16:
76 return ncclBfloat16;
77 #endif
78 default:
79 return tensorflow::errors::InvalidArgument(absl::StrFormat(
80 "Unsupported data type: %s", PrimitiveType_Name(element_type)));
81 }
82 }
83
84 } // namespace
85
ToNcclDataTypeAndCountMultiplier(PrimitiveType element_type)86 StatusOr<std::pair<ncclDataType_t, int>> ToNcclDataTypeAndCountMultiplier(
87 PrimitiveType element_type) {
88 TF_ASSIGN_OR_RETURN(ncclDataType_t dtype, ToNcclDataType(element_type));
89 bool is_complex = primitive_util::IsComplexType(element_type);
90 return std::make_pair(dtype, is_complex ? 2 : 1);
91 }
92
IsGlobalNcclConfig()93 bool IsGlobalNcclConfig() {
94 static const bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
95 return global_nccl_config;
96 }
97
IsNcclLaunchModeParallel()98 bool IsNcclLaunchModeParallel() {
99 static const bool is_launch_mode_parallel =
100 absl::string_view(std::getenv("NCCL_LAUNCH_MODE")) == "PARALLEL";
101 return is_launch_mode_parallel;
102 }
103
ToStatus(ncclResult_t s,const char * file,int64_t line,const char * expr)104 Status ToStatus(ncclResult_t s, const char* file, int64_t line,
105 const char* expr) {
106 if (s == ncclSuccess) {
107 return Status::OK();
108 }
109 return tensorflow::errors::Internal(
110 absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
111 ncclGetErrorString(s)));
112 }
113
ToStatus(cudaError_t s,const char * file,int64_t line,const char * expr)114 Status ToStatus(cudaError_t s, const char* file, int64_t line,
115 const char* expr) {
116 if (s == cudaSuccess) {
117 return Status::OK();
118 }
119 return tensorflow::errors::Internal(
120 absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
121 cudaGetErrorString(s)));
122 }
123
NcclClique(absl::flat_hash_map<int,NcclComm> comms_by_device_ordinal)124 NcclClique::NcclClique(
125 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
126 : comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
127
GetCommForDeviceOrdinal(int device_ordinal) const128 ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
129 return comms_by_device_ordinal_.at(device_ordinal).get();
130 }
131
NcclCliqueCache()132 NcclCliqueMap& NcclCliqueCache() {
133 // Global cache of NCCL cliques. An entry in this map is always kept alive.
134 //
135 // A consequence of the fact that this is process-global is that we'll only
136 // ever have one clique alive for a given set of GPUs. This means that a
137 // process will never do two collective operations concurrently on the same
138 // set of GPUs.
139 static auto& cache = *new NcclCliqueMap();
140 return cache;
141 }
142
143 namespace {
144
DestroyNcclComm(ncclComm_t comm)145 void DestroyNcclComm(ncclComm_t comm) {
146 VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
147 XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
148 }
149
ToNcclUniqueId(const std::string & str_id,ncclUniqueId * nccl_id)150 Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
151 if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
152 return InvalidArgument(
153 "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
154 NCCL_UNIQUE_ID_BYTES);
155 }
156 // NcclUniqueId is internally just a char[].
157 static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
158 "NCCL_UNIQUE_ID_BYTES");
159 std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
160 return Status::OK();
161 }
162
LocalParticipantsToString(const std::vector<LocalParticipant> & local_participants)163 std::string LocalParticipantsToString(
164 const std::vector<LocalParticipant>& local_participants) {
165 std::vector<std::string> parts;
166 for (const LocalParticipant& local_participant : local_participants) {
167 parts.push_back(absl::StrFormat("%d/rank=%d",
168 local_participant.device_ordinal,
169 local_participant.rank));
170 }
171 return absl::StrJoin(parts, ",");
172 }
173
CreateNcclClique(const NcclCliqueKey & key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)174 StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
175 const NcclCliqueKey& key,
176 const std::vector<LocalParticipant>& local_participants,
177 const NcclUniqueIdCallback* callback) {
178 int num_participants = key.devices().size();
179 ncclUniqueId unique_id;
180 bool only_local_participants = num_participants == local_participants.size();
181 if (callback && !only_local_participants) { // Multi-host collective.
182 TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
183 TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
184 } else {
185 TF_RET_CHECK(only_local_participants || IsGlobalNcclConfig())
186 << "If non-local devices are taking part of a collective API on GPU, "
187 "the nccl_unique_id_callback must be provided by the client.";
188 XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
189 }
190
191 VLOG(3) << "Initializing nccl comms for local participants: "
192 << LocalParticipantsToString(local_participants);
193
194 // Restore CUDA device after running this. XLA shouldn't care, but maybe
195 // another consumer does.
196 int initial_cuda_device;
197 XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
198 auto cuda_device_restorer = MakeCleanup(
199 [&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
200
201 // When using ncclGroupStart/End it seems that the ncclComm_t's are not
202 // populated until the End() call.
203 std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
204 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
205 Status status = [&] {
206 for (int i = 0; i < local_participants.size(); ++i) {
207 XLA_CUDA_RETURN_IF_ERROR(
208 cudaSetDevice(local_participants[i].device_ordinal));
209 XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
210 unique_id,
211 local_participants[i].rank));
212 }
213 return Status::OK();
214 }();
215 // Always call ncclGroupEnd().
216 status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
217
218 // Always copy raw comms to RAII type, so they are cleaned up properly.
219 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
220 for (int i = 0; i < raw_comms.size(); ++i) {
221 int device_ordinal = local_participants[i].device_ordinal;
222 VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
223 device_ordinal, raw_comms[i]);
224 CHECK(raw_comms[i] != nullptr || !status.ok());
225 comms_by_device_ordinal.emplace(device_ordinal,
226 NcclComm(raw_comms[i], &DestroyNcclComm));
227 }
228
229 // Now we can check if there was an error creating the communicators.
230 TF_RETURN_IF_ERROR(status);
231 return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
232 }
233
234 class NcclCliqueRendezvous
235 : public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
236 public:
NcclCliqueRendezvous(const RendezvousKey & rendezvous_key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)237 NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
238 const std::vector<LocalParticipant>& local_participants,
239 const NcclUniqueIdCallback* callback)
240 : Rendezvous(rendezvous_key),
241 key_(std::move(rendezvous_key.global_devices)),
242 local_participants_(local_participants),
243 callback_(callback),
244 counter_(nullptr) {}
245
RunCollectiveOp(const NcclCliqueParticipantData &)246 StatusOr<LockedNcclClique> RunCollectiveOp(
247 const NcclCliqueParticipantData&) override {
248 tensorflow::mutex_lock lock(mu_);
249 bool primary = !initialized_;
250 if (primary) {
251 maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
252 key_, [&](const NcclCliqueKey& key) {
253 return CreateNcclClique(key, local_participants_, callback_);
254 });
255 initialized_ = true;
256 }
257 TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
258 std::unique_ptr<absl::MutexLock> clique_lock;
259 if (primary) {
260 clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
261 counter_ = new absl::BlockingCounter(local_participants_.size());
262 }
263 return LockedNcclClique(*clique, std::move(clique_lock), counter_);
264 }
265
266 private:
267 NcclCliqueKey key_;
268 const std::vector<LocalParticipant>& local_participants_;
269 const NcclUniqueIdCallback* callback_;
270
271 StatusOr<NcclClique*> maybe_clique_;
272 absl::BlockingCounter* counter_;
273 };
274
275 } // namespace
276
GetLocalParticipants(const std::vector<GlobalDeviceId> & participants,const std::vector<GlobalDeviceId> * local_devices)277 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
278 const std::vector<GlobalDeviceId>& participants,
279 const std::vector<GlobalDeviceId>* local_devices) {
280 std::vector<LocalParticipant> local_participants;
281 if (local_devices) {
282 absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
283 for (int rank = 0; rank < participants.size(); ++rank) {
284 auto result = device_ranks.emplace(participants[rank], rank);
285 TF_RET_CHECK(result.second) << "Duplicate device found";
286 }
287
288 local_participants.reserve(local_devices->size());
289 for (int device_ordinal = 0; device_ordinal < local_devices->size();
290 ++device_ordinal) {
291 auto it = device_ranks.find((*local_devices)[device_ordinal]);
292 if (it != device_ranks.end()) {
293 local_participants.push_back({device_ordinal, /*rank=*/it->second});
294 }
295 }
296 } else { // Single host, so use identity mapping (device ordinal == id).
297 local_participants.reserve(participants.size());
298 for (int rank = 0; rank < participants.size(); ++rank) {
299 int device_ordinal = participants[rank].value();
300 local_participants.push_back({device_ordinal, rank});
301 }
302 }
303
304 return local_participants;
305 }
306
LockedNcclClique(NcclClique & clique,std::unique_ptr<absl::MutexLock> lock,absl::BlockingCounter * counter)307 LockedNcclClique::LockedNcclClique(NcclClique& clique,
308 std::unique_ptr<absl::MutexLock> lock,
309 absl::BlockingCounter* counter)
310 : clique(clique), lock_(std::move(lock)), counter_(counter) {}
311
LockedNcclClique(LockedNcclClique && other)312 LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
313 : clique(other.clique),
314 lock_(std::move(other.lock_)),
315 counter_(std::exchange(other.counter_, nullptr)) {}
316
~LockedNcclClique()317 LockedNcclClique::~LockedNcclClique() {
318 if (counter_) {
319 counter_->DecrementCount();
320 if (lock_) {
321 counter_->Wait(); // Don't release lock until all threads are finished.
322 delete counter_;
323 }
324 }
325 }
326
GetOrTryCreateIfAbsent(const NcclCliqueKey & key,const std::function<StatusOr<std::unique_ptr<NcclClique>> (const NcclCliqueKey &)> & value_factory)327 StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
328 const NcclCliqueKey& key,
329 const std::function<StatusOr<std::unique_ptr<NcclClique>>(
330 const NcclCliqueKey&)>& value_factory) {
331 {
332 absl::MutexLock lock(&mu_);
333 auto it = map_.find(key);
334 if (it != map_.end()) {
335 return it->second.get();
336 }
337 }
338 // We release the lock to allow different cliques to be created in parallel
339 // (avoiding a potential deadlock in multi-host settings). This is safe
340 // provided that there aren't two threads trying to create cliques with the
341 // same key - which we know will not happen as this method is only called by
342 // the primary thread from the clique rendezvous. If this assumption is not
343 // valid, the method will return an error.
344 TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
345 absl::MutexLock lock(&mu_);
346 auto result = map_.emplace(key, std::move(value));
347 TF_RET_CHECK(result.second) << "Clique already in cache.";
348 return result.first->second.get();
349 }
350
ForEach(const std::function<void (const NcclCliqueKey &,const NcclClique &)> & fn)351 void NcclCliqueMap::ForEach(
352 const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
353 absl::MutexLock lock(&mu_);
354 for (const auto& kv : map_) {
355 fn(kv.first, *kv.second);
356 }
357 }
358
AcquireNcclClique(const NcclCliqueParticipantData & participant,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)359 StatusOr<LockedNcclClique> AcquireNcclClique(
360 const NcclCliqueParticipantData& participant,
361 const std::vector<LocalParticipant>& local_participants,
362 const NcclUniqueIdCallback* callback) {
363 VLOG(2) << "Rendezvous key: " << participant.rendezvous_key.ToString()
364 << ", local participants: "
365 << LocalParticipantsToString(local_participants);
366
367 static auto& rendezvous_map =
368 *new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
369
370 return NcclCliqueRendezvous::SubmitParticipant(
371 /*rendezvous_getter=*/
372 [&, participant] {
373 return rendezvous_map.GetOrCreateIfAbsent(
374 participant.rendezvous_key,
375 [&](const RendezvousKey& rendezvous_key) {
376 return std::make_unique<NcclCliqueRendezvous>(
377 participant.rendezvous_key, local_participants, callback);
378 });
379 },
380 participant);
381 }
382
383 } // namespace gpu
384 } // namespace xla
385