• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 
StartAbortRendevous(Rendezvous * rendez,const Status & s)41 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
42   rendez->StartAbort(s);
43   rendez->Unref();
44 }
45 
BaseRendezvousMgr(const WorkerEnv * worker_env)46 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
47     : worker_env_(worker_env) {}
48 
~BaseRendezvousMgr()49 BaseRendezvousMgr::~BaseRendezvousMgr() {
50   for (auto& p : table_) {
51     auto rendez = p.second;
52     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
53   }
54 }
55 
Find(int64 step_id)56 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
57   return FindOrCreate(step_id);
58 }
59 
FindOrCreate(int64 step_id)60 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
61   mutex_lock l(mu_);
62   auto iter = table_.find(step_id);
63   if (iter == table_.end()) {
64     auto rr = Create(step_id, worker_env_);
65     iter = table_.insert({step_id, rr}).first;
66   }
67   iter->second->Ref();
68   return iter->second;
69 }
70 
RecvLocalAsync(int64 step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)71 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
72                                        const Rendezvous::ParsedKey& parsed,
73                                        Rendezvous::DoneCallback done) {
74   auto rendez = FindOrCreate(step_id);
75   auto done_cb = [rendez, done = std::move(done)](
76                      const Status& s, const Rendezvous::Args& send_args,
77                      const Rendezvous::Args& recv_args, const Tensor& v,
78                      bool dead) {
79     rendez->Unref();
80     done(s, send_args, recv_args, v, dead);
81   };
82   rendez->RecvLocalAsync(parsed, std::move(done_cb));
83 }
84 
RecvLocal(int64 step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)85 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
86                                     const Rendezvous::ParsedKey& parsed,
87                                     Tensor* val, bool* is_dead) {
88   Status ret;
89   Notification n;
90   RecvLocalAsync(step_id, parsed,
91                  [val, is_dead, &ret, &n](const Status& s,
92                                           const Rendezvous::Args& send_args,
93                                           const Rendezvous::Args& recv_args,
94                                           const Tensor& v, const bool dead) {
95                    ret = s;
96                    *val = v;
97                    *is_dead = dead;
98                    n.Notify();
99                  });
100   n.WaitForNotification();
101   return ret;
102 }
103 
Cleanup(int64 step_id)104 void BaseRendezvousMgr::Cleanup(int64 step_id) {
105   Rendezvous* rendez = nullptr;
106   {
107     mutex_lock l(mu_);
108     auto iter = table_.find(step_id);
109     if (iter != table_.end()) {
110       rendez = iter->second;
111       table_.erase(iter);
112     }
113   }
114   if (rendez) {
115     StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
116   }
117 }
118 
BaseRemoteRendezvous(const WorkerEnv * env,int64 step_id)119 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
120     : env_(env),
121       step_id_(step_id),
122       local_(NewLocalRendezvous()),
123       session_(nullptr) {}
124 
~BaseRemoteRendezvous()125 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
126   CHECK(active_.empty());
127   local_->Unref();
128 }
129 
130 // Returns true if "device_name" is a valid full name of local device
131 // of the "worker".  This helper is purely based on the worker name
132 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)133 static bool IsLocalDevice(const StringPiece worker_name,
134                           const StringPiece device_name) {
135   return absl::StartsWith(device_name, worker_name);
136 }
137 
Initialize(WorkerSession * session)138 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
139   CHECK_NE(session, nullptr) << "session must not be null!";
140   std::vector<DeferredCall> deferred_calls;
141   {
142     mutex_lock l(mu_);
143     if (session_ != nullptr) {
144       if (session_->worker_name() == session->worker_name()) {
145         VLOG(1) << "Skipping rendezvous re-initialization.";
146         return Status::OK();
147       }
148       Status s = errors::Internal(
149           "Double init! Worker names would have changed from: ",
150           session_->worker_name(), " -> ", session->worker_name());
151       LOG(WARNING) << s;
152       return s;
153     }
154     session_ = session;
155     std::swap(deferred_calls, deferred_calls_);
156   }
157   for (auto& call : deferred_calls) {
158     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
159   }
160   return Status::OK();
161 }
162 
session()163 WorkerSession* BaseRemoteRendezvous::session() {
164   tf_shared_lock l(mu_);
165   return session_;
166 }
167 
is_initialized()168 bool BaseRemoteRendezvous::is_initialized() {
169   tf_shared_lock l(mu_);
170   return is_initialized_locked();
171 }
172 
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)173 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
174                                   const Rendezvous::Args& args,
175                                   const Tensor& val, const bool is_dead) {
176   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
177   WorkerSession* sess = nullptr;
178   {
179     tf_shared_lock l(mu_);
180     if (!status_.ok()) return status_;
181     DCHECK(is_initialized_locked());
182     sess = session_;
183   }
184 
185   if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
186     return errors::InvalidArgument(
187         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
188         sess->worker_name());
189   }
190 
191   // Buffers "val" and "device_context" in local_.
192   return local_->Send(parsed, args, val, is_dead);
193 }
194 
ValidateDevices(const ParsedKey & parsed,bool is_src)195 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
196                                              bool is_src) {
197   // Cache session pointer to avoid repeatedly taking & releasing the lock
198   // (e.g. calling session())
199   WorkerSession* sess = nullptr;
200   {
201     tf_shared_lock l(mu_);
202     if (!status_.ok()) return status_;
203     if (!is_initialized_locked()) {
204       return errors::Internal("ValidateDevices called before initialization.");
205     }
206     sess = session_;
207   }
208   if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
209     return errors::InvalidArgument(
210         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
211         sess->worker_name());
212   }
213   if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
214     return errors::InvalidArgument(
215         "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
216         sess->worker_name());
217   }
218   return Status::OK();
219 }
220 
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)221 void BaseRemoteRendezvous::SameWorkerRecvDone(
222     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
223     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
224     StatusCallback done) {
225   // Do a quick copy (sharing the underlying buffer) if both tensors
226   // are on host memory.
227   const bool src_host =
228       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
229   const bool dst_host =
230       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
231   if (src_host && dst_host) {
232     *out = in;
233     done(Status::OK());
234     return;
235   }
236 
237   // This copy must involve a GPU. Hence, "in" must support DMA
238   // (e.g., string tensors do not work on GPU).  Variant copy DMA
239   // checks happen inside CopyTensor::ViaDMA.
240   if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
241       in.dtype() != DT_RESOURCE) {
242     done(errors::InvalidArgument(
243         "Non-DMA-safe ", DataTypeString(in.dtype()),
244         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
245     return;
246   }
247 
248   WorkerSession* sess = session();
249   Device* src_device;
250   Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
251   if (!s.ok()) {
252     done(s);
253     return;
254   }
255   Device* dst_device;
256   s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
257   if (!s.ok()) {
258     done(s);
259     return;
260   }
261 
262   ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", step_id_,
263                                             "dynamic", in.dtype(), &in.shape());
264   AllocatorAttributes attr = recv_args.alloc_attrs;
265   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
266                           recv_args.alloc_attrs.gpu_compatible());
267   Allocator* out_allocator = dst_device->GetAllocator(attr);
268   AllocationAttributes allocation_attr;
269   uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
270   bool sync_dst_compute = (safe_alloc_frontier == 0);
271   std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
272     safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
273     return safe_alloc_frontier;
274   };
275   if (!sync_dst_compute) {
276     allocation_attr.freed_by_func = &freed_by_func;
277   }
278   if (in.dtype() != DT_VARIANT) {
279     // Variants are handled by CopyTensor::ViaDMA.
280     Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
281     *out = copy;
282   }
283 
284   // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
285   // etc.
286   CopyTensor::ViaDMA(
287       parsed.edge_name, send_args.device_context, recv_args.device_context,
288       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
289       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
290 }
291 
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)292 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
293                                         DeviceNameUtils::ParsedName dst) {
294   return DeviceNameUtils::IsSameAddressSpace(src, dst);
295 }
296 
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)297 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
298                                      const Rendezvous::Args& recv_args,
299                                      DoneCallback done) {
300   VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
301   Status s = ValidateDevices(parsed, false /*!is_src*/);
302   if (!s.ok()) {
303     done(s, Args(), recv_args, Tensor(), false);
304     return;
305   }
306 
307   // ValidateDevices() returns an error status if the rendezvous is not
308   // initialized.
309   DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
310                            << parsed.FullKey() << ").";
311 
312   ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
313   // Are src and dst in the same worker?
314   if (IsSameWorker(parsed.src, parsed.dst)) {
315     // Recv the tensor from local_.
316     local_->RecvAsync(
317         parsed, recv_args,
318         [this, parsed, done](
319             const Status& status, const Rendezvous::Args& send_args,
320             const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
321           VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
322                   << parsed.FullKey();
323           Tensor* out = new Tensor;
324           StatusCallback final_callback = [done, send_args, recv_args, out,
325                                            is_dead](const Status& s) {
326             done(s, send_args, recv_args, *out, is_dead);
327             delete out;
328           };
329 
330           if (status.ok()) {
331             SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
332                                std::move(final_callback));
333           } else {
334             final_callback(status);
335           }
336         });
337     return;
338   } else {
339     RecvFromRemoteAsync(parsed, recv_args, std::move(done));
340   }
341 }
342 
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)343 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
344                                           DoneCallback done) {
345   // Test whether the rendezvous is initialized using a shared lock, to avoid
346   // the need for exclusive access in the common case.
347   if (TF_PREDICT_FALSE(!is_initialized())) {
348     mutex_lock l(mu_);
349     if (!is_initialized_locked()) {
350       // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
351       // remote worker) before the RunStep (or PartialRunStep) RPC from the
352       // master arrives. RecvLocalAsync thus buffers the arguments until after
353       // the RemoteRendezvous is Initialize()'d, when it completes the
354       // rendezvous logic. At some point after Initialize() is called, a Tensor
355       // is produced locally that will then be sent in response to the incoming
356       // RPC.
357       DeferredCall call(parsed, std::move(done));
358       deferred_calls_.push_back(call);
359       return;
360     }
361   }
362   RecvLocalAsyncInternal(parsed, std::move(done));
363 }
364 
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)365 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
366                                                   DoneCallback done) {
367   Status s = ValidateDevices(parsed, true /* is_src */);
368   if (!s.ok()) {
369     done(s, Args(), Args(), Tensor(), false);
370     return;
371   }
372   local_->RecvAsync(parsed, Args(), std::move(done));
373 }
374 
StartAbort(const Status & s)375 void BaseRemoteRendezvous::StartAbort(const Status& s) {
376   CHECK(!s.ok());
377   // If the status passed in is a cancelled or aborted error, mark it as
378   // "derived" for the rendezvous. Derived status messages are ignored when
379   // aggregating errors across devices: this allows us to prefer our original
380   // status message over any cancellation related errors.
381   Status derived_status = s;
382   if (errors::IsCancelled(s) || errors::IsAborted(s)) {
383     derived_status = StatusGroup::MakeDerived(s);
384   }
385 
386   local_->StartAbort(derived_status);
387   {
388     // Aborts all active RecvTensor calls.
389     mutex_lock l(mu_);
390     if (status_.ok()) {
391       status_ = derived_status;
392       for (auto& entry : active_) {
393         entry.first->StartAbort(derived_status);
394         entry.second();
395       }
396       active_.clear();
397     }
398   }
399 }
400 
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)401 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
402                                         const Rendezvous::Args& args) {
403   CancellationManager* cm = args.cancellation_manager;
404   bool already_cancelled = false;
405   InactiveCallback callback = [] {};
406   {
407     mutex_lock l(mu_);
408     if (!status_.ok()) {
409       call->StartAbort(status_);
410       return;
411     }
412     if (cm != nullptr) {
413       auto token = cm->get_cancellation_token();
414       already_cancelled = !cm->RegisterCallback(token, [this, call] {
415         {
416           mutex_lock l(mu_);
417           if (active_.find(call) == active_.end()) return;
418           call->StartAbort(
419               errors::Cancelled("RecvFromRemoteAsync is cancelled."));
420         }
421       });
422       callback = [cm, token] { cm->TryDeregisterCallback(token); };
423     }
424     if (already_cancelled) {
425       call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
426     } else {
427       CHECK(active_.emplace(call, callback).second);
428     }
429   }
430 }
431 
DeregisterCall(BaseRecvTensorCall * call)432 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
433   mutex_lock l(mu_);
434   auto it = active_.find(call);
435   if (it != active_.end()) {
436     // Deregister the cancellation callback, if one was registered.
437     it->second();
438     active_.erase(it);
439   }
440 }
441 
DeferredCall(const ParsedKey & parsed,DoneCallback done)442 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
443                                                  DoneCallback done)
444     : parsed(parsed), done(std::move(done)) {}
445 
446 }  // end namespace tensorflow
447