• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/platform/platform.h"
28 // clang-format on
29 
30 #include "absl/types/optional.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33 #include "tensorflow/core/framework/cancellation.h"
34 #include "tensorflow/core/framework/collective.h"
35 #include "tensorflow/core/framework/node_def.pb.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/platform/fingerprint.h"
42 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
43 #if !defined(IS_MOBILE_PLATFORM)
44 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
45 #endif  // IS_MOBILE_PLATFORM
46 
47 namespace tensorflow {
48 
49 class ProcessFunctionLibraryRuntime;
50 class FunctionLibraryRuntime;
51 
52 struct EagerRemoteFunctionParams {
53   int64 op_id;
54   // Set when this function is a component function.
55   absl::optional<int64> step_id = absl::nullopt;
56 };
57 
58 class EagerKernelArgs : public FunctionArgsInterface {
59  public:
EagerKernelArgs()60   EagerKernelArgs() {}
61 
EagerKernelArgs(int count)62   explicit EagerKernelArgs(int count) : tensor_args_(count) {}
63 
EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)64   explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args)
65       : tensor_args_(std::move(tensor_args)) {}
66 
~EagerKernelArgs()67   ~EagerKernelArgs() override{};
68 
HasRemoteInputs()69   bool HasRemoteInputs() const override { return false; };
70 
71   Status GetLocalArg(const int index, Tensor* val) const override;
72 
73   std::vector<Tensor> GetLocalTensors() const override;
74 
GetTensorValues()75   const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const override {
76     return &tensor_args_;
77   };
78 
79  protected:
80   gtl::InlinedVector<TensorValue, 4> tensor_args_;
81 };
82 
83 // KernelAndDevice encapsulates the logic needed to run a computation eagerly.
84 // The computation can be a single instantiated kernel (implemented by
85 // KernelAndDeviceOp below) or a multi-device function (implemented by
86 // KernelAndDeviceFunc below).
87 //
88 // Also see:
89 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
90 // and
91 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
92 class KernelAndDevice : public core::RefCounted {
93  public:
94   // Populates this with a kernel appropriate for 'ndef'.
95   //
96   // The provided FunctionLibraryRuntime MUST outlive all calls to
97   // Run() on the returned KernelAndDevice.
98   virtual Status Init(const NodeDef& ndef, GraphCollector* graph_collector) = 0;
99 
100   // Non-multi-device functions are run using regular CallOp and look like
101   // primitive operations from KernelAndDevice perspective.
102   // `flr` can be nullptr if the operation is not run on any specific device
103   // (currently can happen only for multi-device functions).
KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)104   KernelAndDevice(
105       FunctionLibraryRuntime* flr,
106       std::function<void(std::function<void()>)>* runner,
107       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
108       Device* host_cpu_device)
109       : device_(flr == nullptr ? nullptr : flr->device()),
110         host_cpu_device_(host_cpu_device),
111         flr_(flr),
112         collective_executor_(std::move(collective_executor)),
113         runner_(runner) {}
114 
115   // Not thread safe.
~KernelAndDevice()116   ~KernelAndDevice() override {}
117 
IsFunction()118   virtual bool IsFunction() { return false; }
119 
120   // TODO(ashankar): Handle list-valued inputs.
121   virtual Status Run(
122       const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
123       CancellationManager* cancellation_manager,
124       const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0;
125 
126   virtual Status Run(
127       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
128       std::vector<Tensor>* outputs, CancellationManager* cancellation_manager,
129       const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0;
130 
131   virtual Device* InputDevice(int i) const = 0;
132   virtual Device* OutputDevice(int idx) const = 0;
133   // If idx'th output is a resource, returns the device backing the resource.
134   // Else, returns nullptr.
135   virtual Device* OutputResourceDevice(int idx) const = 0;
136 
137   // Returns the kernel that will be used to run this.
138   // Returns nullptr if this will be run using function library runtime.
139   virtual const OpKernel* kernel() const = 0;
140 
141   // Returns the device on which this kernel will run. In the case of
142   // multi-device functions, this is the default device that is passed to the
143   // placer but actual computation can happen on a different set of devices.
144   // Also, outputs can be produced on devices different from what this method
145   // returns.
device()146   Device* device() const { return device_; }
147 
148   virtual const DataTypeVector& output_dtypes() const = 0;
149 
150   virtual DataType input_type(int i) const = 0;
151   virtual int num_inputs() const = 0;
152   virtual int num_outputs() const = 0;
153   virtual const string& name() const = 0;
154 
155  protected:
156   std::function<void(std::function<void()>)>* get_runner() const;
157 
158   Device* const device_;               // can be null
159   Device* const host_cpu_device_;      // non-null
160   FunctionLibraryRuntime* const flr_;  // can be null
161   const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
162 
163  private:
164   std::function<void(std::function<void()>)>* const runner_;  // can be null
165 };
166 
167 // Represents an op kernel and the device it will be run on.
168 class KernelAndDeviceOp final : public KernelAndDevice {
169  public:
KernelAndDeviceOp(tensorflow::Rendezvous * rendez,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)170   KernelAndDeviceOp(
171       tensorflow::Rendezvous* rendez, bool log_memory,
172       FunctionLibraryRuntime* flr,
173       std::function<void(std::function<void()>)>* runner,
174       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
175       Device* host_cpu_device)
176       : KernelAndDevice(flr, runner, std::move(collective_executor),
177                         host_cpu_device),
178         rendez_(rendez),
179         log_memory_(log_memory),
180         step_container_(0, [this](const string& name) {
181           device_->resource_manager()->Cleanup(name).IgnoreError();
182         }) {}
183 
~KernelAndDeviceOp()184   ~KernelAndDeviceOp() override {}
185 
186   Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
187 
188   Status Run(const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
189              CancellationManager* cancellation_manager,
190              const absl::optional<EagerRemoteFunctionParams>&
191                  remote_func_params) override;
192 
193   Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
194              std::vector<Tensor>* outputs,
195              CancellationManager* cancellation_manager,
196              const absl::optional<EagerRemoteFunctionParams>&
197                  remote_func_params) override;
198 
kernel()199   const OpKernel* kernel() const override { return kernel_.get(); }
200 
201   Device* InputDevice(int i) const override;
202   Device* OutputDevice(int idx) const override;
203   Device* OutputResourceDevice(int idx) const override;
204 
205   DataType input_type(int i) const override;
output_dtypes()206   const DataTypeVector& output_dtypes() const override {
207     return kernel_->output_types();
208   }
num_inputs()209   int num_inputs() const override { return kernel_->num_inputs(); }
num_outputs()210   int num_outputs() const override { return kernel_->num_outputs(); }
name()211   const string& name() const override { return kernel_->name(); }
212 
213  private:
214   std::unique_ptr<OpKernel> kernel_;
215   gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_;
216   gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_;
217   Rendezvous* const rendez_;
218   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
219   const bool log_memory_;
220   ScopedStepContainer step_container_;
221 };
222 
223 // Represents a multi-device function. Functions can also be run using
224 // various function-calling kernels including CallOp and PartitionedCallOp.
225 // In such cases, KernelAndDeviceOp is used.
226 class KernelAndDeviceFunc final : public KernelAndDevice {
227  public:
228   // `flr` can be nullptr.
229   // `pflr` must not be nullptr.
230   // `host_cpu_device` must not be nullptr.
KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,std::function<Rendezvous * (const int64)> rendezvous_creator,std::function<int64 ()> get_op_id)231   KernelAndDeviceFunc(
232       FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
233       std::vector<Device*> input_devices,
234       std::unordered_map<int, DtypeAndPartialTensorShape>
235           input_resource_dtypes_and_shapes,
236       std::function<void(std::function<void()>)>* runner,
237       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
238       Device* host_cpu_device, const string& name,
239       std::function<Rendezvous*(const int64)> rendezvous_creator,
240       std::function<int64()> get_op_id)
241       : KernelAndDevice(flr, runner, std::move(collective_executor),
242                         host_cpu_device),
243         pflr_(pflr),
244         handle_(kInvalidHandle),
245         input_devices_(std::move(input_devices)),
246         input_resource_dtypes_and_shapes_(
247             std::move(input_resource_dtypes_and_shapes)),
248         name_(name),
249         rendezvous_creator_(std::move(rendezvous_creator)),
250         get_op_id_(std::move(get_op_id)),
251         step_container_(0, [this](const string& name) {
252           // TODO(b/139809335): This does not properly clean up remote resources
253           const std::vector<Device*> devices =
254               pflr_->device_mgr()->ListDevices();
255           for (Device* device : devices) {
256             device->resource_manager()->Cleanup(name).IgnoreError();
257           }
258         }) {}
259 
260   ~KernelAndDeviceFunc() override;
261 
IsFunction()262   bool IsFunction() override { return true; };
263 
264   Status InstantiateFunc(const NodeDef& ndef, GraphCollector* graph_collector);
265 
266   Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
267 
268   Status Run(const EagerKernelArgs& inputs, std::vector<Tensor>* outputs,
269              CancellationManager* cancellation_manager,
270              const absl::optional<EagerRemoteFunctionParams>&
271                  remote_func_params) override;
272   Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
273              std::vector<Tensor>* outputs,
274              CancellationManager* cancellation_manager,
275              const absl::optional<EagerRemoteFunctionParams>&
276                  remote_func_params) override;
277 
kernel()278   const OpKernel* kernel() const override { return nullptr; }
279 
280   Device* InputDevice(int i) const override;
281   Device* OutputDevice(int idx) const override;
282   Device* OutputResourceDevice(int idx) const override;
283 
284   DataType input_type(int i) const override;
output_dtypes()285   const DataTypeVector& output_dtypes() const override {
286     return output_dtypes_;
287   }
num_inputs()288   int num_inputs() const override { return input_dtypes_.size(); }
num_outputs()289   int num_outputs() const override { return output_dtypes_.size(); }
name()290   const string& name() const override { return name_; };
291 
292  private:
293   ProcessFunctionLibraryRuntime* const pflr_;  // non-null
294   FunctionLibraryRuntime::Handle handle_;
295   // Indicates whether the function needs to execute cross process.
296   bool is_cross_process_;
297   // CPU devices are null. Resource handles' devices are actual backing
298   // devices.
299   std::vector<Device*> output_devices_;
300   // CPU devices are not null. Resource handles' devices are actual backing
301   // devices.
302   std::vector<Device*> input_devices_;
303   std::unordered_map<int, DtypeAndPartialTensorShape>
304       input_resource_dtypes_and_shapes_;
305 
306   DataTypeVector input_dtypes_;
307   DataTypeVector output_dtypes_;
308   string name_;
309 
310   std::function<Rendezvous*(const int64)> rendezvous_creator_;
311   std::function<int64()> get_op_id_;
312 
313   ScopedStepContainer step_container_;
314 };
315 
316 }  // namespace tensorflow
317 
318 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
319