• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 #include "tensorflow/core/common_runtime/device.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/platform/fingerprint.h"
33 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
34 
35 namespace tensorflow {
36 
37 // Forward declaration for proto class NodeExecStats so we do not need to
38 // include the proto header
39 class NodeExecStats;
40 class StepStats;
41 class ProcessFunctionLibraryRuntime;
42 class FunctionLibraryRuntime;
43 
44 // KernelAndDevice encapsulates the logic needed to run a computation eagerly.
45 // The computation can be a single instantiated kernel (implemented by
46 // KernelAndDeviceOp below) or a multi-device function (implemented by
47 // KernelAndDeviceFunc below).
48 //
49 // Also see:
50 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
51 // and
52 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
53 class KernelAndDevice {
54  public:
55   // Populates this with a kernel appropriate for 'ndef'.
56   //
57   // The provided FunctionLibraryRuntime MUST outlive all calls to
58   // Run() on the returned KernelAndDevice.
59   virtual Status Init(const NodeDef& ndef, GraphCollector* graph_collector) = 0;
60 
61   // Non-multi-device functions are run using regular CallOp and look like
62   // primitive operations from KernelAndDevice perspective.
63   // `flr` can be nullptr if the operation is not run on any specific device
64   // (currently can happen only for multi-device functions).
KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)65   KernelAndDevice(
66       FunctionLibraryRuntime* flr,
67       std::function<void(std::function<void()>)>* runner,
68       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
69       Device* host_cpu_device)
70       : device_(flr == nullptr ? nullptr : flr->device()),
71         host_cpu_device_(host_cpu_device),
72         flr_(flr),
73         runner_(runner),
74         default_runner_([](std::function<void()> f) { f(); }),
75         collective_executor_(std::move(collective_executor)) {}
76 
77   // Not thread safe.
~KernelAndDevice()78   virtual ~KernelAndDevice() {}
79 
80   // TODO(ashankar): Handle list-valued inputs.
81   virtual Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
82                      std::vector<Tensor>* outputs, NodeExecStats* stats,
83                      StepStats* step_stats,
84                      GraphCollector* graph_collector) = 0;
85 
86   virtual Status Run(ScopedStepContainer* step_container,
87                      const gtl::InlinedVector<TensorValue, 4>& inputs,
88                      std::vector<Tensor>* outputs, NodeExecStats* stats,
89                      StepStats* step_stats,
90                      GraphCollector* graph_collector) = 0;
91 
92   virtual Device* InputDevice(int i) const = 0;
93   virtual Device* OutputDevice(int idx) const = 0;
94   // If idx'th output is a resource, returns the device backing the resource.
95   // Else, returns nullptr.
96   virtual Device* OutputResourceDevice(int idx) const = 0;
97 
98   // Returns the kernel that will be used to run this.
99   // Returns nullptr if this will be run using function library runtime.
100   virtual const OpKernel* kernel() const = 0;
101 
102   // Returns the device on which this kernel will run. In the case of
103   // multi-device functions, this is the default device that is passed to the
104   // placer but actual computation can happen on a different set of devices.
105   // Also, outputs can be produced on devices different from what this method
106   // returns.
device()107   Device* device() const { return device_; }
108 
109   virtual const DataTypeVector& output_dtypes() const = 0;
110 
111   virtual DataType input_type(int i) const = 0;
112   virtual int num_inputs() const = 0;
113   virtual int num_outputs() const = 0;
114 
115  protected:
116   // TODO(apassos) Consider a shared cancellation manager. Note that this
117   // cancellation manager is not useful to actually cancel anything, and is
118   // provided here only for the few kernels which can't handle one being
119   // missing.
120   CancellationManager cm_;
121   Device* const device_;               // can be null
122   Device* const host_cpu_device_;      // non-null
123   FunctionLibraryRuntime* const flr_;  // can be null
124   std::function<void(std::function<void()>)>* const runner_;
125   std::function<void(std::function<void()>)> default_runner_;
126   const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
127 };
128 
129 // Represents an op kernel and the device it will be run on.
130 class KernelAndDeviceOp final : public KernelAndDevice {
131  public:
KernelAndDeviceOp(tensorflow::Rendezvous * rendez,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)132   KernelAndDeviceOp(
133       tensorflow::Rendezvous* rendez, bool log_memory,
134       FunctionLibraryRuntime* flr,
135       std::function<void(std::function<void()>)>* runner,
136       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
137       Device* host_cpu_device)
138       : KernelAndDevice(flr, runner, std::move(collective_executor),
139                         host_cpu_device),
140         rendez_(rendez),
141         log_memory_(log_memory) {}
142 
143   virtual ~KernelAndDeviceOp();
144 
145   Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
146 
147   Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
148              std::vector<Tensor>* outputs, NodeExecStats* stats,
149              StepStats* step_stats, GraphCollector* graph_collector) override;
150 
151   Status Run(ScopedStepContainer* step_container,
152              const gtl::InlinedVector<TensorValue, 4>& inputs,
153              std::vector<Tensor>* outputs, NodeExecStats* stats,
154              StepStats* step_stats, GraphCollector* graph_collector) override;
155 
kernel()156   const OpKernel* kernel() const override { return kernel_.get(); }
157 
158   Device* InputDevice(int i) const override;
159   Device* OutputDevice(int idx) const override;
160   Device* OutputResourceDevice(int idx) const override;
161 
162   DataType input_type(int i) const override;
output_dtypes()163   const DataTypeVector& output_dtypes() const override {
164     return kernel_->output_types();
165   }
num_inputs()166   int num_inputs() const override { return kernel_->num_inputs(); }
num_outputs()167   int num_outputs() const override { return kernel_->num_outputs(); }
168 
169  private:
170   std::unique_ptr<OpKernel> kernel_;
171   Rendezvous* const rendez_;
172   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
173   const bool log_memory_;
174 
175   // For deferred ops, AsyncOpKernel::DoneCallback is called once the op is
176   // enqueued to device. The execution of the op may not finish when
177   // device_->Compute returns. We rely on no_deferred_ops_cv_ to know when the
178   // execution has finished.
179   // Available via OpKernelContext to every OpKernel invocation.
180   mutex num_deferred_ops_mu_;
181   condition_variable no_deferred_ops_cv_;
182   int64 num_deferred_ops_ GUARDED_BY(num_deferred_ops_mu_) = 0;
183 };
184 
185 // Represents a multi-device function. Functions can also be run using
186 // various function-calling kernels including CallOp and PartitionedCallOp.
187 // In such cases, KernelAndDeviceOp is used.
188 class KernelAndDeviceFunc final : public KernelAndDevice {
189  public:
190   // `flr` can be nullptr.
191   // `pflr` must not be nullptr.
192   // `host_cpu_device` must not be nullptr.
KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)193   KernelAndDeviceFunc(
194       FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
195       std::vector<Device*> input_devices,
196       std::function<void(std::function<void()>)>* runner,
197       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
198       Device* host_cpu_device)
199       : KernelAndDevice(flr, runner, std::move(collective_executor),
200                         host_cpu_device),
201         pflr_(pflr),
202         handle_(kInvalidHandle),
203         input_devices_(std::move(input_devices)) {}
204 
205   virtual ~KernelAndDeviceFunc();
206 
207   Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override;
208 
209   Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs,
210              std::vector<Tensor>* outputs, NodeExecStats* stats,
211              StepStats* step_stats, GraphCollector* graph_collector) override;
212   Status Run(ScopedStepContainer* step_container,
213              const gtl::InlinedVector<TensorValue, 4>& inputs,
214              std::vector<Tensor>* outputs, NodeExecStats* stats,
215              StepStats* step_stats, GraphCollector* graph_collector) override;
216 
kernel()217   const OpKernel* kernel() const override { return nullptr; }
218 
219   Device* InputDevice(int i) const override;
220   Device* OutputDevice(int idx) const override;
221   Device* OutputResourceDevice(int idx) const override;
222 
223   DataType input_type(int i) const override;
output_dtypes()224   const DataTypeVector& output_dtypes() const override {
225     return output_dtypes_;
226   }
num_inputs()227   int num_inputs() const override { return input_dtypes_.size(); }
num_outputs()228   int num_outputs() const override { return output_dtypes_.size(); }
229 
230  private:
231   ProcessFunctionLibraryRuntime* const pflr_;  // non-null
232   FunctionLibraryRuntime::Handle handle_;
233   // CPU devices are null. Resource handles' devices are actual backing
234   // devices.
235   std::vector<Device*> output_devices_;
236   // CPU devices are not null. Resource handles' devices are actual backing
237   // devices.
238   std::vector<Device*> input_devices_;
239 
240   DataTypeVector input_dtypes_;
241   DataTypeVector output_dtypes_;
242 };
243 
244 }  // namespace tensorflow
245 
246 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
247