• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16 
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
21 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class CompleteGroupCall : public CancellableCall {
33  public:
CompleteGroupCall(const CollGroupParams & group,const DeviceAttributes & device,const CollectiveType & collective_type,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)34   CompleteGroupCall(const CollGroupParams& group,
35                     const DeviceAttributes& device,
36                     const CollectiveType& collective_type,
37                     CancellationManager* cancel_mgr,
38                     const string& remote_worker, WorkerCacheInterface* wc)
39       : CancellableCall(cancel_mgr, remote_worker, wc) {
40     req_.set_group_key(group.group_key);
41     req_.set_group_size(group.group_size);
42     req_.set_device_type(group.device_type.type_string());
43     *req_.mutable_device_attributes() = device;
44     req_.set_collective_type(collective_type);
45   }
~CompleteGroupCall()46   ~CompleteGroupCall() override {}
47 
IssueCall(const StatusCallback & done)48   void IssueCall(const StatusCallback& done) override {
49     wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
50   }
51 
52   CompleteGroupRequest req_;
53   CompleteGroupResponse resp_;
54 };
55 
56 class CompleteInstanceCall : public CancellableCall {
57  public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)58   CompleteInstanceCall(const CollGroupParams& group,
59                        const CollInstanceParams& instance,
60                        const string& node_name, const string& device_name,
61                        bool is_source, CancellationManager* cancel_mgr,
62                        const string& remote_worker, WorkerCacheInterface* wc)
63       : CancellableCall(cancel_mgr, remote_worker, wc) {
64     req_.set_name(node_name);
65     req_.set_type(instance.type);
66     req_.set_data_type(instance.data_type);
67     instance.shape.AsProto(req_.mutable_shape());
68     req_.set_group_key(group.group_key);
69     req_.set_group_size(group.group_size);
70     req_.set_instance_key(instance.instance_key);
71     req_.set_device_type(group.device_type.type_string());
72     for (int32 offset : instance.impl_details.subdiv_offsets) {
73       req_.add_subdiv_offset(offset);
74     }
75     req_.set_device(device_name);
76     req_.set_is_source(is_source);
77   }
78 
~CompleteInstanceCall()79   ~CompleteInstanceCall() override {}
80 
IssueCall(const StatusCallback & done)81   void IssueCall(const StatusCallback& done) override {
82     wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
83   }
84 
85   CompleteInstanceRequest req_;
86   CompleteInstanceResponse resp_;
87 };
88 
89 }  // namespace
90 
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,WorkerCacheInterface * worker_cache,const string & task_name)91 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
92     const ConfigProto& config, const DeviceMgr* dev_mgr,
93     DeviceResolverDistributed* dev_resolver, WorkerCacheInterface* worker_cache,
94     const string& task_name)
95     : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver, task_name),
96       worker_cache_(worker_cache),
97       group_leader_(task_name == config.experimental().collective_group_leader()
98                         ? ""
99                         : config.experimental().collective_group_leader()) {
100   VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
101           << "} config.collective_group_leader={"
102           << config.experimental().collective_group_leader() << "}"
103           << " config.collective_nccl={"
104           << config.experimental().collective_nccl() << "}";
105 }
106 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)107 void CollectiveParamResolverDistributed::CompleteParamsAsync(
108     const DeviceAttributes& device, CollectiveParams* cp,
109     CancellationManager* cancel_mgr, const StatusCallback& done) {
110   VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
111           << ": " << cp->ToString();
112   CompleteGroupDistributed(
113       device, cp, cancel_mgr,
114       [this, device, cp, cancel_mgr, done](Status s, const GroupRec* gr) {
115         if (s.ok()) {
116           std::vector<DeviceAttributes> attributes;
117           mutex_lock l(gr->mu);
118           for (const auto& item : gr->devices) {
119             attributes.push_back(item.second);
120           }
121           s = dev_resolver_->UpdateDeviceAttributes(attributes);
122         }
123         if (s.ok()) {
124           CompleteInstanceDistributed(device.name(), gr, cp, cancel_mgr, done);
125         } else {
126           done(s);
127         }
128       });
129 }
130 
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)131 void CollectiveParamResolverDistributed::CompleteGroupAsync(
132     const CompleteGroupRequest* request, CompleteGroupResponse* response,
133     CancellationManager* cancel_mgr, const StatusCallback& done) {
134   if (!request->has_device_attributes()) {
135     done(errors::Internal(
136         "CompleteGroupRequest device_attributes is not set. Make sure you're "
137         "running the same version of Tensorflow on all workers."));
138     return;
139   }
140   auto* cp = new CollectiveParams();
141   core::ScopedUnref unref(cp);
142   cp->group.group_key = request->group_key();
143   cp->group.group_size = request->group_size();
144   cp->group.device_type = DeviceType(request->device_type());
145   cp->instance.type = CollectiveType(request->collective_type());
146   CompleteGroupDistributed(
147       request->device_attributes(), cp, cancel_mgr,
148       [response, done](const Status& s, const GroupRec* gr) {
149         if (s.ok()) {
150           mutex_lock l(gr->mu);
151           response->set_group_key(gr->group.group_key);
152           response->set_group_size(gr->group.group_size);
153           response->set_device_type(gr->group.device_type.type_string());
154           response->set_num_tasks(gr->group.num_tasks);
155           for (const auto& item : gr->devices) {
156             *response->add_device_attributes() = item.second;
157           }
158           response->set_communicator_key(
159               gr->group.runtime_details.communicator_key);
160         } else {
161           LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
162         }
163         done(s);
164       });
165 }
166 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)167 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
168     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
169     CancellationManager* cancel_mgr, const StatusCallback& done) {
170   GroupRec* gr = GetCachedGroup(request->group_key());
171   if (gr == nullptr) {
172     done(errors::FailedPrecondition(
173         "group ", request->group_key(),
174         " not found. This normally means the server has restarted"));
175     return;
176   }
177   {
178     mutex_lock l(gr->mu);
179     if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) {
180       done(errors::FailedPrecondition(
181           "group ", request->group_key(),
182           " failed to resolve. This normally means the server has restarted"));
183       return;
184     }
185   }
186   CollectiveParams* cp = new CollectiveParams;
187   cp->name = request->name();
188   cp->group.group_key = request->group_key();
189   cp->group.group_size = request->group_size();
190   cp->group.device_type = DeviceType(request->device_type());
191   cp->instance.type = CollectiveType(request->type());
192   cp->instance.instance_key = request->instance_key();
193   cp->instance.data_type = request->data_type();
194   cp->instance.shape = TensorShape(request->shape());
195   for (int32 offset : request->subdiv_offset()) {
196     cp->instance.impl_details.subdiv_offsets.push_back(offset);
197   }
198   StatusCallback done_and_cleanup = [cp, done](const Status& s) {
199     done(s);
200     cp->Unref();
201   };
202   CompleteInstanceDistributed(
203       request->device(), gr, cp, cancel_mgr,
204       [this, gr, cp, response, done_and_cleanup](Status status) {
205         if (status.ok()) {
206           // Now source_rank should be known, so retrieve it.
207           bool created_irec;
208           InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
209           {
210             mutex_lock l(ir->mu);
211             status = ir->status;
212             if (ir->status.ok()) {
213               response->set_instance_key(cp->instance.instance_key);
214               response->set_source_rank(ir->source_rank);
215             }
216           }
217         }
218         done_and_cleanup(status);
219       });
220 }
221 
222 CollectiveParamResolverDistributed::GroupRec*
GetCachedGroup(int32 group_key)223 CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) {
224   mutex_lock l(group_mu_);
225   auto it = group_table_.find(group_key);
226   if (it == group_table_.end()) {
227     return nullptr;
228   }
229   return it->second.get();
230 }
231 
UpdateGroupCache(const CompleteGroupResponse & resp)232 Status CollectiveParamResolverDistributed::UpdateGroupCache(
233     const CompleteGroupResponse& resp) {
234   // Build a new record from resp.
235   std::unique_ptr<GroupRec> gr(new GroupRec);
236   {
237     mutex_lock grl(gr->mu);
238     gr->group.device_type = DeviceType(resp.device_type());
239     gr->group.group_key = resp.group_key();
240     gr->group.group_size = resp.group_size();
241     gr->group.num_tasks = resp.num_tasks();
242     if (resp.device_attributes().empty()) {
243       return errors::Internal(
244           "CompleteGroupResponse device_attributes is empty. Make sure you're "
245           "running the same version of Tensorflow on all workers.");
246     }
247     if (resp.device_attributes_size() != gr->group.group_size) {
248       return errors::Internal(
249           "CompleteGroupResponse group_size doesn't match device_name list");
250     }
251     for (const DeviceAttributes& device : resp.device_attributes()) {
252       gr->devices[device.name()] = device;
253     }
254     gr->group.runtime_details.communicator_key = resp.communicator_key();
255     FinishGroup(gr.get());
256   }
257   GroupRec* previous_gr = nullptr;
258   {
259     // Group membership should never change. Once a record is in group_table_
260     // it never gets removed.
261     mutex_lock l(group_mu_);
262     auto it = group_table_.find(resp.group_key());
263     if (it == group_table_.end()) {
264       VLOG(2) << "UpdateGroupCache: communicator_key="
265               << absl::CEscape(resp.communicator_key());
266       group_table_[gr->group.group_key] = std::move(gr);
267     } else {
268       previous_gr = it->second.get();
269     }
270   }
271   if (previous_gr != nullptr) {
272     mutex_lock grl(previous_gr->mu);
273     if (previous_gr->group.runtime_details.communicator_key !=
274         resp.communicator_key()) {
275       return errors::Internal(
276           "UpdateGroupCache: CompleteGroupResponse for group ",
277           resp.group_key(),
278           " gives communicator_key=", absl::CEscape(resp.communicator_key()),
279           " but cache already holds communicator_key=",
280           absl::CEscape(previous_gr->group.runtime_details.communicator_key));
281     }
282   }
283   return Status::OK();
284 }
285 
CompleteGroupDistributed(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const GroupRecCallback & done)286 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
287     const DeviceAttributes& device, CollectiveParams* cp,
288     CancellationManager* cancel_mgr, const GroupRecCallback& done) {
289   VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
290           << " dev: " << device.name()
291           << " is_leader=" << (group_leader_.empty());
292   if (group_leader_.empty()) {
293     // This is the group leader, so resolution is local.
294     return CompleteGroupLocal(device, cp, done, cancel_mgr);
295   } else if (GetCachedGroup(cp->group.group_key) == nullptr) {
296     // Need to update Group cache from the leader.
297     CompleteGroupCall* call =
298         new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
299                               group_leader_, worker_cache_);
300     CancellationToken abortion_token =
301         abortion_cancel_mgr_.get_cancellation_token();
302     bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
303         abortion_token, [call] { call->Cancel(); });
304     if (already_aborted) {
305       done(errors::Cancelled("collective ops already aborted"), nullptr);
306       delete call;
307       return;
308     }
309     call->Start([this, device, cp, call, cancel_mgr, abortion_token,
310                  done](const Status& s) {
311       abortion_cancel_mgr_.DeregisterCallback(abortion_token);
312       if (s.ok()) {
313         Status status = UpdateGroupCache(call->resp_);
314         if (status.ok()) {
315           CompleteGroupLocal(device, cp, done, cancel_mgr);
316         } else {
317           done(status, nullptr);
318         }
319       } else {
320         done(s, nullptr);
321       }
322       delete call;
323     });
324     return;
325   } else {
326     return CompleteGroupLocal(device, cp, done, cancel_mgr);
327   }
328 }
329 
InstanceIsCached(int32 group_key,int32 instance_key)330 bool CollectiveParamResolverDistributed::InstanceIsCached(int32 group_key,
331                                                           int32 instance_key) {
332   mutex_lock l(instance_mu_);
333   auto group_it = instance_table_.find(group_key);
334   if (group_it == instance_table_.end()) {
335     return false;
336   }
337   auto instance_it = group_it->second.find(instance_key);
338   return instance_it != group_it->second.end();
339 }
340 
UpdateInstanceCache(const GroupRec * gr,CollectiveParams * cp,const CompleteInstanceResponse & resp)341 Status CollectiveParamResolverDistributed::UpdateInstanceCache(
342     const GroupRec* gr, CollectiveParams* cp,
343     const CompleteInstanceResponse& resp) {
344   int32 source_rank = resp.source_rank();
345   bool created_irec;
346   InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
347   mutex_lock l(ir->mu);
348   if (!ir->status.ok()) {
349     return ir->status;
350   }
351   if (ir->source_rank != source_rank) {
352     if (ir->source_rank >= 0) {
353       ir->status = errors::Internal(
354           "UpdateInstanceCache: CompleteInstanceResponse for instance ",
355           cp->instance.instance_key, " gives source_rank=", source_rank,
356           " but cache already holds value=", ir->source_rank);
357       return ir->status;
358     }
359     ir->source_rank = source_rank;
360   }
361   if (ir->known_count < cp->group.group_size) {
362     ir->known_count = cp->group.group_size;
363     const int ir_known_size = ir->known.size();
364     if (ir_known_size != cp->group.group_size) {
365       ir->status = errors::Internal(
366           "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
367           cp->instance.instance_key, " has known.size()=", ir->known.size(),
368           " < group_size=", cp->group.group_size);
369       return ir->status;
370     }
371     for (int i = 0; i < ir_known_size; ++i) {
372       ir->known[i] = true;
373     }
374   }
375   return ir->status;
376 }
377 
CompleteInstanceDistributed(const string & device,const GroupRec * gr,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)378 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
379     const string& device, const GroupRec* gr, CollectiveParams* cp,
380     CancellationManager* cancel_mgr, const StatusCallback& done) {
381   if (group_leader_.empty()) {
382     // This is the group leader so resolution is local.
383     return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
384   } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
385     return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
386   } else {
387     CompleteInstanceCall* call = new CompleteInstanceCall(
388         cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
389         group_leader_, worker_cache_);
390     CancellationToken abortion_token =
391         abortion_cancel_mgr_.get_cancellation_token();
392     bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
393         abortion_token, [call] { call->Cancel(); });
394     if (already_aborted) {
395       done(errors::Cancelled("collective ops already aborted"));
396       delete call;
397       return;
398     }
399     call->Start([this, device, gr, cp, call, abortion_token, done](Status s) {
400       abortion_cancel_mgr_.DeregisterCallback(abortion_token);
401       if (s.ok()) {
402         s = UpdateInstanceCache(gr, cp, call->resp_);
403       }
404       if (s.ok()) {
405         CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
406       } else {
407         done(s);
408       }
409       delete call;
410     });
411     return;
412   }
413 }
414 
StartAbort(const Status & s)415 void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
416   {
417     mutex_lock l(status_mu_);
418     if (!status_.ok()) {
419       VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
420                  "subsequent abortion with status: "
421               << s;
422       return;
423     }
424     status_ = s;
425   }
426   StartAbortLocal(s);
427   abortion_cancel_mgr_.StartCancel();
428 }
429 
430 }  // namespace tensorflow
431