• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace tensorflow {
37 
38 namespace {
39 
40 class RpcRemoteRendezvous : public BaseRemoteRendezvous {
41  public:
RpcRemoteRendezvous(const WorkerEnv * env,int64 step_id)42   RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
43       : BaseRemoteRendezvous(env, step_id) {}
44 
45  protected:
46   void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
47                            const Rendezvous::Args& args,
48                            DoneCallback done) override;
49 
50  private:
~RpcRemoteRendezvous()51   ~RpcRemoteRendezvous() override {}
52 
53   TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
54 };
55 
56 // Used only to retrieve tensors from remote processes.
57 class RpcRecvTensorCall : public BaseRecvTensorCall {
58  public:
RpcRecvTensorCall()59   RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
60 
Init(WorkerInterface * wi,int64 step_id,StringPiece key,AllocatorAttributes alloc_attrs,Device * dst_device,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)61   void Init(WorkerInterface* wi, int64 step_id, StringPiece key,
62             AllocatorAttributes alloc_attrs, Device* dst_device,
63             const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
64     wi_ = wi;
65     alloc_attrs_ = alloc_attrs;
66     dst_device_ = dst_device;
67     recv_args_ = recv_args;
68     done_ = std::move(done);
69     req_.set_step_id(step_id);
70     req_.set_rendezvous_key(key.data(), key.size());
71     req_.set_request_id(GetUniqueRequestId());
72   }
73 
Reset()74   void Reset() {
75     // The RpcRemoteRendezvous using this object is responsible for calling
76     // ReleaseWorker() before Reset().
77     DCHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
78         << "Leaking WorkerInterface in RpcRecvTensorCall::Reset().";
79 
80     alloc_attrs_ = AllocatorAttributes();
81     dst_device_ = nullptr;
82     // We don't clear opts_ and assume that Init will set up the state for
83     // opts_ appropriately.
84     req_.Clear();
85     resp_.Clear();
86     {
87       mutex_lock l(mu_);
88       status_ = Status::OK();
89     }
90     done_ = nullptr;
91   }
92 
~RpcRecvTensorCall()93   ~RpcRecvTensorCall() override {
94     // Since only the RpcRecvTensorFreeList will delete an
95     // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
96     // the user releases a Call object to the free list.
97     CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
98         << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
99   }
100 
Start(std::function<void ()> recv_done)101   void Start(std::function<void()> recv_done) override {
102     StartRTCall(std::move(recv_done));
103   }
104 
StartAbort(const Status & s)105   void StartAbort(const Status& s) override {
106     {
107       mutex_lock l(mu_);
108       status_.Update(s);
109     }
110     opts_.StartCancel();
111   }
112 
status() const113   Status status() const override {
114     mutex_lock l(mu_);
115     return status_;
116   }
117 
ReleaseWorker(WorkerCacheInterface * worker_cache)118   void ReleaseWorker(WorkerCacheInterface* worker_cache) {
119     DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
120         << "RpcRecvTensorCall::ReleaseWorker() called twice.";
121     worker_cache->ReleaseWorker(src_worker_, wi_);
122     wi_ = nullptr;
123   }
124 
tensor() const125   const Tensor& tensor() const { return resp_.tensor(); }
126 
is_dead() const127   bool is_dead() const { return resp_.metadata().is_dead(); }
128 
dst_device() const129   Device* dst_device() const { return dst_device_; }
recv_args() const130   const Rendezvous::Args& recv_args() const { return recv_args_; }
done() const131   const Rendezvous::DoneCallback& done() const { return done_; }
132 
133  private:
134   friend class RpcRemoteRendezvous;
135 
136   // Start the main RecvTensor call, checking for an async abort.
StartRTCall(std::function<void ()> recv_done)137   void StartRTCall(std::function<void()> recv_done) {
138     resp_.InitAlloc(dst_device_, alloc_attrs_);
139     using namespace std::placeholders;
140     StatusCallback cb = std::bind(
141         [this](std::function<void()> recv_done,
142                // Begin unbound arguments.
143                const Status& s) {
144           if (!s.ok()) {
145             mutex_lock l(mu_);
146             status_.Update(s);
147           }
148           recv_done();
149         },
150         std::move(recv_done), _1);
151     wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
152   }
153 
154   string src_worker_;
155   string src_rel_device_;
156   WorkerInterface* wi_;  // Not owned.
157   AllocatorAttributes alloc_attrs_;
158   Device* dst_device_;
159   CallOptions opts_;
160   RecvTensorRequest req_;
161   TensorResponse resp_;
162   Rendezvous::Args recv_args_;
163   Rendezvous::DoneCallback done_;
164 
165   mutable mutex mu_;
166   Status status_ GUARDED_BY(mu_);
167 
168   TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
169 };
170 
171 class RpcRecvTensorFreeList {
172  public:
RpcRecvTensorFreeList()173   RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList()174   ~RpcRecvTensorFreeList() {
175     for (size_t i = 0; i < objects_.size(); i++) {
176       delete objects_[i];
177     }
178   }
179 
New()180   RpcRecvTensorCall* New() {
181     {
182       mutex_lock l(mu_);
183       if (!objects_.empty()) {
184         RpcRecvTensorCall* result = objects_.back();
185         objects_.pop_back();
186         return result;
187       }
188     }
189     return new RpcRecvTensorCall;
190   }
191 
Release(RpcRecvTensorCall * obj)192   void Release(RpcRecvTensorCall* obj) {
193     obj->Reset();
194     {
195       mutex_lock l(mu_);
196       if (objects_.size() < kMaxObjects) {
197         objects_.push_back(obj);
198         return;
199       }
200     }
201     delete obj;
202   }
203 
204  private:
205   static const int kMaxObjects = 1000;
206 
207   mutex mu_;
208   std::vector<RpcRecvTensorCall*> objects_ GUARDED_BY(mu_);
209 };
210 
get_call_freelist()211 static RpcRecvTensorFreeList* get_call_freelist() {
212   static RpcRecvTensorFreeList* call_freelist = new RpcRecvTensorFreeList();
213   return call_freelist;
214 }
215 
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)216 void RpcRemoteRendezvous::RecvFromRemoteAsync(
217     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
218     DoneCallback done) {
219   CHECK(is_initialized());
220   Status s;
221 
222   // Prepare a RecvTensor call that can handle being aborted.
223   RpcRecvTensorCall* call = get_call_freelist()->New();
224 
225   // key.src_device identifies a remote device.
226   if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
227                                         &call->src_rel_device_)) {
228     s = errors::Internal(parsed.src_device,
229                          " is invalid remote source device.");
230   }
231   WorkerSession* sess = session();
232   // The worker will be released in a subsequent call to
233   // `sess->worker_cache->ReleaseWorker()` (if the call has not yet been
234   // initialized) or `call->ReleaseWorker()` (if it has been initialized).
235   WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
236   if (s.ok() && rwi == nullptr) {
237     s = errors::Internal("No worker known as ", call->src_worker_);
238   }
239 
240   Device* dst_device;
241   if (s.ok()) {
242     s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
243   }
244   if (!s.ok()) {
245     if (rwi != nullptr) {
246       sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
247     }
248     get_call_freelist()->Release(call);
249     done(s, Args(), recv_args, Tensor{}, false);
250     return;
251   }
252 
253   call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
254              recv_args, std::move(done));
255 
256   // Record "call" in active_ so that it can be aborted cleanly.
257   RegisterCall(call);
258 
259   // RendezvousMgr already aborted, shouldn't send RPC call any more
260   if (!call->status().ok()) {
261     // NOTE: `*sess` can potentially be deleted before we return from
262     // `call->done()(...)`, so we must release the worker before calling the
263     // callback.
264     call->ReleaseWorker(sess->worker_cache.get());
265     call->done()(call->status(), Args(), Args(), Tensor(), false);
266     get_call_freelist()->Release(call);
267     return;
268   }
269 
270   // Start "call".
271   Ref();
272   call->Start([this, call]() {
273     // Removes "call" from active_. Prevent StartAbort().
274     DeregisterCall(call);
275     // If StartAbort was called prior to DeregisterCall, then the
276     // current status should be bad.
277     Status s = call->status();
278     // NOTE: `*session()` can potentially be deleted before we return from
279     // `call->done()(...)`, so we must release the worker before calling the
280     // callback.
281     call->ReleaseWorker(session()->worker_cache.get());
282     call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
283     get_call_freelist()->Release(call);
284     Unref();
285   });
286 }
287 
288 }  // namespace
289 
RpcRendezvousMgr(const WorkerEnv * env)290 RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
291     : BaseRendezvousMgr(env) {}
292 
Create(int64 step_id,const WorkerEnv * worker_env)293 BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
294                                                const WorkerEnv* worker_env) {
295   return new RpcRemoteRendezvous(worker_env, step_id);
296 }
297 
298 }  // end namespace tensorflow
299