1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace tensorflow {
37
38 namespace {
39
40 class RpcRemoteRendezvous : public BaseRemoteRendezvous {
41 public:
RpcRemoteRendezvous(const WorkerEnv * env,int64 step_id)42 RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
43 : BaseRemoteRendezvous(env, step_id) {}
44
45 protected:
46 void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
47 const Rendezvous::Args& args,
48 DoneCallback done) override;
49
50 private:
~RpcRemoteRendezvous()51 ~RpcRemoteRendezvous() override {}
52
53 TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
54 };
55
56 // Used only to retrieve tensors from remote processes.
57 class RpcRecvTensorCall : public BaseRecvTensorCall {
58 public:
RpcRecvTensorCall()59 RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
60
Init(WorkerInterface * wi,int64 step_id,StringPiece key,AllocatorAttributes alloc_attrs,Device * dst_device,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)61 void Init(WorkerInterface* wi, int64 step_id, StringPiece key,
62 AllocatorAttributes alloc_attrs, Device* dst_device,
63 const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
64 wi_ = wi;
65 alloc_attrs_ = alloc_attrs;
66 dst_device_ = dst_device;
67 recv_args_ = recv_args;
68 done_ = std::move(done);
69 req_.set_step_id(step_id);
70 req_.set_rendezvous_key(key.data(), key.size());
71 req_.set_request_id(GetUniqueRequestId());
72 }
73
Reset()74 void Reset() {
75 // The RpcRemoteRendezvous using this object is responsible for calling
76 // ReleaseWorker() before Reset().
77 DCHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
78 << "Leaking WorkerInterface in RpcRecvTensorCall::Reset().";
79
80 alloc_attrs_ = AllocatorAttributes();
81 dst_device_ = nullptr;
82 // We don't clear opts_ and assume that Init will set up the state for
83 // opts_ appropriately.
84 req_.Clear();
85 resp_.Clear();
86 {
87 mutex_lock l(mu_);
88 status_ = Status::OK();
89 }
90 done_ = nullptr;
91 }
92
~RpcRecvTensorCall()93 ~RpcRecvTensorCall() override {
94 // Since only the RpcRecvTensorFreeList will delete an
95 // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
96 // the user releases a Call object to the free list.
97 CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
98 << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
99 }
100
Start(std::function<void ()> recv_done)101 void Start(std::function<void()> recv_done) override {
102 StartRTCall(std::move(recv_done));
103 }
104
StartAbort(const Status & s)105 void StartAbort(const Status& s) override {
106 {
107 mutex_lock l(mu_);
108 status_.Update(s);
109 }
110 opts_.StartCancel();
111 }
112
status() const113 Status status() const override {
114 mutex_lock l(mu_);
115 return status_;
116 }
117
ReleaseWorker(WorkerCacheInterface * worker_cache)118 void ReleaseWorker(WorkerCacheInterface* worker_cache) {
119 DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
120 << "RpcRecvTensorCall::ReleaseWorker() called twice.";
121 worker_cache->ReleaseWorker(src_worker_, wi_);
122 wi_ = nullptr;
123 }
124
tensor() const125 const Tensor& tensor() const { return resp_.tensor(); }
126
is_dead() const127 bool is_dead() const { return resp_.metadata().is_dead(); }
128
dst_device() const129 Device* dst_device() const { return dst_device_; }
recv_args() const130 const Rendezvous::Args& recv_args() const { return recv_args_; }
done() const131 const Rendezvous::DoneCallback& done() const { return done_; }
132
133 private:
134 friend class RpcRemoteRendezvous;
135
136 // Start the main RecvTensor call, checking for an async abort.
StartRTCall(std::function<void ()> recv_done)137 void StartRTCall(std::function<void()> recv_done) {
138 resp_.InitAlloc(dst_device_, alloc_attrs_);
139 using namespace std::placeholders;
140 StatusCallback cb = std::bind(
141 [this](std::function<void()> recv_done,
142 // Begin unbound arguments.
143 const Status& s) {
144 if (!s.ok()) {
145 mutex_lock l(mu_);
146 status_.Update(s);
147 }
148 recv_done();
149 },
150 std::move(recv_done), _1);
151 wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
152 }
153
154 string src_worker_;
155 string src_rel_device_;
156 WorkerInterface* wi_; // Not owned.
157 AllocatorAttributes alloc_attrs_;
158 Device* dst_device_;
159 CallOptions opts_;
160 RecvTensorRequest req_;
161 TensorResponse resp_;
162 Rendezvous::Args recv_args_;
163 Rendezvous::DoneCallback done_;
164
165 mutable mutex mu_;
166 Status status_ GUARDED_BY(mu_);
167
168 TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
169 };
170
171 class RpcRecvTensorFreeList {
172 public:
RpcRecvTensorFreeList()173 RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList()174 ~RpcRecvTensorFreeList() {
175 for (size_t i = 0; i < objects_.size(); i++) {
176 delete objects_[i];
177 }
178 }
179
New()180 RpcRecvTensorCall* New() {
181 {
182 mutex_lock l(mu_);
183 if (!objects_.empty()) {
184 RpcRecvTensorCall* result = objects_.back();
185 objects_.pop_back();
186 return result;
187 }
188 }
189 return new RpcRecvTensorCall;
190 }
191
Release(RpcRecvTensorCall * obj)192 void Release(RpcRecvTensorCall* obj) {
193 obj->Reset();
194 {
195 mutex_lock l(mu_);
196 if (objects_.size() < kMaxObjects) {
197 objects_.push_back(obj);
198 return;
199 }
200 }
201 delete obj;
202 }
203
204 private:
205 static const int kMaxObjects = 1000;
206
207 mutex mu_;
208 std::vector<RpcRecvTensorCall*> objects_ GUARDED_BY(mu_);
209 };
210
get_call_freelist()211 static RpcRecvTensorFreeList* get_call_freelist() {
212 static RpcRecvTensorFreeList* call_freelist = new RpcRecvTensorFreeList();
213 return call_freelist;
214 }
215
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)216 void RpcRemoteRendezvous::RecvFromRemoteAsync(
217 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
218 DoneCallback done) {
219 CHECK(is_initialized());
220 Status s;
221
222 // Prepare a RecvTensor call that can handle being aborted.
223 RpcRecvTensorCall* call = get_call_freelist()->New();
224
225 // key.src_device identifies a remote device.
226 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
227 &call->src_rel_device_)) {
228 s = errors::Internal(parsed.src_device,
229 " is invalid remote source device.");
230 }
231 WorkerSession* sess = session();
232 // The worker will be released in a subsequent call to
233 // `sess->worker_cache->ReleaseWorker()` (if the call has not yet been
234 // initialized) or `call->ReleaseWorker()` (if it has been initialized).
235 WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
236 if (s.ok() && rwi == nullptr) {
237 s = errors::Internal("No worker known as ", call->src_worker_);
238 }
239
240 Device* dst_device;
241 if (s.ok()) {
242 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
243 }
244 if (!s.ok()) {
245 if (rwi != nullptr) {
246 sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
247 }
248 get_call_freelist()->Release(call);
249 done(s, Args(), recv_args, Tensor{}, false);
250 return;
251 }
252
253 call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
254 recv_args, std::move(done));
255
256 // Record "call" in active_ so that it can be aborted cleanly.
257 RegisterCall(call);
258
259 // RendezvousMgr already aborted, shouldn't send RPC call any more
260 if (!call->status().ok()) {
261 // NOTE: `*sess` can potentially be deleted before we return from
262 // `call->done()(...)`, so we must release the worker before calling the
263 // callback.
264 call->ReleaseWorker(sess->worker_cache.get());
265 call->done()(call->status(), Args(), Args(), Tensor(), false);
266 get_call_freelist()->Release(call);
267 return;
268 }
269
270 // Start "call".
271 Ref();
272 call->Start([this, call]() {
273 // Removes "call" from active_. Prevent StartAbort().
274 DeregisterCall(call);
275 // If StartAbort was called prior to DeregisterCall, then the
276 // current status should be bad.
277 Status s = call->status();
278 // NOTE: `*session()` can potentially be deleted before we return from
279 // `call->done()(...)`, so we must release the worker before calling the
280 // callback.
281 call->ReleaseWorker(session()->worker_cache.get());
282 call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
283 get_call_freelist()->Release(call);
284 Unref();
285 });
286 }
287
288 } // namespace
289
RpcRendezvousMgr(const WorkerEnv * env)290 RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
291 : BaseRendezvousMgr(env) {}
292
Create(int64 step_id,const WorkerEnv * worker_env)293 BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
294 const WorkerEnv* worker_env) {
295 return new RpcRemoteRendezvous(worker_env, step_id);
296 }
297
298 } // end namespace tensorflow
299