1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16
17 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
18 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
19 #include "tensorflow/core/distributed_runtime/worker_cache.h"
20 #include "tensorflow/core/protobuf/config.pb.h"
21
22 namespace tensorflow {
23 namespace {
24
25 class CompleteGroupCall : public CancellableCall {
26 public:
CompleteGroupCall(const CollGroupParams & group,const string & device_name,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)27 CompleteGroupCall(const CollGroupParams& group, const string& device_name,
28 CancellationManager* cancel_mgr,
29 const string& remote_worker, WorkerCacheInterface* wc)
30 : CancellableCall(cancel_mgr, remote_worker, wc) {
31 req_.set_group_key(group.group_key);
32 req_.set_group_size(group.group_size);
33 req_.set_device_type(group.device_type.type_string());
34 req_.add_device_name(device_name);
35 }
~CompleteGroupCall()36 ~CompleteGroupCall() override {}
37
IssueCall(const StatusCallback & done)38 void IssueCall(const StatusCallback& done) override {
39 wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
40 }
41
42 CompleteGroupRequest req_;
43 CompleteGroupResponse resp_;
44 };
45
46 class CompleteInstanceCall : public CancellableCall {
47 public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)48 CompleteInstanceCall(const CollGroupParams& group,
49 const CollInstanceParams& instance,
50 const string& node_name, const string& device_name,
51 bool is_source, CancellationManager* cancel_mgr,
52 const string& remote_worker, WorkerCacheInterface* wc)
53 : CancellableCall(cancel_mgr, remote_worker, wc) {
54 req_.set_name(node_name);
55 req_.set_type(instance.type);
56 req_.set_data_type(instance.data_type);
57 instance.shape.AsProto(req_.mutable_shape());
58 req_.set_group_key(group.group_key);
59 req_.set_group_size(group.group_size);
60 req_.set_instance_key(instance.instance_key);
61 req_.set_device_type(group.device_type.type_string());
62 for (int32 offset : instance.impl_details.subdiv_offsets) {
63 req_.add_subdiv_offset(offset);
64 }
65 req_.set_device(device_name);
66 req_.set_is_source(is_source);
67 }
68
~CompleteInstanceCall()69 ~CompleteInstanceCall() override {}
70
IssueCall(const StatusCallback & done)71 void IssueCall(const StatusCallback& done) override {
72 wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
73 }
74
75 CompleteInstanceRequest req_;
76 CompleteInstanceResponse resp_;
77 };
78
79 } // namespace
80
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,WorkerCacheInterface * worker_cache,const string & task_name)81 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
82 const ConfigProto& config, const DeviceMgr* dev_mgr,
83 DeviceResolverDistributed* dev_resolver, WorkerCacheInterface* worker_cache,
84 const string& task_name)
85 : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver, task_name),
86 worker_cache_(worker_cache),
87 group_leader_(task_name == config.experimental().collective_group_leader()
88 ? ""
89 : config.experimental().collective_group_leader()) {
90 VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
91 << "} config.collective_group_leader={"
92 << config.experimental().collective_group_leader() << "}";
93 }
94
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)95 void CollectiveParamResolverDistributed::CompleteParamsAsync(
96 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
97 const StatusCallback& done) {
98 VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": "
99 << cp->ToString();
100 CompleteGroupDistributed(device, cp, cancel_mgr,
101 [this, device, cp, cancel_mgr, done](
102 const Status& s, const GroupRec* gr) {
103 if (s.ok()) {
104 CompleteInstanceDistributed(device, gr, cp,
105 cancel_mgr, done);
106 } else {
107 done(s);
108 }
109 });
110 }
111
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)112 void CollectiveParamResolverDistributed::CompleteGroupAsync(
113 const CompleteGroupRequest* request, CompleteGroupResponse* response,
114 CancellationManager* cancel_mgr, const StatusCallback& done) {
115 CollectiveParams cp;
116 cp.group.group_key = request->group_key();
117 cp.group.group_size = request->group_size();
118 cp.group.device_type = DeviceType(request->device_type());
119 for (const string& dn : request->device_name()) {
120 cp.instance.device_names.push_back(dn);
121 }
122 CompleteGroupDistributed(
123 cp.instance.device_names[0], &cp, cancel_mgr,
124 [this, response, done](const Status& s, const GroupRec* gr) {
125 if (s.ok()) {
126 mutex_lock l(gr->mu);
127 response->set_group_key(gr->group.group_key);
128 response->set_group_size(gr->group.group_size);
129 response->set_device_type(gr->group.device_type.type_string());
130 response->set_num_tasks(gr->task_set.size());
131 for (const string& dn : gr->device_list) {
132 response->add_device_name(dn);
133 }
134 for (const string& tn : gr->task_list) {
135 response->add_task_name(tn);
136 }
137 } else {
138 LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
139 }
140 done(s);
141 });
142 }
143
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)144 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
145 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
146 CancellationManager* cancel_mgr, const StatusCallback& done) {
147 CollectiveParams* cp = new CollectiveParams;
148 cp->name = request->name();
149 cp->group.group_key = request->group_key();
150 cp->group.group_size = request->group_size();
151 cp->group.device_type = DeviceType(request->device_type());
152 cp->instance.type = CollectiveType(request->type());
153 cp->instance.instance_key = request->instance_key();
154 cp->instance.data_type = request->data_type();
155 cp->instance.shape = TensorShape(request->shape());
156 for (int32 offset : request->subdiv_offset()) {
157 cp->instance.impl_details.subdiv_offsets.push_back(offset);
158 }
159 string* device = new string(request->device());
160 VLOG(1) << "New cp " << cp << " for device " << *device << " : "
161 << cp->ToString();
162 StatusCallback done_and_cleanup = [this, cp, device, done](const Status& s) {
163 done(s);
164 delete cp;
165 delete device;
166 };
167 // Start by completing the group.
168 CompleteGroupDistributed(
169 *device, cp, cancel_mgr,
170 [this, cp, device, response, cancel_mgr, done_and_cleanup](
171 const Status& cg_status, const GroupRec* gr) {
172 if (cg_status.ok()) {
173 // Then complete the instance.
174 CompleteInstanceDistributed(
175 *device, gr, cp, cancel_mgr,
176 [this, gr, cp, response,
177 done_and_cleanup](const Status& ci_status) {
178 if (ci_status.ok()) {
179 // Now source_rank should be known, so
180 // retrieve it.
181 FindInstanceRec(
182 gr, cp,
183 [this, gr, cp, response, done_and_cleanup](
184 const Status& fi_status, InstanceRec* ir) {
185 if (fi_status.ok()) {
186 mutex_lock l(ir->out_mu);
187 ir->WaitForOutMu(l);
188 response->set_instance_key(cp->instance.instance_key);
189 response->set_source_rank(ir->source_rank);
190 if (!cp->instance.communicator_key.empty()) {
191 response->set_communicator_key(
192 cp->instance.communicator_key);
193 }
194 done_and_cleanup(fi_status);
195 } else {
196 done_and_cleanup(fi_status);
197 }
198 });
199 } else {
200 done_and_cleanup(ci_status);
201 }
202 });
203 } else {
204 done_and_cleanup(cg_status);
205 }
206 });
207 }
208
GroupIsCached(int32 group_key)209 bool CollectiveParamResolverDistributed::GroupIsCached(int32 group_key) {
210 mutex_lock l(group_mu_);
211 const auto& it = group_table_.find(group_key);
212 return it != group_table_.end();
213 }
214
UpdateGroupCache(const CompleteGroupResponse & resp)215 Status CollectiveParamResolverDistributed::UpdateGroupCache(
216 const CompleteGroupResponse& resp) {
217 // Build a new record from resp.
218 std::unique_ptr<GroupRec> gr(new GroupRec);
219 mutex_lock grl(gr->mu);
220 gr->group.device_type = DeviceType(resp.device_type());
221 gr->group.group_key = resp.group_key();
222 gr->group.group_size = resp.group_size();
223 gr->group.num_tasks = resp.num_tasks();
224 if (resp.device_name_size() != gr->group.group_size) {
225 return errors::Internal(
226 "CompleteGroupResponse group_size doesn't match device_name list");
227 }
228 for (const string& dn : resp.device_name()) {
229 gr->device_set.insert(dn);
230 gr->device_list.push_back(dn);
231 }
232 if (resp.task_name_size() != gr->group.group_size) {
233 return errors::Internal(
234 "CompleteGroupResponse group_size doesn't match task_name list");
235 }
236 for (const string& tn : resp.task_name()) {
237 gr->task_list.push_back(tn);
238 gr->task_set.insert(tn);
239 }
240 CHECK_EQ(gr->task_set.size(), gr->group.num_tasks);
241 {
242 // Group membership should never change. Once a record is in group_table_
243 // it never gets removed.
244 mutex_lock l(group_mu_);
245 auto it = group_table_.find(gr->group.group_key);
246 if (it == group_table_.end()) {
247 group_table_[gr->group.group_key] = std::move(gr);
248 }
249 }
250 return Status::OK();
251 }
252
CompleteGroupDistributed(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const GroupRecCallback & done)253 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
254 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
255 const GroupRecCallback& done) {
256 VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
257 << " dev: " << device << " is_leader=" << (group_leader_.empty());
258 if (group_leader_.empty()) {
259 // This is the group leader, so resolution is local.
260 return CompleteGroupLocal(device, cp, done);
261 } else if (!GroupIsCached(cp->group.group_key)) {
262 // Need to update Group cache from the leader.
263 CompleteGroupCall* call = new CompleteGroupCall(
264 cp->group, device, cancel_mgr, group_leader_, worker_cache_);
265 call->Start([this, device, cp, call, done](const Status& s) {
266 if (s.ok()) {
267 Status status = UpdateGroupCache(call->resp_);
268 if (status.ok()) {
269 CompleteGroupLocal(device, cp, done);
270 } else {
271 done(status, nullptr);
272 }
273 } else {
274 done(s, nullptr);
275 }
276 delete call;
277 });
278 return;
279 } else {
280 return CompleteGroupLocal(device, cp, done);
281 }
282 }
283
InstanceIsCached(int32 instance_key)284 bool CollectiveParamResolverDistributed::InstanceIsCached(int32 instance_key) {
285 mutex_lock l(instance_mu_);
286 const auto& it = instance_table_.find(instance_key);
287 return it != instance_table_.end();
288 }
289
UpdateInstanceCache(const GroupRec * gr,CollectiveParams * cp,const CompleteInstanceResponse & resp,const StatusCallback & done)290 void CollectiveParamResolverDistributed::UpdateInstanceCache(
291 const GroupRec* gr, CollectiveParams* cp,
292 const CompleteInstanceResponse& resp, const StatusCallback& done) {
293 using InstanceRecPointer = InstanceRec*;
294 InstanceRecPointer* irp = new InstanceRecPointer(nullptr);
295 int32 source_rank = resp.source_rank();
296 string communicator_key = resp.communicator_key();
297
298 auto continue_with_ir = [cp, irp, source_rank, communicator_key,
299 done](const Status& s) {
300 if (!s.ok()) {
301 done(s);
302 delete irp;
303 return;
304 }
305 Status status;
306 InstanceRec* ir = *irp;
307 do {
308 mutex_lock l(ir->out_mu);
309 ir->WaitForOutMu(l);
310 if (ir->source_rank != source_rank) {
311 if (ir->source_rank >= 0) {
312 ir->status = errors::Internal(
313 "UpdateInstanceCache: CompleteInstanceResponse for instance ",
314 cp->instance.instance_key, " gives source_rank=", source_rank,
315 " but cache already holds value=", ir->source_rank);
316 status = ir->status;
317 break;
318 }
319 ir->source_rank = source_rank;
320 }
321 if (ir->communicator_key != communicator_key) {
322 if (!ir->communicator_key.empty()) {
323 ir->status = errors::Internal(
324 "UpdateInstanceCache: CompleteInstanceResponse for instance ",
325 cp->instance.instance_key,
326 " gives communicator_key with size =", communicator_key.size(),
327 " but cache already holds communicator_key with size=",
328 ir->communicator_key.size());
329 status = ir->status;
330 break;
331 }
332 ir->communicator_key = communicator_key;
333 }
334 if (ir->known_count < cp->group.group_size) {
335 ir->known_count = cp->group.group_size;
336 if (ir->known.size() != cp->group.group_size) {
337 ir->status = errors::Internal(
338 "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
339 cp->instance.instance_key, " has known.size()=", ir->known.size(),
340 " < group_size=", cp->group.group_size);
341 status = ir->status;
342 break;
343 }
344 for (int i = 0; i < ir->known.size(); ++i) {
345 ir->known[i] = true;
346 }
347 }
348 status = ir->status;
349 } while (false);
350 // Callback outside of lock.
351 done(status);
352 delete irp;
353 };
354
355 FindInstanceRec(
356 gr, cp, [this, irp, continue_with_ir](const Status s, InstanceRec* irec) {
357 *irp = irec;
358 continue_with_ir(s);
359 });
360 }
361
CompleteInstanceDistributed(const string & device,const GroupRec * gr,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)362 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
363 const string& device, const GroupRec* gr, CollectiveParams* cp,
364 CancellationManager* cancel_mgr, const StatusCallback& done) {
365 if (group_leader_.empty()) {
366 // This is the group leader so resolution is local.
367 return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
368 } else if (InstanceIsCached(cp->instance.instance_key)) {
369 return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
370 } else {
371 CompleteInstanceCall* call = new CompleteInstanceCall(
372 cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
373 group_leader_, worker_cache_);
374 call->Start([this, device, gr, cp, call, done](const Status& s) {
375 if (s.ok()) {
376 UpdateInstanceCache(
377 gr, cp, call->resp_, [this, device, gr, cp, done](const Status& s) {
378 if (!s.ok()) {
379 done(s);
380 } else {
381 CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
382 }
383 });
384 } else {
385 done(s);
386 }
387 delete call;
388 });
389 return;
390 }
391 }
392
393 } // namespace tensorflow
394