• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/platform/platform.h"
28 // clang-format on
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/types/optional.h"
32 #include "tensorflow/core/common_runtime/device.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/platform/fingerprint.h"
43 #include "tensorflow/core/util/managed_stack_trace.h"
44 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
45 #if !defined(IS_MOBILE_PLATFORM)
46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
47 #endif  // IS_MOBILE_PLATFORM
48 
49 namespace tensorflow {
50 
51 static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice";
52 
53 class ProcessFunctionLibraryRuntime;
54 class FunctionLibraryRuntime;
55 
56 struct EagerRemoteFunctionParams {
57   int64 op_id;
58   // Set when this function is a component function.
59   absl::optional<int64> step_id = absl::nullopt;
60 };
61 
62 class EagerKernelArgs : public FunctionArgsInterface {
63  public:
EagerKernelArgs()64   EagerKernelArgs() {}
65 
EagerKernelArgs(int count)66   explicit EagerKernelArgs(int count) : tensor_args_(count) {}
67 
EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)68   explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args)
69       : tensor_args_(std::move(tensor_args)) {}
70 
~EagerKernelArgs()71   ~EagerKernelArgs() override{};
72 
HasRemoteOrPackedInputs()73   bool HasRemoteOrPackedInputs() const override { return false; };
MutableInput(int i)74   TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
75 
76   Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
77 
78   std::vector<Tensor> GetLocalTensors() const override;
79 
GetTensorValues()80   const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
81     return &tensor_args_;
82   }
83 
84  protected:
85   gtl::InlinedVector<TensorValue, 4> tensor_args_;
86 };
87 
88 typedef absl::variant<Tensor, TensorShape> EagerKernelRet;
89 
90 // KernelAndDevice encapsulates the logic needed to run a computation eagerly.
91 // The computation can be a single instantiated kernel (implemented by
92 // KernelAndDeviceOp below) or a multi-device function (implemented by
93 // KernelAndDeviceFunc below).
94 //
95 // Also see:
96 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
97 // and
98 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
99 class KernelAndDevice : public core::RefCounted {
100  public:
101   // Populates this with a kernel appropriate for 'ndef'.
102   //
103   // The provided FunctionLibraryRuntime MUST outlive all calls to
104   // Run() on the returned KernelAndDevice.
105   virtual Status Init(const bool log_device_placement, const NodeDef& ndef,
106                       GraphCollector* graph_collector) = 0;
107 
108   // Non-multi-device functions are run using regular CallOp and look like
109   // primitive operations from KernelAndDevice perspective.
110   // `flr` can be nullptr if the operation is not run on any specific device
111   // (currently can happen only for multi-device functions).
KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)112   KernelAndDevice(
113       FunctionLibraryRuntime* flr,
114       std::function<void(std::function<void()>)>* runner,
115       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
116       Device* host_cpu_device)
117       : device_(flr == nullptr ? nullptr : flr->device()),
118         host_cpu_device_(host_cpu_device),
119         flr_(flr),
120         collective_executor_(std::move(collective_executor)),
121         runner_(runner) {}
122 
123   // Not thread safe.
~KernelAndDevice()124   ~KernelAndDevice() override {}
125 
IsFunction()126   virtual bool IsFunction() { return false; }
127 
IsCrossProcess()128   virtual bool IsCrossProcess() { return false; }
129 
130   // TODO(ashankar): Handle list-valued inputs.
131   virtual Status Run(
132       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
133       std::vector<EagerKernelRet>* outputs,
134       CancellationManager* cancellation_manager,
135       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
136       const absl::optional<ManagedStackTrace>& stack_trace,
137       CoordinationServiceAgent* coordination_service_agent) = 0;
138 
139   // Execute kernel asynchronously when applicable. Different from `Run` which
140   // blocks the caller thread and waits for the execution of the op/function,
141   // `RunAsync` could return before finishing the execution. The `done` callback
142   // will be triggered once the op/function execution finishes.
143   // Currently, calling RunAsync on ops might not honor the asynchronicity when
144   // it is called on an instance with only sync implementation, execute the
145   // kernel synchronously and then call the callback with the return status
146   // from sync execution.
147   virtual void RunAsync(
148       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
149       std::vector<EagerKernelRet>* outputs,
150       CancellationManager* cancellation_manager,
151       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
152       CoordinationServiceAgent* coordination_service_agent,
153       StatusCallback done) = 0;
154 
155   virtual Device* InputDevice(int i) const = 0;
156   virtual Device* OutputDevice(int idx) const = 0;
157   // If idx'th output is a resource, returns the device backing the resource.
158   // Else, returns nullptr.
159   virtual Device* OutputResourceDevice(int idx) const = 0;
160 
161   // Returns the kernel that will be used to run this.
162   // Returns nullptr if this will be run using function library runtime.
163   virtual const OpKernel* kernel() const = 0;
164 
165   // Returns the device on which this kernel will run. In the case of
166   // multi-device functions, this is the default device that is passed to the
167   // placer but actual computation can happen on a different set of devices.
168   // Also, outputs can be produced on devices different from what this method
169   // returns.
device()170   Device* device() const { return device_; }
171 
172   virtual const DataTypeVector& input_dtypes() const = 0;
173   virtual const DataTypeVector& output_dtypes() const = 0;
174 
175   virtual int num_inputs() const = 0;
176   virtual int num_outputs() const = 0;
177   virtual const string& name() const = 0;
178 
179  protected:
180   std::function<void(std::function<void()>)>* get_runner() const;
181 
182   Device* const device_;               // can be null
183   Device* const host_cpu_device_;      // non-null
184   FunctionLibraryRuntime* const flr_;  // can be null
185   const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
186 
187  private:
188   std::function<void(std::function<void()>)>* const runner_;  // can be null
189 };
190 
191 // Represents an op kernel and the device it will be run on.
192 class KernelAndDeviceOp final : public KernelAndDevice {
193  public:
KernelAndDeviceOp(tensorflow::Rendezvous * rendezvous,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)194   KernelAndDeviceOp(
195       tensorflow::Rendezvous* rendezvous, bool log_memory,
196       FunctionLibraryRuntime* flr,
197       std::function<void(std::function<void()>)>* runner,
198       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
199       Device* host_cpu_device)
200       : KernelAndDevice(flr, runner, std::move(collective_executor),
201                         host_cpu_device),
202         rendezvous_(rendezvous),
203         log_memory_(log_memory) {}
204 
~KernelAndDeviceOp()205   ~KernelAndDeviceOp() override {}
206 
207   Status Init(const bool log_device_placement, const NodeDef& ndef,
208               GraphCollector* graph_collector) override;
209 
210   Status Run(
211       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
212       std::vector<EagerKernelRet>* outputs,
213       CancellationManager* cancellation_manager,
214       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
215       const absl::optional<ManagedStackTrace>& stack_trace,
216       CoordinationServiceAgent* coordination_service_agent) override;
217 
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,CoordinationServiceAgent * coordination_service_agent,StatusCallback done)218   void RunAsync(
219       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
220       std::vector<EagerKernelRet>* outputs,
221       CancellationManager* cancellation_manager,
222       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
223       CoordinationServiceAgent* coordination_service_agent,
224       StatusCallback done) override {
225     // Trivial async implementation on top of the sync version
226     done(Run(step_container, inputs, outputs, cancellation_manager,
227              remote_func_params, {}, coordination_service_agent));
228   }
229 
kernel()230   const OpKernel* kernel() const override { return kernel_.get(); }
231 
232   Device* InputDevice(int i) const override;
233   Device* OutputDevice(int idx) const override;
234   Device* OutputResourceDevice(int idx) const override;
235 
input_dtypes()236   const DataTypeVector& input_dtypes() const override {
237     return kernel_->input_types();
238   }
output_dtypes()239   const DataTypeVector& output_dtypes() const override {
240     return kernel_->output_types();
241   }
num_inputs()242   int num_inputs() const override { return kernel_->num_inputs(); }
num_outputs()243   int num_outputs() const override { return kernel_->num_outputs(); }
name()244   const string& name() const override { return kernel_->name(); }
245 
246  private:
247   std::unique_ptr<OpKernel> kernel_;
248   bool is_distributed_communication_op_;
249   gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_;
250   std::vector<Device*> input_devices_;
251   gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_;
252   Rendezvous* const rendezvous_;
253   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
254   const bool log_memory_;
255 };
256 
257 // Represents a multi-device function. Functions can also be run using
258 // various function-calling kernels including CallOp and PartitionedCallOp.
259 // In such cases, KernelAndDeviceOp is used.
260 class KernelAndDeviceFunc : public KernelAndDevice {
261  public:
262   // `flr` can be nullptr.
263   // `pflr` must not be nullptr.
264   // `host_cpu_device` must not be nullptr.
KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,absl::flat_hash_map<string,const std::vector<string> * > composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,const bool outputs_on_op_device,std::function<Rendezvous * (const int64_t)> rendezvous_creator,std::function<int64 ()> get_op_id)265   KernelAndDeviceFunc(
266       FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
267       std::vector<Device*> input_devices,
268       absl::flat_hash_map<string, const std::vector<string>*> composite_devices,
269       std::unordered_map<int, DtypeAndPartialTensorShape>
270           input_resource_dtypes_and_shapes,
271       std::function<void(std::function<void()>)>* runner,
272       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
273       Device* host_cpu_device, const string& name,
274       const bool outputs_on_op_device,
275       std::function<Rendezvous*(const int64_t)> rendezvous_creator,
276       std::function<int64()> get_op_id)
277       : KernelAndDevice(flr, runner, std::move(collective_executor),
278                         host_cpu_device),
279         pflr_(pflr),
280         handle_(kInvalidHandle),
281         outputs_on_op_device_(outputs_on_op_device),
282         input_devices_(std::move(input_devices)),
283         composite_devices_(std::move(composite_devices)),
284         input_resource_dtypes_and_shapes_(
285             std::move(input_resource_dtypes_and_shapes)),
286         name_(name),
287         rendezvous_creator_(std::move(rendezvous_creator)),
288         get_op_id_(std::move(get_op_id)) {}
289 
290   ~KernelAndDeviceFunc() override;
291 
IsFunction()292   bool IsFunction() override { return true; };
293 
IsCrossProcess()294   bool IsCrossProcess() override { return is_cross_process_; }
295 
296   Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef,
297                          GraphCollector* graph_collector);
298 
299   Status Init(const bool log_device_placement, const NodeDef& ndef,
300               GraphCollector* graph_collector) override;
301 
302   Status Run(
303       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
304       std::vector<EagerKernelRet>* outputs,
305       CancellationManager* cancellation_manager,
306       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
307       const absl::optional<ManagedStackTrace>& stack_trace,
308       CoordinationServiceAgent* coordination_service_agent) override;
309 
310   void RunAsync(
311       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
312       std::vector<EagerKernelRet>* outputs,
313       CancellationManager* cancellation_manager,
314       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
315       CoordinationServiceAgent* coordination_service_agent,
316       StatusCallback done) override;
317 
kernel()318   const OpKernel* kernel() const override { return nullptr; }
319 
320   Device* InputDevice(int i) const override;
321   Device* OutputDevice(int idx) const override;
322   Device* OutputResourceDevice(int idx) const override;
323 
input_dtypes()324   const DataTypeVector& input_dtypes() const override { return input_dtypes_; }
output_dtypes()325   const DataTypeVector& output_dtypes() const override {
326     return output_dtypes_;
327   }
num_inputs()328   int num_inputs() const override { return input_dtypes_.size(); }
num_outputs()329   int num_outputs() const override { return output_dtypes_.size(); }
name()330   const string& name() const override { return name_; };
331 
332  private:
333   std::shared_ptr<FunctionLibraryRuntime::Options> PrepareForRun(
334       ScopedStepContainer* step_container, std::vector<EagerKernelRet>* outputs,
335       CancellationManager* cancellation_manager,
336       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
337       const absl::optional<ManagedStackTrace>& stack_trace,
338       CoordinationServiceAgent* coordination_service_agent);
339 
340   ProcessFunctionLibraryRuntime* const pflr_;  // non-null
341   FunctionLibraryRuntime::Handle handle_;
342   // Indicates whether the function needs to execute cross process.
343   bool is_cross_process_;
344 
345   // If true, function outputs are explicitly assigned to the default device;
346   // if false, the output devices are inferred by pflr_.
347   bool outputs_on_op_device_;
348 
349   // CPU devices are null. Resource handles' devices are actual backing
350   // devices.
351   std::vector<Device*> output_devices_;
352   // CPU devices are not null. Resource handles' devices are actual backing
353   // devices.
354   std::vector<Device*> input_devices_;
355   // Maps from a CompositeDevice name to a list of physical device names.
356   absl::flat_hash_map<string, const std::vector<string>*> composite_devices_;
357   std::unordered_map<int, DtypeAndPartialTensorShape>
358       input_resource_dtypes_and_shapes_;
359 
360   DataTypeVector input_dtypes_;
361   DataTypeVector output_dtypes_;
362   string name_;
363 
364   std::function<Rendezvous*(const int64_t)> rendezvous_creator_;
365   std::function<int64()> get_op_id_;
366 };
367 
368 }  // namespace tensorflow
369 
370 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
371