1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace tensorflow {
39
StartAbortRendevous(Rendezvous * rendez,const Status & s)40 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
41 rendez->StartAbort(s);
42 rendez->Unref();
43 }
44
BaseRendezvousMgr(const WorkerEnv * worker_env)45 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
46 : worker_env_(worker_env) {}
47
~BaseRendezvousMgr()48 BaseRendezvousMgr::~BaseRendezvousMgr() {
49 for (auto& p : table_) {
50 auto rendez = p.second;
51 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
52 }
53 }
54
Find(int64 step_id)55 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
56 return FindOrCreate(step_id);
57 }
58
FindOrCreate(int64 step_id)59 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
60 mutex_lock l(mu_);
61 auto iter = table_.find(step_id);
62 if (iter == table_.end()) {
63 auto rr = Create(step_id, worker_env_);
64 iter = table_.insert({step_id, rr}).first;
65 }
66 iter->second->Ref();
67 return iter->second;
68 }
69
RecvLocalAsync(int64 step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)70 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
71 const Rendezvous::ParsedKey& parsed,
72 Rendezvous::DoneCallback done) {
73 auto rendez = FindOrCreate(step_id);
74 auto done_cb = [rendez, done = std::move(done)](
75 const Status& s, const Rendezvous::Args& send_args,
76 const Rendezvous::Args& recv_args, const Tensor& v,
77 bool dead) {
78 rendez->Unref();
79 done(s, send_args, recv_args, v, dead);
80 };
81 rendez->RecvLocalAsync(parsed, std::move(done_cb));
82 }
83
RecvLocal(int64 step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)84 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
85 const Rendezvous::ParsedKey& parsed,
86 Tensor* val, bool* is_dead) {
87 Status ret;
88 Notification n;
89 RecvLocalAsync(step_id, parsed,
90 [val, is_dead, &ret, &n](const Status& s,
91 const Rendezvous::Args& send_args,
92 const Rendezvous::Args& recv_args,
93 const Tensor& v, const bool dead) {
94 ret = s;
95 *val = v;
96 *is_dead = dead;
97 n.Notify();
98 });
99 n.WaitForNotification();
100 return ret;
101 }
102
Cleanup(int64 step_id)103 void BaseRendezvousMgr::Cleanup(int64 step_id) {
104 Rendezvous* rendez = nullptr;
105 {
106 mutex_lock l(mu_);
107 auto iter = table_.find(step_id);
108 if (iter != table_.end()) {
109 rendez = iter->second;
110 table_.erase(iter);
111 }
112 }
113 if (rendez) {
114 StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
115 }
116 }
117
CleanupAll()118 void BaseRendezvousMgr::CleanupAll() {
119 std::vector<Rendezvous*> rendezs;
120 {
121 mutex_lock l(mu_);
122 for (const auto& entry : table_) {
123 rendezs.push_back(entry.second);
124 }
125 table_.clear();
126 }
127 for (auto rendez : rendezs) {
128 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
129 }
130 }
131
BaseRemoteRendezvous(const WorkerEnv * env,int64 step_id)132 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
133 : env_(env),
134 step_id_(step_id),
135 local_(NewLocalRendezvous()),
136 session_(nullptr) {}
137
~BaseRemoteRendezvous()138 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
139 CHECK(active_.empty());
140 local_->Unref();
141 }
142
143 // Returns true if "device_name" is a valid full name of local device
144 // of the "worker". This helper is purely based on the worker name
145 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)146 static bool IsLocalDevice(const StringPiece worker_name,
147 const StringPiece device_name) {
148 return absl::StartsWith(device_name, worker_name);
149 }
150
Initialize(WorkerSession * session)151 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
152 CHECK_NE(session, nullptr) << "session must not be null!";
153 std::vector<DeferredCall> deferred_calls;
154 {
155 mutex_lock l(init_mu_);
156 if (session_ != nullptr) {
157 if (session_->worker_name() == session->worker_name()) {
158 VLOG(1) << "Skipping rendezvous re-initialization.";
159 return Status::OK();
160 }
161 Status s = errors::Internal(
162 "Double init! Worker names would have changed from: ",
163 session_->worker_name(), " -> ", session->worker_name());
164 LOG(WARNING) << s;
165 return s;
166 }
167 session_ = session;
168 std::swap(deferred_calls, deferred_calls_);
169 }
170 for (auto& call : deferred_calls) {
171 RecvLocalAsyncInternal(call.parsed, std::move(call.done));
172 }
173 return Status::OK();
174 }
175
session()176 WorkerSession* BaseRemoteRendezvous::session() {
177 tf_shared_lock l(init_mu_);
178 return session_;
179 }
180
is_initialized()181 bool BaseRemoteRendezvous::is_initialized() {
182 tf_shared_lock l(init_mu_);
183 return is_initialized_locked();
184 }
185
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)186 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
187 const Rendezvous::Args& args,
188 const Tensor& val, const bool is_dead) {
189 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
190 WorkerSession* sess = nullptr;
191 {
192 tf_shared_lock l(init_mu_);
193 if (!status_.ok()) return status_;
194 DCHECK(is_initialized_locked());
195 sess = session_;
196 }
197
198 if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
199 return errors::InvalidArgument(
200 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
201 sess->worker_name());
202 }
203
204 // Buffers "val" and "device_context" in local_.
205 return local_->Send(parsed, args, val, is_dead);
206 }
207
ValidateDevices(const ParsedKey & parsed,bool is_src)208 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
209 bool is_src) {
210 // Cache session pointer to avoid repeatedly taking & releasing the lock
211 // (e.g. calling session())
212 WorkerSession* sess = nullptr;
213 {
214 tf_shared_lock l(init_mu_);
215 if (!status_.ok()) return status_;
216 if (!is_initialized_locked()) {
217 return errors::Internal("ValidateDevices called before initialization.");
218 }
219 sess = session_;
220 }
221 if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
222 return errors::InvalidArgument(
223 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
224 sess->worker_name());
225 }
226 if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
227 return errors::InvalidArgument(
228 "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
229 sess->worker_name());
230 }
231 return Status::OK();
232 }
233
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)234 void BaseRemoteRendezvous::SameWorkerRecvDone(
235 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
236 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
237 StatusCallback done) {
238 // Do a quick copy (sharing the underlying buffer) if both tensors
239 // are on host memory.
240 const bool src_host =
241 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
242 const bool dst_host =
243 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
244 if (src_host && dst_host) {
245 *out = in;
246 done(Status::OK());
247 return;
248 }
249
250 // This copy must involve a GPU. Hence, "in" must support DMA
251 // (e.g., string tensors do not work on GPU). Variant copy DMA
252 // checks happen inside CopyTensor::ViaDMA.
253 if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
254 in.dtype() != DT_RESOURCE) {
255 done(errors::InvalidArgument(
256 "Non-DMA-safe ", DataTypeString(in.dtype()),
257 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
258 return;
259 }
260
261 WorkerSession* sess = session();
262 Device* src_device;
263 Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
264 if (!s.ok()) {
265 done(s);
266 return;
267 }
268 Device* dst_device;
269 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
270 if (!s.ok()) {
271 done(s);
272 return;
273 }
274
275 MEMDEBUG_CACHE_STEPID(0);
276 // Note that it would be nice to cache the step_id here, but it's not
277 // available.
278 MEMDEBUG_CACHE_OP("SameWorkerRecvDone");
279 AllocatorAttributes attr = recv_args.alloc_attrs;
280 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
281 recv_args.alloc_attrs.gpu_compatible());
282 Allocator* out_allocator = dst_device->GetAllocator(attr);
283 AllocationAttributes allocation_attr;
284 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
285 bool sync_dst_compute = (safe_alloc_frontier == 0);
286 std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
287 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
288 return safe_alloc_frontier;
289 };
290 if (!sync_dst_compute) {
291 allocation_attr.freed_by_func = &freed_by_func;
292 }
293 if (in.dtype() != DT_VARIANT) {
294 // Variants are handled by CopyTensor::ViaDMA.
295 Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
296 *out = copy;
297 }
298
299 // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
300 // etc.
301 CopyTensor::ViaDMA(
302 parsed.edge_name, send_args.device_context, recv_args.device_context,
303 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
304 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
305 }
306
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)307 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
308 DeviceNameUtils::ParsedName dst) {
309 return DeviceNameUtils::IsSameAddressSpace(src, dst);
310 }
311
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)312 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
313 const Rendezvous::Args& recv_args,
314 DoneCallback done) {
315 VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
316 Status s = ValidateDevices(parsed, false /*!is_src*/);
317 if (!s.ok()) {
318 done(s, Args(), recv_args, Tensor(), false);
319 return;
320 }
321
322 // ValidateDevices() returns an error status if the rendezvous is not
323 // initialized.
324 DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
325 << parsed.FullKey() << ").";
326
327 MEMDEBUG_CACHE_OP("RecvAsync");
328 MEMDEBUG_CACHE_STEPID(0);
329 // Are src and dst in the same worker?
330 if (IsSameWorker(parsed.src, parsed.dst)) {
331 // Recv the tensor from local_.
332 local_->RecvAsync(
333 parsed, recv_args,
334 [this, parsed, done](
335 const Status& status, const Rendezvous::Args& send_args,
336 const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
337 VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
338 << parsed.FullKey();
339 Tensor* out = new Tensor;
340 StatusCallback final_callback = [done, send_args, recv_args, out,
341 is_dead](const Status& s) {
342 done(s, send_args, recv_args, *out, is_dead);
343 delete out;
344 };
345
346 if (status.ok()) {
347 SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
348 std::move(final_callback));
349 } else {
350 final_callback(status);
351 }
352 });
353 return;
354 } else {
355 RecvFromRemoteAsync(parsed, recv_args, std::move(done));
356 }
357 }
358
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)359 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
360 DoneCallback done) {
361 // Test whether the rendezvous is initialized using a shared lock, to avoid
362 // the need for exclusive access in the common case.
363 if (TF_PREDICT_FALSE(!is_initialized())) {
364 mutex_lock l(init_mu_);
365 if (!is_initialized_locked()) {
366 // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
367 // remote worker) before the RunStep (or PartialRunStep) RPC from the
368 // master arrives. RecvLocalAsync thus buffers the arguments until after
369 // the RemoteRendezvous is Initialize()'d, when it completes the
370 // rendezvous logic. At some point after Initialize() is called, a Tensor
371 // is produced locally that will then be sent in response to the incoming
372 // RPC.
373 DeferredCall call(parsed, std::move(done));
374 deferred_calls_.push_back(call);
375 return;
376 }
377 }
378 RecvLocalAsyncInternal(parsed, std::move(done));
379 }
380
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)381 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
382 DoneCallback done) {
383 Status s = ValidateDevices(parsed, true /* is_src */);
384 if (!s.ok()) {
385 done(s, Args(), Args(), Tensor(), false);
386 return;
387 }
388 local_->RecvAsync(parsed, Args(), std::move(done));
389 }
390
StartAbort(const Status & s)391 void BaseRemoteRendezvous::StartAbort(const Status& s) {
392 CHECK(!s.ok());
393 // Use a "derived" status as the status for the rendezvous. Derived
394 // status messages are ignored when aggregating errors across devices: this
395 // allows us to prefer our original status message over any cancellation
396 // related errors.
397 Status derived_status = StatusGroup::MakeDerived(s);
398
399 local_->StartAbort(derived_status);
400 {
401 // Aborts all active RecvTensor calls.
402 mutex_lock l(init_mu_);
403 mutex_lock l2(active_mu_);
404 if (status_.ok()) {
405 status_ = derived_status;
406 for (auto& entry : active_) {
407 entry.first->StartAbort(derived_status);
408 entry.second();
409 }
410 active_.clear();
411 }
412 }
413 }
414
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)415 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
416 const Rendezvous::Args& args) {
417 CancellationManager* cm = args.cancellation_manager;
418 Status captured_status;
419 {
420 tf_shared_lock l(init_mu_);
421 if (!status_.ok()) {
422 captured_status = status_;
423 }
424 }
425 if (!captured_status.ok()) {
426 call->StartAbort(captured_status);
427 return;
428 }
429
430 bool already_cancelled = false;
431 InactiveCallback callback = [] {};
432 if (cm != nullptr) {
433 auto token = cm->get_cancellation_token();
434 already_cancelled = !cm->RegisterCallback(token, [this, call] {
435 {
436 tf_shared_lock l(active_mu_);
437 if (active_.find(call) == active_.end()) return;
438 }
439 call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
440 });
441 callback = [cm, token] { cm->TryDeregisterCallback(token); };
442 }
443
444 if (already_cancelled) {
445 call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
446 } else {
447 mutex_lock l(active_mu_);
448 CHECK(active_.emplace(call, callback).second);
449 }
450 }
451
DeregisterCall(BaseRecvTensorCall * call)452 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
453 mutex_lock l(active_mu_);
454 auto it = active_.find(call);
455 if (it != active_.end()) {
456 // Deregister the cancellation callback, if one was registered.
457 it->second();
458 active_.erase(it);
459 }
460 }
461
DeferredCall(const ParsedKey & parsed,DoneCallback done)462 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
463 DoneCallback done)
464 : parsed(parsed), done(std::move(done)) {}
465
466 } // end namespace tensorflow
467