• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
21 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/framework/cancellation.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 namespace eager {
29 
30 // This node supports copying a tensor in the following way:
31 // - Remote -> Local:
32 //   We don't block on the remote _Send op and start executing the local
33 //   _Recv immediately after issuing the remote _Send. The local _Recv
34 //   kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run)
35 //   blocks until the tensor is received. If the remote _Send (or some op
36 //   before it) fails, the local callback we give to EnqueueAsync will run
37 //   and call CancellationManager.StartCancel(). The blocked local _Recv will
38 //   get this notification and return with a cancelled error.
39 //
40 // - Local -> Remote:
41 //   The local _Send op is synchronous and non-blocking, thus it should complete
42 //   quickly. We issue remote _Recv RPC only after local _Send completes
43 //   successfully. At this point, the tensor to be sent is in the local
44 //   Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor
45 //   to appear.
46 //   When ctx->UseSendTensorRPC() is true, we use EagerService::Enqueue
47 //   SendTensor instead of _Send/_Recv.
48 //
49 // - Remote -> Remote:
50 //   We could issue both remote ops asynchronously, but if remote _Send (or some
51 //   op before it) fails, we don't have a good way of cancelling the remote
52 //   _Recv. The remote _Recv will deadlock in this case. The current approach
53 //   to deal with this issue is to wait for remote _Send to complete before
54 //   issuing remote _Recv RPC. Another option is to close the whole streaming
55 //   RPC that contains the deadlocked remote _Recv. This would not unblock the
56 //   deadlocked RPC on the remote machine without some extra code. Luckily, the
57 //   remote -> remote case seems to be fairly rare at this point. So, the
58 //   current partially synchronous approach seems fine.
59 //
60 // To copy a tensor within a host, please use copy_to_device_node instead.
61 class RemoteCopyNode : public AsyncEagerNode {
62  public:
63   RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src,
64                  TensorHandle* dst, Device* recv_device, uint64 recv_op_id);
65 
66   ~RemoteCopyNode() override;
67 
68   Status Prepare() override;
69 
70   void RunAsync(StatusCallback done) override;
71 
72   void Abort(Status status) override;
73 
DebugString()74   string DebugString() const override {
75     string out = "[RemoteCopyNode]";
76     strings::StrAppend(&out, " send_device: ", send_device_->name());
77     strings::StrAppend(&out, ", recv_device: ", recv_device_->name());
78     strings::StrAppend(&out, ", send_tensor: ", src_->DebugString());
79     strings::StrAppend(
80         &out, ", recv_tensor: ", captured_state_->dst()->DebugString());
81     return out;
82   }
83 
84  private:
85   // Runs the _Send operation locally or remotely.
86   // StartSend() makes sure that captured_state_->send_status_ is set to the
87   // final _Send status after captured_state->send_done_.WaitForNotification()
88   // returns.
89   void StartSend();
90 
91   // Synchronously runs local send `op` and returns its status.
92   Status RunLocalSend(EagerOperation* op);
93 
94   // Runs the _Recv operation locally or remotely.
95   // An error return value indicates that _Recv did not run successfully. It
96   // does not indicate that _Send op has completed since StartRecv could have
97   // encountered an error before waiting for _Send's completion.
98   // An OK return value does NOT necessarily indicate that _Recv has completed
99   // successfully (it does now, but won't when streaming RPCs are turned on).
100   // StartRecv() makes sure that dst_ tensor handle is handled correctly
101   // (potentially after this methods returns); a tensor is set in the local
102   // case, a remote shape is set in the remote case, the dst_ handle is
103   // poisoned in either case if there is an error.
104   void StartRecv(StatusCallback done);
105 
106   // Synchronously runs local receive `op` and returns its status.
107   // Does not wait for the send to complete before running receive.
108   Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs);
109 
110   // Waits for send to complete, then issues remote receive `op` and
111   // returns its status.
112   void RunRemoteRecv(EagerOperation* op, StatusCallback done);
113 
114   // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote
115   // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the
116   // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel).
117   //
118   // However, in some configurations the node that has the tensor to be copied
119   // isn't running a server (WorkerService RPC interface). For such cases,
120   // this function enables sending tensors using the EagerService.Enqueue
121   // SendTensor RPC *on the receiver*.
122   void StartRemoteSendTensor(StatusCallback done);
123 
124   // State that is captured by Send and/or Recv callbacks (depending on which
125   // one(s) is remote) and outlives this node in the case of remote->remote
126   // copy.
127   class CapturedSharedState {
128    public:
CapturedSharedState(TensorHandle * d)129     explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); }
~CapturedSharedState()130     ~CapturedSharedState() { dst_->Unref(); }
131 
SetSendStatus(Status status)132     void SetSendStatus(Status status) {
133       send_status_.Update(status);
134       send_done_.Notify();
135     }
136 
GetSendStatus()137     Status GetSendStatus() {
138       send_done_.WaitForNotification();
139       return send_status_;
140     }
141 
142     // src_shape_ is not thread-safe. It should only be set in one thread.
SetSrcShape(const TensorShape & shape)143     void SetSrcShape(const TensorShape& shape) { src_shape_ = shape; }
144 
GetSrcShape()145     const TensorShape& GetSrcShape() { return src_shape_; }
146 
dst()147     TensorHandle* dst() { return dst_; }
recv_cancellation()148     CancellationManager* recv_cancellation() { return &recv_cancellation_; }
149 
150    private:
151     TensorHandle* const dst_;
152     CancellationManager recv_cancellation_;
153     // send_status_ is safe to read only after send_done_.WaitForNotification()
154     // has returned.
155     Status send_status_;
156     Notification send_done_;
157     TensorShape src_shape_;
158   };
159 
160   TensorHandle* const src_;
161   EagerContext* const ctx_;
162   EagerExecutor* const executor_;
163   Device* const send_device_;
164   Device* const recv_device_;
165   const string wire_id_;
166   const uint64 recv_op_id_;
167 
168   std::shared_ptr<CapturedSharedState> captured_state_;
169   bool started_;
170 };
171 
172 }  // namespace eager
173 }  // namespace tensorflow
174 
175 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_
176