• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 
StartAbortRendevous(Rendezvous * rendez,const Status & s)38 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
39   rendez->StartAbort(s);
40   rendez->Unref();
41 }
42 
BaseRendezvousMgr(const WorkerEnv * worker_env)43 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
44     : worker_env_(worker_env) {}
45 
~BaseRendezvousMgr()46 BaseRendezvousMgr::~BaseRendezvousMgr() {
47   for (auto& p : table_) {
48     auto rendez = p.second;
49     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
50   }
51 }
52 
Find(int64 step_id)53 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
54   return FindOrCreate(step_id);
55 }
56 
FindOrCreate(int64 step_id)57 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
58   mutex_lock l(mu_);
59   auto iter = table_.find(step_id);
60   if (iter == table_.end()) {
61     auto rr = Create(step_id, worker_env_);
62     iter = table_.insert({step_id, rr}).first;
63   }
64   iter->second->Ref();
65   return iter->second;
66 }
67 
RecvLocalAsync(int64 step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)68 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
69                                        const Rendezvous::ParsedKey& parsed,
70                                        Rendezvous::DoneCallback done) {
71   auto rendez = FindOrCreate(step_id);
72   using namespace std::placeholders;
73   Rendezvous::DoneCallback done_cb = std::bind(
74       [rendez](Rendezvous::DoneCallback done,
75                // Begin unbound arguments.
76                const Status& s, const Rendezvous::Args& send_args,
77                const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
78         rendez->Unref();
79         done(s, send_args, recv_args, v, dead);
80       },
81       std::move(done), _1, _2, _3, _4, _5);
82   rendez->RecvLocalAsync(parsed, std::move(done_cb));
83 }
84 
RecvLocal(int64 step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)85 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
86                                     const Rendezvous::ParsedKey& parsed,
87                                     Tensor* val, bool* is_dead) {
88   Status ret;
89   Notification n;
90   RecvLocalAsync(step_id, parsed,
91                  [val, is_dead, &ret, &n](const Status& s,
92                                           const Rendezvous::Args& send_args,
93                                           const Rendezvous::Args& recv_args,
94                                           const Tensor& v, const bool dead) {
95                    ret = s;
96                    *val = v;
97                    *is_dead = dead;
98                    n.Notify();
99                  });
100   n.WaitForNotification();
101   return ret;
102 }
103 
Cleanup(int64 step_id)104 void BaseRendezvousMgr::Cleanup(int64 step_id) {
105   Rendezvous* rendez = nullptr;
106   {
107     mutex_lock l(mu_);
108     auto iter = table_.find(step_id);
109     if (iter != table_.end()) {
110       rendez = iter->second;
111       table_.erase(iter);
112     }
113   }
114   if (rendez) {
115     StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
116   }
117 }
118 
CleanupAll()119 void BaseRendezvousMgr::CleanupAll() {
120   std::vector<Rendezvous*> rendezs;
121   {
122     mutex_lock l(mu_);
123     for (const auto& entry : table_) {
124       rendezs.push_back(entry.second);
125     }
126     table_.clear();
127   }
128   for (auto rendez : rendezs) {
129     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
130   }
131 }
132 
BaseRemoteRendezvous(const WorkerEnv * env,int64 step_id)133 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
134     : env_(env),
135       step_id_(step_id),
136       local_(NewLocalRendezvous()),
137       session_(nullptr) {}
138 
~BaseRemoteRendezvous()139 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
140   CHECK(active_.empty());
141   local_->Unref();
142 }
143 
144 // Returns true if "device_name" is a valid full name of local device
145 // of the "worker".  This helper is purely based on the worker name
146 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)147 static bool IsLocalDevice(const StringPiece worker_name,
148                           const StringPiece device_name) {
149   return str_util::StartsWith(device_name, worker_name);
150 }
151 
Initialize(WorkerSession * session)152 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
153   CHECK_NE(session, nullptr) << "session must not be null!";
154   std::vector<DeferredCall> deferred_calls;
155   {
156     mutex_lock l(mu_);
157     if (session_ != nullptr) {
158       if (session_->worker_name == session->worker_name) {
159         LOG(INFO) << "Skipping rendezvous re-initialization.";
160         return Status::OK();
161       }
162       Status s = errors::Internal(
163           "Double init! Worker names would have changed from: ",
164           session_->worker_name, " -> ", session->worker_name);
165       LOG(WARNING) << s;
166       return s;
167     }
168     session_ = session;
169     std::swap(deferred_calls, deferred_calls_);
170   }
171   for (auto& call : deferred_calls) {
172     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
173   }
174   return Status::OK();
175 }
176 
session()177 WorkerSession* BaseRemoteRendezvous::session() {
178   mutex_lock l(mu_);
179   return session_;
180 }
181 
is_initialized()182 bool BaseRemoteRendezvous::is_initialized() {
183   mutex_lock l(mu_);
184   return is_initialized_locked();
185 }
186 
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)187 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
188                                   const Rendezvous::Args& args,
189                                   const Tensor& val, const bool is_dead) {
190   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
191   {
192     mutex_lock l(mu_);
193     if (!status_.ok()) return status_;
194     DCHECK(is_initialized_locked());
195     if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
196       return errors::InvalidArgument(
197           "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
198           session_->worker_name);
199     }
200   }
201   // Buffers "val" and "device_context" in local_.
202   return local_->Send(parsed, args, val, is_dead);
203 }
204 
ValidateDevices(const ParsedKey & parsed,bool is_src)205 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
206                                              bool is_src) {
207   // Cache session pointer to avoid repeatedly taking & releasing the lock
208   // (e.g. calling session())
209   WorkerSession* sess = nullptr;
210   {
211     mutex_lock l(mu_);
212     if (!status_.ok()) return status_;
213     if (!is_initialized_locked()) {
214       return errors::Internal("ValidateDevices called before initialization.");
215     }
216     sess = session_;
217   }
218   if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
219     return errors::InvalidArgument("Invalid rendezvous key (src): ",
220                                    parsed.FullKey(), " @ ", sess->worker_name);
221   }
222   if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
223     return errors::InvalidArgument("Invalid rendezvous key (dst): ",
224                                    parsed.FullKey(), " @ ", sess->worker_name);
225   }
226   return Status::OK();
227 }
228 
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)229 void BaseRemoteRendezvous::SameWorkerRecvDone(
230     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
231     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
232     StatusCallback done) {
233   // Do a quick copy (sharing the underlying buffer) if both tensors
234   // are on host memory.
235   const bool src_host =
236       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
237   const bool dst_host =
238       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
239   if (src_host && dst_host) {
240     *out = in;
241     done(Status::OK());
242     return;
243   }
244 
245   // This copy must involve a GPU. Hence, "in" must support DMA
246   // (e.g., string tensors do not work on GPU).  Variant copy DMA
247   // checks happen inside CopyTensor::ViaDMA.
248   if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT) {
249     done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
250                                  " tensor may not be copied from/to a GPU."));
251     return;
252   }
253 
254   WorkerSession* sess = session();
255   Device* src_device;
256   Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
257   if (!s.ok()) {
258     done(s);
259     return;
260   }
261   Device* dst_device;
262   s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
263   if (!s.ok()) {
264     done(s);
265     return;
266   }
267 
268   AllocatorAttributes attr = recv_args.alloc_attrs;
269   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
270                           recv_args.alloc_attrs.gpu_compatible());
271   Allocator* out_allocator = dst_device->GetAllocator(attr);
272 
273   if (in.dtype() != DT_VARIANT) {
274     // Variants are handled by CopyTensor::ViaDMA.
275     Tensor copy(out_allocator, in.dtype(), in.shape());
276     *out = copy;
277   }
278 
279   // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
280   // etc.
281   CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
282                      recv_args.device_context, src_device, dst_device,
283                      send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
284                      0 /*dev_to_dev_stream_index*/, std::move(done));
285 }
286 
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)287 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
288                                         DeviceNameUtils::ParsedName dst) {
289   return DeviceNameUtils::IsSameAddressSpace(src, dst);
290 }
291 
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)292 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
293                                      const Rendezvous::Args& recv_args,
294                                      DoneCallback done) {
295   VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
296   Status s = ValidateDevices(parsed, false /*!is_src*/);
297   if (s.ok() && !is_initialized()) {
298     s.Update(errors::Internal(
299         "RecvAsync called when uninitialized (key:", parsed.FullKey(), ")."));
300   }
301   if (!s.ok()) {
302     done(s, Args(), recv_args, Tensor(), false);
303     return;
304   }
305 
306   // Are src and dst in the same worker?
307   if (IsSameWorker(parsed.src, parsed.dst)) {
308     // Recv the tensor from local_.
309     local_->RecvAsync(
310         parsed, recv_args,
311         [this, parsed, done](
312             const Status& status, const Rendezvous::Args& send_args,
313             const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
314           Tensor* out = new Tensor;
315           StatusCallback final_callback = [done, send_args, recv_args, out,
316                                            is_dead](const Status& s) {
317             done(s, send_args, recv_args, *out, is_dead);
318             delete out;
319           };
320 
321           if (status.ok()) {
322             SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
323                                std::move(final_callback));
324           } else {
325             final_callback(status);
326           }
327         });
328     return;
329   } else {
330     RecvFromRemoteAsync(parsed, recv_args, std::move(done));
331   }
332 }
333 
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)334 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
335                                           DoneCallback done) {
336   {
337     mutex_lock l(mu_);
338     if (!is_initialized_locked()) {
339       // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
340       // remote worker) before the RunStep (or PartialRunStep) RPC from the
341       // master arrives. RecvLocalAsync thus buffers the arguments until after
342       // the RemoteRendezvous is Initialize()'d, when it completes the
343       // rendezvous logic. At some point after Initialize() is called, a Tensor
344       // is produced locally that will then be sent in response to the incoming
345       // RPC.
346       DeferredCall call(parsed, std::move(done));
347       deferred_calls_.push_back(call);
348       return;
349     }
350   }
351   RecvLocalAsyncInternal(parsed, std::move(done));
352 }
353 
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)354 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
355                                                   DoneCallback done) {
356   Status s = ValidateDevices(parsed, true /* is_src */);
357   if (!s.ok()) {
358     done(s, Args(), Args(), Tensor(), false);
359     return;
360   }
361   local_->RecvAsync(parsed, Args(), std::move(done));
362 }
363 
StartAbort(const Status & s)364 void BaseRemoteRendezvous::StartAbort(const Status& s) {
365   CHECK(!s.ok());
366   local_->StartAbort(s);
367   {
368     // Aborts all active RecvTensor calls.
369     mutex_lock l(mu_);
370     if (status_.ok()) {
371       status_ = s;
372       for (BaseRecvTensorCall* call : active_) {
373         call->StartAbort(s);
374       }
375       active_.clear();
376     }
377   }
378 }
379 
RegisterCall(BaseRecvTensorCall * call)380 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
381   mutex_lock l(mu_);
382   if (!status_.ok()) {
383     call->StartAbort(status_);
384   } else {
385     CHECK(active_.insert(call).second);
386   }
387 }
388 
DeregisterCall(BaseRecvTensorCall * call)389 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
390   mutex_lock l(mu_);
391   active_.erase(call);
392 }
393 
DeferredCall(const ParsedKey & parsed,DoneCallback done)394 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
395                                                  DoneCallback done)
396     : parsed(parsed), done(std::move(done)) {}
397 
398 }  // end namespace tensorflow
399