• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/model.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/thread_factory.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variant_encode_decode.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/lib/core/threadpool_interface.h"
40 #include "tensorflow/core/lib/strings/str_util.h"
41 #include "tensorflow/core/lib/strings/strcat.h"
42 #include "tensorflow/core/platform/cpu_info.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/tracing.h"
45 
46 // Polymorphic datasets should support all primitive TensorFlow
47 // types. Use this macro to expand `m(T)` once for each primitive type
48 // `T`, e.g. to build a `switch` statement.
49 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
50 
51 namespace tensorflow {
52 
53 // Forward declarations to avoid introducing a dependency on headers in
54 // "tensorflow/core/graph/...".
55 class GraphDefBuilder;
56 class Node;
57 
58 namespace data {
59 
60 using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
61 
62 constexpr char kTFDataFunction[] = "_tf_data_function";
63 
64 constexpr int kInfiniteCardinality = -1;
65 constexpr int kUnknownCardinality = -2;
66 
67 // This constant is a magic number that is used (as a prefix) to identify keys
68 // used for serialization of iterator state.
69 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
70 constexpr char kPipe[] = "|";
71 constexpr char kColon[] = ":";
72 
73 constexpr char kTFDataResourceTag[] = "tfdata";
74 
75 class DatasetBase;
76 class SerializationContext;
77 
IsTFDataFunction(const FunctionDef & func)78 inline bool IsTFDataFunction(const FunctionDef& func) {
79   auto iter = func.attr().find(data::kTFDataFunction);
80   return (iter != func.attr().end() && iter->second.b());
81 }
82 
83 // Interface for reading values from a key-value store.
84 // Used for restoring iterator state. This class is thread safe.
85 // Please see comment on IteratorStateWriter for guidance around using the
86 // Read*(key, val) vs Read*(name, key, val).
87 class IteratorStateReader {
88  public:
89   virtual Status ReadScalar(StringPiece key, int64* val) const = 0;
90   virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
91   virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
92 
93   virtual Status ReadScalar(StringPiece name, StringPiece key,
94                             int64* val) const = 0;
95   virtual Status ReadScalar(StringPiece name, StringPiece key,
96                             tstring* val) const = 0;
97   virtual Status ReadTensor(StringPiece name, StringPiece key,
98                             Tensor* val) const = 0;
99 
100   virtual bool Contains(StringPiece key) const = 0;
101   virtual bool Contains(StringPiece name, StringPiece key) const = 0;
102 
~IteratorStateReader()103   virtual ~IteratorStateReader() {}
104 };
105 
106 // Interface for writing values to a key-value store.
107 // Used for saving iterator state. Not thread safe.
108 // The IteratorStateWriter creates a tensor for each unique iterator name it
109 // sees. For the Write*(key, val) API's the key is expected to encode this
110 // name as keys are required to be produced using the full_name() method.
111 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
112 // might exceed the 2 GB limit, you can pass an explicit name in via the
113 // Write*(name, key, val) APIs allowing you to further split up the state
114 // into more manageable chunks.
115 class IteratorStateWriter {
116  public:
117   virtual Status WriteScalar(StringPiece key, const int64 val) = 0;
118   virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
119   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
120 
121   virtual Status WriteScalar(StringPiece name, StringPiece key,
122                              const int64 val) = 0;
123   virtual Status WriteScalar(StringPiece name, StringPiece key,
124                              const tstring& val) = 0;
125   virtual Status WriteTensor(StringPiece name, StringPiece key,
126                              const Tensor& val) = 0;
127 
~IteratorStateWriter()128   virtual ~IteratorStateWriter() {}
129 };
130 
131 // Generates a full name key for iterator checkpointing. All keys generated for
132 // iterator checkpoints should go through this function.
133 std::string FullName(const std::string& prefix, const std::string& name);
134 
135 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
136 class GraphDefBuilderWrapper {
137  public:
GraphDefBuilderWrapper(GraphDefBuilder * b)138   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
139 
140   // Adds a Const node with scalar value to the Graph.
141   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
142   // non-null if the method returns with an OK status.
143   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
144   template <typename T>
AddScalar(const T & val,Node ** output)145   Status AddScalar(const T& val, Node** output) {
146     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
147     val_t.scalar<T>()() = val;
148     AddTensorInternal(val_t, output);
149     if (*output == nullptr) {
150       return errors::Internal("AddScalar: Failed to build Const op.");
151     }
152     return Status::OK();
153   }
154 
155   // Adds a Const node with vector value to the Graph.
156   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
157   // non-null if the method returns with an OK status.
158   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
159   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
160   template <typename T>
AddVector(const std::vector<T> & val,Node ** output)161   Status AddVector(const std::vector<T>& val, Node** output) {
162     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
163                           TensorShape({static_cast<int64>(val.size())}));
164     for (size_t i = 0; i < val.size(); i++) {
165       val_t.flat<T>()(i) = val[i];
166     }
167     AddTensorInternal(val_t, output);
168     if (*output == nullptr) {
169       return errors::Internal("AddVector: Failed to build Const op.");
170     }
171     return Status::OK();
172   }
173 
AddVector(const std::vector<string> & val,Node ** output)174   Status AddVector(const std::vector<string>& val, Node** output) {
175     Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
176                           TensorShape({static_cast<int64>(val.size())}));
177     for (size_t i = 0; i < val.size(); i++) {
178       val_t.flat<tstring>()(i) = val[i];
179     }
180     AddTensorInternal(val_t, output);
181     if (*output == nullptr) {
182       return errors::Internal("AddVector: Failed to build Const op.");
183     }
184     return Status::OK();
185   }
186 
187   // Adds a `Const` node for the given tensor value to the graph.
188   //
189   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
190   // non-null if the method returns with an OK status. The returned `Node`
191   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)192   Status AddTensor(const Tensor& val, Node** output) {
193     AddTensorInternal(val, output);
194     if (*output == nullptr) {
195       return errors::Internal("AddTensor: Failed to build Const op.");
196     }
197     return Status::OK();
198   }
199 
200   // Adds a `Placeholder` node for the given tensor value to the graph.
201   //
202   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
203   // non-null if the method returns with an OK status. The returned `Node`
204   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)205   Status AddPlaceholder(const Tensor& val, Node** output) {
206     AddPlaceholderInternal(val, output);
207     if (*output == nullptr) {
208       return errors::Internal(
209           "AddPlaceholder: Failed to build Placeholder op.");
210     }
211     return Status::OK();
212   }
213 
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,Node ** output)214   Status AddDataset(const DatasetBase* dataset,
215                     const std::vector<Node*>& inputs, Node** output) {
216     return AddDataset(dataset, inputs, {}, output);
217   }
218 
219   // Adds a node corresponding to the `DatasetType` to the Graph.
220   // Return value of `DatasetType::type_string()` is used as the op type for the
221   // node.
222   // Values for the output_types and output_shapes node attributes are also
223   // written if those attributes are defined in the OpDef.
224   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
225   // non-null if the method returns with an OK status.
226   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
AddDataset(const DatasetBase * dataset,const std::vector<Node * > & inputs,const std::vector<std::pair<StringPiece,AttrValue>> & attrs,Node ** output)227   Status AddDataset(const DatasetBase* dataset,
228                     const std::vector<Node*>& inputs,
229                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
230                     Node** output) {
231     std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
232     for (size_t i = 0; i < inputs.size(); i++) {
233       enumerated_inputs[i] = std::make_pair(i, inputs[i]);
234     }
235     return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
236   }
237 
238   Status AddDataset(
239       const DatasetBase* dataset,
240       const std::vector<std::pair<size_t, Node*>>& inputs,
241       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
242       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
243       Node** output);
244 
245   // Adds a user-defined function with name `function_name` to the graph and
246   // recursively adds all functions it references. If a function with a matching
247   // name has already been added, returns with OK status. If a user-defined with
248   // name `function_name` is not found in the context's function library,
249   // returns an InvalidArgumentError. If the function with name `function_name`
250   // or any of its dependent functions are stateful, and the context does not
251   // explicitly permit stateful functions, returns an InvalidArgument error.
252   Status AddFunction(SerializationContext* ctx, const string& function_name,
253                      const FunctionLibraryDefinition& lib_def);
254 
255   template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)256   void BuildAttrValue(const T& value, AttrValue* attr) {
257     SetAttrValue(value, attr);
258   }
259 
260  protected:
builder()261   GraphDefBuilder* builder() { return b_; }
262 
263  private:
264   void AddPlaceholderInternal(const Tensor& val, Node** output);
265   void AddTensorInternal(const Tensor& val, Node** output);
266   bool HasAttr(const string& op_type_name, const string& attr_name) const;
267 
HasAttr(const OpDef * op_def,const string & attr_name)268   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
269     for (const auto& attr : op_def->attr()) {
270       if (attr.name() == attr_name) {
271         return true;
272       }
273     }
274     return false;
275   }
276 
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)277   Status AddAttrFunctions(SerializationContext* ctx,
278                           const AttrValue& attr_value,
279                           const FunctionLibraryDefinition& lib_def) {
280     if (attr_value.has_func()) {
281       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
282     } else if (attr_value.has_list()) {
283       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
284         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
285       }
286     }
287     return Status::OK();
288   }
289 
290   GraphDefBuilder* b_;
291 };
292 
293 class StatsAggregator;
294 class FunctionHandleCache;
295 
296 // A utility class for running a function and ensuring that there is always a
297 // `tensorflow::data` symbol on the stack.
298 class Runner {
299  public:
~Runner()300   virtual ~Runner() {}
301 
302   // Runs the given function.
303   virtual void Run(const std::function<void()>& f) = 0;
304 
305   // Returns a global singleton Runner.
306   static Runner* get();
307 };
308 
309 // A class which provides a sequence of splits. Iterators created with a split
310 // provider will iterate over only the splits provided by the split provider.
311 class SplitProvider {
312  public:
~SplitProvider()313   virtual ~SplitProvider() {}
314   // Stores the next split in `*split`, setting `*end_of_splits` to indicate
315   // whether there were any splits left.
316   virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
317   // Resets the split provider to its beginning.
318   virtual Status Reset() = 0;
319   // Saves the state of this split provider.
320   virtual Status Save(std::function<std::string(std::string)> full_name,
321                       IteratorStateWriter* writer) = 0;
322   // Restores the state of this split provider.
323   virtual Status Restore(std::function<std::string(std::string)> full_name,
324                          IteratorStateReader* reader) = 0;
325 };
326 
327 // A cut-down version of `OpKernelContext` for running computations in
328 // iterators. Note that we cannot simply use `OpKernelContext` here because we
329 // might run computation in an iterator whose lifetime is not nested within the
330 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
331 //
332 // TODO(mrry): We're making some daring assumptions about the lifetime of the
333 // runner passed in here. A runner will be deleted when the original step ends,
334 // but all existing runners only close over session-lifetime (or longer-lived)
335 // state, so we can make a copy of the function. There's nothing in the
336 // definition of the API from which we took the runner to guarantee that what we
337 // are doing is safe. We should formalize the properties here.
338 class IteratorContext {
339  public:
340   struct Params {
ParamsParams341     explicit Params(IteratorContext* ctx)
342         : allocator_getter(ctx->allocator_getter()),
343           cancellation_manager(ctx->cancellation_manager()),
344           env(ctx->env()),
345           flr(ctx->flr()),
346           function_handle_cache(ctx->function_handle_cache()),
347           resource_mgr(ctx->resource_mgr()),
348           model(ctx->model()),
349           runner(*(ctx->runner())),
350           runner_threadpool_size(ctx->runner_threadpool_size()),
351           split_provider(ctx->split_provider()),
352           stats_aggregator(ctx->stats_aggregator()),
353           thread_factory(ctx->thread_factory()),
354           thread_pool(ctx->thread_pool()) {}
355 
ParamsParams356     explicit Params(OpKernelContext* ctx)
357         : env(ctx->env()), flr(ctx->function_library()) {
358       // NOTE: need reinterpret_cast because function.h forward-declares Device.
359       DeviceBase* device =
360           reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
361       allocator_getter = [device](AllocatorAttributes attrs) {
362         return device->GetAllocator(attrs);
363       };
364 
365       thread::ThreadPool* thread_pool =
366           ctx->device()->tensorflow_device_thread_pool();
367       if (thread_pool) {
368         runner_threadpool_size = thread_pool->NumThreads();
369       } else {
370         static const int32 kDefaultRunnerThreadpoolSize =
371             port::MaxParallelism();
372         runner_threadpool_size = kDefaultRunnerThreadpoolSize;
373       }
374 
375       // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
376       // that a symbol in the tensorflow::data namespace is always on the stack
377       // when executing a function inside a Dataset.
378       runner = std::bind(
379           [](
380               // Note: `runner` is a const reference to avoid copying it.
381               const std::function<void(std::function<void()>)>& ctx_runner,
382               std::function<void()> fn) {
383             std::function<void()> wrapped_fn = std::bind(
384                 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
385                 std::move(fn));
386             ctx_runner(std::move(wrapped_fn));
387           },
388           *ctx->runner(), std::placeholders::_1);
389     }
390 
391     // The Allocator to be used to allocate the output of an iterator.
392     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
393 
394     // The CancellationManager to be used to cancel execution of ops.
395     CancellationManager* cancellation_manager;
396 
397     // Interface to operating system functionality.
398     Env* env = nullptr;
399 
400     // The FunctionLibraryRuntime object to be used to make function calls.
401     FunctionLibraryRuntime* flr = nullptr;
402 
403     // A FunctionHandleCache that owns all the function handles. Not owned.
404     FunctionHandleCache* function_handle_cache = nullptr;
405 
406     // A resource manager for storing dataset-related state, e.g. random
407     // seeds or cached tensors. Not owned.
408     ResourceMgr* resource_mgr = nullptr;
409 
410     // If non-null, identifies the object used for performance modeling.
411     std::shared_ptr<model::Model> model = nullptr;
412 
413     // Function call support.
414     std::function<void(std::function<void()>)> runner = nullptr;
415 
416     // Number of threads used for executing user-defined functions.
417     int32 runner_threadpool_size = 0;
418 
419     // An optional split provider indicating which splits to process.
420     std::shared_ptr<SplitProvider> split_provider = nullptr;
421 
422     // The `StatsAggregator` object to record statistics about the iterator.
423     std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
424 
425     // A factory for creating threads to perform blocking work.
426     std::shared_ptr<ThreadFactory> thread_factory = nullptr;
427 
428     // A shared thread pool to schedule computation into.
429     thread::ThreadPoolInterface* thread_pool = nullptr;
430   };
431 
IteratorContext(IteratorContext * ctx)432   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
433 
IteratorContext(OpKernelContext * ctx)434   explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
435 
IteratorContext(Params params)436   explicit IteratorContext(Params params) : params_(std::move(params)) {}
437 
allocator(AllocatorAttributes attrs)438   Allocator* allocator(AllocatorAttributes attrs) {
439     return params_.allocator_getter(attrs);
440   }
441 
allocator_getter()442   std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
443     return params_.allocator_getter;
444   }
445 
cancellation_manager()446   CancellationManager* cancellation_manager() {
447     return params_.cancellation_manager;
448   }
449 
env()450   Env* env() const { return params_.env; }
451 
flr()452   FunctionLibraryRuntime* flr() { return params_.flr; }
453 
function_handle_cache()454   FunctionHandleCache* function_handle_cache() {
455     return params_.function_handle_cache;
456   }
457 
resource_mgr()458   ResourceMgr* resource_mgr() { return params_.resource_mgr; }
459 
model()460   const std::shared_ptr<model::Model>& model() { return params_.model; }
461 
runner()462   std::function<void(std::function<void()>)>* runner() {
463     return &params_.runner;
464   }
465 
runner_threadpool_size()466   int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
467 
split_provider()468   std::shared_ptr<SplitProvider> split_provider() {
469     return params_.split_provider;
470   }
471 
stats_aggregator()472   std::shared_ptr<StatsAggregator> stats_aggregator() {
473     return params_.stats_aggregator;
474   }
475 
thread_factory()476   const std::shared_ptr<ThreadFactory>& thread_factory() {
477     return params_.thread_factory;
478   }
479 
thread_pool()480   thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
481 
params()482   Params params() { return params_; }
483 
CreateThreadPool(const string & name,int num_threads)484   std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
485                                                        int num_threads) {
486     if (params_.thread_pool) {
487       // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
488       // is an instance of `thread::ThreadPoolInterface`). Notably, the
489       // ownership of `params_.thread_pool` is *not* transferred onto the newly
490       // created `ThreadPool` instance.
491       return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
492     } else {
493       return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
494                                                    name, num_threads,
495                                                    /*low_latency_hint=*/false);
496     }
497   }
498 
StartThread(const string & name,std::function<void ()> fn)499   std::unique_ptr<Thread> StartThread(const string& name,
500                                       std::function<void()> fn) {
501     if (params_.thread_factory) {
502       return params_.thread_factory->StartThread(name, std::move(fn));
503     } else {
504       return absl::WrapUnique(
505           Env::Default()->StartThread({}, name, std::move(fn)));
506     }
507   }
508 
509  private:
510   Params params_;
511 };
512 
513 // Aggregates runtime support needed for dataset and iterator serialization.
514 class SerializationContext {
515  public:
516   // Enum describing what to do during serialization when external state is
517   // encountered.
518   enum class ExternalStatePolicy : int64 {
519     // Proceed with serialization, but log a warning about what state will be
520     // lost.
521     kWarn = 0,
522     // Proceed with serialization without logging any warning.
523     kIgnore = 1,
524     // Fail the serialization with an error.
525     kFail = 2,
526   };
527 
528   // Handles the CheckExternalState status according to the external state
529   // policy.
HandleCheckExternalStateStatus(Status s)530   Status HandleCheckExternalStateStatus(Status s) {
531     if (s.ok()) {
532       return s;
533     }
534     switch (params_.external_state_policy) {
535       case ExternalStatePolicy::kWarn:
536         LOG(WARNING) << s.ToString();
537         return Status::OK();
538       case ExternalStatePolicy::kIgnore:
539         VLOG(2) << "Ignoring error status: " << s.ToString();
540         return Status::OK();
541       case ExternalStatePolicy::kFail:
542         return s;
543     }
544     LOG(FATAL) << "Control should never reach here";
545   }
546 
547   struct Params {
548     std::vector<std::pair<string, Tensor>>* input_list = nullptr;  // Not owned.
549 
550     // Indicates what to do if the dataset depends on external state.
551     ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
552 
553     // Indicates whether an attempt to serialize a dataset that does not
554     // implement serialization should result in an error. If set to `false`, the
555     // serialized graph will replace the dataset with a placeholder returned in
556     // `input_list`.
557     bool fail_if_unimplemented = true;
558 
559     // Indicates whether (potentially large) data tensors should be
560     // serialized, or replaced with a placeholder returned in `input_list`. The
561     // latter makes sense to do when performing data agnostic graph rewrites to
562     // reduce the memory usage.
563     bool serialize_data_tensors = true;
564 
565     // Indicates whether datasets that use random seeds should have the values
566     // of random seeds serialized or not. If the values of random seeds are
567     // serialized, the deserialized dataset will have the same seeds as the
568     // original dataset. Otherwise, the deserialized dataset will use different
569     // seeds. This param does not affect datasets that use fixed seeds; fixed
570     // seeds will always be preserved.
571     bool preserve_random_seeds = true;
572   };
573 
SerializationContext(Params params)574   explicit SerializationContext(Params params) : params_(params) {}
575 
input_list()576   std::vector<std::pair<string, Tensor>>* input_list() {
577     return params_.input_list;
578   }
579 
external_state_policy()580   ExternalStatePolicy external_state_policy() const {
581     return params_.external_state_policy;
582   }
583 
fail_if_unimplemented()584   bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
585 
serialize_data_tensors()586   bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
587 
preserve_random_seeds()588   bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
589 
590  private:
591   Params params_;
592 
593   TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
594 };
595 
596 // Represents the current position in a range of outputs, where the
597 // range of outputs is typically represented by an `DatasetBase`,
598 // defined below.
599 class IteratorBase {
600  public:
~IteratorBase()601   virtual ~IteratorBase() {
602     for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
603       (*rit)();
604     }
605   }
606 
607   // Gets the next output from the range that this iterator is traversing.
608   //
609   // If at least one output remains in this iterator's range, that
610   // output will be stored in `*out_tensors` and `false` will be
611   // stored in `*end_of_sequence`.
612   //
613   // If no more outputs remain in this iterator's range, `true` will
614   // be stored in `*end_of_sequence`, and the content of
615   // `*out_tensors` will be undefined.
616   //
617   // Implementations should never return `OutOfRange` error. If at end of
618   // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
619   // Internally raised `OutOfRange` errors that do not imply end of sequence
620   // should be converted to a different error type before being propagated to
621   // the caller.
622   //
623   // This method is thread-safe.
624   //
625   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
626   // potentially remove this method.
627   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
628                          bool* end_of_sequence) = 0;
629 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)630   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
631                  bool* end_of_sequence) {
632     return GetNext(&ctx, out_tensors, end_of_sequence);
633   }
634 
635   // Skips the next `num_to_skip` outputs from the range that this iterator
636   // is traversing.
637   //
638   // If there are not enough outputs to skip, it will set
639   // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
640   // store the number of outputs that are skipped. When `*end_of_sequence` is
641   // `false`, `*num_skipped` should equal to `num_to_skip`.
642   virtual Status Skip(IteratorContext* ctx, int num_to_skip,
643                       bool* end_of_sequence, int* num_skipped) = 0;
644 
645   // Returns a vector of DataType values, representing the respective
646   // element types of each tuple component in the outputs of this
647   // iterator.
648   virtual const DataTypeVector& output_dtypes() const = 0;
649 
650   // Returns a vector of tensor shapes, representing the respective
651   // (and possibly partially defined) shapes of each tuple component
652   // in the outputs of this iterator.
653   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
654 
655   // Returns a string that identifies the sequence of iterators leading up to
656   // this iterator.
657   virtual const string& prefix() const = 0;
658 
659   // Performs initialization that needs to happen outside of a constructor to
660   // properly propagate errors.
Initialize(IteratorContext * ctx)661   virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
662 
663   // Performs initialization of the base iterator.
664   Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
665 
666   // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)667   virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
668     int64 start_us = EnvTime::NowMicros();
669     TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
670     VLOG(1) << "Saved " << prefix() << " in "
671             << (EnvTime::NowMicros() - start_us) << "us";
672     return Status::OK();
673   }
674 
675  protected:
676   // Returns a node that models this iterator.
677   virtual std::shared_ptr<model::Node> CreateNode(
678       IteratorContext* ctx, model::Node::Args args) const = 0;
679 
680   // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)681   Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
682     int64 start_us = EnvTime::NowMicros();
683     TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
684     VLOG(1) << "Restored " << prefix() << " in "
685             << (EnvTime::NowMicros() - start_us) << "us";
686     return Status::OK();
687   }
688 
689   // This is needed so that sub-classes of IteratorBase can call
690   // `SaveInternal` on their input iterators.
SaveInput(SerializationContext * ctx,IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)691   Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
692                    const std::unique_ptr<IteratorBase>& input) {
693     int64 start_us = EnvTime::NowMicros();
694     TF_RETURN_IF_ERROR(input->SaveInternal(ctx, writer));
695     VLOG(2) << "Saved " << input->prefix() << " in "
696             << (EnvTime::NowMicros() - start_us) << "us";
697     return Status::OK();
698   }
699 
700   // This is needed so that sub-classes of IteratorBase can call
701   // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)702   Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
703                       const std::unique_ptr<IteratorBase>& input) {
704     int64 start_us = EnvTime::NowMicros();
705     TF_RETURN_IF_ERROR(input->RestoreInternal(ctx, reader));
706     VLOG(2) << "Restored " << input->prefix() << " in "
707             << (EnvTime::NowMicros() - start_us) << "us";
708     return Status::OK();
709   }
710 
RestoreInput(IteratorContext && ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)711   Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
712                       const std::unique_ptr<IteratorBase>& input) {
713     return RestoreInput(&ctx, reader, input);
714   }
715 
716   // Saves the state of this iterator.
717   //
718   // This method is used to store the state of the iterator in a checkpoint.
719   // implementations have an override.
720   virtual Status SaveInternal(SerializationContext* ctx,
721                               IteratorStateWriter* writer) = 0;
722 
723   // Restores the state of this iterator.
724   //
725   // This method is used to restore the state of the iterator from a checkpoint.
726   //
727   // Implementations may assume that the iterator is in a clean state. That is,
728   // its `Initialize` method has been called, but its `GetNext` method has
729   // never been called.
730   // implementations have an override.
731   virtual Status RestoreInternal(IteratorContext* ctx,
732                                  IteratorStateReader* reader) = 0;
733 
734   // Returns a pointer to the node representing this iterator in the performance
735   // model. It may be null, if performance modeling is not enabled for this
736   // iterator.
model_node()737   std::shared_ptr<model::Node> model_node() const { return node_; }
738 
739   // Returns the number of elements produced by this iterator.
num_elements()740   int64 num_elements() const {
741     if (node_) return node_->num_elements();
742     return 0;
743   }
744 
745  private:
746   // For access to `AddCleanupFunction` and `Restore`.
747   friend class DatasetBase;
748   friend class DatasetBaseIterator;  // for access to `node_`
749 
750   std::vector<std::function<void()>> cleanup_fns_;
751   std::shared_ptr<model::Node> node_ = nullptr;
752   const IteratorBase* parent_ = nullptr;  // Not owned.
753   int64 id_ = 0;
754   int64 parent_id_ = 0;
755 };
756 
757 // Represents runtime information needed to construct a dataset.
758 class DatasetContext {
759  public:
760   struct Params {
761     string type_string;  // op type name of this dataset.
762     string node_name;    // graph node name of this dataset op, uniquely
763                          // identifying the dataset in the graph.
764   };
765 
DatasetContext(Params params)766   explicit DatasetContext(Params params) : params_(std::move(params)) {}
767 
DatasetContext(OpKernelContext * ctx)768   explicit DatasetContext(OpKernelContext* ctx) {
769     params_.type_string = ctx->op_kernel().type_string();
770     params_.node_name = ctx->op_kernel().name();
771   }
772 
type_string()773   const string& type_string() const { return params_.type_string; }
node_name()774   const string& node_name() const { return params_.node_name; }
775 
776  private:
777   Params params_;
778 };
779 
780 // Returns the number of bytes allocated for the given tensor.
781 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
782 
783 // Returns the estimated memory usage in bytes of the given tensor.
784 int64 GetTotalBytes(const std::vector<Tensor>& element);
785 
786 // Validates and extracts a `DatasetBase` object from `tensor`.
787 //
788 // `tensor` must have been written by a call to SetVariantTensorToDataset().
789 //
790 // The retrieved pointer is a borrowed reference to the dataset, which is owned
791 // by the tensor. The consumer must either acquire its own reference to the
792 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
793 // destroyed or mutated while the retrieved pointer is in use.
794 Status GetDatasetFromVariantTensor(const Tensor& tensor,
795                                    DatasetBase** out_dataset);
796 
797 // Stores a `DatasetBase` object in `tensor`.
798 //
799 // The ownership of `dataset` is transferred to `tensor`.
800 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
801 
802 // Represents a (potentially infinite) range of outputs, where each
803 // output is a tuple of tensors.
804 class DatasetBase : public core::RefCounted {
805  public:
806   // Key for storing the Dataset graph in the serialized format.
807   TF_EXPORT static const char kDatasetGraphKey[];
808 
809   // Key for storing the output node of the Dataset graph in the serialized
810   // format.
811   TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
812 
DatasetBase(DatasetContext && ctx)813   explicit DatasetBase(DatasetContext&& ctx)
814       : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
815 
816   // Op type name of this dataset.
type_string()817   const string& type_string() const { return type_string_; }
818 
819   // Graph node name of this dataset op, uniquely identifying the dataset in
820   // the graph.
node_name()821   const string& node_name() const { return node_name_; }
822 
823   // Returns a new iterator for iterating over the range of elements in
824   // this dataset.
825   //
826   // This method may be called multiple times on the same instance,
827   // and the resulting iterators will have distinct state. Each
828   // iterator will traverse all elements in this dataset from the
829   // start.
830   //
831   // The prefix identifies the sequence of iterators leading up to the newly
832   // created iterator.
833   Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
834                       const string& output_prefix,
835                       std::unique_ptr<IteratorBase>* iterator) const;
836 
MakeIterator(IteratorContext && ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)837   Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
838                       const string& output_prefix,
839                       std::unique_ptr<IteratorBase>* iterator) const {
840     return MakeIterator(&ctx, parent, output_prefix, iterator);
841   }
842 
843   // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)844   Status MakeIteratorFromCheckpoint(
845       IteratorContext* ctx, const string& output_prefix,
846       IteratorStateReader* reader,
847       std::unique_ptr<IteratorBase>* iterator) const {
848     std::unique_ptr<IteratorBase> it;
849     TF_RETURN_IF_ERROR(
850         MakeIterator(ctx, /*parent=*/nullptr, output_prefix, &it));
851     TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
852     *iterator = std::move(it);
853     return Status::OK();
854   }
855 
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)856   Status MakeIteratorFromCheckpoint(
857       IteratorContext&& ctx, const string& output_prefix,
858       IteratorStateReader* reader,
859       std::unique_ptr<IteratorBase>* iterator) const {
860     return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
861   }
862 
863   // Returns a split provider which partitions the dataset's data into splits
864   // and provides them in a sequence. The split provider is stored in
865   // `*split_provider`.
866   virtual Status MakeSplitProvider(
867       std::unique_ptr<SplitProvider>* split_provider) const;
868 
869   // Returns a vector of DataType values, representing the respective
870   // element types of each tuple component in the outputs of this
871   // dataset.
872   virtual const DataTypeVector& output_dtypes() const = 0;
873 
874   // Returns a vector of tensor shapes, representing the respective
875   // (and possibly partially defined) shapes of each tuple component
876   // in the outputs of this dataset.
877   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
878 
879   // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()880   virtual int64 AllocatedBytes() const { return 0; }
881 
882   // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()883   virtual int64 TotalBytes() const { return 0; }
884 
885   // Returns the cardinality of this dataset.
Cardinality()886   virtual int64 Cardinality() const { return kUnknownCardinality; }
887 
888   // A human-readable debug string for this dataset.
889   virtual string DebugString() const = 0;
890 
891   // Stores the dataset's input datasets in `*inputs`. The pointers stored in
892   // `*inputs` are borrowed. The only valid non-ok return status is
893   // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
894   // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
895   // default implementation of `MakeSplitProvider` when there is a single input
896   // dataset.
897   virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
898 
899   // Indicates whether the dataset depends on any external state which would
900   // prevent it from being serializable. If so, the method returns
901   // `errors::FailedPrecondition` with a message that identifies the external
902   // state. Otherwise, the method returns `Status::OK()`.
903   virtual Status CheckExternalState() const = 0;
904 
905  protected:
906   friend Status AsGraphDef(
907       OpKernelContext* ctx, const DatasetBase* dataset,
908       SerializationContext&& serialization_ctx,
909       GraphDef* graph_def);  // For access to graph related members.
910   friend class CapturedFunction;
911 
912   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
913    public:
DatasetGraphDefBuilder(GraphDefBuilder * b)914     explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
915         : GraphDefBuilderWrapper(b) {}
916     Status AddInputDataset(SerializationContext* ctx,
917                            const DatasetBase* dataset, Node** output);
918     Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
919                               Node** output);
920 
921    private:
922     Status AddDatasetOrTensorHelper(SerializationContext* ctx,
923                                     const Tensor& val, Node** output);
924   };
925 
926   // Serializes the dataset into a `GraphDef`, which has two uses:
927   //
928   // 1) To perform static input pipeline optimizations, tf.data serializes the
929   // dataset graph, applies graph rewrites, and then deserializes the graph.
930   // If a subclass of `DatasetBase` does not implement this method, then it will
931   // be excluded from static optimizations (and so will any upstream datasets).
932   //
933   // 2) To save the dataset so that it can restore at a later point (possibly in
934   // different environment). If a subclass of `DatasetBase` does not implement
935   // this method, then this migration will not be possible.
936   virtual Status AsGraphDefInternal(SerializationContext* ctx,
937                                     DatasetGraphDefBuilder* b,
938                                     Node** node) const = 0;
939 
940   virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
941       const string& prefix) const = 0;
942 
943  private:
944   const string type_string_;
945   const string node_name_;
946 };
947 
948 // Represents an iterator that is associated with a particular dataset.
949 class DatasetBaseIterator : public IteratorBase {
950  public:
951   struct BaseParams {
952     // Owns one reference on the shared dataset object.
953     const DatasetBase* dataset;
954 
955     // Identifies the sequence of iterators leading up to this iterator.
956     const string prefix;
957   };
958 
959   explicit DatasetBaseIterator(const BaseParams& params);
960 
961   ~DatasetBaseIterator() override;
962 
dataset()963   virtual const DatasetBase* dataset() const { return params_.dataset; }
964 
output_dtypes()965   const DataTypeVector& output_dtypes() const override {
966     return params_.dataset->output_dtypes();
967   }
968 
output_shapes()969   const std::vector<PartialTensorShape>& output_shapes() const override {
970     return params_.dataset->output_shapes();
971   }
972 
prefix()973   const string& prefix() const override { return params_.prefix; }
974 
975   // Returns a name to be used for the TraceMe event.
976   //
977   // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
978   // following format "name#arg_1=value_,...,arg_n=value_n".
979   string BuildTraceMeName();
980 
981   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
982                  bool* end_of_sequence) final;
983 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)984   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
985                  bool* end_of_sequence) {
986     return GetNext(&ctx, out_tensors, end_of_sequence);
987   }
988 
989   Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
990               int* num_skipped) final;
991 
Save(SerializationContext * ctx,IteratorStateWriter * writer)992   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
993     return IteratorBase::Save(ctx, writer);
994   }
995 
996  protected:
997   // Internal implementation of GetNext that is wrapped in tracing logic.
998   virtual Status GetNextInternal(IteratorContext* ctx,
999                                  std::vector<Tensor>* out_tensors,
1000                                  bool* end_of_sequence) = 0;
1001 
1002   // Internal implementation of Skip that is wrapped in tracing logic
1003   virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1004                               bool* end_of_sequence, int* num_skipped);
1005 
full_name(const string & name)1006   string full_name(const string& name) const {
1007     return FullName(params_.prefix, name);
1008   }
1009 
1010   // Returns a map of key-value pairs to included in the TraceMe string.
GetTraceMeMetadata()1011   virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1012 
1013   // By default we model iterators using an unknown node, which acts as
1014   // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)1015   std::shared_ptr<model::Node> CreateNode(
1016       IteratorContext* ctx, model::Node::Args args) const override {
1017     return model::MakeUnknownNode(std::move(args));
1018   }
1019 
1020   // When modeling is enabled, this method disables autotuning for the given
1021   // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)1022   void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1023     if (iterator->node_) {
1024       iterator->node_->set_autotune(false);
1025     }
1026   }
1027 
1028   // When modeling is enabled, this method enables autotuning for the given
1029   // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)1030   void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1031     if (iterator->node_) {
1032       iterator->node_->set_autotune(true);
1033     }
1034   }
1035 
1036   // When modeling is enabled, this method records the fact that this iterator
1037   // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)1038   void RecordBufferDequeue(IteratorContext* ctx,
1039                            const std::vector<Tensor>& element) {
1040     if (collect_resource_usage(ctx)) {
1041       node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1042     }
1043   }
1044 
1045   // When modeling is enabled, this method records the fact that this iterator
1046   // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)1047   void RecordBufferEnqueue(IteratorContext* ctx,
1048                            const std::vector<Tensor>& element) {
1049     if (collect_resource_usage(ctx)) {
1050       node_->record_buffer_event(GetAllocatedBytes(element), 1);
1051     }
1052   }
1053 
1054   // When modeling is enabled, this method records the fact that this iterator
1055   // has produced an element and its size in bytes.
RecordElement(IteratorContext * ctx,std::vector<Tensor> * out_tensors)1056   void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1057     if (node_) {
1058       int64 num_bytes = GetAllocatedBytes(*out_tensors);
1059       node_->record_element();
1060       node_->record_bytes_produced(num_bytes);
1061       if (node_->output()) {
1062         node_->output()->record_bytes_consumed(num_bytes);
1063       }
1064     }
1065   }
1066 
1067   // When modeling is enabled, this method records the fact that a thread of
1068   // this iterator has started work.
RecordStart(IteratorContext * ctx)1069   void RecordStart(IteratorContext* ctx) {
1070     if (collect_resource_usage(ctx)) {
1071       int64 now_nanos = EnvTime::NowNanos();
1072       node_->record_start(now_nanos);
1073     }
1074   }
1075 
1076   // When modeling is enabled, this method records the fact that a thread of
1077   // this iterator has stopped work.
RecordStop(IteratorContext * ctx)1078   void RecordStop(IteratorContext* ctx) {
1079     if (collect_resource_usage(ctx)) {
1080       int64 now_nanos = EnvTime::NowNanos();
1081       node_->record_stop(now_nanos);
1082     }
1083   }
1084 
1085   // Returns whether work is currently being recorded, i.e. whether we are
1086   // currently between a `RecordStart` and a `RecordStop`.
IsRecording(IteratorContext * ctx)1087   bool IsRecording(IteratorContext* ctx) {
1088     return collect_resource_usage(ctx) && node_->is_recording();
1089   }
1090 
1091  private:
collect_resource_usage(IteratorContext * ctx)1092   bool collect_resource_usage(IteratorContext* ctx) {
1093     auto model = ctx->model();
1094     return model && model->collect_resource_usage() && node_;
1095   }
1096 
1097   BaseParams params_;
1098 };
1099 
1100 // Represents an iterator that is associated with a particular dataset
1101 // with a particular type.
1102 template <class DatasetType>
1103 class DatasetIterator : public DatasetBaseIterator {
1104  public:
1105   struct Params {
1106     // Borrowed pointer to the dataset.
1107     const DatasetType* dataset;
1108 
1109     // Identifies the sequence of iterators leading up to this iterator.
1110     const string prefix;
1111   };
1112 
DatasetIterator(const Params & params)1113   explicit DatasetIterator(const Params& params)
1114       : DatasetBaseIterator({params.dataset, params.prefix}),
1115         typed_dataset_(params.dataset) {}
1116 
1117   // The dataset from which this iterator was created.
dataset()1118   const DatasetType* dataset() const final { return typed_dataset_; }
1119 
1120  private:
1121   const DatasetType* const typed_dataset_;  // Not owned.
1122 };
1123 
1124 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1125 Status ParseScalarArgument(OpKernelContext* ctx,
1126                            const StringPiece& argument_name, T* output) {
1127   const Tensor* argument_t;
1128   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1129   if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1130     return errors::InvalidArgument(argument_name, " must be a scalar");
1131   }
1132   *output = argument_t->scalar<T>()();
1133   return Status::OK();
1134 }
1135 
1136 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1137 Status ParseVectorArgument(OpKernelContext* ctx,
1138                            const StringPiece& argument_name,
1139                            std::vector<T>* output) {
1140   const Tensor* argument_t;
1141   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1142   if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1143     return errors::InvalidArgument(argument_name, " must be a vector");
1144   }
1145   int size = argument_t->vec<T>().size();
1146   output->reserve(size);
1147   for (int i = 0; i < size; ++i) {
1148     output->push_back(argument_t->vec<T>()(i));
1149   }
1150   return Status::OK();
1151 }
1152 
1153 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1154 // graph execution engine.
1155 class DatasetOpKernel : public OpKernel {
1156  public:
DatasetOpKernel(OpKernelConstruction * ctx)1157   explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1158 
1159   void Compute(OpKernelContext* ctx) final;
1160 
1161   // Indicates whether the given op corresponds to an op whose kernels subclass
1162   // the `DatasetOpKernel` class.
1163   static bool IsDatasetOp(const OpDef* op_def);
1164 
1165   string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1166 
1167  protected:
1168   // Subclasses should implement this method. It will be called during Compute
1169   // execution.
1170   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1171 };
1172 
1173 // Encapsulates the work required to plug unary Datasets into the core
1174 // TensorFlow graph execution engine.
1175 class UnaryDatasetOpKernel : public DatasetOpKernel {
1176  public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1177   explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1178       : DatasetOpKernel(ctx) {}
1179 
1180  protected:
1181   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1182   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1183                            DatasetBase** output) = 0;
1184 };
1185 
1186 // Encapsulates the work required to plug binary Datasets into the core
1187 // TensorFlow graph execution engine.
1188 class BinaryDatasetOpKernel : public DatasetOpKernel {
1189  public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1190   explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1191       : DatasetOpKernel(ctx) {}
1192 
1193  protected:
1194   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1195   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1196                            DatasetBase* another_input,
1197                            DatasetBase** output) = 0;
1198 };
1199 
1200 // A simple background worker that executes closures asynchronously and without
1201 // blocking.
1202 //
1203 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1204 // to avoid blocking an executor thread that may be required by the blocking
1205 // work.
1206 //
1207 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1208 // purpose because its current implementation (in Eigen) uses a finite-length
1209 // queue and will block the caller when full. This can lead to deadlock under
1210 // heavy load. Since the number of concurrent work items in each user of a
1211 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1212 // overhead is tolerable.
1213 class BackgroundWorker {
1214  public:
1215   BackgroundWorker(Env* env, const char* name);
1216 
1217   ~BackgroundWorker();
1218 
1219   void Schedule(std::function<void()> work_item);
1220 
1221  private:
1222   void WorkerLoop();
1223 
1224   Env* const env_;
1225   const char* const name_;
1226 
1227   std::unique_ptr<Thread> thread_;
1228   mutex mu_;
1229   condition_variable cond_var_;
1230   bool cancelled_ TF_GUARDED_BY(mu_) = false;
1231   std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1232 };
1233 
1234 // Registry of names of ops whose kernels subclass the `DatasetOpKernel` class.
1235 class DatasetOpRegistry {
1236  public:
1237   // Registers the op name.
1238   static void Register(const string& op_name);
1239 
1240   // Checks whether the given op name has been registered.
1241   static bool IsRegistered(const string& op_name);
1242 };
1243 
1244 // Helper class to register dataset op name.
1245 class DatasetOpRegistrar {
1246  public:
DatasetOpRegistrar(const string & op_name)1247   explicit DatasetOpRegistrar(const string& op_name) {
1248     DatasetOpRegistry::Register(op_name);
1249   }
1250 };
1251 
1252 // Macro that can be used to register an op name of a dataset op.
1253 #define REGISTER_DATASET_OP_NAME(op_name) \
1254   REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, op_name)
1255 
1256 #define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, op_name) \
1257   REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name)
1258 
1259 #define REGISTER_DATASET_OP_NAME_UNIQ(ctr, op_name) \
1260   static ::tensorflow::data::DatasetOpRegistrar     \
1261       registrar__body__##ctr##__object(op_name)
1262 
1263 }  // namespace data
1264 }  // namespace tensorflow
1265 
1266 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1267