1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
21 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28
29 namespace tensorflow {
30 namespace {
31
32 class CompleteGroupCall : public CancellableCall {
33 public:
CompleteGroupCall(const CollGroupParams & group,const DeviceAttributes & device,const CollectiveType & collective_type,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)34 CompleteGroupCall(const CollGroupParams& group,
35 const DeviceAttributes& device,
36 const CollectiveType& collective_type,
37 CancellationManager* cancel_mgr,
38 const string& remote_worker, WorkerCacheInterface* wc)
39 : CancellableCall(cancel_mgr, remote_worker, wc) {
40 req_.set_group_key(group.group_key);
41 req_.set_group_size(group.group_size);
42 req_.set_device_type(group.device_type.type_string());
43 *req_.mutable_device_attributes() = device;
44 req_.set_collective_type(collective_type);
45 }
~CompleteGroupCall()46 ~CompleteGroupCall() override {}
47
IssueCall(const StatusCallback & done)48 void IssueCall(const StatusCallback& done) override {
49 wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
50 }
51
52 CompleteGroupRequest req_;
53 CompleteGroupResponse resp_;
54 };
55
56 class CompleteInstanceCall : public CancellableCall {
57 public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)58 CompleteInstanceCall(const CollGroupParams& group,
59 const CollInstanceParams& instance,
60 const string& node_name, const string& device_name,
61 bool is_source, CancellationManager* cancel_mgr,
62 const string& remote_worker, WorkerCacheInterface* wc)
63 : CancellableCall(cancel_mgr, remote_worker, wc) {
64 req_.set_name(node_name);
65 req_.set_type(instance.type);
66 req_.set_data_type(instance.data_type);
67 instance.shape.AsProto(req_.mutable_shape());
68 req_.set_group_key(group.group_key);
69 req_.set_group_size(group.group_size);
70 req_.set_instance_key(instance.instance_key);
71 req_.set_device_type(group.device_type.type_string());
72 for (int32 offset : instance.impl_details.subdiv_offsets) {
73 req_.add_subdiv_offset(offset);
74 }
75 req_.set_device(device_name);
76 req_.set_is_source(is_source);
77 }
78
~CompleteInstanceCall()79 ~CompleteInstanceCall() override {}
80
IssueCall(const StatusCallback & done)81 void IssueCall(const StatusCallback& done) override {
82 wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
83 }
84
85 CompleteInstanceRequest req_;
86 CompleteInstanceResponse resp_;
87 };
88
89 } // namespace
90
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,WorkerCacheInterface * worker_cache,const string & task_name)91 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
92 const ConfigProto& config, const DeviceMgr* dev_mgr,
93 DeviceResolverDistributed* dev_resolver, WorkerCacheInterface* worker_cache,
94 const string& task_name)
95 : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver, task_name),
96 worker_cache_(worker_cache),
97 group_leader_(task_name == config.experimental().collective_group_leader()
98 ? ""
99 : config.experimental().collective_group_leader()) {
100 VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
101 << "} config.collective_group_leader={"
102 << config.experimental().collective_group_leader() << "}"
103 << " config.collective_nccl={"
104 << config.experimental().collective_nccl() << "}";
105 }
106
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)107 void CollectiveParamResolverDistributed::CompleteParamsAsync(
108 const DeviceAttributes& device, CollectiveParams* cp,
109 CancellationManager* cancel_mgr, const StatusCallback& done) {
110 VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
111 << ": " << cp->ToString();
112 CompleteGroupDistributed(
113 device, cp, cancel_mgr,
114 [this, device, cp, cancel_mgr, done](Status s, const GroupRec* gr) {
115 if (s.ok()) {
116 std::vector<DeviceAttributes> attributes;
117 mutex_lock l(gr->mu);
118 for (const auto& item : gr->devices) {
119 attributes.push_back(item.second);
120 }
121 s = dev_resolver_->UpdateDeviceAttributes(attributes);
122 }
123 if (s.ok()) {
124 CompleteInstanceDistributed(device.name(), gr, cp, cancel_mgr, done);
125 } else {
126 done(s);
127 }
128 });
129 }
130
CompleteGroupAsync(const CompleteGroupRequest * request,CompleteGroupResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)131 void CollectiveParamResolverDistributed::CompleteGroupAsync(
132 const CompleteGroupRequest* request, CompleteGroupResponse* response,
133 CancellationManager* cancel_mgr, const StatusCallback& done) {
134 if (!request->has_device_attributes()) {
135 done(errors::Internal(
136 "CompleteGroupRequest device_attributes is not set. Make sure you're "
137 "running the same version of Tensorflow on all workers."));
138 return;
139 }
140 auto* cp = new CollectiveParams();
141 core::ScopedUnref unref(cp);
142 cp->group.group_key = request->group_key();
143 cp->group.group_size = request->group_size();
144 cp->group.device_type = DeviceType(request->device_type());
145 cp->instance.type = CollectiveType(request->collective_type());
146 CompleteGroupDistributed(
147 request->device_attributes(), cp, cancel_mgr,
148 [response, done](const Status& s, const GroupRec* gr) {
149 if (s.ok()) {
150 mutex_lock l(gr->mu);
151 response->set_group_key(gr->group.group_key);
152 response->set_group_size(gr->group.group_size);
153 response->set_device_type(gr->group.device_type.type_string());
154 response->set_num_tasks(gr->group.num_tasks);
155 for (const auto& item : gr->devices) {
156 *response->add_device_attributes() = item.second;
157 }
158 response->set_communicator_key(
159 gr->group.runtime_details.communicator_key);
160 } else {
161 LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
162 }
163 done(s);
164 });
165 }
166
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)167 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
168 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
169 CancellationManager* cancel_mgr, const StatusCallback& done) {
170 GroupRec* gr = GetCachedGroup(request->group_key());
171 if (gr == nullptr) {
172 done(errors::FailedPrecondition(
173 "group ", request->group_key(),
174 " not found. This normally means the server has restarted"));
175 return;
176 }
177 {
178 mutex_lock l(gr->mu);
179 if (!gr->status.ok() || gr->devices.size() != gr->group.group_size) {
180 done(errors::FailedPrecondition(
181 "group ", request->group_key(),
182 " failed to resolve. This normally means the server has restarted"));
183 return;
184 }
185 }
186 CollectiveParams* cp = new CollectiveParams;
187 cp->name = request->name();
188 cp->group.group_key = request->group_key();
189 cp->group.group_size = request->group_size();
190 cp->group.device_type = DeviceType(request->device_type());
191 cp->instance.type = CollectiveType(request->type());
192 cp->instance.instance_key = request->instance_key();
193 cp->instance.data_type = request->data_type();
194 cp->instance.shape = TensorShape(request->shape());
195 for (int32 offset : request->subdiv_offset()) {
196 cp->instance.impl_details.subdiv_offsets.push_back(offset);
197 }
198 StatusCallback done_and_cleanup = [cp, done](const Status& s) {
199 done(s);
200 cp->Unref();
201 };
202 CompleteInstanceDistributed(
203 request->device(), gr, cp, cancel_mgr,
204 [this, gr, cp, response, done_and_cleanup](Status status) {
205 if (status.ok()) {
206 // Now source_rank should be known, so retrieve it.
207 bool created_irec;
208 InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
209 {
210 mutex_lock l(ir->mu);
211 status = ir->status;
212 if (ir->status.ok()) {
213 response->set_instance_key(cp->instance.instance_key);
214 response->set_source_rank(ir->source_rank);
215 }
216 }
217 }
218 done_and_cleanup(status);
219 });
220 }
221
222 CollectiveParamResolverDistributed::GroupRec*
GetCachedGroup(int32 group_key)223 CollectiveParamResolverDistributed::GetCachedGroup(int32 group_key) {
224 mutex_lock l(group_mu_);
225 auto it = group_table_.find(group_key);
226 if (it == group_table_.end()) {
227 return nullptr;
228 }
229 return it->second.get();
230 }
231
UpdateGroupCache(const CompleteGroupResponse & resp)232 Status CollectiveParamResolverDistributed::UpdateGroupCache(
233 const CompleteGroupResponse& resp) {
234 // Build a new record from resp.
235 std::unique_ptr<GroupRec> gr(new GroupRec);
236 {
237 mutex_lock grl(gr->mu);
238 gr->group.device_type = DeviceType(resp.device_type());
239 gr->group.group_key = resp.group_key();
240 gr->group.group_size = resp.group_size();
241 gr->group.num_tasks = resp.num_tasks();
242 if (resp.device_attributes().empty()) {
243 return errors::Internal(
244 "CompleteGroupResponse device_attributes is empty. Make sure you're "
245 "running the same version of Tensorflow on all workers.");
246 }
247 if (resp.device_attributes_size() != gr->group.group_size) {
248 return errors::Internal(
249 "CompleteGroupResponse group_size doesn't match device_name list");
250 }
251 for (const DeviceAttributes& device : resp.device_attributes()) {
252 gr->devices[device.name()] = device;
253 }
254 gr->group.runtime_details.communicator_key = resp.communicator_key();
255 FinishGroup(gr.get());
256 }
257 GroupRec* previous_gr = nullptr;
258 {
259 // Group membership should never change. Once a record is in group_table_
260 // it never gets removed.
261 mutex_lock l(group_mu_);
262 auto it = group_table_.find(resp.group_key());
263 if (it == group_table_.end()) {
264 VLOG(2) << "UpdateGroupCache: communicator_key="
265 << absl::CEscape(resp.communicator_key());
266 group_table_[gr->group.group_key] = std::move(gr);
267 } else {
268 previous_gr = it->second.get();
269 }
270 }
271 if (previous_gr != nullptr) {
272 mutex_lock grl(previous_gr->mu);
273 if (previous_gr->group.runtime_details.communicator_key !=
274 resp.communicator_key()) {
275 return errors::Internal(
276 "UpdateGroupCache: CompleteGroupResponse for group ",
277 resp.group_key(),
278 " gives communicator_key=", absl::CEscape(resp.communicator_key()),
279 " but cache already holds communicator_key=",
280 absl::CEscape(previous_gr->group.runtime_details.communicator_key));
281 }
282 }
283 return Status::OK();
284 }
285
CompleteGroupDistributed(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const GroupRecCallback & done)286 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
287 const DeviceAttributes& device, CollectiveParams* cp,
288 CancellationManager* cancel_mgr, const GroupRecCallback& done) {
289 VLOG(1) << "CompleteGroupDistributed group_key=" << cp->group.group_key
290 << " dev: " << device.name()
291 << " is_leader=" << (group_leader_.empty());
292 if (group_leader_.empty()) {
293 // This is the group leader, so resolution is local.
294 return CompleteGroupLocal(device, cp, done, cancel_mgr);
295 } else if (GetCachedGroup(cp->group.group_key) == nullptr) {
296 // Need to update Group cache from the leader.
297 CompleteGroupCall* call =
298 new CompleteGroupCall(cp->group, device, cp->instance.type, cancel_mgr,
299 group_leader_, worker_cache_);
300 CancellationToken abortion_token =
301 abortion_cancel_mgr_.get_cancellation_token();
302 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
303 abortion_token, [call] { call->Cancel(); });
304 if (already_aborted) {
305 done(errors::Cancelled("collective ops already aborted"), nullptr);
306 delete call;
307 return;
308 }
309 call->Start([this, device, cp, call, cancel_mgr, abortion_token,
310 done](const Status& s) {
311 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
312 if (s.ok()) {
313 Status status = UpdateGroupCache(call->resp_);
314 if (status.ok()) {
315 CompleteGroupLocal(device, cp, done, cancel_mgr);
316 } else {
317 done(status, nullptr);
318 }
319 } else {
320 done(s, nullptr);
321 }
322 delete call;
323 });
324 return;
325 } else {
326 return CompleteGroupLocal(device, cp, done, cancel_mgr);
327 }
328 }
329
InstanceIsCached(int32 group_key,int32 instance_key)330 bool CollectiveParamResolverDistributed::InstanceIsCached(int32 group_key,
331 int32 instance_key) {
332 mutex_lock l(instance_mu_);
333 auto group_it = instance_table_.find(group_key);
334 if (group_it == instance_table_.end()) {
335 return false;
336 }
337 auto instance_it = group_it->second.find(instance_key);
338 return instance_it != group_it->second.end();
339 }
340
UpdateInstanceCache(const GroupRec * gr,CollectiveParams * cp,const CompleteInstanceResponse & resp)341 Status CollectiveParamResolverDistributed::UpdateInstanceCache(
342 const GroupRec* gr, CollectiveParams* cp,
343 const CompleteInstanceResponse& resp) {
344 int32 source_rank = resp.source_rank();
345 bool created_irec;
346 InstanceRec* ir = GetOrCreateInstanceRec(gr, cp, &created_irec);
347 mutex_lock l(ir->mu);
348 if (!ir->status.ok()) {
349 return ir->status;
350 }
351 if (ir->source_rank != source_rank) {
352 if (ir->source_rank >= 0) {
353 ir->status = errors::Internal(
354 "UpdateInstanceCache: CompleteInstanceResponse for instance ",
355 cp->instance.instance_key, " gives source_rank=", source_rank,
356 " but cache already holds value=", ir->source_rank);
357 return ir->status;
358 }
359 ir->source_rank = source_rank;
360 }
361 if (ir->known_count < cp->group.group_size) {
362 ir->known_count = cp->group.group_size;
363 const int ir_known_size = ir->known.size();
364 if (ir_known_size != cp->group.group_size) {
365 ir->status = errors::Internal(
366 "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
367 cp->instance.instance_key, " has known.size()=", ir->known.size(),
368 " < group_size=", cp->group.group_size);
369 return ir->status;
370 }
371 for (int i = 0; i < ir_known_size; ++i) {
372 ir->known[i] = true;
373 }
374 }
375 return ir->status;
376 }
377
CompleteInstanceDistributed(const string & device,const GroupRec * gr,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)378 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
379 const string& device, const GroupRec* gr, CollectiveParams* cp,
380 CancellationManager* cancel_mgr, const StatusCallback& done) {
381 if (group_leader_.empty()) {
382 // This is the group leader so resolution is local.
383 return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
384 } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
385 return CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
386 } else {
387 CompleteInstanceCall* call = new CompleteInstanceCall(
388 cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
389 group_leader_, worker_cache_);
390 CancellationToken abortion_token =
391 abortion_cancel_mgr_.get_cancellation_token();
392 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
393 abortion_token, [call] { call->Cancel(); });
394 if (already_aborted) {
395 done(errors::Cancelled("collective ops already aborted"));
396 delete call;
397 return;
398 }
399 call->Start([this, device, gr, cp, call, abortion_token, done](Status s) {
400 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
401 if (s.ok()) {
402 s = UpdateInstanceCache(gr, cp, call->resp_);
403 }
404 if (s.ok()) {
405 CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
406 } else {
407 done(s);
408 }
409 delete call;
410 });
411 return;
412 }
413 }
414
StartAbort(const Status & s)415 void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
416 {
417 mutex_lock l(status_mu_);
418 if (!status_.ok()) {
419 VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
420 "subsequent abortion with status: "
421 << s;
422 return;
423 }
424 status_ = s;
425 }
426 StartAbortLocal(s);
427 abortion_cancel_mgr_.StartCancel();
428 }
429
430 } // namespace tensorflow
431