• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
16 #define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/model.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 class Device;
35 class OpKernelContext;
36 class ResourceMgr;
37 
38 namespace data {
39 
40 class CapturedFunction;
41 class InstantiatedCapturedFunction;
42 
43 // Creates an iterator for a dataset which is created by applying the given
44 // function to the given input element.
45 Status MakeIteratorFromInputElement(
46     IteratorContext* ctx, const IteratorBase* parent,
47     const std::vector<Tensor>& input_element, int64 thread_index,
48     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
49     std::unique_ptr<IteratorBase>* out_iterator);
50 
51 // Creates an iterator for a dataset which is created by applying the given
52 // function to the given input element. Pass non-null `node` to record
53 // processing time for modeling Iterator's GetNext() resource usage.
54 Status MakeIteratorFromInputElement(
55     IteratorContext* ctx, const IteratorBase* parent,
56     const std::vector<Tensor>& input_element, int64 thread_index,
57     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
58     std::unique_ptr<IteratorBase>* out_iterator,
59     const std::shared_ptr<model::Node>& node);
60 
61 // Determines whether the given node is stateful.
62 Status IsNodeStateful(const FunctionLibraryDefinition& library,
63                       const NodeDef& node);
64 
65 struct ShortCircuitInfo {
66   std::vector<int> indices;
67   std::vector<bool> can_move;
68 };
69 
70 // Metadata shared across all captures of the same function.
71 class FunctionMetadata {
72  public:
73   struct Params {
74     bool use_inter_op_parallelism = true;
75     bool use_default_device = true;
76   };
77 
78   // Creates a new instance of the `FunctionMetadata` class, fetching function
79   // from a context argument.
80   static Status Create(tensorflow::OpKernelConstruction* ctx,
81                        const string& func_name, Params params,
82                        std::shared_ptr<FunctionMetadata>* out_metadata);
83 
84   // Creates a new instance of the `FunctionMetadata` class, using the provided
85   // function.
86   static Status Create(tensorflow::OpKernelConstruction* ctx,
87                        NameAttrList&& func, Params params,
88                        std::shared_ptr<FunctionMetadata>* out_metadata);
89 
90   // Returns the named list of function arguments.
func()91   const NameAttrList& func() const { return func_; }
92 
93   // Returns a borrowed pointer to the function library that contains the
94   // transitive closure of definitions used by the function.
lib_def()95   const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); }
96 
97   // Returns short-circuit information.
short_circuit_info()98   const ShortCircuitInfo& short_circuit_info() const {
99     return short_circuit_info_;
100   }
101 
102   // Indicates whether a default device should be used for executing function
103   // ops.
use_default_device()104   bool use_default_device() const { return use_default_device_; }
105 
106   // Indicates whether to use inter-op parallelism for execution of the
107   // function.
use_inter_op_parallelism()108   bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; }
109 
110   // Indicates whether the function should a multi-device function backend.
use_multi_device_function()111   bool use_multi_device_function() const { return use_multi_device_function_; }
112 
113  private:
FunctionMetadata(NameAttrList && func,Params params)114   FunctionMetadata(NameAttrList&& func, Params params)
115       : func_(std::move(func)),
116         use_default_device_(params.use_default_device),
117         use_inter_op_parallelism_(params.use_inter_op_parallelism) {}
118 
119   NameAttrList func_;
120   std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr;
121   ShortCircuitInfo short_circuit_info_;
122   bool use_default_device_ = true;
123   bool use_inter_op_parallelism_ = true;
124   bool use_multi_device_function_ = true;
125 };
126 
127 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured"
128 // arguments that it closed over in the user program.
129 class CapturedFunction {
130  public:
131   // Creates a new instance using a list of named attributes, fetching captured
132   // inputs from a context argument.
133   static Status Create(OpKernelContext* ctx,
134                        std::shared_ptr<const FunctionMetadata> metadata,
135                        const string& argument_name,
136                        std::unique_ptr<CapturedFunction>* out_function);
137 
138   // Creates a new instance using a list of named attributes, using provided
139   // captured inputs.
140   static Status Create(OpKernelContext* ctx,
141                        std::shared_ptr<const FunctionMetadata> metadata,
142                        std::vector<Tensor>&& captured_inputs,
143                        std::unique_ptr<CapturedFunction>* out_function);
144 
145   // Adds the definition of this captured function into the given graph,
146   // returning its captured inputs and types through the respective output
147   // arguments.
148   Status AddToGraph(SerializationContext* ctx,
149                     DatasetBase::DatasetGraphDefBuilder* b,
150                     std::vector<Node*>* other_arguments,
151                     DataTypeVector* other_arguments_types) const;
152 
153   // Instantiates this function for use in the given context, providing an
154   // InstantiatedCapturedFunction that can be used to execute functions.
155   Status Instantiate(IteratorContext* ctx,
156                      std::unique_ptr<InstantiatedCapturedFunction>*
157                          instantiated_captured_function);
158 
159   // Determines whether the captured function is stateful.
160   Status CheckExternalState() const;
161 
162   // Returns the additional captured inputs that will be passed to the function.
captured_inputs()163   const std::vector<Tensor>& captured_inputs() const {
164     return captured_inputs_;
165   }
166 
167   // Returns the named list of function arguments.
func()168   const NameAttrList& func() const { return metadata_->func(); }
169 
170   // Returns the transitive set of function definition required to instantiate
171   // this function.
lib_def()172   const FunctionLibraryDefinition* lib_def() const {
173     return metadata_->lib_def();
174   }
175 
176   // If every function output corresponds to one of its inputs, the method
177   // returns the mapping from output indices to input indices. Otherwise, it
178   // returns an empty list.
short_circuit_info()179   const ShortCircuitInfo& short_circuit_info() const {
180     return metadata_->short_circuit_info();
181   }
182 
183   // Indicates whether the function should use inter op parallelism.
use_inter_op_parallelism()184   bool use_inter_op_parallelism() const {
185     return metadata_->use_inter_op_parallelism();
186   }
187 
188  private:
189   CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
190                    std::vector<Tensor> captured_inputs);
191 
192   Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const;
193 
194   const std::shared_ptr<const FunctionMetadata> metadata_;
195   const std::vector<Tensor> captured_inputs_;
196 
197   TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
198 };
199 
200 // `InstantiatedCapturedFunction` encapsulates all the runtime support needed
201 // to execute a tensorflow function.
202 //
203 // While `CapturedFunction` encapsulates constant attributes of the function,
204 // such as its name and captured arguments, `InstantiatedCapturedFunction`
205 // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
206 // handle.
207 //
208 // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
209 // functions outside of the normal `OpKernel::Compute()` context.
210 class InstantiatedCapturedFunction {
211  public:
212   // Creates a new instance of the `InstantiatedCapturedFunction` class from the
213   // given inputs.
214   static Status Create(
215       FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
216       DataTypeVector ret_types,
217       std::function<void(std::function<void()>)> runner,
218       CapturedFunction* captured_func, bool is_multi_device,
219       std::unique_ptr<InstantiatedCapturedFunction>* out_function);
220 
221   // Runs the instantiated captured function. This method takes ownership of
222   // the tensors in `args`, in order to be able to deallocate them as early as
223   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
224   // ownership of the `args`.
225   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
226              std::vector<Tensor>* rets) const;
227 
228   // Runs the instantiated captured function. This method takes ownership of
229   // the tensors in `args`, in order to be able to deallocate them as early as
230   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
231   // ownership of the `args`. Pass non-null `node` to record processing time
232   // for modeling Iterator's GetNext() resource usage.
233   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
234              std::vector<Tensor>* rets,
235              const std::shared_ptr<model::Node>& node) const;
236 
237   // Synchronously runs the captured function on the given `args`, and stores
238   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
239   // possible.
240   Status RunWithBorrowedArgs(IteratorContext* ctx,
241                              const std::vector<Tensor>& args,
242                              std::vector<Tensor>* rets) const;
243 
244   // Synchronously runs the captured function on the given `args`, and stores
245   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
246   // possible. Pass non-null `node` to record processing time for modeling
247   // Iterator's GetNext() resource usage.
248   Status RunWithBorrowedArgs(IteratorContext* ctx,
249                              const std::vector<Tensor>& args,
250                              std::vector<Tensor>* rets,
251                              const std::shared_ptr<model::Node>& node) const;
252 
253   // Synchronously runs the captured function on the given `args`, and stores
254   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
255   // possible. This can be useful for calling a captured function in cases where
256   // an `IteratorContext*` is not available (such as a destructor).
257   //
258   // TODO(b/144278100): Avoid running functions without IteratorContext.
259   Status RunInstantiated(const std::vector<Tensor>& args,
260                          std::vector<Tensor>* rets);
261 
262   // Asynchronously runs the captured function on the given `args`, stores the
263   // results in `*rets`, and calls the given `done` callback when the function
264   // returns. This method takes ownership of the tensors in `args`, in order to
265   // be able to deallocate them as early as possible. Pass non-null `node` to
266   // record processing time for modeling Iterator's GetNext() resource usage.
267   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
268                 std::vector<Tensor>* rets,
269                 FunctionLibraryRuntime::DoneCallback done,
270                 const std::shared_ptr<model::Node>& node) const;
271 
272  private:
273   InstantiatedCapturedFunction(
274       FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
275       DataTypeVector ret_types,
276       std::function<void(std::function<void()>)> runner,
277       CapturedFunction* captured_func, bool is_multi_device);
278 
279   // Determines whether a rendezvous object should be created when running the
280   // instantiated function.
281   bool ShouldCreateRendezvous() const;
282 
283   FunctionLibraryRuntime* const lib_;  // Not owned.
284   const FunctionLibraryRuntime::Handle f_handle_;
285   const DataTypeVector ret_types_;
286   // Note: We capture the runner at function instantiation time to be able to
287   // run the function without `IteratorContext` via `RunInstantiated`.
288   std::function<void(std::function<void()>)> captured_runner_;
289   CapturedFunction* const captured_func_;  // Not owned.
290   const bool is_multi_device_;
291 
292   TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
293 };
294 
295 }  // namespace data
296 }  // namespace tensorflow
297 
298 #endif  // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
299