1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ 17 18 // clang-format off 19 // Required for IS_MOBILE_PLATFORM 20 #include <cstddef> 21 #include <memory> 22 #include "tensorflow/core/platform/platform.h" 23 // clang-format on 24 25 #include "absl/memory/memory.h" 26 #include "absl/types/span.h" 27 #include "tensorflow/core/common_runtime/device.h" 28 #include "tensorflow/core/common_runtime/eager/context.h" 29 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 30 #include "tensorflow/core/common_runtime/eager/execute.h" 31 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" 32 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 33 #include "tensorflow/core/framework/step_stats.pb.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/types.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/lib/gtl/inlined_vector.h" 38 #include "tensorflow/core/lib/strings/strcat.h" 39 #if !defined(IS_MOBILE_PLATFORM) 40 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" 41 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 42 #endif // IS_MOBILE_PLATFORM 43 44 namespace tensorflow { 45 46 class ExecuteNodeArgs : public EagerKernelArgs { 47 public: ExecuteNodeArgs(int count)48 explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {} 49 ~ExecuteNodeArgs() override; 50 51 Status Init(EagerContext* ctx, 52 const gtl::InlinedVector<TensorHandle*, 4>& op_inputs); 53 HasRemoteInputs()54 bool HasRemoteInputs() const override { return has_remote_inputs_; }; 55 56 #if !defined(IS_MOBILE_PLATFORM) GetRemoteArg(const int index,eager::RemoteTensorHandle * val)57 Status GetRemoteArg(const int index, 58 eager::RemoteTensorHandle* val) const override { 59 return serialize_remote_handle_(index, val); 60 } 61 #endif // IS_MOBILE_PLATFORM 62 63 private: 64 bool has_remote_inputs_ = false; 65 TensorReferenceVector protected_tensors_; 66 #if !defined(IS_MOBILE_PLATFORM) 67 std::function<Status(const int, eager::RemoteTensorHandle*)> 68 serialize_remote_handle_; 69 #endif // IS_MOBILE_PLATFORM 70 }; 71 72 class ExecuteNode : public EagerNode { 73 public: ExecuteNode(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,const DataTypeVector & output_dtypes,CancellationManager * cancellation_manager,bool async,absl::Span<TensorHandle * > retvals)74 ExecuteNode( 75 EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& inputs, 76 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 77 core::RefCountPtr<KernelAndDevice> kernel, 78 GraphCollector* graph_collector, const DataTypeVector& output_dtypes, 79 CancellationManager* cancellation_manager, bool async, 80 absl::Span<TensorHandle*> retvals) 81 : EagerNode(), 82 ctx_(ctx), 83 inputs_(inputs), 84 remote_func_params_(remote_func_params), 85 kernel_(std::move(kernel)), 86 graph_collector_(graph_collector), 87 cancellation_manager_(cancellation_manager), 88 async_(async) { 89 // Copy the output handles, since the container for them might get 90 // destroyed. 91 for (auto handle : retvals) { 92 retvals_.push_back(handle); 93 } 94 95 if (async_) { 96 // This is required to ensure that the tensor handles stay alive across 97 // the execution. 98 for (auto handle : inputs_) { 99 handle->Ref(); 100 } 101 102 for (auto handle : retvals_) { 103 handle->Ref(); 104 } 105 } 106 } 107 ~ExecuteNode()108 ~ExecuteNode() override { 109 if (async_) { 110 for (auto handle : retvals_) { 111 handle->Unref(); 112 } 113 114 for (auto handle : inputs_) { 115 handle->Unref(); 116 } 117 } 118 } 119 Run()120 Status Run() override { 121 const Status status = EagerKernelExecute( 122 ctx_, inputs_, remote_func_params_, kernel_, graph_collector_, 123 cancellation_manager_, absl::MakeSpan(retvals_)); 124 if (!status.ok()) { 125 Abort(status); 126 return status; 127 } 128 // If status is ok, EagerKernelExecute would have called SetTensor on 129 // all the output handles. 130 return Status::OK(); 131 } 132 Abort(Status status)133 void Abort(Status status) override { 134 for (auto handle : retvals_) { 135 handle->Poison(status); 136 } 137 } 138 DebugString()139 string DebugString() const override { 140 string out = "[ExecuteNode]"; 141 strings::StrAppend(&out, " kernel: ", kernel_->name()); 142 return out; 143 } 144 145 private: 146 EagerContext* ctx_; 147 gtl::InlinedVector<TensorHandle*, 4> inputs_; 148 const absl::optional<EagerRemoteFunctionParams> remote_func_params_; 149 core::RefCountPtr<KernelAndDevice> kernel_; 150 GraphCollector* graph_collector_; 151 CancellationManager* const cancellation_manager_; 152 const bool async_; 153 gtl::InlinedVector<TensorHandle*, 2> retvals_; 154 }; 155 156 } // namespace tensorflow 157 158 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ 159