1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 // clang-format off 25 // Required for IS_MOBILE_PLATFORM 26 #include "absl/memory/memory.h" 27 #include "tensorflow/core/platform/platform.h" 28 // clang-format on 29 30 #include "absl/container/flat_hash_map.h" 31 #include "absl/types/optional.h" 32 #include "tensorflow/core/common_runtime/device.h" 33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 34 #include "tensorflow/core/framework/cancellation.h" 35 #include "tensorflow/core/framework/collective.h" 36 #include "tensorflow/core/framework/node_def.pb.h" 37 #include "tensorflow/core/framework/op_kernel.h" 38 #include "tensorflow/core/framework/types.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/gtl/inlined_vector.h" 42 #include "tensorflow/core/platform/fingerprint.h" 43 #include "tensorflow/core/util/managed_stack_trace.h" 44 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 45 #if !defined(IS_MOBILE_PLATFORM) 46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 47 #endif // IS_MOBILE_PLATFORM 48 49 namespace tensorflow { 50 51 static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice"; 52 53 class ProcessFunctionLibraryRuntime; 54 class FunctionLibraryRuntime; 55 56 struct EagerRemoteFunctionParams { 57 int64 op_id; 58 // Set when this function is a component function. 59 absl::optional<int64> step_id = absl::nullopt; 60 }; 61 62 class EagerKernelArgs : public FunctionArgsInterface { 63 public: EagerKernelArgs()64 EagerKernelArgs() {} 65 EagerKernelArgs(int count)66 explicit EagerKernelArgs(int count) : tensor_args_(count) {} 67 EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)68 explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args) 69 : tensor_args_(std::move(tensor_args)) {} 70 ~EagerKernelArgs()71 ~EagerKernelArgs() override{}; 72 HasRemoteOrPackedInputs()73 bool HasRemoteOrPackedInputs() const override { return false; }; MutableInput(int i)74 TensorValue* MutableInput(int i) { return &tensor_args_[i]; } 75 76 Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; 77 78 std::vector<Tensor> GetLocalTensors() const override; 79 GetTensorValues()80 const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const { 81 return &tensor_args_; 82 } 83 84 protected: 85 gtl::InlinedVector<TensorValue, 4> tensor_args_; 86 }; 87 88 typedef absl::variant<Tensor, TensorShape> EagerKernelRet; 89 90 // KernelAndDevice encapsulates the logic needed to run a computation eagerly. 91 // The computation can be a single instantiated kernel (implemented by 92 // KernelAndDeviceOp below) or a multi-device function (implemented by 93 // KernelAndDeviceFunc below). 94 // 95 // Also see: 96 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h 97 // and 98 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h 99 class KernelAndDevice : public core::RefCounted { 100 public: 101 // Populates this with a kernel appropriate for 'ndef'. 102 // 103 // The provided FunctionLibraryRuntime MUST outlive all calls to 104 // Run() on the returned KernelAndDevice. 105 virtual Status Init(const bool log_device_placement, const NodeDef& ndef, 106 GraphCollector* graph_collector) = 0; 107 108 // Non-multi-device functions are run using regular CallOp and look like 109 // primitive operations from KernelAndDevice perspective. 110 // `flr` can be nullptr if the operation is not run on any specific device 111 // (currently can happen only for multi-device functions). KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)112 KernelAndDevice( 113 FunctionLibraryRuntime* flr, 114 std::function<void(std::function<void()>)>* runner, 115 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 116 Device* host_cpu_device) 117 : device_(flr == nullptr ? nullptr : flr->device()), 118 host_cpu_device_(host_cpu_device), 119 flr_(flr), 120 collective_executor_(std::move(collective_executor)), 121 runner_(runner) {} 122 123 // Not thread safe. ~KernelAndDevice()124 ~KernelAndDevice() override {} 125 IsFunction()126 virtual bool IsFunction() { return false; } 127 IsCrossProcess()128 virtual bool IsCrossProcess() { return false; } 129 130 // TODO(ashankar): Handle list-valued inputs. 131 virtual Status Run( 132 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 133 std::vector<EagerKernelRet>* outputs, 134 CancellationManager* cancellation_manager, 135 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 136 const absl::optional<ManagedStackTrace>& stack_trace, 137 CoordinationServiceAgent* coordination_service_agent) = 0; 138 139 // Execute kernel asynchronously when applicable. Different from `Run` which 140 // blocks the caller thread and waits for the execution of the op/function, 141 // `RunAsync` could return before finishing the execution. The `done` callback 142 // will be triggered once the op/function execution finishes. 143 // Currently, calling RunAsync on ops might not honor the asynchronicity when 144 // it is called on an instance with only sync implementation, execute the 145 // kernel synchronously and then call the callback with the return status 146 // from sync execution. 147 virtual void RunAsync( 148 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 149 std::vector<EagerKernelRet>* outputs, 150 CancellationManager* cancellation_manager, 151 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 152 CoordinationServiceAgent* coordination_service_agent, 153 StatusCallback done) = 0; 154 155 virtual Device* InputDevice(int i) const = 0; 156 virtual Device* OutputDevice(int idx) const = 0; 157 // If idx'th output is a resource, returns the device backing the resource. 158 // Else, returns nullptr. 159 virtual Device* OutputResourceDevice(int idx) const = 0; 160 161 // Returns the kernel that will be used to run this. 162 // Returns nullptr if this will be run using function library runtime. 163 virtual const OpKernel* kernel() const = 0; 164 165 // Returns the device on which this kernel will run. In the case of 166 // multi-device functions, this is the default device that is passed to the 167 // placer but actual computation can happen on a different set of devices. 168 // Also, outputs can be produced on devices different from what this method 169 // returns. device()170 Device* device() const { return device_; } 171 172 virtual const DataTypeVector& input_dtypes() const = 0; 173 virtual const DataTypeVector& output_dtypes() const = 0; 174 175 virtual int num_inputs() const = 0; 176 virtual int num_outputs() const = 0; 177 virtual const string& name() const = 0; 178 179 protected: 180 std::function<void(std::function<void()>)>* get_runner() const; 181 182 Device* const device_; // can be null 183 Device* const host_cpu_device_; // non-null 184 FunctionLibraryRuntime* const flr_; // can be null 185 const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_; 186 187 private: 188 std::function<void(std::function<void()>)>* const runner_; // can be null 189 }; 190 191 // Represents an op kernel and the device it will be run on. 192 class KernelAndDeviceOp final : public KernelAndDevice { 193 public: KernelAndDeviceOp(tensorflow::Rendezvous * rendezvous,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)194 KernelAndDeviceOp( 195 tensorflow::Rendezvous* rendezvous, bool log_memory, 196 FunctionLibraryRuntime* flr, 197 std::function<void(std::function<void()>)>* runner, 198 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 199 Device* host_cpu_device) 200 : KernelAndDevice(flr, runner, std::move(collective_executor), 201 host_cpu_device), 202 rendezvous_(rendezvous), 203 log_memory_(log_memory) {} 204 ~KernelAndDeviceOp()205 ~KernelAndDeviceOp() override {} 206 207 Status Init(const bool log_device_placement, const NodeDef& ndef, 208 GraphCollector* graph_collector) override; 209 210 Status Run( 211 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 212 std::vector<EagerKernelRet>* outputs, 213 CancellationManager* cancellation_manager, 214 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 215 const absl::optional<ManagedStackTrace>& stack_trace, 216 CoordinationServiceAgent* coordination_service_agent) override; 217 RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,CoordinationServiceAgent * coordination_service_agent,StatusCallback done)218 void RunAsync( 219 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 220 std::vector<EagerKernelRet>* outputs, 221 CancellationManager* cancellation_manager, 222 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 223 CoordinationServiceAgent* coordination_service_agent, 224 StatusCallback done) override { 225 // Trivial async implementation on top of the sync version 226 done(Run(step_container, inputs, outputs, cancellation_manager, 227 remote_func_params, {}, coordination_service_agent)); 228 } 229 kernel()230 const OpKernel* kernel() const override { return kernel_.get(); } 231 232 Device* InputDevice(int i) const override; 233 Device* OutputDevice(int idx) const override; 234 Device* OutputResourceDevice(int idx) const override; 235 input_dtypes()236 const DataTypeVector& input_dtypes() const override { 237 return kernel_->input_types(); 238 } output_dtypes()239 const DataTypeVector& output_dtypes() const override { 240 return kernel_->output_types(); 241 } num_inputs()242 int num_inputs() const override { return kernel_->num_inputs(); } num_outputs()243 int num_outputs() const override { return kernel_->num_outputs(); } name()244 const string& name() const override { return kernel_->name(); } 245 246 private: 247 std::unique_ptr<OpKernel> kernel_; 248 bool is_distributed_communication_op_; 249 gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_; 250 std::vector<Device*> input_devices_; 251 gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_; 252 Rendezvous* const rendezvous_; 253 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; 254 const bool log_memory_; 255 }; 256 257 // Represents a multi-device function. Functions can also be run using 258 // various function-calling kernels including CallOp and PartitionedCallOp. 259 // In such cases, KernelAndDeviceOp is used. 260 class KernelAndDeviceFunc : public KernelAndDevice { 261 public: 262 // `flr` can be nullptr. 263 // `pflr` must not be nullptr. 264 // `host_cpu_device` must not be nullptr. KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,absl::flat_hash_map<string,const std::vector<string> * > composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,const bool outputs_on_op_device,std::function<Rendezvous * (const int64_t)> rendezvous_creator,std::function<int64 ()> get_op_id)265 KernelAndDeviceFunc( 266 FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, 267 std::vector<Device*> input_devices, 268 absl::flat_hash_map<string, const std::vector<string>*> composite_devices, 269 std::unordered_map<int, DtypeAndPartialTensorShape> 270 input_resource_dtypes_and_shapes, 271 std::function<void(std::function<void()>)>* runner, 272 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 273 Device* host_cpu_device, const string& name, 274 const bool outputs_on_op_device, 275 std::function<Rendezvous*(const int64_t)> rendezvous_creator, 276 std::function<int64()> get_op_id) 277 : KernelAndDevice(flr, runner, std::move(collective_executor), 278 host_cpu_device), 279 pflr_(pflr), 280 handle_(kInvalidHandle), 281 outputs_on_op_device_(outputs_on_op_device), 282 input_devices_(std::move(input_devices)), 283 composite_devices_(std::move(composite_devices)), 284 input_resource_dtypes_and_shapes_( 285 std::move(input_resource_dtypes_and_shapes)), 286 name_(name), 287 rendezvous_creator_(std::move(rendezvous_creator)), 288 get_op_id_(std::move(get_op_id)) {} 289 290 ~KernelAndDeviceFunc() override; 291 IsFunction()292 bool IsFunction() override { return true; }; 293 IsCrossProcess()294 bool IsCrossProcess() override { return is_cross_process_; } 295 296 Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef, 297 GraphCollector* graph_collector); 298 299 Status Init(const bool log_device_placement, const NodeDef& ndef, 300 GraphCollector* graph_collector) override; 301 302 Status Run( 303 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 304 std::vector<EagerKernelRet>* outputs, 305 CancellationManager* cancellation_manager, 306 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 307 const absl::optional<ManagedStackTrace>& stack_trace, 308 CoordinationServiceAgent* coordination_service_agent) override; 309 310 void RunAsync( 311 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 312 std::vector<EagerKernelRet>* outputs, 313 CancellationManager* cancellation_manager, 314 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 315 CoordinationServiceAgent* coordination_service_agent, 316 StatusCallback done) override; 317 kernel()318 const OpKernel* kernel() const override { return nullptr; } 319 320 Device* InputDevice(int i) const override; 321 Device* OutputDevice(int idx) const override; 322 Device* OutputResourceDevice(int idx) const override; 323 input_dtypes()324 const DataTypeVector& input_dtypes() const override { return input_dtypes_; } output_dtypes()325 const DataTypeVector& output_dtypes() const override { 326 return output_dtypes_; 327 } num_inputs()328 int num_inputs() const override { return input_dtypes_.size(); } num_outputs()329 int num_outputs() const override { return output_dtypes_.size(); } name()330 const string& name() const override { return name_; }; 331 332 private: 333 std::shared_ptr<FunctionLibraryRuntime::Options> PrepareForRun( 334 ScopedStepContainer* step_container, std::vector<EagerKernelRet>* outputs, 335 CancellationManager* cancellation_manager, 336 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 337 const absl::optional<ManagedStackTrace>& stack_trace, 338 CoordinationServiceAgent* coordination_service_agent); 339 340 ProcessFunctionLibraryRuntime* const pflr_; // non-null 341 FunctionLibraryRuntime::Handle handle_; 342 // Indicates whether the function needs to execute cross process. 343 bool is_cross_process_; 344 345 // If true, function outputs are explicitly assigned to the default device; 346 // if false, the output devices are inferred by pflr_. 347 bool outputs_on_op_device_; 348 349 // CPU devices are null. Resource handles' devices are actual backing 350 // devices. 351 std::vector<Device*> output_devices_; 352 // CPU devices are not null. Resource handles' devices are actual backing 353 // devices. 354 std::vector<Device*> input_devices_; 355 // Maps from a CompositeDevice name to a list of physical device names. 356 absl::flat_hash_map<string, const std::vector<string>*> composite_devices_; 357 std::unordered_map<int, DtypeAndPartialTensorShape> 358 input_resource_dtypes_and_shapes_; 359 360 DataTypeVector input_dtypes_; 361 DataTypeVector output_dtypes_; 362 string name_; 363 364 std::function<Rendezvous*(const int64_t)> rendezvous_creator_; 365 std::function<int64()> get_op_id_; 366 }; 367 368 } // namespace tensorflow 369 370 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 371