• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 namespace eager {
29 
30 // This node supports copying a tensor in the following way:
31 // - Remote -> Local:
32 //   We don't block on the remote _Send op and start executing the local
33 //   _Recv immediately after issuing the remote _Send. The local _Recv
34 //   kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run)
35 //   blocks until the tensor is received. If the remote _Send (or some op
36 //   before it) fails, the local callback we give to EnqueueAsync will run
37 //   and call CancellationManager.StartCancel(). The blocked local _Recv will
38 //   get this notification and return with a cancelled error.
39 //
40 // - Local -> Remote:
41 //   The local _Send op is synchronous and non-blocking, thus it should complete
42 //   quickly. We issue remote _Recv RPC only after local _Send completes
43 //   successfully. At this point, the tensor to be sent is in the local
44 //   Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor
45 //   to appear.
46 //   When ctx->UseSendTensorRPC() is true, we use EagerService::Enqueue
47 //   SendTensor instead of _Send/_Recv.
48 //
49 // - Remote -> Remote:
50 //   We could issue both remote ops asynchronously, but if remote _Send (or some
51 //   op before it) fails, we don't have a good way of cancelling the remote
52 //   _Recv. The remote _Recv will deadlock in this case. The current approach
53 //   to deal with this issue is to wait for remote _Send to complete before
54 //   issuing remote _Recv RPC. Another option is to close the whole streaming
55 //   RPC that contains the deadlocked remote _Recv. This would not unblock the
56 //   deadlocked RPC on the remote machine without some extra code. Luckily, the
57 //   remote -> remote case seems to be fairly rare at this point. So, the
58 //   current partially synchronous approach seems fine.
59 //
60 // To copy a tensor within a host, please use copy_to_device_node instead.
61 class RemoteCopyNode : public AsyncEagerNode {
62  public:
63   RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
64                  TensorHandle* dst, Device* recv_device, uint64 recv_op_id);
65 
66   ~RemoteCopyNode() override;
67 
68   Status Prepare() override;
69 
70   void RunAsync(StatusCallback done) override;
71 
72   void Abort(Status status) override;
73 
DebugString()74   string DebugString() const override {
75     string out = "[RemoteCopyNode]";
76     strings::StrAppend(&out, " send_device: ", send_device_->name());
77     strings::StrAppend(&out, ", recv_device: ", recv_device_->name());
78     strings::StrAppend(&out, ", send_tensor: ", src_->DebugString());
79     strings::StrAppend(
80         &out, ", recv_tensor: ", captured_state_->dst()->DebugString());
81     return out;
82   }
83 
84  private:
85   // Runs the _Send operation locally or remotely.
86   // StartSend() makes sure that captured_state_->send_status_ is set to the
87   // final _Send status after captured_state->send_done_.WaitForNotification()
88   // returns.
89   void StartSend();
90 
91   // Synchronously runs local send `op` and returns its status.
92   Status RunLocalSend(EagerOperation* op);
93 
94   // Runs the _Recv operation locally or remotely.
95   // An error return value indicates that _Recv did not run successfully. It
96   // does not indicate that _Send op has completed since StartRecv could have
97   // encountered an error before waiting for _Send's completion.
98   // An OK return value does NOT necessarily indicate that _Recv has completed
99   // successfully (it does now, but won't when streaming RPCs are turned on).
100   // StartRecv() makes sure that dst_ tensor handle is handled correctly
101   // (potentially after this methods returns); a tensor is set in the local
102   // case, a remote shape is set in the remote case, the dst_ handle is
103   // poisoned in either case if there is an error.
104   void StartRecv(StatusCallback done);
105 
106   // Synchronously runs local receive `op` and returns its status.
107   // Does not wait for the send to complete before running receive.
108   Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs);
109 
110   // Waits for send to complete, then issues remote receive `op` and
111   // returns its status.
112   void RunRemoteRecv(EagerOperation* op, StatusCallback done);
113 
114   // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
115   // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
116   // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
117   //
118   // However, in some configurations the node that has the tensor to be copied
119   // isn't running a server (WorkerService RPC interface). For such cases,
120   // this function enables sending tensors using the EagerService.Enqueue
121   // SendTensor RPC *on the receiver*.
122   void StartRemoteSendTensor(StatusCallback done);
123 
124   // Send a local packed TensorHandle to a remote device.
125   void StartSendPackedHandle(StatusCallback done);
126 
127   // State that is captured by Send and/or Recv callbacks (depending on which
128   // one(s) is remote) and outlives this node in the case of remote->remote
129   // copy.
130   class CapturedSharedState {
131    public:
CapturedSharedState(TensorHandle * d)132     explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); }
~CapturedSharedState()133     ~CapturedSharedState() { dst_->Unref(); }
134 
SetSendStatus(Status status)135     void SetSendStatus(Status status) {
136       send_status_.Update(status);
137       send_done_.Notify();
138     }
139 
GetSendStatus()140     Status GetSendStatus() {
141       send_done_.WaitForNotification();
142       return send_status_;
143     }
144 
145     // src_shape_ is not thread-safe. It should only be set in one thread.
SetSrcShape(const TensorShape & shape)146     void SetSrcShape(const TensorShape& shape) { src_shape_ = shape; }
147 
GetSrcShape()148     const TensorShape& GetSrcShape() { return src_shape_; }
149 
dst()150     TensorHandle* dst() { return dst_; }
recv_cancellation()151     CancellationManager* recv_cancellation() { return &recv_cancellation_; }
152 
153    private:
154     TensorHandle* const dst_;
155     CancellationManager recv_cancellation_;
156     // send_status_ is safe to read only after send_done_.WaitForNotification()
157     // has returned.
158     Status send_status_;
159     Notification send_done_;
160     TensorShape src_shape_;
161   };
162 
163   TensorHandle* const src_;
164   EagerContext* const ctx_;
165   EagerExecutor* const executor_;
166   Device* const send_device_;
167   Device* const recv_device_;
168   const string wire_id_;
169   const uint64 recv_op_id_;
170 
171   std::shared_ptr<CapturedSharedState> captured_state_;
172   bool started_;
173 };
174 
175 }  // namespace eager
176 }  // namespace tensorflow
177 
178 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
179