• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
17 
18 #include <functional>
19 
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
23 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 
31 namespace tensorflow {
32 namespace eager {
33 
34 namespace {
35 
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)36 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
37   remote_op->set_name(op->Name());
38 
39   op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
40   remote_op->set_device(op->DeviceName());
41 }
42 
CreateUncachedKernelAndDeviceOp(EagerOperation * op,core::RefCountPtr<KernelAndDevice> * kernel)43 Status CreateUncachedKernelAndDeviceOp(
44     EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
45   EagerContext& ctx = op->EagerContext();
46   Device* device = absl::get<Device*>(op->Device());
47 
48   FunctionLibraryRuntime* flr = ctx.func_lib(device);
49   if (flr == nullptr) {
50     return errors::Unavailable(
51         "Unable to find a FunctionLibraryRuntime corresponding to device ",
52         device->name());
53   }
54 
55   auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx.runner();
56   kernel->reset(new KernelAndDeviceOp(ctx.GetRendezvous(), ctx.LogMemory(), flr,
57                                       runner, ctx.GetCollectiveExecutorHandle(),
58                                       ctx.HostCPU()));
59 
60   const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
61   return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
62                              /*graph_collector=*/nullptr);
63 }
64 
65 // This gets a unique wire ID. We add a random identifier so that if the
66 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()67 string GetUniqueWireID() {
68   static tensorflow::uint64 random_seed = random::New64();
69   static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
70   static std::atomic<int64_t> wire_id;
71   return strings::StrCat(random_seed, "_", wire_id++);
72 }
73 
74 }  // namespace
75 
RemoteCopyNode(EagerContext * ctx,EagerExecutor * executor,TensorHandle * src,TensorHandle * dst,Device * recv_device,uint64 recv_op_id)76 RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
77                                TensorHandle* src, TensorHandle* dst,
78                                Device* recv_device, uint64 recv_op_id)
79     : AsyncEagerNode(),
80       src_(src),
81       ctx_(ctx),
82       executor_(executor),
83       send_device_(src->DeviceOrHostCPU(*ctx)),
84       recv_device_(recv_device),
85       wire_id_(GetUniqueWireID()),
86       recv_op_id_(recv_op_id),
87       captured_state_(std::make_shared<CapturedSharedState>(dst)),
88       started_(false) {
89   DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
90   src_->Ref();
91   ctx_->Ref();
92 }
93 
~RemoteCopyNode()94 RemoteCopyNode::~RemoteCopyNode() {
95   src_->Unref();
96   ctx_->Unref();
97 }
98 
RunLocalSend(EagerOperation * op)99 Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
100   TF_RETURN_IF_ERROR(executor_->status());
101 
102   TF_RETURN_IF_ERROR(op->AddInput(src_));
103 
104   core::RefCountPtr<KernelAndDevice> kernel;
105   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
106 
107   EagerKernelArgs args(1);
108   Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
109   TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
110 
111   return kernel->Run(/*step_container=*/nullptr, args, /*outputs=*/nullptr,
112                      /*cancellation_manager=*/nullptr,
113                      /*remote_func_params=*/absl::nullopt,
114                      /*stack_trace=*/absl::nullopt);
115 }
116 
StartSend()117 void RemoteCopyNode::StartSend() {
118   // TODO(gjn): We should consider just using the low-level SendOp::Compute()
119   // functionality here instead of constructing an Op.
120   EagerOperation op(ctx_);
121   Status status = op.Reset("_Send", /*raw_device_name=*/nullptr,
122                            /*remote=*/false, /*executor=*/nullptr);
123   if (!status.ok()) {
124     captured_state_->SetSendStatus(status);
125     return;
126   }
127 
128   op.SetDevice(send_device_);
129 
130   op.MutableAttrs()->Set("tensor_name", wire_id_);
131   op.MutableAttrs()->Set("send_device", send_device_->name());
132   op.MutableAttrs()->Set(
133       "send_device_incarnation",
134       static_cast<int64>(send_device_->attributes().incarnation()));
135   op.MutableAttrs()->Set("recv_device", recv_device_->name());
136   op.MutableAttrs()->Set("client_terminated", false);
137 
138   op.MutableAttrs()->Set("T", src_->dtype);
139 
140   DCHECK(send_device_ != nullptr);
141 
142   if (send_device_->IsLocal()) {
143     status = RunLocalSend(&op);
144     captured_state_->SetSendStatus(status);
145     return;
146   } else {
147     // Prepare the request
148     EnqueueRequest request;
149     request.set_context_id(ctx_->GetContextId());
150     auto* remote_op = request.add_queue()->mutable_operation();
151     status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
152         src_, /*wait_until_ready=*/false,
153         remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
154         src_->DeviceOrHostCPU(*ctx_)->name());
155     if (!status.ok()) {
156       captured_state_->SetSendStatus(status);
157       return;
158     }
159 
160     PrepareRemoteOp(remote_op, &op);
161     remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
162 
163     // Issue the RPC
164     core::RefCountPtr<eager::EagerClient> eager_client;
165     status = ctx_->GetClient(send_device_, &eager_client);
166     if (!status.ok()) {
167       captured_state_->SetSendStatus(status);
168       return;
169     }
170 
171     const std::shared_ptr<CapturedSharedState>& captured_state =
172         captured_state_;
173     EnqueueResponse* response = new EnqueueResponse;
174     // If StartRecv fails very quickly, `this` can be destroyed before the
175     // callback below is executed. So, we can't capture `this`.
176     eager_client->StreamingEnqueueAsync(
177         /*call_opts=*/nullptr, &request, response,
178         [response, captured_state](const Status& s) {
179           captured_state->SetSendStatus(s);
180           if (!s.ok()) {
181             captured_state->recv_cancellation()->StartCancel();
182           }
183           delete response;
184         });
185   }
186 }
187 
RunLocalRecv(EagerOperation * op,std::vector<Tensor> * outputs)188 Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
189                                     std::vector<Tensor>* outputs) {
190   TF_RETURN_IF_ERROR(executor_->status());
191 
192   core::RefCountPtr<KernelAndDevice> kernel;
193   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
194 
195   EagerKernelArgs args;
196   std::vector<EagerKernelRet> rets;
197   TF_RETURN_IF_ERROR(kernel->Run(/*step_container*/ nullptr, args, &rets,
198                                  captured_state_->recv_cancellation(),
199                                  /*remote_func_params=*/absl::nullopt,
200                                  /*stack_trace=*/absl::nullopt));
201   outputs->clear();
202   for (const auto& ret : rets) {
203     if (ret.index() == 0) {
204       outputs->push_back(absl::get<Tensor>(ret));
205     } else {
206       return errors::Internal(
207           "Expect to receive a Tensor but got a TensorShape.");
208     }
209   }
210   return Status::OK();
211 }
212 
RunRemoteRecv(EagerOperation * op,StatusCallback done)213 void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
214   EnqueueRequest request;
215   uint64 context_id = ctx_->GetContextId();
216   request.set_context_id(context_id);
217   auto* remote_op = request.add_queue()->mutable_operation();
218   PrepareRemoteOp(remote_op, op);
219   remote_op->set_id(recv_op_id_);
220   uint64 context_view_id = ctx_->GetContextViewId();
221 
222   core::RefCountPtr<eager::EagerClient> eager_client;
223   Status status = ctx_->GetClient(recv_device_, &eager_client);
224   if (!status.ok()) {
225     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
226     done(status);
227     return;
228   }
229 
230   // Don't issue the recv until send has completed.
231   //  - local send will complete very quickly.
232   //  - remote send will take some time, but remote->remote copy is
233   //    probably rare enough that we don't care much.
234   // Blocks until send has completed.
235   Status send_status = captured_state_->GetSendStatus();
236   if (!send_status.ok()) {
237     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
238     done(send_status);
239     return;
240   }
241 
242   EnqueueResponse* response = new EnqueueResponse;
243   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
244   Device* recv_device = recv_device_;
245   eager_client->StreamingEnqueueAsync(
246       /*call_opts=*/nullptr, &request, response,
247       [captured_state, response, recv_device, context_view_id,
248        done](const Status& s) {
249         if (s.ok()) {
250           Status status = captured_state->dst()->SetRemoteShape(
251               response->queue_response(0).shape(0), recv_device,
252               context_view_id);
253           if (!status.ok()) {
254             LOG(ERROR) << "Ignoring an error encountered when setting remote "
255                           "shape of tensor received by remote Recv op: "
256                        << status.ToString()
257                        << "\nThis should never happen. "
258                           "Please file an issue with the TensorFlow Team.";
259           }
260         } else {
261           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
262         }
263         done(s);
264         delete response;
265       });
266 }
267 
StartRecv(StatusCallback done)268 void RemoteCopyNode::StartRecv(StatusCallback done) {
269   // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
270   // functionality here instead of constructing an Op.
271   EagerOperation op(ctx_);
272   Status status = op.Reset("_Recv", /*raw_device_name=*/nullptr,
273                            /*remote=*/false, /*executor=*/nullptr);
274   Device* recv_device = ctx_->CanonicalDevice(recv_device_);
275   if (!status.ok()) {
276     captured_state_->dst()->Poison(status, recv_device);
277     done(status);
278     return;
279   }
280 
281   op.SetDevice(recv_device_);
282 
283   op.MutableAttrs()->Set("tensor_name", wire_id_);
284   op.MutableAttrs()->Set("send_device", send_device_->name());
285   op.MutableAttrs()->Set(
286       "send_device_incarnation",
287       static_cast<int64>(send_device_->attributes().incarnation()));
288   op.MutableAttrs()->Set("recv_device", recv_device_->name());
289   op.MutableAttrs()->Set("client_terminated", false);
290 
291   op.MutableAttrs()->Set("tensor_type", src_->dtype);
292 
293   if (recv_device_->IsLocal()) {
294     std::vector<Tensor> outputs(1);
295     status = RunLocalRecv(&op, &outputs);
296     if (!status.ok()) {
297       captured_state_->dst()->Poison(status, recv_device);
298       done(status);
299       return;
300     }
301     status =
302         captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
303     done(status);
304   } else {
305     // Handles captured_state_->dst_ internally.
306     RunRemoteRecv(&op, std::move(done));
307   }
308 }
309 
SerializePackedHandle(const uint64 op_id,TensorHandle * packed_handle,const Device * target_device,EagerContext * ctx,SendPackedHandleOp * op)310 Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
311                              const Device* target_device, EagerContext* ctx,
312                              SendPackedHandleOp* op) {
313   op->set_op_id(op_id);
314   op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
315   for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
316     TensorHandle* h = nullptr;
317     TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
318     if (h->Type() == TensorHandle::LOCAL) {
319       // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
320       // copy it to the CPU before copying it out.
321       Tensor tensor;
322       TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
323       auto* local_handle = op->add_handles()->mutable_local_handle();
324       local_handle->set_device(h->op_device() ? h->op_device()->name()
325                                               : ctx->HostCPU()->name());
326       tensor.AsProtoTensorContent(local_handle->mutable_tensor());
327     } else if (h->Type() == TensorHandle::REMOTE) {
328       // Only serialize the resource dtype and shape of the first handle, since
329       // all handles are of the same resource dtype and shape.
330       // If src_device is on the same task of target_device, the handle is a
331       // local handle on the target device, which means the resource dtype and
332       // shape are known on the target device.
333       Device* src_device = h->device();
334       const bool serialize_resource_dtype_and_shape =
335           (i == 0) && (h->dtype == DT_RESOURCE) &&
336           (!ctx->OnSameTask(src_device, target_device));
337       // For a remote component function, a function execution request and an
338       // input generation request may come from different workers. We need to
339       // guarantee that the input generation request is processed before the
340       // function execution request, so wait until the underlying remote handles
341       // are ready before sending a packed handle to the function device.
342       TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
343           h, /*wait_until_ready=*/true,
344           op->add_handles()->mutable_remote_handle(), src_device,
345           h->DeviceOrHostCPU(*ctx)->name(),
346           serialize_resource_dtype_and_shape));
347     } else {
348       return errors::InvalidArgument("Nested packed handles are not supported");
349     }
350   }
351   return Status::OK();
352 }
353 
StartSendPackedHandle(StatusCallback done)354 void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
355   Status s;
356   const uint64 context_view_id = ctx_->GetContextViewId();
357   if (!send_device_->IsLocal()) {
358     s = errors::InvalidArgument(
359         "Copy a packed handle from a remote device is not supported");
360     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
361     done(s);
362     return;
363   }
364 
365   EnqueueRequest request;
366   uint64 context_id = ctx_->GetContextId();
367   request.set_context_id(context_id);
368   s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
369                             request.add_queue()->mutable_send_packed_handle());
370   if (!s.ok()) {
371     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
372     done(s);
373     return;
374   }
375 
376   TensorShape shape;
377   s = src_->Shape(&shape);
378   if (!s.ok()) {
379     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
380     done(s);
381     return;
382   }
383   captured_state_->SetSrcShape(shape);
384 
385   core::RefCountPtr<eager::EagerClient> eager_client;
386   s = ctx_->GetClient(recv_device_, &eager_client);
387   if (!s.ok()) {
388     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
389     done(s);
390     return;
391   }
392 
393   EnqueueResponse* response = new EnqueueResponse;
394   Device* recv_device = recv_device_;
395   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
396   eager_client->StreamingEnqueueAsync(
397       /*call_opts=*/nullptr, &request, response,
398       [captured_state, response, recv_device, context_view_id,
399        done](const Status& s) {
400         if (s.ok()) {
401           Status status = captured_state->dst()->SetRemoteShape(
402               captured_state->GetSrcShape(), recv_device, context_view_id);
403           if (!status.ok()) {
404             LOG(ERROR) << "Ignoring an error encountered when setting remote "
405                           "shape of tensor received by SendPackedHadnle rpc: "
406                        << status.ToString();
407           }
408         } else {
409           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
410         }
411         done(s);
412         delete response;
413       });
414 }
415 
StartRemoteSendTensor(StatusCallback done)416 void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
417   Status s;
418   EnqueueRequest request;
419   uint64 context_id = ctx_->GetContextId();
420   request.set_context_id(context_id);
421   auto* send_tensor = request.add_queue()->mutable_send_tensor();
422   send_tensor->set_op_id(recv_op_id_);
423   send_tensor->set_device_name(recv_device_->name());
424   uint64 context_view_id = ctx_->GetContextViewId();
425 
426   // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
427   // copy it to the CPU before copying it out.
428   // TODO(b/110044833): this is currently slow, but can be fixed by making
429   // tensor handles aware of more than one device.
430   // TODO(fishx): Make CopyToDevice asynchronous.
431   Tensor tensor;
432   s = src_->CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor);
433   if (!s.ok()) {
434     done(s);
435     return;
436   }
437   tensor.AsProtoTensorContent(send_tensor->add_tensors());
438 
439   core::RefCountPtr<eager::EagerClient> eager_client;
440   s = ctx_->GetClient(recv_device_, &eager_client);
441   if (!s.ok()) {
442     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
443     done(s);
444     return;
445   }
446   EnqueueResponse* response = new EnqueueResponse;
447   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
448   captured_state->SetSrcShape(tensor.shape());
449   Device* recv_device = recv_device_;
450   eager_client->StreamingEnqueueAsync(
451       /*call_opts=*/nullptr, &request, response,
452       [captured_state, response, recv_device, context_view_id,
453        done](const Status& s) {
454         if (s.ok()) {
455           Status status = captured_state->dst()->SetRemoteShape(
456               captured_state->GetSrcShape(), recv_device, context_view_id);
457           if (!status.ok()) {
458             LOG(ERROR) << "Ignoring an error encountered when setting remote "
459                           "shape of tensor received by SendTensor rpc: "
460                        << status.ToString();
461           }
462         } else {
463           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
464         }
465         done(s);
466         delete response;
467       });
468 }
469 
Prepare()470 Status RemoteCopyNode::Prepare() {
471   TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_));
472   return Status::OK();
473 }
474 
RunAsync(StatusCallback done)475 void RemoteCopyNode::RunAsync(StatusCallback done) {
476   started_ = true;
477   if (src_->Type() == TensorHandle::PACKED) {
478     return StartSendPackedHandle(std::move(done));
479   }
480 
481   if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
482       !recv_device_->IsLocal()) {
483     return StartRemoteSendTensor(std::move(done));
484   }
485   StartSend();
486 
487   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
488   auto done_wrapper = [captured_state,
489                        done = std::move(done)](const Status& s) {
490     if (!s.ok() && errors::IsCancelled(s)) {
491       Status send_status = captured_state->GetSendStatus();
492       if (!send_status.ok()) {
493         // In this case, Recv is cancelled because the Send op failed.
494         // Return the status of the Send op instead.
495         done(send_status);
496       }
497     } else {
498       done(s);
499     }
500   };
501 
502   // StartRecv() takes care of doing the right thing to dst handle.
503   // No need to poison it after this point.
504   StartRecv(std::move(done_wrapper));
505 }
506 
Abort(Status status)507 void RemoteCopyNode::Abort(Status status) {
508   if (!started_) {
509     uint64 context_view_id = ctx_->GetContextViewId();
510     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
511   }
512 }
513 
514 }  // namespace eager
515 }  // namespace tensorflow
516