• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
16 #define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/cancellation.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/model.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 class Device;
35 class OpKernelContext;
36 class ResourceMgr;
37 
38 namespace data {
39 
40 class CapturedFunction;
41 class InstantiatedCapturedFunction;
42 
43 // Creates an iterator for a dataset which is created by applying the given
44 // function to the given input element.
45 Status MakeIteratorFromInputElement(
46     IteratorContext* ctx, const IteratorBase* parent,
47     const std::vector<Tensor>& input_element, int64_t thread_index,
48     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
49     std::unique_ptr<IteratorBase>* out_iterator);
50 
51 // Creates an iterator for a dataset which is created by applying the given
52 // function to the given input element. Pass non-null `node` to record
53 // processing time for modeling Iterator's GetNext() resource usage.
54 Status MakeIteratorFromInputElement(
55     IteratorContext* ctx, const IteratorBase* parent,
56     const std::vector<Tensor>& input_element, int64_t thread_index,
57     const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
58     std::unique_ptr<IteratorBase>* out_iterator,
59     const std::shared_ptr<model::Node>& node);
60 
61 struct ShortCircuitInfo {
62   std::vector<int> indices;
63   std::vector<bool> can_move;
64 };
65 
66 // Metadata shared across all captures of the same function.
67 class FunctionMetadata {
68  public:
69   struct Params {
70     bool use_inter_op_parallelism = true;
71     bool use_default_device = true;
72   };
73 
74   // Creates a new instance of the `FunctionMetadata` class, fetching function
75   // from a context argument.
76   static Status Create(tensorflow::OpKernelConstruction* ctx,
77                        const string& func_name, Params params,
78                        std::shared_ptr<FunctionMetadata>* out_metadata);
79 
80   // Creates a new instance of the `FunctionMetadata` class, using the provided
81   // function.
82   static Status Create(tensorflow::OpKernelConstruction* ctx,
83                        NameAttrList&& func, Params params,
84                        std::shared_ptr<FunctionMetadata>* out_metadata);
85 
86   // Returns the named list of function arguments.
func()87   const NameAttrList& func() const { return func_; }
88 
89   // Returns a borrowed pointer to the function library that contains the
90   // transitive closure of definitions used by the function.
lib_def()91   const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); }
92 
93   // Returns short-circuit information.
short_circuit_info()94   const ShortCircuitInfo& short_circuit_info() const {
95     return short_circuit_info_;
96   }
97 
98   // Indicates whether a default device should be used for executing function
99   // ops.
use_default_device()100   bool use_default_device() const { return use_default_device_; }
101 
102   // Indicates whether to use inter-op parallelism for execution of the
103   // function.
use_inter_op_parallelism()104   bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; }
105 
106   // Indicates whether the function should a multi-device function backend.
use_multi_device_function()107   bool use_multi_device_function() const { return use_multi_device_function_; }
108 
109  private:
FunctionMetadata(NameAttrList && func,Params params)110   FunctionMetadata(NameAttrList&& func, Params params)
111       : func_(std::move(func)),
112         use_default_device_(params.use_default_device),
113         use_inter_op_parallelism_(params.use_inter_op_parallelism) {}
114 
115   NameAttrList func_;
116   std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr;
117   ShortCircuitInfo short_circuit_info_;
118   bool use_default_device_ = true;
119   bool use_inter_op_parallelism_ = true;
120   bool use_multi_device_function_ = true;
121 };
122 
123 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured"
124 // arguments that it closed over in the user program.
125 class CapturedFunction {
126  public:
127   // Creates a new instance using a list of named attributes, fetching captured
128   // inputs from a context argument.
129   static Status Create(OpKernelContext* ctx,
130                        std::shared_ptr<const FunctionMetadata> metadata,
131                        const string& argument_name,
132                        std::unique_ptr<CapturedFunction>* out_function);
133 
134   // Creates a new instance using a list of named attributes, using provided
135   // captured inputs.
136   static Status Create(OpKernelContext* ctx,
137                        std::shared_ptr<const FunctionMetadata> metadata,
138                        std::vector<Tensor>&& captured_inputs,
139                        std::unique_ptr<CapturedFunction>* out_function);
140 
141   // Adds the definition of this captured function into the given graph,
142   // returning its captured inputs and types through the respective output
143   // arguments.
144   Status AddToGraph(SerializationContext* ctx,
145                     DatasetBase::DatasetGraphDefBuilder* b,
146                     std::vector<Node*>* other_arguments,
147                     DataTypeVector* other_arguments_types) const;
148 
149   // Instantiates this function for use in the given context, providing an
150   // InstantiatedCapturedFunction that can be used to execute functions.
151   Status Instantiate(IteratorContext* ctx,
152                      std::unique_ptr<InstantiatedCapturedFunction>*
153                          instantiated_captured_function);
154 
155   // Determines whether the captured function is stateful.
156   Status CheckExternalState() const;
157 
158   // Returns the additional captured inputs that will be passed to the function.
captured_inputs()159   const std::vector<Tensor>& captured_inputs() const {
160     return captured_inputs_;
161   }
162 
163   // Returns the named list of function arguments.
func()164   const NameAttrList& func() const { return metadata_->func(); }
165 
166   // Returns the transitive set of function definition required to instantiate
167   // this function.
lib_def()168   const FunctionLibraryDefinition* lib_def() const {
169     return metadata_->lib_def();
170   }
171 
172   // If every function output corresponds to one of its inputs, the method
173   // returns the mapping from output indices to input indices. Otherwise, it
174   // returns an empty list.
short_circuit_info()175   const ShortCircuitInfo& short_circuit_info() const {
176     return metadata_->short_circuit_info();
177   }
178 
179   // Indicates whether the function should use inter op parallelism.
use_inter_op_parallelism()180   bool use_inter_op_parallelism() const {
181     return metadata_->use_inter_op_parallelism();
182   }
183 
184  private:
185   CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata,
186                    std::vector<Tensor> captured_inputs);
187 
188   Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const;
189 
190   const std::shared_ptr<const FunctionMetadata> metadata_;
191   const std::vector<Tensor> captured_inputs_;
192 
193   TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
194 };
195 
196 // `InstantiatedCapturedFunction` encapsulates all the runtime support needed
197 // to execute a tensorflow function.
198 //
199 // While `CapturedFunction` encapsulates constant attributes of the function,
200 // such as its name and captured arguments, `InstantiatedCapturedFunction`
201 // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
202 // handle.
203 //
204 // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
205 // functions outside of the normal `OpKernel::Compute()` context.
206 class InstantiatedCapturedFunction {
207  public:
208   // Creates a new instance of the `InstantiatedCapturedFunction` class from the
209   // given inputs.
210   static Status Create(
211       FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
212       DataTypeVector ret_types,
213       std::function<void(std::function<void()>)> runner,
214       CapturedFunction* captured_func, bool is_multi_device,
215       std::unique_ptr<InstantiatedCapturedFunction>* out_function);
216 
217   // Runs the instantiated captured function. This method takes ownership of
218   // the tensors in `args`, in order to be able to deallocate them as early as
219   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
220   // ownership of the `args`.
221   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
222              std::vector<Tensor>* rets) const;
223 
224   // Runs the instantiated captured function. This method takes ownership of
225   // the tensors in `args`, in order to be able to deallocate them as early as
226   // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
227   // ownership of the `args`. Pass non-null `node` to record processing time
228   // for modeling Iterator's GetNext() resource usage.
229   Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
230              std::vector<Tensor>* rets,
231              const std::shared_ptr<model::Node>& node) const;
232 
233   // Synchronously runs the captured function on the given `args`, and stores
234   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
235   // possible.
236   Status RunWithBorrowedArgs(IteratorContext* ctx,
237                              const std::vector<Tensor>& args,
238                              std::vector<Tensor>* rets) const;
239 
240   // Synchronously runs the captured function on the given `args`, and stores
241   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
242   // possible. Pass non-null `node` to record processing time for modeling
243   // Iterator's GetNext() resource usage.
244   Status RunWithBorrowedArgs(IteratorContext* ctx,
245                              const std::vector<Tensor>& args,
246                              std::vector<Tensor>* rets,
247                              const std::shared_ptr<model::Node>& node) const;
248 
249   // Synchronously runs the captured function on the given `args`, and stores
250   // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
251   // possible. This can be useful for calling a captured function in cases where
252   // an `IteratorContext*` is not available (such as a destructor).
253   //
254   // TODO(b/144278100): Avoid running functions without IteratorContext.
255   Status RunInstantiated(const std::vector<Tensor>& args,
256                          std::vector<Tensor>* rets);
257 
258   // Asynchronously runs the captured function on the given `args`, stores the
259   // results in `*rets`, and calls the given `done` callback when the function
260   // returns. This method takes ownership of the tensors in `args`, in order to
261   // be able to deallocate them as early as possible. Pass non-null `node` to
262   // record processing time for modeling Iterator's GetNext() resource usage.
263   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
264                 std::vector<Tensor>* rets,
265                 FunctionLibraryRuntime::DoneCallback done,
266                 const std::shared_ptr<model::Node>& node) const;
267 
268  private:
269   InstantiatedCapturedFunction(
270       FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
271       DataTypeVector ret_types,
272       std::function<void(std::function<void()>)> runner,
273       CapturedFunction* captured_func, bool is_multi_device);
274 
275   // Determines whether a rendezvous object should be created when running the
276   // instantiated function.
277   bool ShouldCreateRendezvous() const;
278 
279   FunctionLibraryRuntime* const lib_;  // Not owned.
280   const FunctionLibraryRuntime::Handle f_handle_;
281   const DataTypeVector ret_types_;
282   // Note: We capture the runner at function instantiation time to be able to
283   // run the function without `IteratorContext` via `RunInstantiated`.
284   std::function<void(std::function<void()>)> captured_runner_;
285   CapturedFunction* const captured_func_;  // Not owned.
286   const bool is_multi_device_;
287 
288   TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
289 };
290 
291 }  // namespace data
292 }  // namespace tensorflow
293 
294 #endif  // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_
295