1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 18 19 #include "tensorflow/core/common_runtime/function.h" 20 #include "tensorflow/core/framework/dataset.h" 21 #include "tensorflow/core/framework/function_handle_cache.h" 22 #include "tensorflow/core/framework/metrics.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor_shape.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/kernels/data/dataset_utils.h" 27 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h" 28 #include "tensorflow/core/kernels/ops_util.h" 29 30 namespace tensorflow { 31 namespace data { 32 33 class IteratorResource : public ResourceBase { 34 public: 35 IteratorResource(Env* env, const DataTypeVector& output_dtypes, 36 const std::vector<PartialTensorShape>& output_shapes, 37 std::unique_ptr<DeviceMgr> device_mgr, 38 std::unique_ptr<FunctionLibraryDefinition> flib_def, 39 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 40 FunctionLibraryRuntime* flr); 41 42 ~IteratorResource() override; 43 44 // Gets the next output from the iterator managed by this iterator resource. 45 // 46 // If at least one output remains, that output will be stored in 47 // `*out_tensors` and `false` will be stored in `*end_of_sequence`. 48 // 49 // If no more outputs remain, `true` will be stored in `*end_of_sequence`, and 50 // the content of `*out_tensors` will be undefined. 51 Status GetNext(OpKernelContext* ctx, std::vector<Tensor>* out_tensors, 52 bool* end_of_sequence); 53 54 // Saves a checkpoint of the state of the iterator through the given `writer`. 55 Status Save(SerializationContext* ctx, IteratorStateWriter* writer); 56 57 // Restores the state of the iterator from a checkpoint created by `Save`. 58 Status Restore(OpKernelContext* ctx, IteratorStateReader* reader); 59 60 // Creates an iterator for `dataset`, and associates the iterator with this 61 // iterator resource. 62 // 63 // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`, 64 // or `Restore`. 65 Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset); 66 DebugString()67 string DebugString() const override { return "Iterator resource"; } 68 output_dtypes()69 const DataTypeVector& output_dtypes() const { return output_dtypes_; } 70 output_shapes()71 const std::vector<PartialTensorShape>& output_shapes() const { 72 return output_shapes_; 73 } 74 75 private: 76 class State { 77 public: State(std::shared_ptr<FunctionLibraryDefinition> flib_def,std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr,std::unique_ptr<DatasetBaseIterator> iterator)78 State(std::shared_ptr<FunctionLibraryDefinition> flib_def, 79 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr, 80 FunctionLibraryRuntime* flr, 81 std::unique_ptr<DatasetBaseIterator> iterator) 82 : flib_def_(std::move(flib_def)), 83 flr_(flr), 84 pflr_(std::move(pflr)), 85 function_handle_cache_(absl::make_unique<FunctionHandleCache>(flr)), 86 iterator_(std::move(iterator)) {} 87 ~State()88 ~State() { cancellation_manager_.StartCancel(); } 89 90 // Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses 91 // it to set the `iterator` field. DowncastAndSetIterator(std::unique_ptr<IteratorBase> it)92 void DowncastAndSetIterator(std::unique_ptr<IteratorBase> it) { 93 iterator_.reset(static_cast<DatasetBaseIterator*>(it.release())); 94 } 95 flib_def()96 std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; } 97 flr()98 FunctionLibraryRuntime* flr() { return flr_; } 99 pflr()100 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr() { return pflr_; } 101 function_handle_cache()102 FunctionHandleCache* function_handle_cache() { 103 return function_handle_cache_.get(); 104 } 105 resource_mgr()106 ResourceMgr* resource_mgr() { return &resource_mgr_; } 107 cancellation_manager()108 CancellationManager* cancellation_manager() { 109 return &cancellation_manager_; 110 } 111 iterator()112 DatasetBaseIterator* iterator() { return iterator_.get(); } 113 114 private: 115 std::shared_ptr<FunctionLibraryDefinition> flib_def_; 116 FunctionLibraryRuntime* flr_ = nullptr; // not owned 117 std::shared_ptr<ProcessFunctionLibraryRuntime> pflr_; 118 std::unique_ptr<FunctionHandleCache> function_handle_cache_; 119 ResourceMgr resource_mgr_; 120 CancellationManager cancellation_manager_; 121 std::unique_ptr<DatasetBaseIterator> iterator_; 122 }; 123 124 UnboundedThreadPool unbounded_thread_pool_; 125 mutex mu_; 126 // Records the number of currently active `GetNext()` calls. 127 uint64 num_get_next_calls_ TF_GUARDED_BY(mu_) = 0; 128 // Records the start time (in microseconds) of the first `GetNext()` call that 129 // followed the last period of inactivity. 130 uint64 get_next_start_time_us_ TF_GUARDED_BY(mu_) = 0; 131 // Records the end time (in microseconds) of the most recent `GetNext()` call. 132 uint64 get_next_end_time_us_ TF_GUARDED_BY(mu_) = 0; 133 const std::unique_ptr<DeviceMgr> device_mgr_ TF_GUARDED_BY(mu_); 134 std::shared_ptr<State> iterator_state_ TF_GUARDED_BY(mu_); 135 const DataTypeVector output_dtypes_; 136 const std::vector<PartialTensorShape> output_shapes_; 137 const bool collect_metrics_; 138 }; 139 140 class IteratorHandleOp : public OpKernel { 141 public: 142 explicit IteratorHandleOp(OpKernelConstruction* ctx); 143 144 // The resource is deleted from the resource manager only when it is private 145 // to kernel. Ideally the resource should be deleted when it is no longer held 146 // by anyone, but it would break backward compatibility. 147 ~IteratorHandleOp() override; 148 149 void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_); 150 151 private: 152 // During the first Compute(), resource is either created or looked up using 153 // shared_name. In the latter case, the resource found should be verified if 154 // it is compatible with this op's configuration. The verification may fail in 155 // cases such as two graphs asking queues of the same shared name to have 156 // inconsistent capacities. 157 Status VerifyResource(IteratorResource* resource); 158 159 FunctionLibraryRuntime* CreatePrivateFLR( 160 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, 161 std::unique_ptr<FunctionLibraryDefinition>* flib_def, 162 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr); 163 164 mutex mu_; 165 ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. 166 IteratorResource* resource_ TF_GUARDED_BY(mu_) = nullptr; 167 DataTypeVector output_dtypes_; 168 std::vector<PartialTensorShape> output_shapes_; 169 const int graph_def_version_; 170 string name_; 171 }; 172 173 // Like IteratorHandleOp, but creates handles which are never shared, and does 174 // not hold a reference to these handles. The latter is important for eager 175 // execution, since OpKernel instances generally live as long as the program 176 // running them. 177 class AnonymousIteratorHandleOp : public AnonymousResourceOp<IteratorResource> { 178 public: 179 explicit AnonymousIteratorHandleOp(OpKernelConstruction* context); 180 181 private: 182 string name() override; 183 184 Status CreateResource(OpKernelContext* ctx, 185 std::unique_ptr<FunctionLibraryDefinition> flib_def, 186 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 187 FunctionLibraryRuntime* lib, 188 IteratorResource** resource) override; 189 190 DataTypeVector output_dtypes_; 191 std::vector<PartialTensorShape> output_shapes_; 192 const int graph_def_version_; 193 }; 194 195 // A hybrid asynchronous-and-synchronous OpKernel with efficient support for 196 // both modes. 197 // 198 // Inherit from this class when the application logic of the kernel (i) is 199 // implemented synchronously, (ii) must run on a background thread when the 200 // kernel executes in the inter-op threadpool (typically because it depends on 201 // inter-op threadpool threads, e.g. for function execution), and (iii) can run 202 // synchronously on the calling thread when the caller donates a thread 203 // (typically in eager execution). The implementation avoids a thread-hop in 204 // case (iii). 205 // 206 // NOTE: Unlike typical OpKernel subclasses, the application logic is 207 // implemented in a method (DoCompute()) that returns Status. Use 208 // TF_RETURN_IF_ERROR for error-related control flow rather than 209 // OP_REQUIRES_OK(). 210 class HybridAsyncOpKernel : public AsyncOpKernel { 211 public: 212 HybridAsyncOpKernel(OpKernelConstruction* ctx, 213 const char* background_worker_name); 214 215 void Compute(OpKernelContext* ctx) final; 216 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) final; 217 218 protected: 219 virtual Status DoCompute(OpKernelContext* ctx) = 0; 220 221 private: 222 BackgroundWorker background_worker_; 223 }; 224 225 class MakeIteratorOp : public HybridAsyncOpKernel { 226 public: MakeIteratorOp(OpKernelConstruction * ctx)227 explicit MakeIteratorOp(OpKernelConstruction* ctx) 228 : HybridAsyncOpKernel(ctx, "tf_data_make_iterator") {} 229 230 protected: 231 Status DoCompute(OpKernelContext* ctx) override; 232 }; 233 234 class IteratorGetNextOp : public HybridAsyncOpKernel { 235 public: IteratorGetNextOp(OpKernelConstruction * ctx)236 explicit IteratorGetNextOp(OpKernelConstruction* ctx) 237 : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") { 238 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 239 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 240 } 241 242 AsyncOpKernel* AsAsync() override; 243 244 protected: 245 Status DoCompute(OpKernelContext* ctx) override; 246 247 private: 248 DataTypeVector output_types_; 249 std::vector<PartialTensorShape> output_shapes_; 250 }; 251 252 class DeleteIteratorOp : public HybridAsyncOpKernel { 253 public: DeleteIteratorOp(OpKernelConstruction * ctx)254 explicit DeleteIteratorOp(OpKernelConstruction* ctx) 255 : HybridAsyncOpKernel(ctx, "tf_data_delete_iterator") {} 256 257 protected: 258 Status DoCompute(OpKernelContext* ctx) override; 259 }; 260 261 class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel { 262 public: IteratorGetNextAsOptionalOp(OpKernelConstruction * ctx)263 explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) 264 : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next_as_optional") { 265 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 266 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 267 } 268 269 protected: 270 Status DoCompute(OpKernelContext* ctx) override; 271 272 private: 273 DataTypeVector output_types_; 274 std::vector<PartialTensorShape> output_shapes_; 275 }; 276 277 class IteratorToStringHandleOp : public OpKernel { 278 public: IteratorToStringHandleOp(OpKernelConstruction * ctx)279 explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) 280 : OpKernel(ctx) {} 281 282 void Compute(OpKernelContext* ctx) override; 283 }; 284 285 class IteratorFromStringHandleOp : public OpKernel { 286 public: 287 explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx); 288 289 void Compute(OpKernelContext* ctx) override; 290 291 private: 292 DataTypeVector output_dtypes_; 293 std::vector<PartialTensorShape> output_shapes_; 294 }; 295 296 class SerializeIteratorOp : public OpKernel { 297 public: 298 static constexpr const char* const kExternalStatePolicy = 299 "external_state_policy"; 300 301 explicit SerializeIteratorOp(OpKernelConstruction* ctx); 302 303 void Compute(OpKernelContext* ctx) override; 304 305 private: 306 SerializationContext::ExternalStatePolicy external_state_policy_ = 307 SerializationContext::ExternalStatePolicy::kWarn; 308 }; 309 310 class DeserializeIteratorOp : public OpKernel { 311 public: DeserializeIteratorOp(OpKernelConstruction * ctx)312 explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 313 314 void Compute(OpKernelContext* ctx) override; 315 }; 316 317 } // namespace data 318 } // namespace tensorflow 319 320 #endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ 321