• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16 
17 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
18 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
19 #include "tensorflow/core/distributed_runtime/worker_cache.h"
20 #include "tensorflow/core/protobuf/config.pb.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 class CompleteGroupCall : public CancellableCall {
26  public:
CompleteGroupCall(const CollGroupParams & group,const string & device_name,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)27   CompleteGroupCall(const CollGroupParams& group, const string& device_name,
28                     CancellationManager* cancel_mgr,
29                     const string& remote_worker, WorkerCacheInterface* wc)
30       : CancellableCall(cancel_mgr, remote_worker, wc) {
31     req_.set_group_key(group.group_key);
32     req_.set_group_size(group.group_size);
33     req_.set_device_type(group.device_type.type_string());
34     req_.add_device_name(device_name);
35   }
~CompleteGroupCall()36   ~CompleteGroupCall() override {}
37 
IssueCall(const StatusCallback & done)38   void IssueCall(const StatusCallback& done) override {
39     wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
40   }
41 
42   CompleteGroupRequest req_;
43   CompleteGroupResponse resp_;
44 };
45 
46 class CompleteInstanceCall : public CancellableCall {
47  public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)48   CompleteInstanceCall(const CollGroupParams& group,
49                        const CollInstanceParams& instance,
50                        const string& node_name, const string& device_name,
51                        bool is_source, CancellationManager* cancel_mgr,
52                        const string& remote_worker, WorkerCacheInterface* wc)
53       : CancellableCall(cancel_mgr, remote_worker, wc) {
54     req_.set_name(node_name);
55     req_.set_type(instance.type);
56     req_.set_data_type(instance.data_type);
57     instance.shape.AsProto(req_.mutable_shape());
58     req_.set_group_key(group.group_key);
59     req_.set_group_size(group.group_size);
60     req_.set_instance_key(instance.instance_key);
61     req_.set_device_type(group.device_type.type_string());
62     for (int32 offset : instance.impl_details.subdiv_offsets) {
63       req_.add_subdiv_offset(offset);
64     }
65     req_.set_device(device_name);
66     req_.set_is_source(is_source);
67   }
68 
~CompleteInstanceCall()69   ~CompleteInstanceCall() override {}
70 
IssueCall(const StatusCallback & done)71   void IssueCall(const StatusCallback& done) override {
72     wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
73   }
74 
75   CompleteInstanceRequest req_;
76   CompleteInstanceResponse resp_;
77 };
78 
79 }  // namespace
80 
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,WorkerCacheInterface * worker_cache,const string & task_name)81 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
82     const ConfigProto& config, const DeviceMgr* dev_mgr,
83     DeviceResolverDistributed* dev_resolver, WorkerCacheInterface* worker_cache,
84     const string& task_name)
85     : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver, task_name),
86       worker_cache_(worker_cache),
87       group_leader_(task_name == config.experimental().collective_group_leader()
88                         ? ""
89                         : config.experimental().collective_group_leader()) {
90   VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
91           << "} config.collective_group_leader={"
92           << config.experimental().collective_group_leader() << "}";
93 }
94 
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)95 void CollectiveParamResolverDistributed::CompleteParamsAsync(
96     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
97     const StatusCallback& done) {
98   VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": "
99           << cp->ToString();
100   CompleteGroupDistributed(device, cp, cancel_mgr,
101                            [this, device, cp, cancel_mgr, done](
102                                const Status& s, const GroupRec* gr) {
103                              if (s.ok()) {
104                                CompleteInstanceDistributed(device, gr, cp,
105                                                            cancel_mgr, done);
106                              } else {
107                                done(s);
108                              }
109                            });
110 }
111 
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)112 void CollectiveParamResolverDistributed::CompleteGroupAsync(
113     const CompleteGroupRequest* request, CompleteGroupResponse* response,
114     CancellationManager* cancel_mgr, const StatusCallback& done) {
115   CollectiveParams cp;
116   cp.group.group_key = request->group_key();
117   cp.group.group_size = request->group_size();
118   cp.group.device_type = DeviceType(request->device_type());
119   for (const string& dn : request->device_name()) {
120     cp.instance.device_names.push_back(dn);
121   }
122   CompleteGroupDistributed(
123       cp.instance.device_names[0], &cp, cancel_mgr,
124       [this, response, done](const Status& s, const GroupRec* gr) {
125         if (s.ok()) {
126           mutex_lock l(gr->mu);
127           response->set_group_key(gr->group.group_key);
128           response->set_group_size(gr->group.group_size);
129           response->set_device_type(gr->group.device_type.type_string());
130           response->set_num_tasks(gr->task_set.size());
131           for (const string& dn : gr->device_list) {
132             response->add_device_name(dn);
133           }
134           for (const string& tn : gr->task_list) {
135             response->add_task_name(tn);
136           }
137         } else {
138           LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
139         }
140         done(s);
141       });
142 }
143 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)144 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
145     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
146     CancellationManager* cancel_mgr, const StatusCallback& done) {
147   CollectiveParams* cp = new CollectiveParams;
148   cp->name = request->name();
149   cp->group.group_key = request->group_key();
150   cp->group.group_size = request->group_size();
151   cp->group.device_type = DeviceType(request->device_type());
152   cp->instance.type = CollectiveType(request->type());
153   cp->instance.instance_key = request->instance_key();
154   cp->instance.data_type = request->data_type();
155   cp->instance.shape = TensorShape(request->shape());
156   for (int32 offset : request->subdiv_offset()) {
157     cp->instance.impl_details.subdiv_offsets.push_back(offset);
158   }
159   string* device = new string(request->device());
160   VLOG(1) << "New cp " << cp << " for device " << *device << " : "
161           << cp->ToString();
162   StatusCallback done_and_cleanup = [this, cp, device, done](const Status& s) {
163     done(s);
164     delete cp;
165     delete device;
166   };
167   // Start by completing the group.
168   CompleteGroupDistributed(
169       *device, cp, cancel_mgr,
170       [this, cp, device, response, cancel_mgr, done_and_cleanup](
171           const Status& cg_status, const GroupRec* gr) {
172         if (cg_status.ok()) {
173           // Then complete the instance.
174           CompleteInstanceDistributed(
175               *device, gr, cp, cancel_mgr,
176               [this, gr, cp, response,
177                done_and_cleanup](const Status& ci_status) {
178                 if (ci_status.ok()) {
179                   // Now source_rank should be known, so
180                   // retrieve it.
181                   FindInstanceRec(
182                       gr, cp,
183                       [this, gr, cp, response, done_and_cleanup](
184                           const Status& fi_status, InstanceRec* ir) {
185                         if (fi_status.ok()) {
186                           mutex_lock l(ir->out_mu);
187                           ir->WaitForOutMu(l);
188                           response->set_instance_key(cp->instance.instance_key);
189                           response->set_source_rank(ir->source_rank);
190                           if (!cp->instance.communicator_key.empty()) {
191                             response->set_communicator_key(
192                                 cp->instance.communicator_key);
193                           }
194                           done_and_cleanup(fi_status);
195                         } else {
196                           done_and_cleanup(fi_status);
197                         }
198                       });
199                 } else {
200                   done_and_cleanup(ci_status);
201                 }
202               });
203         } else {
204           done_and_cleanup(cg_status);
205         }
206       });
207 }
208 
GroupIsCached(int32 group_key)209 bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) {
210   mutex_lock l(group_mu_);
211   const auto& it = group_table_.find(group_key);
212   return it != group_table_.end();
213 }
214 
UpdateGroupCache(const CompleteGroupResponse & resp)215 Status CollectiveParamResolverDistributed::UpdateGroupCache(
216     const CompleteGroupResponse& resp) {
217   // Build a new record from resp.
218   std::unique_ptr<GroupRec> gr(new GroupRec);
219   mutex_lock grl(gr->mu);
220   gr->group.device_type = DeviceType(resp.device_type());
221   gr->group.group_key = resp.group_key();
222   gr->group.group_size = resp.group_size();
223   gr->group.num_tasks = resp.num_tasks();
224   if (resp.device_name_size() != gr->group.group_size) {
225     return errors::Internal(
226         "CompleteGroupResponse group_size doesn't match device_name list");
227   }
228   for (const string& dn : resp.device_name()) {
229     gr->device_set.insert(dn);
230     gr->device_list.push_back(dn);
231   }
232   if (resp.task_name_size() != gr->group.group_size) {
233     return errors::Internal(
234         "CompleteGroupResponse group_size doesn't match task_name list");
235   }
236   for (const string& tn : resp.task_name()) {
237     gr->task_list.push_back(tn);
238     gr->task_set.insert(tn);
239   }
240   CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
241   {
242     // Group membership should never change. Once a record is in group_table_
243     // it never gets removed.
244     mutex_lock l(group_mu_);
245     auto it = group_table_.find(gr->group.group_key);
246     if (it == group_table_.end()) {
247       group_table_[gr->group.group_key] = std::move(gr);
248     }
249   }
250   return Status::OK();
251 }
252 
CompleteGroupDistributed(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const GroupRecCallback & done)253 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
254     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
255     const GroupRecCallback& done) {
256   VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
257           << " dev: " << device << " is_leader=" << (group_leader_.empty());
258   if (group_leader_.empty()) {
259     // This is the group leader, so resolution is local.
260     return CompleteGroupLocal(device, cp, done);
261   } else if (!GroupIsCached(cp->group.group_key)) {
262     // Need to update Group cache from the leader.
263     CompleteGroupCall* call = new CompleteGroupCall(
264         cp->group, device, cancel_mgr, group_leader_, worker_cache_);
265     call->Start([this, device, cp, call, done](const Status& s) {
266       if (s.ok()) {
267         Status status = UpdateGroupCache(call->resp_);
268         if (status.ok()) {
269           CompleteGroupLocal(device, cp, done);
270         } else {
271           done(status, nullptr);
272         }
273       } else {
274         done(s, nullptr);
275       }
276       delete call;
277     });
278     return;
279   } else {
280     return CompleteGroupLocal(device, cp, done);
281   }
282 }
283 
InstanceIsCached(int32 instance_key)284 bool CollectiveParamResolverDistributed::InstanceIsCached(int32 instance_key) {
285   mutex_lock l(instance_mu_);
286   const auto& it = instance_table_.find(instance_key);
287   return it != instance_table_.end();
288 }
289 
UpdateInstanceCache(const GroupRec * gr,CollectiveParams * cp,const CompleteInstanceResponse & resp,const StatusCallback & done)290 void CollectiveParamResolverDistributed::UpdateInstanceCache(
291     const GroupRec* gr, CollectiveParams* cp,
292     const CompleteInstanceResponse& resp, const StatusCallback& done) {
293   using InstanceRecPointer = InstanceRec*;
294   InstanceRecPointer* irp = new InstanceRecPointer(nullptr);
295   int32 source_rank = resp.source_rank();
296   string communicator_key = resp.communicator_key();
297 
298   auto continue_with_ir = [cp, irp, source_rank, communicator_key,
299                            done](const Status& s) {
300     if (!s.ok()) {
301       done(s);
302       delete irp;
303       return;
304     }
305     Status status;
306     InstanceRec* ir = *irp;
307     do {
308       mutex_lock l(ir->out_mu);
309       ir->WaitForOutMu(l);
310       if (ir->source_rank != source_rank) {
311         if (ir->source_rank >= 0) {
312           ir->status = errors::Internal(
313               "UpdateInstanceCache: CompleteInstanceResponse for instance ",
314               cp->instance.instance_key, " gives source_rank=", source_rank,
315               " but cache already holds value=", ir->source_rank);
316           status = ir->status;
317           break;
318         }
319         ir->source_rank = source_rank;
320       }
321       if (ir->communicator_key != communicator_key) {
322         if (!ir->communicator_key.empty()) {
323           ir->status = errors::Internal(
324               "UpdateInstanceCache: CompleteInstanceResponse for instance ",
325               cp->instance.instance_key,
326               " gives communicator_key with size =", communicator_key.size(),
327               " but cache already holds communicator_key with size=",
328               ir->communicator_key.size());
329           status = ir->status;
330           break;
331         }
332         ir->communicator_key = communicator_key;
333       }
334       if (ir->known_count < cp->group.group_size) {
335         ir->known_count = cp->group.group_size;
336         if (ir->known.size() != cp->group.group_size) {
337           ir->status = errors::Internal(
338               "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
339               cp->instance.instance_key, " has known.size()=", ir->known.size(),
340               " < group_size=", cp->group.group_size);
341           status = ir->status;
342           break;
343         }
344         for (int i = 0; i < ir->known.size(); ++i) {
345           ir->known[i] = true;
346         }
347       }
348       status = ir->status;
349     } while (false);
350     // Callback outside of lock.
351     done(status);
352     delete irp;
353   };
354 
355   FindInstanceRec(
356       gr, cp, [this, irp, continue_with_ir](const Status s, InstanceRec* irec) {
357         *irp = irec;
358         continue_with_ir(s);
359       });
360 }
361 
CompleteInstanceDistributed(const string & device,const GroupRec * gr,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)362 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
363     const string& device, const GroupRec* gr, CollectiveParams* cp,
364     CancellationManager* cancel_mgr, const StatusCallback& done) {
365   if (group_leader_.empty()) {
366     // This is the group leader so resolution is local.
367     return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
368   } else if (InstanceIsCached(cp->instance.instance_key)) {
369     return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
370   } else {
371     CompleteInstanceCall* call = new CompleteInstanceCall(
372         cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
373         group_leader_, worker_cache_);
374     call->Start([this, device, gr, cp, call, done](const Status& s) {
375       if (s.ok()) {
376         UpdateInstanceCache(
377             gr, cp, call->resp_, [this, device, gr, cp, done](const Status& s) {
378               if (!s.ok()) {
379                 done(s);
380               } else {
381                 CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
382               }
383             });
384       } else {
385         done(s);
386       }
387       delete call;
388     });
389     return;
390   }
391 }
392 
393 }  // namespace tensorflow
394