• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 
StartAbortRendevous(Rendezvous * rendez,const Status & s)41 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
42   rendez->StartAbort(s);
43   rendez->Unref();
44 }
45 
BaseRendezvousMgr(const WorkerEnv * worker_env)46 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
47     : worker_env_(worker_env) {}
48 
~BaseRendezvousMgr()49 BaseRendezvousMgr::~BaseRendezvousMgr() {
50   for (auto& p : table_) {
51     auto rendez = p.second;
52     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
53   }
54 }
55 
Find(int64_t step_id)56 RemoteRendezvous* BaseRendezvousMgr::Find(int64_t step_id) {
57   return FindOrCreate(step_id);
58 }
59 
FindOrCreate(int64_t step_id)60 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64_t step_id) {
61   mutex_lock l(mu_);
62   auto iter = table_.find(step_id);
63   if (iter == table_.end()) {
64     auto rr = Create(step_id, worker_env_);
65     iter = table_.insert({step_id, rr}).first;
66   }
67   iter->second->Ref();
68   return iter->second;
69 }
70 
RecvLocalAsync(int64_t step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)71 void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id,
72                                        const Rendezvous::ParsedKey& parsed,
73                                        Rendezvous::DoneCallback done) {
74   auto rendez = FindOrCreate(step_id);
75   auto done_cb = [rendez, done = std::move(done)](
76                      const Status& s, const Rendezvous::Args& send_args,
77                      const Rendezvous::Args& recv_args, const Tensor& v,
78                      bool dead) {
79     rendez->Unref();
80     done(s, send_args, recv_args, v, dead);
81   };
82   rendez->RecvLocalAsync(parsed, std::move(done_cb));
83 }
84 
RecvLocal(int64_t step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)85 Status BaseRendezvousMgr::RecvLocal(int64_t step_id,
86                                     const Rendezvous::ParsedKey& parsed,
87                                     Tensor* val, bool* is_dead) {
88   Status ret;
89   Notification n;
90   RecvLocalAsync(step_id, parsed,
91                  [val, is_dead, &ret, &n](const Status& s,
92                                           const Rendezvous::Args& send_args,
93                                           const Rendezvous::Args& recv_args,
94                                           const Tensor& v, const bool dead) {
95                    ret = s;
96                    *val = v;
97                    *is_dead = dead;
98                    n.Notify();
99                  });
100   n.WaitForNotification();
101   return ret;
102 }
103 
Cleanup(int64_t step_id)104 void BaseRendezvousMgr::Cleanup(int64_t step_id) {
105   Rendezvous* rendez = nullptr;
106   {
107     mutex_lock l(mu_);
108     auto iter = table_.find(step_id);
109     if (iter != table_.end()) {
110       rendez = iter->second;
111       table_.erase(iter);
112     }
113   }
114   if (rendez) {
115     StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
116   }
117 }
118 
BaseRemoteRendezvous(const WorkerEnv * env,int64_t step_id)119 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
120                                            int64_t step_id)
121     : env_(env),
122       step_id_(step_id),
123       local_(NewLocalRendezvous()),
124       session_(nullptr) {}
125 
~BaseRemoteRendezvous()126 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
127   CHECK(active_.empty());
128   local_->Unref();
129 }
130 
131 // Returns true if "device_name" is a valid full name of local device
132 // of the "worker".  This helper is purely based on the worker name
133 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)134 static bool IsLocalDevice(const StringPiece worker_name,
135                           const StringPiece device_name) {
136   return absl::StartsWith(device_name, worker_name);
137 }
138 
Initialize(WorkerSession * session)139 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
140   CHECK_NE(session, nullptr) << "session must not be null!";
141   std::vector<DeferredCall> deferred_calls;
142   {
143     mutex_lock l(mu_);
144     if (session_ != nullptr) {
145       if (session_->worker_name() == session->worker_name()) {
146         VLOG(1) << "Skipping rendezvous re-initialization.";
147         return Status::OK();
148       }
149       Status s = errors::Internal(
150           "Double init! Worker names would have changed from: ",
151           session_->worker_name(), " -> ", session->worker_name());
152       LOG(WARNING) << s;
153       return s;
154     }
155     session_ = session;
156     std::swap(deferred_calls, deferred_calls_);
157   }
158   for (auto& call : deferred_calls) {
159     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
160   }
161   return Status::OK();
162 }
163 
session()164 WorkerSession* BaseRemoteRendezvous::session() {
165   tf_shared_lock l(mu_);
166   return session_;
167 }
168 
is_initialized()169 bool BaseRemoteRendezvous::is_initialized() {
170   tf_shared_lock l(mu_);
171   return is_initialized_locked();
172 }
173 
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)174 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
175                                   const Rendezvous::Args& args,
176                                   const Tensor& val, const bool is_dead) {
177   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
178   WorkerSession* sess = nullptr;
179   {
180     tf_shared_lock l(mu_);
181     if (!status_.ok()) return status_;
182     DCHECK(is_initialized_locked());
183     sess = session_;
184   }
185 
186   if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
187     return errors::InvalidArgument(
188         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
189         sess->worker_name());
190   }
191 
192   // Buffers "val" and "device_context" in local_.
193   return local_->Send(parsed, args, val, is_dead);
194 }
195 
ValidateDevices(const ParsedKey & parsed,bool is_src)196 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
197                                              bool is_src) {
198   // Cache session pointer to avoid repeatedly taking & releasing the lock
199   // (e.g. calling session())
200   WorkerSession* sess = nullptr;
201   {
202     tf_shared_lock l(mu_);
203     if (!status_.ok()) return status_;
204     if (!is_initialized_locked()) {
205       return errors::Internal("ValidateDevices called before initialization.");
206     }
207     sess = session_;
208   }
209   if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
210     return errors::InvalidArgument(
211         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
212         sess->worker_name());
213   }
214   if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
215     return errors::InvalidArgument(
216         "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
217         sess->worker_name());
218   }
219   return Status::OK();
220 }
221 
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)222 void BaseRemoteRendezvous::SameWorkerRecvDone(
223     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
224     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
225     StatusCallback done) {
226   // Do a quick copy (sharing the underlying buffer) if both tensors
227   // are on host memory.
228   const bool src_host =
229       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
230   const bool dst_host =
231       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
232   if (src_host && dst_host) {
233     *out = in;
234     done(Status::OK());
235     return;
236   }
237 
238   // This copy must involve a GPU. Hence, "in" must support DMA
239   // (e.g., string tensors do not work on GPU).  Variant copy DMA
240   // checks happen inside CopyTensor::ViaDMA.
241   if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
242       in.dtype() != DT_RESOURCE) {
243     done(errors::InvalidArgument(
244         "Non-DMA-safe ", DataTypeString(in.dtype()),
245         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
246     return;
247   }
248 
249   WorkerSession* sess = session();
250   Device* src_device;
251   Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
252   if (!s.ok()) {
253     done(s);
254     return;
255   }
256   Device* dst_device;
257   s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
258   if (!s.ok()) {
259     done(s);
260     return;
261   }
262 
263   ScopedMemoryDebugAnnotation op_annotation("SameWorkerRecvDone", step_id_,
264                                             "dynamic", in.dtype(), &in.shape());
265   AllocatorAttributes attr = recv_args.alloc_attrs;
266   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
267                           recv_args.alloc_attrs.gpu_compatible());
268   Allocator* out_allocator = dst_device->GetAllocator(attr);
269   AllocationAttributes allocation_attr;
270   uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
271   bool sync_dst_compute = (safe_alloc_frontier == 0);
272   std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
273     safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
274     return safe_alloc_frontier;
275   };
276   if (!sync_dst_compute) {
277     allocation_attr.freed_by_func = &freed_by_func;
278   }
279   if (in.dtype() != DT_VARIANT) {
280     // Variants are handled by CopyTensor::ViaDMA.
281     Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
282     *out = copy;
283   }
284 
285   // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
286   // etc.
287   CopyTensor::ViaDMA(
288       parsed.edge_name, send_args.device_context, recv_args.device_context,
289       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
290       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
291 }
292 
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)293 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
294                                         DeviceNameUtils::ParsedName dst) {
295   return DeviceNameUtils::IsSameAddressSpace(src, dst);
296 }
297 
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)298 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
299                                      const Rendezvous::Args& recv_args,
300                                      DoneCallback done) {
301   VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
302   Status s = ValidateDevices(parsed, false /*!is_src*/);
303   if (!s.ok()) {
304     done(s, Args(), recv_args, Tensor(), false);
305     return;
306   }
307 
308   // ValidateDevices() returns an error status if the rendezvous is not
309   // initialized.
310   DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
311                            << parsed.FullKey() << ").";
312 
313   ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
314   // Are src and dst in the same worker?
315   if (IsSameWorker(parsed.src, parsed.dst)) {
316     // Recv the tensor from local_.
317     local_->RecvAsync(
318         parsed, recv_args,
319         [this, parsed, done](
320             const Status& status, const Rendezvous::Args& send_args,
321             const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
322           VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
323                   << parsed.FullKey();
324           Tensor* out = new Tensor;
325           StatusCallback final_callback = [done, send_args, recv_args, out,
326                                            is_dead](const Status& s) {
327             done(s, send_args, recv_args, *out, is_dead);
328             delete out;
329           };
330 
331           if (status.ok()) {
332             SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
333                                std::move(final_callback));
334           } else {
335             final_callback(status);
336           }
337         });
338     return;
339   } else {
340     RecvFromRemoteAsync(parsed, recv_args, std::move(done));
341   }
342 }
343 
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)344 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
345                                           DoneCallback done) {
346   // Test whether the rendezvous is initialized using a shared lock, to avoid
347   // the need for exclusive access in the common case.
348   if (TF_PREDICT_FALSE(!is_initialized())) {
349     mutex_lock l(mu_);
350     if (!is_initialized_locked()) {
351       // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
352       // remote worker) before the RunStep (or PartialRunStep) RPC from the
353       // master arrives. RecvLocalAsync thus buffers the arguments until after
354       // the RemoteRendezvous is Initialize()'d, when it completes the
355       // rendezvous logic. At some point after Initialize() is called, a Tensor
356       // is produced locally that will then be sent in response to the incoming
357       // RPC.
358       DeferredCall call(parsed, std::move(done));
359       deferred_calls_.push_back(call);
360       return;
361     }
362   }
363   RecvLocalAsyncInternal(parsed, std::move(done));
364 }
365 
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)366 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
367                                                   DoneCallback done) {
368   Status s = ValidateDevices(parsed, true /* is_src */);
369   if (!s.ok()) {
370     done(s, Args(), Args(), Tensor(), false);
371     return;
372   }
373   local_->RecvAsync(parsed, Args(), std::move(done));
374 }
375 
StartAbort(const Status & s)376 void BaseRemoteRendezvous::StartAbort(const Status& s) {
377   CHECK(!s.ok());
378   // If the status passed in is a cancelled or aborted error, mark it as
379   // "derived" for the rendezvous. Derived status messages are ignored when
380   // aggregating errors across devices: this allows us to prefer our original
381   // status message over any cancellation related errors.
382   Status derived_status = s;
383   if (errors::IsCancelled(s) || errors::IsAborted(s)) {
384     derived_status = StatusGroup::MakeDerived(s);
385   }
386 
387   local_->StartAbort(derived_status);
388   {
389     // Aborts all active RecvTensor calls.
390     mutex_lock l(mu_);
391     if (status_.ok()) {
392       status_ = derived_status;
393       for (auto& entry : active_) {
394         entry.first->StartAbort(derived_status);
395         entry.second();
396       }
397       active_.clear();
398     }
399   }
400 }
401 
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)402 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
403                                         const Rendezvous::Args& args) {
404   CancellationManager* cm = args.cancellation_manager;
405   bool already_cancelled = false;
406   InactiveCallback callback = [] {};
407   {
408     mutex_lock l(mu_);
409     if (!status_.ok()) {
410       call->StartAbort(status_);
411       return;
412     }
413     if (cm != nullptr) {
414       auto token = cm->get_cancellation_token();
415       already_cancelled = !cm->RegisterCallback(token, [this, call] {
416         {
417           mutex_lock l(mu_);
418           if (active_.find(call) == active_.end()) return;
419           call->StartAbort(
420               errors::Cancelled("RecvFromRemoteAsync is cancelled."));
421         }
422       });
423       callback = [cm, token] { cm->TryDeregisterCallback(token); };
424     }
425     if (already_cancelled) {
426       call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
427     } else {
428       CHECK(active_.emplace(call, callback).second);
429     }
430   }
431 }
432 
DeregisterCall(BaseRecvTensorCall * call)433 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
434   mutex_lock l(mu_);
435   auto it = active_.find(call);
436   if (it != active_.end()) {
437     // Deregister the cancellation callback, if one was registered.
438     it->second();
439     active_.erase(it);
440   }
441 }
442 
DeferredCall(const ParsedKey & parsed,DoneCallback done)443 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
444                                                  DoneCallback done)
445     : parsed(parsed), done(std::move(done)) {}
446 
447 }  // end namespace tensorflow
448