1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace tensorflow {
40
StartAbortRendevous(Rendezvous * rendez,const Status & s)41 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
42 rendez->StartAbort(s);
43 rendez->Unref();
44 }
45
BaseRendezvousMgr(const WorkerEnv * worker_env)46 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
47 : worker_env_(worker_env) {}
48
~BaseRendezvousMgr()49 BaseRendezvousMgr::~BaseRendezvousMgr() {
50 for (auto& p : table_) {
51 auto rendez = p.second;
52 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
53 }
54 }
55
Find(int64_t step_id)56 RemoteRendezvous* BaseRendezvousMgr::Find(int64_t step_id) {
57 return FindOrCreate(step_id);
58 }
59
FindOrCreate(int64_t step_id)60 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64_t step_id) {
61 mutex_lock l(mu_);
62 auto iter = table_.find(step_id);
63 if (iter == table_.end()) {
64 auto rr = Create(step_id, worker_env_);
65 iter = table_.insert({step_id, rr}).first;
66 }
67 iter->second->Ref();
68 return iter->second;
69 }
70
RecvLocalAsync(int64_t step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)71 void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id,
72 const Rendezvous::ParsedKey& parsed,
73 Rendezvous::DoneCallback done) {
74 auto rendez = FindOrCreate(step_id);
75 auto done_cb = [rendez, done = std::move(done)](
76 const Status& s, const Rendezvous::Args& send_args,
77 const Rendezvous::Args& recv_args, const Tensor& v,
78 bool dead) {
79 rendez->Unref();
80 done(s, send_args, recv_args, v, dead);
81 };
82 rendez->RecvLocalAsync(parsed, std::move(done_cb));
83 }
84
RecvLocal(int64_t step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)85 Status BaseRendezvousMgr::RecvLocal(int64_t step_id,
86 const Rendezvous::ParsedKey& parsed,
87 Tensor* val, bool* is_dead) {
88 Status ret;
89 Notification n;
90 RecvLocalAsync(step_id, parsed,
91 [val, is_dead, &ret, &n](const Status& s,
92 const Rendezvous::Args& send_args,
93 const Rendezvous::Args& recv_args,
94 const Tensor& v, const bool dead) {
95 ret = s;
96 *val = v;
97 *is_dead = dead;
98 n.Notify();
99 });
100 n.WaitForNotification();
101 return ret;
102 }
103
Cleanup(int64_t step_id)104 void BaseRendezvousMgr::Cleanup(int64_t step_id) {
105 Rendezvous* rendez = nullptr;
106 {
107 mutex_lock l(mu_);
108 auto iter = table_.find(step_id);
109 if (iter != table_.end()) {
110 rendez = iter->second;
111 table_.erase(iter);
112 }
113 }
114 if (rendez) {
115 StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
116 }
117 }
118
BaseRemoteRendezvous(const WorkerEnv * env,int64_t step_id)119 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
120 int64_t step_id)
121 : env_(env),
122 step_id_(step_id),
123 local_(NewLocalRendezvous()),
124 session_(nullptr) {}
125
~BaseRemoteRendezvous()126 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
127 CHECK(active_.empty());
128 local_->Unref();
129 }
130
131 // Returns true if "device_name" is a valid full name of local device
132 // of the "worker". This helper is purely based on the worker name
133 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)134 static bool IsLocalDevice(const StringPiece worker_name,
135 const StringPiece device_name) {
136 return absl::StartsWith(device_name, worker_name);
137 }
138
Initialize(WorkerSession * session)139 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
140 CHECK_NE(session, nullptr) << "session must not be null!";
141 std::vector<DeferredCall> deferred_calls;
142 {
143 mutex_lock l(mu_);
144 if (session_ != nullptr) {
145 if (session_->worker_name() == session->worker_name()) {
146 VLOG(1) << "Skipping rendezvous re-initialization.";
147 return Status::OK();
148 }
149 Status s = errors::Internal(
150 "Double init! Worker names would have changed from: ",
151 session_->worker_name(), " -> ", session->worker_name());
152 LOG(WARNING) << s;
153 return s;
154 }
155 session_ = session;
156 std::swap(deferred_calls, deferred_calls_);
157 }
158 for (auto& call : deferred_calls) {
159 RecvLocalAsyncInternal(call.parsed, std::move(call.done));
160 }
161 return Status::OK();
162 }
163
session()164 WorkerSession* BaseRemoteRendezvous::session() {
165 tf_shared_lock l(mu_);
166 return session_;
167 }
168
is_initialized()169 bool BaseRemoteRendezvous::is_initialized() {
170 tf_shared_lock l(mu_);
171 return is_initialized_locked();
172 }
173
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)174 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
175 const Rendezvous::Args& args,
176 const Tensor& val, const bool is_dead) {
177 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
178 WorkerSession* sess = nullptr;
179 {
180 tf_shared_lock l(mu_);
181 if (!status_.ok()) return status_;
182 DCHECK(is_initialized_locked());
183 sess = session_;
184 }
185
186 if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
187 return errors::InvalidArgument(
188 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
189 sess->worker_name());
190 }
191
192 // Buffers "val" and "device_context" in local_.
193 return local_->Send(parsed, args, val, is_dead);
194 }
195
ValidateDevices(const ParsedKey & parsed,bool is_src)196 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
197 bool is_src) {
198 // Cache session pointer to avoid repeatedly taking & releasing the lock
199 // (e.g. calling session())
200 WorkerSession* sess = nullptr;
201 {
202 tf_shared_lock l(mu_);
203 if (!status_.ok()) return status_;
204 if (!is_initialized_locked()) {
205 return errors::Internal("ValidateDevices called before initialization.");
206 }
207 sess = session_;
208 }
209 if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
210 return errors::InvalidArgument(
211 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
212 sess->worker_name());
213 }
214 if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
215 return errors::InvalidArgument(
216 "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
217 sess->worker_name());
218 }
219 return Status::OK();
220 }
221
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)222 void BaseRemoteRendezvous::SameWorkerRecvDone(
223 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
224 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
225 StatusCallback done) {
226 // Do a quick copy (sharing the underlying buffer) if both tensors
227 // are on host memory.
228 const bool src_host =
229 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
230 const bool dst_host =
231 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
232 if (src_host && dst_host) {
233 *out = in;
234 done(Status::OK());
235 return;
236 }
237
238 // This copy must involve a GPU. Hence, "in" must support DMA
239 // (e.g., string tensors do not work on GPU). Variant copy DMA
240 // checks happen inside CopyTensor::ViaDMA.
241 if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
242 in.dtype() != DT_RESOURCE) {
243 done(errors::InvalidArgument(
244 "Non-DMA-safe ", DataTypeString(in.dtype()),
245 " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
246 return;
247 }
248
249 WorkerSession* sess = session();
250 Device* src_device;
251 Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
252 if (!s.ok()) {
253 done(s);
254 return;
255 }
256 Device* dst_device;
257 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
258 if (!s.ok()) {
259 done(s);
260 return;
261 }
262
263 ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", step_id_,
264 "dynamic", in.dtype(), &in.shape());
265 AllocatorAttributes attr = recv_args.alloc_attrs;
266 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
267 recv_args.alloc_attrs.gpu_compatible());
268 Allocator* out_allocator = dst_device->GetAllocator(attr);
269 AllocationAttributes allocation_attr;
270 uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
271 bool sync_dst_compute = (safe_alloc_frontier == 0);
272 std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
273 safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
274 return safe_alloc_frontier;
275 };
276 if (!sync_dst_compute) {
277 allocation_attr.freed_by_func = &freed_by_func;
278 }
279 if (in.dtype() != DT_VARIANT) {
280 // Variants are handled by CopyTensor::ViaDMA.
281 Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
282 *out = copy;
283 }
284
285 // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
286 // etc.
287 CopyTensor::ViaDMA(
288 parsed.edge_name, send_args.device_context, recv_args.device_context,
289 src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
290 out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
291 }
292
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)293 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
294 DeviceNameUtils::ParsedName dst) {
295 return DeviceNameUtils::IsSameAddressSpace(src, dst);
296 }
297
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)298 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
299 const Rendezvous::Args& recv_args,
300 DoneCallback done) {
301 VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
302 Status s = ValidateDevices(parsed, false /*!is_src*/);
303 if (!s.ok()) {
304 done(s, Args(), recv_args, Tensor(), false);
305 return;
306 }
307
308 // ValidateDevices() returns an error status if the rendezvous is not
309 // initialized.
310 DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
311 << parsed.FullKey() << ").";
312
313 ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
314 // Are src and dst in the same worker?
315 if (IsSameWorker(parsed.src, parsed.dst)) {
316 // Recv the tensor from local_.
317 local_->RecvAsync(
318 parsed, recv_args,
319 [this, parsed, done](
320 const Status& status, const Rendezvous::Args& send_args,
321 const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
322 VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
323 << parsed.FullKey();
324 Tensor* out = new Tensor;
325 StatusCallback final_callback = [done, send_args, recv_args, out,
326 is_dead](const Status& s) {
327 done(s, send_args, recv_args, *out, is_dead);
328 delete out;
329 };
330
331 if (status.ok()) {
332 SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
333 std::move(final_callback));
334 } else {
335 final_callback(status);
336 }
337 });
338 return;
339 } else {
340 RecvFromRemoteAsync(parsed, recv_args, std::move(done));
341 }
342 }
343
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)344 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
345 DoneCallback done) {
346 // Test whether the rendezvous is initialized using a shared lock, to avoid
347 // the need for exclusive access in the common case.
348 if (TF_PREDICT_FALSE(!is_initialized())) {
349 mutex_lock l(mu_);
350 if (!is_initialized_locked()) {
351 // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
352 // remote worker) before the RunStep (or PartialRunStep) RPC from the
353 // master arrives. RecvLocalAsync thus buffers the arguments until after
354 // the RemoteRendezvous is Initialize()'d, when it completes the
355 // rendezvous logic. At some point after Initialize() is called, a Tensor
356 // is produced locally that will then be sent in response to the incoming
357 // RPC.
358 DeferredCall call(parsed, std::move(done));
359 deferred_calls_.push_back(call);
360 return;
361 }
362 }
363 RecvLocalAsyncInternal(parsed, std::move(done));
364 }
365
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)366 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
367 DoneCallback done) {
368 Status s = ValidateDevices(parsed, true /* is_src */);
369 if (!s.ok()) {
370 done(s, Args(), Args(), Tensor(), false);
371 return;
372 }
373 local_->RecvAsync(parsed, Args(), std::move(done));
374 }
375
StartAbort(const Status & s)376 void BaseRemoteRendezvous::StartAbort(const Status& s) {
377 CHECK(!s.ok());
378 // If the status passed in is a cancelled or aborted error, mark it as
379 // "derived" for the rendezvous. Derived status messages are ignored when
380 // aggregating errors across devices: this allows us to prefer our original
381 // status message over any cancellation related errors.
382 Status derived_status = s;
383 if (errors::IsCancelled(s) || errors::IsAborted(s)) {
384 derived_status = StatusGroup::MakeDerived(s);
385 }
386
387 local_->StartAbort(derived_status);
388 {
389 // Aborts all active RecvTensor calls.
390 mutex_lock l(mu_);
391 if (status_.ok()) {
392 status_ = derived_status;
393 for (auto& entry : active_) {
394 entry.first->StartAbort(derived_status);
395 entry.second();
396 }
397 active_.clear();
398 }
399 }
400 }
401
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)402 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
403 const Rendezvous::Args& args) {
404 CancellationManager* cm = args.cancellation_manager;
405 bool already_cancelled = false;
406 InactiveCallback callback = [] {};
407 {
408 mutex_lock l(mu_);
409 if (!status_.ok()) {
410 call->StartAbort(status_);
411 return;
412 }
413 if (cm != nullptr) {
414 auto token = cm->get_cancellation_token();
415 already_cancelled = !cm->RegisterCallback(token, [this, call] {
416 {
417 mutex_lock l(mu_);
418 if (active_.find(call) == active_.end()) return;
419 call->StartAbort(
420 errors::Cancelled("RecvFromRemoteAsync is cancelled."));
421 }
422 });
423 callback = [cm, token] { cm->TryDeregisterCallback(token); };
424 }
425 if (already_cancelled) {
426 call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
427 } else {
428 CHECK(active_.emplace(call, callback).second);
429 }
430 }
431 }
432
DeregisterCall(BaseRecvTensorCall * call)433 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
434 mutex_lock l(mu_);
435 auto it = active_.find(call);
436 if (it != active_.end()) {
437 // Deregister the cancellation callback, if one was registered.
438 it->second();
439 active_.erase(it);
440 }
441 }
442
DeferredCall(const ParsedKey & parsed,DoneCallback done)443 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
444 DoneCallback done)
445 : parsed(parsed), done(std::move(done)) {}
446
447 } // end namespace tensorflow
448