1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 // clang-format off 25 // Required for IS_MOBILE_PLATFORM 26 #include "absl/memory/memory.h" 27 #include "tensorflow/core/platform/platform.h" 28 // clang-format on 29 30 #include "absl/container/flat_hash_map.h" 31 #include "absl/types/optional.h" 32 #include "tensorflow/core/common_runtime/device.h" 33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 34 #include "tensorflow/core/framework/cancellation.h" 35 #include "tensorflow/core/framework/collective.h" 36 #include "tensorflow/core/framework/node_def.pb.h" 37 #include "tensorflow/core/framework/op_kernel.h" 38 #include "tensorflow/core/framework/types.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/gtl/inlined_vector.h" 42 #include "tensorflow/core/platform/fingerprint.h" 43 #include "tensorflow/core/util/managed_stack_trace.h" 44 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 45 #if !defined(IS_MOBILE_PLATFORM) 46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 47 #endif // IS_MOBILE_PLATFORM 48 49 namespace tensorflow { 50 51 static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice"; 52 53 class ProcessFunctionLibraryRuntime; 54 class FunctionLibraryRuntime; 55 56 const int64_t kInvalidOpId = -1; 57 58 // This struc is used for: 59 // 1. setting op_id and step_id, is_component_function for single-client 60 // remote function scenario, 61 // 2. setting step_id for multi-client parallel_device scenario. 62 struct EagerFunctionParams { 63 int64_t op_id = kInvalidOpId; 64 bool is_component_function; 65 absl::optional<int64_t> step_id = absl::nullopt; 66 }; 67 68 class EagerKernelArgs : public FunctionArgsInterface { 69 public: EagerKernelArgs()70 EagerKernelArgs() {} 71 EagerKernelArgs(int count)72 explicit EagerKernelArgs(int count) : tensor_args_(count) {} 73 EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)74 explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args) 75 : tensor_args_(std::move(tensor_args)) {} 76 ~EagerKernelArgs()77 ~EagerKernelArgs() override{}; 78 HasRemoteOrPackedInputs()79 bool HasRemoteOrPackedInputs() const override { return false; }; MutableInput(int i)80 TensorValue* MutableInput(int i) { return &tensor_args_[i]; } 81 82 Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; 83 84 std::vector<Tensor> GetLocalTensors() const override; 85 GetTensorValues()86 const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const { 87 return &tensor_args_; 88 } 89 90 protected: 91 gtl::InlinedVector<TensorValue, 4> tensor_args_; 92 }; 93 94 typedef absl::variant<Tensor, TensorShape> EagerKernelRet; 95 96 // KernelAndDevice encapsulates the logic needed to run a computation eagerly. 97 // The computation can be a single instantiated kernel (implemented by 98 // KernelAndDeviceOp below) or a multi-device function (implemented by 99 // KernelAndDeviceFunc below). 100 // 101 // Also see: 102 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h 103 // and 104 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h 105 class KernelAndDevice : public core::RefCounted { 106 public: 107 // Populates this with a kernel appropriate for 'ndef'. 108 // 109 // The provided FunctionLibraryRuntime MUST outlive all calls to 110 // Run() on the returned KernelAndDevice. 111 virtual Status Init(const bool log_device_placement, const NodeDef& ndef, 112 GraphCollector* graph_collector) = 0; 113 114 // Non-multi-device functions are run using regular CallOp and look like 115 // primitive operations from KernelAndDevice perspective. 116 // `flr` can be nullptr if the operation is not run on any specific device 117 // (currently can happen only for multi-device functions). KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)118 KernelAndDevice( 119 FunctionLibraryRuntime* flr, 120 std::function<void(std::function<void()>)>* runner, 121 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 122 Device* host_cpu_device) 123 : device_(flr == nullptr ? nullptr : flr->device()), 124 host_cpu_device_(host_cpu_device), 125 flr_(flr), 126 collective_executor_(std::move(collective_executor)), 127 runner_(runner) {} 128 129 // Not thread safe. ~KernelAndDevice()130 ~KernelAndDevice() override {} 131 IsFunction()132 virtual bool IsFunction() { return false; } 133 IsCrossProcess()134 virtual bool IsCrossProcess() { return false; } 135 136 // TODO(ashankar): Handle list-valued inputs. 137 virtual Status Run( 138 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 139 std::vector<EagerKernelRet>* outputs, 140 CancellationManager* cancellation_manager, 141 const absl::optional<EagerFunctionParams>& eager_func_params, 142 const absl::optional<ManagedStackTrace>& stack_trace, 143 CoordinationServiceAgent* coordination_service_agent) = 0; 144 145 // Execute kernel asynchronously when applicable. Different from `Run` which 146 // blocks the caller thread and waits for the execution of the op/function, 147 // `RunAsync` could return before finishing the execution. The `done` callback 148 // will be triggered once the op/function execution finishes. 149 // Currently, calling RunAsync on ops might not honor the asynchronicity when 150 // it is called on an instance with only sync implementation, execute the 151 // kernel synchronously and then call the callback with the return status 152 // from sync execution. 153 virtual void RunAsync( 154 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 155 std::vector<EagerKernelRet>* outputs, 156 CancellationManager* cancellation_manager, 157 const absl::optional<EagerFunctionParams>& eager_func_params, 158 CoordinationServiceAgent* coordination_service_agent, 159 StatusCallback done) = 0; 160 161 virtual Device* InputDevice(int i) const = 0; 162 virtual Device* OutputDevice(int idx) const = 0; 163 // If idx'th output is a resource, returns the device backing the resource. 164 // Else, returns nullptr. 165 virtual Device* OutputResourceDevice(int idx) const = 0; 166 167 // Returns the kernel that will be used to run this. 168 // Returns nullptr if this will be run using function library runtime. 169 virtual const OpKernel* kernel() const = 0; 170 171 // Returns the device on which this kernel will run. In the case of 172 // multi-device functions, this is the default device that is passed to the 173 // placer but actual computation can happen on a different set of devices. 174 // Also, outputs can be produced on devices different from what this method 175 // returns. device()176 Device* device() const { return device_; } 177 178 virtual const DataTypeVector& input_dtypes() const = 0; 179 virtual const DataTypeVector& output_dtypes() const = 0; 180 181 virtual int num_inputs() const = 0; 182 virtual int num_outputs() const = 0; 183 virtual const string& name() const = 0; 184 185 protected: 186 std::function<void(std::function<void()>)>* get_runner() const; 187 188 Device* const device_; // can be null 189 Device* const host_cpu_device_; // non-null 190 FunctionLibraryRuntime* const flr_; // can be null 191 const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_; 192 193 private: 194 std::function<void(std::function<void()>)>* const runner_; // can be null 195 }; 196 197 // Represents an op kernel and the device it will be run on. 198 class KernelAndDeviceOp final : public KernelAndDevice { 199 public: KernelAndDeviceOp(tensorflow::Rendezvous * rendezvous,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)200 KernelAndDeviceOp( 201 tensorflow::Rendezvous* rendezvous, bool log_memory, 202 FunctionLibraryRuntime* flr, 203 std::function<void(std::function<void()>)>* runner, 204 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 205 Device* host_cpu_device) 206 : KernelAndDevice(flr, runner, std::move(collective_executor), 207 host_cpu_device), 208 rendezvous_(rendezvous), 209 log_memory_(log_memory) {} 210 ~KernelAndDeviceOp()211 ~KernelAndDeviceOp() override {} 212 213 Status Init(const bool log_device_placement, const NodeDef& ndef, 214 GraphCollector* graph_collector) override; 215 216 Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 217 std::vector<EagerKernelRet>* outputs, 218 CancellationManager* cancellation_manager, 219 const absl::optional<EagerFunctionParams>& eager_func_params, 220 const absl::optional<ManagedStackTrace>& stack_trace, 221 CoordinationServiceAgent* coordination_service_agent) override; 222 RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,CoordinationServiceAgent * coordination_service_agent,StatusCallback done)223 void RunAsync(ScopedStepContainer* step_container, 224 const EagerKernelArgs& inputs, 225 std::vector<EagerKernelRet>* outputs, 226 CancellationManager* cancellation_manager, 227 const absl::optional<EagerFunctionParams>& eager_func_params, 228 CoordinationServiceAgent* coordination_service_agent, 229 StatusCallback done) override { 230 // Trivial async implementation on top of the sync version 231 done(Run(step_container, inputs, outputs, cancellation_manager, 232 eager_func_params, {}, coordination_service_agent)); 233 } 234 kernel()235 const OpKernel* kernel() const override { return kernel_.get(); } 236 237 Device* InputDevice(int i) const override; 238 Device* OutputDevice(int idx) const override; 239 Device* OutputResourceDevice(int idx) const override; 240 input_dtypes()241 const DataTypeVector& input_dtypes() const override { 242 return kernel_->input_types(); 243 } output_dtypes()244 const DataTypeVector& output_dtypes() const override { 245 return kernel_->output_types(); 246 } num_inputs()247 int num_inputs() const override { return kernel_->num_inputs(); } num_outputs()248 int num_outputs() const override { return kernel_->num_outputs(); } name()249 const string& name() const override { return kernel_->name(); } 250 251 private: 252 std::unique_ptr<OpKernel> kernel_; 253 bool is_distributed_communication_op_; 254 gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_; 255 std::vector<Device*> input_devices_; 256 gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_; 257 Rendezvous* const rendezvous_; 258 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; 259 const bool log_memory_; 260 }; 261 262 // Represents a multi-device function. Functions can also be run using 263 // various function-calling kernels including CallOp and PartitionedCallOp. 264 // In such cases, KernelAndDeviceOp is used. 265 class KernelAndDeviceFunc : public KernelAndDevice { 266 public: 267 // `flr` can be nullptr. 268 // `pflr` must not be nullptr. 269 // `host_cpu_device` must not be nullptr. KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,absl::flat_hash_map<string,const std::vector<string> * > composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,const bool outputs_on_op_device,const bool allow_small_function_optimizations,const bool allow_control_flow_sync_execution,const bool shape_inference_on_tfe_dialect_import,const bool int_args_and_retvals_on_device,absl::optional<string> xla_compile_device_type,std::function<Rendezvous * (const int64_t)> rendezvous_creator,std::function<int64_t ()> get_op_id)270 KernelAndDeviceFunc( 271 FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, 272 std::vector<Device*> input_devices, 273 absl::flat_hash_map<string, const std::vector<string>*> composite_devices, 274 std::unordered_map<int, DtypeAndPartialTensorShape> 275 input_resource_dtypes_and_shapes, 276 std::function<void(std::function<void()>)>* runner, 277 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 278 Device* host_cpu_device, const string& name, 279 const bool outputs_on_op_device, 280 const bool allow_small_function_optimizations, 281 const bool allow_control_flow_sync_execution, 282 const bool shape_inference_on_tfe_dialect_import, 283 const bool int_args_and_retvals_on_device, 284 absl::optional<string> xla_compile_device_type, 285 std::function<Rendezvous*(const int64_t)> rendezvous_creator, 286 std::function<int64_t()> get_op_id) 287 : KernelAndDevice(flr, runner, std::move(collective_executor), 288 host_cpu_device), 289 pflr_(pflr), 290 handle_(kInvalidHandle), 291 outputs_on_op_device_(outputs_on_op_device), 292 allow_small_function_optimizations_(allow_small_function_optimizations), 293 allow_control_flow_sync_execution_(allow_control_flow_sync_execution), 294 shape_inference_on_tfe_dialect_import_( 295 shape_inference_on_tfe_dialect_import), 296 int_args_and_retvals_on_device_(int_args_and_retvals_on_device), 297 xla_compile_device_type_(xla_compile_device_type), 298 input_devices_(std::move(input_devices)), 299 composite_devices_(std::move(composite_devices)), 300 input_resource_dtypes_and_shapes_( 301 std::move(input_resource_dtypes_and_shapes)), 302 name_(name), 303 rendezvous_creator_(std::move(rendezvous_creator)), 304 get_op_id_(std::move(get_op_id)) {} 305 306 ~KernelAndDeviceFunc() override; 307 IsFunction()308 bool IsFunction() override { return true; }; 309 IsCrossProcess()310 bool IsCrossProcess() override { return is_cross_process_; } 311 312 Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef, 313 GraphCollector* graph_collector); 314 315 Status Init(const bool log_device_placement, const NodeDef& ndef, 316 GraphCollector* graph_collector) override; 317 318 Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 319 std::vector<EagerKernelRet>* outputs, 320 CancellationManager* cancellation_manager, 321 const absl::optional<EagerFunctionParams>& eager_func_params, 322 const absl::optional<ManagedStackTrace>& stack_trace, 323 CoordinationServiceAgent* coordination_service_agent) override; 324 325 void RunAsync(ScopedStepContainer* step_container, 326 const EagerKernelArgs& inputs, 327 std::vector<EagerKernelRet>* outputs, 328 CancellationManager* cancellation_manager, 329 const absl::optional<EagerFunctionParams>& eager_func_params, 330 CoordinationServiceAgent* coordination_service_agent, 331 StatusCallback done) override; 332 kernel()333 const OpKernel* kernel() const override { return nullptr; } 334 335 Device* InputDevice(int i) const override; 336 Device* OutputDevice(int idx) const override; 337 Device* OutputResourceDevice(int idx) const override; 338 input_dtypes()339 const DataTypeVector& input_dtypes() const override { return input_dtypes_; } output_dtypes()340 const DataTypeVector& output_dtypes() const override { 341 return output_dtypes_; 342 } num_inputs()343 int num_inputs() const override { return input_dtypes_.size(); } num_outputs()344 int num_outputs() const override { return output_dtypes_.size(); } name()345 const string& name() const override { return name_; }; 346 347 private: 348 std::shared_ptr<FunctionLibraryRuntime::Options> PrepareForRun( 349 ScopedStepContainer* step_container, std::vector<EagerKernelRet>* outputs, 350 CancellationManager* cancellation_manager, 351 const absl::optional<EagerFunctionParams>& eager_func_params, 352 const absl::optional<ManagedStackTrace>& stack_trace, 353 CoordinationServiceAgent* coordination_service_agent); 354 355 ProcessFunctionLibraryRuntime* const pflr_; // non-null 356 FunctionLibraryRuntime::Handle handle_; 357 // Indicates whether the function needs to execute cross process. 358 bool is_cross_process_; 359 360 // If true, function outputs are explicitly assigned to the default device; 361 // if false, the output devices are inferred by pflr_. 362 bool outputs_on_op_device_; 363 364 // If True, allow optimizations which should be targeted at a limited 365 // set of small functions. (For example, running kernels synchronously can 366 // be faster under some conditions.) 367 const bool allow_small_function_optimizations_; 368 369 // If True, allows control nodes to run on the single threaded executor. 370 const bool allow_control_flow_sync_execution_; 371 372 // TODO(b/176491312): Remove this if shape inference on import flag is 373 // removed. If True, allows mlir roundtrip to run shape inference on import. 374 const bool shape_inference_on_tfe_dialect_import_; 375 376 const bool int_args_and_retvals_on_device_; 377 378 const absl::optional<string> xla_compile_device_type_; 379 380 // CPU devices are null. Resource handles' devices are actual backing 381 // devices. 382 std::vector<Device*> output_devices_; 383 // CPU devices are not null. Resource handles' devices are actual backing 384 // devices. 385 std::vector<Device*> input_devices_; 386 // Maps from a CompositeDevice name to a list of physical device names. 387 absl::flat_hash_map<string, const std::vector<string>*> composite_devices_; 388 std::unordered_map<int, DtypeAndPartialTensorShape> 389 input_resource_dtypes_and_shapes_; 390 391 DataTypeVector input_dtypes_; 392 DataTypeVector output_dtypes_; 393 string name_; 394 395 std::function<Rendezvous*(const int64_t)> rendezvous_creator_; 396 std::function<int64_t()> get_op_id_; 397 }; 398 399 } // namespace tensorflow 400 401 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 402