1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 16 #define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/lib/random/random.h" 27 #include "tensorflow/core/platform/macros.h" 28 29 namespace tensorflow { 30 31 class Device; 32 class OpKernelContext; 33 class ResourceMgr; 34 35 namespace data { 36 37 class CapturedFunction; 38 39 // An InstantiatedCapturedFunction encapsulates all the runtime support needed 40 // to execute a tensorflow function. 41 // 42 // While CapturedFunction (below) encapsulates the more permanent attributes 43 // of the function i.e. name, captured arguments etc., 44 // InstantiatedCapturedFunction encapsulates the more runtime aspects i.e. 45 // FunctionLibraryRuntime, function handle etc. 46 // 47 // The `Iterator-`related classes use `InstantiatedCapturedFunction` to execute 48 // functions outside a the normal `OpKernel::Compute()` context. 49 class InstantiatedCapturedFunction { 50 public: 51 ~InstantiatedCapturedFunction(); 52 53 // Runs the "Instantiated Captured function". This method takes ownership of 54 // the tensors in `args`, in order to be able to deallocate them as early as 55 // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain 56 // ownership of the `args`. 57 Status Run(IteratorContext* ctx, std::vector<Tensor>&& args, 58 std::vector<Tensor>* rets) const; 59 60 // Synchronously runs the captured function on the given `args`, and stores 61 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 62 // possible. 63 Status RunWithBorrowedArgs(IteratorContext* ctx, 64 const std::vector<Tensor>& args, 65 std::vector<Tensor>* rets) const; 66 67 // Synchronously runs the captured function on the given `args`, and stores 68 // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when 69 // possible. This can be useful for calling a captured 70 // function in cases where an `IteratorContext*` is not available 71 // (such as a destructor). 72 Status RunInstantiated(const std::vector<Tensor>& args, 73 std::vector<Tensor>* rets); 74 75 // Asynchronously runs the captured function on the given `args`, stores 76 // the results in `*rets`, and calls the given `done` callback when the 77 // function returns. This method takes ownership of the tensors in `args`, 78 // in order to be able to deallocate them as early as possible. 79 void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, 80 std::vector<Tensor>* rets, 81 FunctionLibraryRuntime::DoneCallback done, 82 const string& prefix) const; 83 84 // Returns a step ID for use when running an `InstantiatedCapturedFunction`. generate_step_id()85 static int64 generate_step_id() { 86 // Choose a step ID that is guaranteed not to clash with any 87 // Session-generated step ID. DirectSession only generates 88 // non-negative step IDs (contiguous, starting from 0), and 89 // MasterSession generates 56-bit random step IDs whose MSB is 90 // always 0, so a negative random step ID should suffice. 91 return -std::abs(static_cast<int64>(random::New64())); 92 } 93 94 private: 95 InstantiatedCapturedFunction( 96 FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, 97 DataTypeVector ret_types, 98 std::function<void(std::function<void()>)> runner, 99 CapturedFunction* captured_func); 100 101 friend class CapturedFunction; 102 103 FunctionLibraryRuntime* const lib_; 104 const FunctionLibraryRuntime::Handle f_handle_; 105 const DataTypeVector ret_types_; 106 std::function<void(std::function<void()>)> captured_runner_; 107 CapturedFunction* const captured_func_; 108 109 TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction); 110 }; 111 112 // A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured" 113 // arguments that it closed over in the user program. 114 class CapturedFunction { 115 public: 116 // Creates a new instance using a list of named attributes, fetching captured 117 // inputs from a context argument. 118 static Status Create(const NameAttrList& func, OpKernelContext* ctx, 119 const string& argument_name, 120 std::unique_ptr<CapturedFunction>* out_function); 121 122 // Creates a new instance using a list of named attributes, fetching captured 123 // inputs from a context argument. 124 // 125 // If `use_inter_op_parallelism` is false, the runtime may use an executor 126 // that is optimized for small functions. 127 static Status Create(const NameAttrList& func, OpKernelContext* ctx, 128 const string& argument_name, 129 bool use_inter_op_parallelism, 130 std::unique_ptr<CapturedFunction>* out_function); 131 132 // Creates a new instance using a list of named attributes, using provided 133 // captured inputs. 134 // 135 // If `use_inter_op_parallelism` is false, the runtime may use an executor 136 // that is optimized for small functions. 137 static Status Create(const NameAttrList& func, OpKernelContext* ctx, 138 std::vector<Tensor>&& captured_inputs, 139 bool use_inter_op_parallelism, 140 std::unique_ptr<CapturedFunction>* out_function); 141 142 // Instantiates this function for use in the given context, providing an 143 // InstantiatedCapturedFunction that can be used to execute functions. 144 Status Instantiate(IteratorContext* ctx, 145 std::unique_ptr<InstantiatedCapturedFunction>* 146 instantiated_captured_function); 147 148 // Returns the named list of function arguments. func()149 const NameAttrList& func() { return func_; } 150 151 // Returns that additional captured inputs that will be passed to the function captured_inputs()152 const std::vector<Tensor>& captured_inputs() { return captured_inputs_; } 153 154 private: 155 CapturedFunction(const NameAttrList& func, 156 std::vector<Tensor> captured_inputs, 157 bool use_inter_op_parallelism); 158 159 const NameAttrList func_; 160 const std::vector<Tensor> captured_inputs_; 161 const bool use_inter_op_parallelism_; 162 163 TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); 164 }; 165 } // namespace data 166 167 // TODO(b/114112161): Remove these aliases when all users have moved over to the 168 // `tensorflow::data` namespace. 169 using data::CapturedFunction; 170 171 } // namespace tensorflow 172 173 #endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ 174