1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
16
17 #include <stddef.h>
18 #include <algorithm>
19 #include <unordered_map>
20 #include <utility>
21
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/framework/device_attributes.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/flatmap.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/types.h"
32 #include "tensorflow/core/util/device_name_utils.h"
33
34 namespace tensorflow {
35
WaitForOutMu(mutex_lock & lock)36 void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu(mutex_lock& lock) {
37 while (!out_mu_available) out_cv.wait(lock);
38 }
39
CollectiveParamResolverLocal(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,const string & task_name)40 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
41 const ConfigProto& config, const DeviceMgr* dev_mgr,
42 DeviceResolverInterface* dev_resolver, const string& task_name)
43 : nccl_(config.experimental().collective_nccl()),
44 dev_mgr_(dev_mgr),
45 dev_resolver_(dev_resolver),
46 task_name_(task_name) {}
47
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)48 void CollectiveParamResolverLocal::CompleteGroupAsync(
49 const CompleteGroupRequest* request, CompleteGroupResponse* response,
50 CancellationManager* cancel_mgr, const StatusCallback& done) {
51 done(
52 errors::Internal("CompleteGroup is not implemented by "
53 "CollectiveParamResolverLocal which is "
54 "intended only for non-distributed deployment."));
55 }
56
CompleteGroupLocal(const string & device,CollectiveParams * cp,const GroupRecCallback & done)57 void CollectiveParamResolverLocal::CompleteGroupLocal(
58 const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
59 VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
60 << cp->ToString();
61 std::vector<StatusCallback> to_be_called;
62 GroupRec* gr = nullptr;
63 {
64 mutex_lock l(group_mu_);
65 auto it = group_table_.find(cp->group.group_key);
66 if (it == group_table_.end()) {
67 gr = new GroupRec;
68 gr->group.group_key = cp->group.group_key;
69 gr->group.group_size = cp->group.group_size;
70 gr->group.device_type = cp->group.device_type;
71 group_table_[gr->group.group_key].reset(gr);
72 VLOG(2) << "New group_key=" << gr->group.group_key
73 << " group_size=" << gr->group.group_size;
74 } else {
75 gr = it->second.get();
76 }
77 }
78 Status status;
79 {
80 mutex_lock gr_lock(gr->mu);
81 if (!gr->device_set.empty()) {
82 // Check for consistency with existing GroupRec.
83 if (cp->group.device_type != gr->group.device_type) {
84 status = errors::Internal(
85 "Collective Op ", cp->name, " is assigned to device ", device,
86 " with type ", cp->group.device_type.type_string(),
87 " and group_key ", cp->group.group_key, " but that group has type ",
88 gr->group.device_type.type_string());
89 } else if (cp->group.group_size != gr->group.group_size) {
90 status = errors::Internal(
91 "Collective Op ", cp->name, " has group_size ",
92 cp->group.group_size, " and group_key", cp->group.group_key,
93 " but that group has size ", gr->group.group_size);
94 }
95 }
96 if (status.ok()) {
97 // Insert device if not already present.
98 auto it = gr->device_set.find(device);
99 if (it == gr->device_set.end()) {
100 if (gr->device_set.size() == gr->group.group_size) {
101 // The group is already full.
102 status = errors::Internal(
103 "Collective Op ", cp->name, " is assigned to device ", device,
104 " and group_key ", cp->group.group_key,
105 " but that group doesn't contain that device.");
106 } else {
107 // This is a new device that has not yet joined the group.
108 gr->device_set.insert(device);
109 gr->device_list.push_back(device);
110 DeviceNameUtils::ParsedName parsed_device;
111 DeviceNameUtils::ParseFullName(device, &parsed_device);
112 string task_name = strings::StrCat("/job:", parsed_device.job,
113 "/replica:", parsed_device.replica,
114 "/task:", parsed_device.task);
115 gr->task_set.insert(task_name);
116 gr->task_list.push_back(task_name);
117 gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
118 VLOG(1) << "group_key=" << gr->group.group_key
119 << " group_size=" << gr->group.group_size
120 << " dev_set=" << gr->device_set.size();
121 }
122 }
123 }
124
125 if (status.ok()) {
126 // If the group is not yet complete, queue to wait for it.
127 VLOG(2) << "group_size " << gr->group.group_size << " set size "
128 << gr->device_set.size() << " gr " << gr;
129
130 if (gr->device_set.size() < gr->group.group_size) {
131 gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
132 return;
133 }
134 CHECK_EQ(gr->device_set.size(), gr->group.group_size);
135 if (!gr->waiting.empty()) {
136 std::swap(to_be_called, gr->waiting);
137 }
138 }
139 }
140 done(status, gr);
141 for (int i = 0; i < to_be_called.size(); ++i) {
142 to_be_called[i](Status::OK());
143 }
144 }
145
146 namespace {
147 struct DevRec {
148 string task;
149 string device;
150 int original_rank;
151 int local_rank;
152 int global_rank;
153 const DeviceLocality* locality;
154 };
155 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
156 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
157
158 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
BuildDevRecs(const CollInstanceParams & ip,const std::vector<DeviceLocality> & localities)159 GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
160 const std::vector<DeviceLocality>& localities) {
161 GlobalDeviceMap gdm;
162 CHECK_EQ(ip.device_names.size(), ip.task_names.size());
163 CHECK_EQ(ip.device_names.size(), localities.size());
164 for (int i = 0; i < ip.device_names.size(); ++i) {
165 TaskDeviceMap& tdm = gdm[ip.task_names[i]];
166 DevRec* dr = &tdm[ip.device_names[i]];
167 dr->task = ip.task_names[i];
168 dr->device = ip.device_names[i];
169 dr->original_rank = i;
170 dr->local_rank = 0; // Will be populated later by OrderTaskDeviceMap.
171 dr->global_rank = 0; // Will be populated later by EstablishGlobalRank.
172 dr->locality = &localities[i];
173 }
174 return gdm;
175 }
176
ParseRingOrder(const string & gpu_ring_order_str,TaskDeviceMap * tdm)177 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
178 std::vector<int32> gpu_ring_order_vec;
179 if (!str_util::SplitAndParseAsInts(gpu_ring_order_str, ',',
180 &gpu_ring_order_vec)) {
181 return false;
182 }
183 if (gpu_ring_order_vec.size() != tdm->size()) return false;
184 // gpu id -> local rank
185 gtl::FlatMap<int32, int32> gpu_ranks;
186 for (int32 rank = 0; rank < static_cast<int32>(gpu_ring_order_vec.size());
187 ++rank) {
188 gpu_ranks[gpu_ring_order_vec[rank]] = rank;
189 }
190
191 for (auto& tdm_it : *tdm) {
192 DeviceNameUtils::ParsedName parsed_name;
193 DevRec* dr = &tdm_it.second;
194 if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
195 return false;
196 }
197 auto rank_it = gpu_ranks.find(parsed_name.id);
198 if (rank_it == gpu_ranks.end()) return false;
199 dr->local_rank = rank_it->second;
200 }
201 VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
202 return true;
203 }
204
OrderTaskDeviceMap(const string & gpu_ring_order,TaskDeviceMap * tdm)205 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
206 CHECK_GT(tdm->size(), 0); // Should never be called with 0 devices
207
208 // If a valid ring order has been passed in via ConfigProto, use that.
209 if (ParseRingOrder(gpu_ring_order, tdm)) return;
210
211 // Either no ring order was passed in, or the format was unexpected.
212 // We now assign a ring order based on link strengths. Note that this
213 // algorithm is not optimal and may not always find the best ring order.
214 int least_rank = -1;
215 string next_device;
216 std::set<string> selected;
217 // Starting device is one with the least initial rank.
218 for (const auto& it : *tdm) {
219 if (least_rank < 0 || it.second.original_rank < least_rank) {
220 least_rank = it.second.original_rank;
221 next_device = it.second.device;
222 }
223 }
224 CHECK_GE(least_rank, 0);
225 DeviceNameUtils::ParsedName parsed_name;
226 CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
227 // NOTE: InterconnectLink has only a device_id, nothing more, so for
228 // the time being if there's more than one device at a task we
229 // assume they're all GPUs.
230
231 int next_rank = 0;
232 while (true) {
233 selected.insert(next_device);
234 auto next_dev_it = tdm->find(next_device);
235 CHECK(next_dev_it != tdm->end());
236 DevRec* dr = &next_dev_it->second;
237 dr->local_rank = next_rank;
238 ++next_rank;
239 if (selected.size() == tdm->size()) {
240 break;
241 }
242 // For the present time we assume Locality links only cover GPUs.
243 // For multiple CPUs, just take them in order.
244 const InterconnectLink* best_link = nullptr;
245 if (parsed_name.type == "GPU") {
246 for (const InterconnectLink& il : dr->locality->links().link()) {
247 parsed_name.id = il.device_id();
248 string endpoint_device =
249 DeviceNameUtils::ParsedNameToString(parsed_name);
250 // Skip the device if we've already seen it.
251 if (selected.find(endpoint_device) != selected.end()) {
252 continue;
253 }
254 // Skip the device if it is not participating in this collective
255 // instance.
256 if (tdm->find(endpoint_device) == tdm->end()) {
257 continue;
258 }
259 if (best_link == nullptr || il.strength() > best_link->strength()) {
260 best_link = &il;
261 }
262 }
263 }
264 if (best_link != nullptr) {
265 // Follow the best edge
266 parsed_name.id = best_link->device_id();
267 next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
268 } else {
269 // No good edges, alas. Pick the lowest initial rank among remaining
270 // devices.
271 least_rank = -1;
272 for (const auto& it : *tdm) {
273 if (selected.find(it.second.device) != selected.end()) {
274 continue;
275 }
276 if (least_rank < 0 || it.second.original_rank < least_rank) {
277 least_rank = it.second.original_rank;
278 next_device = it.second.device;
279 }
280 }
281 CHECK_GE(least_rank, 0);
282 }
283 }
284 }
285
286 // The first time a shared CollectiveParams is established for a
287 // shared set of instances we compute a good rank order for all the
288 // devices in the group, that is appropriate for a ring algorithm.
289 // This order need not be the same across different instance groups
290 // sharing the same device group where there is more than one good
291 // order.
EstablishGlobalRank(CollectiveParams * cp,const std::vector<DeviceLocality> & localities)292 GlobalDeviceMap EstablishGlobalRank(
293 CollectiveParams* cp, const std::vector<DeviceLocality>& localities) {
294 VLOG(1) << "EstablishGlobalRank";
295 GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities);
296 for (auto& iter : gdm) {
297 TaskDeviceMap& tdm = iter.second;
298 OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm);
299 }
300 // Connect the global rank order by the order in which tasks first appear.
301 std::set<string> ordered_tasks;
302 int next_rank = 0;
303 for (int i = 0; i < cp->instance.task_names.size(); ++i) {
304 const string& task_name = cp->instance.task_names[i];
305 if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
306 continue;
307 }
308 ordered_tasks.insert(task_name);
309 TaskDeviceMap* tdm = &gdm[task_name];
310 for (auto& it : *tdm) {
311 it.second.global_rank = it.second.local_rank + next_rank;
312 }
313 next_rank += tdm->size();
314 }
315 return gdm;
316 }
317
318 // Count the devices associated with each task and set
319 // cp->same_num_devices_per_task. Requires cp->instance.task_names
320 // be sorted.
SetDevPerTask(CollectiveParams * cp)321 void SetDevPerTask(CollectiveParams* cp) {
322 cp->instance.num_devices_per_task.clear();
323 const string* last_task_name = &cp->instance.task_names[0];
324 int count = 0;
325 for (const string& task_name : cp->instance.task_names) {
326 if (task_name == *last_task_name) {
327 ++count;
328 } else {
329 cp->instance.num_devices_per_task[*last_task_name] = count;
330 count = 1;
331 last_task_name = &task_name;
332 }
333 }
334 cp->instance.num_devices_per_task[*last_task_name] = count;
335
336 cp->instance.same_num_devices_per_task = false;
337 int dev_per_task = -1;
338 for (const auto& task_dev : cp->instance.num_devices_per_task) {
339 if (dev_per_task == -1) {
340 dev_per_task = task_dev.second;
341 } else if (dev_per_task != task_dev.second) {
342 return;
343 }
344 }
345 cp->instance.same_num_devices_per_task = true;
346 CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0);
347 }
348
349 // Sort cp->instance.device_names lexicographically, but do by first
350 // computing a reordering permutation so we can keep cp->instance.task_names
351 // in corresponding order.
SortDevicesAndTasks(CollectiveParams * cp)352 void SortDevicesAndTasks(CollectiveParams* cp) {
353 VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
354 CHECK(cp);
355 CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
356 CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
357 std::vector<int> perm(cp->group.group_size);
358 // TODO(tucker): substitute std::iota when the windows build supports it.
359 // std::iota(perm.begin(), perm.end(), 0);
360 for (int i = 0; i < perm.size(); ++i) {
361 perm[i] = i;
362 }
363 std::sort(perm.begin(), perm.end(), [cp](int a, int b) {
364 return cp->instance.device_names[a] < cp->instance.device_names[b];
365 });
366 std::vector<string> new_devs;
367 std::vector<string> new_tasks;
368 new_devs.reserve(cp->group.group_size);
369 new_tasks.reserve(cp->group.group_size);
370 for (int pi : perm) {
371 new_devs.push_back(cp->instance.device_names[pi]);
372 new_tasks.push_back(cp->instance.task_names[pi]);
373 }
374 cp->instance.device_names = std::move(new_devs);
375 cp->instance.task_names = std::move(new_tasks);
376 VLOG(1) << "Modified device_names on " << cp;
377 SetDevPerTask(cp);
378 }
379 } // namespace
380
CompleteTaskIsLocal(const string & task_name,CollectiveParams * cp)381 void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
382 CollectiveParams* cp) {
383 cp->task.is_local.resize(cp->group.group_size, false);
384 for (int i = 0; i < cp->group.group_size; ++i) {
385 cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
386 }
387 }
388
SetDefaultRank(const string & device,CollectiveParams * cp)389 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
390 CollectiveParams* cp) {
391 CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
392 for (int i = 0; i < cp->group.group_size; ++i) {
393 if (cp->instance.device_names[i] == device) {
394 cp->default_rank = i;
395 break;
396 }
397 }
398 }
399
InitInstanceSharedParams(const GroupRec * gr,const CollectiveParams * cp,InstanceRec * ir,const StatusCallback & done)400 void CollectiveParamResolverLocal::InitInstanceSharedParams(
401 const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
402 const StatusCallback& done) {
403 ir->shared.instance = cp->instance;
404 {
405 mutex_lock gl(gr->mu);
406 ir->shared.group = gr->group;
407 ir->shared.instance.device_names.assign(gr->device_list.begin(),
408 gr->device_list.end());
409 ir->shared.instance.task_names.assign(gr->task_list.begin(),
410 gr->task_list.end());
411 VLOG(2) << "Initialized names for instance: "
412 << ir->shared.instance.ToString();
413 }
414 ir->shared.default_rank = -1;
415
416 // Sort device_names lexicographically, keeping task_names in corresponding
417 // order. Also set number of devices per task.
418 SortDevicesAndTasks(&ir->shared);
419
420 // Get Locality data for all devices.
421
422 // Set is_local and task_names in *shared prior to invoking
423 // GetDeviceLocalitiesAsync. In a distributed context this function can be
424 // called by a derived class, some of the devices may be non-local and
425 // GetDeviceLocalitiesAsync will use those fields to launch RPCs.
426 CompleteTaskIsLocal(task_name_, &ir->shared);
427
428 // Because the callback may execute in a different thread, we release
429 // ir->out_mu here. Before releasing, we mark it as unavailable for other
430 // threads.
431 ir->out_mu_available = false;
432 ir->out_mu.unlock();
433 std::vector<DeviceLocality>* localities = new std::vector<DeviceLocality>;
434 dev_resolver_->GetDeviceLocalitiesAsync(
435 ir->shared.instance, localities,
436 [this, gr, cp, ir, localities, done](const Status& s)
437 EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
438 // Then we recover the lock in the callback thread that will hold it
439 // through the rest of the call chain. Signal the cv now, any
440 // waiting threads will wake only when out_mu is released later.
441 ir->out_mu.lock();
442 DCHECK(!ir->out_mu_available);
443 ir->out_mu_available = true;
444 ir->out_cv.notify_all();
445 if (s.ok()) {
446 CompleteDefaultRanking(gr, cp, ir, *localities);
447 done(Status::OK());
448 } else {
449 done(s);
450 }
451 delete localities;
452 });
453 }
454
455 // NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks
456 // to all devices that they are physically connected to and visible to the
457 // TensorFlow runtime. This set of devices may be a superset of the devices
458 // participating in this instance of collectives.
CompleteDefaultRanking(const GroupRec * gr,const CollectiveParams * cp,InstanceRec * ir,const std::vector<DeviceLocality> & localities)459 void CollectiveParamResolverLocal::CompleteDefaultRanking(
460 const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
461 const std::vector<DeviceLocality>& localities) {
462 // Establish an instance-specific default rank order for devices
463 // based on localities. This rank order should be a good ring
464 // order, if possible.
465 GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities);
466 // Reflect the new global ranking on shared
467 size_t num_devices = ir->shared.group.group_size;
468 std::vector<string> new_device_names(num_devices, "");
469 std::vector<string> new_task_names(num_devices, "");
470 for (const auto& git : gdm) {
471 const TaskDeviceMap& tdm = git.second;
472 for (const auto& tit : tdm) {
473 const DevRec& dr = tit.second;
474 new_device_names[dr.global_rank] =
475 ir->shared.instance.device_names[dr.original_rank];
476 new_task_names[dr.global_rank] =
477 ir->shared.instance.task_names[dr.original_rank];
478 }
479 }
480
481 ir->shared.instance.device_names = new_device_names;
482 ir->shared.instance.task_names = new_task_names;
483 if (VLOG_IS_ON(2)) {
484 string buf;
485 for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
486 VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
487 }
488 }
489
CallbackWithStatus(const InstanceRecCallback & done,InstanceRec * irec)490 void CollectiveParamResolverLocal::CallbackWithStatus(
491 const InstanceRecCallback& done, InstanceRec* irec) {
492 Status s;
493 {
494 mutex_lock l(irec->out_mu);
495 irec->WaitForOutMu(l);
496 s = irec->status;
497 }
498 done(s, irec);
499 }
500
FindInstanceRec(const GroupRec * gr,CollectiveParams * cp,const InstanceRecCallback & done)501 void CollectiveParamResolverLocal::FindInstanceRec(
502 const GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
503 InstanceRec* irec = nullptr;
504 bool exit_outside_locks = false;
505 {
506 mutex_lock l(instance_mu_);
507 auto it = instance_table_.find(cp->instance.instance_key);
508 if (it != instance_table_.end()) {
509 irec = it->second.get();
510 {
511 mutex_lock l(irec->in_mu);
512 if (irec->is_init) {
513 exit_outside_locks = true;
514 } else {
515 irec->init_waiters.push_back([this, done](InstanceRec* irec) {
516 CallbackWithStatus(done, irec);
517 });
518 return;
519 }
520 }
521 } else {
522 // Create new InstanceRec.
523 irec = new InstanceRec;
524 instance_table_[cp->instance.instance_key].reset(irec);
525 }
526 }
527 if (exit_outside_locks) {
528 CallbackWithStatus(done, irec);
529 return;
530 }
531
532 CallInitInstanceSharedParams(gr, cp, irec, done);
533 }
534
CallInitInstanceSharedParams(const GroupRec * gr,const CollectiveParams * cp,InstanceRec * ir,const InstanceRecCallback & done)535 void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
536 const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
537 const InstanceRecCallback& done) NO_THREAD_SAFETY_ANALYSIS {
538 // This function serves merely to make a function call that should
539 // be thread/mutex safe but violates the simple model applied by
540 // static analysis, so we turn off analysis only within this
541 // function body.
542 //
543 // A lock on ir->out_mu must be held* throughout the _bodies_ of the
544 // chain of function calls initiated here, each of which calls
545 // another as its last action, but it will be dropped within the
546 // callback defined below, which means that the lock can be dropped
547 // before all the function stack frames pop. The static analysis will
548 // not allow that.
549 //
550 // *the lock is dropped just before calling GetDeviceLocalitiesAsync, because
551 // there is no guarantee that the thread that executes the callback is the
552 // same as the one that locked ir->out_mu. To prevent other threads from
553 // grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in
554 // principle, the lock is held throughout.
555 ir->out_mu.lock();
556 DCHECK(ir->out_mu_available);
557 ir->known.resize(cp->group.group_size, false);
558 InitInstanceSharedParams(
559 gr, cp, ir,
560 [this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
561 DCHECK(ir->out_mu_available);
562 ir->status.Update(s);
563 ir->out_mu.unlock();
564 // Prepare to invoke any waiters that accumulated during
565 // initialization.
566 std::vector<IRConsumer> init_waiters;
567 {
568 mutex_lock tl(instance_mu_);
569 {
570 mutex_lock l(ir->in_mu);
571 ir->is_init = true;
572 if (!ir->init_waiters.empty()) {
573 std::swap(init_waiters, ir->init_waiters);
574 }
575 }
576 }
577 CallbackWithStatus(done, ir);
578 for (auto& f : init_waiters) {
579 f(ir);
580 }
581 });
582 }
583
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)584 void CollectiveParamResolverLocal::CompleteParamsAsync(
585 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
586 const StatusCallback& done) {
587 VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
588 << cp->ToString();
589 CompleteGroupLocal(
590 device, cp,
591 [this, device, cp, done](const Status& s, const GroupRec* gr) {
592 if (s.ok()) {
593 CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
594 } else {
595 done(s);
596 }
597 });
598 }
599
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)600 void CollectiveParamResolverLocal::CompleteInstanceAsync(
601 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
602 CancellationManager* cancel_mgr, const StatusCallback& done) {
603 done(
604 errors::Internal("CompleteInstance is not implemented by "
605 "CollectiveParamResolverLocal which is "
606 "intended only for non-distributed deployment."));
607 }
608
609 // TODO(b/111897089): we need a better way to pick the collective
610 // implementation. The ideal way would depend upon the topology and link
611 // strength before picking a particular implementation.
AssignCollectiveType(CollectiveParams * cp)612 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
613 if (cp->instance.type == BROADCAST_COLLECTIVE) {
614 cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
615 } else if (cp->instance.type == REDUCTION_COLLECTIVE) {
616 if (nccl_) {
617 cp->instance.impl_details.collective_name = "NcclReduce";
618 } else {
619 cp->instance.impl_details.collective_name = "RingReduce";
620 }
621 } else if (cp->instance.type == GATHER_COLLECTIVE) {
622 cp->instance.impl_details.collective_name = "RingGather";
623 } else {
624 cp->instance.impl_details.collective_name = "undef";
625 }
626 VLOG(1) << "AssignCollectiveType "
627 << cp->instance.impl_details.collective_name;
628 }
629
CompleteInstanceLocal(const string & device,const GroupRec * gr,CollectiveParams * cp,bool is_source,const StatusCallback & done)630 void CollectiveParamResolverLocal::CompleteInstanceLocal(
631 const string& device, const GroupRec* gr, CollectiveParams* cp,
632 bool is_source, const StatusCallback& done) {
633 VLOG(1) << "CompleteInstanceLocal " << device
634 << " instance_key: " << cp->instance.instance_key << " gr " << gr;
635
636 // Populate the group portion of *cp from *gr. Most of it should already
637 // match.
638 DCHECK_EQ(cp->group.group_key, gr->group.group_key);
639 DCHECK_EQ(cp->group.group_size, gr->group.group_size);
640 DCHECK_EQ(cp->group.device_type, gr->group.device_type);
641 cp->group = gr->group;
642
643 // Get the shared InstanceRec for this instance.
644 FindInstanceRec(gr, cp,
645 [this, device, gr, cp, is_source, done](const Status& s,
646 InstanceRec* ir) {
647 if (s.ok()) {
648 CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
649 is_source, done);
650 } else {
651 done(s);
652 }
653 });
654 }
655
CompleteInstanceFromInitializedIRec(const string & device,const GroupRec * gr,CollectiveParams * cp,InstanceRec * ir,bool is_source,const StatusCallback & done)656 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
657 const string& device, const GroupRec* gr, CollectiveParams* cp,
658 InstanceRec* ir, bool is_source, const StatusCallback& done) {
659 // Populate the fields common across instance.
660 {
661 mutex_lock l(ir->out_mu);
662 ir->WaitForOutMu(l);
663 // custom operator= does a deep copy.
664 cp->instance = ir->shared.instance;
665 }
666 // Populate the fields common across task.
667 AssignCollectiveType(cp);
668 SetDefaultRank(device, cp);
669 CompleteTaskIsLocal(task_name_, cp);
670
671 CollectiveImplementationInterface* col_impl;
672 Status status = CollectiveRegistry::LookupParamResolverInstance(
673 cp->instance.impl_details.collective_name, &col_impl);
674 if (status.ok()) {
675 status = col_impl->InitializeInstanceBeforeGroupDiscovery(cp);
676 }
677 if (!status.ok()) {
678 done(status);
679 return;
680 }
681
682 // We may need to wait for the group if:
683 // * this is a broadcast, for source discovery;
684 // * we are using NCCL with more than 1 worker, for the communicator key from
685 // rank 0.
686 bool broadcast = cp->instance.type == BROADCAST_COLLECTIVE;
687 bool nccl = cp->instance.type == REDUCTION_COLLECTIVE &&
688 cp->instance.impl_details.collective_name == "NcclReduce" &&
689 cp->group.num_tasks > 1;
690 if (broadcast || nccl) {
691 WaitForGroup(ir, cp, is_source, broadcast, nccl,
692 [col_impl, ir, device, cp, done](InstanceRec* irec) {
693 Status s;
694 if (ir != irec) {
695 s = errors::Internal("Expected ir ", ir, " and irec ",
696 irec, " to be equal");
697 } else {
698 mutex_lock l(irec->out_mu);
699 irec->WaitForOutMu(l);
700 s = irec->status;
701 cp->source_rank = irec->source_rank;
702 cp->instance.communicator_key = irec->communicator_key;
703 }
704 if (s.ok()) {
705 s = col_impl->InitializeCollectiveParams(cp);
706 }
707 done(s);
708 });
709 } else {
710 done(col_impl->InitializeCollectiveParams(cp));
711 }
712 }
713
WaitForGroup(InstanceRec * ir,CollectiveParams * cp,bool is_source,bool init_source,bool init_nccl,const IRConsumer & f)714 void CollectiveParamResolverLocal::WaitForGroup(
715 InstanceRec* ir, CollectiveParams* cp, bool is_source, bool init_source,
716 bool init_nccl, const IRConsumer& f) {
717 std::vector<IRConsumer> ready_waiters;
718 {
719 mutex_lock l(ir->out_mu);
720 ir->WaitForOutMu(l);
721 CHECK_EQ(cp->group.group_size, ir->known.size());
722 CHECK_GE(cp->default_rank, 0);
723 if (!ir->known[cp->default_rank]) {
724 ir->known[cp->default_rank] = true;
725 ++ir->known_count;
726 if (init_source && is_source) {
727 // Initialize source rank.
728 if (ir->source_rank >= 0) {
729 ir->status = errors::Internal("Instance ", cp->instance.instance_key,
730 " already has source ", ir->source_rank,
731 ", received second claim from ",
732 cp->default_rank);
733 } else {
734 ir->source_rank = cp->default_rank;
735 }
736 }
737 if (init_nccl && cp->default_rank == 0) {
738 // Initialize communicator key.
739 if (!ir->communicator_key.empty()) {
740 ir->status =
741 errors::Internal("Instance ", cp->instance.instance_key,
742 " already has communicator_key ",
743 str_util::CEscape(ir->communicator_key),
744 ", received second claim from device ",
745 cp->instance.device_names[cp->default_rank]);
746 } else {
747 ir->communicator_key = cp->instance.communicator_key;
748 }
749 }
750 }
751 if (ir->known_count < ir->shared.group.group_size) {
752 ir->known_waiters.push_back(f);
753 return;
754 }
755 CHECK_EQ(ir->known_count, ir->shared.group.group_size);
756 if (init_source && ir->source_rank < 0) {
757 // NOTE(ayushd): changing the error message below would also require
758 // updating CompleteParamsBroadcastForgotSend test in
759 // CollectiveParamResolverLocalTest.
760 ir->status =
761 errors::Internal("Instance ", cp->instance.instance_key,
762 " found no source for broadcast. This "
763 "could mean that there were group_size=",
764 ir->known_count, " BcastRecvs but no BcastSend.");
765 }
766 if (init_nccl && ir->communicator_key.empty()) {
767 ir->status = errors::Internal(
768 "Instance ", cp->instance.instance_key, " device ",
769 cp->instance.device_names[cp->default_rank],
770 " did not find rank 0 for setting communicator key. This is an "
771 "internal error in collective param resolution");
772 }
773 if (!ir->known_waiters.empty()) {
774 ready_waiters = std::move(ir->known_waiters);
775 }
776 }
777 f(ir);
778 for (auto& f : ready_waiters) {
779 f(ir);
780 }
781 }
782
783 } // namespace tensorflow
784