1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
17
18 #include <functional>
19
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
23 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/protobuf.h"
30
31 namespace tensorflow {
32 namespace eager {
33
34 namespace {
35
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)36 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
37 remote_op->set_name(op->Name());
38
39 op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
40 remote_op->set_device(op->DeviceName());
41 }
42
CreateUncachedKernelAndDeviceOp(EagerOperation * op,core::RefCountPtr<KernelAndDevice> * kernel)43 Status CreateUncachedKernelAndDeviceOp(
44 EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
45 EagerContext& ctx = op->EagerContext();
46 Device* device = absl::get<Device*>(op->Device());
47
48 FunctionLibraryRuntime* flr = ctx.func_lib(device);
49 if (flr == nullptr) {
50 return errors::Unavailable(
51 "Unable to find a FunctionLibraryRuntime corresponding to device ",
52 device->name());
53 }
54
55 auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx.runner();
56 kernel->reset(new KernelAndDeviceOp(ctx.GetRendezvous(), ctx.LogMemory(), flr,
57 runner, ctx.GetCollectiveExecutorHandle(),
58 ctx.HostCPU()));
59
60 const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
61 return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
62 /*graph_collector=*/nullptr);
63 }
64
65 // This gets a unique wire ID. We add a random identifier so that if the
66 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()67 string GetUniqueWireID() {
68 static tensorflow::uint64 random_seed = random::New64();
69 static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
70 static std::atomic<int64_t> wire_id;
71 return strings::StrCat(random_seed, "_", wire_id++);
72 }
73
74 } // namespace
75
RemoteCopyNode(EagerContext * ctx,EagerExecutor * executor,TensorHandle * src,TensorHandle * dst,Device * recv_device,uint64 recv_op_id)76 RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
77 TensorHandle* src, TensorHandle* dst,
78 Device* recv_device, uint64 recv_op_id)
79 : AsyncEagerNode(),
80 src_(src),
81 ctx_(ctx),
82 executor_(executor),
83 send_device_(src->DeviceOrHostCPU(*ctx)),
84 recv_device_(recv_device),
85 wire_id_(GetUniqueWireID()),
86 recv_op_id_(recv_op_id),
87 captured_state_(std::make_shared<CapturedSharedState>(dst)),
88 started_(false) {
89 DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
90 src_->Ref();
91 ctx_->Ref();
92 }
93
~RemoteCopyNode()94 RemoteCopyNode::~RemoteCopyNode() {
95 src_->Unref();
96 ctx_->Unref();
97 }
98
RunLocalSend(EagerOperation * op)99 Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
100 TF_RETURN_IF_ERROR(executor_->status());
101
102 TF_RETURN_IF_ERROR(op->AddInput(src_));
103
104 core::RefCountPtr<KernelAndDevice> kernel;
105 TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
106
107 EagerKernelArgs args(1);
108 Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
109 TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
110
111 return kernel->Run(/*step_container=*/nullptr, args, /*outputs=*/nullptr,
112 /*cancellation_manager=*/nullptr,
113 /*remote_func_params=*/absl::nullopt,
114 /*stack_trace=*/absl::nullopt);
115 }
116
StartSend()117 void RemoteCopyNode::StartSend() {
118 // TODO(gjn): We should consider just using the low-level SendOp::Compute()
119 // functionality here instead of constructing an Op.
120 EagerOperation op(ctx_);
121 Status status = op.Reset("_Send", /*raw_device_name=*/nullptr,
122 /*remote=*/false, /*executor=*/nullptr);
123 if (!status.ok()) {
124 captured_state_->SetSendStatus(status);
125 return;
126 }
127
128 op.SetDevice(send_device_);
129
130 op.MutableAttrs()->Set("tensor_name", wire_id_);
131 op.MutableAttrs()->Set("send_device", send_device_->name());
132 op.MutableAttrs()->Set(
133 "send_device_incarnation",
134 static_cast<int64>(send_device_->attributes().incarnation()));
135 op.MutableAttrs()->Set("recv_device", recv_device_->name());
136 op.MutableAttrs()->Set("client_terminated", false);
137
138 op.MutableAttrs()->Set("T", src_->dtype);
139
140 DCHECK(send_device_ != nullptr);
141
142 if (send_device_->IsLocal()) {
143 status = RunLocalSend(&op);
144 captured_state_->SetSendStatus(status);
145 return;
146 } else {
147 // Prepare the request
148 EnqueueRequest request;
149 request.set_context_id(ctx_->GetContextId());
150 auto* remote_op = request.add_queue()->mutable_operation();
151 status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
152 src_, /*wait_until_ready=*/false,
153 remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
154 src_->DeviceOrHostCPU(*ctx_)->name());
155 if (!status.ok()) {
156 captured_state_->SetSendStatus(status);
157 return;
158 }
159
160 PrepareRemoteOp(remote_op, &op);
161 remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
162
163 // Issue the RPC
164 core::RefCountPtr<eager::EagerClient> eager_client;
165 status = ctx_->GetClient(send_device_, &eager_client);
166 if (!status.ok()) {
167 captured_state_->SetSendStatus(status);
168 return;
169 }
170
171 const std::shared_ptr<CapturedSharedState>& captured_state =
172 captured_state_;
173 EnqueueResponse* response = new EnqueueResponse;
174 // If StartRecv fails very quickly, `this` can be destroyed before the
175 // callback below is executed. So, we can't capture `this`.
176 eager_client->StreamingEnqueueAsync(
177 /*call_opts=*/nullptr, &request, response,
178 [response, captured_state](const Status& s) {
179 captured_state->SetSendStatus(s);
180 if (!s.ok()) {
181 captured_state->recv_cancellation()->StartCancel();
182 }
183 delete response;
184 });
185 }
186 }
187
RunLocalRecv(EagerOperation * op,std::vector<Tensor> * outputs)188 Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
189 std::vector<Tensor>* outputs) {
190 TF_RETURN_IF_ERROR(executor_->status());
191
192 core::RefCountPtr<KernelAndDevice> kernel;
193 TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
194
195 EagerKernelArgs args;
196 std::vector<EagerKernelRet> rets;
197 TF_RETURN_IF_ERROR(kernel->Run(/*step_container*/ nullptr, args, &rets,
198 captured_state_->recv_cancellation(),
199 /*remote_func_params=*/absl::nullopt,
200 /*stack_trace=*/absl::nullopt));
201 outputs->clear();
202 for (const auto& ret : rets) {
203 if (ret.index() == 0) {
204 outputs->push_back(absl::get<Tensor>(ret));
205 } else {
206 return errors::Internal(
207 "Expect to receive a Tensor but got a TensorShape.");
208 }
209 }
210 return Status::OK();
211 }
212
RunRemoteRecv(EagerOperation * op,StatusCallback done)213 void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
214 EnqueueRequest request;
215 uint64 context_id = ctx_->GetContextId();
216 request.set_context_id(context_id);
217 auto* remote_op = request.add_queue()->mutable_operation();
218 PrepareRemoteOp(remote_op, op);
219 remote_op->set_id(recv_op_id_);
220 uint64 context_view_id = ctx_->GetContextViewId();
221
222 core::RefCountPtr<eager::EagerClient> eager_client;
223 Status status = ctx_->GetClient(recv_device_, &eager_client);
224 if (!status.ok()) {
225 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
226 done(status);
227 return;
228 }
229
230 // Don't issue the recv until send has completed.
231 // - local send will complete very quickly.
232 // - remote send will take some time, but remote->remote copy is
233 // probably rare enough that we don't care much.
234 // Blocks until send has completed.
235 Status send_status = captured_state_->GetSendStatus();
236 if (!send_status.ok()) {
237 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
238 done(send_status);
239 return;
240 }
241
242 EnqueueResponse* response = new EnqueueResponse;
243 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
244 Device* recv_device = recv_device_;
245 eager_client->StreamingEnqueueAsync(
246 /*call_opts=*/nullptr, &request, response,
247 [captured_state, response, recv_device, context_view_id,
248 done](const Status& s) {
249 if (s.ok()) {
250 Status status = captured_state->dst()->SetRemoteShape(
251 response->queue_response(0).shape(0), recv_device,
252 context_view_id);
253 if (!status.ok()) {
254 LOG(ERROR) << "Ignoring an error encountered when setting remote "
255 "shape of tensor received by remote Recv op: "
256 << status.ToString()
257 << "\nThis should never happen. "
258 "Please file an issue with the TensorFlow Team.";
259 }
260 } else {
261 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
262 }
263 done(s);
264 delete response;
265 });
266 }
267
StartRecv(StatusCallback done)268 void RemoteCopyNode::StartRecv(StatusCallback done) {
269 // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
270 // functionality here instead of constructing an Op.
271 EagerOperation op(ctx_);
272 Status status = op.Reset("_Recv", /*raw_device_name=*/nullptr,
273 /*remote=*/false, /*executor=*/nullptr);
274 Device* recv_device = ctx_->CanonicalDevice(recv_device_);
275 if (!status.ok()) {
276 captured_state_->dst()->Poison(status, recv_device);
277 done(status);
278 return;
279 }
280
281 op.SetDevice(recv_device_);
282
283 op.MutableAttrs()->Set("tensor_name", wire_id_);
284 op.MutableAttrs()->Set("send_device", send_device_->name());
285 op.MutableAttrs()->Set(
286 "send_device_incarnation",
287 static_cast<int64>(send_device_->attributes().incarnation()));
288 op.MutableAttrs()->Set("recv_device", recv_device_->name());
289 op.MutableAttrs()->Set("client_terminated", false);
290
291 op.MutableAttrs()->Set("tensor_type", src_->dtype);
292
293 if (recv_device_->IsLocal()) {
294 std::vector<Tensor> outputs(1);
295 status = RunLocalRecv(&op, &outputs);
296 if (!status.ok()) {
297 captured_state_->dst()->Poison(status, recv_device);
298 done(status);
299 return;
300 }
301 status =
302 captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
303 done(status);
304 } else {
305 // Handles captured_state_->dst_ internally.
306 RunRemoteRecv(&op, std::move(done));
307 }
308 }
309
SerializePackedHandle(const uint64 op_id,TensorHandle * packed_handle,const Device * target_device,EagerContext * ctx,SendPackedHandleOp * op)310 Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
311 const Device* target_device, EagerContext* ctx,
312 SendPackedHandleOp* op) {
313 op->set_op_id(op_id);
314 op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
315 for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
316 TensorHandle* h = nullptr;
317 TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
318 if (h->Type() == TensorHandle::LOCAL) {
319 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
320 // copy it to the CPU before copying it out.
321 Tensor tensor;
322 TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
323 auto* local_handle = op->add_handles()->mutable_local_handle();
324 local_handle->set_device(h->op_device() ? h->op_device()->name()
325 : ctx->HostCPU()->name());
326 tensor.AsProtoTensorContent(local_handle->mutable_tensor());
327 } else if (h->Type() == TensorHandle::REMOTE) {
328 // Only serialize the resource dtype and shape of the first handle, since
329 // all handles are of the same resource dtype and shape.
330 // If src_device is on the same task of target_device, the handle is a
331 // local handle on the target device, which means the resource dtype and
332 // shape are known on the target device.
333 Device* src_device = h->device();
334 const bool serialize_resource_dtype_and_shape =
335 (i == 0) && (h->dtype == DT_RESOURCE) &&
336 (!ctx->OnSameTask(src_device, target_device));
337 // For a remote component function, a function execution request and an
338 // input generation request may come from different workers. We need to
339 // guarantee that the input generation request is processed before the
340 // function execution request, so wait until the underlying remote handles
341 // are ready before sending a packed handle to the function device.
342 TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
343 h, /*wait_until_ready=*/true,
344 op->add_handles()->mutable_remote_handle(), src_device,
345 h->DeviceOrHostCPU(*ctx)->name(),
346 serialize_resource_dtype_and_shape));
347 } else {
348 return errors::InvalidArgument("Nested packed handles are not supported");
349 }
350 }
351 return Status::OK();
352 }
353
StartSendPackedHandle(StatusCallback done)354 void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
355 Status s;
356 const uint64 context_view_id = ctx_->GetContextViewId();
357 if (!send_device_->IsLocal()) {
358 s = errors::InvalidArgument(
359 "Copy a packed handle from a remote device is not supported");
360 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
361 done(s);
362 return;
363 }
364
365 EnqueueRequest request;
366 uint64 context_id = ctx_->GetContextId();
367 request.set_context_id(context_id);
368 s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
369 request.add_queue()->mutable_send_packed_handle());
370 if (!s.ok()) {
371 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
372 done(s);
373 return;
374 }
375
376 TensorShape shape;
377 s = src_->Shape(&shape);
378 if (!s.ok()) {
379 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
380 done(s);
381 return;
382 }
383 captured_state_->SetSrcShape(shape);
384
385 core::RefCountPtr<eager::EagerClient> eager_client;
386 s = ctx_->GetClient(recv_device_, &eager_client);
387 if (!s.ok()) {
388 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
389 done(s);
390 return;
391 }
392
393 EnqueueResponse* response = new EnqueueResponse;
394 Device* recv_device = recv_device_;
395 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
396 eager_client->StreamingEnqueueAsync(
397 /*call_opts=*/nullptr, &request, response,
398 [captured_state, response, recv_device, context_view_id,
399 done](const Status& s) {
400 if (s.ok()) {
401 Status status = captured_state->dst()->SetRemoteShape(
402 captured_state->GetSrcShape(), recv_device, context_view_id);
403 if (!status.ok()) {
404 LOG(ERROR) << "Ignoring an error encountered when setting remote "
405 "shape of tensor received by SendPackedHadnle rpc: "
406 << status.ToString();
407 }
408 } else {
409 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
410 }
411 done(s);
412 delete response;
413 });
414 }
415
StartRemoteSendTensor(StatusCallback done)416 void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
417 Status s;
418 EnqueueRequest request;
419 uint64 context_id = ctx_->GetContextId();
420 request.set_context_id(context_id);
421 auto* send_tensor = request.add_queue()->mutable_send_tensor();
422 send_tensor->set_op_id(recv_op_id_);
423 send_tensor->set_device_name(recv_device_->name());
424 uint64 context_view_id = ctx_->GetContextViewId();
425
426 // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
427 // copy it to the CPU before copying it out.
428 // TODO(b/110044833): this is currently slow, but can be fixed by making
429 // tensor handles aware of more than one device.
430 // TODO(fishx): Make CopyToDevice asynchronous.
431 Tensor tensor;
432 s = src_->CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor);
433 if (!s.ok()) {
434 done(s);
435 return;
436 }
437 tensor.AsProtoTensorContent(send_tensor->add_tensors());
438
439 core::RefCountPtr<eager::EagerClient> eager_client;
440 s = ctx_->GetClient(recv_device_, &eager_client);
441 if (!s.ok()) {
442 captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
443 done(s);
444 return;
445 }
446 EnqueueResponse* response = new EnqueueResponse;
447 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
448 captured_state->SetSrcShape(tensor.shape());
449 Device* recv_device = recv_device_;
450 eager_client->StreamingEnqueueAsync(
451 /*call_opts=*/nullptr, &request, response,
452 [captured_state, response, recv_device, context_view_id,
453 done](const Status& s) {
454 if (s.ok()) {
455 Status status = captured_state->dst()->SetRemoteShape(
456 captured_state->GetSrcShape(), recv_device, context_view_id);
457 if (!status.ok()) {
458 LOG(ERROR) << "Ignoring an error encountered when setting remote "
459 "shape of tensor received by SendTensor rpc: "
460 << status.ToString();
461 }
462 } else {
463 captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
464 }
465 done(s);
466 delete response;
467 });
468 }
469
Prepare()470 Status RemoteCopyNode::Prepare() {
471 TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_));
472 return Status::OK();
473 }
474
RunAsync(StatusCallback done)475 void RemoteCopyNode::RunAsync(StatusCallback done) {
476 started_ = true;
477 if (src_->Type() == TensorHandle::PACKED) {
478 return StartSendPackedHandle(std::move(done));
479 }
480
481 if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
482 !recv_device_->IsLocal()) {
483 return StartRemoteSendTensor(std::move(done));
484 }
485 StartSend();
486
487 const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
488 auto done_wrapper = [captured_state,
489 done = std::move(done)](const Status& s) {
490 if (!s.ok() && errors::IsCancelled(s)) {
491 Status send_status = captured_state->GetSendStatus();
492 if (!send_status.ok()) {
493 // In this case, Recv is cancelled because the Send op failed.
494 // Return the status of the Send op instead.
495 done(send_status);
496 }
497 } else {
498 done(s);
499 }
500 };
501
502 // StartRecv() takes care of doing the right thing to dst handle.
503 // No need to poison it after this point.
504 StartRecv(std::move(done_wrapper));
505 }
506
Abort(Status status)507 void RemoteCopyNode::Abort(Status status) {
508 if (!started_) {
509 uint64 context_view_id = ctx_->GetContextViewId();
510 captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
511 }
512 }
513
514 } // namespace eager
515 } // namespace tensorflow
516