• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/model.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/thread_factory.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variant_encode_decode.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/lib/core/threadpool_interface.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/lib/strings/strcat.h"
42 #include "tensorflow/core/platform/cpu_info.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/tracing.h"
45 
46 // Polymorphic datasets should support all primitive TensorFlow
47 // types. Use this macro to expand `m(T)` once for each primitive type
48 // `T`, e.g. to build a `switch` statement.
49 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
50 
51 namespace tensorflow {
52 
53 // Forward declarations to avoid introducing a dependency on headers in
54 // "tensorflow/core/graph/...".
55 class GraphDefBuilder;
56 class Node;
57 
58 namespace data {
59 
60 constexpr int kInfiniteCardinality = -1;
61 constexpr int kUnknownCardinality = -2;
62 
63 // This constant is a magic number that is used (as a prefix) to identify keys
64 // used for serialization of iterator state.
65 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
66 constexpr char kPipe[] = "|";
67 constexpr char kColon[] = ":";
68 
69 class DatasetBase;
70 class SerializationContext;
71 
72 // Interface for reading values from a key-value store.
73 // Used for restoring iterator state. This class is thread safe.
74 // Please see comment on IteratorStateWriter for guidance around using the
75 // Read*(key, val) vs Read*(name, key, val).
76 class IteratorStateReader {
77  public:
78   virtual Status ReadScalar(StringPiece key, int64* val) = 0;
79   virtual Status ReadScalar(StringPiece key, tstring* val) = 0;
80   virtual Status ReadTensor(StringPiece key, Tensor* val) = 0;
81 
82   virtual Status ReadScalar(StringPiece name, StringPiece key, int64* val) = 0;
83   virtual Status ReadScalar(StringPiece name, StringPiece key,
84                             tstring* val) = 0;
85   virtual Status ReadTensor(StringPiece name, StringPiece key, Tensor* val) = 0;
86 
87   virtual bool Contains(StringPiece key) = 0;
88   virtual bool Contains(StringPiece name, StringPiece key) = 0;
89 
~IteratorStateReader()90   virtual ~IteratorStateReader() {}
91 };
92 
93 // Interface for writing values to a key-value store.
94 // Used for saving iterator state. Not thread safe.
95 // The IteratorStateWriter creates a tensor for each unique iterator name it
96 // sees. For the Write*(key, val) API's the key is expected to encode this
97 // name as keys are required to be produced using the full_name() method.
98 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
99 // might exceed the 2 GB limit, you can pass an explicit name in via the
100 // Write*(name, key, val) APIs allowing you to further split up the state
101 // into more manageable chunks.
102 class IteratorStateWriter {
103  public:
104   virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
105   virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
106   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
107 
108   virtual Status WriteScalar(StringPiece name, StringPiece key,
109                              const int64 val) = 0;
110   virtual Status WriteScalar(StringPiece name, StringPiece key,
111                              const tstring& val) = 0;
112   virtual Status WriteTensor(StringPiece name, StringPiece key,
113                              const Tensor& val) = 0;
114 
~IteratorStateWriter()115   virtual ~IteratorStateWriter() {}
116 };
117 
118 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
119 class GraphDefBuilderWrapper {
120  public:
GraphDefBuilderWrapper(GraphDefBuilder * b)121   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
122 
123   // Adds a Const node with scalar value to the Graph.
124   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
125   // non-null if the method returns with an OK status.
126   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
127   template <typename T>
AddScalar(const T & val,Node ** output)128   Status AddScalar(const T& val, Node** output) {
129     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
130     val_t.scalar<T>()() = val;
131     AddTensorInternal(val_t, output);
132     if (*output == nullptr) {
133       return errors::Internal("AddScalar: Failed to build Const op.");
134     }
135     return Status::OK();
136   }
137 
138   // Adds a Const node with vector value to the Graph.
139   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
140   // non-null if the method returns with an OK status.
141   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
142   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
143   template <typename T>
AddVector(const std::vector<T> & val,Node ** output)144   Status AddVector(const std::vector<T>& val, Node** output) {
145     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
146                           TensorShape({static_cast<int64>(val.size())}));
147     for (size_t i = 0; i < val.size(); i++) {
148       val_t.flat<T>()(i) = val[i];
149     }
150     AddTensorInternal(val_t, output);
151     if (*output == nullptr) {
152       return errors::Internal("AddVector: Failed to build Const op.");
153     }
154     return Status::OK();
155   }
156 
AddVector(const std::vector<string> & val,Node ** output)157   Status AddVector(const std::vector<string>& val, Node** output) {
158     Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
159                           TensorShape({static_cast<int64>(val.size())}));
160     for (size_t i = 0; i < val.size(); i++) {
161       val_t.flat<tstring>()(i) = val[i];
162     }
163     AddTensorInternal(val_t, output);
164     if (*output == nullptr) {
165       return errors::Internal("AddVector: Failed to build Const op.");
166     }
167     return Status::OK();
168   }
169 
170   // Adds a `Const` node for the given tensor value to the graph.
171   //
172   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
173   // non-null if the method returns with an OK status. The returned `Node`
174   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)175   Status AddTensor(const Tensor& val, Node** output) {
176     AddTensorInternal(val, output);
177     if (*output == nullptr) {
178       return errors::Internal("AddTensor: Failed to build Const op.");
179     }
180     return Status::OK();
181   }
182 
183   // Adds a `Placeholder` node for the given tensor value to the graph.
184   //
185   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
186   // non-null if the method returns with an OK status. The returned `Node`
187   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)188   Status AddPlaceholder(const Tensor& val, Node** output) {
189     AddPlaceholderInternal(val, output);
190     if (*output == nullptr) {
191       return errors::Internal(
192           "AddPlaceholder: Failed to build Placeholder op.");
193     }
194     return Status::OK();
195   }
196 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)197   Status AddDataset(const DatasetBase* dataset,
198                     const std::vector<Node*>& inputs, Node** output) {
199     return AddDataset(dataset, inputs, {}, output);
200   }
201 
202   // Adds a node corresponding to the `DatasetType` to the Graph.
203   // Return value of `DatasetType::type_string()` is used as the op type for the
204   // node.
205   // Values for the output_types and output_shapes node attributes are also
206   // written if those attributes are defined in the OpDef.
207   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
208   // non-null if the method returns with an OK status.
209   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)210   Status AddDataset(const DatasetBase* dataset,
211                     const std::vector<Node*>& inputs,
212                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
213                     Node** output) {
214     std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
215     for (size_t i = 0; i < inputs.size(); i++) {
216       enumerated_inputs[i] = std::make_pair(i, inputs[i]);
217     }
218     return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
219   }
220 
221   Status AddDataset(
222       const DatasetBase* dataset,
223       const std::vector<std::pair<size_t, Node*>>& inputs,
224       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
225       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
226       Node** output);
227 
228   // Adds a user-defined function with name `function_name` to the graph and
229   // recursively adds all functions it references. If a function with a matching
230   // name has already been added, returns with OK status. If a user-defined with
231   // name `function_name` is not found in the context's function library,
232   // returns an InvalidArgumentError. If the function with name `function_name`
233   // or any of its dependent functions are stateful, and the context does not
234   // explicitly permit stateful functions, returns an InvalidArgument error.
235   Status AddFunction(SerializationContext* ctx, const string& function_name,
236                      const FunctionLibraryDefinition& lib_def);
237 
238   template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)239   void BuildAttrValue(const T& value, AttrValue* attr) {
240     SetAttrValue(value, attr);
241   }
242 
243  private:
244   void AddPlaceholderInternal(const Tensor& val, Node** output);
245   void AddTensorInternal(const Tensor& val, Node** output);
246   bool HasAttr(const string& op_type_name, const string& attr_name) const;
247 
HasAttr(const OpDef * op_def,const string & attr_name)248   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
249     for (auto attr : op_def->attr()) {
250       if (attr.name() == attr_name) {
251         return true;
252       }
253     }
254     return false;
255   }
256 
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)257   Status AddAttrFunctions(SerializationContext* ctx,
258                           const AttrValue& attr_value,
259                           const FunctionLibraryDefinition& lib_def) {
260     if (attr_value.has_func()) {
261       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
262     } else if (attr_value.has_list()) {
263       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
264         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
265       }
266     }
267     return Status::OK();
268   }
269 
270   GraphDefBuilder* b_;
271 };
272 
273 class StatsAggregator;
274 class FunctionHandleCache;
275 
276 // A utility class for running a function and ensuring that there is always a
277 // `tensorflow::data` symbol on the stack.
278 class Runner {
279  public:
~Runner()280   virtual ~Runner() {}
281 
282   // Runs the given function.
283   virtual void Run(const std::function<void()>& f) = 0;
284 
285   // Returns a global singleton Runner.
286   static Runner* get();
287 };
288 
289 // A cut-down version of `OpKernelContext` for running computations in
290 // iterators. Note that we cannot simply use `OpKernelContext` here because we
291 // might run computation in an iterator whose lifetime is not nested within the
292 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
293 //
294 // TODO(mrry): We're making some daring assumptions about the lifetime of the
295 // runner passed in here. A runner will be deleted when the original step ends,
296 // but all existing runners only close over session-lifetime (or longer-lived)
297 // state, so we can make a copy of the function. There's nothing in the
298 // definition of the API from which we took the runner to guarantee that what we
299 // are doing is safe. We should formalize the properties here.
300 class IteratorContext {
301  public:
302   struct Params {
ParamsParams303     explicit Params(IteratorContext* ctx)
304         : allocator_getter(ctx->allocator_getter()),
305           cancellation_manager(ctx->cancellation_manager()),
306           env(ctx->env()),
307           flr(ctx->flr()),
308           function_handle_cache(ctx->function_handle_cache()),
309           resource_mgr(ctx->resource_mgr()),
310           model(ctx->model()),
311           runner(*(ctx->runner())),
312           runner_threadpool_size(ctx->runner_threadpool_size()),
313           stats_aggregator(ctx->stats_aggregator()),
314           thread_factory(ctx->thread_factory()),
315           thread_pool(ctx->thread_pool()) {}
316 
ParamsParams317     explicit Params(OpKernelContext* ctx)
318         : env(ctx->env()), flr(ctx->function_library()) {
319       // NOTE: need reinterpret_cast because function.h forward-declares Device.
320       DeviceBase* device =
321           reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
322       allocator_getter = [device](AllocatorAttributes attrs) {
323         return device->GetAllocator(attrs);
324       };
325 
326       thread::ThreadPool* thread_pool =
327           ctx->device()->tensorflow_device_thread_pool();
328       if (thread_pool) {
329         runner_threadpool_size = thread_pool->NumThreads();
330       } else {
331         runner_threadpool_size = port::MaxParallelism();
332       }
333 
334       // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
335       // that a symbol in the tensorflow::data namespace is always on the stack
336       // when executing a function inside a Dataset.
337       runner = std::bind(
338           [](
339               // Note: `runner` is a const reference to avoid copying it.
340               const std::function<void(std::function<void()>)>& ctx_runner,
341               std::function<void()> fn) {
342             std::function<void()> wrapped_fn = std::bind(
343                 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
344                 std::move(fn));
345             ctx_runner(std::move(wrapped_fn));
346           },
347           *ctx->runner(), std::placeholders::_1);
348     }
349 
350     // The Allocator to be used to allocate the output of an iterator.
351     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
352 
353     // The CancellationManager to be used to cancel execution of ops.
354     CancellationManager* cancellation_manager;
355 
356     // Interface to operating system functionality.
357     Env* env = nullptr;
358 
359     // The FunctionLibraryRuntime object to be used to make function calls.
360     FunctionLibraryRuntime* flr = nullptr;
361 
362     // A FunctionHandleCache that owns all the function handles. Not owned.
363     FunctionHandleCache* function_handle_cache = nullptr;
364 
365     // A resource manager for storing dataset-related state, e.g. random
366     // seeds or cached tensors. Not owned.
367     ResourceMgr* resource_mgr = nullptr;
368 
369     // If non-null, identifies the object used for performance modeling.
370     std::shared_ptr<model::Model> model = nullptr;
371 
372     // Function call support.
373     std::function<void(std::function<void()>)> runner = nullptr;
374 
375     // Number of threads used for executing user-defined functions.
376     int32 runner_threadpool_size = 0;
377 
378     // The `StatsAggregator` object to record statistics about the iterator.
379     std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
380 
381     // A factory for creating threads to perform blocking work.
382     std::shared_ptr<ThreadFactory> thread_factory = nullptr;
383 
384     // A shared thread pool to schedule computation into.
385     thread::ThreadPoolInterface* thread_pool = nullptr;
386   };
387 
IteratorContext(IteratorContext * ctx)388   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
389 
IteratorContext(OpKernelContext * ctx)390   explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
391 
IteratorContext(Params params)392   explicit IteratorContext(Params params) : params_(std::move(params)) {}
393 
allocator(AllocatorAttributes attrs)394   Allocator* allocator(AllocatorAttributes attrs) {
395     return params_.allocator_getter(attrs);
396   }
397 
allocator_getter()398   std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
399     return params_.allocator_getter;
400   }
401 
cancellation_manager()402   CancellationManager* cancellation_manager() {
403     return params_.cancellation_manager;
404   }
405 
env()406   Env* env() const { return params_.env; }
407 
flr()408   FunctionLibraryRuntime* flr() { return params_.flr; }
409 
function_handle_cache()410   FunctionHandleCache* function_handle_cache() {
411     return params_.function_handle_cache;
412   }
413 
resource_mgr()414   ResourceMgr* resource_mgr() { return params_.resource_mgr; }
415 
model()416   const std::shared_ptr<model::Model>& model() { return params_.model; }
417 
runner()418   std::function<void(std::function<void()>)>* runner() {
419     return &params_.runner;
420   }
421 
runner_threadpool_size()422   int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
423 
stats_aggregator()424   std::shared_ptr<StatsAggregator> stats_aggregator() {
425     return params_.stats_aggregator;
426   }
427 
thread_factory()428   const std::shared_ptr<ThreadFactory>& thread_factory() {
429     return params_.thread_factory;
430   }
431 
thread_pool()432   thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
433 
params()434   Params params() { return params_; }
435 
CreateThreadPool(const string & name,int num_threads)436   std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
437                                                        int num_threads) {
438     if (params_.thread_pool) {
439       // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
440       // is an instance of `thread::ThreadPoolInterface`). Notably, the
441       // ownership of `params_.thread_pool` is *not* transferred onto the newly
442       // created `ThreadPool` instance.
443       return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
444     } else {
445       return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
446                                                    name, num_threads,
447                                                    /*low_latency_hint=*/false);
448     }
449   }
450 
StartThread(const string & name,std::function<void ()> fn)451   std::unique_ptr<Thread> StartThread(const string& name,
452                                       std::function<void()> fn) {
453     if (params_.thread_factory) {
454       return params_.thread_factory->StartThread(name, std::move(fn));
455     } else {
456       return absl::WrapUnique(
457           Env::Default()->StartThread({}, name, std::move(fn)));
458     }
459   }
460 
461  private:
462   Params params_;
463 };
464 
465 // Aggregates runtime support needed for dataset and iterator serialization.
466 class SerializationContext {
467  public:
468   // Enum describing what to do during serialization when external state is
469   // encountered.
470   enum class ExternalStatePolicy : int64 {
471     // Proceed with serialization, but log a warning about what state will be
472     // lost.
473     kWarn = 0,
474     // Proceed with serialization without logging any warning.
475     kIgnore = 1,
476     // Fail the serialization with an error.
477     kFail = 2,
478   };
479 
480   struct Params {
481     std::vector<std::pair<string, Tensor>>* input_list = nullptr;  // Not owned.
482 
483     // Indicates what to do if the dataset depends on external state.
484     ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
485 
486     // Indicates whether an attempt to serialize a dataset that does not
487     // implement serialization should result in an error. If set to `false`, the
488     // serialized graph will replace the dataset with a placeholder returned in
489     // `input_list`.
490     bool fail_if_unimplemented = true;
491 
492     // Indicates whether (potentionally large) data tensors should be
493     // serialized, or replaced with a placeholder returned in `input_list`. The
494     // latter makes sense to do when performing data agnostic graph rewrites to
495     // reduce the memory usage.
496     bool serialize_data_tensors = true;
497 
498     // Indicates whether datasets that use random seeds should have the values
499     // of random seeds serialized or not. If the values of random seeds are
500     // serialized, the deserialized dataset will have the same seeds as the
501     // original dataset. Otherwise, the deserialized dataset will use different
502     // seeds. This param does not affect datasets that use fixed seeds; fixed
503     // seeds will always be preserved.
504     bool preserve_random_seeds = true;
505   };
506 
SerializationContext(Params params)507   explicit SerializationContext(Params params) : params_(params) {}
508 
input_list()509   std::vector<std::pair<string, Tensor>>* input_list() {
510     return params_.input_list;
511   }
512 
external_state_policy()513   ExternalStatePolicy external_state_policy() const {
514     return params_.external_state_policy;
515   }
516 
fail_if_unimplemented()517   bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
518 
serialize_data_tensors()519   bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
520 
preserve_random_seeds()521   bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
522 
523  private:
524   Params params_;
525 
526   TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
527 };
528 
529 // Represents the current position in a range of outputs, where the
530 // range of outputs is typically represented by an `DatasetBase`,
531 // defined below.
532 class IteratorBase {
533  public:
~IteratorBase()534   virtual ~IteratorBase() {
535     for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
536       (*rit)();
537     }
538   }
539 
540   // Gets the next output from the range that this iterator is traversing.
541   //
542   // If at least one output remains in this iterator's range, that
543   // output will be stored in `*out_tensors` and `false` will be
544   // stored in `*end_of_sequence`.
545   //
546   // If no more outputs remain in this iterator's range, `true` will
547   // be stored in `*end_of_sequence`, and the content of
548   // `*out_tensors` will be undefined.
549   //
550   // Implementations should never return `OutOfRange` error. If at end of
551   // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
552   // Internally raised `OutOfRange` errors that do not imply end of sequence
553   // should be converted to a different error type before being propagated to
554   // the caller.
555   //
556   // This method is thread-safe.
557   //
558   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
559   // potentially remove this method.
560   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
561                          bool* end_of_sequence) = 0;
562 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)563   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
564                  bool* end_of_sequence) {
565     return GetNext(&ctx, out_tensors, end_of_sequence);
566   }
567 
568   // Returns a vector of DataType values, representing the respective
569   // element types of each tuple component in the outputs of this
570   // iterator.
571   virtual const DataTypeVector& output_dtypes() const = 0;
572 
573   // Returns a vector of tensor shapes, representing the respective
574   // (and possibly partially defined) shapes of each tuple component
575   // in the outputs of this iterator.
576   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
577 
578   // Returns a string that identifies the sequence of iterators leading up to
579   // this iterator.
580   virtual const string& prefix() const = 0;
581 
582   // Performs initialization that needs to happen outside of a constructor to
583   // properly propagate errors.
Initialize(IteratorContext * ctx)584   virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
585 
586   // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)587   virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
588     return SaveInternal(writer);
589   }
590 
591  protected:
592   // Returns a node that models this iterator.
593   virtual std::shared_ptr<model::Node> CreateNode(
594       IteratorContext* ctx, model::Node::Args args) const = 0;
595 
596   // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)597   Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
598     return RestoreInternal(ctx, reader);
599   }
600 
601   // This is needed so that sub-classes of IteratorBase can call
602   // `SaveInternal` on their input iterators.
SaveInput(IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)603   Status SaveInput(IteratorStateWriter* writer,
604                    const std::unique_ptr<IteratorBase>& input) {
605     return input->SaveInternal(writer);
606   }
607 
608   // This is needed so that sub-classes of IteratorBase can call
609   // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)610   Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
611                       const std::unique_ptr<IteratorBase>& input) {
612     return input->RestoreInternal(ctx, reader);
613   }
614 
615   // Saves the state of this iterator.
616   //
617   // This method is used to store the state of the iterator in a checkpoint.
618   //
619   // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
620   // implementations have an override.
SaveInternal(IteratorStateWriter * writer)621   virtual Status SaveInternal(IteratorStateWriter* writer) {
622     return errors::Unimplemented("SaveInternal");
623   }
624 
625   // Restores the state of this iterator.
626   //
627   // This method is used to restore the state of the iterator from a checkpoint.
628   //
629   // Implementations may assume that the iterator is in a clean state. That is,
630   // its `Initialize` method has been called, but its `GetNext` method has
631   // never been called.
632   //
633   // TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
634   // implementations have an override.
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)635   virtual Status RestoreInternal(IteratorContext* ctx,
636                                  IteratorStateReader* reader) {
637     return errors::Unimplemented("RestoreInternal");
638   }
639 
640   // Returns the number of elements produced by this itertaor.
num_elements()641   int64 num_elements() const {
642     if (node_) return node_->num_elements();
643     return 0;
644   }
645 
646  private:
647   // For access to `AddCleanupFunction` and `Restore`.
648   friend class DatasetBase;
649   friend class DatasetBaseIterator;  // for access to `node_`
650 
651   // Registers a cleanup function to be called upon object destruction.
652   //
653   // Registered functions are invoked in the reserve order of registration.
AddCleanupFunction(std::function<void ()> && cleanup_fn)654   void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
655     cleanup_fns_.push_back(std::move(cleanup_fn));
656   }
657 
658   // Associates the given performance modeling `Node` with this iterator.
SetNode(std::shared_ptr<model::Node> node)659   void SetNode(std::shared_ptr<model::Node> node) { node_ = node.get(); }
660 
661   std::vector<std::function<void()>> cleanup_fns_;
662   model::Node* node_ = nullptr;  // Not owned.
663 };
664 
665 // Represents runtime information needed to construct a dataset.
666 class DatasetContext {
667  public:
668   struct Params {
669     string type_string;  // op type name of this dataset.
670     string node_name;    // graph node name of this dataset op, uniquely
671                          // identifying the dataset in the graph.
672   };
673 
DatasetContext(Params params)674   explicit DatasetContext(Params params) : params_(std::move(params)) {}
675 
DatasetContext(OpKernelContext * ctx)676   explicit DatasetContext(OpKernelContext* ctx) {
677     params_.type_string = ctx->op_kernel().type_string();
678     params_.node_name = ctx->op_kernel().name();
679   }
680 
type_string()681   const string& type_string() const { return params_.type_string; }
node_name()682   const string& node_name() const { return params_.node_name; }
683 
684  private:
685   Params params_;
686 };
687 
688 // Returns the number of bytes allocated for the given tensor.
689 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
690 
691 // Returns the estimated memory usage in bytes of the given tensor.
692 int64 GetTotalBytes(const std::vector<Tensor>& element);
693 
694 // Validates and extracts a `DatasetBase` object from `tensor`.
695 //
696 // `tensor` must have been written by a call to SetVariantTensorToDataset().
697 //
698 // The retrieved pointer is a borrowed reference to the dataset, which is owned
699 // by the tensor. The consumer must either acquire its own reference to the
700 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
701 // destroyed or mutated while the retrieved pointer is in use.
702 Status GetDatasetFromVariantTensor(const Tensor& tensor,
703                                    DatasetBase** out_dataset);
704 
705 // Stores a `DatasetBase` object in `tensor`.
706 //
707 // The ownership of `dataset` is transferred to `tensor`.
708 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
709 
710 // Represents a (potentially infinite) range of outputs, where each
711 // output is a tuple of tensors.
712 class DatasetBase : public core::RefCounted {
713  public:
714   // Key for storing the Dataset graph in the serialized format.
715   TF_EXPORT static const char kDatasetGraphKey[];
716 
717   // Key for storing the output node of the Dataset graph in the serialized
718   // format.
719   TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
720 
DatasetBase(DatasetContext && ctx)721   explicit DatasetBase(DatasetContext&& ctx)
722       : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
723 
724   // Op type name of this dataset.
type_string()725   const string& type_string() const { return type_string_; }
726 
727   // Graph node name of this dataset op, uniquely identifying the dataset in
728   // the graph.
node_name()729   const string& node_name() const { return node_name_; }
730 
731   // Returns a new iterator for iterating over the range of elements in
732   // this dataset.
733   //
734   // This method may be called multiple times on the same instance,
735   // and the resulting iterators will have distinct state. Each
736   // iterator will traverse all elements in this dataset from the
737   // start.
738   //
739   // The prefix identifies the sequence of iterators leading up to the newly
740   // created iterator.
MakeIterator(IteratorContext * ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)741   Status MakeIterator(IteratorContext* ctx, const string& output_prefix,
742                       std::unique_ptr<IteratorBase>* iterator) const {
743     *iterator = MakeIteratorInternal(output_prefix);
744     if (const auto& model = ctx->model()) {
745       const string& prefix = (*iterator)->prefix();
746       (*iterator)->SetNode(model->AddNode(MakeNodeFactory(ctx, iterator->get()),
747                                           prefix, output_prefix));
748       (*iterator)->AddCleanupFunction(
749           [model, prefix]() { model->RemoveNode(prefix); });
750     }
751     Status s = (*iterator)->Initialize(ctx);
752     if (!s.ok()) {
753       // Reset the iterator to avoid returning an uninitialized iterator.
754       iterator->reset();
755     }
756     return s;
757   }
758 
MakeIterator(IteratorContext && ctx,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)759   Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
760                       std::unique_ptr<IteratorBase>* iterator) const {
761     return MakeIterator(&ctx, output_prefix, iterator);
762   }
763 
764   // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)765   Status MakeIteratorFromCheckpoint(
766       IteratorContext* ctx, const string& output_prefix,
767       IteratorStateReader* reader,
768       std::unique_ptr<IteratorBase>* iterator) const {
769     std::unique_ptr<IteratorBase> it;
770     TF_RETURN_IF_ERROR(MakeIterator(ctx, output_prefix, &it));
771     TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
772     *iterator = std::move(it);
773     return Status::OK();
774   }
775 
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)776   Status MakeIteratorFromCheckpoint(
777       IteratorContext&& ctx, const string& output_prefix,
778       IteratorStateReader* reader,
779       std::unique_ptr<IteratorBase>* iterator) const {
780     return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
781   }
782 
783   // Returns a vector of DataType values, representing the respective
784   // element types of each tuple component in the outputs of this
785   // dataset.
786   virtual const DataTypeVector& output_dtypes() const = 0;
787 
788   // Returns a vector of tensor shapes, representing the respective
789   // (and possibly partially defined) shapes of each tuple component
790   // in the outputs of this dataset.
791   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
792 
793   // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()794   virtual int64 AllocatedBytes() const { return 0; }
795 
796   // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()797   virtual int64 TotalBytes() const { return 0; }
798 
799   // Returns the cardinality of this dataset.
Cardinality()800   virtual int64 Cardinality() const { return kUnknownCardinality; }
801 
802   // A human-readable debug string for this dataset.
803   virtual string DebugString() const = 0;
804 
805   // If the dataset is stateful it will not be possible to save its graph or
806   // checkpoint the state of its iterators.
807   //
808   // TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
809   // migrated over to `CheckExternalState`.
IsStateful()810   virtual bool IsStateful() const { return false; }
811 
812   // Indicates whether the dataset depends on any external state. If so, the
813   // method returns `errors::FailedPrecondition` with a message that identifies
814   // the external state. Otherwise, the method returns `Status::OK()`.
815   //
816   // TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
817   // implementations have an override.
CheckExternalState()818   virtual Status CheckExternalState() const {
819     if (IsStateful()) {
820       return errors::FailedPrecondition("Dataset cannot be serialized.");
821     }
822     return Status::OK();
823   }
824 
825  protected:
826   friend Status AsGraphDef(
827       OpKernelContext* ctx, const DatasetBase* dataset,
828       SerializationContext&& serialization_ctx,
829       GraphDef* graph_def);  // For access to graph related members.
830   friend class CapturedFunction;
831 
832   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
833    public:
DatasetGraphDefBuilder(GraphDefBuilder * b)834     explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
835         : GraphDefBuilderWrapper(b) {}
836     Status AddInputDataset(SerializationContext* ctx,
837                            const DatasetBase* dataset, Node** output);
838   };
839 
840   // Serializes the dataset into a `GraphDef`, which has two uses:
841   //
842   // 1) To perform static input pipeline optimizations, tf.data serializes the
843   // dataset graph, applies graph rewrites, and then deserializes the graph.
844   // If a subclass of `DatasetBase` does not implement this method, then it will
845   // be excluded from static optimizations (and so will any upstream datasets).
846   //
847   // 2) To save the dataset so that it can restore at a later point (possibly in
848   // different environment). If a subclass of `DatasetBase` does not implement
849   // this method, then this migration will not be possible.
850   virtual Status AsGraphDefInternal(SerializationContext* ctx,
851                                     DatasetGraphDefBuilder* b,
852                                     Node** node) const = 0;
853 
854   virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
855       const string& prefix) const = 0;
856 
857  private:
858   // Returns a factory for nodes that represent the given iterator.
MakeNodeFactory(IteratorContext * ctx,IteratorBase * iterator)859   static model::Node::Factory MakeNodeFactory(IteratorContext* ctx,
860                                               IteratorBase* iterator) {
861     return [ctx, iterator](model::Node::Args args) {
862       return iterator->CreateNode(ctx, std::move(args));
863     };
864   }
865 
866   const string type_string_;
867   const string node_name_;
868 };
869 
870 // Represents an iterator that is associated with a particular dataset.
871 class DatasetBaseIterator : public IteratorBase {
872  public:
873   struct BaseParams {
874     // Owns one reference on the shared dataset object.
875     const DatasetBase* dataset;
876 
877     // Identifies the sequence of iterators leading up to this iterator.
878     const string prefix;
879   };
880 
DatasetBaseIterator(const BaseParams & params)881   explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
882     params_.dataset->Ref();
883     VLOG(2) << prefix() << " constructor";
884   }
885 
~DatasetBaseIterator()886   ~DatasetBaseIterator() override {
887     VLOG(2) << prefix() << " destructor";
888     params_.dataset->Unref();
889   }
890 
output_dtypes()891   const DataTypeVector& output_dtypes() const override {
892     return params_.dataset->output_dtypes();
893   }
894 
output_shapes()895   const std::vector<PartialTensorShape>& output_shapes() const override {
896     return params_.dataset->output_shapes();
897   }
898 
dataset()899   virtual const DatasetBase* dataset() const { return params_.dataset; }
900 
901   // The sequence of iterators leading up to this iterator.
prefix()902   const string& prefix() const override { return params_.prefix; }
903 
904   // Returns a name to be used for the TraceMe event.
905   //
906   // NOTE: TraceMe support passing key value pairs of "arguments" using the
907   // following format "name#arg_1=value_,...,arg_n=value_n".
BuildTraceMeName()908   virtual string BuildTraceMeName() { return params_.prefix; }
909 
910   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
911                  bool* end_of_sequence) final;
912 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)913   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
914                  bool* end_of_sequence) {
915     return GetNext(&ctx, out_tensors, end_of_sequence);
916   }
917 
Save(SerializationContext * ctx,IteratorStateWriter * writer)918   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
919     Status s = params_.dataset->CheckExternalState();
920     if (!s.ok()) {
921       if (ctx->external_state_policy() ==
922           SerializationContext::ExternalStatePolicy::kWarn) {
923         LOG(WARNING) << "Dataset contains external state: " << s.ToString();
924       }
925       if (ctx->external_state_policy() ==
926           SerializationContext::ExternalStatePolicy::kFail) {
927         return s;
928       }
929     }
930     return IteratorBase::Save(ctx, writer);
931   }
932 
933  protected:
934   // Internal implementation of GetNext that is wrapped in tracing logic.
935   virtual Status GetNextInternal(IteratorContext* ctx,
936                                  std::vector<Tensor>* out_tensors,
937                                  bool* end_of_sequence) = 0;
938 
full_name(const string & name)939   string full_name(const string& name) const {
940     if (str_util::StrContains(name, kColon)) {
941       LOG(ERROR) << name << " should not contain " << kColon;
942     }
943 
944     return strings::StrCat(kFullNameRandomHex, kPipe, params_.prefix, kColon,
945                            name);
946   }
947 
948   // By default we model iterators using an unknown node, which acts as
949   // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)950   std::shared_ptr<model::Node> CreateNode(
951       IteratorContext* ctx, model::Node::Args args) const override {
952     return model::MakeUnknownNode(std::move(args));
953   }
954 
955   // When modeling is enabled, this method disables autotuning for the given
956   // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)957   void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
958     if (iterator->node_) {
959       iterator->node_->set_autotune(false);
960     }
961   }
962 
963   // When modeling is enabled, this method enables autotuning for the given
964   // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)965   void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
966     if (iterator->node_) {
967       iterator->node_->set_autotune(true);
968     }
969   }
970 
971   // When modeling is enabled, this method records the fact that this iterator
972   // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)973   void RecordBufferDequeue(IteratorContext* ctx,
974                            const std::vector<Tensor>& element) {
975     if (collect_resource_usage(ctx)) {
976       node_->record_buffer_event(-GetAllocatedBytes(element), -1);
977     }
978   }
979 
980   // When modeling is enabled, this method records the fact that this iterator
981   // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)982   void RecordBufferEnqueue(IteratorContext* ctx,
983                            const std::vector<Tensor>& element) {
984     if (collect_resource_usage(ctx)) {
985       node_->record_buffer_event(GetAllocatedBytes(element), 1);
986     }
987   }
988 
989   // When modeling is enabled, this method records the fact that this iterator
990   // has produced an element.
RecordElement(IteratorContext * ctx)991   void RecordElement(IteratorContext* ctx) {
992     if (node_) {
993       node_->record_element();
994     }
995   }
996 
997   // When modeling is enabled, this method records the fact that a thread of
998   // this iterator has started work.
999   void RecordStart(IteratorContext* ctx, bool stop_output = false) {
1000     if (collect_resource_usage(ctx)) {
1001       int64 now_nanos = EnvTime::NowNanos();
1002       if (stop_output && node_->output()) {
1003         node_->output()->record_stop(now_nanos);
1004       }
1005       node_->record_start(now_nanos);
1006     }
1007   }
1008 
1009   // When modeling is enabled, this method records the fact that a thread of
1010   // this iterator has stopped work.
1011   void RecordStop(IteratorContext* ctx, bool start_output = false) {
1012     if (collect_resource_usage(ctx)) {
1013       int64 now_nanos = EnvTime::NowNanos();
1014       node_->record_stop(now_nanos);
1015       if (start_output && node_->output()) {
1016         node_->output()->record_start(now_nanos);
1017       }
1018     }
1019   }
1020 
1021  private:
collect_resource_usage(IteratorContext * ctx)1022   inline bool collect_resource_usage(IteratorContext* ctx) {
1023     auto model = ctx->model();
1024     return model && model->collect_resource_usage() && node_;
1025   }
1026 
1027   BaseParams params_;
1028 };
1029 
1030 // Represents an iterator that is associated with a particular dataset
1031 // with a particular type.
1032 template <class DatasetType>
1033 class DatasetIterator : public DatasetBaseIterator {
1034  public:
1035   struct Params {
1036     // Borrowed pointer to the dataset.
1037     const DatasetType* dataset;
1038 
1039     // Identifies the sequence of iterators leading up to this iterator.
1040     const string prefix;
1041   };
1042 
DatasetIterator(const Params & params)1043   explicit DatasetIterator(const Params& params)
1044       : DatasetBaseIterator({params.dataset, params.prefix}),
1045         typed_dataset_(params.dataset) {}
1046 
1047   // The dataset from which this iterator was created.
dataset()1048   const DatasetType* dataset() const final { return typed_dataset_; }
1049 
1050  private:
1051   const DatasetType* const typed_dataset_;  // Not owned.
1052 };
1053 
1054 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1055 Status ParseScalarArgument(OpKernelContext* ctx,
1056                            const StringPiece& argument_name, T* output) {
1057   const Tensor* argument_t;
1058   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1059   if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1060     return errors::InvalidArgument(argument_name, " must be a scalar");
1061   }
1062   *output = argument_t->scalar<T>()();
1063   return Status::OK();
1064 }
1065 
1066 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1067 Status ParseVectorArgument(OpKernelContext* ctx,
1068                            const StringPiece& argument_name,
1069                            std::vector<T>* output) {
1070   const Tensor* argument_t;
1071   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1072   if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1073     return errors::InvalidArgument(argument_name, " must be a vector");
1074   }
1075   int size = argument_t->vec<T>().size();
1076   output->reserve(size);
1077   for (int i = 0; i < size; ++i) {
1078     output->push_back(argument_t->vec<T>()(i));
1079   }
1080   return Status::OK();
1081 }
1082 
1083 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1084 // graph execution engine.
1085 class DatasetOpKernel : public OpKernel {
1086  public:
DatasetOpKernel(OpKernelConstruction * ctx)1087   explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1088 
1089   void Compute(OpKernelContext* ctx) final;
1090 
1091   // Indicates whether the given op corresponds to an op whose kernels subclass
1092   // the `DatasetOpKernel` class.
1093   static bool IsDatasetOp(const OpDef* op_def);
1094 
1095  protected:
1096   // Subclasses should implement this method. It will be called during Compute
1097   // execution.
1098   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1099 };
1100 
1101 // Encapsulates the work required to plug unary Datasets into the core
1102 // TensorFlow graph execution engine.
1103 class UnaryDatasetOpKernel : public DatasetOpKernel {
1104  public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1105   explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1106       : DatasetOpKernel(ctx) {}
1107 
1108  protected:
1109   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1110   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1111                            DatasetBase** output) = 0;
1112 };
1113 
1114 // Encapsulates the work required to plug binary Datasets into the core
1115 // TensorFlow graph execution engine.
1116 class BinaryDatasetOpKernel : public DatasetOpKernel {
1117  public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1118   explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1119       : DatasetOpKernel(ctx) {}
1120 
1121  protected:
1122   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1123   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1124                            DatasetBase* another_input,
1125                            DatasetBase** output) = 0;
1126 };
1127 
1128 // A simple background worker that executes closures asynchronously and without
1129 // blocking.
1130 //
1131 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1132 // to avoid blocking an executor thread that may be required by the blocking
1133 // work.
1134 //
1135 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1136 // purpose because its current implementation (in Eigen) uses a finite-length
1137 // queue and will block the caller when full. This can lead to deadlock under
1138 // heavy load. Since the number of concurrent work items in each user of a
1139 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1140 // overhead is tolerable.
1141 class BackgroundWorker {
1142  public:
1143   BackgroundWorker(Env* env, const char* name);
1144 
1145   ~BackgroundWorker();
1146 
1147   void Schedule(std::function<void()> work_item);
1148 
1149  private:
1150   void WorkerLoop();
1151 
1152   Env* const env_;
1153   const char* const name_;
1154 
1155   std::unique_ptr<Thread> thread_;
1156   mutex mu_;
1157   condition_variable cond_var_;
1158   bool cancelled_ GUARDED_BY(mu_) = false;
1159   std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
1160 };
1161 
1162 // Registry of names of ops whose kernels subclass the `DatasetOpKernel` class.
1163 class DatasetOpRegistry {
1164  public:
1165   // Registers the op name.
1166   static void Register(const string& op_name);
1167 
1168   // Checks whether the given op name has been registered.
1169   static bool IsRegistered(const string& op_name);
1170 };
1171 
1172 // Helper class to register dataset op name.
1173 class DatasetOpRegistrar {
1174  public:
DatasetOpRegistrar(const string & op_name)1175   explicit DatasetOpRegistrar(const string& op_name) {
1176     DatasetOpRegistry::Register(op_name);
1177   }
1178 };
1179 
1180 // Macro that can be used to register an op name of a dataset op.
1181 #define REGISTER_DATASET_OP_NAME(op_name) \
1182   REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, op_name)
1183 
1184 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, op_name) \
1185   REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name)
1186 
1187 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name) \
1188   static ::tensorflow::data::DatasetOpRegistrar     \
1189       registrar__body__##ctr##__object(op_name)
1190 
1191 }  // namespace data
1192 
1193 // TODO(b/114112161): Remove these aliases when all users have moved over to the
1194 // `tensorflow::data` namespace.
1195 using data::DatasetBase;
1196 using data::DatasetContext;
1197 using data::DatasetIterator;
1198 using data::DatasetOpKernel;
1199 using data::IteratorBase;
1200 using data::IteratorContext;
1201 using data::IteratorStateReader;
1202 using data::IteratorStateWriter;
1203 using data::SerializationContext;
1204 using data::UnaryDatasetOpKernel;
1205 
1206 }  // namespace tensorflow
1207 
1208 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1209