1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 16 #define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/framework/cancellation.h" 22 #include "tensorflow/core/framework/dataset.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/model.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/lib/random/random.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class Device; 35 class OpKernelContext; 36 class ResourceMgr; 37 38 namespace data { 39 40 class CapturedFunction; 41 class InstantiatedCapturedFunction; 42 43 // Creates an iterator for a dataset which is created by applying the given 44 // function to the given input element. 45 Status MakeIteratorFromInputElement( 46 IteratorContext* ctx, const IteratorBase* parent, 47 const std::vector<Tensor>& input_element, int64 thread_index, 48 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, 49 std::unique_ptr<IteratorBase>* out_iterator); 50 51 // Creates an iterator for a dataset which is created by applying the given 52 // function to the given input element. Pass non-null `node` to record 53 // processing time for modeling Iterator's GetNext() resource usage. 54 Status MakeIteratorFromInputElement( 55 IteratorContext* ctx, const IteratorBase* parent, 56 const std::vector<Tensor>& input_element, int64 thread_index, 57 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, 58 std::unique_ptr<IteratorBase>* out_iterator, 59 const std::shared_ptr<model::Node>& node); 60 61 // Determines whether the given node is stateful. 62 Status IsNodeStateful(const FunctionLibraryDefinition& library, 63 const NodeDef& node); 64 65 struct ShortCircuitInfo { 66 std::vector<int> indices; 67 std::vector<bool> can_move; 68 }; 69 70 // Metadata shared across all captures of the same function. 71 class FunctionMetadata { 72 public: 73 struct Params { 74 bool use_inter_op_parallelism = true; 75 bool use_default_device = true; 76 }; 77 78 // Creates a new instance of the `FunctionMetadata` class, fetching function 79 // from a context argument. 80 static Status Create(tensorflow::OpKernelConstruction* ctx, 81 const string& func_name, Params params, 82 std::shared_ptr<FunctionMetadata>* out_metadata); 83 84 // Creates a new instance of the `FunctionMetadata` class, using the provided 85 // function. 86 static Status Create(tensorflow::OpKernelConstruction* ctx, 87 NameAttrList&& func, Params params, 88 std::shared_ptr<FunctionMetadata>* out_metadata); 89 90 // Returns the named list of function arguments. func()91 const NameAttrList& func() const { return func_; } 92 93 // Returns a borrowed pointer to the function library that contains the 94 // transitive closure of definitions used by the function. lib_def()95 const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); } 96 97 // Returns short-circuit information. short_circuit_info()98 const ShortCircuitInfo& short_circuit_info() const { 99 return short_circuit_info_; 100 } 101 102 // Indicates whether a default device should be used for executing function 103 // ops. use_default_device()104 bool use_default_device() const { return use_default_device_; } 105 106 // Indicates whether to use inter-op parallelism for execution of the 107 // function. use_inter_op_parallelism()108 bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; } 109 110 // Indicates whether the function should a multi-device function backend. use_multi_device_function()111 bool use_multi_device_function() const { return use_multi_device_function_; } 112 113 private: FunctionMetadata(NameAttrList && func,Params params)114 FunctionMetadata(NameAttrList&& func, Params params) 115 : func_(std::move(func)), 116 use_default_device_(params.use_default_device), 117 use_inter_op_parallelism_(params.use_inter_op_parallelism) {} 118 119 NameAttrList func_; 120 std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr; 121 ShortCircuitInfo short_circuit_info_; 122 bool use_default_device_ = true; 123 bool use_inter_op_parallelism_ = true; 124 bool use_multi_device_function_ = true; 125 }; 126 127 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured" 128 // arguments that it closed over in the user program. 129 class CapturedFunction { 130 public: 131 // Creates a new instance using a list of named attributes, fetching captured 132 // inputs from a context argument. 133 static Status Create(OpKernelContext* ctx, 134 std::shared_ptr<const FunctionMetadata> metadata, 135 const string& argument_name, 136 std::unique_ptr<CapturedFunction>* out_function); 137 138 // Creates a new instance using a list of named attributes, using provided 139 // captured inputs. 140 static Status Create(OpKernelContext* ctx, 141 std::shared_ptr<const FunctionMetadata> metadata, 142 std::vector<Tensor>&& captured_inputs, 143 std::unique_ptr<CapturedFunction>* out_function); 144 145 // Adds the definition of this captured function into the given graph, 146 // returning its captured inputs and types through the respective output 147 // arguments. 148 Status AddToGraph(SerializationContext* ctx, 149 DatasetBase::DatasetGraphDefBuilder* b, 150 std::vector<Node*>* other_arguments, 151 DataTypeVector* other_arguments_types) const; 152 153 // Instantiates this function for use in the given context, providing an 154 // InstantiatedCapturedFunction that can be used to execute functions. 155 Status Instantiate(IteratorContext* ctx, 156 std::unique_ptr<InstantiatedCapturedFunction>* 157 instantiated_captured_function); 158 159 // Determines whether the captured function is stateful. 160 Status CheckExternalState() const; 161 162 // Returns the additional captured inputs that will be passed to the function. captured_inputs()163 const std::vector<Tensor>& captured_inputs() const { 164 return captured_inputs_; 165 } 166 167 // Returns the named list of function arguments. func()168 const NameAttrList& func() const { return metadata_->func(); } 169 170 // Returns the transitive set of function definition required to instantiate 171 // this function. lib_def()172 const FunctionLibraryDefinition* lib_def() const { 173 return metadata_->lib_def(); 174 } 175 176 // If every function output corresponds to one of its inputs, the method 177 // returns the mapping from output indices to input indices. Otherwise, it 178 // returns an empty list. short_circuit_info()179 const ShortCircuitInfo& short_circuit_info() const { 180 return metadata_->short_circuit_info(); 181 } 182 183 // Indicates whether the function should use inter op parallelism. use_inter_op_parallelism()184 bool use_inter_op_parallelism() const { 185 return metadata_->use_inter_op_parallelism(); 186 } 187 188 private: 189 CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata, 190 std::vector<Tensor> captured_inputs); 191 192 Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const; 193 194 const std::shared_ptr<const FunctionMetadata> metadata_; 195 const std::vector<Tensor> captured_inputs_; 196 197 TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); 198 }; 199 200 // `InstantiatedCapturedFunction` encapsulates all the runtime support needed 201 // to execute a tensorflow function. 202 // 203 // While `CapturedFunction` encapsulates constant attributes of the function, 204 // such as its name and captured arguments, `InstantiatedCapturedFunction` 205 // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function 206 // handle. 207 // 208 // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute 209 // functions outside of the normal `OpKernel::Compute()` context. 210 class InstantiatedCapturedFunction { 211 public: 212 // Creates a new instance of the `InstantiatedCapturedFunction` class from the 213 // given inputs. 214 static Status Create( 215 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, 216 DataTypeVector ret_types, 217 std::function<void(std::function<void()>)> runner, 218 CapturedFunction* captured_func, bool is_multi_device, 219 std::unique_ptr<InstantiatedCapturedFunction>* out_function); 220 221 // Runs the instantiated captured function. This method takes ownership of 222 // the tensors in `args`, in order to be able to deallocate them as early as 223 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 224 // ownership of the `args`. 225 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 226 std::vector<Tensor>* rets) const; 227 228 // Runs the instantiated captured function. This method takes ownership of 229 // the tensors in `args`, in order to be able to deallocate them as early as 230 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 231 // ownership of the `args`. Pass non-null `node` to record processing time 232 // for modeling Iterator's GetNext() resource usage. 233 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 234 std::vector<Tensor>* rets, 235 const std::shared_ptr<model::Node>& node) const; 236 237 // Synchronously runs the captured function on the given `args`, and stores 238 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 239 // possible. 240 Status RunWithBorrowedArgs(IteratorContext* ctx, 241 const std::vector<Tensor>& args, 242 std::vector<Tensor>* rets) const; 243 244 // Synchronously runs the captured function on the given `args`, and stores 245 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 246 // possible. Pass non-null `node` to record processing time for modeling 247 // Iterator's GetNext() resource usage. 248 Status RunWithBorrowedArgs(IteratorContext* ctx, 249 const std::vector<Tensor>& args, 250 std::vector<Tensor>* rets, 251 const std::shared_ptr<model::Node>& node) const; 252 253 // Synchronously runs the captured function on the given `args`, and stores 254 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 255 // possible. This can be useful for calling a captured function in cases where 256 // an `IteratorContext*` is not available (such as a destructor). 257 // 258 // TODO(b/144278100): Avoid running functions without IteratorContext. 259 Status RunInstantiated(const std::vector<Tensor>& args, 260 std::vector<Tensor>* rets); 261 262 // Asynchronously runs the captured function on the given `args`, stores the 263 // results in `*rets`, and calls the given `done` callback when the function 264 // returns. This method takes ownership of the tensors in `args`, in order to 265 // be able to deallocate them as early as possible. Pass non-null `node` to 266 // record processing time for modeling Iterator's GetNext() resource usage. 267 void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, 268 std::vector<Tensor>* rets, 269 FunctionLibraryRuntime::DoneCallback done, 270 const std::shared_ptr<model::Node>& node) const; 271 272 private: 273 InstantiatedCapturedFunction( 274 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, 275 DataTypeVector ret_types, 276 std::function<void(std::function<void()>)> runner, 277 CapturedFunction* captured_func, bool is_multi_device); 278 279 // Determines whether a rendezvous object should be created when running the 280 // instantiated function. 281 bool ShouldCreateRendezvous() const; 282 283 FunctionLibraryRuntime* const lib_; // Not owned. 284 const FunctionLibraryRuntime::Handle f_handle_; 285 const DataTypeVector ret_types_; 286 // Note: We capture the runner at function instantiation time to be able to 287 // run the function without `IteratorContext` via `RunInstantiated`. 288 std::function<void(std::function<void()>)> captured_runner_; 289 CapturedFunction* const captured_func_; // Not owned. 290 const bool is_multi_device_; 291 292 TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); 293 }; 294 295 } // namespace data 296 } // namespace tensorflow 297 298 #endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 299