• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/numbers.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 
StartAbortRendevous(Rendezvous * rendez,const Status & s)40 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
41   rendez->StartAbort(s);
42   rendez->Unref();
43 }
44 
BaseRendezvousMgr(const WorkerEnv * worker_env)45 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
46     : worker_env_(worker_env) {}
47 
~BaseRendezvousMgr()48 BaseRendezvousMgr::~BaseRendezvousMgr() {
49   for (auto& p : table_) {
50     auto rendez = p.second;
51     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
52   }
53 }
54 
Find(int64 step_id)55 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
56   return FindOrCreate(step_id);
57 }
58 
FindOrCreate(int64 step_id)59 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
60   mutex_lock l(mu_);
61   auto iter = table_.find(step_id);
62   if (iter == table_.end()) {
63     auto rr = Create(step_id, worker_env_);
64     iter = table_.insert({step_id, rr}).first;
65   }
66   iter->second->Ref();
67   return iter->second;
68 }
69 
RecvLocalAsync(int64 step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)70 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
71                                        const Rendezvous::ParsedKey& parsed,
72                                        Rendezvous::DoneCallback done) {
73   auto rendez = FindOrCreate(step_id);
74   auto done_cb = [rendez, done = std::move(done)](
75                      const Status& s, const Rendezvous::Args& send_args,
76                      const Rendezvous::Args& recv_args, const Tensor& v,
77                      bool dead) {
78     rendez->Unref();
79     done(s, send_args, recv_args, v, dead);
80   };
81   rendez->RecvLocalAsync(parsed, std::move(done_cb));
82 }
83 
RecvLocal(int64 step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)84 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
85                                     const Rendezvous::ParsedKey& parsed,
86                                     Tensor* val, bool* is_dead) {
87   Status ret;
88   Notification n;
89   RecvLocalAsync(step_id, parsed,
90                  [val, is_dead, &ret, &n](const Status& s,
91                                           const Rendezvous::Args& send_args,
92                                           const Rendezvous::Args& recv_args,
93                                           const Tensor& v, const bool dead) {
94                    ret = s;
95                    *val = v;
96                    *is_dead = dead;
97                    n.Notify();
98                  });
99   n.WaitForNotification();
100   return ret;
101 }
102 
Cleanup(int64 step_id)103 void BaseRendezvousMgr::Cleanup(int64 step_id) {
104   Rendezvous* rendez = nullptr;
105   {
106     mutex_lock l(mu_);
107     auto iter = table_.find(step_id);
108     if (iter != table_.end()) {
109       rendez = iter->second;
110       table_.erase(iter);
111     }
112   }
113   if (rendez) {
114     StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
115   }
116 }
117 
CleanupAll()118 void BaseRendezvousMgr::CleanupAll() {
119   std::vector<Rendezvous*> rendezs;
120   {
121     mutex_lock l(mu_);
122     for (const auto& entry : table_) {
123       rendezs.push_back(entry.second);
124     }
125     table_.clear();
126   }
127   for (auto rendez : rendezs) {
128     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
129   }
130 }
131 
BaseRemoteRendezvous(const WorkerEnv * env,int64 step_id)132 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
133     : env_(env),
134       step_id_(step_id),
135       local_(NewLocalRendezvous()),
136       session_(nullptr) {}
137 
~BaseRemoteRendezvous()138 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
139   CHECK(active_.empty());
140   local_->Unref();
141 }
142 
143 // Returns true if "device_name" is a valid full name of local device
144 // of the "worker".  This helper is purely based on the worker name
145 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)146 static bool IsLocalDevice(const StringPiece worker_name,
147                           const StringPiece device_name) {
148   return absl::StartsWith(device_name, worker_name);
149 }
150 
Initialize(WorkerSession * session)151 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
152   CHECK_NE(session, nullptr) << "session must not be null!";
153   std::vector<DeferredCall> deferred_calls;
154   {
155     mutex_lock l(init_mu_);
156     if (session_ != nullptr) {
157       if (session_->worker_name() == session->worker_name()) {
158         VLOG(1) << "Skipping rendezvous re-initialization.";
159         return Status::OK();
160       }
161       Status s = errors::Internal(
162           "Double init! Worker names would have changed from: ",
163           session_->worker_name(), " -> ", session->worker_name());
164       LOG(WARNING) << s;
165       return s;
166     }
167     session_ = session;
168     std::swap(deferred_calls, deferred_calls_);
169   }
170   for (auto& call : deferred_calls) {
171     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
172   }
173   return Status::OK();
174 }
175 
session()176 WorkerSession* BaseRemoteRendezvous::session() {
177   tf_shared_lock l(init_mu_);
178   return session_;
179 }
180 
is_initialized()181 bool BaseRemoteRendezvous::is_initialized() {
182   tf_shared_lock l(init_mu_);
183   return is_initialized_locked();
184 }
185 
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)186 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
187                                   const Rendezvous::Args& args,
188                                   const Tensor& val, const bool is_dead) {
189   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
190   WorkerSession* sess = nullptr;
191   {
192     tf_shared_lock l(init_mu_);
193     if (!status_.ok()) return status_;
194     DCHECK(is_initialized_locked());
195     sess = session_;
196   }
197 
198   if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
199     return errors::InvalidArgument(
200         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
201         sess->worker_name());
202   }
203 
204   // Buffers "val" and "device_context" in local_.
205   return local_->Send(parsed, args, val, is_dead);
206 }
207 
ValidateDevices(const ParsedKey & parsed,bool is_src)208 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
209                                              bool is_src) {
210   // Cache session pointer to avoid repeatedly taking & releasing the lock
211   // (e.g. calling session())
212   WorkerSession* sess = nullptr;
213   {
214     tf_shared_lock l(init_mu_);
215     if (!status_.ok()) return status_;
216     if (!is_initialized_locked()) {
217       return errors::Internal("ValidateDevices called before initialization.");
218     }
219     sess = session_;
220   }
221   if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
222     return errors::InvalidArgument(
223         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
224         sess->worker_name());
225   }
226   if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
227     return errors::InvalidArgument(
228         "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
229         sess->worker_name());
230   }
231   return Status::OK();
232 }
233 
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)234 void BaseRemoteRendezvous::SameWorkerRecvDone(
235     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
236     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
237     StatusCallback done) {
238   // Do a quick copy (sharing the underlying buffer) if both tensors
239   // are on host memory.
240   const bool src_host =
241       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
242   const bool dst_host =
243       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
244   if (src_host && dst_host) {
245     *out = in;
246     done(Status::OK());
247     return;
248   }
249 
250   // This copy must involve a GPU. Hence, "in" must support DMA
251   // (e.g., string tensors do not work on GPU).  Variant copy DMA
252   // checks happen inside CopyTensor::ViaDMA.
253   if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
254       in.dtype() != DT_RESOURCE) {
255     done(errors::InvalidArgument(
256         "Non-DMA-safe ", DataTypeString(in.dtype()),
257         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
258     return;
259   }
260 
261   WorkerSession* sess = session();
262   Device* src_device;
263   Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
264   if (!s.ok()) {
265     done(s);
266     return;
267   }
268   Device* dst_device;
269   s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
270   if (!s.ok()) {
271     done(s);
272     return;
273   }
274 
275   MEMDEBUG_CACHE_STEPID(0);
276   // Note that it would be nice to cache the step_id here, but it's not
277   // available.
278   MEMDEBUG_CACHE_OP("SameWorkerRecvDone");
279   AllocatorAttributes attr = recv_args.alloc_attrs;
280   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
281                           recv_args.alloc_attrs.gpu_compatible());
282   Allocator* out_allocator = dst_device->GetAllocator(attr);
283   AllocationAttributes allocation_attr;
284   uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
285   bool sync_dst_compute = (safe_alloc_frontier == 0);
286   std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
287     safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
288     return safe_alloc_frontier;
289   };
290   if (!sync_dst_compute) {
291     allocation_attr.freed_by_func = &freed_by_func;
292   }
293   if (in.dtype() != DT_VARIANT) {
294     // Variants are handled by CopyTensor::ViaDMA.
295     Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
296     *out = copy;
297   }
298 
299   // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
300   // etc.
301   CopyTensor::ViaDMA(
302       parsed.edge_name, send_args.device_context, recv_args.device_context,
303       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
304       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
305 }
306 
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)307 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
308                                         DeviceNameUtils::ParsedName dst) {
309   return DeviceNameUtils::IsSameAddressSpace(src, dst);
310 }
311 
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)312 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
313                                      const Rendezvous::Args& recv_args,
314                                      DoneCallback done) {
315   VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
316   Status s = ValidateDevices(parsed, false /*!is_src*/);
317   if (!s.ok()) {
318     done(s, Args(), recv_args, Tensor(), false);
319     return;
320   }
321 
322   // ValidateDevices() returns an error status if the rendezvous is not
323   // initialized.
324   DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
325                            << parsed.FullKey() << ").";
326 
327   MEMDEBUG_CACHE_OP("RecvAsync");
328   MEMDEBUG_CACHE_STEPID(0);
329   // Are src and dst in the same worker?
330   if (IsSameWorker(parsed.src, parsed.dst)) {
331     // Recv the tensor from local_.
332     local_->RecvAsync(
333         parsed, recv_args,
334         [this, parsed, done](
335             const Status& status, const Rendezvous::Args& send_args,
336             const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
337           VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
338                   << parsed.FullKey();
339           Tensor* out = new Tensor;
340           StatusCallback final_callback = [done, send_args, recv_args, out,
341                                            is_dead](const Status& s) {
342             done(s, send_args, recv_args, *out, is_dead);
343             delete out;
344           };
345 
346           if (status.ok()) {
347             SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
348                                std::move(final_callback));
349           } else {
350             final_callback(status);
351           }
352         });
353     return;
354   } else {
355     RecvFromRemoteAsync(parsed, recv_args, std::move(done));
356   }
357 }
358 
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)359 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
360                                           DoneCallback done) {
361   // Test whether the rendezvous is initialized using a shared lock, to avoid
362   // the need for exclusive access in the common case.
363   if (TF_PREDICT_FALSE(!is_initialized())) {
364     mutex_lock l(init_mu_);
365     if (!is_initialized_locked()) {
366       // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
367       // remote worker) before the RunStep (or PartialRunStep) RPC from the
368       // master arrives. RecvLocalAsync thus buffers the arguments until after
369       // the RemoteRendezvous is Initialize()'d, when it completes the
370       // rendezvous logic. At some point after Initialize() is called, a Tensor
371       // is produced locally that will then be sent in response to the incoming
372       // RPC.
373       DeferredCall call(parsed, std::move(done));
374       deferred_calls_.push_back(call);
375       return;
376     }
377   }
378   RecvLocalAsyncInternal(parsed, std::move(done));
379 }
380 
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)381 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
382                                                   DoneCallback done) {
383   Status s = ValidateDevices(parsed, true /* is_src */);
384   if (!s.ok()) {
385     done(s, Args(), Args(), Tensor(), false);
386     return;
387   }
388   local_->RecvAsync(parsed, Args(), std::move(done));
389 }
390 
StartAbort(const Status & s)391 void BaseRemoteRendezvous::StartAbort(const Status& s) {
392   CHECK(!s.ok());
393   // Use a "derived" status as the status for the rendezvous. Derived
394   // status messages are ignored when aggregating errors across devices: this
395   // allows us to prefer our original status message over any cancellation
396   // related errors.
397   Status derived_status = StatusGroup::MakeDerived(s);
398 
399   local_->StartAbort(derived_status);
400   {
401     // Aborts all active RecvTensor calls.
402     mutex_lock l(init_mu_);
403     mutex_lock l2(active_mu_);
404     if (status_.ok()) {
405       status_ = derived_status;
406       for (auto& entry : active_) {
407         entry.first->StartAbort(derived_status);
408         entry.second();
409       }
410       active_.clear();
411     }
412   }
413 }
414 
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)415 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
416                                         const Rendezvous::Args& args) {
417   CancellationManager* cm = args.cancellation_manager;
418   Status captured_status;
419   {
420     tf_shared_lock l(init_mu_);
421     if (!status_.ok()) {
422       captured_status = status_;
423     }
424   }
425   if (!captured_status.ok()) {
426     call->StartAbort(captured_status);
427     return;
428   }
429 
430   bool already_cancelled = false;
431   InactiveCallback callback = [] {};
432   if (cm != nullptr) {
433     auto token = cm->get_cancellation_token();
434     already_cancelled = !cm->RegisterCallback(token, [this, call] {
435       {
436         tf_shared_lock l(active_mu_);
437         if (active_.find(call) == active_.end()) return;
438       }
439       call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
440     });
441     callback = [cm, token] { cm->TryDeregisterCallback(token); };
442   }
443 
444   if (already_cancelled) {
445     call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
446   } else {
447     mutex_lock l(active_mu_);
448     CHECK(active_.emplace(call, callback).second);
449   }
450 }
451 
DeregisterCall(BaseRecvTensorCall * call)452 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
453   mutex_lock l(active_mu_);
454   auto it = active_.find(call);
455   if (it != active_.end()) {
456     // Deregister the cancellation callback, if one was registered.
457     it->second();
458     active_.erase(it);
459   }
460 }
461 
DeferredCall(const ParsedKey & parsed,DoneCallback done)462 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
463                                                  DoneCallback done)
464     : parsed(parsed), done(std::move(done)) {}
465 
466 }  // end namespace tensorflow
467