1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ 16 #define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/framework/cancellation.h" 22 #include "tensorflow/core/framework/dataset.h" 23 #include "tensorflow/core/framework/function.h" 24 #include "tensorflow/core/framework/model.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/lib/random/random.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 class Device; 35 class OpKernelContext; 36 class ResourceMgr; 37 38 namespace data { 39 40 class CapturedFunction; 41 class InstantiatedCapturedFunction; 42 43 // Creates an iterator for a dataset which is created by applying the given 44 // function to the given input element. 45 Status MakeIteratorFromInputElement( 46 IteratorContext* ctx, const IteratorBase* parent, 47 const std::vector<Tensor>& input_element, int64_t thread_index, 48 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, 49 std::unique_ptr<IteratorBase>* out_iterator); 50 51 // Creates an iterator for a dataset which is created by applying the given 52 // function to the given input element. Pass non-null `node` to record 53 // processing time for modeling Iterator's GetNext() resource usage. 54 Status MakeIteratorFromInputElement( 55 IteratorContext* ctx, const IteratorBase* parent, 56 const std::vector<Tensor>& input_element, int64_t thread_index, 57 const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, 58 std::unique_ptr<IteratorBase>* out_iterator, 59 const std::shared_ptr<model::Node>& node); 60 61 struct ShortCircuitInfo { 62 std::vector<int> indices; 63 std::vector<bool> can_move; 64 }; 65 66 // Metadata shared across all captures of the same function. 67 class FunctionMetadata { 68 public: 69 struct Params { 70 bool use_inter_op_parallelism = true; 71 bool use_default_device = true; 72 }; 73 74 // Creates a new instance of the `FunctionMetadata` class, fetching function 75 // from a context argument. 76 static Status Create(tensorflow::OpKernelConstruction* ctx, 77 const string& func_name, Params params, 78 std::shared_ptr<FunctionMetadata>* out_metadata); 79 80 // Creates a new instance of the `FunctionMetadata` class, using the provided 81 // function. 82 static Status Create(tensorflow::OpKernelConstruction* ctx, 83 NameAttrList&& func, Params params, 84 std::shared_ptr<FunctionMetadata>* out_metadata); 85 86 // Returns the named list of function arguments. func()87 const NameAttrList& func() const { return func_; } 88 89 // Returns a borrowed pointer to the function library that contains the 90 // transitive closure of definitions used by the function. lib_def()91 const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); } 92 93 // Returns short-circuit information. short_circuit_info()94 const ShortCircuitInfo& short_circuit_info() const { 95 return short_circuit_info_; 96 } 97 98 // Indicates whether a default device should be used for executing function 99 // ops. use_default_device()100 bool use_default_device() const { return use_default_device_; } 101 102 // Indicates whether to use inter-op parallelism for execution of the 103 // function. use_inter_op_parallelism()104 bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; } 105 106 // Indicates whether the function should a multi-device function backend. use_multi_device_function()107 bool use_multi_device_function() const { return use_multi_device_function_; } 108 109 private: FunctionMetadata(NameAttrList && func,Params params)110 FunctionMetadata(NameAttrList&& func, Params params) 111 : func_(std::move(func)), 112 use_default_device_(params.use_default_device), 113 use_inter_op_parallelism_(params.use_inter_op_parallelism) {} 114 115 NameAttrList func_; 116 std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr; 117 ShortCircuitInfo short_circuit_info_; 118 bool use_default_device_ = true; 119 bool use_inter_op_parallelism_ = true; 120 bool use_multi_device_function_ = true; 121 }; 122 123 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured" 124 // arguments that it closed over in the user program. 125 class CapturedFunction { 126 public: 127 // Creates a new instance using a list of named attributes, fetching captured 128 // inputs from a context argument. 129 static Status Create(OpKernelContext* ctx, 130 std::shared_ptr<const FunctionMetadata> metadata, 131 const string& argument_name, 132 std::unique_ptr<CapturedFunction>* out_function); 133 134 // Creates a new instance using a list of named attributes, using provided 135 // captured inputs. 136 static Status Create(OpKernelContext* ctx, 137 std::shared_ptr<const FunctionMetadata> metadata, 138 std::vector<Tensor>&& captured_inputs, 139 std::unique_ptr<CapturedFunction>* out_function); 140 141 // Adds the definition of this captured function into the given graph, 142 // returning its captured inputs and types through the respective output 143 // arguments. 144 Status AddToGraph(SerializationContext* ctx, 145 DatasetBase::DatasetGraphDefBuilder* b, 146 std::vector<Node*>* other_arguments, 147 DataTypeVector* other_arguments_types) const; 148 149 // Instantiates this function for use in the given context, providing an 150 // InstantiatedCapturedFunction that can be used to execute functions. 151 Status Instantiate(IteratorContext* ctx, 152 std::unique_ptr<InstantiatedCapturedFunction>* 153 instantiated_captured_function); 154 155 // Determines whether the captured function is stateful. 156 Status CheckExternalState() const; 157 158 // Returns the additional captured inputs that will be passed to the function. captured_inputs()159 const std::vector<Tensor>& captured_inputs() const { 160 return captured_inputs_; 161 } 162 163 // Returns the named list of function arguments. func()164 const NameAttrList& func() const { return metadata_->func(); } 165 166 // Returns the transitive set of function definition required to instantiate 167 // this function. lib_def()168 const FunctionLibraryDefinition* lib_def() const { 169 return metadata_->lib_def(); 170 } 171 172 // If every function output corresponds to one of its inputs, the method 173 // returns the mapping from output indices to input indices. Otherwise, it 174 // returns an empty list. short_circuit_info()175 const ShortCircuitInfo& short_circuit_info() const { 176 return metadata_->short_circuit_info(); 177 } 178 179 // Indicates whether the function should use inter op parallelism. use_inter_op_parallelism()180 bool use_inter_op_parallelism() const { 181 return metadata_->use_inter_op_parallelism(); 182 } 183 184 private: 185 CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata, 186 std::vector<Tensor> captured_inputs); 187 188 Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device) const; 189 190 const std::shared_ptr<const FunctionMetadata> metadata_; 191 const std::vector<Tensor> captured_inputs_; 192 193 TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); 194 }; 195 196 // `InstantiatedCapturedFunction` encapsulates all the runtime support needed 197 // to execute a tensorflow function. 198 // 199 // While `CapturedFunction` encapsulates constant attributes of the function, 200 // such as its name and captured arguments, `InstantiatedCapturedFunction` 201 // encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function 202 // handle. 203 // 204 // The `Iterator` related classes use `InstantiatedCapturedFunction` to execute 205 // functions outside of the normal `OpKernel::Compute()` context. 206 class InstantiatedCapturedFunction { 207 public: 208 // Creates a new instance of the `InstantiatedCapturedFunction` class from the 209 // given inputs. 210 static Status Create( 211 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, 212 DataTypeVector ret_types, 213 std::function<void(std::function<void()>)> runner, 214 CapturedFunction* captured_func, bool is_multi_device, 215 std::unique_ptr<InstantiatedCapturedFunction>* out_function); 216 217 // Runs the instantiated captured function. This method takes ownership of 218 // the tensors in `args`, in order to be able to deallocate them as early as 219 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 220 // ownership of the `args`. 221 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 222 std::vector<Tensor>* rets) const; 223 224 // Runs the instantiated captured function. This method takes ownership of 225 // the tensors in `args`, in order to be able to deallocate them as early as 226 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 227 // ownership of the `args`. Pass non-null `node` to record processing time 228 // for modeling Iterator's GetNext() resource usage. 229 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 230 std::vector<Tensor>* rets, 231 const std::shared_ptr<model::Node>& node) const; 232 233 // Synchronously runs the captured function on the given `args`, and stores 234 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 235 // possible. 236 Status RunWithBorrowedArgs(IteratorContext* ctx, 237 const std::vector<Tensor>& args, 238 std::vector<Tensor>* rets) const; 239 240 // Synchronously runs the captured function on the given `args`, and stores 241 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 242 // possible. Pass non-null `node` to record processing time for modeling 243 // Iterator's GetNext() resource usage. 244 Status RunWithBorrowedArgs(IteratorContext* ctx, 245 const std::vector<Tensor>& args, 246 std::vector<Tensor>* rets, 247 const std::shared_ptr<model::Node>& node) const; 248 249 // Synchronously runs the captured function on the given `args`, and stores 250 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 251 // possible. This can be useful for calling a captured function in cases where 252 // an `IteratorContext*` is not available (such as a destructor). 253 // 254 // TODO(b/144278100): Avoid running functions without IteratorContext. 255 Status RunInstantiated(const std::vector<Tensor>& args, 256 std::vector<Tensor>* rets); 257 258 // Asynchronously runs the captured function on the given `args`, stores the 259 // results in `*rets`, and calls the given `done` callback when the function 260 // returns. This method takes ownership of the tensors in `args`, in order to 261 // be able to deallocate them as early as possible. Pass non-null `node` to 262 // record processing time for modeling Iterator's GetNext() resource usage. 263 void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, 264 std::vector<Tensor>* rets, 265 FunctionLibraryRuntime::DoneCallback done, 266 const std::shared_ptr<model::Node>& node) const; 267 268 private: 269 InstantiatedCapturedFunction( 270 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, 271 DataTypeVector ret_types, 272 std::function<void(std::function<void()>)> runner, 273 CapturedFunction* captured_func, bool is_multi_device); 274 275 // Determines whether a rendezvous object should be created when running the 276 // instantiated function. 277 bool ShouldCreateRendezvous() const; 278 279 FunctionLibraryRuntime* const lib_; // Not owned. 280 const FunctionLibraryRuntime::Handle f_handle_; 281 const DataTypeVector ret_types_; 282 // Note: We capture the runner at function instantiation time to be able to 283 // run the function without `IteratorContext` via `RunInstantiated`. 284 std::function<void(std::function<void()>)> captured_runner_; 285 CapturedFunction* const captured_func_; // Not owned. 286 const bool is_multi_device_; 287 288 TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); 289 }; 290 291 } // namespace data 292 } // namespace tensorflow 293 294 #endif // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ 295