• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h"
17 
18 #include <functional>
19 
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
23 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/errors.h"
29 #include "tensorflow/core/platform/protobuf.h"
30 
31 namespace tensorflow {
32 namespace eager {
33 
34 namespace {
35 
PrepareRemoteOp(eager::Operation * remote_op,EagerOperation * op)36 void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
37   remote_op->set_name(op->Name());
38 
39   op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
40   remote_op->set_device(op->DeviceName());
41 }
42 
CreateUncachedKernelAndDeviceOp(EagerOperation * op,core::RefCountPtr<KernelAndDevice> * kernel)43 Status CreateUncachedKernelAndDeviceOp(
44     EagerOperation* op, core::RefCountPtr<KernelAndDevice>* kernel) {
45   EagerContext& ctx = op->EagerContext();
46   Device* device = absl::get<Device*>(op->Device());
47 
48   FunctionLibraryRuntime* flr = ctx.func_lib(device);
49   if (flr == nullptr) {
50     return errors::Unavailable(
51         "Unable to find a FunctionLibraryRuntime corresponding to device ",
52         device->name());
53   }
54 
55   auto runner = (flr->runner() != nullptr) ? flr->runner() : ctx.runner();
56   kernel->reset(new KernelAndDeviceOp(ctx.GetRendezvous(), ctx.LogMemory(), flr,
57                                       runner, ctx.GetCollectiveExecutorHandle(),
58                                       ctx.HostCPU()));
59 
60   const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
61   return kernel->get()->Init(ctx.LogDevicePlacement(), ndef,
62                              /*graph_collector=*/nullptr);
63 }
64 
65 // This gets a unique wire ID. We add a random identifier so that if the
66 // worker has other clients that it is servicing, we don't have any collision.
GetUniqueWireID()67 string GetUniqueWireID() {
68   static tensorflow::uint64 random_seed = random::New64();
69   static tensorflow::mutex wireid_mutex(tensorflow::LINKER_INITIALIZED);
70   static std::atomic<int64_t> wire_id;
71   return strings::StrCat(random_seed, "_", wire_id++);
72 }
73 
74 }  // namespace
75 
RemoteCopyNode(EagerContext * ctx,EagerExecutor * executor,TensorHandle * src,TensorHandle * dst,Device * recv_device,uint64 recv_op_id)76 RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
77                                TensorHandle* src, TensorHandle* dst,
78                                Device* recv_device, uint64 recv_op_id)
79     : AsyncEagerNode(),
80       src_(src),
81       ctx_(ctx),
82       executor_(executor),
83       send_device_(src->DeviceOrHostCPU(*ctx)),
84       recv_device_(recv_device),
85       wire_id_(GetUniqueWireID()),
86       recv_op_id_(recv_op_id),
87       captured_state_(std::make_shared<CapturedSharedState>(dst)),
88       started_(false) {
89   DCHECK(!send_device_->IsLocal() || !recv_device_->IsLocal());
90   src_->Ref();
91   ctx_->Ref();
92 }
93 
~RemoteCopyNode()94 RemoteCopyNode::~RemoteCopyNode() {
95   src_->Unref();
96   ctx_->Unref();
97 }
98 
RunLocalSend(EagerOperation * op)99 Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
100   TF_RETURN_IF_ERROR(executor_->status());
101 
102   TF_RETURN_IF_ERROR(op->AddInput(src_));
103 
104   core::RefCountPtr<KernelAndDevice> kernel;
105   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
106 
107   EagerKernelArgs args(1);
108   Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
109   TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
110   CoordinationServiceAgent* coord_agent = nullptr;
111   if (ctx_->GetDistributedManager() != nullptr)
112     coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
113 
114   return kernel->Run(/*step_container=*/nullptr, args, /*outputs=*/nullptr,
115                      /*cancellation_manager=*/nullptr,
116                      /*remote_func_params=*/absl::nullopt,
117                      /*stack_trace=*/absl::nullopt, coord_agent);
118 }
119 
StartSend()120 void RemoteCopyNode::StartSend() {
121   // TODO(gjn): We should consider just using the low-level SendOp::Compute()
122   // functionality here instead of constructing an Op.
123   EagerOperation op(ctx_);
124   Status status = op.Reset("_Send", /*raw_device_name=*/nullptr,
125                            /*remote=*/false, /*executor=*/nullptr);
126   if (!status.ok()) {
127     captured_state_->SetSendStatus(status);
128     return;
129   }
130 
131   op.SetDevice(send_device_);
132 
133   op.MutableAttrs()->Set("tensor_name", wire_id_);
134   op.MutableAttrs()->Set("send_device", send_device_->name());
135   op.MutableAttrs()->Set(
136       "send_device_incarnation",
137       static_cast<int64>(send_device_->attributes().incarnation()));
138   op.MutableAttrs()->Set("recv_device", recv_device_->name());
139   op.MutableAttrs()->Set("client_terminated", false);
140 
141   op.MutableAttrs()->Set("T", src_->dtype);
142 
143   DCHECK(send_device_ != nullptr);
144 
145   if (send_device_->IsLocal()) {
146     status = RunLocalSend(&op);
147     captured_state_->SetSendStatus(status);
148     return;
149   } else {
150     // Prepare the request
151     EnqueueRequest request;
152     request.set_context_id(ctx_->GetContextId());
153     auto* remote_op = request.add_queue()->mutable_operation();
154     status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
155         src_, /*wait_until_ready=*/false,
156         remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
157         src_->DeviceOrHostCPU(*ctx_)->name());
158     if (!status.ok()) {
159       captured_state_->SetSendStatus(status);
160       return;
161     }
162 
163     PrepareRemoteOp(remote_op, &op);
164     remote_op->set_id(ctx_->RemoteMgr()->NextOpId());
165 
166     // Issue the RPC
167     core::RefCountPtr<eager::EagerClient> eager_client;
168     status = ctx_->GetClient(send_device_, &eager_client);
169     if (!status.ok()) {
170       captured_state_->SetSendStatus(status);
171       return;
172     }
173 
174     const std::shared_ptr<CapturedSharedState>& captured_state =
175         captured_state_;
176     EnqueueResponse* response = new EnqueueResponse;
177     // If StartRecv fails very quickly, `this` can be destroyed before the
178     // callback below is executed. So, we can't capture `this`.
179     eager_client->StreamingEnqueueAsync(
180         /*call_opts=*/nullptr, &request, response,
181         [response, captured_state](const Status& s) {
182           captured_state->SetSendStatus(s);
183           if (!s.ok()) {
184             captured_state->recv_cancellation()->StartCancel();
185           }
186           delete response;
187         });
188   }
189 }
190 
RunLocalRecv(EagerOperation * op,std::vector<Tensor> * outputs)191 Status RemoteCopyNode::RunLocalRecv(EagerOperation* op,
192                                     std::vector<Tensor>* outputs) {
193   TF_RETURN_IF_ERROR(executor_->status());
194 
195   core::RefCountPtr<KernelAndDevice> kernel;
196   TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
197 
198   EagerKernelArgs args;
199   std::vector<EagerKernelRet> rets;
200   CoordinationServiceAgent* coord_agent = nullptr;
201   if (ctx_->GetDistributedManager() != nullptr)
202     coord_agent = ctx_->GetDistributedManager()->GetCoordinationServiceAgent();
203   TF_RETURN_IF_ERROR(kernel->Run(/*step_container*/ nullptr, args, &rets,
204                                  captured_state_->recv_cancellation(),
205                                  /*remote_func_params=*/absl::nullopt,
206                                  /*stack_trace=*/absl::nullopt, coord_agent));
207   outputs->clear();
208   for (const auto& ret : rets) {
209     if (ret.index() == 0) {
210       outputs->push_back(absl::get<Tensor>(ret));
211     } else {
212       return errors::Internal(
213           "Expect to receive a Tensor but got a TensorShape.");
214     }
215   }
216   return Status::OK();
217 }
218 
RunRemoteRecv(EagerOperation * op,StatusCallback done)219 void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) {
220   EnqueueRequest request;
221   uint64 context_id = ctx_->GetContextId();
222   request.set_context_id(context_id);
223   auto* remote_op = request.add_queue()->mutable_operation();
224   PrepareRemoteOp(remote_op, op);
225   remote_op->set_id(recv_op_id_);
226   uint64 context_view_id = ctx_->GetContextViewId();
227 
228   core::RefCountPtr<eager::EagerClient> eager_client;
229   Status status = ctx_->GetClient(recv_device_, &eager_client);
230   if (!status.ok()) {
231     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
232     done(status);
233     return;
234   }
235 
236   // Don't issue the recv until send has completed.
237   //  - local send will complete very quickly.
238   //  - remote send will take some time, but remote->remote copy is
239   //    probably rare enough that we don't care much.
240   // Blocks until send has completed.
241   Status send_status = captured_state_->GetSendStatus();
242   if (!send_status.ok()) {
243     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
244     done(send_status);
245     return;
246   }
247 
248   EnqueueResponse* response = new EnqueueResponse;
249   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
250   Device* recv_device = recv_device_;
251   eager_client->StreamingEnqueueAsync(
252       /*call_opts=*/nullptr, &request, response,
253       [captured_state, response, recv_device, context_view_id,
254        done](const Status& s) {
255         if (s.ok()) {
256           Status status = captured_state->dst()->SetRemoteShape(
257               response->queue_response(0).shape(0), recv_device,
258               context_view_id);
259           if (!status.ok()) {
260             LOG(ERROR) << "Ignoring an error encountered when setting remote "
261                           "shape of tensor received by remote Recv op: "
262                        << status.ToString()
263                        << "\nThis should never happen. "
264                           "Please file an issue with the TensorFlow Team.";
265           }
266         } else {
267           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
268         }
269         done(s);
270         delete response;
271       });
272 }
273 
StartRecv(StatusCallback done)274 void RemoteCopyNode::StartRecv(StatusCallback done) {
275   // TODO(gjn): We should consider just using the low-level RecvOp::Compute()
276   // functionality here instead of constructing an Op.
277   EagerOperation op(ctx_);
278   Status status = op.Reset("_Recv", /*raw_device_name=*/nullptr,
279                            /*remote=*/false, /*executor=*/nullptr);
280   Device* recv_device = ctx_->CanonicalDevice(recv_device_);
281   if (!status.ok()) {
282     captured_state_->dst()->Poison(status, recv_device);
283     done(status);
284     return;
285   }
286 
287   op.SetDevice(recv_device_);
288 
289   op.MutableAttrs()->Set("tensor_name", wire_id_);
290   op.MutableAttrs()->Set("send_device", send_device_->name());
291   op.MutableAttrs()->Set(
292       "send_device_incarnation",
293       static_cast<int64>(send_device_->attributes().incarnation()));
294   op.MutableAttrs()->Set("recv_device", recv_device_->name());
295   op.MutableAttrs()->Set("client_terminated", false);
296 
297   op.MutableAttrs()->Set("tensor_type", src_->dtype);
298 
299   if (recv_device_->IsLocal()) {
300     std::vector<Tensor> outputs(1);
301     status = RunLocalRecv(&op, &outputs);
302     if (!status.ok()) {
303       captured_state_->dst()->Poison(status, recv_device);
304       done(status);
305       return;
306     }
307     status =
308         captured_state_->dst()->SetTensor(std::move(outputs[0]), recv_device);
309     done(status);
310   } else {
311     // Handles captured_state_->dst_ internally.
312     RunRemoteRecv(&op, std::move(done));
313   }
314 }
315 
SerializePackedHandle(const uint64 op_id,TensorHandle * packed_handle,const Device * target_device,EagerContext * ctx,SendPackedHandleOp * op)316 Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
317                              const Device* target_device, EagerContext* ctx,
318                              SendPackedHandleOp* op) {
319   op->set_op_id(op_id);
320   op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
321   for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
322     TensorHandle* h = nullptr;
323     TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
324     if (h->Type() == TensorHandle::LOCAL) {
325       // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
326       // copy it to the CPU before copying it out.
327       Tensor tensor;
328       TF_RETURN_IF_ERROR(h->CopyToDevice(*ctx, ctx->HostCPU(), &tensor));
329       auto* local_handle = op->add_handles()->mutable_local_handle();
330       local_handle->set_device(h->op_device() ? h->op_device()->name()
331                                               : ctx->HostCPU()->name());
332       tensor.AsProtoTensorContent(local_handle->mutable_tensor());
333     } else if (h->Type() == TensorHandle::REMOTE) {
334       // Only serialize the resource dtype and shape of the first handle, since
335       // all handles are of the same resource dtype and shape.
336       // If src_device is on the same task of target_device, the handle is a
337       // local handle on the target device, which means the resource dtype and
338       // shape are known on the target device.
339       Device* src_device = h->device();
340       const bool serialize_resource_dtype_and_shape =
341           (i == 0) && (h->dtype == DT_RESOURCE) &&
342           (!ctx->OnSameTask(src_device, target_device));
343       // For a remote component function, a function execution request and an
344       // input generation request may come from different workers. We need to
345       // guarantee that the input generation request is processed before the
346       // function execution request, so wait until the underlying remote handles
347       // are ready before sending a packed handle to the function device.
348       TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
349           h, /*wait_until_ready=*/true,
350           op->add_handles()->mutable_remote_handle(), src_device,
351           h->DeviceOrHostCPU(*ctx)->name(),
352           serialize_resource_dtype_and_shape));
353     } else {
354       return errors::InvalidArgument("Nested packed handles are not supported");
355     }
356   }
357   return Status::OK();
358 }
359 
StartSendPackedHandle(StatusCallback done)360 void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) {
361   Status s;
362   const uint64 context_view_id = ctx_->GetContextViewId();
363   if (!send_device_->IsLocal()) {
364     s = errors::InvalidArgument(
365         "Copy a packed handle from a remote device is not supported");
366     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
367     done(s);
368     return;
369   }
370 
371   EnqueueRequest request;
372   uint64 context_id = ctx_->GetContextId();
373   request.set_context_id(context_id);
374   s = SerializePackedHandle(recv_op_id_, src_, recv_device_, ctx_,
375                             request.add_queue()->mutable_send_packed_handle());
376   if (!s.ok()) {
377     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
378     done(s);
379     return;
380   }
381 
382   TensorShape shape;
383   s = src_->Shape(&shape);
384   if (!s.ok()) {
385     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
386     done(s);
387     return;
388   }
389   captured_state_->SetSrcShape(shape);
390 
391   core::RefCountPtr<eager::EagerClient> eager_client;
392   s = ctx_->GetClient(recv_device_, &eager_client);
393   if (!s.ok()) {
394     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
395     done(s);
396     return;
397   }
398 
399   EnqueueResponse* response = new EnqueueResponse;
400   Device* recv_device = recv_device_;
401   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
402   eager_client->StreamingEnqueueAsync(
403       /*call_opts=*/nullptr, &request, response,
404       [captured_state, response, recv_device, context_view_id,
405        done](const Status& s) {
406         if (s.ok()) {
407           Status status = captured_state->dst()->SetRemoteShape(
408               captured_state->GetSrcShape(), recv_device, context_view_id);
409           if (!status.ok()) {
410             LOG(ERROR) << "Ignoring an error encountered when setting remote "
411                           "shape of tensor received by SendPackedHadnle rpc: "
412                        << status.ToString();
413           }
414         } else {
415           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
416         }
417         done(s);
418         delete response;
419       });
420 }
421 
StartRemoteSendTensor(StatusCallback done)422 void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) {
423   Status s;
424   EnqueueRequest request;
425   uint64 context_id = ctx_->GetContextId();
426   request.set_context_id(context_id);
427   auto* send_tensor = request.add_queue()->mutable_send_tensor();
428   send_tensor->set_op_id(recv_op_id_);
429   send_tensor->set_device_name(recv_device_->name());
430   uint64 context_view_id = ctx_->GetContextViewId();
431 
432   // AsProtoTensorContent doesn't work when the tensor is on the GPU, hence
433   // copy it to the CPU before copying it out.
434   // TODO(fishx): Make CopyToDevice asynchronous.
435   Tensor tensor;
436   s = src_->CopyToDevice(*ctx_, ctx_->HostCPU(), &tensor);
437   if (!s.ok()) {
438     done(s);
439     return;
440   }
441   tensor.AsProtoTensorContent(send_tensor->add_tensors());
442 
443   core::RefCountPtr<eager::EagerClient> eager_client;
444   s = ctx_->GetClient(recv_device_, &eager_client);
445   if (!s.ok()) {
446     captured_state_->dst()->PoisonRemote(s, recv_device_, context_view_id);
447     done(s);
448     return;
449   }
450   EnqueueResponse* response = new EnqueueResponse;
451   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
452   captured_state->SetSrcShape(tensor.shape());
453   Device* recv_device = recv_device_;
454   eager_client->StreamingEnqueueAsync(
455       /*call_opts=*/nullptr, &request, response,
456       [captured_state, response, recv_device, context_view_id,
457        done](const Status& s) {
458         if (s.ok()) {
459           Status status = captured_state->dst()->SetRemoteShape(
460               captured_state->GetSrcShape(), recv_device, context_view_id);
461           if (!status.ok()) {
462             LOG(ERROR) << "Ignoring an error encountered when setting remote "
463                           "shape of tensor received by SendTensor rpc: "
464                        << status.ToString();
465           }
466         } else {
467           captured_state->dst()->PoisonRemote(s, recv_device, context_view_id);
468         }
469         done(s);
470         delete response;
471       });
472 }
473 
Prepare()474 Status RemoteCopyNode::Prepare() {
475   TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_));
476   return Status::OK();
477 }
478 
RunAsync(StatusCallback done)479 void RemoteCopyNode::RunAsync(StatusCallback done) {
480   started_ = true;
481   if (src_->Type() == TensorHandle::PACKED) {
482     return StartSendPackedHandle(std::move(done));
483   }
484 
485   if ((ctx_->UseSendTensorRPC()) && send_device_->IsLocal() &&
486       !recv_device_->IsLocal()) {
487     return StartRemoteSendTensor(std::move(done));
488   }
489   StartSend();
490 
491   const std::shared_ptr<CapturedSharedState>& captured_state = captured_state_;
492   auto done_wrapper = [captured_state,
493                        done = std::move(done)](const Status& s) {
494     if (!s.ok() && errors::IsCancelled(s)) {
495       Status send_status = captured_state->GetSendStatus();
496       if (!send_status.ok()) {
497         // In this case, Recv is cancelled because the Send op failed.
498         // Return the status of the Send op instead.
499         done(send_status);
500       }
501     } else {
502       done(s);
503     }
504   };
505 
506   // StartRecv() takes care of doing the right thing to dst handle.
507   // No need to poison it after this point.
508   StartRecv(std::move(done_wrapper));
509 }
510 
Abort(Status status)511 void RemoteCopyNode::Abort(Status status) {
512   if (!started_) {
513     uint64 context_view_id = ctx_->GetContextViewId();
514     captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id);
515   }
516 }
517 
518 }  // namespace eager
519 }  // namespace tensorflow
520