1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/dma_helper.h"
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/distributed_runtime/request_id.h"
25 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/notification.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace tensorflow {
38
39 namespace {
40
41 class RpcRemoteRendezvous : public BaseRemoteRendezvous {
42 public:
RpcRemoteRendezvous(const WorkerEnv * env,int64 step_id)43 RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
44 : BaseRemoteRendezvous(env, step_id) {}
45
46 protected:
47 void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
48 const Rendezvous::Args& args,
49 DoneCallback done) override;
50
51 private:
~RpcRemoteRendezvous()52 ~RpcRemoteRendezvous() override {}
53
54 TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
55 };
56
57 // Used only to retrieve tensors from remote processes.
58 class RpcRecvTensorCall : public BaseRecvTensorCall {
59 public:
RpcRecvTensorCall()60 RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
61
Init(WorkerInterface * wi,int64 step_id,StringPiece key,AllocatorAttributes alloc_attrs,Device * dst_device,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)62 void Init(WorkerInterface* wi, int64 step_id, StringPiece key,
63 AllocatorAttributes alloc_attrs, Device* dst_device,
64 const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
65 wi_ = wi;
66 alloc_attrs_ = alloc_attrs;
67 dst_device_ = dst_device;
68 recv_args_ = recv_args;
69 done_ = std::move(done);
70 req_.set_step_id(step_id);
71 req_.set_rendezvous_key(key.data(), key.size());
72 req_.set_request_id(GetUniqueRequestId());
73 }
74
Reset()75 void Reset() {
76 // The RpcRemoteRendezvous using this object is responsible for calling
77 // ReleaseWorker() before Reset().
78 DCHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
79 << "Leaking WorkerInterface in RpcRecvTensorCall::Reset().";
80
81 alloc_attrs_ = AllocatorAttributes();
82 dst_device_ = nullptr;
83 // We don't clear opts_ and assume that Init will set up the state for
84 // opts_ appropriately.
85 req_.Clear();
86 resp_.Clear();
87 {
88 mutex_lock l(mu_);
89 status_ = Status::OK();
90 }
91 done_ = nullptr;
92 }
93
~RpcRecvTensorCall()94 ~RpcRecvTensorCall() override {
95 // Since only the RpcRecvTensorFreeList will delete an
96 // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
97 // the user releases a Call object to the free list.
98 CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
99 << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
100 }
101
Start(std::function<void ()> recv_done)102 void Start(std::function<void()> recv_done) override {
103 StartRTCall(std::move(recv_done));
104 }
105
StartAbort(const Status & s)106 void StartAbort(const Status& s) override {
107 {
108 mutex_lock l(mu_);
109 status_.Update(s);
110 }
111 opts_.StartCancel();
112 }
113
status() const114 Status status() const override {
115 mutex_lock l(mu_);
116 return status_;
117 }
118
ReleaseWorker(WorkerCacheInterface * worker_cache)119 void ReleaseWorker(WorkerCacheInterface* worker_cache) {
120 DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
121 << "RpcRecvTensorCall::ReleaseWorker() called twice.";
122 worker_cache->ReleaseWorker(src_worker_, wi_);
123 wi_ = nullptr;
124 }
125
tensor() const126 const Tensor& tensor() const { return resp_.tensor(); }
127
is_dead() const128 bool is_dead() const { return resp_.metadata().is_dead(); }
129
dst_device() const130 Device* dst_device() const { return dst_device_; }
recv_args() const131 const Rendezvous::Args& recv_args() const { return recv_args_; }
done() const132 const Rendezvous::DoneCallback& done() const { return done_; }
133
134 private:
135 friend class RpcRemoteRendezvous;
136
137 // Start the main RecvTensor call, checking for an async abort.
StartRTCall(std::function<void ()> recv_done)138 void StartRTCall(std::function<void()> recv_done) {
139 resp_.InitAlloc(dst_device_, alloc_attrs_);
140 auto abort_checked = std::make_shared<Notification>();
141 auto cb = [this, abort_checked,
142 recv_done = std::move(recv_done)](const Status& s) {
143 // Make sure the Rendezvous abort checking is finished before running the
144 // callback, which might destroy the current call object.
145 abort_checked->WaitForNotification();
146 if (!s.ok()) {
147 mutex_lock l(mu_);
148 status_.Update(s);
149 }
150 recv_done();
151 };
152 wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
153
154 // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
155 // ordering is important because `StartAbort` could be called right before
156 // the `RecvTensorAsync` request registers its RPC cancellation to `opts_`.
157 // In that case, the previous `StartAbort` would not trigger the
158 // cancellation of this call.
159 Status s;
160 {
161 mutex_lock l(mu_);
162 s = status_;
163 }
164 if (!s.ok()) {
165 opts_.StartCancel();
166 }
167 // Notify that the abort check has finished.
168 abort_checked->Notify();
169 }
170
171 string src_worker_;
172 string src_rel_device_;
173 WorkerInterface* wi_; // Not owned.
174 AllocatorAttributes alloc_attrs_;
175 Device* dst_device_;
176 CallOptions opts_;
177 RecvTensorRequest req_;
178 TensorResponse resp_;
179 Rendezvous::Args recv_args_;
180 Rendezvous::DoneCallback done_;
181
182 mutable mutex mu_;
183 Status status_ TF_GUARDED_BY(mu_);
184
185 TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
186 };
187
188 class RpcRecvTensorFreeList {
189 public:
RpcRecvTensorFreeList()190 RpcRecvTensorFreeList() {}
~RpcRecvTensorFreeList()191 ~RpcRecvTensorFreeList() {
192 for (size_t i = 0; i < objects_.size(); i++) {
193 delete objects_[i];
194 }
195 }
196
New()197 RpcRecvTensorCall* New() {
198 {
199 mutex_lock l(mu_);
200 if (!objects_.empty()) {
201 RpcRecvTensorCall* result = objects_.back();
202 objects_.pop_back();
203 return result;
204 }
205 }
206 return new RpcRecvTensorCall;
207 }
208
Release(RpcRecvTensorCall * obj)209 void Release(RpcRecvTensorCall* obj) {
210 obj->Reset();
211 {
212 mutex_lock l(mu_);
213 if (objects_.size() < kMaxObjects) {
214 objects_.push_back(obj);
215 return;
216 }
217 }
218 delete obj;
219 }
220
221 private:
222 static constexpr int kMaxObjects = 1000;
223
224 mutex mu_;
225 std::vector<RpcRecvTensorCall*> objects_ TF_GUARDED_BY(mu_);
226 };
227
get_call_freelist()228 static RpcRecvTensorFreeList* get_call_freelist() {
229 static RpcRecvTensorFreeList* call_freelist = new RpcRecvTensorFreeList();
230 return call_freelist;
231 }
232
RecvFromRemoteAsync(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)233 void RpcRemoteRendezvous::RecvFromRemoteAsync(
234 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
235 DoneCallback done) {
236 CHECK(is_initialized());
237 Status s;
238
239 // Prepare a RecvTensor call that can handle being aborted.
240 RpcRecvTensorCall* call = get_call_freelist()->New();
241
242 // key.src_device identifies a remote device.
243 if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
244 &call->src_rel_device_)) {
245 s = errors::Internal(parsed.src_device,
246 " is invalid remote source device.");
247 }
248 WorkerSession* sess = session();
249 std::shared_ptr<WorkerCacheInterface> worker_cache =
250 sess->GetSharedWorkerCache();
251 // The worker will be released in a subsequent call to
252 // `sess->worker_cache()->ReleaseWorker()` (if the call has not yet been
253 // initialized) or `call->ReleaseWorker()` (if it has been initialized).
254 WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);
255 if (s.ok() && rwi == nullptr) {
256 s = errors::Internal("No worker known as ", call->src_worker_);
257 }
258
259 Device* dst_device;
260 if (s.ok()) {
261 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
262 }
263 if (!s.ok()) {
264 if (rwi != nullptr) {
265 sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
266 }
267 get_call_freelist()->Release(call);
268 done(s, Args(), recv_args, Tensor{}, false);
269 return;
270 }
271
272 call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
273 recv_args, std::move(done));
274
275 // Record "call" in active_ so that it can be aborted cleanly.
276 RegisterCall(call, recv_args);
277
278 // RendezvousMgr already aborted, shouldn't send RPC call any more
279 if (!call->status().ok()) {
280 DeregisterCall(call);
281 // NOTE: `*sess` can potentially be deleted before we return from
282 // `call->done()(...)`, so we must release the worker before calling the
283 // callback.
284 call->ReleaseWorker(sess->worker_cache());
285 call->done()(call->status(), Args(), Args(), Tensor(), false);
286 get_call_freelist()->Release(call);
287 return;
288 }
289
290 // Start "call".
291 Ref();
292 call->Start([this, call, worker_cache]() {
293 // Removes "call" from active_. Prevent StartAbort().
294 DeregisterCall(call);
295 // If StartAbort was called prior to DeregisterCall, then the
296 // current status should be bad.
297 Status s = call->status();
298 // NOTE: `*session()` can potentially be deleted before we return from
299 // `call->done()(...)`, so we must release the worker before calling the
300 // callback.
301 call->ReleaseWorker(session()->worker_cache());
302 call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
303 get_call_freelist()->Release(call);
304 Unref();
305 });
306 }
307
308 } // namespace
309
RpcRendezvousMgr(const WorkerEnv * env)310 RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
311 : BaseRendezvousMgr(env) {}
312
Create(int64 step_id,const WorkerEnv * worker_env)313 BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
314 const WorkerEnv* worker_env) {
315 return new RpcRemoteRendezvous(worker_env, step_id);
316 }
317
318 } // end namespace tensorflow
319