1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16
17 #include <stddef.h>
18
19 #include <algorithm>
20 #include <unordered_set>
21 #include <utility>
22
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/device_attributes.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/gtl/flatmap.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace tensorflow {
40
CollectiveParamResolverLocal(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,const string & task_name)41 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
42 const ConfigProto& config, const DeviceMgr* dev_mgr,
43 DeviceResolverInterface* dev_resolver, const string& task_name)
44 : nccl_(config.experimental().collective_nccl()),
45 dev_mgr_(dev_mgr),
46 dev_resolver_(dev_resolver),
47 task_name_(task_name) {}
48
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)49 void CollectiveParamResolverLocal::CompleteGroupAsync(
50 const CompleteGroupRequest* request, CompleteGroupResponse* response,
51 CancellationManager* cancel_mgr, const StatusCallback& done) {
52 done(
53 errors::Internal("CompleteGroup is not implemented by "
54 "CollectiveParamResolverLocal which is "
55 "intended only for non-distributed deployment."));
56 }
57
58 namespace {
GetCollectiveName(const CollectiveParams * cp,bool nccl)59 const char* GetCollectiveName(const CollectiveParams* cp, bool nccl) {
60 switch (cp->instance.type) {
61 case BROADCAST_COLLECTIVE:
62 return nccl ? "NcclBroadcast" : "HierarchicalTreeBroadcast";
63
64 case REDUCTION_COLLECTIVE:
65 return nccl ? "NcclReduce" : "RingReduce";
66
67 case GATHER_COLLECTIVE:
68 return nccl ? "NcclGather" : "RingGather";
69
70 case PERMUTE_COLLECTIVE:
71 return "Permute";
72
73 default:
74 return "undef";
75 }
76 }
77
TaskNameFromDeviceName(const string & device_name)78 string TaskNameFromDeviceName(const string& device_name) {
79 DeviceNameUtils::ParsedName parsed_device;
80 CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_device));
81 string task_name;
82 CHECK(DeviceNameUtils::GetTaskName(parsed_device, &task_name));
83 return task_name;
84 }
85 } // namespace
86
CompleteGroupLocal(const DeviceAttributes & device,CollectiveParams * cp,const GroupRecCallback & done,CancellationManager * cancel_mgr)87 void CollectiveParamResolverLocal::CompleteGroupLocal(
88 const DeviceAttributes& device, CollectiveParams* cp,
89 const GroupRecCallback& done, CancellationManager* cancel_mgr) {
90 VLOG(1) << "CompleteGroupLocal device=" << device.name() << " cp: " << cp
91 << ": " << cp->ToString();
92 std::vector<StatusCallback> to_be_called;
93 // Keep a reference to `cp` to avoid racing with deletion due to cancellation.
94 cp->Ref();
95 core::ScopedUnref cp_unref(cp);
96
97 std::function<void(const Status& s, GroupRec* gr)> done_with_cleanup;
98 if (cancel_mgr != nullptr) {
99 auto cancelled_mu = std::make_shared<mutex>();
100 // Some callers delete `cancel_mgr` as soon as `done` is called once,
101 // meaning we can't rely on it to avoid calling `done` twice if the local op
102 // is cancelled but the group succeeds.
103 auto cancelled = std::make_shared<bool>(false);
104 const CancellationToken token = cancel_mgr->get_cancellation_token();
105 const bool already_cancelled =
106 !cancel_mgr->RegisterCallback(token, [done, cancelled_mu, cancelled]() {
107 {
108 mutex_lock l(*cancelled_mu);
109 *cancelled = true;
110 }
111 done(errors::Cancelled("op cancelled"), nullptr);
112 });
113 if (already_cancelled) {
114 done(errors::Cancelled("op cancelled"), nullptr);
115 return;
116 }
117 done_with_cleanup = [cancel_mgr, done, cancelled_mu, cancelled, token](
118 const Status& s, GroupRec* gr) {
119 {
120 mutex_lock l(*cancelled_mu);
121 if (*cancelled || !cancel_mgr->TryDeregisterCallback(token)) {
122 return;
123 }
124 }
125 // The operation was never cancelled, so we'll return a normal status.
126 done(s, gr);
127 };
128 } else {
129 done_with_cleanup = done;
130 }
131
132 GroupRec* gr = nullptr;
133 Status status;
134 {
135 mutex_lock l(group_mu_);
136 auto it = group_table_.find(cp->group.group_key);
137 if (it == group_table_.end()) {
138 gr = new GroupRec;
139 mutex_lock grl(gr->mu);
140 gr->group.group_key = cp->group.group_key;
141 gr->group.group_size = cp->group.group_size;
142 gr->group.device_type = cp->group.device_type;
143 gr->group.gpu_ring_order = cp->group.gpu_ring_order;
144
145 // Initialize group runtime details.
146 CollectiveImplementationInterface* col_impl;
147 // Try to lookup a NCCL collective kernel. This will return error status
148 // if `NcclReduce` kernel is not present in the registry, e.g. on an
149 // environment that does not support NCCL.
150 status = CollectiveRegistry::LookupParamResolverInstance("NcclReduce",
151 &col_impl);
152 if (!status.ok()) {
153 // Fallback to non-NCCL collective.
154 status = CollectiveRegistry::LookupParamResolverInstance(
155 GetCollectiveName(cp, /*nccl=*/false), &col_impl);
156 }
157 if (status.ok()) {
158 status = col_impl->InitializeCollectiveGroupRuntimeDetails(
159 &gr->group.runtime_details);
160 }
161
162 if (!status.ok()) {
163 done_with_cleanup(status, gr);
164 return;
165 }
166
167 // Store GroupRec in group_table_ which is shared between all devices on
168 // this worker.
169 group_table_[gr->group.group_key].reset(gr);
170 VLOG(2) << "New group_key=" << gr->group.group_key
171 << " group_size=" << gr->group.group_size
172 << " runtime_details=" << gr->group.runtime_details.ToString();
173 } else {
174 gr = it->second.get();
175 }
176 }
177 {
178 mutex_lock l(status_mu_);
179 status = status_;
180 }
181 if (!status.ok()) {
182 done_with_cleanup(status, nullptr);
183 return;
184 }
185 {
186 mutex_lock gr_lock(gr->mu);
187 // If there is ever an error associated with a group key, we store the error
188 // status and invoke all waiting and future callbacks with this error
189 // status.
190 VLOG(2) << "gr device_type=" << gr->group.device_type
191 << " cp device_type=" << cp->group.device_type
192 << " current device=" << device.name();
193 if (gr->status.ok()) {
194 // Check for consistency with existing GroupRec.
195 if (cp->group.device_type != gr->group.device_type) {
196 gr->status = errors::Internal(
197 "Collective Op ", cp->name, " is assigned to device ",
198 device.name(), " with type ", cp->group.device_type.type_string(),
199 " and group_key ", cp->group.group_key, " but that group has type ",
200 gr->group.device_type.type_string());
201 } else if (cp->group.group_size != gr->group.group_size) {
202 gr->status = errors::Internal(
203 "Collective Op ", cp->name, " has group_size ",
204 cp->group.group_size, " and group_key ", cp->group.group_key,
205 " but that group has size ", gr->group.group_size);
206 }
207 }
208 bool new_device = false;
209 if (gr->status.ok()) {
210 // Insert device if not already present.
211 auto it = gr->devices.find(device.name());
212 if (it == gr->devices.end()) {
213 if (gr->devices.size() == gr->group.group_size) {
214 // The group is already full.
215 gr->status = errors::Internal(
216 "Collective Op ", cp->name, " is assigned to device ",
217 device.name(), " and group_key ", cp->group.group_key,
218 " but that group doesn't contain that device.");
219 } else {
220 // This is a new device that has not yet joined the group.
221 gr->devices[device.name()] = device;
222 new_device = true;
223 if (VLOG_IS_ON(1)) {
224 string dev_buf;
225 for (const auto& d : gr->devices) {
226 strings::StrAppend(&dev_buf, ",", d.first);
227 }
228 VLOG(1) << "CompleteGroupLocal group_key=" << gr->group.group_key
229 << " group_size=" << gr->group.group_size << " (current"
230 << " devices)=(" << dev_buf << ") (number of"
231 << " devices pending)="
232 << (gr->group.group_size - gr->devices.size());
233 }
234 }
235 } else {
236 // If the device already exists, check if the incarnation matches.
237 if (it->second.incarnation() != device.incarnation()) {
238 gr->status = errors::FailedPrecondition(
239 "Device ", device.name(),
240 " current incarnation doesn't match with one in the group. This "
241 "usually means this worker has restarted but the collective "
242 "leader hasn't, or this worker connects to a wrong cluster.");
243 }
244 }
245 }
246
247 if (gr->status.ok()) {
248 // If the group is not yet complete, queue to wait for it.
249 VLOG(2) << "group_size " << gr->group.group_size << " set size "
250 << gr->devices.size() << " gr " << gr;
251
252 if (gr->devices.size() < gr->group.group_size) {
253 gr->waiting.push_back(
254 std::bind(done_with_cleanup, std::placeholders::_1, gr));
255 return;
256 }
257 CHECK_EQ(gr->devices.size(), gr->group.group_size);
258 // We get a full group. Fill in remaining fields in gr->group.
259 if (new_device) {
260 FinishGroup(gr);
261 }
262 }
263 // At this point, we either have a full group, or an error status. Ensure
264 // that all callbacks are invoked with the appropriate status.
265 if (!gr->waiting.empty()) {
266 std::swap(to_be_called, gr->waiting);
267 }
268 status = gr->status;
269 }
270 done_with_cleanup(status, gr);
271 for (int i = 0; i < to_be_called.size(); ++i) {
272 to_be_called[i](status);
273 }
274 }
275
276 namespace {
277 struct DevRec {
278 string task;
279 string device;
280 int original_rank;
281 int local_rank;
282 int global_rank;
283 const DeviceLocality* locality;
284 };
285 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
286 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
287
288 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
BuildDevRecs(const CollGroupParams & gp,const std::vector<DeviceAttributes> & attributes)289 GlobalDeviceMap BuildDevRecs(const CollGroupParams& gp,
290 const std::vector<DeviceAttributes>& attributes) {
291 GlobalDeviceMap gdm;
292 CHECK_EQ(gp.device_names.size(), gp.task_names.size());
293 CHECK_EQ(gp.device_names.size(), attributes.size());
294 for (int i = 0; i < gp.device_names.size(); ++i) {
295 TaskDeviceMap& tdm = gdm[gp.task_names[i]];
296 DevRec* dr = &tdm[gp.device_names[i]];
297 dr->task = gp.task_names[i];
298 dr->device = gp.device_names[i];
299 dr->original_rank = i;
300 dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
301 dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
302 dr->locality = &attributes[i].locality();
303 }
304 return gdm;
305 }
306
ParseRingOrder(const string & gpu_ring_order_str,TaskDeviceMap * tdm)307 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
308 std::vector<string> split_gpu_ring_order_str =
309 str_util::Split(gpu_ring_order_str, ',');
310 if (split_gpu_ring_order_str.size() != tdm->size()) return false;
311
312 // gpu id -> local rank
313 gtl::FlatMap<int32, int32> gpu_ranks;
314 for (int32 rank = 0;
315 rank < static_cast<int32>(split_gpu_ring_order_str.size()); ++rank) {
316 int32 tmp;
317 if (strings::safe_strto32(split_gpu_ring_order_str[rank], &tmp)) {
318 gpu_ranks[tmp] = rank;
319 } else {
320 return false;
321 }
322 }
323
324 for (auto& tdm_it : *tdm) {
325 DeviceNameUtils::ParsedName parsed_name;
326 DevRec* dr = &tdm_it.second;
327 if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
328 return false;
329 }
330 auto rank_it = gpu_ranks.find(parsed_name.id);
331 if (rank_it == gpu_ranks.end()) return false;
332 dr->local_rank = rank_it->second;
333 }
334 VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
335 return true;
336 }
337
OrderTaskDeviceMap(const string & gpu_ring_order,TaskDeviceMap * tdm)338 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
339 CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
340
341 // If a valid ring order has been passed in via ConfigProto, use that.
342 if (ParseRingOrder(gpu_ring_order, tdm)) return;
343
344 // Either no ring order was passed in, or the format was unexpected.
345 // We now assign a ring order based on link strengths. Note that this
346 // algorithm is not optimal and may not always find the best ring order.
347 int least_rank = -1;
348 string next_device;
349 std::set<string> selected;
350 // Starting device is one with the least initial rank.
351 for (const auto& it : *tdm) {
352 if (least_rank < 0 || it.second.original_rank < least_rank) {
353 least_rank = it.second.original_rank;
354 next_device = it.second.device;
355 }
356 }
357 CHECK_GE(least_rank, 0);
358 DeviceNameUtils::ParsedName parsed_name;
359 CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
360 // NOTE: InterconnectLink has only a device_id, nothing more, so for
361 // the time being if there's more than one device at a task we
362 // assume they're all GPUs.
363
364 int next_rank = 0;
365 while (true) {
366 selected.insert(next_device);
367 auto next_dev_it = tdm->find(next_device);
368 CHECK(next_dev_it != tdm->end());
369 DevRec* dr = &next_dev_it->second;
370 dr->local_rank = next_rank;
371 ++next_rank;
372 if (selected.size() == tdm->size()) {
373 break;
374 }
375 // For the present time we assume Locality links only cover GPUs.
376 // For multiple CPUs, just take them in order.
377 const InterconnectLink* best_link = nullptr;
378 if (parsed_name.type == "GPU") {
379 for (const InterconnectLink& il : dr->locality->links().link()) {
380 parsed_name.id = il.device_id();
381 string endpoint_device =
382 DeviceNameUtils::ParsedNameToString(parsed_name);
383 // Skip the device if we've already seen it.
384 if (selected.find(endpoint_device) != selected.end()) {
385 continue;
386 }
387 // Skip the device if it is not participating in this collective
388 // instance.
389 if (tdm->find(endpoint_device) == tdm->end()) {
390 continue;
391 }
392 if (best_link == nullptr || il.strength() > best_link->strength()) {
393 best_link = &il;
394 }
395 }
396 }
397 if (best_link != nullptr) {
398 // Follow the best edge
399 parsed_name.id = best_link->device_id();
400 next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
401 } else {
402 // No good edges, alas. Pick the lowest initial rank among remaining
403 // devices.
404 least_rank = -1;
405 for (const auto& it : *tdm) {
406 if (selected.find(it.second.device) != selected.end()) {
407 continue;
408 }
409 if (least_rank < 0 || it.second.original_rank < least_rank) {
410 least_rank = it.second.original_rank;
411 next_device = it.second.device;
412 }
413 }
414 CHECK_GE(least_rank, 0);
415 }
416 }
417 }
418
419 // The first time a CollGroupParams is established for a group we compute a good
420 // rank order for all the devices in the group, that is appropriate for a ring
421 // algorithm.
EstablishGlobalRank(const CollGroupParams & gp,const std::vector<DeviceAttributes> & attributes)422 GlobalDeviceMap EstablishGlobalRank(
423 const CollGroupParams& gp,
424 const std::vector<DeviceAttributes>& attributes) {
425 VLOG(1) << "EstablishGlobalRank";
426 GlobalDeviceMap gdm = BuildDevRecs(gp, attributes);
427 for (auto& iter : gdm) {
428 TaskDeviceMap& tdm = iter.second;
429 OrderTaskDeviceMap(gp.gpu_ring_order, &tdm);
430 }
431 // Connect the global rank order by the order in which tasks first appear.
432 std::set<string> ordered_tasks;
433 int next_rank = 0;
434 for (int i = 0; i < gp.task_names.size(); ++i) {
435 const string& task_name = gp.task_names[i];
436 if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
437 continue;
438 }
439 ordered_tasks.insert(task_name);
440 TaskDeviceMap* tdm = &gdm[task_name];
441 for (auto& it : *tdm) {
442 it.second.global_rank = it.second.local_rank + next_rank;
443 }
444 next_rank += tdm->size();
445 }
446 return gdm;
447 }
448
449 // Count the devices associated with each task and set
450 // gp->same_num_devices_per_task. Requires gp->task_names
451 // be sorted.
SetDevPerTask(CollGroupParams * gp)452 void SetDevPerTask(CollGroupParams* gp) {
453 gp->num_devices_per_task.clear();
454 const string* last_task_name = &gp->task_names[0];
455 int count = 0;
456 for (const string& task_name : gp->task_names) {
457 if (task_name == *last_task_name) {
458 ++count;
459 } else {
460 gp->num_devices_per_task[*last_task_name] = count;
461 count = 1;
462 last_task_name = &task_name;
463 }
464 }
465 gp->num_devices_per_task[*last_task_name] = count;
466
467 gp->same_num_devices_per_task = false;
468 int dev_per_task = -1;
469 for (const auto& task_dev : gp->num_devices_per_task) {
470 if (dev_per_task == -1) {
471 dev_per_task = task_dev.second;
472 } else if (dev_per_task != task_dev.second) {
473 return;
474 }
475 }
476 gp->same_num_devices_per_task = true;
477 CHECK_EQ((gp->group_size % gp->num_tasks), 0);
478 }
479
480 // Sort gp->device_names lexicographically, but do by first
481 // computing a reordering permutation so we can keep gp->task_names
482 // in corresponding order.
SortDevicesAndTasks(CollGroupParams * gp)483 void SortDevicesAndTasks(CollGroupParams* gp) {
484 VLOG(1) << "SortDevicesAndTasks " << gp << " " << gp;
485 CHECK(gp);
486 CHECK_EQ(gp->group_size, gp->device_names.size());
487 CHECK_EQ(gp->group_size, gp->task_names.size());
488 std::vector<int> perm(gp->group_size);
489 // TODO(tucker): substitute std::iota when the windows build supports it.
490 // std::iota(perm.begin(), perm.end(), 0);
491 for (int i = 0; i < perm.size(); ++i) {
492 perm[i] = i;
493 }
494 std::sort(perm.begin(), perm.end(), [gp](int a, int b) {
495 return gp->device_names[a] < gp->device_names[b];
496 });
497 std::vector<string> new_devs;
498 std::vector<string> new_tasks;
499 new_devs.reserve(gp->group_size);
500 new_tasks.reserve(gp->group_size);
501 for (int pi : perm) {
502 new_devs.push_back(gp->device_names[pi]);
503 new_tasks.push_back(gp->task_names[pi]);
504 }
505 gp->device_names = std::move(new_devs);
506 gp->task_names = std::move(new_tasks);
507 VLOG(1) << "Modified device_names on " << gp;
508 SetDevPerTask(gp);
509 }
510 } // namespace
511
FinishGroup(GroupRec * gr)512 void CollectiveParamResolverLocal::FinishGroup(GroupRec* gr) {
513 gr->group.device_names.reserve(gr->devices.size());
514 gr->group.task_names.reserve(gr->devices.size());
515 std::vector<DeviceAttributes> attributes;
516 // Unique tasks. It's used to calculate num_tasks.
517 std::unordered_set<string> tasks;
518 attributes.reserve(gr->devices.size());
519 for (const auto& item : gr->devices) {
520 gr->group.device_names.push_back(item.first);
521 string task_name = TaskNameFromDeviceName(item.first);
522 gr->group.task_names.push_back(task_name);
523 tasks.insert(task_name);
524 attributes.push_back(item.second);
525 }
526 gr->group.num_tasks = static_cast<int32>(tasks.size());
527 // Sort device_names lexicographically, keeping task_names in corresponding
528 // order. Also set number of devices per task.
529 SortDevicesAndTasks(&gr->group);
530 // Establish the final order of gp->device_names and gp->task_names by
531 // considering localities of all devices.
532 CompleteDefaultRanking(attributes, &gr->group);
533 }
534
CompleteTaskIsLocal(const string & task_name,CollectiveParams * cp)535 void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
536 CollectiveParams* cp) {
537 cp->task.is_local.resize(cp->group.group_size, false);
538 for (int i = 0; i < cp->group.group_size; ++i) {
539 cp->task.is_local[i] = (cp->group.task_names[i] == task_name);
540 }
541 }
542
SetDefaultRank(const string & device,CollectiveParams * cp)543 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
544 CollectiveParams* cp) {
545 CHECK_EQ(cp->group.group_size, cp->group.device_names.size()) << cp;
546 for (int i = 0; i < cp->group.group_size; ++i) {
547 if (cp->group.device_names[i] == device) {
548 cp->default_rank = i;
549 break;
550 }
551 }
552 }
553
InitInstanceSharedParams(const GroupRec * gr,const CollectiveParams * cp,InstanceRec * ir)554 void CollectiveParamResolverLocal::InitInstanceSharedParams(
555 const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir) {
556 ir->shared->instance = cp->instance;
557 ir->shared->default_rank = -1;
558
559 // Set is_local and task_names in *shared prior to invoking
560 // GetDeviceAttributesAsync. In a distributed context this function can be
561 // called by a derived class, some of the devices may be non-local and
562 // GetDeviceAttributesAsync will use those fields to launch RPCs.
563 CompleteTaskIsLocal(task_name_, ir->shared);
564 }
565
566 // NOTE(ayushd): The DeviceLocality objects in attributes will have LocalLinks
567 // to all devices that they are physically connected to and visible to the
568 // TensorFlow runtime. This set of devices may be a superset of the devices
569 // participating in this instance of collectives.
CompleteDefaultRanking(const std::vector<DeviceAttributes> & attributes,CollGroupParams * gp)570 void CollectiveParamResolverLocal::CompleteDefaultRanking(
571 const std::vector<DeviceAttributes>& attributes, CollGroupParams* gp) {
572 // Establish an instance-specific default rank order for devices
573 // based on localities. This rank order should be a good ring
574 // order, if possible.
575 GlobalDeviceMap gdm = EstablishGlobalRank(*gp, attributes);
576 // Reflect the new global ranking on shared
577 size_t num_devices = gp->group_size;
578 std::vector<string> new_device_names(num_devices, "");
579 std::vector<string> new_task_names(num_devices, "");
580 for (const auto& git : gdm) {
581 const TaskDeviceMap& tdm = git.second;
582 for (const auto& tit : tdm) {
583 const DevRec& dr = tit.second;
584 new_device_names[dr.global_rank] = gp->device_names[dr.original_rank];
585 new_task_names[dr.global_rank] = gp->task_names[dr.original_rank];
586 }
587 }
588
589 gp->device_names = new_device_names;
590 gp->task_names = new_task_names;
591 if (VLOG_IS_ON(2)) {
592 string buf;
593 for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
594 VLOG(2) << "Optimized device order for group " << gp->group_key << ": "
595 << buf;
596 }
597 }
598
599 CollectiveParamResolverLocal::InstanceRec*
GetOrCreateInstanceRec(const GroupRec * gr,CollectiveParams * cp,bool * created)600 CollectiveParamResolverLocal::GetOrCreateInstanceRec(const GroupRec* gr,
601 CollectiveParams* cp,
602 bool* created) {
603 *created = false;
604 InstanceRec* irec = nullptr;
605 {
606 mutex_lock l(instance_mu_);
607 auto group_it = instance_table_.find(gr->group.group_key);
608 if (group_it != instance_table_.end()) {
609 auto instance_it = group_it->second.find(cp->instance.instance_key);
610 if (instance_it != group_it->second.end()) {
611 irec = instance_it->second.get();
612 }
613 }
614 if (irec == nullptr) {
615 // Create new InstanceRec.
616 irec = new InstanceRec;
617 *created = true;
618 {
619 mutex_lock il(irec->mu);
620 irec->known.resize(cp->group.group_size, false);
621 }
622 InitInstanceSharedParams(gr, cp, irec);
623 instance_table_[gr->group.group_key][cp->instance.instance_key].reset(
624 irec);
625 }
626 }
627 Status status;
628 {
629 mutex_lock l(status_mu_);
630 status = status_;
631 }
632 if (!status.ok()) {
633 mutex_lock l(irec->mu);
634 irec->status = status;
635 }
636 return irec;
637 }
638
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)639 void CollectiveParamResolverLocal::CompleteParamsAsync(
640 const DeviceAttributes& device, CollectiveParams* cp,
641 CancellationManager* cancel_mgr, const StatusCallback& done) {
642 VLOG(1) << "CompleteParams local " << device.name() << " for " << cp << ": "
643 << cp->ToString();
644 CompleteGroupLocal(
645 device, cp,
646 [this, device, cp, done](const Status& s, const GroupRec* gr) {
647 if (s.ok()) {
648 CompleteInstanceLocal(device.name(), gr, cp, cp->is_source, done);
649 } else {
650 done(s);
651 }
652 },
653 cancel_mgr);
654 }
655
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)656 void CollectiveParamResolverLocal::CompleteInstanceAsync(
657 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
658 CancellationManager* cancel_mgr, const StatusCallback& done) {
659 done(
660 errors::Internal("CompleteInstance is not implemented by "
661 "CollectiveParamResolverLocal which is "
662 "intended only for non-distributed deployment."));
663 }
664
665 // TODO(b/111897089): we need a better way to pick the collective
666 // implementation. The ideal way would depend upon the topology and link
667 // strength before picking a particular implementation.
AssignCollectiveType(CollectiveParams * cp)668 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
669 // We use the NCCL implementation if this is an environment which supports
670 // NCCL, i.e. `LookupParamResolverInstance` for `NcclReduce` returns OK, and
671 // also if indicated either in `ConfigProto` or `communication_hint`.
672 //
673 // After enough testing, we may simplify this logic to use NCCL whenever
674 // available.
675 CollectiveImplementationInterface* col_impl;
676 bool use_nccl =
677 (nccl_ || cp->instance.impl_details.communication_hint == "nccl") &&
678 CollectiveRegistry::LookupParamResolverInstance("NcclReduce", &col_impl)
679 .ok();
680 cp->instance.impl_details.collective_name = GetCollectiveName(cp, use_nccl);
681 VLOG(1) << "AssignCollectiveType "
682 << cp->instance.impl_details.collective_name;
683 }
684
CompleteInstanceLocal(const string & device,const GroupRec * gr,CollectiveParams * cp,bool is_source,const StatusCallback & done)685 void CollectiveParamResolverLocal::CompleteInstanceLocal(
686 const string& device, const GroupRec* gr, CollectiveParams* cp,
687 bool is_source, const StatusCallback& done) {
688 VLOG(1) << "CompleteInstanceLocal " << device
689 << " instance_key: " << cp->instance.instance_key << " gr " << gr;
690
691 // Populate the group portion of *cp from *gr. Most of it should already
692 // match.
693 {
694 mutex_lock l(gr->mu);
695 DCHECK_EQ(cp->group.group_key, gr->group.group_key);
696 DCHECK_EQ(cp->group.group_size, gr->group.group_size);
697 DCHECK_EQ(cp->group.device_type, gr->group.device_type);
698 cp->group = gr->group;
699 }
700
701 bool created_irec;
702 InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
703 if (!created_irec) {
704 // Check that the preexisting IRec is consistent with the params passed into
705 // this invocation.
706 if (ir->shared->instance.type != cp->instance.type ||
707 ir->shared->instance.data_type != cp->instance.data_type) {
708 done(errors::Internal("Collective instance ", cp->instance.instance_key,
709 " expected type ", ir->shared->instance.type,
710 " and data_type ", ir->shared->instance.data_type,
711 " but got type ", cp->instance.type,
712 " and data_type ", cp->instance.data_type));
713 return;
714 }
715 }
716 CompleteInstanceFromInitializedIRec(device, gr, cp, ir, is_source, done);
717 }
718
CompleteInstanceFromInitializedIRec(const string & device,const GroupRec * gr,CollectiveParams * cp,InstanceRec * ir,bool is_source,const StatusCallback & done)719 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
720 const string& device, const GroupRec* gr, CollectiveParams* cp,
721 InstanceRec* ir, bool is_source, const StatusCallback& done) {
722 auto expected_shape = cp->instance.shape;
723 Status status;
724 // Populate the fields common across instance.
725 {
726 mutex_lock l(ir->mu);
727 status = ir->status;
728 if (status.ok()) {
729 // custom operator= does a deep copy.
730 cp->instance = ir->shared->instance;
731 }
732 }
733 if (!status.ok()) {
734 done(status);
735 return;
736 }
737 if (expected_shape != cp->instance.shape) {
738 done(errors::InvalidArgument(
739 "Shape mismatch in the collective instance ", cp->instance.instance_key,
740 ". Op at device ", device, " expected shape ",
741 expected_shape.DebugString(), " but another member in the group ",
742 "expected shape ", cp->instance.shape.DebugString(), ". This is likely",
743 " due to different input shapes at different members of the collective",
744 " op."));
745 return;
746 }
747 // Populate the fields common across task.
748 AssignCollectiveType(cp);
749 SetDefaultRank(device, cp);
750 CompleteTaskIsLocal(task_name_, cp);
751
752 CollectiveImplementationInterface* col_impl;
753 status = CollectiveRegistry::LookupParamResolverInstance(
754 cp->instance.impl_details.collective_name, &col_impl);
755 if (!status.ok()) {
756 done(status);
757 return;
758 }
759
760 // We may need to wait for the group, if this is a broadcast, for source
761 // discovery.
762 if (cp->instance.type == BROADCAST_COLLECTIVE) {
763 WaitForGroup(ir, cp, is_source,
764 [col_impl, ir, device, cp, done](InstanceRec* irec) {
765 Status s;
766 if (ir != irec) {
767 s = errors::Internal("Expected ir ", ir, " and irec ",
768 irec, " to be equal");
769 } else {
770 mutex_lock l(irec->mu);
771 s = irec->status;
772 cp->source_rank = irec->source_rank;
773 }
774 if (s.ok()) {
775 s = col_impl->InitializeCollectiveParams(cp);
776 }
777 done(s);
778 });
779 } else {
780 done(col_impl->InitializeCollectiveParams(cp));
781 }
782 }
783
WaitForGroup(InstanceRec * ir,CollectiveParams * cp,bool is_source,const IRConsumer & f)784 void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir,
785 CollectiveParams* cp,
786 bool is_source,
787 const IRConsumer& f) {
788 std::vector<IRConsumer> ready_waiters;
789 do {
790 mutex_lock l(ir->mu);
791 if (!ir->status.ok()) {
792 break;
793 }
794 CHECK_EQ(cp->group.group_size, ir->known.size());
795 CHECK_GE(cp->default_rank, 0);
796 if (!ir->known[cp->default_rank]) {
797 ir->known[cp->default_rank] = true;
798 ++ir->known_count;
799 if (is_source) {
800 // Initialize source rank.
801 if (ir->source_rank >= 0) {
802 ir->status = errors::Internal("Instance ", cp->instance.instance_key,
803 " already has source ", ir->source_rank,
804 ", received second claim from ",
805 cp->default_rank);
806 } else {
807 ir->source_rank = cp->default_rank;
808 }
809 }
810 }
811 if (ir->known_count < cp->group.group_size) {
812 ir->known_waiters.push_back(f);
813 return;
814 }
815 CHECK_EQ(ir->known_count, cp->group.group_size);
816 if (ir->source_rank < 0) {
817 // NOTE(ayushd): changing the error message below would also require
818 // updating CompleteParamsBroadcastForgotSend test in
819 // CollectiveParamResolverLocalTest.
820 ir->status =
821 errors::Internal("Instance ", cp->instance.instance_key,
822 " found no source for broadcast. This "
823 "could mean that there were group_size=",
824 ir->known_count, " BcastRecvs but no BcastSend.");
825 }
826 if (!ir->known_waiters.empty()) {
827 ready_waiters = std::move(ir->known_waiters);
828 }
829 } while (false);
830 f(ir);
831 for (auto& f : ready_waiters) {
832 f(ir);
833 }
834 }
835
StartAbort(const Status & s)836 void CollectiveParamResolverLocal::StartAbort(const Status& s) {
837 {
838 mutex_lock l(status_mu_);
839 if (!status_.ok()) {
840 VLOG(2) << "CollectiveParamResolverLocal already aborted. Ignoring "
841 "subsequent abortion with status: "
842 << s;
843 return;
844 }
845 status_ = s;
846 }
847 StartAbortLocal(s);
848 }
849
StartAbortLocal(const Status & s)850 void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) {
851 {
852 mutex_lock l(group_mu_);
853 for (const auto& item : group_table_) {
854 GroupRec* gr = item.second.get();
855 std::vector<StatusCallback> waiting;
856 {
857 mutex_lock gl(gr->mu);
858 gr->status = s;
859 waiting.swap(gr->waiting);
860 }
861 for (const StatusCallback& done : waiting) {
862 done(s);
863 }
864 }
865 }
866 std::vector<InstanceRec*> instances;
867 {
868 mutex_lock l(instance_mu_);
869 for (const auto& group_entry : instance_table_) {
870 for (const auto& item : group_entry.second) {
871 instances.push_back(item.second.get());
872 }
873 }
874 }
875 for (InstanceRec* ir : instances) {
876 std::vector<IRConsumer> known_waiters;
877 {
878 mutex_lock il(ir->mu);
879 ir->status = s;
880 known_waiters.swap(ir->known_waiters);
881 }
882 for (const IRConsumer& done : known_waiters) {
883 done(ir);
884 }
885 }
886 }
887
888 } // namespace tensorflow
889