1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
17
18 #include <functional>
19
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
23 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/protobuf.h"
30
31 namespace tensorflow {
32 namespace eager {
33
34 namespace {
35
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)36 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
37 remote_op->set_name(op->Name());
38
39 op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
40 remote_op->set_device(op->DeviceName());
41 }
42
CreateUncachedKernelAndDeviceOp(EagerOperation * op,core::RefCountPtr<KernelAndDevice> * kernel)43 Status CreateUncachedKernelAndDeviceOp(
44 EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
45 EagerContext& ctx = op->EagerContext();
46 Device* device = absl::get<Device*>(op->Device());
47
48 FunctionLibraryRuntime* flr = ctx.func_lib(device);
49 if (flr == nullptr) {
50 return errors::Unavailable(
51 "Unable to find a FunctionLibraryRuntime corresponding to device ",
52 device->name());
53 }
54
55 auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx.runner();
56 kernel->reset(new KernelAndDeviceOp(ctx.GetRendezvous(), ctx.LogMemory(), flr,
57 runner, ctx.GetCollectiveExecutorHandle(),
58 ctx.HostCPU()));
59
60 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
61 return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
62 /*graph_collector=*/nullptr);
63 }
64
65 // This gets a unique wire ID. We add a random identifier so that if the
66 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()67 string GetUniqueWireID() {
68 static tensorflow::uint64 random_seed = random::New64();
69 static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
70 static std::atomic<int64_t> wire_id;
71 return strings::StrCat(random_seed, "_", wire_id++);
72 }
73
74 } // namespace
75
RemoteCopyNode(EagerContext * ctx,EagerExecutor * executor,TensorHandle * src,TensorHandle * dst,Device * recv_device,uint64 recv_op_id)76 RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
77 TensorHandle* src, TensorHandle* dst,
78 Device* recv_device, uint64 recv_op_id)
79 : AsyncEagerNode(),
80 src_(src),
81 ctx_(ctx),
82 executor_(executor),
83 send_device_(src->DeviceOrHostCPU(*ctx)),
84 recv_device_(recv_device),
85 wire_id_(GetUniqueWireID()),
86 recv_op_id_(recv_op_id),
87 captured_state_(std::make_shared<CapturedSharedState>(dst)),
88 started_(false) {
89 DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
90 src_->Ref();
91 ctx_->Ref();
92 }
93
~RemoteCopyNode()94 RemoteCopyNode::~RemoteCopyNode() {
95 src_->Unref();
96 ctx_->Unref();
97 }
98
RunLocalSend(EagerOperation * op)99 Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
100 TF_RETURN_IF_ERROR(executor_->status());
101
102 TF_RETURN_IF_ERROR(op->AddInput(src_));
103
104 core::RefCountPtr<KernelAndDevice> kernel;
105 TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
106
107 EagerKernelArgs args(1);
108 Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
109 TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
110 CoordinationServiceAgent* coord_agent = nullptr;
111 if (ctx_->GetDistributedManager() != nullptr)
112 coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
113
114 return kernel->Run(/*step_container=*/nullptr, args, /*outputs=*/nullptr,
115 /*cancellation_manager=*/nullptr,
116 /*remote_func_params=*/absl::nullopt,
117 /*stack_trace=*/absl::nullopt, coord_agent);
118 }
119
StartSend()120 void RemoteCopyNode::StartSend() {
121 // TODO(gjn): We should consider just using the low-level SendOp::Compute()
122 // functionality here instead of constructing an Op.
123 EagerOperation op(ctx_);
124 Status status = op.Reset("_Send", /*raw_device_name=*/nullptr,
125 /*remote=*/false, /*executor=*/nullptr);
126 if (!status.ok()) {
127 captured_state_->SetSendStatus(status);
128 return;
129 }
130
131 op.SetDevice(send_device_);
132
133 op.MutableAttrs()->Set("tensor_name", wire_id_);
134 op.MutableAttrs()->Set("send_device", send_device_->name());
135 op.MutableAttrs()->Set(
136 "send_device_incarnation",
137 static_cast<int64>(send_device_->attributes().incarnation()));
138 op.MutableAttrs()->Set("recv_device", recv_device_->name());
139 op.MutableAttrs()->Set("client_terminated", false);
140
141 op.MutableAttrs()->Set("T", src_->dtype);
142
143 DCHECK(send_device_ != nullptr);
144
145 if (send_device_->IsLocal()) {
146 status = RunLocalSend(&op);
147 captured_state_->SetSendStatus(status);
148 return;
149 } else {
150 // Prepare the request
151 EnqueueRequest request;
152 request.set_context_id(ctx_->GetContextId());
153 auto* remote_op = request.add_queue()->mutable_operation();
154 status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
155 src_, /*wait_until_ready=*/false,
156 remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
157 src_->DeviceOrHostCPU(*ctx_)->name());
158 if (!status.ok()) {
159 captured_state_->SetSendStatus(status);
160 return;
161 }
162
163 PrepareRemoteOp(remote_op, &op);
164 remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
165
166 // Issue the RPC
167 core::RefCountPtr<eager::EagerClient> eager_client;
168 status = ctx_->GetClient(send_device_, &eager_client);
169 if (!status.ok()) {
170 captured_state_->SetSendStatus(status);
171 return;
172 }
173
174 const std::shared_ptr<CapturedSharedState>& captured_state =
175 captured_state_;
176 EnqueueResponse* response = new EnqueueResponse;
177 // If StartRecv fails very quickly, `this` can be destroyed before the
178 // callback below is executed. So, we can't capture `this`.
179 eager_client->StreamingEnqueueAsync(
180 /*call_opts=*/nullptr, &request, response,
181 [response, captured_state](const Status& s) {
182 captured_state->SetSendStatus(s);
183 if (!s.ok()) {
184 captured_state->recv_cancellation()->StartCancel();
185 }
186 delete response;
187 });
188 }
189 }
190
RunLocalRecv(EagerOperation * op,std::vector<Tensor> * outputs)191 Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
192 std::vector<Tensor>* outputs) {
193 TF_RETURN_IF_ERROR(executor_->status());
194
195 core::RefCountPtr<KernelAndDevice> kernel;
196 TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
197
198 EagerKernelArgs args;
199 std::vector<EagerKernelRet> rets;
200 CoordinationServiceAgent* coord_agent = nullptr;
201 if (ctx_->GetDistributedManager() != nullptr)
202 coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
203 TF_RETURN_IF_ERROR(kernel->Run(/*step_container*/ nullptr, args, &rets,
204 captured_state_->recv_cancellation(),
205 /*remote_func_params=*/absl::nullopt,
206 /*stack_trace=*/absl::nullopt, coord_agent));
207 outputs->clear();
208 for (const auto& ret : rets) {
209 if (ret.index() == 0) {
210 outputs->push_back(absl::get<Tensor>(ret));
211 } else {
212 return errors::Internal(
213 "Expect to receive a Tensor but got a TensorShape.");
214 }
215 }
216 return Status::OK();
217 }
218
RunRemoteRecv(EagerOperation * op,StatusCallback done)219 void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
220 EnqueueRequest request;
221 uint64 context_id = ctx_->GetContextId();
222 request.set_context_id(context_id);
223 auto* remote_op = request.add_queue()->mutable_operation();
224 PrepareRemoteOp(remote_op, op);
225 remote_op->set_id(recv_op_id_);
226 uint64 context_view_id = ctx_->GetContextViewId();
227
228 core::RefCountPtr<eager::EagerClient> eager_client;
229 Status status = ctx_->GetClient(recv_device_, &eager_client);
230 if (!status.ok()) {
231 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
232 done(status);
233 return;
234 }
235
236 // Don't issue the recv until send has completed.
237 // - local send will complete very quickly.
238 // - remote send will take some time, but remote->remote copy is
239 // probably rare enough that we don't care much.
240 // Blocks until send has completed.
241 Status send_status = captured_state_->GetSendStatus();
242 if (!send_status.ok()) {
243 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
244 done(send_status);
245 return;
246 }
247
248 EnqueueResponse* response = new EnqueueResponse;
249 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
250 Device* recv_device = recv_device_;
251 eager_client->StreamingEnqueueAsync(
252 /*call_opts=*/nullptr, &request, response,
253 [captured_state, response, recv_device, context_view_id,
254 done](const Status& s) {
255 if (s.ok()) {
256 Status status = captured_state->dst()->SetRemoteShape(
257 response->queue_response(0).shape(0), recv_device,
258 context_view_id);
259 if (!status.ok()) {
260 LOG(ERROR) << "Ignoring an error encountered when setting remote "
261 "shape of tensor received by remote Recv op: "
262 << status.ToString()
263 << "\nThis should never happen. "
264 "Please file an issue with the TensorFlow Team.";
265 }
266 } else {
267 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
268 }
269 done(s);
270 delete response;
271 });
272 }
273
StartRecv(StatusCallback done)274 void RemoteCopyNode::StartRecv(StatusCallback done) {
275 // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
276 // functionality here instead of constructing an Op.
277 EagerOperation op(ctx_);
278 Status status = op.Reset("_Recv", /*raw_device_name=*/nullptr,
279 /*remote=*/false, /*executor=*/nullptr);
280 Device* recv_device = ctx_->CanonicalDevice(recv_device_);
281 if (!status.ok()) {
282 captured_state_->dst()->Poison(status, recv_device);
283 done(status);
284 return;
285 }
286
287 op.SetDevice(recv_device_);
288
289 op.MutableAttrs()->Set("tensor_name", wire_id_);
290 op.MutableAttrs()->Set("send_device", send_device_->name());
291 op.MutableAttrs()->Set(
292 "send_device_incarnation",
293 static_cast<int64>(send_device_->attributes().incarnation()));
294 op.MutableAttrs()->Set("recv_device", recv_device_->name());
295 op.MutableAttrs()->Set("client_terminated", false);
296
297 op.MutableAttrs()->Set("tensor_type", src_->dtype);
298
299 if (recv_device_->IsLocal()) {
300 std::vector<Tensor> outputs(1);
301 status = RunLocalRecv(&op, &outputs);
302 if (!status.ok()) {
303 captured_state_->dst()->Poison(status, recv_device);
304 done(status);
305 return;
306 }
307 status =
308 captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
309 done(status);
310 } else {
311 // Handles captured_state_->dst_ internally.
312 RunRemoteRecv(&op, std::move(done));
313 }
314 }
315
SerializePackedHandle(const uint64 op_id,TensorHandle * packed_handle,const Device * target_device,EagerContext * ctx,SendPackedHandleOp * op)316 Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
317 const Device* target_device, EagerContext* ctx,
318 SendPackedHandleOp* op) {
319 op->set_op_id(op_id);
320 op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
321 for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
322 TensorHandle* h = nullptr;
323 TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
324 if (h->Type() == TensorHandle::LOCAL) {
325 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
326 // copy it to the CPU before copying it out.
327 Tensor tensor;
328 TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
329 auto* local_handle = op->add_handles()->mutable_local_handle();
330 local_handle->set_device(h->op_device() ? h->op_device()->name()
331 : ctx->HostCPU()->name());
332 tensor.AsProtoTensorContent(local_handle->mutable_tensor());
333 } else if (h->Type() == TensorHandle::REMOTE) {
334 // Only serialize the resource dtype and shape of the first handle, since
335 // all handles are of the same resource dtype and shape.
336 // If src_device is on the same task of target_device, the handle is a
337 // local handle on the target device, which means the resource dtype and
338 // shape are known on the target device.
339 Device* src_device = h->device();
340 const bool serialize_resource_dtype_and_shape =
341 (i == 0) && (h->dtype == DT_RESOURCE) &&
342 (!ctx->OnSameTask(src_device, target_device));
343 // For a remote component function, a function execution request and an
344 // input generation request may come from different workers. We need to
345 // guarantee that the input generation request is processed before the
346 // function execution request, so wait until the underlying remote handles
347 // are ready before sending a packed handle to the function device.
348 TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
349 h, /*wait_until_ready=*/true,
350 op->add_handles()->mutable_remote_handle(), src_device,
351 h->DeviceOrHostCPU(*ctx)->name(),
352 serialize_resource_dtype_and_shape));
353 } else {
354 return errors::InvalidArgument("Nested packed handles are not supported");
355 }
356 }
357 return Status::OK();
358 }
359
StartSendPackedHandle(StatusCallback done)360 void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
361 Status s;
362 const uint64 context_view_id = ctx_->GetContextViewId();
363 if (!send_device_->IsLocal()) {
364 s = errors::InvalidArgument(
365 "Copy a packed handle from a remote device is not supported");
366 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
367 done(s);
368 return;
369 }
370
371 EnqueueRequest request;
372 uint64 context_id = ctx_->GetContextId();
373 request.set_context_id(context_id);
374 s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
375 request.add_queue()->mutable_send_packed_handle());
376 if (!s.ok()) {
377 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
378 done(s);
379 return;
380 }
381
382 TensorShape shape;
383 s = src_->Shape(&shape);
384 if (!s.ok()) {
385 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
386 done(s);
387 return;
388 }
389 captured_state_->SetSrcShape(shape);
390
391 core::RefCountPtr<eager::EagerClient> eager_client;
392 s = ctx_->GetClient(recv_device_, &eager_client);
393 if (!s.ok()) {
394 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
395 done(s);
396 return;
397 }
398
399 EnqueueResponse* response = new EnqueueResponse;
400 Device* recv_device = recv_device_;
401 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
402 eager_client->StreamingEnqueueAsync(
403 /*call_opts=*/nullptr, &request, response,
404 [captured_state, response, recv_device, context_view_id,
405 done](const Status& s) {
406 if (s.ok()) {
407 Status status = captured_state->dst()->SetRemoteShape(
408 captured_state->GetSrcShape(), recv_device, context_view_id);
409 if (!status.ok()) {
410 LOG(ERROR) << "Ignoring an error encountered when setting remote "
411 "shape of tensor received by SendPackedHadnle rpc: "
412 << status.ToString();
413 }
414 } else {
415 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
416 }
417 done(s);
418 delete response;
419 });
420 }
421
StartRemoteSendTensor(StatusCallback done)422 void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
423 Status s;
424 EnqueueRequest request;
425 uint64 context_id = ctx_->GetContextId();
426 request.set_context_id(context_id);
427 auto* send_tensor = request.add_queue()->mutable_send_tensor();
428 send_tensor->set_op_id(recv_op_id_);
429 send_tensor->set_device_name(recv_device_->name());
430 uint64 context_view_id = ctx_->GetContextViewId();
431
432 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
433 // copy it to the CPU before copying it out.
434 // TODO(fishx): Make CopyToDevice asynchronous.
435 Tensor tensor;
436 s = src_->CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor);
437 if (!s.ok()) {
438 done(s);
439 return;
440 }
441 tensor.AsProtoTensorContent(send_tensor->add_tensors());
442
443 core::RefCountPtr<eager::EagerClient> eager_client;
444 s = ctx_->GetClient(recv_device_, &eager_client);
445 if (!s.ok()) {
446 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
447 done(s);
448 return;
449 }
450 EnqueueResponse* response = new EnqueueResponse;
451 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
452 captured_state->SetSrcShape(tensor.shape());
453 Device* recv_device = recv_device_;
454 eager_client->StreamingEnqueueAsync(
455 /*call_opts=*/nullptr, &request, response,
456 [captured_state, response, recv_device, context_view_id,
457 done](const Status& s) {
458 if (s.ok()) {
459 Status status = captured_state->dst()->SetRemoteShape(
460 captured_state->GetSrcShape(), recv_device, context_view_id);
461 if (!status.ok()) {
462 LOG(ERROR) << "Ignoring an error encountered when setting remote "
463 "shape of tensor received by SendTensor rpc: "
464 << status.ToString();
465 }
466 } else {
467 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
468 }
469 done(s);
470 delete response;
471 });
472 }
473
Prepare()474 Status RemoteCopyNode::Prepare() {
475 TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_));
476 return Status::OK();
477 }
478
RunAsync(StatusCallback done)479 void RemoteCopyNode::RunAsync(StatusCallback done) {
480 started_ = true;
481 if (src_->Type() == TensorHandle::PACKED) {
482 return StartSendPackedHandle(std::move(done));
483 }
484
485 if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
486 !recv_device_->IsLocal()) {
487 return StartRemoteSendTensor(std::move(done));
488 }
489 StartSend();
490
491 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
492 auto done_wrapper = [captured_state,
493 done = std::move(done)](const Status& s) {
494 if (!s.ok() && errors::IsCancelled(s)) {
495 Status send_status = captured_state->GetSendStatus();
496 if (!send_status.ok()) {
497 // In this case, Recv is cancelled because the Send op failed.
498 // Return the status of the Send op instead.
499 done(send_status);
500 }
501 } else {
502 done(s);
503 }
504 };
505
506 // StartRecv() takes care of doing the right thing to dst handle.
507 // No need to poison it after this point.
508 StartRecv(std::move(done_wrapper));
509 }
510
Abort(Status status)511 void RemoteCopyNode::Abort(Status status) {
512 if (!started_) {
513 uint64 context_view_id = ctx_->GetContextViewId();
514 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
515 }
516 }
517
518 } // namespace eager
519 } // namespace tensorflow
520