1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16
17 #include <stddef.h>
18
19 #include <algorithm>
20 #include <unordered_set>
21 #include <utility>
22
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/device_attributes.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace tensorflow {
40
CollectiveParamResolverLocal(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,NcclCommunicatorInterface * nccl_communicator,const string & task_name)41 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
42 const ConfigProto& config, const DeviceMgr* dev_mgr,
43 DeviceResolverInterface* dev_resolver,
44 NcclCommunicatorInterface* nccl_communicator, const string& task_name)
45 : nccl_(config.experimental().collective_nccl()),
46 dev_mgr_(dev_mgr),
47 dev_resolver_(dev_resolver),
48 nccl_communicator_(nccl_communicator),
49 task_name_(task_name),
50 gpu_ring_order_(
51 config.gpu_options().experimental().collective_ring_order()) {}
52
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)53 void CollectiveParamResolverLocal::CompleteGroupAsync(
54 const DeviceAttributes& device, CollGroupParams* group_params,
55 CancellationManager* cancel_mgr, const StatusCallback& done) {
56 CompleteGroupLocal(device, group_params, cancel_mgr, done);
57 }
58
59 namespace {
GetCollectiveName(const CollectiveParams * cp,bool nccl)60 const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
61 switch (cp->instance.type) {
62 case BROADCAST_COLLECTIVE:
63 return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
64
65 case REDUCTION_COLLECTIVE:
66 return nccl ? "NcclReduce" : "RingReduce";
67
68 case GATHER_COLLECTIVE:
69 return nccl ? "NcclGather" : "RingGather";
70
71 case PERMUTE_COLLECTIVE:
72 return "Permute";
73
74 case ALL_TO_ALL_COLLECTIVE:
75 return "AllToAll";
76
77 default:
78 return "undef";
79 }
80 }
81
TaskNameFromDeviceName(const string & device_name)82 string TaskNameFromDeviceName(const string& device_name) {
83 DeviceNameUtils::ParsedName parsed_device;
84 CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
85 string task_name;
86 CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
87 return task_name;
88 }
89 } // namespace
90
CompleteGroupLocal(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,StatusCallback done)91 void CollectiveParamResolverLocal::CompleteGroupLocal(
92 const DeviceAttributes& device, CollGroupParams* group_params,
93 CancellationManager* cancel_mgr, StatusCallback done) {
94 VLOG(1) << "CompleteGroup device=" << device.name() << ": "
95 << group_params->ToString();
96 std::vector<StatusCallback> to_be_called;
97
98 GroupRec* gr = nullptr;
99 Status status;
100 {
101 mutex_lock l(group_mu_);
102 auto it = group_table_.find(group_params->group_key);
103 if (it == group_table_.end()) {
104 gr = new GroupRec;
105 mutex_lock grl(gr->mu);
106 gr->group.group_key = group_params->group_key;
107 gr->group.group_size = group_params->group_size;
108 gr->group.device_type = group_params->device_type;
109 if (nccl_communicator_ != nullptr) {
110 gr->group.runtime_details.communicator_key =
111 nccl_communicator_->GenerateCommunicatorKey();
112 }
113 // Store GroupRec in group_table_ which is shared between all devices on
114 // this worker.
115 group_table_[gr->group.group_key].reset(gr);
116 VLOG(2) << "New group_key=" << gr->group.group_key
117 << " group_size=" << gr->group.group_size
118 << " runtime_details=" << gr->group.runtime_details.ToString();
119 } else {
120 gr = it->second.get();
121 }
122 }
123 {
124 mutex_lock l(status_mu_);
125 status = status_;
126 }
127 if (!status.ok()) {
128 done(status);
129 return;
130 }
131
132 if (cancel_mgr != nullptr) {
133 CancellationToken token = cancel_mgr->get_cancellation_token();
134 bool is_cancelled = !cancel_mgr->RegisterCallback(
135 token, std::bind(&CollectiveParamResolverLocal::CancelGroup, this,
136 group_params->group_key));
137 if (is_cancelled) {
138 done(errors::Cancelled("CompleteGroup is cancelled before it starts"));
139 return;
140 }
141 done = [cancel_mgr, token,
142 original_done = std::move(done)](const Status& status) {
143 cancel_mgr->TryDeregisterCallback(token);
144 original_done(status);
145 };
146 }
147
148 {
149 mutex_lock gr_lock(gr->mu);
150 // If there is ever an error associated with a group key, we store the error
151 // status and invoke all waiting and future callbacks with this error
152 // status.
153 VLOG(2) << "gr device_type=" << gr->group.device_type
154 << " cp device_type=" << group_params->device_type
155 << " current device=" << device.name();
156 if (gr->status.ok()) {
157 // Check for consistency with existing GroupRec.
158 if (group_params->device_type != gr->group.device_type) {
159 gr->status = errors::Internal(
160 "Device ", device.name(),
161 " is joining a group with incompatible device type",
162 gr->group.device_type.type_string(),
163 " (group_key=", gr->group.group_key, ")");
164 } else if (group_params->group_size != gr->group.group_size) {
165 gr->status = errors::Internal(
166 "Device ", device.name(), " is joining a group with size",
167 group_params->group_size, ", but that group has size ",
168 gr->group.group_size, " (group_key=", gr->group.group_key, ")");
169 }
170 }
171 bool new_device = false;
172 if (gr->status.ok()) {
173 // Insert device if not already present.
174 auto it = gr->incarnations_by_device_name.find(device.name());
175 if (it == gr->incarnations_by_device_name.end()) {
176 if (gr->group.devices.size() == gr->group.group_size) {
177 // The group is already full.
178 gr->status =
179 errors::Internal("Device ", device.name(),
180 " is joining a group that is already full",
181 " (group_key=", gr->group.group_key, ")");
182 } else {
183 // This is a new device that has not yet joined the group.
184 gr->incarnations_by_device_name[device.name()] = device.incarnation();
185 gr->group.devices.push_back(device);
186 new_device = true;
187 if (VLOG_IS_ON(1)) {
188 string dev_buf;
189 for (const auto& d : gr->group.devices) {
190 strings::StrAppend(&dev_buf, ",", d.name());
191 }
192 VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
193 << " group_size=" << gr->group.group_size << " (current"
194 << " devices)=(" << dev_buf << ") (number of"
195 << " devices pending)="
196 << (gr->group.group_size - gr->group.devices.size());
197 }
198 }
199 } else {
200 // If the device already exists, check if the incarnation matches.
201 if (it->second != device.incarnation()) {
202 gr->status = errors::FailedPrecondition(
203 "Device ", device.name(),
204 " current incarnation doesn't match with one in the group. This "
205 "usually means this worker has restarted but the collective "
206 "leader hasn't, or this worker connects to a wrong cluster.");
207 }
208 }
209 }
210
211 if (gr->status.ok()) {
212 // If the group is not yet complete, queue to wait for it.
213 VLOG(2) << "group_size " << gr->group.group_size << " set size "
214 << gr->group.devices.size() << " gr " << gr;
215
216 if (gr->group.devices.size() < gr->group.group_size) {
217 gr->pending_done.push_back(std::move(done));
218 gr->pending_params.push_back(group_params);
219 return;
220 }
221 CHECK_EQ(gr->group.devices.size(), gr->group.group_size);
222 // We get a full group. Fill in remaining fields in gr->group.
223 if (new_device) {
224 FinishGroup(gr);
225 }
226 // Copy to all pending CollGroupParams;
227 *group_params = gr->group;
228 for (auto* params : gr->pending_params) {
229 *params = gr->group;
230 }
231 }
232 // At this point, we either have a full group, or an error status. Ensure
233 // that all callbacks are invoked with the appropriate status.
234 to_be_called.swap(gr->pending_done);
235 gr->pending_params.clear();
236 status = gr->status;
237 }
238 done(status);
239 for (int i = 0; i < to_be_called.size(); ++i) {
240 to_be_called[i](status);
241 }
242 }
243
244 namespace {
245 struct DevRec {
246 string task;
247 string device;
248 int original_rank;
249 int local_rank;
250 int global_rank;
251 const DeviceLocality* locality;
252 };
253 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
254 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
255
256 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
BuildDevRecs(const CollGroupParams & gp)257 GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) {
258 GlobalDeviceMap gdm;
259 CHECK_EQ(gp.devices.size(), gp.task_names.size());
260 for (int i = 0; i < gp.devices.size(); ++i) {
261 TaskDeviceMap& tdm = gdm[gp.task_names[i]];
262 DevRec* dr = &tdm[gp.devices[i].name()];
263 dr->task = gp.task_names[i];
264 dr->device = gp.devices[i].name();
265 dr->original_rank = i;
266 dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
267 dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
268 dr->locality = &gp.devices[i].locality();
269 }
270 return gdm;
271 }
272
ParseRingOrder(const string & gpu_ring_order_str,TaskDeviceMap * tdm)273 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
274 std::vector<string> split_gpu_ring_order_str =
275 str_util::Split(gpu_ring_order_str, ',');
276 if (split_gpu_ring_order_str.size() != tdm->size()) return false;
277
278 // gpu id -> local rank
279 gtl::FlatMap<int32, int32> gpu_ranks;
280 for (int32_t rank = 0;
281 rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
282 int32_t tmp;
283 if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
284 gpu_ranks[tmp] = rank;
285 } else {
286 return false;
287 }
288 }
289
290 for (auto& tdm_it : *tdm) {
291 DeviceNameUtils::ParsedName parsed_name;
292 DevRec* dr = &tdm_it.second;
293 if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
294 return false;
295 }
296 auto rank_it = gpu_ranks.find(parsed_name.id);
297 if (rank_it == gpu_ranks.end()) return false;
298 dr->local_rank = rank_it->second;
299 }
300 VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
301 return true;
302 }
303
OrderTaskDeviceMap(const string & gpu_ring_order,TaskDeviceMap * tdm)304 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
305 CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
306
307 // If a valid ring order has been passed in via ConfigProto, use that.
308 if (ParseRingOrder(gpu_ring_order, tdm)) return;
309
310 // Either no ring order was passed in, or the format was unexpected.
311 // We now assign a ring order based on link strengths. Note that this
312 // algorithm is not optimal and may not always find the best ring order.
313 int least_rank = -1;
314 string next_device;
315 std::set<string> selected;
316 // Starting device is one with the least initial rank.
317 for (const auto& it : *tdm) {
318 if (least_rank < 0 || it.second.original_rank < least_rank) {
319 least_rank = it.second.original_rank;
320 next_device = it.second.device;
321 }
322 }
323 CHECK_GE(least_rank, 0);
324 DeviceNameUtils::ParsedName parsed_name;
325 CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
326 // NOTE: InterconnectLink has only a device_id, nothing more, so for
327 // the time being if there's more than one device at a task we
328 // assume they're all GPUs.
329
330 int next_rank = 0;
331 while (true) {
332 selected.insert(next_device);
333 auto next_dev_it = tdm->find(next_device);
334 CHECK(next_dev_it != tdm->end());
335 DevRec* dr = &next_dev_it->second;
336 dr->local_rank = next_rank;
337 ++next_rank;
338 if (selected.size() == tdm->size()) {
339 break;
340 }
341 // For the present time we assume Locality links only cover GPUs.
342 // For multiple CPUs, just take them in order.
343 const InterconnectLink* best_link = nullptr;
344 if (parsed_name.type == "GPU") {
345 for (const InterconnectLink& il : dr->locality->links().link()) {
346 parsed_name.id = il.device_id();
347 string endpoint_device =
348 DeviceNameUtils::ParsedNameToString(parsed_name);
349 // Skip the device if we've already seen it.
350 if (selected.find(endpoint_device) != selected.end()) {
351 continue;
352 }
353 // Skip the device if it is not participating in this collective
354 // instance.
355 if (tdm->find(endpoint_device) == tdm->end()) {
356 continue;
357 }
358 if (best_link == nullptr || il.strength() > best_link->strength()) {
359 best_link = &il;
360 }
361 }
362 }
363 if (best_link != nullptr) {
364 // Follow the best edge
365 parsed_name.id = best_link->device_id();
366 next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
367 } else {
368 // No good edges, alas. Pick the lowest initial rank among remaining
369 // devices.
370 least_rank = -1;
371 for (const auto& it : *tdm) {
372 if (selected.find(it.second.device) != selected.end()) {
373 continue;
374 }
375 if (least_rank < 0 || it.second.original_rank < least_rank) {
376 least_rank = it.second.original_rank;
377 next_device = it.second.device;
378 }
379 }
380 CHECK_GE(least_rank, 0);
381 }
382 }
383 }
384
385 // The first time a CollGroupParams is established for a group we compute a good
386 // rank order for all the devices in the group, that is appropriate for a ring
387 // algorithm.
EstablishGlobalRank(const CollGroupParams & gp,const string & gpu_ring_order)388 GlobalDeviceMap EstablishGlobalRank(const CollGroupParams& gp,
389 const string& gpu_ring_order) {
390 VLOG(1) << "EstablishGlobalRank";
391 GlobalDeviceMap gdm = BuildDevRecs(gp);
392 for (auto& iter : gdm) {
393 TaskDeviceMap& tdm = iter.second;
394 OrderTaskDeviceMap(gpu_ring_order, &tdm);
395 }
396 // Connect the global rank order by the order in which tasks first appear.
397 std::set<string> ordered_tasks;
398 int next_rank = 0;
399 for (int i = 0; i < gp.task_names.size(); ++i) {
400 const string& task_name = gp.task_names[i];
401 if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
402 continue;
403 }
404 ordered_tasks.insert(task_name);
405 TaskDeviceMap* tdm = &gdm[task_name];
406 for (auto& it : *tdm) {
407 it.second.global_rank = it.second.local_rank + next_rank;
408 }
409 next_rank += tdm->size();
410 }
411 return gdm;
412 }
413
414 // Count the devices associated with each task and set
415 // gp->same_num_devices_per_task. Requires gp->task_names
416 // be sorted.
SetDevPerTask(CollGroupParams * gp)417 void SetDevPerTask(CollGroupParams* gp) {
418 gp->num_devices_per_task.clear();
419 const string* last_task_name = &gp->task_names[0];
420 int count = 0;
421 for (const string& task_name : gp->task_names) {
422 if (task_name == *last_task_name) {
423 ++count;
424 } else {
425 gp->num_devices_per_task[*last_task_name] = count;
426 count = 1;
427 last_task_name = &task_name;
428 }
429 }
430 gp->num_devices_per_task[*last_task_name] = count;
431
432 gp->same_num_devices_per_task = false;
433 int dev_per_task = -1;
434 for (const auto& task_dev : gp->num_devices_per_task) {
435 if (dev_per_task == -1) {
436 dev_per_task = task_dev.second;
437 } else if (dev_per_task != task_dev.second) {
438 return;
439 }
440 }
441 gp->same_num_devices_per_task = true;
442 }
443
444 } // namespace
445
FinishGroup(GroupRec * gr)446 void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
447 // Sort devices lexicographically first.
448 std::sort(gr->group.devices.begin(), gr->group.devices.end(),
449 [](const DeviceAttributes& lhs, const DeviceAttributes& rhs) {
450 return lhs.name() < rhs.name();
451 });
452 // Build task_names, which is needed by CompleteDefaultRanking.
453 gr->group.task_names.reserve(gr->group.devices.size());
454 for (const DeviceAttributes& device : gr->group.devices) {
455 gr->group.task_names.push_back(TaskNameFromDeviceName(device.name()));
456 }
457 // Establish the final order of gp->devices and gp->task_names by
458 // considering localities of all devices.
459 CompleteDefaultRanking(&gr->group);
460 SetDevPerTask(&gr->group);
461 gr->group.num_tasks =
462 static_cast<int32>(gr->group.num_devices_per_task.size());
463 }
464
CancelGroup(int32 group_key)465 void CollectiveParamResolverLocal::CancelGroup(int32 group_key) {
466 std::vector<StatusCallback> pending_done;
467 GroupRec* gr = nullptr;
468 {
469 mutex_lock l(group_mu_);
470 auto it = group_table_.find(group_key);
471 if (it == group_table_.end()) {
472 return;
473 }
474 gr = it->second.get();
475 }
476 {
477 mutex_lock l(gr->mu);
478 if (gr->group.devices.size() == gr->group.group_size) {
479 // The group is already complete. There's no need to cancel.
480 return;
481 }
482 gr->status = errors::Cancelled("group is cancelled");
483 pending_done.swap(gr->pending_done);
484 gr->pending_params.clear();
485 }
486 for (const StatusCallback& done : pending_done) {
487 done(errors::Cancelled("group is cancelled"));
488 }
489 }
490
CompleteTaskIsLocal(const string & task_name,CollectiveParams * cp)491 void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
492 CollectiveParams* cp) {
493 cp->task.is_local.resize(cp->group.group_size, false);
494 for (int i = 0; i < cp->group.group_size; ++i) {
495 cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
496 }
497 }
498
SetDefaultRank(const string & device,CollectiveParams * cp)499 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
500 CollectiveParams* cp) {
501 CHECK_EQ(cp->group.group_size, cp->group.devices.size()) << cp->ToString();
502 for (int i = 0; i < cp->group.group_size; ++i) {
503 if (cp->group.devices[i].name() == device) {
504 cp->default_rank = i;
505 break;
506 }
507 }
508 }
509
InitInstanceSharedParams(const CollectiveParams * cp,InstanceRec * ir)510 void CollectiveParamResolverLocal::InitInstanceSharedParams(
511 const CollectiveParams* cp, InstanceRec* ir) {
512 ir->shared->instance = cp->instance;
513 ir->shared->default_rank = -1;
514
515 // Set is_local and task_names in *shared prior to invoking
516 // GetDeviceAttributesAsync. In a distributed context this function can be
517 // called by a derived class, some of the devices may be non-local and
518 // GetDeviceAttributesAsync will use those fields to launch RPCs.
519 CompleteTaskIsLocal(task_name_, ir->shared);
520 }
521
522 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
523 // to all devices that they are physically connected to and visible to the
524 // TensorFlow runtime. This set of devices may be a superset of the devices
525 // participating in this instance of collectives.
CompleteDefaultRanking(CollGroupParams * gp)526 void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) {
527 // Establish an instance-specific default rank order for devices
528 // based on localities. This rank order should be a good ring
529 // order, if possible.
530 GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_);
531 // Reflect the new global ranking on shared
532 std::vector<DeviceAttributes> new_devices(gp->group_size);
533 std::vector<string> new_task_names(gp->group_size);
534 for (const auto& git : gdm) {
535 const TaskDeviceMap& tdm = git.second;
536 for (const auto& tit : tdm) {
537 const DevRec& dr = tit.second;
538 new_devices[dr.global_rank] = gp->devices[dr.original_rank];
539 new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
540 }
541 }
542
543 if (VLOG_IS_ON(2)) {
544 string buf;
545 for (const auto& d : new_devices) strings::StrAppend(&buf, "\n", d.name());
546 VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
547 << buf;
548 }
549 gp->devices = std::move(new_devices);
550 gp->task_names = std::move(new_task_names);
551 }
552
553 CollectiveParamResolverLocal::InstanceRec*
GetOrCreateInstanceRec(CollectiveParams * cp,bool * created)554 CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp,
555 bool* created) {
556 *created = false;
557 InstanceRec* irec = nullptr;
558 {
559 mutex_lock l(instance_mu_);
560 auto group_it = instance_table_.find(cp->group.group_key);
561 if (group_it != instance_table_.end()) {
562 auto instance_it = group_it->second.find(cp->instance.instance_key);
563 if (instance_it != group_it->second.end()) {
564 irec = instance_it->second.get();
565 }
566 }
567 if (irec == nullptr) {
568 // Create new InstanceRec.
569 irec = new InstanceRec;
570 *created = true;
571 {
572 mutex_lock il(irec->mu);
573 irec->known.resize(cp->group.group_size, false);
574 }
575 InitInstanceSharedParams(cp, irec);
576 instance_table_[cp->group.group_key][cp->instance.instance_key].reset(
577 irec);
578 }
579 }
580 Status status;
581 {
582 mutex_lock l(status_mu_);
583 status = status_;
584 }
585 if (!status.ok()) {
586 mutex_lock l(irec->mu);
587 irec->status = status;
588 }
589 return irec;
590 }
591
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)592 void CollectiveParamResolverLocal::CompleteParamsAsync(
593 const DeviceAttributes& device, CollectiveParams* cp,
594 CancellationManager* cancel_mgr, const StatusCallback& done) {
595 VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
596 << cp->ToString();
597 CompleteGroupLocal(device, &cp->group, cancel_mgr,
598 [this, device, cp, done](const Status& s) {
599 if (s.ok()) {
600 CompleteInstanceLocal(device.name(), cp, done);
601 } else {
602 done(s);
603 }
604 });
605 }
606
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)607 void CollectiveParamResolverLocal::CompleteInstanceAsync(
608 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
609 CancellationManager* cancel_mgr, const StatusCallback& done) {
610 done(
611 errors::Internal("CompleteInstance is not implemented by "
612 "CollectiveParamResolverLocal which is "
613 "intended only for non-distributed deployment."));
614 }
615
616 // TODO(b/111897089): we need a better way to pick the collective
617 // implementation. The ideal way would depend upon the topology and link
618 // strength before picking a particular implementation.
AssignCollectiveType(CollectiveParams * cp)619 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
620 // We use the NCCL implementation if this is an environment which supports
621 // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
622 // also if indicated either in `ConfigProto` or `communication_hint`.
623 //
624 // After enough testing, we may simplify this logic to use NCCL whenever
625 // available.
626 CollectiveImplementationInterface* col_impl;
627 bool use_nccl =
628 (nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
629 CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
630 .ok();
631 cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
632 VLOG(1) << "AssignCollectiveType "
633 << cp->instance.impl_details.collective_name;
634 }
635
CompleteInstanceLocal(const string & device,CollectiveParams * cp,const StatusCallback & done)636 void CollectiveParamResolverLocal::CompleteInstanceLocal(
637 const string& device, CollectiveParams* cp, const StatusCallback& done) {
638 VLOG(1) << "CompleteInstanceLocal " << device
639 << " instance_key: " << cp->instance.instance_key << " group_key "
640 << cp->group.group_key;
641
642 bool created_irec;
643 InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
644 if (!created_irec) {
645 // Check that the preexisting IRec is consistent with the params passed into
646 // this invocation.
647 if (ir->shared->instance.type != cp->instance.type ||
648 ir->shared->instance.data_type != cp->instance.data_type) {
649 done(errors::Internal("Collective instance ", cp->instance.instance_key,
650 " expected type ", ir->shared->instance.type,
651 " and data_type ", ir->shared->instance.data_type,
652 " but got type ", cp->instance.type,
653 " and data_type ", cp->instance.data_type));
654 return;
655 }
656 }
657 CompleteInstanceFromInitializedIRec(device, cp, ir, done);
658 }
659
CompleteInstanceFromInitializedIRec(const string & device,CollectiveParams * cp,InstanceRec * ir,const StatusCallback & done)660 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
661 const string& device, CollectiveParams* cp, InstanceRec* ir,
662 const StatusCallback& done) {
663 auto expected_shape = cp->instance.shape;
664 Status status;
665 // Populate the fields common across instance.
666 {
667 mutex_lock l(ir->mu);
668 status = ir->status;
669 if (status.ok()) {
670 // custom operator= does a deep copy.
671 cp->instance = ir->shared->instance;
672 }
673 }
674 if (!status.ok()) {
675 done(status);
676 return;
677 }
678 if (expected_shape != cp->instance.shape) {
679 done(errors::InvalidArgument(
680 "Shape mismatch in the collective instance ", cp->instance.instance_key,
681 ". Op at device ", device, " expected shape ",
682 expected_shape.DebugString(), " but another member in the group ",
683 "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
684 " due to different input shapes at different members of the collective",
685 " op."));
686 return;
687 }
688 // Populate the fields common across task.
689 AssignCollectiveType(cp);
690 SetDefaultRank(device, cp);
691 CompleteTaskIsLocal(task_name_, cp);
692
693 CollectiveImplementationInterface* col_impl;
694 status = CollectiveRegistry::LookupParamResolverInstance(
695 cp->instance.impl_details.collective_name, &col_impl);
696 if (!status.ok()) {
697 done(status);
698 return;
699 }
700
701 // We may need to wait for the group, if this is a broadcast, for source
702 // discovery.
703 if (cp->instance.type == BROADCAST_COLLECTIVE) {
704 WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) {
705 Status s;
706 if (ir != irec) {
707 s = errors::Internal("Expected ir ", ir, " and irec ", irec,
708 " to be equal");
709 } else {
710 mutex_lock l(irec->mu);
711 s = irec->status;
712 cp->source_rank = irec->source_rank;
713 }
714 if (s.ok()) {
715 s = col_impl->InitializeCollectiveParams(cp);
716 }
717 done(s);
718 });
719 } else {
720 done(col_impl->InitializeCollectiveParams(cp));
721 }
722 }
723
WaitForGroup(InstanceRec * ir,CollectiveParams * cp,const IRConsumer & f)724 void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
725 CollectiveParams* cp,
726 const IRConsumer& f) {
727 std::vector<IRConsumer> ready_waiters;
728 do {
729 mutex_lock l(ir->mu);
730 if (!ir->status.ok()) {
731 break;
732 }
733 CHECK_EQ(cp->group.group_size, ir->known.size());
734 CHECK_GE(cp->default_rank, 0);
735 if (!ir->known[cp->default_rank]) {
736 ir->known[cp->default_rank] = true;
737 ++ir->known_count;
738 if (cp->is_source) {
739 // Initialize source rank.
740 if (ir->source_rank >= 0) {
741 ir->status = errors::Internal("Instance ", cp->instance.instance_key,
742 " already has source ", ir->source_rank,
743 ", received second claim from ",
744 cp->default_rank);
745 } else {
746 ir->source_rank = cp->default_rank;
747 }
748 }
749 }
750 if (ir->known_count < cp->group.group_size) {
751 ir->known_waiters.push_back(f);
752 return;
753 }
754 CHECK_EQ(ir->known_count, cp->group.group_size);
755 if (ir->source_rank < 0) {
756 // NOTE(ayushd): changing the error message below would also require
757 // updating CompleteParamsBroadcastForgotSend test in
758 // CollectiveParamResolverLocalTest.
759 ir->status =
760 errors::Internal("Instance ", cp->instance.instance_key,
761 " found no source for broadcast. This "
762 "could mean that there were group_size=",
763 ir->known_count, " BcastRecvs but no BcastSend.");
764 }
765 if (!ir->known_waiters.empty()) {
766 ready_waiters = std::move(ir->known_waiters);
767 }
768 } while (false);
769 f(ir);
770 for (auto& f : ready_waiters) {
771 f(ir);
772 }
773 }
774
StartAbort(const Status & s)775 void CollectiveParamResolverLocal::StartAbort(const Status& s) {
776 {
777 mutex_lock l(status_mu_);
778 if (!status_.ok()) {
779 VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
780 "subsequent abortion with status: "
781 << s;
782 return;
783 }
784 status_ = s;
785 }
786 StartAbortLocal(s);
787 }
788
StartAbortLocal(const Status & s)789 void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
790 std::vector<StatusCallback> pending_done;
791 {
792 mutex_lock l(group_mu_);
793 for (const auto& item : group_table_) {
794 GroupRec* gr = item.second.get();
795 {
796 mutex_lock gl(gr->mu);
797 gr->status = s;
798 for (auto& done : gr->pending_done) {
799 pending_done.push_back(std::move(done));
800 }
801 gr->pending_done.clear();
802 gr->pending_params.clear();
803 }
804 }
805 }
806 for (const StatusCallback& done : pending_done) {
807 done(s);
808 }
809 std::vector<InstanceRec*> instances;
810 {
811 mutex_lock l(instance_mu_);
812 for (const auto& group_entry : instance_table_) {
813 for (const auto& item : group_entry.second) {
814 instances.push_back(item.second.get());
815 }
816 }
817 }
818 for (InstanceRec* ir : instances) {
819 std::vector<IRConsumer> known_waiters;
820 {
821 mutex_lock il(ir->mu);
822 ir->status = s;
823 known_waiters.swap(ir->known_waiters);
824 }
825 for (const IRConsumer& done : known_waiters) {
826 done(ir);
827 }
828 }
829 }
830
831 } // namespace tensorflow
832