1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/framework/dataset_metadata.pb.h"
28 #include "tensorflow/core/framework/dataset_options.pb.h"
29 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/function_handle_cache.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/model.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/framework/register_types.h"
37 #include "tensorflow/core/framework/thread_factory.h"
38 #include "tensorflow/core/framework/types.pb.h"
39 #include "tensorflow/core/framework/variant_encode_decode.h"
40 #include "tensorflow/core/framework/variant_tensor_data.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 #include "tensorflow/core/lib/core/threadpool.h"
43 #include "tensorflow/core/lib/core/threadpool_interface.h"
44 #include "tensorflow/core/lib/strings/str_util.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/cpu_info.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/refcount.h"
49 #include "tensorflow/core/platform/tracing.h"
50 
51 // Polymorphic datasets should support all primitive TensorFlow
52 // types. Use this macro to expand `m(T)` once for each primitive type
53 // `T`, e.g. to build a `switch` statement.
54 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
55 
56 namespace tensorflow {
57 
58 // Forward declarations to avoid introducing a dependency on headers in
59 // "tensorflow/core/graph/...".
60 class GraphDefBuilder;
61 class Node;
62 
63 namespace data {
64 
65 namespace internal {
66 // Merges Options from source to destination. If there is a conflict on a field,
67 // the field value from the source takes precedence.
68 void MergeOptions(const protobuf::Message& source,
69                   protobuf::Message* destination);
70 void MergeOptions(const protobuf::MessageLite& source,
71                   protobuf::MessageLite* destination);
72 }  // namespace internal
73 
74 using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
75 
76 constexpr char kTFDataFunction[] = "_tf_data_function";
77 
78 constexpr int kInfiniteCardinality = -1;
79 constexpr int kUnknownCardinality = -2;
80 
81 // This constant is a magic number that is used (as a prefix) to identify keys
82 // used for serialization of iterator state.
83 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
84 constexpr char kPipe[] = "|";
85 constexpr char kColon[] = ":";
86 
87 constexpr char kTFDataResourceTag[] = "tfdata";
88 constexpr char kTraceInfoUnavailable[] = "unavailable";
89 constexpr char kMetadata[] = "metadata";
90 
91 constexpr char kCardinalityAttrForRewrite[] = "_cardinality";
92 
93 class DatasetBase;
94 class SerializationContext;
95 
IsTFDataFunction(const FunctionDef & func)96 inline bool IsTFDataFunction(const FunctionDef& func) {
97   auto iter = func.attr().find(data::kTFDataFunction);
98   return (iter != func.attr().end() && iter->second.b());
99 }
100 
101 // Interface for reading values from a key-value store.
102 // Used for restoring iterator state. This class is thread safe.
103 // Please see comment on IteratorStateWriter for guidance around using the
104 // Read*(key, val) vs Read*(name, key, val).
105 class IteratorStateReader {
106  public:
107   // Determines whether the iterator state contains the given key.
108   virtual bool Contains(StringPiece key) const = 0;
109   virtual bool Contains(StringPiece name, StringPiece key) const = 0;
110 
111   // Reads an integer for the given key.
112   virtual Status ReadScalar(StringPiece key, int64_t* val) const = 0;
113   virtual Status ReadScalar(StringPiece name, StringPiece key,
114                             int64_t* val) const = 0;
115 
116   // Reads a string for the given key.
117   virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
118   virtual Status ReadScalar(StringPiece name, StringPiece key,
119                             tstring* val) const = 0;
120 
121   // Reads a tensor for the given key.
122   // TODO(jsimsa): Remove non-FLR overrides once all callers are updated.
123   virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
124   virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
125                             Tensor* val) const = 0;
126   virtual Status ReadTensor(StringPiece name, StringPiece key,
127                             Tensor* val) const = 0;
128   virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
129                             StringPiece key, Tensor* val) const = 0;
130 
~IteratorStateReader()131   virtual ~IteratorStateReader() {}
132 };
133 
134 // Interface for writing values to a key-value store.
135 // Used for saving iterator state. Not thread safe.
136 // The IteratorStateWriter creates a tensor for each unique iterator name it
137 // sees. For the Write*(key, val) API's the key is expected to encode this
138 // name as keys are required to be produced using the full_name() method.
139 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
140 // might exceed the 2 GB limit, you can pass an explicit name in via the
141 // Write*(name, key, val) APIs allowing you to further split up the state
142 // into more manageable chunks.
143 class IteratorStateWriter {
144  public:
145   // Writes an integer for the given key.
146   virtual Status WriteScalar(StringPiece key, const int64_t val) = 0;
147   virtual Status WriteScalar(StringPiece name, StringPiece key,
148                              const int64_t val) = 0;
149 
150   // Writes a string for the given key.
151   virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
152   virtual Status WriteScalar(StringPiece name, StringPiece key,
153                              const tstring& val) = 0;
154 
155   // Writes a tensor for the given key.
156   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
157   virtual Status WriteTensor(StringPiece name, StringPiece key,
158                              const Tensor& val) = 0;
159 
~IteratorStateWriter()160   virtual ~IteratorStateWriter() {}
161 };
162 
163 // Generates a full name key for iterator checkpointing. All keys generated for
164 // iterator checkpoints should go through this function.
165 std::string FullName(const std::string& prefix, const std::string& name);
166 
167 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
168 class GraphDefBuilderWrapper {
169  public:
GraphDefBuilderWrapper(GraphDefBuilder * b)170   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
171 
172   // Adds a Const node with scalar value to the Graph.
173   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
174   // non-null if the method returns with an OK status.
175   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
176   template <typename T>
AddScalar(const T & val,Node ** output)177   Status AddScalar(const T& val, Node** output) {
178     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
179     val_t.scalar<T>()() = val;
180     AddTensorInternal(val_t, output);
181     if (*output == nullptr) {
182       return errors::Internal("AddScalar: Failed to build Const op.");
183     }
184     return OkStatus();
185   }
186 
187   // Adds a Const node with vector value to the Graph.
188   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
189   // non-null if the method returns with an OK status.
190   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
191   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
192   template <typename T>
AddVector(const std::vector<T> & val,Node ** output)193   Status AddVector(const std::vector<T>& val, Node** output) {
194     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
195                           TensorShape({static_cast<int64_t>(val.size())}));
196     for (size_t i = 0; i < val.size(); i++) {
197       val_t.flat<T>()(i) = val[i];
198     }
199     AddTensorInternal(val_t, output);
200     if (*output == nullptr) {
201       return errors::Internal("AddVector: Failed to build Const op.");
202     }
203     return OkStatus();
204   }
205 
AddVector(const std::vector<string> & val,Node ** output)206   Status AddVector(const std::vector<string>& val, Node** output) {
207     Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
208                           TensorShape({static_cast<int64_t>(val.size())}));
209     for (size_t i = 0; i < val.size(); i++) {
210       val_t.flat<tstring>()(i) = val[i];
211     }
212     AddTensorInternal(val_t, output);
213     if (*output == nullptr) {
214       return errors::Internal("AddVector: Failed to build Const op.");
215     }
216     return OkStatus();
217   }
218 
219   // Adds a `Const` node for the given tensor value to the graph.
220   //
221   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
222   // non-null if the method returns with an OK status. The returned `Node`
223   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)224   Status AddTensor(const Tensor& val, Node** output) {
225     AddTensorInternal(val, output);
226     if (*output == nullptr) {
227       return errors::Internal("AddTensor: Failed to build Const op.");
228     }
229     return OkStatus();
230   }
231 
232   // Adds a `Placeholder` node for the given tensor value to the graph.
233   //
234   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
235   // non-null if the method returns with an OK status. The returned `Node`
236   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)237   Status AddPlaceholder(const Tensor& val, Node** output) {
238     AddPlaceholderInternal(val, output);
239     if (*output == nullptr) {
240       return errors::Internal(
241           "AddPlaceholder: Failed to build Placeholder op.");
242     }
243     return OkStatus();
244   }
245 
246   // Adds a node for the given dataset to the `Graph`. The value of
247   // `DatasetBase::type_string()` is used as the op type for the node. Values
248   // for the `output_types` and `output_shapes` node attributes are also written
249   // if those attributes are defined in the `OpDef`.
250   //
251   // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is
252   // used as the op name for the node. This argument should only be set when
253   // serializing `DatasetBase` instances which might not have been created
254   // through op kernel execution to make sure the dataset op name is preserved
255   // across serialization boundaries, which is in turn needed to make sure
256   // iterator checkpoints are valid across serialization boundaries. When
257   // `use_dataset_name` is set, the caller is responsible for making sure that
258   // the op name is unique across the graph.
259   //
260   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
261   // non-null if the method returns with an OK status. The returned `Node`
262   // pointer is owned by the backing `Graph` of `GraphDefBuilder`.
263   Status AddDataset(const DatasetBase* dataset,
264                     const std::vector<Node*>& inputs, Node** output);
265   Status AddDataset(const DatasetBase* dataset,
266                     const std::vector<Node*>& inputs,
267                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
268                     Node** output);
269   Status AddDataset(
270       const DatasetBase* dataset,
271       const std::vector<std::pair<size_t, Node*>>& inputs,
272       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
273       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
274       Node** output);
275   Status AddDataset(
276       const DatasetBase* dataset,
277       const std::vector<std::pair<size_t, Node*>>& inputs,
278       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
279       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
280       bool use_dataset_name, Node** output);
281 
282   // Adds a user-defined function with name `function_name` to the graph and
283   // recursively adds all functions it references. If a function with a matching
284   // name has already been added, returns with OK status. If a user-defined with
285   // name `function_name` is not found in the context's function library,
286   // returns an InvalidArgumentError. If the function with name `function_name`
287   // or any of its dependent functions are stateful, and the context does not
288   // explicitly permit stateful functions, returns an InvalidArgument error.
289   Status AddFunction(SerializationContext* ctx, const string& function_name,
290                      const FunctionLibraryDefinition& lib_def);
291 
292   template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)293   void BuildAttrValue(const T& value, AttrValue* attr) {
294     SetAttrValue(value, attr);
295   }
296 
297   template <typename T>
BuildAttrValue(const T & value)298   AttrValue BuildAttrValue(const T& value) {
299     AttrValue attr;
300     SetAttrValue(value, &attr);
301     return attr;
302   }
303 
304  protected:
builder()305   GraphDefBuilder* builder() { return b_; }
306 
307  private:
308   void AddPlaceholderInternal(const Tensor& val, Node** output);
309   void AddTensorInternal(const Tensor& val, Node** output);
310   bool HasAttr(const string& op_type_name, const string& attr_name) const;
311 
HasAttr(const OpDef * op_def,const string & attr_name)312   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
313     for (const auto& attr : op_def->attr()) {
314       if (attr.name() == attr_name) {
315         return true;
316       }
317     }
318     return false;
319   }
320 
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)321   Status AddAttrFunctions(SerializationContext* ctx,
322                           const AttrValue& attr_value,
323                           const FunctionLibraryDefinition& lib_def) {
324     if (attr_value.has_func()) {
325       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
326     } else if (attr_value.has_list()) {
327       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
328         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
329       }
330     }
331     return OkStatus();
332   }
333 
334   GraphDefBuilder* b_;
335 };
336 
337 class StatsAggregator;
338 
339 // A utility class for running a function and ensuring that there is always a
340 // `tensorflow::data` symbol on the stack.
341 class Runner {
342  public:
~Runner()343   virtual ~Runner() {}
344 
345   // Runs the given function.
346   virtual void Run(const std::function<void()>& f) = 0;
347 
348   // Returns a global singleton Runner.
349   static Runner* get();
350 };
351 
352 // A class which provides a sequence of splits. Splits represent subdivisions of
353 // a dataset, e.g. filenames or ranges within files. We use splitting to
354 // partition input data into smaller pieces for distributed processing (see
355 // go/tf-data-splitting-design).
356 //
357 // Datasets provide a `MakeSplitProvider` method to expose a listing of their
358 // splits.
359 //
360 // Iterators created with a split provider will only iterate over the splits
361 // provided by the split provider.
362 class SplitProvider {
363  public:
~SplitProvider()364   virtual ~SplitProvider() {}
365   // Stores the next split in `*split`, setting `*end_of_splits` to indicate
366   // whether there were any splits left.
367   virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
368   // Resets the split provider to its beginning.
369   virtual Status Reset() = 0;
370   // Saves the state of this split provider.
371   virtual Status Save(std::function<std::string(std::string)> full_name,
372                       IteratorStateWriter* writer) = 0;
373   // Restores the state of this split provider.
374   virtual Status Restore(std::function<std::string(std::string)> full_name,
375                          IteratorStateReader* reader) = 0;
376 };
377 
378 // Returns the runner threadpool size from an OpKernelContext.
379 int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx);
380 
381 // A cut-down version of `OpKernelContext` for running computations in
382 // iterators. Note that we cannot simply use `OpKernelContext` here because we
383 // might run computation in an iterator whose lifetime is not nested within the
384 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
385 //
386 // TODO(mrry): We're making some daring assumptions about the lifetime of the
387 // runner passed in here. A runner will be deleted when the original step ends,
388 // but all existing runners only close over session-lifetime (or longer-lived)
389 // state, so we can make a copy of the function. There's nothing in the
390 // definition of the API from which we took the runner to guarantee that what we
391 // are doing is safe. We should formalize the properties here.
392 class IteratorContext {
393  public:
394   struct Params {
ParamsParams395     explicit Params(IteratorContext* ctx)
396         : allocator_getter(ctx->allocator_getter()),
397           cancellation_manager(ctx->cancellation_manager()),
398           collective_executor(ctx->collective_executor()),
399           env(ctx->env()),
400           flr(ctx->flr()),
401           function_handle_cache(ctx->function_handle_cache()),
402           interleave_depth(ctx->interleave_depth()),
403           is_restoring(ctx->is_restoring()),
404           model(ctx->model()),
405           options(ctx->options()),
406           resource_mgr(ctx->resource_mgr()),
407           runner(*(ctx->runner())),
408           runner_threadpool_size(ctx->runner_threadpool_size()),
409           split_providers(ctx->split_providers()),
410           stats_aggregator(ctx->stats_aggregator()),
411           thread_factory(ctx->thread_factory()),
412           thread_pool(ctx->thread_pool()) {}
413 
ParamsParams414     explicit Params(OpKernelContext* ctx)
415         : collective_executor(ctx->collective_executor()),
416           env(ctx->env()),
417           flr(ctx->function_library()) {
418       // NOTE: need reinterpret_cast because function.h forward-declares Device.
419       DeviceBase* device =
420           reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
421       allocator_getter = [device](AllocatorAttributes attrs) {
422         return device->GetAllocator(attrs);
423       };
424 
425       runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
426 
427       // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
428       // that a symbol in the tensorflow::data namespace is always on the stack
429       // when executing a function inside a Dataset.
430       runner = std::bind(
431           [](
432               // Note: `runner` is a const reference to avoid copying it.
433               const std::function<void(std::function<void()>)>& ctx_runner,
434               std::function<void()> fn) {
435             std::function<void()> wrapped_fn = std::bind(
436                 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
437                 std::move(fn));
438             ctx_runner(std::move(wrapped_fn));
439           },
440           *ctx->runner(), std::placeholders::_1);
441     }
442 
443     // The Allocator to be used to allocate the output of an iterator.
444     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
445 
446     // The CancellationManager to be used to cancel execution of ops.
447     CancellationManager* cancellation_manager;
448 
449     // Collective support.
450     CollectiveExecutor* collective_executor = nullptr;
451 
452     // Interface to operating system functionality.
453     Env* env = nullptr;
454 
455     // The FunctionLibraryRuntime object to be used to make function calls.
456     FunctionLibraryRuntime* flr = nullptr;
457 
458     // A FunctionHandleCache that owns all the function handles. Not owned.
459     FunctionHandleCache* function_handle_cache = nullptr;
460 
461     // Records the number of ParallelInterleave operations in the path from the
462     // root node to this node (not including this node) in the input pipeline
463     // tree.
464     int64 interleave_depth = 0;
465 
466     // Marks whether the iterator is restored from a checkpoint.
467     bool is_restoring = false;
468 
469     // If non-null, identifies the object used for performance modeling.
470     std::shared_ptr<model::Model> model = nullptr;
471 
472     // The input pipeline options.
473     const Options* options = nullptr;
474 
475     // A resource manager for storing dataset-related state, e.g. random
476     // seeds or cached tensors. Not owned.
477     ResourceMgr* resource_mgr = nullptr;
478 
479     // Function call support.
480     std::function<void(std::function<void()>)> runner = nullptr;
481 
482     // Number of threads used for executing user-defined functions.
483     int32 runner_threadpool_size = 0;
484 
485     // Split providers indicating which splits to process. May be empty,
486     // indicating that the iterator should process all splits.
487     std::vector<std::shared_ptr<SplitProvider>> split_providers;
488 
489     // The `StatsAggregator` object to record statistics about the iterator.
490     //
491     // TODO(b/147325552): Remove this API and any of its uses after we switch to
492     // using C++ based implementation for tf.data options (on 4/12/2021).
493     std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
494 
495     // A factory for creating threads to perform blocking work.
496     std::shared_ptr<ThreadFactory> thread_factory = nullptr;
497 
498     // A shared thread pool to schedule computation into.
499     thread::ThreadPoolInterface* thread_pool = nullptr;
500   };
501 
IteratorContext(IteratorContext * ctx)502   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
503 
IteratorContext(OpKernelContext * ctx)504   explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
505 
IteratorContext(Params params)506   explicit IteratorContext(Params params) : params_(std::move(params)) {}
507 
allocator(AllocatorAttributes attrs)508   Allocator* allocator(AllocatorAttributes attrs) {
509     return params_.allocator_getter(attrs);
510   }
511 
allocator_getter()512   std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
513     return params_.allocator_getter;
514   }
515 
cancellation_manager()516   CancellationManager* cancellation_manager() {
517     return params_.cancellation_manager;
518   }
519 
collective_executor()520   CollectiveExecutor* collective_executor() {
521     return params_.collective_executor;
522   }
523 
env()524   Env* env() const { return params_.env; }
525 
flr()526   FunctionLibraryRuntime* flr() { return params_.flr; }
527 
function_handle_cache()528   FunctionHandleCache* function_handle_cache() {
529     return params_.function_handle_cache;
530   }
531 
interleave_depth()532   int64 interleave_depth() { return params_.interleave_depth; }
533 
is_restoring()534   bool is_restoring() { return params_.is_restoring; }
535 
model()536   const std::shared_ptr<model::Model>& model() { return params_.model; }
537 
options()538   const Options* options() { return params_.options; }
539 
resource_mgr()540   ResourceMgr* resource_mgr() { return params_.resource_mgr; }
541 
runner()542   std::function<void(std::function<void()>)>* runner() {
543     return ¶ms_.runner;
544   }
545 
runner_threadpool_size()546   int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
547 
split_providers()548   std::vector<std::shared_ptr<SplitProvider>> split_providers() {
549     return params_.split_providers;
550   }
551 
stats_aggregator()552   std::shared_ptr<StatsAggregator> stats_aggregator() {
553     return params_.stats_aggregator;
554   }
555 
thread_factory()556   const std::shared_ptr<ThreadFactory>& thread_factory() {
557     return params_.thread_factory;
558   }
559 
thread_pool()560   thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
561 
CreateThreadPool(const string & name,int num_threads)562   std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
563                                                        int num_threads) {
564     if (params_.thread_pool) {
565       // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
566       // is an instance of `thread::ThreadPoolInterface`). Notably, the
567       // ownership of `params_.thread_pool` is *not* transferred onto the newly
568       // created `ThreadPool` instance.
569       return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
570     } else {
571       return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
572                                                    name, num_threads,
573                                                    /*low_latency_hint=*/false);
574     }
575   }
576 
StartThread(const string & name,std::function<void ()> fn)577   std::unique_ptr<Thread> StartThread(const string& name,
578                                       std::function<void()> fn) {
579     if (params_.thread_factory) {
580       return params_.thread_factory->StartThread(name, std::move(fn));
581     } else {
582       return absl::WrapUnique(
583           Env::Default()->StartThread({}, name, std::move(fn)));
584     }
585   }
586 
587  private:
588   Params params_;
589 };
590 
591 // Aggregates runtime support needed for dataset and iterator serialization.
592 class SerializationContext {
593  public:
594   // Enum describing what to do during serialization when external state is
595   // encountered.
596   enum class ExternalStatePolicy : int64 {
597     // Proceed with serialization, but log a warning about what state will be
598     // lost.
599     kWarn = 0,
600     // Proceed with serialization without logging any warning.
601     kIgnore = 1,
602     // Fail the serialization with an error.
603     kFail = 2,
604   };
605 
606   // Handles the CheckExternalState status according to the external state
607   // policy.
HandleCheckExternalStateStatus(Status s)608   Status HandleCheckExternalStateStatus(Status s) {
609     if (s.ok()) {
610       return s;
611     }
612     switch (params_.external_state_policy) {
613       case ExternalStatePolicy::kWarn:
614         LOG(WARNING) << s.ToString();
615         return OkStatus();
616       case ExternalStatePolicy::kIgnore:
617         VLOG(2) << "Ignoring error status: " << s.ToString();
618         return OkStatus();
619       case ExternalStatePolicy::kFail:
620         return s;
621     }
622     LOG(FATAL) << "Control should never reach here";
623   }
624 
625   struct Params {
ParamsParams626     explicit Params() {}
627 
ParamsParams628     explicit Params(OpKernelContext* ctx)
629         : resource_mgr(ctx->resource_manager()),
630           device_name(ctx->device()->attributes().name()) {}
631 
632     std::vector<std::pair<string, Tensor>>* input_list = nullptr;  // Not owned.
633 
634     // Indicates what to do if the dataset depends on external state.
635     ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
636 
637     // Indicates whether the serialization is for rewrites.
638     //
639     // If true:
640     //   * A dataset that doesn't implement serialization is replaced with a
641     //     placeholder returned in `input_list`.
642     //   * Data tensors are replaced with a placeholder returned in
643     //     `input_list`.
644     //   * Datasets that use random seeds should not serialize the random seeds.
645     //     This doesn't affect datasets that use fixed seeds; fixed seeds will
646     //     always be preserved.
647     //   * Cardinality is serialized as an unregistered attribute
648     //     `_cardinality`.
649     // If false:
650     //   * A dataset that doesn't implement serialization should result in an
651     //     error.
652     //   * Data tensors (potentially large) should be serialized.
653     //   * Datasets that use random seeds should serialize the random seeds.
654     bool is_graph_rewrite = false;
655 
656     // A resource manager for looking up resources during serialization.
657     ResourceMgr* resource_mgr;
658 
659     // The name of the device doing the serialization.
660     std::string device_name;
661   };
662 
SerializationContext(Params params)663   explicit SerializationContext(Params params) : params_(params) {}
664 
input_list()665   std::vector<std::pair<string, Tensor>>* input_list() {
666     return params_.input_list;
667   }
668 
external_state_policy()669   ExternalStatePolicy external_state_policy() const {
670     return params_.external_state_policy;
671   }
672 
is_graph_rewrite()673   bool is_graph_rewrite() const { return params_.is_graph_rewrite; }
674 
resource_mgr()675   const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
676 
device_name()677   const std::string& device_name() const { return params_.device_name; }
678 
679  private:
680   Params params_;
681 
682   TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
683 };
684 
685 // Represents the current position in a range of outputs, where the
686 // range of outputs is typically represented by an `DatasetBase`,
687 // defined below.
688 class IteratorBase {
689  public:
~IteratorBase()690   virtual ~IteratorBase() {
691     for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
692       (*rit)();
693     }
694   }
695 
696   // Gets the next output from the range that this iterator is traversing.
697   //
698   // If at least one output remains in this iterator's range, that
699   // output will be stored in `*out_tensors` and `false` will be
700   // stored in `*end_of_sequence`.
701   //
702   // If no more outputs remain in this iterator's range, `true` will be stored
703   // in `*end_of_sequence`, and `*out_tensors` will be empty.
704   //
705   // Implementations should never return `OutOfRange` error. If at end of
706   // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
707   // Internally raised `OutOfRange` errors that do not imply end of sequence
708   // should be converted to a different error type before being propagated to
709   // the caller.
710   //
711   // Implementations must explicitly set `*end_of_sequence = false` if an
712   // `Status::OK()` status is returned and the iterator is not at the end of the
713   // sequence.
714   //
715   // This method is thread-safe.
716   //
717   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
718   // potentially remove this method.
719   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
720                          bool* end_of_sequence) = 0;
721 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)722   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
723                  bool* end_of_sequence) {
724     return GetNext(&ctx, out_tensors, end_of_sequence);
725   }
726 
727   // Skips the next `num_to_skip` outputs from the range that this iterator
728   // is traversing.
729   //
730   // If there are not enough outputs to skip, it will set
731   // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
732   // store the number of outputs that are skipped. When `*end_of_sequence` is
733   // `false`, `*num_skipped` should equal to `num_to_skip`.
734   virtual Status Skip(IteratorContext* ctx, int num_to_skip,
735                       bool* end_of_sequence, int* num_skipped) = 0;
736 
Skip(IteratorContext && ctx,int num_to_skip,bool * end_of_sequence,int * num_skipped)737   virtual Status Skip(IteratorContext&& ctx, int num_to_skip,
738                       bool* end_of_sequence, int* num_skipped) {
739     return Skip(&ctx, num_to_skip, end_of_sequence, num_skipped);
740   }
741 
742   // Returns a vector of DataType values, representing the respective
743   // element types of each tuple component in the outputs of this
744   // iterator.
745   virtual const DataTypeVector& output_dtypes() const = 0;
746 
747   // Returns a vector of tensor shapes, representing the respective
748   // (and possibly partially defined) shapes of each tuple component
749   // in the outputs of this iterator.
750   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
751 
752   // Returns a string that identifies the sequence of iterators leading up to
753   // this iterator.
754   virtual const string& prefix() const = 0;
755 
756   // Performs initialization that needs to happen outside of a constructor to
757   // properly propagate errors.
Initialize(IteratorContext * ctx)758   virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); }
759 
760   // Performs initialization of the base iterator.
761   Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
762 
763   // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)764   virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
765     int64_t start_us = EnvTime::NowMicros();
766     TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
767     VLOG(1) << "Saved " << prefix() << " in "
768             << (EnvTime::NowMicros() - start_us) << "us";
769     return OkStatus();
770   }
771 
772  protected:
773   // Returns a node that models this iterator.
774   virtual std::shared_ptr<model::Node> CreateNode(
775       IteratorContext* ctx, model::Node::Args args) const = 0;
776 
777   // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)778   virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
779     int64_t start_us = EnvTime::NowMicros();
780     TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
781     VLOG(1) << "Restored " << prefix() << " in "
782             << (EnvTime::NowMicros() - start_us) << "us";
783     return OkStatus();
784   }
785 
786   // This is needed so that sub-classes of IteratorBase can call
787   // `SaveInternal` on their input iterators.
SaveInput(SerializationContext * ctx,IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)788   Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
789                    const std::unique_ptr<IteratorBase>& input) {
790     return input->Save(ctx, writer);
791   }
792 
793   // This is needed so that sub-classes of IteratorBase can call
794   // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)795   Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
796                       const std::unique_ptr<IteratorBase>& input) {
797     return input->Restore(ctx, reader);
798   }
799 
RestoreInput(IteratorContext && ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)800   Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
801                       const std::unique_ptr<IteratorBase>& input) {
802     return RestoreInput(&ctx, reader, input);
803   }
804 
805   // Saves the state of this iterator.
806   //
807   // This method is used to store the state of the iterator in a checkpoint.
808   // implementations have an override.
809   virtual Status SaveInternal(SerializationContext* ctx,
810                               IteratorStateWriter* writer) = 0;
811 
812   // Restores the state of this iterator.
813   //
814   // This method is used to restore the state of the iterator from a checkpoint.
815   //
816   // Implementations may assume that the iterator is in a clean state. That is,
817   // its `Initialize` method has been called, but its `GetNext` method has
818   // never been called.
819   // implementations have an override.
820   virtual Status RestoreInternal(IteratorContext* ctx,
821                                  IteratorStateReader* reader) = 0;
822 
823   // Returns a pointer to the node representing this iterator in the performance
824   // model. It may be null, if performance modeling is not enabled for this
825   // iterator.
model_node()826   std::shared_ptr<model::Node> model_node() const { return node_; }
827 
828   // Returns the number of elements produced by this iterator.
num_elements()829   int64_t num_elements() const {
830     if (node_) return node_->num_elements();
831     return 0;
832   }
833 
834  private:
835   // For access to `AddCleanupFunction` and `Restore`.
836   friend class DatasetBase;
837   friend class DatasetBaseIterator;  // for access to `node_`
838 
839   std::vector<std::function<void()>> cleanup_fns_;
840   std::shared_ptr<model::Node> node_ = nullptr;
841   const IteratorBase* parent_ = nullptr;  // Not owned.
842   int64_t id_ = 0;
843   int64_t parent_id_ = 0;
844 };
845 
846 // Represents runtime information needed to construct a dataset.
847 class DatasetContext {
848  public:
849   struct Params {
850     string type_string;  // op type name of this dataset.
851     string node_name;    // graph node name of this dataset op, uniquely
852                          // identifying the dataset in the graph.
853   };
854 
DatasetContext(Params params)855   explicit DatasetContext(Params params) : params_(std::move(params)) {}
856 
DatasetContext(OpKernelContext * ctx)857   explicit DatasetContext(OpKernelContext* ctx) {
858     params_.type_string = ctx->op_kernel().type_string();
859     params_.node_name = ctx->op_kernel().name();
860   }
861 
type_string()862   const string& type_string() const { return params_.type_string; }
node_name()863   const string& node_name() const { return params_.node_name; }
864 
865  private:
866   Params params_;
867 };
868 
869 // Returns the number of bytes allocated for the given tensor.
870 int64_t GetAllocatedBytes(const std::vector<Tensor>& element);
871 
872 // Returns the estimated memory usage in bytes of the given tensor.
873 int64_t GetTotalBytes(const std::vector<Tensor>& element);
874 
875 // Validates and extracts a `DatasetBase` object from `tensor`.
876 //
877 // `tensor` must have been written by a call to SetVariantTensorToDataset().
878 //
879 // The retrieved pointer is a borrowed reference to the dataset, which is owned
880 // by the tensor. The consumer must either acquire its own reference to the
881 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
882 // destroyed or mutated while the retrieved pointer is in use.
883 Status GetDatasetFromVariantTensor(const Tensor& tensor,
884                                    DatasetBase** out_dataset);
885 
886 // Stores a `DatasetBase` object in `tensor`.
887 //
888 // The ownership of `dataset` is transferred to `tensor`.
889 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
890 
891 // Represents a (potentially infinite) range of outputs, where each
892 // output is a tuple of tensors.
893 class DatasetBase : public core::RefCounted {
894  public:
895   // Key for storing the Dataset graph in the serialized format.
896   TF_EXPORT static const char kDatasetGraphKey[];
897 
898   // Key for storing the output node of the Dataset graph in the serialized
899   // format.
900   TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
901 
DatasetBase(DatasetContext && ctx)902   explicit DatasetBase(DatasetContext&& ctx)
903       : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
904 
905   // Op type name of this dataset.
type_string()906   const string& type_string() const { return type_string_; }
907 
908   // Graph node name of this dataset op, uniquely identifying the dataset in
909   // the graph.
node_name()910   const string& node_name() const { return node_name_; }
911 
912   // Initializes the dataset.
913   void Initialize(const Metadata& metadata);
914 
metadata()915   const Metadata& metadata() const { return metadata_; }
916 
options()917   const Options& options() const { return options_; }
918 
num_sources()919   int64_t num_sources() const { return num_sources_; }
920 
921   // Returns a new iterator for iterating over the range of elements in
922   // this dataset.
923   //
924   // This method may be called multiple times on the same instance,
925   // and the resulting iterators will have distinct state. Each
926   // iterator will traverse all elements in this dataset from the
927   // start.
928   //
929   // The prefix identifies the sequence of iterators leading up to the newly
930   // created iterator.
931   Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
932                       const string& output_prefix,
933                       std::unique_ptr<IteratorBase>* iterator) const;
934 
MakeIterator(IteratorContext && ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)935   Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
936                       const string& output_prefix,
937                       std::unique_ptr<IteratorBase>* iterator) const {
938     return MakeIterator(&ctx, parent, output_prefix, iterator);
939   }
940 
941   // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)942   Status MakeIteratorFromCheckpoint(
943       IteratorContext* ctx, const string& output_prefix,
944       IteratorStateReader* reader,
945       std::unique_ptr<IteratorBase>* iterator) const {
946     std::unique_ptr<IteratorBase> it;
947     IteratorContext::Params params(ctx);
948     params.is_restoring = true;
949     IteratorContext restore_ctx(std::move(params));
950     TF_RETURN_IF_ERROR(MakeIterator(&restore_ctx,
951                                     /*parent=*/nullptr, output_prefix, &it));
952     TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader));
953     *iterator = std::move(it);
954     return OkStatus();
955   }
956 
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)957   Status MakeIteratorFromCheckpoint(
958       IteratorContext&& ctx, const string& output_prefix,
959       IteratorStateReader* reader,
960       std::unique_ptr<IteratorBase>* iterator) const {
961     return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
962   }
963 
964   // Returns a split provider which partitions the dataset's data into splits
965   // and provides them in a sequence. The split provider is stored in
966   // `*split_provider`.
967   virtual Status MakeSplitProviders(
968       std::vector<std::unique_ptr<SplitProvider>>* split_providers) const;
969 
970   // Returns a vector of DataType values, representing the respective
971   // element types of each tuple component in the outputs of this
972   // dataset.
973   virtual const DataTypeVector& output_dtypes() const = 0;
974 
975   // Returns a vector of tensor shapes, representing the respective
976   // (and possibly partially defined) shapes of each tuple component
977   // in the outputs of this dataset.
978   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
979 
980   // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()981   virtual int64_t AllocatedBytes() const { return 0; }
982 
983   // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()984   virtual int64_t TotalBytes() const { return 0; }
985 
986   // Returns the cardinality of this dataset.
987   // TODO(shilpakrish): Remove this overload once all callers are migrated
988   // to the API which passes in the options parameter.
989   ABSL_DEPRECATED("Use the overload that passes in the options parameter.")
990   int64_t Cardinality() const;
991 
992   // Returns the cardinality of this dataset based on the options.
993   int64_t Cardinality(CardinalityOptions options) const;
994 
995   // Internal implementation of cardinality for a dataset.
996   // TODO(shilpakrish): Remove this overload once all callers are migrated
997   // to the API which passes in the options parameter.
998   ABSL_DEPRECATED("Use the overload that passes in the options parameter.")
CardinalityInternal()999   virtual int64_t CardinalityInternal() const { return kUnknownCardinality; }
1000 
1001   // Internal implementation of cardinality for a dataset based on the options.
CardinalityInternal(CardinalityOptions options)1002   virtual int64_t CardinalityInternal(CardinalityOptions options) const {
1003     return kUnknownCardinality;
1004   }
1005 
1006   // A human-readable debug string for this dataset.
1007   virtual string DebugString() const = 0;
1008 
1009   // Stores the dataset's input datasets in `*inputs`. The pointers stored in
1010   // `*inputs` are borrowed. The only valid non-ok return status is
1011   // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
1012   // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
1013   // default implementation of `MakeSplitProvider` when there is a single input
1014   // dataset.
1015   virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
1016 
1017   // Indicates whether the dataset depends on any external state which would
1018   // prevent it from being serializable. If so, the method returns
1019   // `errors::FailedPrecondition` with a message that identifies the external
1020   // state. Otherwise, the method returns `Status::OK()`.
1021   virtual Status CheckExternalState() const = 0;
1022 
1023   // Indicates whether the dataset is compatible with random access.
1024   Status CheckRandomAccessCompatible(const int64 index) const;
1025 
1026   // Return the element at a particular index for a randomly accessible dataset.
1027   virtual Status Get(OpKernelContext* ctx, int64 index,
1028                      std::vector<Tensor>* out_tensors) const;
1029 
1030   // Return a finalized version of the dataset.  The returned DatasetBase is
1031   // unowned and lives for as long as this dataset.
1032   virtual StatusOr<DatasetBase*> Finalize(
1033       OpKernelContext* ctx,
1034       std::function<StatusOr<core::RefCountPtr<DatasetBase>>()>
1035           make_finalized_dataset) const;
1036 
1037   // Wrapper around a GraphDefBuilder which provides support for serializing
1038   // Datasets as GraphDefs.
1039   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
1040    public:
DatasetGraphDefBuilder(GraphDefBuilder * b)1041     explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
1042         : GraphDefBuilderWrapper(b) {}
1043     Status AddInputDataset(SerializationContext* ctx,
1044                            const DatasetBase* dataset, Node** output);
1045     Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
1046                               Node** output);
1047     Status AddIdentity(SerializationContext* ctx,
1048                        const std::string& name_prefix, Node** input,
1049                        Node** output);
1050 
1051    private:
1052     Status AddDatasetOrTensorHelper(SerializationContext* ctx,
1053                                     const Tensor& val, Node** output);
1054     Status AddResourceHelper(SerializationContext* ctx, const Tensor& val,
1055                              Node** output);
1056   };
1057 
1058  protected:
1059   friend class CapturedFunction;
1060 
1061   // Serializes the dataset into a `GraphDef`, which has two uses:
1062   //
1063   // 1) To perform static input pipeline optimizations, tf.data serializes the
1064   // dataset graph, applies graph rewrites, and then deserializes the graph.
1065   // If a subclass of `DatasetBase` does not implement this method, then it will
1066   // be excluded from static optimizations (and so will any upstream datasets).
1067   //
1068   // 2) To save the dataset so that it can restore at a later point (possibly in
1069   // different environment). If a subclass of `DatasetBase` does not implement
1070   // this method, then this migration will not be possible.
1071   virtual Status AsGraphDefInternal(SerializationContext* ctx,
1072                                     DatasetGraphDefBuilder* b,
1073                                     Node** node) const = 0;
1074 
1075   virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
1076       const string& prefix) const = 0;
1077 
set_options(const Options & options)1078   void set_options(const Options& options) { options_ = options; }
1079 
1080  private:
1081   // Computes and stores the cardinality of a given dataset.
1082   Status ComputeCardinality();
1083 
1084   // Computes the number of source datasets feeding into this dataset. A source
1085   // dataset is a leaf in the subtree of dataset inputs.
1086   Status ComputeNumSources();
1087 
1088   // Merges options from inputs to this dataset. If there is a conflict in a
1089   // field value, the options set on this dataset takes precedence over those in
1090   // the inputs. The order of precedence on the inputs is in the same order as
1091   // how they appear for this dataset.
1092   Status MergeOptionsFromInputs();
1093 
1094   const string type_string_;
1095   const string node_name_;
1096   Metadata metadata_;
1097   Options options_;
1098   mutable mutex mu_;
1099   mutable mutex cardinality_mu_;
1100   mutable core::RefCountPtr<DatasetBase> finalized_dataset_;
1101   //  The number of source datasets feeding into the dataset. A source dataset
1102   //  is a leaf in the subtree of dataset inputs.
1103   int64_t num_sources_ = -1;
1104   mutable int64_t cardinality_ TF_GUARDED_BY(cardinality_mu_) =
1105       kUnknownCardinality;
1106 };
1107 
1108 // Represents an iterator that is associated with a particular dataset.
1109 class DatasetBaseIterator : public IteratorBase {
1110  public:
1111   struct BaseParams {
1112     // Owns one reference on the shared dataset object.
1113     const DatasetBase* dataset;
1114 
1115     // Identifies the sequence of iterators leading up to this iterator.
1116     const string prefix;
1117   };
1118 
1119   explicit DatasetBaseIterator(const BaseParams& params);
1120 
1121   ~DatasetBaseIterator() override;
1122 
dataset()1123   virtual const DatasetBase* dataset() const { return params_.dataset; }
1124 
output_dtypes()1125   const DataTypeVector& output_dtypes() const override {
1126     return params_.dataset->output_dtypes();
1127   }
1128 
output_shapes()1129   const std::vector<PartialTensorShape>& output_shapes() const override {
1130     return params_.dataset->output_shapes();
1131   }
1132 
prefix()1133   const string& prefix() const override { return params_.prefix; }
1134 
1135   // Returns a name to be used for the TraceMe event.
1136   //
1137   // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
1138   // following format "name#arg_1=value_,...,arg_n=value_n".
1139   string BuildTraceMeName();
1140 
1141   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
1142                  bool* end_of_sequence) final;
1143 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1144   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
1145                  bool* end_of_sequence) {
1146     return GetNext(&ctx, out_tensors, end_of_sequence);
1147   }
1148 
1149   Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
1150               int* num_skipped) final;
1151 
Save(SerializationContext * ctx,IteratorStateWriter * writer)1152   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
1153     VLOG(2) << "Attempting to save checkpoints on iterator (prefix: "
1154             << prefix() << ") from " << dataset()->DebugString();
1155     return IteratorBase::Save(ctx, writer);
1156   }
1157 
1158  protected:
Restore(IteratorContext * ctx,IteratorStateReader * reader)1159   Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final {
1160     VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: "
1161             << prefix() << ") from " << dataset()->DebugString();
1162     return IteratorBase::Restore(ctx, reader);
1163   }
1164 
1165   // Internal implementation of GetNext that is wrapped in tracing logic.
1166   //
1167   // See the docstring of `GetNext` method regaring the contract for
1168   // `out_tensors` and `end_of_sequence`. Implementations may assume that
1169   // `*out_tensors` is empty.
1170   virtual Status GetNextInternal(IteratorContext* ctx,
1171                                  std::vector<Tensor>* out_tensors,
1172                                  bool* end_of_sequence) = 0;
1173 
1174   // Internal implementation of Skip that is wrapped in tracing logic
1175   virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1176                               bool* end_of_sequence, int* num_skipped);
1177 
full_name(const string & name)1178   string full_name(const string& name) const {
1179     return FullName(params_.prefix, name);
1180   }
1181 
1182   // Returns a map of key-value pairs to included in the TraceMe string.
GetTraceMeMetadata()1183   virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1184 
1185   // By default we model iterators using an unknown node, which acts as
1186   // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)1187   std::shared_ptr<model::Node> CreateNode(
1188       IteratorContext* ctx, model::Node::Args args) const override {
1189     return model::MakeUnknownNode(std::move(args));
1190   }
1191 
1192   // When modeling is enabled, this method disables autotuning for the given
1193   // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)1194   void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1195     if (iterator->node_) {
1196       iterator->node_->set_autotune(false);
1197     }
1198   }
1199 
1200   // When modeling is enabled, this method enables autotuning for the given
1201   // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)1202   void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1203     if (iterator->node_) {
1204       iterator->node_->set_autotune(true);
1205     }
1206   }
1207 
1208   // When modeling is enabled, this method records the fact that this iterator
1209   // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)1210   void RecordBufferDequeue(IteratorContext* ctx,
1211                            const std::vector<Tensor>& element) {
1212     if (collect_resource_usage(ctx)) {
1213       node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1214 
1215       DCHECK_GE(node_->buffered_elements(), 0);
1216     }
1217   }
1218 
1219   // When modeling is enabled, this method records the fact that this iterator
1220   // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)1221   void RecordBufferEnqueue(IteratorContext* ctx,
1222                            const std::vector<Tensor>& element) {
1223     if (collect_resource_usage(ctx)) {
1224       node_->record_buffer_event(GetAllocatedBytes(element), 1);
1225     }
1226   }
1227 
1228   // When modeling is enabled, this method records the fact that this iterator
1229   // has produced an element and its size in bytes.
RecordElement(IteratorContext * ctx,std::vector<Tensor> * out_tensors)1230   void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1231     if (collect_resource_usage(ctx)) {
1232       int64_t num_bytes = GetAllocatedBytes(*out_tensors);
1233       node_->record_element();
1234       node_->record_bytes_produced(num_bytes);
1235       if (node_->output()) {
1236         node_->output()->record_bytes_consumed(num_bytes);
1237       }
1238     }
1239   }
1240 
1241   // When modeling is enabled, this method records the fact that a thread of
1242   // this iterator has started work.
RecordStart(IteratorContext * ctx)1243   void RecordStart(IteratorContext* ctx) {
1244     if (collect_resource_usage(ctx)) {
1245       int64_t now_nanos = EnvTime::NowNanos();
1246       node_->record_start(now_nanos);
1247     }
1248   }
1249 
1250   // When modeling is enabled, this method records the fact that a thread of
1251   // this iterator has stopped work.
RecordStop(IteratorContext * ctx)1252   void RecordStop(IteratorContext* ctx) {
1253     if (collect_resource_usage(ctx)) {
1254       int64_t now_nanos = EnvTime::NowNanos();
1255       node_->record_stop(now_nanos);
1256     }
1257   }
1258 
1259   // Returns whether work is currently being recorded, i.e. whether we are
1260   // currently between a `RecordStart` and a `RecordStop`.
IsRecording(IteratorContext * ctx)1261   bool IsRecording(IteratorContext* ctx) {
1262     return node_ && node_->is_recording();
1263   }
1264 
1265  private:
collect_resource_usage(IteratorContext * ctx)1266   bool collect_resource_usage(IteratorContext* ctx) {
1267     return ctx->model() && node_;
1268   }
1269 
1270   string traceme_metadata_;
1271   BaseParams params_;
1272 };
1273 
1274 // Represents an iterator that is associated with a particular dataset
1275 // with a particular type.
1276 template <class DatasetType>
1277 class DatasetIterator : public DatasetBaseIterator {
1278  public:
1279   struct Params {
1280     // Borrowed pointer to the dataset.
1281     const DatasetType* dataset;
1282 
1283     // Identifies the sequence of iterators leading up to this iterator.
1284     const string prefix;
1285   };
1286 
DatasetIterator(const Params & params)1287   explicit DatasetIterator(const Params& params)
1288       : DatasetBaseIterator({params.dataset, params.prefix}),
1289         typed_dataset_(params.dataset) {}
1290 
1291   // The dataset from which this iterator was created.
dataset()1292   const DatasetType* dataset() const final { return typed_dataset_; }
1293 
1294  private:
1295   const DatasetType* const typed_dataset_;  // Not owned.
1296 };
1297 
1298 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1299 Status ParseScalarArgument(OpKernelContext* ctx,
1300                            const StringPiece& argument_name, T* output) {
1301   const Tensor* argument_t;
1302   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1303   if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1304     return errors::InvalidArgument(argument_name, " must be a scalar");
1305   }
1306   *output = argument_t->scalar<T>()();
1307   return OkStatus();
1308 }
1309 
1310 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1311 Status ParseVectorArgument(OpKernelContext* ctx,
1312                            const StringPiece& argument_name,
1313                            std::vector<T>* output) {
1314   const Tensor* argument_t;
1315   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1316   if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1317     return errors::InvalidArgument(argument_name, " must be a vector");
1318   }
1319   int size = argument_t->vec<T>().size();
1320   output->reserve(size);
1321   for (int i = 0; i < size; ++i) {
1322     output->push_back(argument_t->vec<T>()(i));
1323   }
1324   return OkStatus();
1325 }
1326 
1327 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1328 // graph execution engine.
1329 class DatasetOpKernel : public OpKernel {
1330  public:
DatasetOpKernel(OpKernelConstruction * ctx)1331   explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
1332     if (ctx->HasAttr(kMetadata)) {
1333       std::string serialized_metadata;
1334       OP_REQUIRES_OK(ctx, ctx->GetAttr(kMetadata, &serialized_metadata));
1335       OP_REQUIRES(ctx, metadata_.ParseFromString(serialized_metadata),
1336                   errors::InvalidArgument(absl::StrCat(
1337                       "Could not parse the 'metadata' attribute.")));
1338     }
1339   }
1340 
1341   void Compute(OpKernelContext* ctx) final;
1342 
1343   // Checks whether the given op is a tf.data operation.
1344   //
1345   // NOTE: The check uses a heuristic and can produce both false positives and
1346   // false negatives. In particular, tf.data operations are expected to use
1347   // names that end with "Dataset" or "DatasetV[0-9]+".
1348   static bool IsDatasetOp(const OpDef& op_def);
1349 
1350   string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1351 
1352  protected:
1353   // Subclasses should implement this method. It will be called during Compute
1354   // execution.
1355   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1356 
1357  private:
1358   Metadata metadata_;
1359 };
1360 
1361 // Encapsulates the work required to plug unary Datasets into the core
1362 // TensorFlow graph execution engine.
1363 class UnaryDatasetOpKernel : public DatasetOpKernel {
1364  public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1365   explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1366       : DatasetOpKernel(ctx) {}
1367 
1368  protected:
1369   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1370   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1371                            DatasetBase** output) = 0;
1372 };
1373 
1374 // Encapsulates the work required to plug binary Datasets into the core
1375 // TensorFlow graph execution engine.
1376 class BinaryDatasetOpKernel : public DatasetOpKernel {
1377  public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1378   explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1379       : DatasetOpKernel(ctx) {}
1380 
1381  protected:
1382   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1383   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1384                            DatasetBase* another_input,
1385                            DatasetBase** output) = 0;
1386 };
1387 
1388 // A simple background worker that executes closures asynchronously and without
1389 // blocking.
1390 //
1391 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1392 // to avoid blocking an executor thread that may be required by the blocking
1393 // work.
1394 //
1395 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1396 // purpose because its current implementation (in Eigen) uses a finite-length
1397 // queue and will block the caller when full. This can lead to deadlock under
1398 // heavy load. Since the number of concurrent work items in each user of a
1399 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1400 // overhead is tolerable.
1401 class BackgroundWorker {
1402  public:
1403   BackgroundWorker(Env* env, const char* name);
1404 
1405   ~BackgroundWorker();
1406 
1407   void Schedule(std::function<void()> work_item);
1408 
1409  private:
1410   void WorkerLoop();
1411 
1412   Env* const env_;
1413   const char* const name_;
1414 
1415   std::unique_ptr<Thread> thread_;
1416   mutex mu_;
1417   condition_variable cond_var_;
1418   bool cancelled_ TF_GUARDED_BY(mu_) = false;
1419   std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1420 };
1421 
1422 }  // namespace data
1423 }  // namespace tensorflow
1424 
1425 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1426