• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16 
17 #include <stddef.h>
18 
19 #include <algorithm>
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/device_attributes.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 
CollectiveParamResolverLocal(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,NcclCommunicatorInterface * nccl_communicator,const string & task_name)41 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
42     const ConfigProto& config, const DeviceMgr* dev_mgr,
43     DeviceResolverInterface* dev_resolver,
44     NcclCommunicatorInterface* nccl_communicator, const string& task_name)
45     : nccl_(config.experimental().collective_nccl()),
46       dev_mgr_(dev_mgr),
47       dev_resolver_(dev_resolver),
48       nccl_communicator_(nccl_communicator),
49       task_name_(task_name),
50       gpu_ring_order_(
51           config.gpu_options().experimental().collective_ring_order()) {}
52 
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)53 void CollectiveParamResolverLocal::CompleteGroupAsync(
54     const DeviceAttributes& device, CollGroupParams* group_params,
55     CancellationManager* cancel_mgr, const StatusCallback& done) {
56   CompleteGroupLocal(device, group_params, cancel_mgr, done);
57 }
58 
59 namespace {
GetCollectiveName(const CollectiveParams * cp,bool nccl)60 const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
61   switch (cp->instance.type) {
62     case BROADCAST_COLLECTIVE:
63       return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
64 
65     case REDUCTION_COLLECTIVE:
66       return nccl ? "NcclReduce" : "RingReduce";
67 
68     case GATHER_COLLECTIVE:
69       return nccl ? "NcclGather" : "RingGather";
70 
71     case PERMUTE_COLLECTIVE:
72       return "Permute";
73 
74     case ALL_TO_ALL_COLLECTIVE:
75       return "AllToAll";
76 
77     default:
78       return "undef";
79   }
80 }
81 
TaskNameFromDeviceName(const string & device_name)82 string TaskNameFromDeviceName(const string& device_name) {
83   DeviceNameUtils::ParsedName parsed_device;
84   CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
85   string task_name;
86   CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
87   return task_name;
88 }
89 }  // namespace
90 
CompleteGroupLocal(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,StatusCallback done)91 void CollectiveParamResolverLocal::CompleteGroupLocal(
92     const DeviceAttributes& device, CollGroupParams* group_params,
93     CancellationManager* cancel_mgr, StatusCallback done) {
94   VLOG(1) << "CompleteGroup device=" << device.name() << ": "
95           << group_params->ToString();
96   std::vector<StatusCallback> to_be_called;
97 
98   GroupRec* gr = nullptr;
99   Status status;
100   {
101     mutex_lock l(group_mu_);
102     auto it = group_table_.find(group_params->group_key);
103     if (it == group_table_.end()) {
104       gr = new GroupRec;
105       mutex_lock grl(gr->mu);
106       gr->group.group_key = group_params->group_key;
107       gr->group.group_size = group_params->group_size;
108       gr->group.device_type = group_params->device_type;
109       if (nccl_communicator_ != nullptr) {
110         gr->group.runtime_details.communicator_key =
111             nccl_communicator_->GenerateCommunicatorKey();
112       }
113       // Store GroupRec in group_table_ which is shared between all devices on
114       // this worker.
115       group_table_[gr->group.group_key].reset(gr);
116       VLOG(2) << "New group_key=" << gr->group.group_key
117               << " group_size=" << gr->group.group_size
118               << " runtime_details=" << gr->group.runtime_details.ToString();
119     } else {
120       gr = it->second.get();
121     }
122   }
123   {
124     mutex_lock l(status_mu_);
125     status = status_;
126   }
127   if (!status.ok()) {
128     done(status);
129     return;
130   }
131 
132   if (cancel_mgr != nullptr) {
133     CancellationToken token = cancel_mgr->get_cancellation_token();
134     bool is_cancelled = !cancel_mgr->RegisterCallback(
135         token, std::bind(&CollectiveParamResolverLocal::CancelGroup, this,
136                          group_params->group_key));
137     if (is_cancelled) {
138       done(errors::Cancelled("CompleteGroup is cancelled before it starts"));
139       return;
140     }
141     done = [cancel_mgr, token,
142             original_done = std::move(done)](const Status& status) {
143       cancel_mgr->TryDeregisterCallback(token);
144       original_done(status);
145     };
146   }
147 
148   {
149     mutex_lock gr_lock(gr->mu);
150     // If there is ever an error associated with a group key, we store the error
151     // status and invoke all waiting and future callbacks with this error
152     // status.
153     VLOG(2) << "gr device_type=" << gr->group.device_type
154             << " cp device_type=" << group_params->device_type
155             << " current device=" << device.name();
156     if (gr->status.ok()) {
157       // Check for consistency with existing GroupRec.
158       if (group_params->device_type != gr->group.device_type) {
159         gr->status = errors::Internal(
160             "Device ", device.name(),
161             " is joining a group with incompatible device type",
162             gr->group.device_type.type_string(),
163             " (group_key=", gr->group.group_key, ")");
164       } else if (group_params->group_size != gr->group.group_size) {
165         gr->status = errors::Internal(
166             "Device ", device.name(), " is joining a group with size",
167             group_params->group_size, ", but that group has size ",
168             gr->group.group_size, " (group_key=", gr->group.group_key, ")");
169       }
170     }
171     bool new_device = false;
172     if (gr->status.ok()) {
173       // Insert device if not already present.
174       auto it = gr->incarnations_by_device_name.find(device.name());
175       if (it == gr->incarnations_by_device_name.end()) {
176         if (gr->group.devices.size() == gr->group.group_size) {
177           // The group is already full.
178           gr->status =
179               errors::Internal("Device ", device.name(),
180                                " is joining a group that is already full",
181                                " (group_key=", gr->group.group_key, ")");
182         } else {
183           // This is a new device that has not yet joined the group.
184           gr->incarnations_by_device_name[device.name()] = device.incarnation();
185           gr->group.devices.push_back(device);
186           new_device = true;
187           if (VLOG_IS_ON(1)) {
188             string dev_buf;
189             for (const auto& d : gr->group.devices) {
190               strings::StrAppend(&dev_buf, ",", d.name());
191             }
192             VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
193                     << " group_size=" << gr->group.group_size << " (current"
194                     << " devices)=(" << dev_buf << ") (number of"
195                     << " devices pending)="
196                     << (gr->group.group_size - gr->group.devices.size());
197           }
198         }
199       } else {
200         // If the device already exists, check if the incarnation matches.
201         if (it->second != device.incarnation()) {
202           gr->status = errors::FailedPrecondition(
203               "Device ", device.name(),
204               " current incarnation doesn't match with one in the group. This "
205               "usually means this worker has restarted but the collective "
206               "leader hasn't, or this worker connects to a wrong cluster.");
207         }
208       }
209     }
210 
211     if (gr->status.ok()) {
212       // If the group is not yet complete, queue to wait for it.
213       VLOG(2) << "group_size " << gr->group.group_size << " set size "
214               << gr->group.devices.size() << " gr " << gr;
215 
216       if (gr->group.devices.size() < gr->group.group_size) {
217         gr->pending_done.push_back(std::move(done));
218         gr->pending_params.push_back(group_params);
219         return;
220       }
221       CHECK_EQ(gr->group.devices.size(), gr->group.group_size);
222       // We get a full group. Fill in remaining fields in gr->group.
223       if (new_device) {
224         FinishGroup(gr);
225       }
226       // Copy to all pending CollGroupParams;
227       *group_params = gr->group;
228       for (auto* params : gr->pending_params) {
229         *params = gr->group;
230       }
231     }
232     // At this point, we either have a full group, or an error status.  Ensure
233     // that all callbacks are invoked with the appropriate status.
234     to_be_called.swap(gr->pending_done);
235     gr->pending_params.clear();
236     status = gr->status;
237   }
238   done(status);
239   for (int i = 0; i < to_be_called.size(); ++i) {
240     to_be_called[i](status);
241   }
242 }
243 
244 namespace {
245 struct DevRec {
246   string task;
247   string device;
248   int original_rank;
249   int local_rank;
250   int global_rank;
251   const DeviceLocality* locality;
252 };
253 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
254 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
255 
256 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
BuildDevRecs(const CollGroupParams & gp)257 GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp) {
258   GlobalDeviceMap gdm;
259   CHECK_EQ(gp.devices.size(), gp.task_names.size());
260   for (int i = 0; i < gp.devices.size(); ++i) {
261     TaskDeviceMap& tdm = gdm[gp.task_names[i]];
262     DevRec* dr = &tdm[gp.devices[i].name()];
263     dr->task = gp.task_names[i];
264     dr->device = gp.devices[i].name();
265     dr->original_rank = i;
266     dr->local_rank = 0;   // Will be populated later by OrderTaskDeviceMap.
267     dr->global_rank = 0;  // Will be populated later by EstablishGlobalRank.
268     dr->locality = &gp.devices[i].locality();
269   }
270   return gdm;
271 }
272 
ParseRingOrder(const string & gpu_ring_order_str,TaskDeviceMap * tdm)273 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
274   std::vector<string> split_gpu_ring_order_str =
275       str_util::Split(gpu_ring_order_str, ',');
276   if (split_gpu_ring_order_str.size() != tdm->size()) return false;
277 
278   // gpu id -> local rank
279   gtl::FlatMap<int32, int32> gpu_ranks;
280   for (int32_t rank = 0;
281        rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
282     int32_t tmp;
283     if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
284       gpu_ranks[tmp] = rank;
285     } else {
286       return false;
287     }
288   }
289 
290   for (auto& tdm_it : *tdm) {
291     DeviceNameUtils::ParsedName parsed_name;
292     DevRec* dr = &tdm_it.second;
293     if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
294       return false;
295     }
296     auto rank_it = gpu_ranks.find(parsed_name.id);
297     if (rank_it == gpu_ranks.end()) return false;
298     dr->local_rank = rank_it->second;
299   }
300   VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
301   return true;
302 }
303 
OrderTaskDeviceMap(const string & gpu_ring_order,TaskDeviceMap * tdm)304 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
305   CHECK_GT(tdm->size(), 0);  // Should never be called with 0 devices
306 
307   // If a valid ring order has been passed in via ConfigProto, use that.
308   if (ParseRingOrder(gpu_ring_order, tdm)) return;
309 
310   // Either no ring order was passed in, or the format was unexpected.
311   // We now assign a ring order based on link strengths.  Note that this
312   // algorithm is not optimal and may not always find the best ring order.
313   int least_rank = -1;
314   string next_device;
315   std::set<string> selected;
316   // Starting device is one with the least initial rank.
317   for (const auto& it : *tdm) {
318     if (least_rank < 0 || it.second.original_rank < least_rank) {
319       least_rank = it.second.original_rank;
320       next_device = it.second.device;
321     }
322   }
323   CHECK_GE(least_rank, 0);
324   DeviceNameUtils::ParsedName parsed_name;
325   CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
326   // NOTE: InterconnectLink has only a device_id, nothing more, so for
327   // the time being if there's more than one device at a task we
328   // assume they're all GPUs.
329 
330   int next_rank = 0;
331   while (true) {
332     selected.insert(next_device);
333     auto next_dev_it = tdm->find(next_device);
334     CHECK(next_dev_it != tdm->end());
335     DevRec* dr = &next_dev_it->second;
336     dr->local_rank = next_rank;
337     ++next_rank;
338     if (selected.size() == tdm->size()) {
339       break;
340     }
341     // For the present time we assume Locality links only cover GPUs.
342     // For multiple CPUs, just take them in order.
343     const InterconnectLink* best_link = nullptr;
344     if (parsed_name.type == "GPU") {
345       for (const InterconnectLink& il : dr->locality->links().link()) {
346         parsed_name.id = il.device_id();
347         string endpoint_device =
348             DeviceNameUtils::ParsedNameToString(parsed_name);
349         // Skip the device if we've already seen it.
350         if (selected.find(endpoint_device) != selected.end()) {
351           continue;
352         }
353         // Skip the device if it is not participating in this collective
354         // instance.
355         if (tdm->find(endpoint_device) == tdm->end()) {
356           continue;
357         }
358         if (best_link == nullptr || il.strength() > best_link->strength()) {
359           best_link = &il;
360         }
361       }
362     }
363     if (best_link != nullptr) {
364       // Follow the best edge
365       parsed_name.id = best_link->device_id();
366       next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
367     } else {
368       // No good edges, alas. Pick the lowest initial rank among remaining
369       // devices.
370       least_rank = -1;
371       for (const auto& it : *tdm) {
372         if (selected.find(it.second.device) != selected.end()) {
373           continue;
374         }
375         if (least_rank < 0 || it.second.original_rank < least_rank) {
376           least_rank = it.second.original_rank;
377           next_device = it.second.device;
378         }
379       }
380       CHECK_GE(least_rank, 0);
381     }
382   }
383 }
384 
385 // The first time a CollGroupParams is established for a group we compute a good
386 // rank order for all the devices in the group, that is appropriate for a ring
387 // algorithm.
EstablishGlobalRank(const CollGroupParams & gp,const string & gpu_ring_order)388 GlobalDeviceMap EstablishGlobalRank(const CollGroupParams& gp,
389                                     const string& gpu_ring_order) {
390   VLOG(1) << "EstablishGlobalRank";
391   GlobalDeviceMap gdm = BuildDevRecs(gp);
392   for (auto& iter : gdm) {
393     TaskDeviceMap& tdm = iter.second;
394     OrderTaskDeviceMap(gpu_ring_order, &tdm);
395   }
396   // Connect the global rank order by the order in which tasks first appear.
397   std::set<string> ordered_tasks;
398   int next_rank = 0;
399   for (int i = 0; i < gp.task_names.size(); ++i) {
400     const string& task_name = gp.task_names[i];
401     if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
402       continue;
403     }
404     ordered_tasks.insert(task_name);
405     TaskDeviceMap* tdm = &gdm[task_name];
406     for (auto& it : *tdm) {
407       it.second.global_rank = it.second.local_rank + next_rank;
408     }
409     next_rank += tdm->size();
410   }
411   return gdm;
412 }
413 
414 // Count the devices associated with each task and set
415 // gp->same_num_devices_per_task.  Requires gp->task_names
416 // be sorted.
SetDevPerTask(CollGroupParams * gp)417 void SetDevPerTask(CollGroupParams* gp) {
418   gp->num_devices_per_task.clear();
419   const string* last_task_name = &gp->task_names[0];
420   int count = 0;
421   for (const string& task_name : gp->task_names) {
422     if (task_name == *last_task_name) {
423       ++count;
424     } else {
425       gp->num_devices_per_task[*last_task_name] = count;
426       count = 1;
427       last_task_name = &task_name;
428     }
429   }
430   gp->num_devices_per_task[*last_task_name] = count;
431 
432   gp->same_num_devices_per_task = false;
433   int dev_per_task = -1;
434   for (const auto& task_dev : gp->num_devices_per_task) {
435     if (dev_per_task == -1) {
436       dev_per_task = task_dev.second;
437     } else if (dev_per_task != task_dev.second) {
438       return;
439     }
440   }
441   gp->same_num_devices_per_task = true;
442 }
443 
444 }  // namespace
445 
FinishGroup(GroupRec * gr)446 void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
447   // Sort devices lexicographically first.
448   std::sort(gr->group.devices.begin(), gr->group.devices.end(),
449             [](const DeviceAttributes& lhs, const DeviceAttributes& rhs) {
450               return lhs.name() < rhs.name();
451             });
452   // Build task_names, which is needed by CompleteDefaultRanking.
453   gr->group.task_names.reserve(gr->group.devices.size());
454   for (const DeviceAttributes& device : gr->group.devices) {
455     gr->group.task_names.push_back(TaskNameFromDeviceName(device.name()));
456   }
457   // Establish the final order of gp->devices and gp->task_names by
458   // considering localities of all devices.
459   CompleteDefaultRanking(&gr->group);
460   SetDevPerTask(&gr->group);
461   gr->group.num_tasks =
462       static_cast<int32>(gr->group.num_devices_per_task.size());
463 }
464 
CancelGroup(int32 group_key)465 void CollectiveParamResolverLocal::CancelGroup(int32 group_key) {
466   std::vector<StatusCallback> pending_done;
467   GroupRec* gr = nullptr;
468   {
469     mutex_lock l(group_mu_);
470     auto it = group_table_.find(group_key);
471     if (it == group_table_.end()) {
472       return;
473     }
474     gr = it->second.get();
475   }
476   {
477     mutex_lock l(gr->mu);
478     if (gr->group.devices.size() == gr->group.group_size) {
479       // The group is already complete. There's no need to cancel.
480       return;
481     }
482     gr->status = errors::Cancelled("group is cancelled");
483     pending_done.swap(gr->pending_done);
484     gr->pending_params.clear();
485   }
486   for (const StatusCallback& done : pending_done) {
487     done(errors::Cancelled("group is cancelled"));
488   }
489 }
490 
CompleteTaskIsLocal(const string & task_name,CollectiveParams * cp)491 void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
492                                                        CollectiveParams* cp) {
493   cp->task.is_local.resize(cp->group.group_size, false);
494   for (int i = 0; i < cp->group.group_size; ++i) {
495     cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
496   }
497 }
498 
SetDefaultRank(const string & device,CollectiveParams * cp)499 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
500                                                   CollectiveParams* cp) {
501   CHECK_EQ(cp->group.group_size, cp->group.devices.size()) << cp->ToString();
502   for (int i = 0; i < cp->group.group_size; ++i) {
503     if (cp->group.devices[i].name() == device) {
504       cp->default_rank = i;
505       break;
506     }
507   }
508 }
509 
InitInstanceSharedParams(const CollectiveParams * cp,InstanceRec * ir)510 void CollectiveParamResolverLocal::InitInstanceSharedParams(
511     const CollectiveParams* cp, InstanceRec* ir) {
512   ir->shared->instance = cp->instance;
513   ir->shared->default_rank = -1;
514 
515   // Set is_local and task_names in *shared prior to invoking
516   // GetDeviceAttributesAsync.  In a distributed context this function can be
517   // called by a derived class, some of the devices may be non-local and
518   // GetDeviceAttributesAsync will use those fields to launch RPCs.
519   CompleteTaskIsLocal(task_name_, ir->shared);
520 }
521 
522 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
523 // to all devices that they are physically connected to and visible to the
524 // TensorFlow runtime.  This set of devices may be a superset of the devices
525 // participating in this instance of collectives.
CompleteDefaultRanking(CollGroupParams * gp)526 void CollectiveParamResolverLocal::CompleteDefaultRanking(CollGroupParams* gp) {
527   // Establish an instance-specific default rank order for devices
528   // based on localities.  This rank order should be a good ring
529   // order, if possible.
530   GlobalDeviceMap gdm = EstablishGlobalRank(*gp, gpu_ring_order_);
531   // Reflect the new global ranking on shared
532   std::vector<DeviceAttributes> new_devices(gp->group_size);
533   std::vector<string> new_task_names(gp->group_size);
534   for (const auto& git : gdm) {
535     const TaskDeviceMap& tdm = git.second;
536     for (const auto& tit : tdm) {
537       const DevRec& dr = tit.second;
538       new_devices[dr.global_rank] = gp->devices[dr.original_rank];
539       new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
540     }
541   }
542 
543   if (VLOG_IS_ON(2)) {
544     string buf;
545     for (const auto& d : new_devices) strings::StrAppend(&buf, "\n", d.name());
546     VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
547             << buf;
548   }
549   gp->devices = std::move(new_devices);
550   gp->task_names = std::move(new_task_names);
551 }
552 
553 CollectiveParamResolverLocal::InstanceRec*
GetOrCreateInstanceRec(CollectiveParams * cp,bool * created)554 CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp,
555                                                      bool* created) {
556   *created = false;
557   InstanceRec* irec = nullptr;
558   {
559     mutex_lock l(instance_mu_);
560     auto group_it = instance_table_.find(cp->group.group_key);
561     if (group_it != instance_table_.end()) {
562       auto instance_it = group_it->second.find(cp->instance.instance_key);
563       if (instance_it != group_it->second.end()) {
564         irec = instance_it->second.get();
565       }
566     }
567     if (irec == nullptr) {
568       // Create new InstanceRec.
569       irec = new InstanceRec;
570       *created = true;
571       {
572         mutex_lock il(irec->mu);
573         irec->known.resize(cp->group.group_size, false);
574       }
575       InitInstanceSharedParams(cp, irec);
576       instance_table_[cp->group.group_key][cp->instance.instance_key].reset(
577           irec);
578     }
579   }
580   Status status;
581   {
582     mutex_lock l(status_mu_);
583     status = status_;
584   }
585   if (!status.ok()) {
586     mutex_lock l(irec->mu);
587     irec->status = status;
588   }
589   return irec;
590 }
591 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)592 void CollectiveParamResolverLocal::CompleteParamsAsync(
593     const DeviceAttributes& device, CollectiveParams* cp,
594     CancellationManager* cancel_mgr, const StatusCallback& done) {
595   VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
596           << cp->ToString();
597   CompleteGroupLocal(device, &cp->group, cancel_mgr,
598                      [this, device, cp, done](const Status& s) {
599                        if (s.ok()) {
600                          CompleteInstanceLocal(device.name(), cp, done);
601                        } else {
602                          done(s);
603                        }
604                      });
605 }
606 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)607 void CollectiveParamResolverLocal::CompleteInstanceAsync(
608     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
609     CancellationManager* cancel_mgr, const StatusCallback& done) {
610   done(
611       errors::Internal("CompleteInstance is not implemented by "
612                        "CollectiveParamResolverLocal which is "
613                        "intended only for non-distributed deployment."));
614 }
615 
616 // TODO(b/111897089): we need a better way to pick the collective
617 // implementation.  The ideal way would depend upon the topology and link
618 // strength before picking a particular implementation.
AssignCollectiveType(CollectiveParams * cp)619 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
620   // We use the NCCL implementation if this is an environment which supports
621   // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
622   // also if indicated either in `ConfigProto` or `communication_hint`.
623   //
624   // After enough testing, we may simplify this logic to use NCCL whenever
625   // available.
626   CollectiveImplementationInterface* col_impl;
627   bool use_nccl =
628       (nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
629       CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
630           .ok();
631   cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
632   VLOG(1) << "AssignCollectiveType "
633           << cp->instance.impl_details.collective_name;
634 }
635 
CompleteInstanceLocal(const string & device,CollectiveParams * cp,const StatusCallback & done)636 void CollectiveParamResolverLocal::CompleteInstanceLocal(
637     const string& device, CollectiveParams* cp, const StatusCallback& done) {
638   VLOG(1) << "CompleteInstanceLocal " << device
639           << " instance_key: " << cp->instance.instance_key << " group_key "
640           << cp->group.group_key;
641 
642   bool created_irec;
643   InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
644   if (!created_irec) {
645     // Check that the preexisting IRec is consistent with the params passed into
646     // this invocation.
647     if (ir->shared->instance.type != cp->instance.type ||
648         ir->shared->instance.data_type != cp->instance.data_type) {
649       done(errors::Internal("Collective instance ", cp->instance.instance_key,
650                             " expected type ", ir->shared->instance.type,
651                             " and data_type ", ir->shared->instance.data_type,
652                             " but got type ", cp->instance.type,
653                             " and data_type ", cp->instance.data_type));
654       return;
655     }
656   }
657   CompleteInstanceFromInitializedIRec(device, cp, ir, done);
658 }
659 
CompleteInstanceFromInitializedIRec(const string & device,CollectiveParams * cp,InstanceRec * ir,const StatusCallback & done)660 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
661     const string& device, CollectiveParams* cp, InstanceRec* ir,
662     const StatusCallback& done) {
663   auto expected_shape = cp->instance.shape;
664   Status status;
665   // Populate the fields common across instance.
666   {
667     mutex_lock l(ir->mu);
668     status = ir->status;
669     if (status.ok()) {
670       // custom operator= does a deep copy.
671       cp->instance = ir->shared->instance;
672     }
673   }
674   if (!status.ok()) {
675     done(status);
676     return;
677   }
678   if (expected_shape != cp->instance.shape) {
679     done(errors::InvalidArgument(
680         "Shape mismatch in the collective instance ", cp->instance.instance_key,
681         ". Op at device ", device, " expected shape ",
682         expected_shape.DebugString(), " but another member in the group ",
683         "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
684         " due to different input shapes at different members of the collective",
685         " op."));
686     return;
687   }
688   // Populate the fields common across task.
689   AssignCollectiveType(cp);
690   SetDefaultRank(device, cp);
691   CompleteTaskIsLocal(task_name_, cp);
692 
693   CollectiveImplementationInterface* col_impl;
694   status = CollectiveRegistry::LookupParamResolverInstance(
695       cp->instance.impl_details.collective_name, &col_impl);
696   if (!status.ok()) {
697     done(status);
698     return;
699   }
700 
701   //  We may need to wait for the group, if this is a broadcast, for source
702   //  discovery.
703   if (cp->instance.type == BROADCAST_COLLECTIVE) {
704     WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) {
705       Status s;
706       if (ir != irec) {
707         s = errors::Internal("Expected ir ", ir, " and irec ", irec,
708                              " to be equal");
709       } else {
710         mutex_lock l(irec->mu);
711         s = irec->status;
712         cp->source_rank = irec->source_rank;
713       }
714       if (s.ok()) {
715         s = col_impl->InitializeCollectiveParams(cp);
716       }
717       done(s);
718     });
719   } else {
720     done(col_impl->InitializeCollectiveParams(cp));
721   }
722 }
723 
WaitForGroup(InstanceRec * ir,CollectiveParams * cp,const IRConsumer & f)724 void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
725                                                 CollectiveParams* cp,
726                                                 const IRConsumer& f) {
727   std::vector<IRConsumer> ready_waiters;
728   do {
729     mutex_lock l(ir->mu);
730     if (!ir->status.ok()) {
731       break;
732     }
733     CHECK_EQ(cp->group.group_size, ir->known.size());
734     CHECK_GE(cp->default_rank, 0);
735     if (!ir->known[cp->default_rank]) {
736       ir->known[cp->default_rank] = true;
737       ++ir->known_count;
738       if (cp->is_source) {
739         // Initialize source rank.
740         if (ir->source_rank >= 0) {
741           ir->status = errors::Internal("Instance ", cp->instance.instance_key,
742                                         " already has source ", ir->source_rank,
743                                         ", received second claim from ",
744                                         cp->default_rank);
745         } else {
746           ir->source_rank = cp->default_rank;
747         }
748       }
749     }
750     if (ir->known_count < cp->group.group_size) {
751       ir->known_waiters.push_back(f);
752       return;
753     }
754     CHECK_EQ(ir->known_count, cp->group.group_size);
755     if (ir->source_rank < 0) {
756       // NOTE(ayushd): changing the error message below would also require
757       // updating CompleteParamsBroadcastForgotSend test in
758       // CollectiveParamResolverLocalTest.
759       ir->status =
760           errors::Internal("Instance ", cp->instance.instance_key,
761                            " found no source for broadcast.  This "
762                            "could mean that there were group_size=",
763                            ir->known_count, " BcastRecvs but no BcastSend.");
764     }
765     if (!ir->known_waiters.empty()) {
766       ready_waiters = std::move(ir->known_waiters);
767     }
768   } while (false);
769   f(ir);
770   for (auto& f : ready_waiters) {
771     f(ir);
772   }
773 }
774 
StartAbort(const Status & s)775 void CollectiveParamResolverLocal::StartAbort(const Status& s) {
776   {
777     mutex_lock l(status_mu_);
778     if (!status_.ok()) {
779       VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
780                  "subsequent abortion with status: "
781               << s;
782       return;
783     }
784     status_ = s;
785   }
786   StartAbortLocal(s);
787 }
788 
StartAbortLocal(const Status & s)789 void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
790   std::vector<StatusCallback> pending_done;
791   {
792     mutex_lock l(group_mu_);
793     for (const auto& item : group_table_) {
794       GroupRec* gr = item.second.get();
795       {
796         mutex_lock gl(gr->mu);
797         gr->status = s;
798         for (auto& done : gr->pending_done) {
799           pending_done.push_back(std::move(done));
800         }
801         gr->pending_done.clear();
802         gr->pending_params.clear();
803       }
804     }
805   }
806   for (const StatusCallback& done : pending_done) {
807     done(s);
808   }
809   std::vector<InstanceRec*> instances;
810   {
811     mutex_lock l(instance_mu_);
812     for (const auto& group_entry : instance_table_) {
813       for (const auto& item : group_entry.second) {
814         instances.push_back(item.second.get());
815       }
816     }
817   }
818   for (InstanceRec* ir : instances) {
819     std::vector<IRConsumer> known_waiters;
820     {
821       mutex_lock il(ir->mu);
822       ir->status = s;
823       known_waiters.swap(ir->known_waiters);
824     }
825     for (const IRConsumer& done : known_waiters) {
826       done(ir);
827     }
828   }
829 }
830 
831 }  // namespace tensorflow
832