1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 #include "tensorflow/core/common_runtime/device.h" 25 #include "tensorflow/core/framework/cancellation.h" 26 #include "tensorflow/core/framework/collective.h" 27 #include "tensorflow/core/framework/node_def.pb.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/platform/fingerprint.h" 33 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 34 35 namespace tensorflow { 36 37 // Forward declaration for proto class NodeExecStats so we do not need to 38 // include the proto header 39 class NodeExecStats; 40 class StepStats; 41 class ProcessFunctionLibraryRuntime; 42 class FunctionLibraryRuntime; 43 44 // KernelAndDevice encapsulates the logic needed to run a computation eagerly. 45 // The computation can be a single instantiated kernel (implemented by 46 // KernelAndDeviceOp below) or a multi-device function (implemented by 47 // KernelAndDeviceFunc below). 48 // 49 // Also see: 50 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h 51 // and 52 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h 53 class KernelAndDevice { 54 public: 55 // Populates this with a kernel appropriate for 'ndef'. 56 // 57 // The provided FunctionLibraryRuntime MUST outlive all calls to 58 // Run() on the returned KernelAndDevice. 59 virtual Status Init(const NodeDef& ndef, GraphCollector* graph_collector) = 0; 60 61 // Non-multi-device functions are run using regular CallOp and look like 62 // primitive operations from KernelAndDevice perspective. 63 // `flr` can be nullptr if the operation is not run on any specific device 64 // (currently can happen only for multi-device functions). KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)65 KernelAndDevice( 66 FunctionLibraryRuntime* flr, 67 std::function<void(std::function<void()>)>* runner, 68 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 69 Device* host_cpu_device) 70 : device_(flr == nullptr ? nullptr : flr->device()), 71 host_cpu_device_(host_cpu_device), 72 flr_(flr), 73 runner_(runner), 74 default_runner_([](std::function<void()> f) { f(); }), 75 collective_executor_(std::move(collective_executor)) {} 76 77 // Not thread safe. ~KernelAndDevice()78 virtual ~KernelAndDevice() {} 79 80 // TODO(ashankar): Handle list-valued inputs. 81 virtual Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs, 82 std::vector<Tensor>* outputs, NodeExecStats* stats, 83 StepStats* step_stats, 84 GraphCollector* graph_collector) = 0; 85 86 virtual Status Run(ScopedStepContainer* step_container, 87 const gtl::InlinedVector<TensorValue, 4>& inputs, 88 std::vector<Tensor>* outputs, NodeExecStats* stats, 89 StepStats* step_stats, 90 GraphCollector* graph_collector) = 0; 91 92 virtual Device* InputDevice(int i) const = 0; 93 virtual Device* OutputDevice(int idx) const = 0; 94 // If idx'th output is a resource, returns the device backing the resource. 95 // Else, returns nullptr. 96 virtual Device* OutputResourceDevice(int idx) const = 0; 97 98 // Returns the kernel that will be used to run this. 99 // Returns nullptr if this will be run using function library runtime. 100 virtual const OpKernel* kernel() const = 0; 101 102 // Returns the device on which this kernel will run. In the case of 103 // multi-device functions, this is the default device that is passed to the 104 // placer but actual computation can happen on a different set of devices. 105 // Also, outputs can be produced on devices different from what this method 106 // returns. device()107 Device* device() const { return device_; } 108 109 virtual const DataTypeVector& output_dtypes() const = 0; 110 111 virtual DataType input_type(int i) const = 0; 112 virtual int num_inputs() const = 0; 113 virtual int num_outputs() const = 0; 114 115 protected: 116 // TODO(apassos) Consider a shared cancellation manager. Note that this 117 // cancellation manager is not useful to actually cancel anything, and is 118 // provided here only for the few kernels which can't handle one being 119 // missing. 120 CancellationManager cm_; 121 Device* const device_; // can be null 122 Device* const host_cpu_device_; // non-null 123 FunctionLibraryRuntime* const flr_; // can be null 124 std::function<void(std::function<void()>)>* const runner_; 125 std::function<void(std::function<void()>)> default_runner_; 126 const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_; 127 }; 128 129 // Represents an op kernel and the device it will be run on. 130 class KernelAndDeviceOp final : public KernelAndDevice { 131 public: KernelAndDeviceOp(tensorflow::Rendezvous * rendez,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)132 KernelAndDeviceOp( 133 tensorflow::Rendezvous* rendez, bool log_memory, 134 FunctionLibraryRuntime* flr, 135 std::function<void(std::function<void()>)>* runner, 136 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 137 Device* host_cpu_device) 138 : KernelAndDevice(flr, runner, std::move(collective_executor), 139 host_cpu_device), 140 rendez_(rendez), 141 log_memory_(log_memory) {} 142 143 virtual ~KernelAndDeviceOp(); 144 145 Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override; 146 147 Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs, 148 std::vector<Tensor>* outputs, NodeExecStats* stats, 149 StepStats* step_stats, GraphCollector* graph_collector) override; 150 151 Status Run(ScopedStepContainer* step_container, 152 const gtl::InlinedVector<TensorValue, 4>& inputs, 153 std::vector<Tensor>* outputs, NodeExecStats* stats, 154 StepStats* step_stats, GraphCollector* graph_collector) override; 155 kernel()156 const OpKernel* kernel() const override { return kernel_.get(); } 157 158 Device* InputDevice(int i) const override; 159 Device* OutputDevice(int idx) const override; 160 Device* OutputResourceDevice(int idx) const override; 161 162 DataType input_type(int i) const override; output_dtypes()163 const DataTypeVector& output_dtypes() const override { 164 return kernel_->output_types(); 165 } num_inputs()166 int num_inputs() const override { return kernel_->num_inputs(); } num_outputs()167 int num_outputs() const override { return kernel_->num_outputs(); } 168 169 private: 170 std::unique_ptr<OpKernel> kernel_; 171 Rendezvous* const rendez_; 172 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; 173 const bool log_memory_; 174 175 // For deferred ops, AsyncOpKernel::DoneCallback is called once the op is 176 // enqueued to device. The execution of the op may not finish when 177 // device_->Compute returns. We rely on no_deferred_ops_cv_ to know when the 178 // execution has finished. 179 // Available via OpKernelContext to every OpKernel invocation. 180 mutex num_deferred_ops_mu_; 181 condition_variable no_deferred_ops_cv_; 182 int64 num_deferred_ops_ GUARDED_BY(num_deferred_ops_mu_) = 0; 183 }; 184 185 // Represents a multi-device function. Functions can also be run using 186 // various function-calling kernels including CallOp and PartitionedCallOp. 187 // In such cases, KernelAndDeviceOp is used. 188 class KernelAndDeviceFunc final : public KernelAndDevice { 189 public: 190 // `flr` can be nullptr. 191 // `pflr` must not be nullptr. 192 // `host_cpu_device` must not be nullptr. KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)193 KernelAndDeviceFunc( 194 FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, 195 std::vector<Device*> input_devices, 196 std::function<void(std::function<void()>)>* runner, 197 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 198 Device* host_cpu_device) 199 : KernelAndDevice(flr, runner, std::move(collective_executor), 200 host_cpu_device), 201 pflr_(pflr), 202 handle_(kInvalidHandle), 203 input_devices_(std::move(input_devices)) {} 204 205 virtual ~KernelAndDeviceFunc(); 206 207 Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override; 208 209 Status Run(const gtl::InlinedVector<TensorValue, 4>& inputs, 210 std::vector<Tensor>* outputs, NodeExecStats* stats, 211 StepStats* step_stats, GraphCollector* graph_collector) override; 212 Status Run(ScopedStepContainer* step_container, 213 const gtl::InlinedVector<TensorValue, 4>& inputs, 214 std::vector<Tensor>* outputs, NodeExecStats* stats, 215 StepStats* step_stats, GraphCollector* graph_collector) override; 216 kernel()217 const OpKernel* kernel() const override { return nullptr; } 218 219 Device* InputDevice(int i) const override; 220 Device* OutputDevice(int idx) const override; 221 Device* OutputResourceDevice(int idx) const override; 222 223 DataType input_type(int i) const override; output_dtypes()224 const DataTypeVector& output_dtypes() const override { 225 return output_dtypes_; 226 } num_inputs()227 int num_inputs() const override { return input_dtypes_.size(); } num_outputs()228 int num_outputs() const override { return output_dtypes_.size(); } 229 230 private: 231 ProcessFunctionLibraryRuntime* const pflr_; // non-null 232 FunctionLibraryRuntime::Handle handle_; 233 // CPU devices are null. Resource handles' devices are actual backing 234 // devices. 235 std::vector<Device*> output_devices_; 236 // CPU devices are not null. Resource handles' devices are actual backing 237 // devices. 238 std::vector<Device*> input_devices_; 239 240 DataTypeVector input_dtypes_; 241 DataTypeVector output_dtypes_; 242 }; 243 244 } // namespace tensorflow 245 246 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 247