1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 21 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 23 #include "tensorflow/core/framework/cancellation.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 namespace eager { 29 30 // This node supports copying a tensor in the following way: 31 // - Remote -> Local: 32 // We don't block on the remote _Send op and start executing the local 33 // _Recv immediately after issuing the remote _Send. The local _Recv 34 // kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run) 35 // blocks until the tensor is received. If the remote _Send (or some op 36 // before it) fails, the local callback we give to EnqueueAsync will run 37 // and call CancellationManager.StartCancel(). The blocked local _Recv will 38 // get this notification and return with a cancelled error. 39 // 40 // - Local -> Remote: 41 // The local _Send op is synchronous and non-blocking, thus it should complete 42 // quickly. We issue remote _Recv RPC only after local _Send completes 43 // successfully. At this point, the tensor to be sent is in the local 44 // Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor 45 // to appear. 46 // When ctx->UseSendTensorRPC() is true, we use EagerService::Enqueue 47 // SendTensor instead of _Send/_Recv. 48 // 49 // - Remote -> Remote: 50 // We could issue both remote ops asynchronously, but if remote _Send (or some 51 // op before it) fails, we don't have a good way of cancelling the remote 52 // _Recv. The remote _Recv will deadlock in this case. The current approach 53 // to deal with this issue is to wait for remote _Send to complete before 54 // issuing remote _Recv RPC. Another option is to close the whole streaming 55 // RPC that contains the deadlocked remote _Recv. This would not unblock the 56 // deadlocked RPC on the remote machine without some extra code. Luckily, the 57 // remote -> remote case seems to be fairly rare at this point. So, the 58 // current partially synchronous approach seems fine. 59 // 60 // To copy a tensor within a host, please use copy_to_device_node instead. 61 class RemoteCopyNode : public AsyncEagerNode { 62 public: 63 RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src, 64 TensorHandle* dst, Device* recv_device, uint64 recv_op_id); 65 66 ~RemoteCopyNode() override; 67 68 Status Prepare() override; 69 70 void RunAsync(StatusCallback done) override; 71 72 void Abort(Status status) override; 73 DebugString()74 string DebugString() const override { 75 string out = "[RemoteCopyNode]"; 76 strings::StrAppend(&out, " send_device: ", send_device_->name()); 77 strings::StrAppend(&out, ", recv_device: ", recv_device_->name()); 78 strings::StrAppend(&out, ", send_tensor: ", src_->DebugString()); 79 strings::StrAppend( 80 &out, ", recv_tensor: ", captured_state_->dst()->DebugString()); 81 return out; 82 } 83 84 private: 85 // Runs the _Send operation locally or remotely. 86 // StartSend() makes sure that captured_state_->send_status_ is set to the 87 // final _Send status after captured_state->send_done_.WaitForNotification() 88 // returns. 89 void StartSend(); 90 91 // Synchronously runs local send `op` and returns its status. 92 Status RunLocalSend(EagerOperation* op); 93 94 // Runs the _Recv operation locally or remotely. 95 // An error return value indicates that _Recv did not run successfully. It 96 // does not indicate that _Send op has completed since StartRecv could have 97 // encountered an error before waiting for _Send's completion. 98 // An OK return value does NOT necessarily indicate that _Recv has completed 99 // successfully (it does now, but won't when streaming RPCs are turned on). 100 // StartRecv() makes sure that dst_ tensor handle is handled correctly 101 // (potentially after this methods returns); a tensor is set in the local 102 // case, a remote shape is set in the remote case, the dst_ handle is 103 // poisoned in either case if there is an error. 104 void StartRecv(StatusCallback done); 105 106 // Synchronously runs local receive `op` and returns its status. 107 // Does not wait for the send to complete before running receive. 108 Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs); 109 110 // Waits for send to complete, then issues remote receive `op` and 111 // returns its status. 112 void RunRemoteRecv(EagerOperation* op, StatusCallback done); 113 114 // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote 115 // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the 116 // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel). 117 // 118 // However, in some configurations the node that has the tensor to be copied 119 // isn't running a server (WorkerService RPC interface). For such cases, 120 // this function enables sending tensors using the EagerService.Enqueue 121 // SendTensor RPC *on the receiver*. 122 void StartRemoteSendTensor(StatusCallback done); 123 124 // Send a local packed TensorHandle to a remote device. 125 void StartSendPackedHandle(StatusCallback done); 126 127 // State that is captured by Send and/or Recv callbacks (depending on which 128 // one(s) is remote) and outlives this node in the case of remote->remote 129 // copy. 130 class CapturedSharedState { 131 public: CapturedSharedState(TensorHandle * d)132 explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); } ~CapturedSharedState()133 ~CapturedSharedState() { dst_->Unref(); } 134 SetSendStatus(Status status)135 void SetSendStatus(Status status) { 136 send_status_.Update(status); 137 send_done_.Notify(); 138 } 139 GetSendStatus()140 Status GetSendStatus() { 141 send_done_.WaitForNotification(); 142 return send_status_; 143 } 144 145 // src_shape_ is not thread-safe. It should only be set in one thread. SetSrcShape(const TensorShape & shape)146 void SetSrcShape(const TensorShape& shape) { src_shape_ = shape; } 147 GetSrcShape()148 const TensorShape& GetSrcShape() { return src_shape_; } 149 dst()150 TensorHandle* dst() { return dst_; } recv_cancellation()151 CancellationManager* recv_cancellation() { return &recv_cancellation_; } 152 153 private: 154 TensorHandle* const dst_; 155 CancellationManager recv_cancellation_; 156 // send_status_ is safe to read only after send_done_.WaitForNotification() 157 // has returned. 158 Status send_status_; 159 Notification send_done_; 160 TensorShape src_shape_; 161 }; 162 163 TensorHandle* const src_; 164 EagerContext* const ctx_; 165 EagerExecutor* const executor_; 166 Device* const send_device_; 167 Device* const recv_device_; 168 const string wire_id_; 169 const uint64 recv_op_id_; 170 171 std::shared_ptr<CapturedSharedState> captured_state_; 172 bool started_; 173 }; 174 175 } // namespace eager 176 } // namespace tensorflow 177 178 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 179