1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
21 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28
29 namespace tensorflow {
30 namespace {
31
32 class CompleteGroupCall : public CancellableCall {
33 public:
CompleteGroupCall(const CollGroupParams & group,const DeviceAttributes & device,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)34 CompleteGroupCall(const CollGroupParams& group,
35 const DeviceAttributes& device,
36 CancellationManager* cancel_mgr,
37 const string& remote_worker, WorkerCacheInterface* wc)
38 : CancellableCall(cancel_mgr, remote_worker, wc) {
39 req_.set_group_key(group.group_key);
40 req_.set_group_size(group.group_size);
41 req_.set_device_type(group.device_type.type_string());
42 *req_.mutable_device_attributes() = device;
43 }
~CompleteGroupCall()44 ~CompleteGroupCall() override {}
45
IssueCall(const StatusCallback & done)46 void IssueCall(const StatusCallback& done) override {
47 wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
48 }
49
50 CompleteGroupRequest req_;
51 CompleteGroupResponse resp_;
52 };
53
54 class CompleteInstanceCall : public CancellableCall {
55 public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)56 CompleteInstanceCall(const CollGroupParams& group,
57 const CollInstanceParams& instance,
58 const string& node_name, const string& device_name,
59 bool is_source, CancellationManager* cancel_mgr,
60 const string& remote_worker, WorkerCacheInterface* wc)
61 : CancellableCall(cancel_mgr, remote_worker, wc) {
62 req_.set_name(node_name);
63 req_.set_type(instance.type);
64 req_.set_data_type(instance.data_type);
65 instance.shape.AsProto(req_.mutable_shape());
66 req_.set_group_key(group.group_key);
67 req_.set_group_size(group.group_size);
68 req_.set_instance_key(instance.instance_key);
69 req_.set_device_type(group.device_type.type_string());
70 for (int32_t offset : instance.impl_details.subdiv_offsets) {
71 req_.add_subdiv_offset(offset);
72 }
73 req_.set_device(device_name);
74 req_.set_is_source(is_source);
75 }
76
~CompleteInstanceCall()77 ~CompleteInstanceCall() override {}
78
IssueCall(const StatusCallback & done)79 void IssueCall(const StatusCallback& done) override {
80 wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
81 }
82
83 CompleteInstanceRequest req_;
84 CompleteInstanceResponse resp_;
85 };
86
87 } // namespace
88
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,NcclCommunicatorInterface * nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)89 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
90 const ConfigProto& config, const DeviceMgr* dev_mgr,
91 DeviceResolverDistributed* dev_resolver,
92 NcclCommunicatorInterface* nccl_communicator,
93 WorkerCacheInterface* worker_cache, const string& task_name)
94 : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver,
95 nccl_communicator, task_name),
96 worker_cache_(worker_cache),
97 group_leader_(task_name == config.experimental().collective_group_leader()
98 ? ""
99 : config.experimental().collective_group_leader()) {
100 VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
101 << "} config.collective_group_leader={"
102 << config.experimental().collective_group_leader() << "}"
103 << " config.collective_nccl={"
104 << config.experimental().collective_nccl() << "}";
105 }
106
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)107 void CollectiveParamResolverDistributed::CompleteParamsAsync(
108 const DeviceAttributes& device, CollectiveParams* cp,
109 CancellationManager* cancel_mgr, const StatusCallback& done) {
110 VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
111 << ": " << cp->ToString();
112 CompleteGroupDistributed(
113 device, &cp->group, cancel_mgr,
114 [this, device, cp, cancel_mgr, done](Status s) {
115 if (s.ok()) {
116 s = dev_resolver_->UpdateDeviceAttributes(cp->group.devices);
117 }
118 if (s.ok()) {
119 CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
120 } else {
121 done(s);
122 }
123 });
124 }
125
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)126 void CollectiveParamResolverDistributed::CompleteGroupAsync(
127 const DeviceAttributes& device, CollGroupParams* group_params,
128 CancellationManager* cancel_mgr, const StatusCallback& done) {
129 CompleteGroupDistributed(device, group_params, cancel_mgr, done);
130 }
131
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)132 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
133 const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
134 CancellationManager* cancel_mgr, const StatusCallback& done) {
135 GroupRec* gr = GetCachedGroup(request->group_key());
136 if (gr == nullptr) {
137 done(errors::FailedPrecondition(
138 "group ", request->group_key(),
139 " not found. This normally means the server has restarted"));
140 return;
141 }
142 CollectiveParams* cp = new CollectiveParams;
143 {
144 mutex_lock l(gr->mu);
145 if (!gr->status.ok()) {
146 done(gr->status);
147 return;
148 } else if (gr->group.devices.size() != gr->group.group_size) {
149 done(errors::FailedPrecondition(
150 "group ", request->group_key(),
151 " failed to resolve. This normally means the server has restarted"));
152 return;
153 }
154 cp->group = gr->group;
155 }
156 cp->name = request->name();
157 cp->instance.type = CollectiveType(request->type());
158 cp->instance.instance_key = request->instance_key();
159 cp->instance.data_type = request->data_type();
160 cp->instance.shape = TensorShape(request->shape());
161 cp->is_source = request->is_source();
162 for (int32_t offset : request->subdiv_offset()) {
163 cp->instance.impl_details.subdiv_offsets.push_back(offset);
164 }
165 StatusCallback done_and_cleanup = [cp, done](const Status& s) {
166 done(s);
167 cp->Unref();
168 };
169 CompleteInstanceDistributed(
170 request->device(), cp, cancel_mgr,
171 [this, cp, response, done_and_cleanup](Status status) {
172 if (status.ok()) {
173 // Now source_rank should be known, so retrieve it.
174 bool created_irec;
175 InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
176 {
177 mutex_lock l(ir->mu);
178 status = ir->status;
179 if (ir->status.ok()) {
180 response->set_instance_key(cp->instance.instance_key);
181 response->set_source_rank(ir->source_rank);
182 }
183 }
184 }
185 done_and_cleanup(status);
186 });
187 }
188
189 CollectiveParamResolverDistributed::GroupRec*
GetCachedGroup(int32_t group_key)190 CollectiveParamResolverDistributed::GetCachedGroup(int32_t group_key) {
191 mutex_lock l(group_mu_);
192 auto it = group_table_.find(group_key);
193 if (it == group_table_.end()) {
194 return nullptr;
195 }
196 return it->second.get();
197 }
198
UpdateGroupCache(const CompleteGroupResponse & resp)199 Status CollectiveParamResolverDistributed::UpdateGroupCache(
200 const CompleteGroupResponse& resp) {
201 // Build a new record from resp.
202 std::unique_ptr<GroupRec> gr(new GroupRec);
203 {
204 mutex_lock grl(gr->mu);
205 gr->group.device_type = DeviceType(resp.device_type());
206 gr->group.group_key = resp.group_key();
207 gr->group.group_size = resp.group_size();
208 gr->group.num_tasks = resp.num_tasks();
209 if (resp.device_attributes().empty()) {
210 return errors::Internal(
211 "CompleteGroupResponse device_attributes is empty. Make sure you're "
212 "running the same version of Tensorflow on all workers.");
213 }
214 if (resp.device_attributes_size() != gr->group.group_size) {
215 return errors::Internal(
216 "CompleteGroupResponse group_size doesn't match device_name list");
217 }
218 gr->group.devices.reserve(resp.device_attributes().size());
219 for (const DeviceAttributes& device : resp.device_attributes()) {
220 gr->group.devices.push_back(device);
221 gr->incarnations_by_device_name[device.name()] = device.incarnation();
222 }
223 gr->group.runtime_details.communicator_key = resp.communicator_key();
224 FinishGroup(gr.get());
225 }
226 GroupRec* previous_gr = nullptr;
227 {
228 // Group membership should never change. Once a record is in group_table_
229 // it never gets removed.
230 mutex_lock l(group_mu_);
231 auto it = group_table_.find(resp.group_key());
232 if (it == group_table_.end()) {
233 VLOG(2) << "UpdateGroupCache: communicator_key="
234 << absl::CEscape(resp.communicator_key());
235 group_table_[gr->group.group_key] = std::move(gr);
236 } else {
237 previous_gr = it->second.get();
238 }
239 }
240 if (previous_gr != nullptr) {
241 mutex_lock grl(previous_gr->mu);
242 if (previous_gr->group.runtime_details.communicator_key !=
243 resp.communicator_key()) {
244 return errors::Internal(
245 "UpdateGroupCache: CompleteGroupResponse for group ",
246 resp.group_key(),
247 " gives communicator_key=", absl::CEscape(resp.communicator_key()),
248 " but cache already holds communicator_key=",
249 absl::CEscape(previous_gr->group.runtime_details.communicator_key));
250 }
251 }
252 return Status::OK();
253 }
254
CompleteGroupDistributed(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)255 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
256 const DeviceAttributes& device, CollGroupParams* group_params,
257 CancellationManager* cancel_mgr, const StatusCallback& done) {
258 VLOG(1) << "CompleteGroupDistributed group_key=" << group_params->group_key
259 << " dev: " << device.name()
260 << " is_leader=" << (group_leader_.empty());
261 if (group_leader_.empty()) {
262 // This is the group leader, so resolution is local.
263 return CompleteGroupLocal(device, group_params, cancel_mgr, done);
264 } else if (GetCachedGroup(group_params->group_key) == nullptr) {
265 // Need to update Group cache from the leader.
266 CompleteGroupCall* call = new CompleteGroupCall(
267 *group_params, device, cancel_mgr, group_leader_, worker_cache_);
268 CancellationToken abortion_token =
269 abortion_cancel_mgr_.get_cancellation_token();
270 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
271 abortion_token, [call] { call->Cancel(); });
272 if (already_aborted) {
273 done(errors::Cancelled("collective ops already aborted"));
274 delete call;
275 return;
276 }
277 call->Start([this, device, group_params, call, cancel_mgr, abortion_token,
278 done](const Status& s) {
279 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
280 if (s.ok()) {
281 Status status = UpdateGroupCache(call->resp_);
282 if (status.ok()) {
283 CompleteGroupLocal(device, group_params, cancel_mgr, done);
284 } else {
285 done(status);
286 }
287 } else {
288 done(s);
289 }
290 delete call;
291 });
292 return;
293 } else {
294 return CompleteGroupLocal(device, group_params, cancel_mgr, done);
295 }
296 }
297
InstanceIsCached(int32_t group_key,int32_t instance_key)298 bool CollectiveParamResolverDistributed::InstanceIsCached(
299 int32_t group_key, int32_t instance_key) {
300 mutex_lock l(instance_mu_);
301 auto group_it = instance_table_.find(group_key);
302 if (group_it == instance_table_.end()) {
303 return false;
304 }
305 auto instance_it = group_it->second.find(instance_key);
306 return instance_it != group_it->second.end();
307 }
308
UpdateInstanceCache(CollectiveParams * cp,const CompleteInstanceResponse & resp)309 Status CollectiveParamResolverDistributed::UpdateInstanceCache(
310 CollectiveParams* cp, const CompleteInstanceResponse& resp) {
311 int32_t source_rank = resp.source_rank();
312 bool created_irec;
313 InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
314 mutex_lock l(ir->mu);
315 if (!ir->status.ok()) {
316 return ir->status;
317 }
318 if (ir->source_rank != source_rank) {
319 if (ir->source_rank >= 0) {
320 ir->status = errors::Internal(
321 "UpdateInstanceCache: CompleteInstanceResponse for instance ",
322 cp->instance.instance_key, " gives source_rank=", source_rank,
323 " but cache already holds value=", ir->source_rank);
324 return ir->status;
325 }
326 ir->source_rank = source_rank;
327 }
328 if (ir->known_count < cp->group.group_size) {
329 ir->known_count = cp->group.group_size;
330 const int ir_known_size = ir->known.size();
331 if (ir_known_size != cp->group.group_size) {
332 ir->status = errors::Internal(
333 "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
334 cp->instance.instance_key, " has known.size()=", ir->known.size(),
335 " < group_size=", cp->group.group_size);
336 return ir->status;
337 }
338 for (int i = 0; i < ir_known_size; ++i) {
339 ir->known[i] = true;
340 }
341 }
342 return ir->status;
343 }
344
CompleteInstanceDistributed(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)345 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
346 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
347 const StatusCallback& done) {
348 if (group_leader_.empty()) {
349 // This is the group leader so resolution is local.
350 return CompleteInstanceLocal(device, cp, done);
351 } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
352 return CompleteInstanceLocal(device, cp, done);
353 } else {
354 CompleteInstanceCall* call = new CompleteInstanceCall(
355 cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
356 group_leader_, worker_cache_);
357 CancellationToken abortion_token =
358 abortion_cancel_mgr_.get_cancellation_token();
359 bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
360 abortion_token, [call] { call->Cancel(); });
361 if (already_aborted) {
362 done(errors::Cancelled("collective ops already aborted"));
363 delete call;
364 return;
365 }
366 call->Start([this, device, cp, call, abortion_token, done](Status s) {
367 abortion_cancel_mgr_.DeregisterCallback(abortion_token);
368 if (s.ok()) {
369 s = UpdateInstanceCache(cp, call->resp_);
370 }
371 if (s.ok()) {
372 CompleteInstanceLocal(device, cp, done);
373 } else {
374 done(s);
375 }
376 delete call;
377 });
378 return;
379 }
380 }
381
StartAbort(const Status & s)382 void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
383 {
384 mutex_lock l(status_mu_);
385 if (!status_.ok()) {
386 VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
387 "subsequent abortion with status: "
388 << s;
389 return;
390 }
391 status_ = s;
392 }
393 StartAbortLocal(s);
394 abortion_cancel_mgr_.StartCancel();
395 }
396
397 } // namespace tensorflow
398