1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_utils.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/synchronization/blocking_counter.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
26 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
27 #include "tensorflow/compiler/xla/service/global_device_id.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/platform/errors.h"
32
33 namespace xla {
34 namespace gpu {
35
ToNcclReduction(ReductionKind kind)36 ncclRedOp_t ToNcclReduction(ReductionKind kind) {
37 switch (kind) {
38 case ReductionKind::SUM:
39 return ncclSum;
40 case ReductionKind::PRODUCT:
41 return ncclProd;
42 case ReductionKind::MIN:
43 return ncclMin;
44 case ReductionKind::MAX:
45 return ncclMax;
46 }
47 }
48
ToNcclDataType(PrimitiveType element_type)49 StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType element_type) {
50 switch (element_type) {
51 case S8:
52 return ncclInt8;
53 case PRED:
54 case U8:
55 return ncclUint8;
56 case S32:
57 return ncclInt32;
58 case U32:
59 return ncclUint32;
60 case S64:
61 return ncclInt64;
62 case U64:
63 return ncclUint64;
64 case F16:
65 return ncclFloat16;
66 case F32:
67 return ncclFloat32;
68 case F64:
69 return ncclFloat64;
70 default:
71 return tensorflow::errors::InvalidArgument(absl::StrFormat(
72 "Unsupported data type: %s", PrimitiveType_Name(element_type)));
73 }
74 }
75
IsGlobalNcclConfig()76 bool IsGlobalNcclConfig() {
77 static bool global_nccl_config = std::getenv("NCCL_COMM_ID") != nullptr;
78 return global_nccl_config;
79 }
80
ToStatus(ncclResult_t s,const char * file,int64 line,const char * expr)81 Status ToStatus(ncclResult_t s, const char* file, int64 line,
82 const char* expr) {
83 if (s == ncclSuccess) {
84 return Status::OK();
85 }
86 return tensorflow::errors::Internal(
87 absl::StrFormat("%s:%d: NCCL operation %s failed: %s", file, line, expr,
88 ncclGetErrorString(s)));
89 }
90
ToStatus(cudaError_t s,const char * file,int64 line,const char * expr)91 Status ToStatus(cudaError_t s, const char* file, int64 line, const char* expr) {
92 if (s == cudaSuccess) {
93 return Status::OK();
94 }
95 return tensorflow::errors::Internal(
96 absl::StrFormat("%s:%d: CUDA operation %s failed: %s", file, line, expr,
97 cudaGetErrorString(s)));
98 }
99
NcclClique(absl::flat_hash_map<int,NcclComm> comms_by_device_ordinal)100 NcclClique::NcclClique(
101 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal)
102 : comms_by_device_ordinal_(std::move(comms_by_device_ordinal)) {}
103
GetCommForDeviceOrdinal(int device_ordinal) const104 ncclComm_t NcclClique::GetCommForDeviceOrdinal(int device_ordinal) const {
105 return comms_by_device_ordinal_.at(device_ordinal).get();
106 }
107
NcclCliqueCache()108 NcclCliqueMap& NcclCliqueCache() {
109 // Global cache of NCCL cliques. An entry in this map is always kept alive.
110 //
111 // A consequence of the fact that this is process-global is that we'll only
112 // ever have one clique alive for a given set of GPUs. This means that a
113 // process will never do two collective operations concurrently on the same
114 // set of GPUs.
115 static auto& cache = *new NcclCliqueMap();
116 return cache;
117 }
118
119 namespace {
120
DestroyNcclComm(ncclComm_t comm)121 void DestroyNcclComm(ncclComm_t comm) {
122 VLOG(3) << absl::StreamFormat("Destroying comm %p", comm);
123 XLA_CUDA_WARN_IF_ERROR(ncclCommDestroy(comm));
124 }
125
ToNcclUniqueId(const std::string & str_id,ncclUniqueId * nccl_id)126 Status ToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) {
127 if (str_id.size() != NCCL_UNIQUE_ID_BYTES) {
128 return InvalidArgument(
129 "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(),
130 NCCL_UNIQUE_ID_BYTES);
131 }
132 // NcclUniqueId is internally just a char[].
133 static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES,
134 "NCCL_UNIQUE_ID_BYTES");
135 std::memcpy(static_cast<void*>(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES);
136 return Status::OK();
137 }
138
LocalParticipantsToString(const std::vector<LocalParticipant> & local_participants)139 std::string LocalParticipantsToString(
140 const std::vector<LocalParticipant>& local_participants) {
141 std::vector<std::string> parts;
142 for (const LocalParticipant& local_participant : local_participants) {
143 parts.push_back(absl::StrFormat("%d/rank=%d",
144 local_participant.device_ordinal,
145 local_participant.rank));
146 }
147 return absl::StrJoin(parts, ",");
148 }
149
CreateNcclClique(const NcclCliqueKey & key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)150 StatusOr<std::unique_ptr<NcclClique>> CreateNcclClique(
151 const NcclCliqueKey& key,
152 const std::vector<LocalParticipant>& local_participants,
153 const NcclUniqueIdCallback* callback) {
154 int num_participants = key.devices().size();
155 ncclUniqueId unique_id;
156 if (callback) { // Multi-host collective.
157 TF_ASSIGN_OR_RETURN(std::string id_string, (*callback)(key));
158 TF_RETURN_IF_ERROR(ToNcclUniqueId(id_string, &unique_id));
159 } else {
160 TF_RET_CHECK((num_participants == local_participants.size()) ||
161 IsGlobalNcclConfig())
162 << "If non-local devices are taking part of a collective API on GPU, "
163 "the nccl_unique_id_callback must be provided by the client.";
164 XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&unique_id));
165 }
166
167 VLOG(3) << "Initializing nccl comms for local participants: "
168 << LocalParticipantsToString(local_participants);
169
170 // Restore CUDA device after running this. XLA shouldn't care, but maybe
171 // another consumer does.
172 int initial_cuda_device;
173 XLA_CUDA_RETURN_IF_ERROR(cudaGetDevice(&initial_cuda_device));
174 auto cuda_device_restorer = MakeCleanup(
175 [&] { XLA_CUDA_WARN_IF_ERROR(cudaSetDevice(initial_cuda_device)); });
176
177 // When using ncclGroupStart/End it seems that the ncclComm_t's are not
178 // populated until the End() call.
179 std::vector<ncclComm_t> raw_comms(local_participants.size(), nullptr);
180 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
181 Status status = [&] {
182 for (int i = 0; i < local_participants.size(); ++i) {
183 XLA_CUDA_RETURN_IF_ERROR(
184 cudaSetDevice(local_participants[i].device_ordinal));
185 XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], num_participants,
186 unique_id,
187 local_participants[i].rank));
188 }
189 return Status::OK();
190 }();
191 // Always call ncclGroupEnd().
192 status.Update(XLA_CUDA_STATUS(ncclGroupEnd()));
193
194 // Always copy raw comms to RAII type, so they are cleaned up properly.
195 absl::flat_hash_map<int, NcclComm> comms_by_device_ordinal(raw_comms.size());
196 for (int i = 0; i < raw_comms.size(); ++i) {
197 int device_ordinal = local_participants[i].device_ordinal;
198 VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
199 device_ordinal, raw_comms[i]);
200 CHECK(raw_comms[i] != nullptr || !status.ok());
201 comms_by_device_ordinal.emplace(device_ordinal,
202 NcclComm(raw_comms[i], &DestroyNcclComm));
203 }
204
205 // Now we can check if there was an error creating the communicators.
206 TF_RETURN_IF_ERROR(status);
207 return std::make_unique<NcclClique>(std::move(comms_by_device_ordinal));
208 }
209
210 struct NcclCliqueParticipantData : public ParticipantData {
211 using ParticipantData::ParticipantData;
ToStringxla::gpu::__anonf15227f90111::NcclCliqueParticipantData212 std::string ToString() const override { return ""; }
213 };
214
215 class NcclCliqueRendezvous
216 : public Rendezvous<NcclCliqueParticipantData, LockedNcclClique> {
217 public:
NcclCliqueRendezvous(const RendezvousKey & rendezvous_key,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)218 NcclCliqueRendezvous(const RendezvousKey& rendezvous_key,
219 const std::vector<LocalParticipant>& local_participants,
220 const NcclUniqueIdCallback* callback)
221 : Rendezvous(rendezvous_key),
222 key_(std::move(rendezvous_key.global_devices)),
223 local_participants_(local_participants),
224 callback_(callback),
225 counter_(nullptr) {}
226
RunCollectiveOp(const NcclCliqueParticipantData &)227 StatusOr<LockedNcclClique> RunCollectiveOp(
228 const NcclCliqueParticipantData&) override {
229 tensorflow::mutex_lock lock(mu_);
230 bool primary = !initialized_;
231 if (primary) {
232 maybe_clique_ = NcclCliqueCache().GetOrTryCreateIfAbsent(
233 key_, [&](const NcclCliqueKey& key) {
234 return CreateNcclClique(key, local_participants_, callback_);
235 });
236 initialized_ = true;
237 }
238 TF_ASSIGN_OR_RETURN(NcclClique * clique, maybe_clique_);
239 std::unique_ptr<absl::MutexLock> clique_lock;
240 if (primary) {
241 clique_lock = std::make_unique<absl::MutexLock>(clique->mu());
242 counter_ = new absl::BlockingCounter(local_participants_.size());
243 }
244 return LockedNcclClique(*clique, std::move(clique_lock), counter_);
245 }
246
247 private:
248 NcclCliqueKey key_;
249 const std::vector<LocalParticipant>& local_participants_;
250 const NcclUniqueIdCallback* callback_;
251
252 StatusOr<NcclClique*> maybe_clique_;
253 absl::BlockingCounter* counter_;
254 };
255
256 } // namespace
257
GetLocalParticipants(const std::vector<GlobalDeviceId> & participants,const std::vector<GlobalDeviceId> * local_devices)258 StatusOr<std::vector<LocalParticipant>> GetLocalParticipants(
259 const std::vector<GlobalDeviceId>& participants,
260 const std::vector<GlobalDeviceId>* local_devices) {
261 std::vector<LocalParticipant> local_participants;
262 if (local_devices) {
263 absl::flat_hash_map<GlobalDeviceId, int> device_ranks(participants.size());
264 for (int rank = 0; rank < participants.size(); ++rank) {
265 auto result = device_ranks.emplace(participants[rank], rank);
266 TF_RET_CHECK(result.second) << "Duplicate device found";
267 }
268
269 local_participants.reserve(local_devices->size());
270 for (int device_ordinal = 0; device_ordinal < local_devices->size();
271 ++device_ordinal) {
272 auto it = device_ranks.find((*local_devices)[device_ordinal]);
273 if (it != device_ranks.end()) {
274 local_participants.push_back({device_ordinal, /*rank=*/it->second});
275 }
276 }
277 } else { // Single host, so use identity mapping (device ordinal == id).
278 local_participants.reserve(participants.size());
279 for (int rank = 0; rank < participants.size(); ++rank) {
280 int device_ordinal = participants[rank].value();
281 local_participants.push_back({device_ordinal, rank});
282 }
283 }
284
285 return local_participants;
286 }
287
LockedNcclClique(NcclClique & clique,std::unique_ptr<absl::MutexLock> lock,absl::BlockingCounter * counter)288 LockedNcclClique::LockedNcclClique(NcclClique& clique,
289 std::unique_ptr<absl::MutexLock> lock,
290 absl::BlockingCounter* counter)
291 : clique(clique), lock_(std::move(lock)), counter_(counter) {}
292
LockedNcclClique(LockedNcclClique && other)293 LockedNcclClique::LockedNcclClique(LockedNcclClique&& other)
294 : clique(other.clique),
295 lock_(std::move(other.lock_)),
296 counter_(std::exchange(other.counter_, nullptr)) {}
297
~LockedNcclClique()298 LockedNcclClique::~LockedNcclClique() {
299 if (counter_) {
300 counter_->DecrementCount();
301 if (lock_) {
302 counter_->Wait(); // Don't release lock until all threads are finished.
303 delete counter_;
304 }
305 }
306 }
307
GetOrTryCreateIfAbsent(const NcclCliqueKey & key,const std::function<StatusOr<std::unique_ptr<NcclClique>> (const NcclCliqueKey &)> & value_factory)308 StatusOr<NcclClique*> NcclCliqueMap::GetOrTryCreateIfAbsent(
309 const NcclCliqueKey& key,
310 const std::function<StatusOr<std::unique_ptr<NcclClique>>(
311 const NcclCliqueKey&)>& value_factory) {
312 absl::MutexLock lock(&mu_);
313 auto it = map_.find(key);
314 if (it == map_.end()) {
315 TF_ASSIGN_OR_RETURN(std::unique_ptr<NcclClique> value, value_factory(key));
316 it = map_.emplace(key, std::move(value)).first;
317 }
318 return it->second.get();
319 }
320
ForEach(const std::function<void (const NcclCliqueKey &,const NcclClique &)> & fn)321 void NcclCliqueMap::ForEach(
322 const std::function<void(const NcclCliqueKey&, const NcclClique&)>& fn) {
323 absl::MutexLock lock(&mu_);
324 for (const auto& kv : map_) {
325 fn(kv.first, *kv.second);
326 }
327 }
328
AcquireNcclClique(const RendezvousKey & rendezvous_key,int local_device_ordinal,se::Stream * stream,const std::vector<LocalParticipant> & local_participants,const NcclUniqueIdCallback * callback)329 StatusOr<LockedNcclClique> AcquireNcclClique(
330 const RendezvousKey& rendezvous_key, int local_device_ordinal,
331 se::Stream* stream, const std::vector<LocalParticipant>& local_participants,
332 const NcclUniqueIdCallback* callback) {
333 VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString()
334 << ", local participants: "
335 << LocalParticipantsToString(local_participants);
336
337 static auto& rendezvous_map =
338 *new RefcountingHashMap<RendezvousKey, NcclCliqueRendezvous>();
339
340 NcclCliqueParticipantData participant(rendezvous_key, local_device_ordinal,
341 stream);
342 return NcclCliqueRendezvous::SubmitParticipant(
343 /*rendezvous_getter=*/
344 [&] {
345 return rendezvous_map.GetOrCreateIfAbsent(
346 rendezvous_key, [&](const RendezvousKey& rendezvous_key) {
347 return std::make_unique<NcclCliqueRendezvous>(
348 rendezvous_key, local_participants, callback);
349 });
350 },
351 participant);
352 }
353
354 } // namespace gpu
355 } // namespace xla
356