• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/notification.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 
41 class RpcRemoteRendezvous : public BaseRemoteRendezvous {
42  public:
RpcRemoteRendezvous(const WorkerEnv * env,int64 step_id)43   RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
44       : BaseRemoteRendezvous(env, step_id) {}
45 
46  protected:
47   void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
48                            const Rendezvous::Args& args,
49                            DoneCallback done) override;
50 
51  private:
~RpcRemoteRendezvous()52   ~RpcRemoteRendezvous() override {}
53 
54   TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
55 };
56 
57 // Used only to retrieve tensors from remote processes.
58 class RpcRecvTensorCall : public BaseRecvTensorCall {
59  public:
RpcRecvTensorCall()60   RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
61 
Init(WorkerInterface * wi,int64 step_id,StringPiece key,AllocatorAttributes alloc_attrs,Device * dst_device,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)62   void Init(WorkerInterface* wi, int64 step_id, StringPiece key,
63             AllocatorAttributes alloc_attrs, Device* dst_device,
64             const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
65     wi_ = wi;
66     alloc_attrs_ = alloc_attrs;
67     dst_device_ = dst_device;
68     recv_args_ = recv_args;
69     done_ = std::move(done);
70     req_.set_step_id(step_id);
71     req_.set_rendezvous_key(key.data(), key.size());
72     req_.set_request_id(GetUniqueRequestId());
73   }
74 
Reset()75   void Reset() {
76     // The RpcRemoteRendezvous using this object is responsible for calling
77     // ReleaseWorker() before Reset().
78     DCHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
79         << "Leaking WorkerInterface in RpcRecvTensorCall::Reset().";
80 
81     alloc_attrs_ = AllocatorAttributes();
82     dst_device_ = nullptr;
83     // We don't clear opts_ and assume that Init will set up the state for
84     // opts_ appropriately.
85     req_.Clear();
86     resp_.Clear();
87     {
88       mutex_lock l(mu_);
89       status_ = Status::OK();
90     }
91     done_ = nullptr;
92   }
93 
~RpcRecvTensorCall()94   ~RpcRecvTensorCall() override {
95     // Since only the RpcRecvTensorFreeList will delete an
96     // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
97     // the user releases a Call object to the free list.
98     CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
99         << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
100   }
101 
Start(std::function<void ()> recv_done)102   void Start(std::function<void()> recv_done) override {
103     StartRTCall(std::move(recv_done));
104   }
105 
StartAbort(const Status & s)106   void StartAbort(const Status& s) override {
107     {
108       mutex_lock l(mu_);
109       status_.Update(s);
110     }
111     opts_.StartCancel();
112   }
113 
status() const114   Status status() const override {
115     mutex_lock l(mu_);
116     return status_;
117   }
118 
ReleaseWorker(WorkerCacheInterface * worker_cache)119   void ReleaseWorker(WorkerCacheInterface* worker_cache) {
120     DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
121         << "RpcRecvTensorCall::ReleaseWorker() called twice.";
122     worker_cache->ReleaseWorker(src_worker_, wi_);
123     wi_ = nullptr;
124   }
125 
tensor() const126   const Tensor& tensor() const { return resp_.tensor(); }
127 
is_dead() const128   bool is_dead() const { return resp_.metadata().is_dead(); }
129 
dst_device() const130   Device* dst_device() const { return dst_device_; }
recv_args() const131   const Rendezvous::Args& recv_args() const { return recv_args_; }
done() const132   const Rendezvous::DoneCallback& done() const { return done_; }
133 
134  private:
135   friend class RpcRemoteRendezvous;
136 
137   // Start the main RecvTensor call, checking for an async abort.
StartRTCall(std::function<void ()> recv_done)138   void StartRTCall(std::function<void()> recv_done) {
139     resp_.InitAlloc(dst_device_, alloc_attrs_);
140     auto abort_checked = std::make_shared<Notification>();
141     auto cb = [this, abort_checked,
142                recv_done = std::move(recv_done)](const Status& s) {
143       // Make sure the Rendezvous abort checking is finished before running the
144       // callback, which might destroy the current call object.
145       abort_checked->WaitForNotification();
146       if (!s.ok()) {
147         mutex_lock l(mu_);
148         status_.Update(s);
149       }
150       recv_done();
151     };
152     wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
153 
154     // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
155     // ordering is important because `StartAbort` could be called right before
156     // the `RecvTensorAsync` request registers its RPC cancellation to `opts_`.
157     // In that case, the previous `StartAbort` would not trigger the
158     // cancellation of this call.
159     Status s;
160     {
161       mutex_lock l(mu_);
162       s = status_;
163     }
164     if (!s.ok()) {
165       opts_.StartCancel();
166     }
167     // Notify that the abort check has finished.
168     abort_checked->Notify();
169   }
170 
171   string src_worker_;
172   string src_rel_device_;
173   WorkerInterface* wi_;  // Not owned.
174   AllocatorAttributes alloc_attrs_;
175   Device* dst_device_;
176   CallOptions opts_;
177   RecvTensorRequest req_;
178   TensorResponse resp_;
179   Rendezvous::Args recv_args_;
180   Rendezvous::DoneCallback done_;
181 
182   mutable mutex mu_;
183   Status status_ TF_GUARDED_BY(mu_);
184 
185   TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
186 };
187 
188 class RpcRecvTensorFreeList {
189  public:
RpcRecvTensorFreeList()190   RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList()191   ~RpcRecvTensorFreeList() {
192     for (size_t i = 0; i < objects_.size(); i++) {
193       delete objects_[i];
194     }
195   }
196 
New()197   RpcRecvTensorCall* New() {
198     {
199       mutex_lock l(mu_);
200       if (!objects_.empty()) {
201         RpcRecvTensorCall* result = objects_.back();
202         objects_.pop_back();
203         return result;
204       }
205     }
206     return new RpcRecvTensorCall;
207   }
208 
Release(RpcRecvTensorCall * obj)209   void Release(RpcRecvTensorCall* obj) {
210     obj->Reset();
211     {
212       mutex_lock l(mu_);
213       if (objects_.size() < kMaxObjects) {
214         objects_.push_back(obj);
215         return;
216       }
217     }
218     delete obj;
219   }
220 
221  private:
222   static constexpr int kMaxObjects = 1000;
223 
224   mutex mu_;
225   std::vector<RpcRecvTensorCall*> objects_ TF_GUARDED_BY(mu_);
226 };
227 
get_call_freelist()228 static RpcRecvTensorFreeList* get_call_freelist() {
229   static RpcRecvTensorFreeList* call_freelist = new RpcRecvTensorFreeList();
230   return call_freelist;
231 }
232 
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)233 void RpcRemoteRendezvous::RecvFromRemoteAsync(
234     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
235     DoneCallback done) {
236   CHECK(is_initialized());
237   Status s;
238 
239   // Prepare a RecvTensor call that can handle being aborted.
240   RpcRecvTensorCall* call = get_call_freelist()->New();
241 
242   // key.src_device identifies a remote device.
243   if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
244                                         &call->src_rel_device_)) {
245     s = errors::Internal(parsed.src_device,
246                          " is invalid remote source device.");
247   }
248   WorkerSession* sess = session();
249   std::shared_ptr<WorkerCacheInterface> worker_cache =
250       sess->GetSharedWorkerCache();
251   // The worker will be released in a subsequent call to
252   // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been
253   // initialized) or `call->ReleaseWorker()` (if it has been initialized).
254   WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);
255   if (s.ok() && rwi == nullptr) {
256     s = errors::Internal("No worker known as ", call->src_worker_);
257   }
258 
259   Device* dst_device;
260   if (s.ok()) {
261     s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
262   }
263   if (!s.ok()) {
264     if (rwi != nullptr) {
265       sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
266     }
267     get_call_freelist()->Release(call);
268     done(s, Args(), recv_args, Tensor{}, false);
269     return;
270   }
271 
272   call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
273              recv_args, std::move(done));
274 
275   // Record "call" in active_ so that it can be aborted cleanly.
276   RegisterCall(call, recv_args);
277 
278   // RendezvousMgr already aborted, shouldn't send RPC call any more
279   if (!call->status().ok()) {
280     DeregisterCall(call);
281     // NOTE: `*sess` can potentially be deleted before we return from
282     // `call->done()(...)`, so we must release the worker before calling the
283     // callback.
284     call->ReleaseWorker(sess->worker_cache());
285     call->done()(call->status(), Args(), Args(), Tensor(), false);
286     get_call_freelist()->Release(call);
287     return;
288   }
289 
290   // Start "call".
291   Ref();
292   call->Start([this, call, worker_cache]() {
293     // Removes "call" from active_. Prevent StartAbort().
294     DeregisterCall(call);
295     // If StartAbort was called prior to DeregisterCall, then the
296     // current status should be bad.
297     Status s = call->status();
298     // NOTE: `*session()` can potentially be deleted before we return from
299     // `call->done()(...)`, so we must release the worker before calling the
300     // callback.
301     call->ReleaseWorker(session()->worker_cache());
302     call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
303     get_call_freelist()->Release(call);
304     Unref();
305   });
306 }
307 
308 }  // namespace
309 
RpcRendezvousMgr(const WorkerEnv * env)310 RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
311     : BaseRendezvousMgr(env) {}
312 
Create(int64 step_id,const WorkerEnv * worker_env)313 BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
314                                                const WorkerEnv* worker_env) {
315   return new RpcRemoteRendezvous(worker_env, step_id);
316 }
317 
318 }  // end namespace tensorflow
319