1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace tensorflow {
40
StartAbortRendevous(Rendezvous * rendez,const Status & s)41 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
42 rendez->StartAbort(s);
43 rendez->Unref();
44 }
45
BaseRendezvousMgr(const WorkerEnv * worker_env)46 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
47 : worker_env_(worker_env) {}
48
~BaseRendezvousMgr()49 BaseRendezvousMgr::~BaseRendezvousMgr() {
50 for (auto& p : table_) {
51 auto rendez = p.second;
52 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
53 }
54 }
55
Find(int64 step_id)56 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
57 return FindOrCreate(step_id);
58 }
59
FindOrCreate(int64 step_id)60 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
61 mutex_lock l(mu_);
62 auto iter = table_.find(step_id);
63 if (iter == table_.end()) {
64 auto rr = Create(step_id, worker_env_);
65 iter = table_.insert({step_id, rr}).first;
66 }
67 iter->second->Ref();
68 return iter->second;
69 }
70
RecvLocalAsync(int64 step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)71 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
72 const Rendezvous::ParsedKey& parsed,
73 Rendezvous::DoneCallback done) {
74 auto rendez = FindOrCreate(step_id);
75 auto done_cb = [rendez, done = std::move(done)](
76 const Status& s, const Rendezvous::Args& send_args,
77 const Rendezvous::Args& recv_args, const Tensor& v,
78 bool dead) {
79 rendez->Unref();
80 done(s, send_args, recv_args, v, dead);
81 };
82 rendez->RecvLocalAsync(parsed, std::move(done_cb));
83 }
84
RecvLocal(int64 step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)85 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
86 const Rendezvous::ParsedKey& parsed,
87 Tensor* val, bool* is_dead) {
88 Status ret;
89 Notification n;
90 RecvLocalAsync(step_id, parsed,
91 [val, is_dead, &ret, &n](const Status& s,
92 const Rendezvous::Args& send_args,
93 const Rendezvous::Args& recv_args,
94 const Tensor& v, const bool dead) {
95 ret = s;
96 *val = v;
97 *is_dead = dead;
98 n.Notify();
99 });
100 n.WaitForNotification();
101 return ret;
102 }
103
Cleanup(int64 step_id)104 void BaseRendezvousMgr::Cleanup(int64 step_id) {
105 Rendezvous* rendez = nullptr;
106 {
107 mutex_lock l(mu_);
108 auto iter = table_.find(step_id);
109 if (iter != table_.end()) {
110 rendez = iter->second;
111 table_.erase(iter);
112 }
113 }
114 if (rendez) {
115 StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
116 }
117 }
118
BaseRemoteRendezvous(const WorkerEnv * env,int64 step_id)119 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
120 : env_(env),
121 step_id_(step_id),
122 local_(NewLocalRendezvous()),
123 session_(nullptr) {}
124
~BaseRemoteRendezvous()125 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
126 CHECK(active_.empty());
127 local_->Unref();
128 }
129
130 // Returns true if "device_name" is a valid full name of local device
131 // of the "worker". This helper is purely based on the worker name
132 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)133 static bool IsLocalDevice(const StringPiece worker_name,
134 const StringPiece device_name) {
135 return absl::StartsWith(device_name, worker_name);
136 }
137
Initialize(WorkerSession * session)138 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
139 CHECK_NE(session, nullptr) << "session must not be null!";
140 std::vector<DeferredCall> deferred_calls;
141 {
142 mutex_lock l(mu_);
143 if (session_ != nullptr) {
144 if (session_->worker_name() == session->worker_name()) {
145 VLOG(1) << "Skipping rendezvous re-initialization.";
146 return Status::OK();
147 }
148 Status s = errors::Internal(
149 "Double init! Worker names would have changed from: ",
150 session_->worker_name(), " -> ", session->worker_name());
151 LOG(WARNING) << s;
152 return s;
153 }
154 session_ = session;
155 std::swap(deferred_calls, deferred_calls_);
156 }
157 for (auto& call : deferred_calls) {
158 RecvLocalAsyncInternal(call.parsed, std::move(call.done));
159 }
160 return Status::OK();
161 }
162
session()163 WorkerSession* BaseRemoteRendezvous::session() {
164 tf_shared_lock l(mu_);
165 return session_;
166 }
167
is_initialized()168 bool BaseRemoteRendezvous::is_initialized() {
169 tf_shared_lock l(mu_);
170 return is_initialized_locked();
171 }
172
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)173 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
174 const Rendezvous::Args& args,
175 const Tensor& val, const bool is_dead) {
176 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
177 WorkerSession* sess = nullptr;
178 {
179 tf_shared_lock l(mu_);
180 if (!status_.ok()) return status_;
181 DCHECK(is_initialized_locked());
182 sess = session_;
183 }
184
185 if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
186 return errors::InvalidArgument(
187 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
188 sess->worker_name());
189 }
190
191 // Buffers "val" and "device_context" in local_.
192 return local_->Send(parsed, args, val, is_dead);
193 }
194
ValidateDevices(const ParsedKey & parsed,bool is_src)195 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
196 bool is_src) {
197 // Cache session pointer to avoid repeatedly taking & releasing the lock
198 // (e.g. calling session())
199 WorkerSession* sess = nullptr;
200 {
201 tf_shared_lock l(mu_);
202 if (!status_.ok()) return status_;
203 if (!is_initialized_locked()) {
204 return errors::Internal("ValidateDevices called before initialization.");
205 }
206 sess = session_;
207 }
208 if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
209 return errors::InvalidArgument(
210 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
211 sess->worker_name());
212 }
213 if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
214 return errors::InvalidArgument(
215 "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
216 sess->worker_name());
217 }
218 return Status::OK();
219 }
220
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)221 void BaseRemoteRendezvous::SameWorkerRecvDone(
222 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
223 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
224 StatusCallback done) {
225 // Do a quick copy (sharing the underlying buffer) if both tensors
226 // are on host memory.
227 const bool src_host =
228 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
229 const bool dst_host =
230 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
231 if (src_host && dst_host) {
232 *out = in;
233 done(Status::OK());
234 return;
235 }
236
237 // This copy must involve a GPU. Hence, "in" must support DMA
238 // (e.g., string tensors do not work on GPU). Variant copy DMA
239 // checks happen inside CopyTensor::ViaDMA.
240 if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
241 in.dtype() != DT_RESOURCE) {
242 done(errors::InvalidArgument(
243 "Non-DMA-safe ", DataTypeString(in.dtype()),
244 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
245 return;
246 }
247
248 WorkerSession* sess = session();
249 Device* src_device;
250 Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
251 if (!s.ok()) {
252 done(s);
253 return;
254 }
255 Device* dst_device;
256 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
257 if (!s.ok()) {
258 done(s);
259 return;
260 }
261
262 ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", step_id_,
263 "dynamic", in.dtype(), &in.shape());
264 AllocatorAttributes attr = recv_args.alloc_attrs;
265 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
266 recv_args.alloc_attrs.gpu_compatible());
267 Allocator* out_allocator = dst_device->GetAllocator(attr);
268 AllocationAttributes allocation_attr;
269 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
270 bool sync_dst_compute = (safe_alloc_frontier == 0);
271 std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
272 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
273 return safe_alloc_frontier;
274 };
275 if (!sync_dst_compute) {
276 allocation_attr.freed_by_func = &freed_by_func;
277 }
278 if (in.dtype() != DT_VARIANT) {
279 // Variants are handled by CopyTensor::ViaDMA.
280 Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
281 *out = copy;
282 }
283
284 // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
285 // etc.
286 CopyTensor::ViaDMA(
287 parsed.edge_name, send_args.device_context, recv_args.device_context,
288 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
289 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
290 }
291
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)292 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
293 DeviceNameUtils::ParsedName dst) {
294 return DeviceNameUtils::IsSameAddressSpace(src, dst);
295 }
296
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)297 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
298 const Rendezvous::Args& recv_args,
299 DoneCallback done) {
300 VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
301 Status s = ValidateDevices(parsed, false /*!is_src*/);
302 if (!s.ok()) {
303 done(s, Args(), recv_args, Tensor(), false);
304 return;
305 }
306
307 // ValidateDevices() returns an error status if the rendezvous is not
308 // initialized.
309 DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
310 << parsed.FullKey() << ").";
311
312 ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
313 // Are src and dst in the same worker?
314 if (IsSameWorker(parsed.src, parsed.dst)) {
315 // Recv the tensor from local_.
316 local_->RecvAsync(
317 parsed, recv_args,
318 [this, parsed, done](
319 const Status& status, const Rendezvous::Args& send_args,
320 const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
321 VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
322 << parsed.FullKey();
323 Tensor* out = new Tensor;
324 StatusCallback final_callback = [done, send_args, recv_args, out,
325 is_dead](const Status& s) {
326 done(s, send_args, recv_args, *out, is_dead);
327 delete out;
328 };
329
330 if (status.ok()) {
331 SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
332 std::move(final_callback));
333 } else {
334 final_callback(status);
335 }
336 });
337 return;
338 } else {
339 RecvFromRemoteAsync(parsed, recv_args, std::move(done));
340 }
341 }
342
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)343 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
344 DoneCallback done) {
345 // Test whether the rendezvous is initialized using a shared lock, to avoid
346 // the need for exclusive access in the common case.
347 if (TF_PREDICT_FALSE(!is_initialized())) {
348 mutex_lock l(mu_);
349 if (!is_initialized_locked()) {
350 // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
351 // remote worker) before the RunStep (or PartialRunStep) RPC from the
352 // master arrives. RecvLocalAsync thus buffers the arguments until after
353 // the RemoteRendezvous is Initialize()'d, when it completes the
354 // rendezvous logic. At some point after Initialize() is called, a Tensor
355 // is produced locally that will then be sent in response to the incoming
356 // RPC.
357 DeferredCall call(parsed, std::move(done));
358 deferred_calls_.push_back(call);
359 return;
360 }
361 }
362 RecvLocalAsyncInternal(parsed, std::move(done));
363 }
364
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)365 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
366 DoneCallback done) {
367 Status s = ValidateDevices(parsed, true /* is_src */);
368 if (!s.ok()) {
369 done(s, Args(), Args(), Tensor(), false);
370 return;
371 }
372 local_->RecvAsync(parsed, Args(), std::move(done));
373 }
374
StartAbort(const Status & s)375 void BaseRemoteRendezvous::StartAbort(const Status& s) {
376 CHECK(!s.ok());
377 // If the status passed in is a cancelled or aborted error, mark it as
378 // "derived" for the rendezvous. Derived status messages are ignored when
379 // aggregating errors across devices: this allows us to prefer our original
380 // status message over any cancellation related errors.
381 Status derived_status = s;
382 if (errors::IsCancelled(s) || errors::IsAborted(s)) {
383 derived_status = StatusGroup::MakeDerived(s);
384 }
385
386 local_->StartAbort(derived_status);
387 {
388 // Aborts all active RecvTensor calls.
389 mutex_lock l(mu_);
390 if (status_.ok()) {
391 status_ = derived_status;
392 for (auto& entry : active_) {
393 entry.first->StartAbort(derived_status);
394 entry.second();
395 }
396 active_.clear();
397 }
398 }
399 }
400
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)401 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
402 const Rendezvous::Args& args) {
403 CancellationManager* cm = args.cancellation_manager;
404 bool already_cancelled = false;
405 InactiveCallback callback = [] {};
406 {
407 mutex_lock l(mu_);
408 if (!status_.ok()) {
409 call->StartAbort(status_);
410 return;
411 }
412 if (cm != nullptr) {
413 auto token = cm->get_cancellation_token();
414 already_cancelled = !cm->RegisterCallback(token, [this, call] {
415 {
416 mutex_lock l(mu_);
417 if (active_.find(call) == active_.end()) return;
418 call->StartAbort(
419 errors::Cancelled("RecvFromRemoteAsync is cancelled."));
420 }
421 });
422 callback = [cm, token] { cm->TryDeregisterCallback(token); };
423 }
424 if (already_cancelled) {
425 call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
426 } else {
427 CHECK(active_.emplace(call, callback).second);
428 }
429 }
430 }
431
DeregisterCall(BaseRecvTensorCall * call)432 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
433 mutex_lock l(mu_);
434 auto it = active_.find(call);
435 if (it != active_.end()) {
436 // Deregister the cancellation callback, if one was registered.
437 it->second();
438 active_.erase(it);
439 }
440 }
441
DeferredCall(const ParsedKey & parsed,DoneCallback done)442 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
443 DoneCallback done)
444 : parsed(parsed), done(std::move(done)) {}
445
446 } // end namespace tensorflow
447