• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16 
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
21 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class CompleteGroupCall : public CancellableCall {
33  public:
CompleteGroupCall(const CollGroupParams & group,const DeviceAttributes & device,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)34   CompleteGroupCall(const CollGroupParams& group,
35                     const DeviceAttributes& device,
36                     CancellationManager* cancel_mgr,
37                     const string& remote_worker, WorkerCacheInterface* wc)
38       : CancellableCall(cancel_mgr, remote_worker, wc) {
39     req_.set_group_key(group.group_key);
40     req_.set_group_size(group.group_size);
41     req_.set_device_type(group.device_type.type_string());
42     *req_.mutable_device_attributes() = device;
43   }
~CompleteGroupCall()44   ~CompleteGroupCall() override {}
45 
IssueCall(const StatusCallback & done)46   void IssueCall(const StatusCallback& done) override {
47     wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
48   }
49 
50   CompleteGroupRequest req_;
51   CompleteGroupResponse resp_;
52 };
53 
54 class CompleteInstanceCall : public CancellableCall {
55  public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)56   CompleteInstanceCall(const CollGroupParams& group,
57                        const CollInstanceParams& instance,
58                        const string& node_name, const string& device_name,
59                        bool is_source, CancellationManager* cancel_mgr,
60                        const string& remote_worker, WorkerCacheInterface* wc)
61       : CancellableCall(cancel_mgr, remote_worker, wc) {
62     req_.set_name(node_name);
63     req_.set_type(instance.type);
64     req_.set_data_type(instance.data_type);
65     instance.shape.AsProto(req_.mutable_shape());
66     req_.set_group_key(group.group_key);
67     req_.set_group_size(group.group_size);
68     req_.set_instance_key(instance.instance_key);
69     req_.set_device_type(group.device_type.type_string());
70     for (int32_t offset : instance.impl_details.subdiv_offsets) {
71       req_.add_subdiv_offset(offset);
72     }
73     req_.set_device(device_name);
74     req_.set_is_source(is_source);
75   }
76 
~CompleteInstanceCall()77   ~CompleteInstanceCall() override {}
78 
IssueCall(const StatusCallback & done)79   void IssueCall(const StatusCallback& done) override {
80     wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
81   }
82 
83   CompleteInstanceRequest req_;
84   CompleteInstanceResponse resp_;
85 };
86 
87 }  // namespace
88 
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,NcclCommunicatorInterface * nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)89 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
90     const ConfigProto& config, const DeviceMgr* dev_mgr,
91     DeviceResolverDistributed* dev_resolver,
92     NcclCommunicatorInterface* nccl_communicator,
93     WorkerCacheInterface* worker_cache, const string& task_name)
94     : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver,
95                                    nccl_communicator, task_name),
96       worker_cache_(worker_cache),
97       group_leader_(task_name == config.experimental().collective_group_leader()
98                         ? ""
99                         : config.experimental().collective_group_leader()) {
100   VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
101           << "} config.collective_group_leader={"
102           << config.experimental().collective_group_leader() << "}"
103           << " config.collective_nccl={"
104           << config.experimental().collective_nccl() << "}";
105 }
106 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)107 void CollectiveParamResolverDistributed::CompleteParamsAsync(
108     const DeviceAttributes& device, CollectiveParams* cp,
109     CancellationManager* cancel_mgr, const StatusCallback& done) {
110   VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
111           << ": " << cp->ToString();
112   CompleteGroupDistributed(
113       device, &cp->group, cancel_mgr,
114       [this, device, cp, cancel_mgr, done](Status s) {
115         if (s.ok()) {
116           s = dev_resolver_->UpdateDeviceAttributes(cp->group.devices);
117         }
118         if (s.ok()) {
119           CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
120         } else {
121           done(s);
122         }
123       });
124 }
125 
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)126 void CollectiveParamResolverDistributed::CompleteGroupAsync(
127     const DeviceAttributes& device, CollGroupParams* group_params,
128     CancellationManager* cancel_mgr, const StatusCallback& done) {
129   CompleteGroupDistributed(device, group_params, cancel_mgr, done);
130 }
131 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)132 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
133     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
134     CancellationManager* cancel_mgr, const StatusCallback& done) {
135   GroupRec* gr = GetCachedGroup(request->group_key());
136   if (gr == nullptr) {
137     done(errors::FailedPrecondition(
138         "group ", request->group_key(),
139         " not found. This normally means the server has restarted"));
140     return;
141   }
142   CollectiveParams* cp = new CollectiveParams;
143   {
144     mutex_lock l(gr->mu);
145     if (!gr->status.ok()) {
146       done(gr->status);
147       return;
148     } else if (gr->group.devices.size() != gr->group.group_size) {
149       done(errors::FailedPrecondition(
150           "group ", request->group_key(),
151           " failed to resolve. This normally means the server has restarted"));
152       return;
153     }
154     cp->group = gr->group;
155   }
156   cp->name = request->name();
157   cp->instance.type = CollectiveType(request->type());
158   cp->instance.instance_key = request->instance_key();
159   cp->instance.data_type = request->data_type();
160   cp->instance.shape = TensorShape(request->shape());
161   cp->is_source = request->is_source();
162   for (int32_t offset : request->subdiv_offset()) {
163     cp->instance.impl_details.subdiv_offsets.push_back(offset);
164   }
165   StatusCallback done_and_cleanup = [cp, done](const Status& s) {
166     done(s);
167     cp->Unref();
168   };
169   CompleteInstanceDistributed(
170       request->device(), cp, cancel_mgr,
171       [this, cp, response, done_and_cleanup](Status status) {
172         if (status.ok()) {
173           // Now source_rank should be known, so retrieve it.
174           bool created_irec;
175           InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
176           {
177             mutex_lock l(ir->mu);
178             status = ir->status;
179             if (ir->status.ok()) {
180               response->set_instance_key(cp->instance.instance_key);
181               response->set_source_rank(ir->source_rank);
182             }
183           }
184         }
185         done_and_cleanup(status);
186       });
187 }
188 
189 CollectiveParamResolverDistributed::GroupRec*
GetCachedGroup(int32_t group_key)190 CollectiveParamResolverDistributed::GetCachedGroup(int32_t group_key) {
191   mutex_lock l(group_mu_);
192   auto it = group_table_.find(group_key);
193   if (it == group_table_.end()) {
194     return nullptr;
195   }
196   return it->second.get();
197 }
198 
UpdateGroupCache(const CompleteGroupResponse & resp)199 Status CollectiveParamResolverDistributed::UpdateGroupCache(
200     const CompleteGroupResponse& resp) {
201   // Build a new record from resp.
202   std::unique_ptr<GroupRec> gr(new GroupRec);
203   {
204     mutex_lock grl(gr->mu);
205     gr->group.device_type = DeviceType(resp.device_type());
206     gr->group.group_key = resp.group_key();
207     gr->group.group_size = resp.group_size();
208     gr->group.num_tasks = resp.num_tasks();
209     if (resp.device_attributes().empty()) {
210       return errors::Internal(
211           "CompleteGroupResponse device_attributes is empty. Make sure you're "
212           "running the same version of Tensorflow on all workers.");
213     }
214     if (resp.device_attributes_size() != gr->group.group_size) {
215       return errors::Internal(
216           "CompleteGroupResponse group_size doesn't match device_name list");
217     }
218     gr->group.devices.reserve(resp.device_attributes().size());
219     for (const DeviceAttributes& device : resp.device_attributes()) {
220       gr->group.devices.push_back(device);
221       gr->incarnations_by_device_name[device.name()] = device.incarnation();
222     }
223     gr->group.runtime_details.communicator_key = resp.communicator_key();
224     FinishGroup(gr.get());
225   }
226   GroupRec* previous_gr = nullptr;
227   {
228     // Group membership should never change. Once a record is in group_table_
229     // it never gets removed.
230     mutex_lock l(group_mu_);
231     auto it = group_table_.find(resp.group_key());
232     if (it == group_table_.end()) {
233       VLOG(2) << "UpdateGroupCache: communicator_key="
234               << absl::CEscape(resp.communicator_key());
235       group_table_[gr->group.group_key] = std::move(gr);
236     } else {
237       previous_gr = it->second.get();
238     }
239   }
240   if (previous_gr != nullptr) {
241     mutex_lock grl(previous_gr->mu);
242     if (previous_gr->group.runtime_details.communicator_key !=
243         resp.communicator_key()) {
244       return errors::Internal(
245           "UpdateGroupCache: CompleteGroupResponse for group ",
246           resp.group_key(),
247           " gives communicator_key=", absl::CEscape(resp.communicator_key()),
248           " but cache already holds communicator_key=",
249           absl::CEscape(previous_gr->group.runtime_details.communicator_key));
250     }
251   }
252   return Status::OK();
253 }
254 
CompleteGroupDistributed(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)255 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
256     const DeviceAttributes& device, CollGroupParams* group_params,
257     CancellationManager* cancel_mgr, const StatusCallback& done) {
258   VLOG(1) << "CompleteGroupDistributed group_key=" << group_params->group_key
259           << " dev: " << device.name()
260           << " is_leader=" << (group_leader_.empty());
261   if (group_leader_.empty()) {
262     // This is the group leader, so resolution is local.
263     return CompleteGroupLocal(device, group_params, cancel_mgr, done);
264   } else if (GetCachedGroup(group_params->group_key) == nullptr) {
265     // Need to update Group cache from the leader.
266     CompleteGroupCall* call = new CompleteGroupCall(
267         *group_params, device, cancel_mgr, group_leader_, worker_cache_);
268     CancellationToken abortion_token =
269         abortion_cancel_mgr_.get_cancellation_token();
270     bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
271         abortion_token, [call] { call->Cancel(); });
272     if (already_aborted) {
273       done(errors::Cancelled("collective ops already aborted"));
274       delete call;
275       return;
276     }
277     call->Start([this, device, group_params, call, cancel_mgr, abortion_token,
278                  done](const Status& s) {
279       abortion_cancel_mgr_.DeregisterCallback(abortion_token);
280       if (s.ok()) {
281         Status status = UpdateGroupCache(call->resp_);
282         if (status.ok()) {
283           CompleteGroupLocal(device, group_params, cancel_mgr, done);
284         } else {
285           done(status);
286         }
287       } else {
288         done(s);
289       }
290       delete call;
291     });
292     return;
293   } else {
294     return CompleteGroupLocal(device, group_params, cancel_mgr, done);
295   }
296 }
297 
InstanceIsCached(int32_t group_key,int32_t instance_key)298 bool CollectiveParamResolverDistributed::InstanceIsCached(
299     int32_t group_key, int32_t instance_key) {
300   mutex_lock l(instance_mu_);
301   auto group_it = instance_table_.find(group_key);
302   if (group_it == instance_table_.end()) {
303     return false;
304   }
305   auto instance_it = group_it->second.find(instance_key);
306   return instance_it != group_it->second.end();
307 }
308 
UpdateInstanceCache(CollectiveParams * cp,const CompleteInstanceResponse & resp)309 Status CollectiveParamResolverDistributed::UpdateInstanceCache(
310     CollectiveParams* cp, const CompleteInstanceResponse& resp) {
311   int32_t source_rank = resp.source_rank();
312   bool created_irec;
313   InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
314   mutex_lock l(ir->mu);
315   if (!ir->status.ok()) {
316     return ir->status;
317   }
318   if (ir->source_rank != source_rank) {
319     if (ir->source_rank >= 0) {
320       ir->status = errors::Internal(
321           "UpdateInstanceCache: CompleteInstanceResponse for instance ",
322           cp->instance.instance_key, " gives source_rank=", source_rank,
323           " but cache already holds value=", ir->source_rank);
324       return ir->status;
325     }
326     ir->source_rank = source_rank;
327   }
328   if (ir->known_count < cp->group.group_size) {
329     ir->known_count = cp->group.group_size;
330     const int ir_known_size = ir->known.size();
331     if (ir_known_size != cp->group.group_size) {
332       ir->status = errors::Internal(
333           "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
334           cp->instance.instance_key, " has known.size()=", ir->known.size(),
335           " < group_size=", cp->group.group_size);
336       return ir->status;
337     }
338     for (int i = 0; i < ir_known_size; ++i) {
339       ir->known[i] = true;
340     }
341   }
342   return ir->status;
343 }
344 
CompleteInstanceDistributed(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)345 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
346     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
347     const StatusCallback& done) {
348   if (group_leader_.empty()) {
349     // This is the group leader so resolution is local.
350     return CompleteInstanceLocal(device, cp, done);
351   } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
352     return CompleteInstanceLocal(device, cp, done);
353   } else {
354     CompleteInstanceCall* call = new CompleteInstanceCall(
355         cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
356         group_leader_, worker_cache_);
357     CancellationToken abortion_token =
358         abortion_cancel_mgr_.get_cancellation_token();
359     bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
360         abortion_token, [call] { call->Cancel(); });
361     if (already_aborted) {
362       done(errors::Cancelled("collective ops already aborted"));
363       delete call;
364       return;
365     }
366     call->Start([this, device, cp, call, abortion_token, done](Status s) {
367       abortion_cancel_mgr_.DeregisterCallback(abortion_token);
368       if (s.ok()) {
369         s = UpdateInstanceCache(cp, call->resp_);
370       }
371       if (s.ok()) {
372         CompleteInstanceLocal(device, cp, done);
373       } else {
374         done(s);
375       }
376       delete call;
377     });
378     return;
379   }
380 }
381 
StartAbort(const Status & s)382 void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
383   {
384     mutex_lock l(status_mu_);
385     if (!status_.ok()) {
386       VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
387                  "subsequent abortion with status: "
388               << s;
389       return;
390     }
391     status_ = s;
392   }
393   StartAbortLocal(s);
394   abortion_cancel_mgr_.StartCancel();
395 }
396 
397 }  // namespace tensorflow
398