1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 21 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 23 #include "tensorflow/core/framework/cancellation.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 namespace eager { 29 30 // This node supports copying a tensor in the following way: 31 // - Remote -> Local: 32 // We don't block on the remote _Send op and start executing the local 33 // _Recv immediately after issuing the remote _Send. The local _Recv 34 // kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run) 35 // blocks until the tensor is received. If the remote _Send (or some op 36 // before it) fails, the local callback we give to EnqueueAsync will run 37 // and call CancellationManager.StartCancel(). The blocked local _Recv will 38 // get this notification and return with a cancelled error. 39 // 40 // - Local -> Remote: 41 // The local _Send op is synchronous and non-blocking, thus it should complete 42 // quickly. We issue remote _Recv RPC only after local _Send completes 43 // successfully. At this point, the tensor to be sent is in the local 44 // Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor 45 // to appear. 46 // When ctx->UseSendTensorRPC() is true, we use EagerService::Enqueue 47 // SendTensor instead of _Send/_Recv. 48 // 49 // - Remote -> Remote: 50 // We could issue both remote ops asynchronously, but if remote _Send (or some 51 // op before it) fails, we don't have a good way of cancelling the remote 52 // _Recv. The remote _Recv will deadlock in this case. The current approach 53 // to deal with this issue is to wait for remote _Send to complete before 54 // issuing remote _Recv RPC. Another option is to close the whole streaming 55 // RPC that contains the deadlocked remote _Recv. This would not unblock the 56 // deadlocked RPC on the remote machine without some extra code. Luckily, the 57 // remote -> remote case seems to be fairly rare at this point. So, the 58 // current partially synchronous approach seems fine. 59 // 60 // To copy a tensor within a host, please use copy_to_device_node instead. 61 class RemoteCopyNode : public AsyncEagerNode { 62 public: 63 RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src, 64 TensorHandle* dst, Device* recv_device, uint64 recv_op_id); 65 66 ~RemoteCopyNode() override; 67 68 Status Prepare() override; 69 70 void RunAsync(StatusCallback done) override; 71 72 void Abort(Status status) override; 73 DebugString()74 string DebugString() const override { 75 string out = "[RemoteCopyNode]"; 76 strings::StrAppend(&out, " send_device: ", send_device_->name()); 77 strings::StrAppend(&out, ", recv_device: ", recv_device_->name()); 78 strings::StrAppend(&out, ", send_tensor: ", src_->DebugString()); 79 strings::StrAppend( 80 &out, ", recv_tensor: ", captured_state_->dst()->DebugString()); 81 return out; 82 } 83 84 private: 85 // Runs the _Send operation locally or remotely. 86 // StartSend() makes sure that captured_state_->send_status_ is set to the 87 // final _Send status after captured_state->send_done_.WaitForNotification() 88 // returns. 89 void StartSend(); 90 91 // Synchronously runs local send `op` and returns its status. 92 Status RunLocalSend(EagerOperation* op); 93 94 // Runs the _Recv operation locally or remotely. 95 // An error return value indicates that _Recv did not run successfully. It 96 // does not indicate that _Send op has completed since StartRecv could have 97 // encountered an error before waiting for _Send's completion. 98 // An OK return value does NOT necessarily indicate that _Recv has completed 99 // successfully (it does now, but won't when streaming RPCs are turned on). 100 // StartRecv() makes sure that dst_ tensor handle is handled correctly 101 // (potentially after this methods returns); a tensor is set in the local 102 // case, a remote shape is set in the remote case, the dst_ handle is 103 // poisoned in either case if there is an error. 104 void StartRecv(StatusCallback done); 105 106 // Synchronously runs local receive `op` and returns its status. 107 // Does not wait for the send to complete before running receive. 108 Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs); 109 110 // Waits for send to complete, then issues remote receive `op` and 111 // returns its status. 112 void RunRemoteRecv(EagerOperation* op, StatusCallback done); 113 114 // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote 115 // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the 116 // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel). 117 // 118 // However, in some configurations the node that has the tensor to be copied 119 // isn't running a server (WorkerService RPC interface). For such cases, 120 // this function enables sending tensors using the EagerService.Enqueue 121 // SendTensor RPC *on the receiver*. 122 void StartRemoteSendTensor(StatusCallback done); 123 124 // State that is captured by Send and/or Recv callbacks (depending on which 125 // one(s) is remote) and outlives this node in the case of remote->remote 126 // copy. 127 class CapturedSharedState { 128 public: CapturedSharedState(TensorHandle * d)129 explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); } ~CapturedSharedState()130 ~CapturedSharedState() { dst_->Unref(); } 131 SetSendStatus(Status status)132 void SetSendStatus(Status status) { 133 send_status_.Update(status); 134 send_done_.Notify(); 135 } 136 GetSendStatus()137 Status GetSendStatus() { 138 send_done_.WaitForNotification(); 139 return send_status_; 140 } 141 142 // src_shape_ is not thread-safe. It should only be set in one thread. SetSrcShape(const TensorShape & shape)143 void SetSrcShape(const TensorShape& shape) { src_shape_ = shape; } 144 GetSrcShape()145 const TensorShape& GetSrcShape() { return src_shape_; } 146 dst()147 TensorHandle* dst() { return dst_; } recv_cancellation()148 CancellationManager* recv_cancellation() { return &recv_cancellation_; } 149 150 private: 151 TensorHandle* const dst_; 152 CancellationManager recv_cancellation_; 153 // send_status_ is safe to read only after send_done_.WaitForNotification() 154 // has returned. 155 Status send_status_; 156 Notification send_done_; 157 TensorShape src_shape_; 158 }; 159 160 TensorHandle* const src_; 161 EagerContext* const ctx_; 162 EagerExecutor* const executor_; 163 Device* const send_device_; 164 Device* const recv_device_; 165 const string wire_id_; 166 const uint64 recv_op_id_; 167 168 std::shared_ptr<CapturedSharedState> captured_state_; 169 bool started_; 170 }; 171 172 } // namespace eager 173 } // namespace tensorflow 174 175 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 176