• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
17 
18 // clang-format off
19 // Required for IS_MOBILE_PLATFORM
20 #include <cstddef>
21 #include <memory>
22 #include "tensorflow/core/platform/platform.h"
23 // clang-format on
24 
25 #include "absl/memory/memory.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/eager/context.h"
29 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
30 #include "tensorflow/core/common_runtime/eager/execute.h"
31 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
32 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
33 #include "tensorflow/core/framework/step_stats.pb.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
41 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
42 #endif  // IS_MOBILE_PLATFORM
43 
44 namespace tensorflow {
45 
46 class ExecuteNodeArgs : public EagerKernelArgs {
47  public:
ExecuteNodeArgs(int count)48   explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {}
49   ~ExecuteNodeArgs() override;
50 
51   Status Init(EagerContext* ctx,
52               const gtl::InlinedVector<TensorHandle*, 4>& op_inputs);
53 
HasRemoteInputs()54   bool HasRemoteInputs() const override { return has_remote_inputs_; };
55 
56 #if !defined(IS_MOBILE_PLATFORM)
GetRemoteArg(const int index,eager::RemoteTensorHandle * val)57   Status GetRemoteArg(const int index,
58                       eager::RemoteTensorHandle* val) const override {
59     return serialize_remote_handle_(index, val);
60   }
61 #endif  // IS_MOBILE_PLATFORM
62 
63  private:
64   bool has_remote_inputs_ = false;
65   TensorReferenceVector protected_tensors_;
66 #if !defined(IS_MOBILE_PLATFORM)
67   std::function<Status(const int, eager::RemoteTensorHandle*)>
68       serialize_remote_handle_;
69 #endif  // IS_MOBILE_PLATFORM
70 };
71 
72 class ExecuteNode : public EagerNode {
73  public:
ExecuteNode(EagerContext * ctx,const gtl::InlinedVector<TensorHandle *,4> & inputs,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,core::RefCountPtr<KernelAndDevice> kernel,GraphCollector * graph_collector,const DataTypeVector & output_dtypes,CancellationManager * cancellation_manager,bool async,absl::Span<TensorHandle * > retvals)74   ExecuteNode(
75       EagerContext* ctx, const gtl::InlinedVector<TensorHandle*, 4>& inputs,
76       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
77       core::RefCountPtr<KernelAndDevice> kernel,
78       GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
79       CancellationManager* cancellation_manager, bool async,
80       absl::Span<TensorHandle*> retvals)
81       : EagerNode(),
82         ctx_(ctx),
83         inputs_(inputs),
84         remote_func_params_(remote_func_params),
85         kernel_(std::move(kernel)),
86         graph_collector_(graph_collector),
87         cancellation_manager_(cancellation_manager),
88         async_(async) {
89     // Copy the output handles, since the container for them might get
90     // destroyed.
91     for (auto handle : retvals) {
92       retvals_.push_back(handle);
93     }
94 
95     if (async_) {
96       // This is required to ensure that the tensor handles stay alive across
97       // the execution.
98       for (auto handle : inputs_) {
99         handle->Ref();
100       }
101 
102       for (auto handle : retvals_) {
103         handle->Ref();
104       }
105     }
106   }
107 
~ExecuteNode()108   ~ExecuteNode() override {
109     if (async_) {
110       for (auto handle : retvals_) {
111         handle->Unref();
112       }
113 
114       for (auto handle : inputs_) {
115         handle->Unref();
116       }
117     }
118   }
119 
Run()120   Status Run() override {
121     const Status status = EagerKernelExecute(
122         ctx_, inputs_, remote_func_params_, kernel_, graph_collector_,
123         cancellation_manager_, absl::MakeSpan(retvals_));
124     if (!status.ok()) {
125       Abort(status);
126       return status;
127     }
128     // If status is ok, EagerKernelExecute would have called SetTensor on
129     // all the output handles.
130     return Status::OK();
131   }
132 
Abort(Status status)133   void Abort(Status status) override {
134     for (auto handle : retvals_) {
135       handle->Poison(status);
136     }
137   }
138 
DebugString()139   string DebugString() const override {
140     string out = "[ExecuteNode]";
141     strings::StrAppend(&out, " kernel: ", kernel_->name());
142     return out;
143   }
144 
145  private:
146   EagerContext* ctx_;
147   gtl::InlinedVector<TensorHandle*, 4> inputs_;
148   const absl::optional<EagerRemoteFunctionParams> remote_func_params_;
149   core::RefCountPtr<KernelAndDevice> kernel_;
150   GraphCollector* graph_collector_;
151   CancellationManager* const cancellation_manager_;
152   const bool async_;
153   gtl::InlinedVector<TensorHandle*, 2> retvals_;
154 };
155 
156 }  // namespace tensorflow
157 
158 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_
159