• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/platform/platform.h"
28 // clang-format on
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/types/optional.h"
32 #include "tensorflow/core/common_runtime/device.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/platform/fingerprint.h"
43 #include "tensorflow/core/util/managed_stack_trace.h"
44 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
45 #if !defined(IS_MOBILE_PLATFORM)
46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
47 #endif  // IS_MOBILE_PLATFORM
48 
49 namespace tensorflow {
50 
51 static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice";
52 
53 class ProcessFunctionLibraryRuntime;
54 class FunctionLibraryRuntime;
55 
56 const int64_t kInvalidOpId = -1;
57 
58 // This struc is used for:
59 // 1. setting op_id and step_id, is_component_function for single-client
60 // remote function scenario,
61 // 2. setting step_id for multi-client parallel_device scenario.
62 struct EagerFunctionParams {
63   int64_t op_id = kInvalidOpId;
64   bool is_component_function;
65   absl::optional<int64_t> step_id = absl::nullopt;
66 };
67 
68 class EagerKernelArgs : public FunctionArgsInterface {
69  public:
EagerKernelArgs()70   EagerKernelArgs() {}
71 
EagerKernelArgs(int count)72   explicit EagerKernelArgs(int count) : tensor_args_(count) {}
73 
EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)74   explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args)
75       : tensor_args_(std::move(tensor_args)) {}
76 
~EagerKernelArgs()77   ~EagerKernelArgs() override{};
78 
HasRemoteOrPackedInputs()79   bool HasRemoteOrPackedInputs() const override { return false; };
MutableInput(int i)80   TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
81 
82   Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
83 
84   std::vector<Tensor> GetLocalTensors() const override;
85 
GetTensorValues()86   const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
87     return &tensor_args_;
88   }
89 
90  protected:
91   gtl::InlinedVector<TensorValue, 4> tensor_args_;
92 };
93 
94 typedef absl::variant<Tensor, TensorShape> EagerKernelRet;
95 
96 // KernelAndDevice encapsulates the logic needed to run a computation eagerly.
97 // The computation can be a single instantiated kernel (implemented by
98 // KernelAndDeviceOp below) or a multi-device function (implemented by
99 // KernelAndDeviceFunc below).
100 //
101 // Also see:
102 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
103 // and
104 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
105 class KernelAndDevice : public core::RefCounted {
106  public:
107   // Populates this with a kernel appropriate for 'ndef'.
108   //
109   // The provided FunctionLibraryRuntime MUST outlive all calls to
110   // Run() on the returned KernelAndDevice.
111   virtual Status Init(const bool log_device_placement, const NodeDef& ndef,
112                       GraphCollector* graph_collector) = 0;
113 
114   // Non-multi-device functions are run using regular CallOp and look like
115   // primitive operations from KernelAndDevice perspective.
116   // `flr` can be nullptr if the operation is not run on any specific device
117   // (currently can happen only for multi-device functions).
KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)118   KernelAndDevice(
119       FunctionLibraryRuntime* flr,
120       std::function<void(std::function<void()>)>* runner,
121       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
122       Device* host_cpu_device)
123       : device_(flr == nullptr ? nullptr : flr->device()),
124         host_cpu_device_(host_cpu_device),
125         flr_(flr),
126         collective_executor_(std::move(collective_executor)),
127         runner_(runner) {}
128 
129   // Not thread safe.
~KernelAndDevice()130   ~KernelAndDevice() override {}
131 
IsFunction()132   virtual bool IsFunction() { return false; }
133 
IsCrossProcess()134   virtual bool IsCrossProcess() { return false; }
135 
136   // TODO(ashankar): Handle list-valued inputs.
137   virtual Status Run(
138       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
139       std::vector<EagerKernelRet>* outputs,
140       CancellationManager* cancellation_manager,
141       const absl::optional<EagerFunctionParams>& eager_func_params,
142       const absl::optional<ManagedStackTrace>& stack_trace,
143       CoordinationServiceAgent* coordination_service_agent) = 0;
144 
145   // Execute kernel asynchronously when applicable. Different from `Run` which
146   // blocks the caller thread and waits for the execution of the op/function,
147   // `RunAsync` could return before finishing the execution. The `done` callback
148   // will be triggered once the op/function execution finishes.
149   // Currently, calling RunAsync on ops might not honor the asynchronicity when
150   // it is called on an instance with only sync implementation, execute the
151   // kernel synchronously and then call the callback with the return status
152   // from sync execution.
153   virtual void RunAsync(
154       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
155       std::vector<EagerKernelRet>* outputs,
156       CancellationManager* cancellation_manager,
157       const absl::optional<EagerFunctionParams>& eager_func_params,
158       CoordinationServiceAgent* coordination_service_agent,
159       StatusCallback done) = 0;
160 
161   virtual Device* InputDevice(int i) const = 0;
162   virtual Device* OutputDevice(int idx) const = 0;
163   // If idx'th output is a resource, returns the device backing the resource.
164   // Else, returns nullptr.
165   virtual Device* OutputResourceDevice(int idx) const = 0;
166 
167   // Returns the kernel that will be used to run this.
168   // Returns nullptr if this will be run using function library runtime.
169   virtual const OpKernel* kernel() const = 0;
170 
171   // Returns the device on which this kernel will run. In the case of
172   // multi-device functions, this is the default device that is passed to the
173   // placer but actual computation can happen on a different set of devices.
174   // Also, outputs can be produced on devices different from what this method
175   // returns.
device()176   Device* device() const { return device_; }
177 
178   virtual const DataTypeVector& input_dtypes() const = 0;
179   virtual const DataTypeVector& output_dtypes() const = 0;
180 
181   virtual int num_inputs() const = 0;
182   virtual int num_outputs() const = 0;
183   virtual const string& name() const = 0;
184 
185  protected:
186   std::function<void(std::function<void()>)>* get_runner() const;
187 
188   Device* const device_;               // can be null
189   Device* const host_cpu_device_;      // non-null
190   FunctionLibraryRuntime* const flr_;  // can be null
191   const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
192 
193  private:
194   std::function<void(std::function<void()>)>* const runner_;  // can be null
195 };
196 
197 // Represents an op kernel and the device it will be run on.
198 class KernelAndDeviceOp final : public KernelAndDevice {
199  public:
KernelAndDeviceOp(tensorflow::Rendezvous * rendezvous,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)200   KernelAndDeviceOp(
201       tensorflow::Rendezvous* rendezvous, bool log_memory,
202       FunctionLibraryRuntime* flr,
203       std::function<void(std::function<void()>)>* runner,
204       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
205       Device* host_cpu_device)
206       : KernelAndDevice(flr, runner, std::move(collective_executor),
207                         host_cpu_device),
208         rendezvous_(rendezvous),
209         log_memory_(log_memory) {}
210 
~KernelAndDeviceOp()211   ~KernelAndDeviceOp() override {}
212 
213   Status Init(const bool log_device_placement, const NodeDef& ndef,
214               GraphCollector* graph_collector) override;
215 
216   Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
217              std::vector<EagerKernelRet>* outputs,
218              CancellationManager* cancellation_manager,
219              const absl::optional<EagerFunctionParams>& eager_func_params,
220              const absl::optional<ManagedStackTrace>& stack_trace,
221              CoordinationServiceAgent* coordination_service_agent) override;
222 
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerFunctionParams> & eager_func_params,CoordinationServiceAgent * coordination_service_agent,StatusCallback done)223   void RunAsync(ScopedStepContainer* step_container,
224                 const EagerKernelArgs& inputs,
225                 std::vector<EagerKernelRet>* outputs,
226                 CancellationManager* cancellation_manager,
227                 const absl::optional<EagerFunctionParams>& eager_func_params,
228                 CoordinationServiceAgent* coordination_service_agent,
229                 StatusCallback done) override {
230     // Trivial async implementation on top of the sync version
231     done(Run(step_container, inputs, outputs, cancellation_manager,
232              eager_func_params, {}, coordination_service_agent));
233   }
234 
kernel()235   const OpKernel* kernel() const override { return kernel_.get(); }
236 
237   Device* InputDevice(int i) const override;
238   Device* OutputDevice(int idx) const override;
239   Device* OutputResourceDevice(int idx) const override;
240 
input_dtypes()241   const DataTypeVector& input_dtypes() const override {
242     return kernel_->input_types();
243   }
output_dtypes()244   const DataTypeVector& output_dtypes() const override {
245     return kernel_->output_types();
246   }
num_inputs()247   int num_inputs() const override { return kernel_->num_inputs(); }
num_outputs()248   int num_outputs() const override { return kernel_->num_outputs(); }
name()249   const string& name() const override { return kernel_->name(); }
250 
251  private:
252   std::unique_ptr<OpKernel> kernel_;
253   bool is_distributed_communication_op_;
254   gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_;
255   std::vector<Device*> input_devices_;
256   gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_;
257   Rendezvous* const rendezvous_;
258   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
259   const bool log_memory_;
260 };
261 
262 // Represents a multi-device function. Functions can also be run using
263 // various function-calling kernels including CallOp and PartitionedCallOp.
264 // In such cases, KernelAndDeviceOp is used.
265 class KernelAndDeviceFunc : public KernelAndDevice {
266  public:
267   // `flr` can be nullptr.
268   // `pflr` must not be nullptr.
269   // `host_cpu_device` must not be nullptr.
KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,absl::flat_hash_map<string,const std::vector<string> * > composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,const bool outputs_on_op_device,const bool allow_small_function_optimizations,const bool allow_control_flow_sync_execution,const bool shape_inference_on_tfe_dialect_import,const bool int_args_and_retvals_on_device,absl::optional<string> xla_compile_device_type,std::function<Rendezvous * (const int64_t)> rendezvous_creator,std::function<int64_t ()> get_op_id)270   KernelAndDeviceFunc(
271       FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
272       std::vector<Device*> input_devices,
273       absl::flat_hash_map<string, const std::vector<string>*> composite_devices,
274       std::unordered_map<int, DtypeAndPartialTensorShape>
275           input_resource_dtypes_and_shapes,
276       std::function<void(std::function<void()>)>* runner,
277       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
278       Device* host_cpu_device, const string& name,
279       const bool outputs_on_op_device,
280       const bool allow_small_function_optimizations,
281       const bool allow_control_flow_sync_execution,
282       const bool shape_inference_on_tfe_dialect_import,
283       const bool int_args_and_retvals_on_device,
284       absl::optional<string> xla_compile_device_type,
285       std::function<Rendezvous*(const int64_t)> rendezvous_creator,
286       std::function<int64_t()> get_op_id)
287       : KernelAndDevice(flr, runner, std::move(collective_executor),
288                         host_cpu_device),
289         pflr_(pflr),
290         handle_(kInvalidHandle),
291         outputs_on_op_device_(outputs_on_op_device),
292         allow_small_function_optimizations_(allow_small_function_optimizations),
293         allow_control_flow_sync_execution_(allow_control_flow_sync_execution),
294         shape_inference_on_tfe_dialect_import_(
295             shape_inference_on_tfe_dialect_import),
296         int_args_and_retvals_on_device_(int_args_and_retvals_on_device),
297         xla_compile_device_type_(xla_compile_device_type),
298         input_devices_(std::move(input_devices)),
299         composite_devices_(std::move(composite_devices)),
300         input_resource_dtypes_and_shapes_(
301             std::move(input_resource_dtypes_and_shapes)),
302         name_(name),
303         rendezvous_creator_(std::move(rendezvous_creator)),
304         get_op_id_(std::move(get_op_id)) {}
305 
306   ~KernelAndDeviceFunc() override;
307 
IsFunction()308   bool IsFunction() override { return true; };
309 
IsCrossProcess()310   bool IsCrossProcess() override { return is_cross_process_; }
311 
312   Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef,
313                          GraphCollector* graph_collector);
314 
315   Status Init(const bool log_device_placement, const NodeDef& ndef,
316               GraphCollector* graph_collector) override;
317 
318   Status Run(ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
319              std::vector<EagerKernelRet>* outputs,
320              CancellationManager* cancellation_manager,
321              const absl::optional<EagerFunctionParams>& eager_func_params,
322              const absl::optional<ManagedStackTrace>& stack_trace,
323              CoordinationServiceAgent* coordination_service_agent) override;
324 
325   void RunAsync(ScopedStepContainer* step_container,
326                 const EagerKernelArgs& inputs,
327                 std::vector<EagerKernelRet>* outputs,
328                 CancellationManager* cancellation_manager,
329                 const absl::optional<EagerFunctionParams>& eager_func_params,
330                 CoordinationServiceAgent* coordination_service_agent,
331                 StatusCallback done) override;
332 
kernel()333   const OpKernel* kernel() const override { return nullptr; }
334 
335   Device* InputDevice(int i) const override;
336   Device* OutputDevice(int idx) const override;
337   Device* OutputResourceDevice(int idx) const override;
338 
input_dtypes()339   const DataTypeVector& input_dtypes() const override { return input_dtypes_; }
output_dtypes()340   const DataTypeVector& output_dtypes() const override {
341     return output_dtypes_;
342   }
num_inputs()343   int num_inputs() const override { return input_dtypes_.size(); }
num_outputs()344   int num_outputs() const override { return output_dtypes_.size(); }
name()345   const string& name() const override { return name_; };
346 
347  private:
348   std::shared_ptr<FunctionLibraryRuntime::Options> PrepareForRun(
349       ScopedStepContainer* step_container, std::vector<EagerKernelRet>* outputs,
350       CancellationManager* cancellation_manager,
351       const absl::optional<EagerFunctionParams>& eager_func_params,
352       const absl::optional<ManagedStackTrace>& stack_trace,
353       CoordinationServiceAgent* coordination_service_agent);
354 
355   ProcessFunctionLibraryRuntime* const pflr_;  // non-null
356   FunctionLibraryRuntime::Handle handle_;
357   // Indicates whether the function needs to execute cross process.
358   bool is_cross_process_;
359 
360   // If true, function outputs are explicitly assigned to the default device;
361   // if false, the output devices are inferred by pflr_.
362   bool outputs_on_op_device_;
363 
364   // If True, allow optimizations which should be targeted at a limited
365   // set of small functions.  (For example, running kernels synchronously can
366   // be faster under some conditions.)
367   const bool allow_small_function_optimizations_;
368 
369   // If True, allows control nodes to run on the single threaded executor.
370   const bool allow_control_flow_sync_execution_;
371 
372   // TODO(b/176491312): Remove this if shape inference on import flag is
373   // removed. If True, allows mlir roundtrip to run shape inference on import.
374   const bool shape_inference_on_tfe_dialect_import_;
375 
376   const bool int_args_and_retvals_on_device_;
377 
378   const absl::optional<string> xla_compile_device_type_;
379 
380   // CPU devices are null. Resource handles' devices are actual backing
381   // devices.
382   std::vector<Device*> output_devices_;
383   // CPU devices are not null. Resource handles' devices are actual backing
384   // devices.
385   std::vector<Device*> input_devices_;
386   // Maps from a CompositeDevice name to a list of physical device names.
387   absl::flat_hash_map<string, const std::vector<string>*> composite_devices_;
388   std::unordered_map<int, DtypeAndPartialTensorShape>
389       input_resource_dtypes_and_shapes_;
390 
391   DataTypeVector input_dtypes_;
392   DataTypeVector output_dtypes_;
393   string name_;
394 
395   std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
396   std::function<int64_t()> get_op_id_;
397 };
398 
399 }  // namespace tensorflow
400 
401 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
402