1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 // clang-format off 25 // Required for IS_MOBILE_PLATFORM 26 #include "absl/memory/memory.h" 27 #include "tensorflow/core/platform/platform.h" 28 // clang-format on 29 30 #include "absl/types/optional.h" 31 #include "tensorflow/core/common_runtime/device.h" 32 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 33 #include "tensorflow/core/framework/cancellation.h" 34 #include "tensorflow/core/framework/collective.h" 35 #include "tensorflow/core/framework/node_def.pb.h" 36 #include "tensorflow/core/framework/op_kernel.h" 37 #include "tensorflow/core/framework/types.h" 38 #include "tensorflow/core/lib/core/errors.h" 39 #include "tensorflow/core/lib/core/status.h" 40 #include "tensorflow/core/lib/gtl/inlined_vector.h" 41 #include "tensorflow/core/platform/fingerprint.h" 42 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 43 #if !defined(IS_MOBILE_PLATFORM) 44 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 45 #endif // IS_MOBILE_PLATFORM 46 47 namespace tensorflow { 48 49 class ProcessFunctionLibraryRuntime; 50 class FunctionLibraryRuntime; 51 52 struct EagerRemoteFunctionParams { 53 int64 op_id; 54 // Set when this function is a component function. 55 absl::optional<int64> step_id = absl::nullopt; 56 }; 57 58 class EagerKernelArgs : public FunctionArgsInterface { 59 public: EagerKernelArgs()60 EagerKernelArgs() {} 61 EagerKernelArgs(int count)62 explicit EagerKernelArgs(int count) : tensor_args_(count) {} 63 EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)64 explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args) 65 : tensor_args_(std::move(tensor_args)) {} 66 ~EagerKernelArgs()67 ~EagerKernelArgs() override{}; 68 HasRemoteInputs()69 bool HasRemoteInputs() const override { return false; }; 70 71 Status GetLocalArg(const int index, Tensor* val) const override; 72 73 std::vector<Tensor> GetLocalTensors() const override; 74 GetTensorValues()75 const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const override { 76 return &tensor_args_; 77 }; 78 79 protected: 80 gtl::InlinedVector<TensorValue, 4> tensor_args_; 81 }; 82 83 // KernelAndDevice encapsulates the logic needed to run a computation eagerly. 84 // The computation can be a single instantiated kernel (implemented by 85 // KernelAndDeviceOp below) or a multi-device function (implemented by 86 // KernelAndDeviceFunc below). 87 // 88 // Also see: 89 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h 90 // and 91 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h 92 class KernelAndDevice : public core::RefCounted { 93 public: 94 // Populates this with a kernel appropriate for 'ndef'. 95 // 96 // The provided FunctionLibraryRuntime MUST outlive all calls to 97 // Run() on the returned KernelAndDevice. 98 virtual Status Init(const NodeDef& ndef, GraphCollector* graph_collector) = 0; 99 100 // Non-multi-device functions are run using regular CallOp and look like 101 // primitive operations from KernelAndDevice perspective. 102 // `flr` can be nullptr if the operation is not run on any specific device 103 // (currently can happen only for multi-device functions). KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)104 KernelAndDevice( 105 FunctionLibraryRuntime* flr, 106 std::function<void(std::function<void()>)>* runner, 107 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 108 Device* host_cpu_device) 109 : device_(flr == nullptr ? nullptr : flr->device()), 110 host_cpu_device_(host_cpu_device), 111 flr_(flr), 112 collective_executor_(std::move(collective_executor)), 113 runner_(runner) {} 114 115 // Not thread safe. ~KernelAndDevice()116 ~KernelAndDevice() override {} 117 IsFunction()118 virtual bool IsFunction() { return false; } 119 120 // TODO(ashankar): Handle list-valued inputs. 121 virtual Status Run( 122 const EagerKernelArgs& inputs, std::vector<Tensor>* outputs, 123 CancellationManager* cancellation_manager, 124 const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0; 125 126 virtual Status Run( 127 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 128 std::vector<Tensor>* outputs, CancellationManager* cancellation_manager, 129 const absl::optional<EagerRemoteFunctionParams>& remote_func_params) = 0; 130 131 virtual Device* InputDevice(int i) const = 0; 132 virtual Device* OutputDevice(int idx) const = 0; 133 // If idx'th output is a resource, returns the device backing the resource. 134 // Else, returns nullptr. 135 virtual Device* OutputResourceDevice(int idx) const = 0; 136 137 // Returns the kernel that will be used to run this. 138 // Returns nullptr if this will be run using function library runtime. 139 virtual const OpKernel* kernel() const = 0; 140 141 // Returns the device on which this kernel will run. In the case of 142 // multi-device functions, this is the default device that is passed to the 143 // placer but actual computation can happen on a different set of devices. 144 // Also, outputs can be produced on devices different from what this method 145 // returns. device()146 Device* device() const { return device_; } 147 148 virtual const DataTypeVector& output_dtypes() const = 0; 149 150 virtual DataType input_type(int i) const = 0; 151 virtual int num_inputs() const = 0; 152 virtual int num_outputs() const = 0; 153 virtual const string& name() const = 0; 154 155 protected: 156 std::function<void(std::function<void()>)>* get_runner() const; 157 158 Device* const device_; // can be null 159 Device* const host_cpu_device_; // non-null 160 FunctionLibraryRuntime* const flr_; // can be null 161 const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_; 162 163 private: 164 std::function<void(std::function<void()>)>* const runner_; // can be null 165 }; 166 167 // Represents an op kernel and the device it will be run on. 168 class KernelAndDeviceOp final : public KernelAndDevice { 169 public: KernelAndDeviceOp(tensorflow::Rendezvous * rendez,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)170 KernelAndDeviceOp( 171 tensorflow::Rendezvous* rendez, bool log_memory, 172 FunctionLibraryRuntime* flr, 173 std::function<void(std::function<void()>)>* runner, 174 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 175 Device* host_cpu_device) 176 : KernelAndDevice(flr, runner, std::move(collective_executor), 177 host_cpu_device), 178 rendez_(rendez), 179 log_memory_(log_memory), 180 step_container_(0, [this](const string& name) { 181 device_->resource_manager()->Cleanup(name).IgnoreError(); 182 }) {} 183 ~KernelAndDeviceOp()184 ~KernelAndDeviceOp() override {} 185 186 Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override; 187 188 Status Run(const EagerKernelArgs& inputs, std::vector<Tensor>* outputs, 189 CancellationManager* cancellation_manager, 190 const absl::optional<EagerRemoteFunctionParams>& 191 remote_func_params) override; 192 193 Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 194 std::vector<Tensor>* outputs, 195 CancellationManager* cancellation_manager, 196 const absl::optional<EagerRemoteFunctionParams>& 197 remote_func_params) override; 198 kernel()199 const OpKernel* kernel() const override { return kernel_.get(); } 200 201 Device* InputDevice(int i) const override; 202 Device* OutputDevice(int idx) const override; 203 Device* OutputResourceDevice(int idx) const override; 204 205 DataType input_type(int i) const override; output_dtypes()206 const DataTypeVector& output_dtypes() const override { 207 return kernel_->output_types(); 208 } num_inputs()209 int num_inputs() const override { return kernel_->num_inputs(); } num_outputs()210 int num_outputs() const override { return kernel_->num_outputs(); } name()211 const string& name() const override { return kernel_->name(); } 212 213 private: 214 std::unique_ptr<OpKernel> kernel_; 215 gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_; 216 gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_; 217 Rendezvous* const rendez_; 218 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; 219 const bool log_memory_; 220 ScopedStepContainer step_container_; 221 }; 222 223 // Represents a multi-device function. Functions can also be run using 224 // various function-calling kernels including CallOp and PartitionedCallOp. 225 // In such cases, KernelAndDeviceOp is used. 226 class KernelAndDeviceFunc final : public KernelAndDevice { 227 public: 228 // `flr` can be nullptr. 229 // `pflr` must not be nullptr. 230 // `host_cpu_device` must not be nullptr. KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,std::function<Rendezvous * (const int64)> rendezvous_creator,std::function<int64 ()> get_op_id)231 KernelAndDeviceFunc( 232 FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, 233 std::vector<Device*> input_devices, 234 std::unordered_map<int, DtypeAndPartialTensorShape> 235 input_resource_dtypes_and_shapes, 236 std::function<void(std::function<void()>)>* runner, 237 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 238 Device* host_cpu_device, const string& name, 239 std::function<Rendezvous*(const int64)> rendezvous_creator, 240 std::function<int64()> get_op_id) 241 : KernelAndDevice(flr, runner, std::move(collective_executor), 242 host_cpu_device), 243 pflr_(pflr), 244 handle_(kInvalidHandle), 245 input_devices_(std::move(input_devices)), 246 input_resource_dtypes_and_shapes_( 247 std::move(input_resource_dtypes_and_shapes)), 248 name_(name), 249 rendezvous_creator_(std::move(rendezvous_creator)), 250 get_op_id_(std::move(get_op_id)), 251 step_container_(0, [this](const string& name) { 252 // TODO(b/139809335): This does not properly clean up remote resources 253 const std::vector<Device*> devices = 254 pflr_->device_mgr()->ListDevices(); 255 for (Device* device : devices) { 256 device->resource_manager()->Cleanup(name).IgnoreError(); 257 } 258 }) {} 259 260 ~KernelAndDeviceFunc() override; 261 IsFunction()262 bool IsFunction() override { return true; }; 263 264 Status InstantiateFunc(const NodeDef& ndef, GraphCollector* graph_collector); 265 266 Status Init(const NodeDef& ndef, GraphCollector* graph_collector) override; 267 268 Status Run(const EagerKernelArgs& inputs, std::vector<Tensor>* outputs, 269 CancellationManager* cancellation_manager, 270 const absl::optional<EagerRemoteFunctionParams>& 271 remote_func_params) override; 272 Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 273 std::vector<Tensor>* outputs, 274 CancellationManager* cancellation_manager, 275 const absl::optional<EagerRemoteFunctionParams>& 276 remote_func_params) override; 277 kernel()278 const OpKernel* kernel() const override { return nullptr; } 279 280 Device* InputDevice(int i) const override; 281 Device* OutputDevice(int idx) const override; 282 Device* OutputResourceDevice(int idx) const override; 283 284 DataType input_type(int i) const override; output_dtypes()285 const DataTypeVector& output_dtypes() const override { 286 return output_dtypes_; 287 } num_inputs()288 int num_inputs() const override { return input_dtypes_.size(); } num_outputs()289 int num_outputs() const override { return output_dtypes_.size(); } name()290 const string& name() const override { return name_; }; 291 292 private: 293 ProcessFunctionLibraryRuntime* const pflr_; // non-null 294 FunctionLibraryRuntime::Handle handle_; 295 // Indicates whether the function needs to execute cross process. 296 bool is_cross_process_; 297 // CPU devices are null. Resource handles' devices are actual backing 298 // devices. 299 std::vector<Device*> output_devices_; 300 // CPU devices are not null. Resource handles' devices are actual backing 301 // devices. 302 std::vector<Device*> input_devices_; 303 std::unordered_map<int, DtypeAndPartialTensorShape> 304 input_resource_dtypes_and_shapes_; 305 306 DataTypeVector input_dtypes_; 307 DataTypeVector output_dtypes_; 308 string name_; 309 310 std::function<Rendezvous*(const int64)> rendezvous_creator_; 311 std::function<int64()> get_op_id_; 312 313 ScopedStepContainer step_container_; 314 }; 315 316 } // namespace tensorflow 317 318 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 319