• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
18 
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/function_handle_cache.h"
22 #include "tensorflow/core/framework/metrics.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/data/dataset_utils.h"
27 #include "tensorflow/core/kernels/data/unbounded_thread_pool.h"
28 #include "tensorflow/core/kernels/ops_util.h"
29 
30 namespace tensorflow {
31 namespace data {
32 
33 class IteratorResource : public ResourceBase {
34  public:
35   IteratorResource(Env* env, const DataTypeVector& output_dtypes,
36                    const std::vector<PartialTensorShape>& output_shapes,
37                    std::unique_ptr<DeviceMgr> device_mgr,
38                    std::unique_ptr<FunctionLibraryDefinition> flib_def,
39                    std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
40                    FunctionLibraryRuntime* flr);
41 
42   ~IteratorResource() override;
43 
44   // Gets the next output from the iterator managed by this iterator resource.
45   //
46   // If at least one output remains, that output will be stored in
47   // `*out_tensors` and `false` will be stored in `*end_of_sequence`.
48   //
49   // If no more outputs remain, `true` will be stored in `*end_of_sequence`, and
50   // the content of `*out_tensors` will be undefined.
51   Status GetNext(OpKernelContext* ctx, std::vector<Tensor>* out_tensors,
52                  bool* end_of_sequence);
53 
54   // Saves a checkpoint of the state of the iterator through the given `writer`.
55   Status Save(SerializationContext* ctx, IteratorStateWriter* writer);
56 
57   // Restores the state of the iterator from a checkpoint created by `Save`.
58   Status Restore(OpKernelContext* ctx, IteratorStateReader* reader);
59 
60   // Creates an iterator for `dataset`, and associates the iterator with this
61   // iterator resource.
62   //
63   // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
64   // or `Restore`.
65   Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset);
66 
DebugString()67   string DebugString() const override { return "Iterator resource"; }
68 
output_dtypes()69   const DataTypeVector& output_dtypes() const { return output_dtypes_; }
70 
output_shapes()71   const std::vector<PartialTensorShape>& output_shapes() const {
72     return output_shapes_;
73   }
74 
75  private:
76   class State {
77    public:
State(std::shared_ptr<FunctionLibraryDefinition> flib_def,std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr,std::unique_ptr<DatasetBaseIterator> iterator)78     State(std::shared_ptr<FunctionLibraryDefinition> flib_def,
79           std::shared_ptr<ProcessFunctionLibraryRuntime> pflr,
80           FunctionLibraryRuntime* flr,
81           std::unique_ptr<DatasetBaseIterator> iterator)
82         : flib_def_(std::move(flib_def)),
83           flr_(flr),
84           pflr_(std::move(pflr)),
85           function_handle_cache_(absl::make_unique<FunctionHandleCache>(flr)),
86           iterator_(std::move(iterator)) {}
87 
~State()88     ~State() { cancellation_manager_.StartCancel(); }
89 
90     // Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses
91     // it to set the `iterator` field.
DowncastAndSetIterator(std::unique_ptr<IteratorBase> it)92     void DowncastAndSetIterator(std::unique_ptr<IteratorBase> it) {
93       iterator_.reset(static_cast<DatasetBaseIterator*>(it.release()));
94     }
95 
flib_def()96     std::shared_ptr<FunctionLibraryDefinition> flib_def() { return flib_def_; }
97 
flr()98     FunctionLibraryRuntime* flr() { return flr_; }
99 
pflr()100     std::shared_ptr<ProcessFunctionLibraryRuntime> pflr() { return pflr_; }
101 
function_handle_cache()102     FunctionHandleCache* function_handle_cache() {
103       return function_handle_cache_.get();
104     }
105 
resource_mgr()106     ResourceMgr* resource_mgr() { return &resource_mgr_; }
107 
cancellation_manager()108     CancellationManager* cancellation_manager() {
109       return &cancellation_manager_;
110     }
111 
iterator()112     DatasetBaseIterator* iterator() { return iterator_.get(); }
113 
114    private:
115     std::shared_ptr<FunctionLibraryDefinition> flib_def_;
116     FunctionLibraryRuntime* flr_ = nullptr;  // not owned
117     std::shared_ptr<ProcessFunctionLibraryRuntime> pflr_;
118     std::unique_ptr<FunctionHandleCache> function_handle_cache_;
119     ResourceMgr resource_mgr_;
120     CancellationManager cancellation_manager_;
121     std::unique_ptr<DatasetBaseIterator> iterator_;
122   };
123 
124   UnboundedThreadPool unbounded_thread_pool_;
125   mutex mu_;
126   // Records the number of currently active `GetNext()` calls.
127   uint64 num_get_next_calls_ TF_GUARDED_BY(mu_) = 0;
128   // Records the start time (in microseconds) of the first `GetNext()` call that
129   // followed the last period of inactivity.
130   uint64 get_next_start_time_us_ TF_GUARDED_BY(mu_) = 0;
131   // Records the end time (in microseconds) of the most recent `GetNext()` call.
132   uint64 get_next_end_time_us_ TF_GUARDED_BY(mu_) = 0;
133   const std::unique_ptr<DeviceMgr> device_mgr_ TF_GUARDED_BY(mu_);
134   std::shared_ptr<State> iterator_state_ TF_GUARDED_BY(mu_);
135   const DataTypeVector output_dtypes_;
136   const std::vector<PartialTensorShape> output_shapes_;
137   const bool collect_metrics_;
138 };
139 
140 class IteratorHandleOp : public OpKernel {
141  public:
142   explicit IteratorHandleOp(OpKernelConstruction* ctx);
143 
144   // The resource is deleted from the resource manager only when it is private
145   // to kernel. Ideally the resource should be deleted when it is no longer held
146   // by anyone, but it would break backward compatibility.
147   ~IteratorHandleOp() override;
148 
149   void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_);
150 
151  private:
152   // During the first Compute(), resource is either created or looked up using
153   // shared_name. In the latter case, the resource found should be verified if
154   // it is compatible with this op's configuration. The verification may fail in
155   // cases such as two graphs asking queues of the same shared name to have
156   // inconsistent capacities.
157   Status VerifyResource(IteratorResource* resource);
158 
159   FunctionLibraryRuntime* CreatePrivateFLR(
160       OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
161       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
162       std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr);
163 
164   mutex mu_;
165   ContainerInfo cinfo_;  // Written once under mu_ then constant afterwards.
166   IteratorResource* resource_ TF_GUARDED_BY(mu_) = nullptr;
167   DataTypeVector output_dtypes_;
168   std::vector<PartialTensorShape> output_shapes_;
169   const int graph_def_version_;
170   string name_;
171 };
172 
173 // Like IteratorHandleOp, but creates handles which are never shared, and does
174 // not hold a reference to these handles. The latter is important for eager
175 // execution, since OpKernel instances generally live as long as the program
176 // running them.
177 class AnonymousIteratorHandleOp : public AnonymousResourceOp<IteratorResource> {
178  public:
179   explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
180 
181  private:
182   string name() override;
183 
184   Status CreateResource(OpKernelContext* ctx,
185                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
186                         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
187                         FunctionLibraryRuntime* lib,
188                         IteratorResource** resource) override;
189 
190   DataTypeVector output_dtypes_;
191   std::vector<PartialTensorShape> output_shapes_;
192   const int graph_def_version_;
193 };
194 
195 // A hybrid asynchronous-and-synchronous OpKernel with efficient support for
196 // both modes.
197 //
198 // Inherit from this class when the application logic of the kernel (i) is
199 // implemented synchronously, (ii) must run on a background thread when the
200 // kernel executes in the inter-op threadpool (typically because it depends on
201 // inter-op threadpool threads, e.g. for function execution), and (iii) can run
202 // synchronously on the calling thread when the caller donates a thread
203 // (typically in eager execution). The implementation avoids a thread-hop in
204 // case (iii).
205 //
206 // NOTE: Unlike typical OpKernel subclasses, the application logic is
207 // implemented in a method (DoCompute()) that returns Status. Use
208 // TF_RETURN_IF_ERROR for error-related control flow rather than
209 // OP_REQUIRES_OK().
210 class HybridAsyncOpKernel : public AsyncOpKernel {
211  public:
212   HybridAsyncOpKernel(OpKernelConstruction* ctx,
213                       const char* background_worker_name);
214 
215   void Compute(OpKernelContext* ctx) final;
216   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) final;
217 
218  protected:
219   virtual Status DoCompute(OpKernelContext* ctx) = 0;
220 
221  private:
222   BackgroundWorker background_worker_;
223 };
224 
225 class MakeIteratorOp : public HybridAsyncOpKernel {
226  public:
MakeIteratorOp(OpKernelConstruction * ctx)227   explicit MakeIteratorOp(OpKernelConstruction* ctx)
228       : HybridAsyncOpKernel(ctx, "tf_data_make_iterator") {}
229 
230  protected:
231   Status DoCompute(OpKernelContext* ctx) override;
232 };
233 
234 class IteratorGetNextOp : public HybridAsyncOpKernel {
235  public:
IteratorGetNextOp(OpKernelConstruction * ctx)236   explicit IteratorGetNextOp(OpKernelConstruction* ctx)
237       : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
238     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
239     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
240   }
241 
242   AsyncOpKernel* AsAsync() override;
243 
244  protected:
245   Status DoCompute(OpKernelContext* ctx) override;
246 
247  private:
248   DataTypeVector output_types_;
249   std::vector<PartialTensorShape> output_shapes_;
250 };
251 
252 class DeleteIteratorOp : public HybridAsyncOpKernel {
253  public:
DeleteIteratorOp(OpKernelConstruction * ctx)254   explicit DeleteIteratorOp(OpKernelConstruction* ctx)
255       : HybridAsyncOpKernel(ctx, "tf_data_delete_iterator") {}
256 
257  protected:
258   Status DoCompute(OpKernelContext* ctx) override;
259 };
260 
261 class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel {
262  public:
IteratorGetNextAsOptionalOp(OpKernelConstruction * ctx)263   explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
264       : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next_as_optional") {
265     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
266     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
267   }
268 
269  protected:
270   Status DoCompute(OpKernelContext* ctx) override;
271 
272  private:
273   DataTypeVector output_types_;
274   std::vector<PartialTensorShape> output_shapes_;
275 };
276 
277 class IteratorToStringHandleOp : public OpKernel {
278  public:
IteratorToStringHandleOp(OpKernelConstruction * ctx)279   explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
280       : OpKernel(ctx) {}
281 
282   void Compute(OpKernelContext* ctx) override;
283 };
284 
285 class IteratorFromStringHandleOp : public OpKernel {
286  public:
287   explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx);
288 
289   void Compute(OpKernelContext* ctx) override;
290 
291  private:
292   DataTypeVector output_dtypes_;
293   std::vector<PartialTensorShape> output_shapes_;
294 };
295 
296 class SerializeIteratorOp : public OpKernel {
297  public:
298   static constexpr const char* const kExternalStatePolicy =
299       "external_state_policy";
300 
301   explicit SerializeIteratorOp(OpKernelConstruction* ctx);
302 
303   void Compute(OpKernelContext* ctx) override;
304 
305  private:
306   SerializationContext::ExternalStatePolicy external_state_policy_ =
307       SerializationContext::ExternalStatePolicy::kWarn;
308 };
309 
310 class DeserializeIteratorOp : public OpKernel {
311  public:
DeserializeIteratorOp(OpKernelConstruction * ctx)312   explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
313 
314   void Compute(OpKernelContext* ctx) override;
315 };
316 
317 }  // namespace data
318 }  // namespace tensorflow
319 
320 #endif  // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
321