1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/distributed_runtime/worker_cache.h"
27 #include "tensorflow/core/distributed_runtime/worker_interface.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/strings/numbers.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace tensorflow {
37
StartAbortRendevous(Rendezvous * rendez,const Status & s)38 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
39 rendez->StartAbort(s);
40 rendez->Unref();
41 }
42
BaseRendezvousMgr(const WorkerEnv * worker_env)43 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
44 : worker_env_(worker_env) {}
45
~BaseRendezvousMgr()46 BaseRendezvousMgr::~BaseRendezvousMgr() {
47 for (auto& p : table_) {
48 auto rendez = p.second;
49 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
50 }
51 }
52
Find(int64 step_id)53 RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
54 return FindOrCreate(step_id);
55 }
56
FindOrCreate(int64 step_id)57 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
58 mutex_lock l(mu_);
59 auto iter = table_.find(step_id);
60 if (iter == table_.end()) {
61 auto rr = Create(step_id, worker_env_);
62 iter = table_.insert({step_id, rr}).first;
63 }
64 iter->second->Ref();
65 return iter->second;
66 }
67
RecvLocalAsync(int64 step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)68 void BaseRendezvousMgr::RecvLocalAsync(int64 step_id,
69 const Rendezvous::ParsedKey& parsed,
70 Rendezvous::DoneCallback done) {
71 auto rendez = FindOrCreate(step_id);
72 using namespace std::placeholders;
73 Rendezvous::DoneCallback done_cb = std::bind(
74 [rendez](Rendezvous::DoneCallback done,
75 // Begin unbound arguments.
76 const Status& s, const Rendezvous::Args& send_args,
77 const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
78 rendez->Unref();
79 done(s, send_args, recv_args, v, dead);
80 },
81 std::move(done), _1, _2, _3, _4, _5);
82 rendez->RecvLocalAsync(parsed, std::move(done_cb));
83 }
84
RecvLocal(int64 step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)85 Status BaseRendezvousMgr::RecvLocal(int64 step_id,
86 const Rendezvous::ParsedKey& parsed,
87 Tensor* val, bool* is_dead) {
88 Status ret;
89 Notification n;
90 RecvLocalAsync(step_id, parsed,
91 [val, is_dead, &ret, &n](const Status& s,
92 const Rendezvous::Args& send_args,
93 const Rendezvous::Args& recv_args,
94 const Tensor& v, const bool dead) {
95 ret = s;
96 *val = v;
97 *is_dead = dead;
98 n.Notify();
99 });
100 n.WaitForNotification();
101 return ret;
102 }
103
Cleanup(int64 step_id)104 void BaseRendezvousMgr::Cleanup(int64 step_id) {
105 Rendezvous* rendez = nullptr;
106 {
107 mutex_lock l(mu_);
108 auto iter = table_.find(step_id);
109 if (iter != table_.end()) {
110 rendez = iter->second;
111 table_.erase(iter);
112 }
113 }
114 if (rendez) {
115 StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
116 }
117 }
118
CleanupAll()119 void BaseRendezvousMgr::CleanupAll() {
120 std::vector<Rendezvous*> rendezs;
121 {
122 mutex_lock l(mu_);
123 for (const auto& entry : table_) {
124 rendezs.push_back(entry.second);
125 }
126 table_.clear();
127 }
128 for (auto rendez : rendezs) {
129 StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
130 }
131 }
132
BaseRemoteRendezvous(const WorkerEnv * env,int64 step_id)133 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id)
134 : env_(env),
135 step_id_(step_id),
136 local_(NewLocalRendezvous()),
137 session_(nullptr) {}
138
~BaseRemoteRendezvous()139 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
140 CHECK(active_.empty());
141 local_->Unref();
142 }
143
144 // Returns true if "device_name" is a valid full name of local device
145 // of the "worker". This helper is purely based on the worker name
146 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)147 static bool IsLocalDevice(const StringPiece worker_name,
148 const StringPiece device_name) {
149 return str_util::StartsWith(device_name, worker_name);
150 }
151
Initialize(WorkerSession * session)152 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
153 CHECK_NE(session, nullptr) << "session must not be null!";
154 std::vector<DeferredCall> deferred_calls;
155 {
156 mutex_lock l(mu_);
157 if (session_ != nullptr) {
158 if (session_->worker_name == session->worker_name) {
159 LOG(INFO) << "Skipping rendezvous re-initialization.";
160 return Status::OK();
161 }
162 Status s = errors::Internal(
163 "Double init! Worker names would have changed from: ",
164 session_->worker_name, " -> ", session->worker_name);
165 LOG(WARNING) << s;
166 return s;
167 }
168 session_ = session;
169 std::swap(deferred_calls, deferred_calls_);
170 }
171 for (auto& call : deferred_calls) {
172 RecvLocalAsyncInternal(call.parsed, std::move(call.done));
173 }
174 return Status::OK();
175 }
176
session()177 WorkerSession* BaseRemoteRendezvous::session() {
178 mutex_lock l(mu_);
179 return session_;
180 }
181
is_initialized()182 bool BaseRemoteRendezvous::is_initialized() {
183 mutex_lock l(mu_);
184 return is_initialized_locked();
185 }
186
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)187 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
188 const Rendezvous::Args& args,
189 const Tensor& val, const bool is_dead) {
190 VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
191 {
192 mutex_lock l(mu_);
193 if (!status_.ok()) return status_;
194 DCHECK(is_initialized_locked());
195 if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
196 return errors::InvalidArgument(
197 "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
198 session_->worker_name);
199 }
200 }
201 // Buffers "val" and "device_context" in local_.
202 return local_->Send(parsed, args, val, is_dead);
203 }
204
ValidateDevices(const ParsedKey & parsed,bool is_src)205 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
206 bool is_src) {
207 // Cache session pointer to avoid repeatedly taking & releasing the lock
208 // (e.g. calling session())
209 WorkerSession* sess = nullptr;
210 {
211 mutex_lock l(mu_);
212 if (!status_.ok()) return status_;
213 if (!is_initialized_locked()) {
214 return errors::Internal("ValidateDevices called before initialization.");
215 }
216 sess = session_;
217 }
218 if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
219 return errors::InvalidArgument("Invalid rendezvous key (src): ",
220 parsed.FullKey(), " @ ", sess->worker_name);
221 }
222 if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
223 return errors::InvalidArgument("Invalid rendezvous key (dst): ",
224 parsed.FullKey(), " @ ", sess->worker_name);
225 }
226 return Status::OK();
227 }
228
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)229 void BaseRemoteRendezvous::SameWorkerRecvDone(
230 const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
231 const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
232 StatusCallback done) {
233 // Do a quick copy (sharing the underlying buffer) if both tensors
234 // are on host memory.
235 const bool src_host =
236 (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
237 const bool dst_host =
238 (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
239 if (src_host && dst_host) {
240 *out = in;
241 done(Status::OK());
242 return;
243 }
244
245 // This copy must involve a GPU. Hence, "in" must support DMA
246 // (e.g., string tensors do not work on GPU). Variant copy DMA
247 // checks happen inside CopyTensor::ViaDMA.
248 if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT) {
249 done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
250 " tensor may not be copied from/to a GPU."));
251 return;
252 }
253
254 WorkerSession* sess = session();
255 Device* src_device;
256 Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
257 if (!s.ok()) {
258 done(s);
259 return;
260 }
261 Device* dst_device;
262 s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
263 if (!s.ok()) {
264 done(s);
265 return;
266 }
267
268 AllocatorAttributes attr = recv_args.alloc_attrs;
269 attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
270 recv_args.alloc_attrs.gpu_compatible());
271 Allocator* out_allocator = dst_device->GetAllocator(attr);
272
273 if (in.dtype() != DT_VARIANT) {
274 // Variants are handled by CopyTensor::ViaDMA.
275 Tensor copy(out_allocator, in.dtype(), in.shape());
276 *out = copy;
277 }
278
279 // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
280 // etc.
281 CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
282 recv_args.device_context, src_device, dst_device,
283 send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
284 0 /*dev_to_dev_stream_index*/, std::move(done));
285 }
286
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)287 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
288 DeviceNameUtils::ParsedName dst) {
289 return DeviceNameUtils::IsSameAddressSpace(src, dst);
290 }
291
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)292 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
293 const Rendezvous::Args& recv_args,
294 DoneCallback done) {
295 VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
296 Status s = ValidateDevices(parsed, false /*!is_src*/);
297 if (s.ok() && !is_initialized()) {
298 s.Update(errors::Internal(
299 "RecvAsync called when uninitialized (key:", parsed.FullKey(), ")."));
300 }
301 if (!s.ok()) {
302 done(s, Args(), recv_args, Tensor(), false);
303 return;
304 }
305
306 // Are src and dst in the same worker?
307 if (IsSameWorker(parsed.src, parsed.dst)) {
308 // Recv the tensor from local_.
309 local_->RecvAsync(
310 parsed, recv_args,
311 [this, parsed, done](
312 const Status& status, const Rendezvous::Args& send_args,
313 const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
314 Tensor* out = new Tensor;
315 StatusCallback final_callback = [done, send_args, recv_args, out,
316 is_dead](const Status& s) {
317 done(s, send_args, recv_args, *out, is_dead);
318 delete out;
319 };
320
321 if (status.ok()) {
322 SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
323 std::move(final_callback));
324 } else {
325 final_callback(status);
326 }
327 });
328 return;
329 } else {
330 RecvFromRemoteAsync(parsed, recv_args, std::move(done));
331 }
332 }
333
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)334 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
335 DoneCallback done) {
336 {
337 mutex_lock l(mu_);
338 if (!is_initialized_locked()) {
339 // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
340 // remote worker) before the RunStep (or PartialRunStep) RPC from the
341 // master arrives. RecvLocalAsync thus buffers the arguments until after
342 // the RemoteRendezvous is Initialize()'d, when it completes the
343 // rendezvous logic. At some point after Initialize() is called, a Tensor
344 // is produced locally that will then be sent in response to the incoming
345 // RPC.
346 DeferredCall call(parsed, std::move(done));
347 deferred_calls_.push_back(call);
348 return;
349 }
350 }
351 RecvLocalAsyncInternal(parsed, std::move(done));
352 }
353
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)354 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
355 DoneCallback done) {
356 Status s = ValidateDevices(parsed, true /* is_src */);
357 if (!s.ok()) {
358 done(s, Args(), Args(), Tensor(), false);
359 return;
360 }
361 local_->RecvAsync(parsed, Args(), std::move(done));
362 }
363
StartAbort(const Status & s)364 void BaseRemoteRendezvous::StartAbort(const Status& s) {
365 CHECK(!s.ok());
366 local_->StartAbort(s);
367 {
368 // Aborts all active RecvTensor calls.
369 mutex_lock l(mu_);
370 if (status_.ok()) {
371 status_ = s;
372 for (BaseRecvTensorCall* call : active_) {
373 call->StartAbort(s);
374 }
375 active_.clear();
376 }
377 }
378 }
379
RegisterCall(BaseRecvTensorCall * call)380 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
381 mutex_lock l(mu_);
382 if (!status_.ok()) {
383 call->StartAbort(status_);
384 } else {
385 CHECK(active_.insert(call).second);
386 }
387 }
388
DeregisterCall(BaseRecvTensorCall * call)389 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
390 mutex_lock l(mu_);
391 active_.erase(call);
392 }
393
DeferredCall(const ParsedKey & parsed,DoneCallback done)394 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
395 DoneCallback done)
396 : parsed(parsed), done(std::move(done)) {}
397
398 } // end namespace tensorflow
399