1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ 18 19 #include <cstddef> 20 21 #include "absl/types/span.h" 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 24 #include "tensorflow/core/common_runtime/eager/shape_inference.h" 25 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 26 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 27 #include "tensorflow/core/framework/cancellation.h" 28 #include "tensorflow/core/framework/function.h" 29 #include "tensorflow/core/framework/node_def.pb.h" 30 #include "tensorflow/core/lib/gtl/inlined_vector.h" 31 #include "tensorflow/core/protobuf/eager_service.pb.h" 32 33 namespace tensorflow { 34 namespace eager { 35 36 // RemoteExecuteNode is an implementation of EagerNode which enqueues 37 // an operation via RPC in a remote EagerService. 38 class RemoteExecuteNode : public AsyncRemoteExecuteNode { 39 public: RemoteExecuteNode(EagerContext * eager_context,std::unique_ptr<EnqueueRequest> request,Device * device,uint64 context_view_id,EagerClient * eager_client,CancellationManager * cancellation_manager,const NodeDef & ndef,FunctionLibraryDefinition * lib_def,const gtl::InlinedVector<TensorHandle *,4> & inputs,absl::Span<TensorHandle * > retvals)40 RemoteExecuteNode(EagerContext* eager_context, 41 std::unique_ptr<EnqueueRequest> request, Device* device, 42 uint64 context_view_id, EagerClient* eager_client, 43 CancellationManager* cancellation_manager, 44 const NodeDef& ndef, FunctionLibraryDefinition* lib_def, 45 const gtl::InlinedVector<TensorHandle*, 4>& inputs, 46 absl::Span<TensorHandle*> retvals) 47 : AsyncRemoteExecuteNode(), 48 eager_context_(eager_context), 49 request_(std::move(request)), 50 device_(device), 51 context_view_id_(context_view_id), 52 eager_client_(eager_client), 53 cancellation_manager_(cancellation_manager), 54 ndef_(ndef), 55 lib_def_(lib_def), 56 inputs_(inputs) { 57 // Copy the output handles, since the container for them might get 58 // destroyed. 59 for (auto handle : retvals) { 60 handle->Ref(); 61 retvals_.push_back(handle); 62 } 63 64 // This is required to ensure that the tensor handles stay alive across the 65 // execution. 66 for (auto handle : inputs_) { 67 handle->Ref(); 68 } 69 eager_client_->Ref(); 70 71 needs_remote_inputs_ = false; 72 for (const TensorHandle* input : inputs_) { 73 // TODO(bramandia): Should this be op_device() instead? 74 if (input->resource_device() != nullptr && 75 input->resource_device() != device_) { 76 needs_remote_inputs_ = true; 77 break; 78 } 79 } 80 } 81 ~RemoteExecuteNode()82 ~RemoteExecuteNode() override { 83 for (auto handle : retvals_) { 84 handle->Unref(); 85 } 86 87 for (auto handle : inputs_) { 88 handle->Unref(); 89 } 90 eager_client_->Unref(); 91 } 92 Prepare()93 Status Prepare() override { 94 return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_); 95 } 96 97 void RunAsync(StatusCallback done) override; 98 SyncExecutors()99 Status SyncExecutors() override { return eager_context_->SyncExecutors(); } 100 Abort(Status status)101 void Abort(Status status) override { 102 int i = 0; 103 for (auto handle : retvals_) { 104 handle->PoisonRemote(status, device_, context_view_id_); 105 ++i; 106 } 107 } 108 eager_client()109 const EagerClient* eager_client() const override { return eager_client_; } 110 needs_remote_inputs()111 bool needs_remote_inputs() const override { return needs_remote_inputs_; } 112 allow_multiple_pending_requests()113 bool allow_multiple_pending_requests() const override { 114 return eager_client_->allow_multiple_pending_requests(); 115 } 116 DebugString()117 string DebugString() const override { 118 string out = "[RemoteExecuteNode]"; 119 strings::StrAppend(&out, " request: ", request_->DebugString()); 120 strings::StrAppend(&out, ", target_device: ", device_->name()); 121 return out; 122 } 123 124 private: 125 EagerContext* eager_context_; // Not owned, and must outlive this node. 126 std::unique_ptr<EnqueueRequest> request_; 127 Device* device_; // Not owned 128 uint64 context_view_id_; 129 bool needs_remote_inputs_; 130 EagerClient* eager_client_; // Not owned, and must outlive this node. 131 CancellationManager* cancellation_manager_; 132 const NodeDef ndef_; 133 const FunctionLibraryDefinition* lib_def_; 134 gtl::InlinedVector<TensorHandle*, 4> inputs_; 135 gtl::InlinedVector<TensorHandle*, 2> retvals_; 136 }; 137 138 } // namespace eager 139 } // namespace tensorflow 140 141 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ 142