• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/attr_value_util.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/framework/dataset_options.pb.h"
28 #include "tensorflow/core/framework/dataset_stateful_op_allowlist.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/function_handle_cache.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/model.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/thread_factory.h"
37 #include "tensorflow/core/framework/types.pb.h"
38 #include "tensorflow/core/framework/variant_encode_decode.h"
39 #include "tensorflow/core/framework/variant_tensor_data.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/core/threadpool.h"
42 #include "tensorflow/core/lib/core/threadpool_interface.h"
43 #include "tensorflow/core/lib/strings/str_util.h"
44 #include "tensorflow/core/lib/strings/strcat.h"
45 #include "tensorflow/core/platform/cpu_info.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/tracing.h"
48 
49 // Polymorphic datasets should support all primitive TensorFlow
50 // types. Use this macro to expand `m(T)` once for each primitive type
51 // `T`, e.g. to build a `switch` statement.
52 #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
53 
54 namespace tensorflow {
55 
56 // Forward declarations to avoid introducing a dependency on headers in
57 // "tensorflow/core/graph/...".
58 class GraphDefBuilder;
59 class Node;
60 
61 namespace data {
62 
63 namespace internal {
64 // Merges Options from source to destination. If there is a conflict on a field,
65 // the field value from the source takes precedence.
66 void MergeOptions(const protobuf::Message& source,
67                   protobuf::Message* destination);
68 void MergeOptions(const protobuf::MessageLite& source,
69                   protobuf::MessageLite* destination);
70 }  // namespace internal
71 
72 using TraceMeMetadata = std::vector<std::pair<StringPiece, string>>;
73 
74 constexpr char kTFDataFunction[] = "_tf_data_function";
75 
76 constexpr int kInfiniteCardinality = -1;
77 constexpr int kUnknownCardinality = -2;
78 
79 // This constant is a magic number that is used (as a prefix) to identify keys
80 // used for serialization of iterator state.
81 constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b";
82 constexpr char kPipe[] = "|";
83 constexpr char kColon[] = ":";
84 
85 constexpr char kTFDataResourceTag[] = "tfdata";
86 constexpr char kTraceInfoUnavailable[] = "unavailable";
87 
88 class DatasetBase;
89 class SerializationContext;
90 
IsTFDataFunction(const FunctionDef & func)91 inline bool IsTFDataFunction(const FunctionDef& func) {
92   auto iter = func.attr().find(data::kTFDataFunction);
93   return (iter != func.attr().end() && iter->second.b());
94 }
95 
96 // Interface for reading values from a key-value store.
97 // Used for restoring iterator state. This class is thread safe.
98 // Please see comment on IteratorStateWriter for guidance around using the
99 // Read*(key, val) vs Read*(name, key, val).
100 class IteratorStateReader {
101  public:
102   // Determines whether the iterator state contains the given key.
103   virtual bool Contains(StringPiece key) const = 0;
104   virtual bool Contains(StringPiece name, StringPiece key) const = 0;
105 
106   // Reads an integer for the given key.
107   virtual Status ReadScalar(StringPiece key, int64* val) const = 0;
108   virtual Status ReadScalar(StringPiece name, StringPiece key,
109                             int64* val) const = 0;
110 
111   // Reads a string for the given key.
112   virtual Status ReadScalar(StringPiece key, tstring* val) const = 0;
113   virtual Status ReadScalar(StringPiece name, StringPiece key,
114                             tstring* val) const = 0;
115 
116   // Reads a tensor for the given key.
117   // TODO(jsimsa): Remove non-FLR overrides once all callers are updated.
118   virtual Status ReadTensor(StringPiece key, Tensor* val) const = 0;
119   virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key,
120                             Tensor* val) const = 0;
121   virtual Status ReadTensor(StringPiece name, StringPiece key,
122                             Tensor* val) const = 0;
123   virtual Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name,
124                             StringPiece key, Tensor* val) const = 0;
125 
~IteratorStateReader()126   virtual ~IteratorStateReader() {}
127 };
128 
129 // Interface for writing values to a key-value store.
130 // Used for saving iterator state. Not thread safe.
131 // The IteratorStateWriter creates a tensor for each unique iterator name it
132 // sees. For the Write*(key, val) API's the key is expected to encode this
133 // name as keys are required to be produced using the full_name() method.
134 // Each tensor has an upper limit of 2 GB and so if the state for an iterator
135 // might exceed the 2 GB limit, you can pass an explicit name in via the
136 // Write*(name, key, val) APIs allowing you to further split up the state
137 // into more manageable chunks.
138 class IteratorStateWriter {
139  public:
140   // Writes an integer for the given key.
141   virtual Status WriteScalar(StringPiece key, const int64_t val) = 0;
142   virtual Status WriteScalar(StringPiece name, StringPiece key,
143                              const int64_t val) = 0;
144 
145   // Writes a string for the given key.
146   virtual Status WriteScalar(StringPiece key, const tstring& val) = 0;
147   virtual Status WriteScalar(StringPiece name, StringPiece key,
148                              const tstring& val) = 0;
149 
150   // Writes a tensor for the given key.
151   virtual Status WriteTensor(StringPiece key, const Tensor& val) = 0;
152   virtual Status WriteTensor(StringPiece name, StringPiece key,
153                              const Tensor& val) = 0;
154 
~IteratorStateWriter()155   virtual ~IteratorStateWriter() {}
156 };
157 
158 // Generates a full name key for iterator checkpointing. All keys generated for
159 // iterator checkpoints should go through this function.
160 std::string FullName(const std::string& prefix, const std::string& name);
161 
162 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
163 class GraphDefBuilderWrapper {
164  public:
GraphDefBuilderWrapper(GraphDefBuilder * b)165   explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {}
166 
167   // Adds a Const node with scalar value to the Graph.
168   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
169   // non-null if the method returns with an OK status.
170   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
171   template <typename T>
AddScalar(const T & val,Node ** output)172   Status AddScalar(const T& val, Node** output) {
173     Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
174     val_t.scalar<T>()() = val;
175     AddTensorInternal(val_t, output);
176     if (*output == nullptr) {
177       return errors::Internal("AddScalar: Failed to build Const op.");
178     }
179     return Status::OK();
180   }
181 
182   // Adds a Const node with vector value to the Graph.
183   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
184   // non-null if the method returns with an OK status.
185   // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
186   // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice?
187   template <typename T>
AddVector(const std::vector<T> & val,Node ** output)188   Status AddVector(const std::vector<T>& val, Node** output) {
189     Tensor val_t = Tensor(DataTypeToEnum<T>::v(),
190                           TensorShape({static_cast<int64>(val.size())}));
191     for (size_t i = 0; i < val.size(); i++) {
192       val_t.flat<T>()(i) = val[i];
193     }
194     AddTensorInternal(val_t, output);
195     if (*output == nullptr) {
196       return errors::Internal("AddVector: Failed to build Const op.");
197     }
198     return Status::OK();
199   }
200 
AddVector(const std::vector<string> & val,Node ** output)201   Status AddVector(const std::vector<string>& val, Node** output) {
202     Tensor val_t = Tensor(DataTypeToEnum<tstring>::v(),
203                           TensorShape({static_cast<int64>(val.size())}));
204     for (size_t i = 0; i < val.size(); i++) {
205       val_t.flat<tstring>()(i) = val[i];
206     }
207     AddTensorInternal(val_t, output);
208     if (*output == nullptr) {
209       return errors::Internal("AddVector: Failed to build Const op.");
210     }
211     return Status::OK();
212   }
213 
214   // Adds a `Const` node for the given tensor value to the graph.
215   //
216   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
217   // non-null if the method returns with an OK status. The returned `Node`
218   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddTensor(const Tensor & val,Node ** output)219   Status AddTensor(const Tensor& val, Node** output) {
220     AddTensorInternal(val, output);
221     if (*output == nullptr) {
222       return errors::Internal("AddTensor: Failed to build Const op.");
223     }
224     return Status::OK();
225   }
226 
227   // Adds a `Placeholder` node for the given tensor value to the graph.
228   //
229   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
230   // non-null if the method returns with an OK status. The returned `Node`
231   // pointer is owned by the backing graph of `GraphDefBuilder`.
AddPlaceholder(const Tensor & val,Node ** output)232   Status AddPlaceholder(const Tensor& val, Node** output) {
233     AddPlaceholderInternal(val, output);
234     if (*output == nullptr) {
235       return errors::Internal(
236           "AddPlaceholder: Failed to build Placeholder op.");
237     }
238     return Status::OK();
239   }
240 
241   // Adds a node for the given dataset to the `Graph`. The value of
242   // `DatasetBase::type_string()` is used as the op type for the node. Values
243   // for the `output_types` and `output_shapes` node attributes are also written
244   // if those attributes are defined in the `OpDef`.
245   //
246   // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is
247   // used as the op name for the node. This argument should only be set when
248   // serializing `DatasetBase` instances which might not have been created
249   // through op kernel execution to make sure the dataset op name is preserved
250   // across serialization boundaries, which is in turn needed to make sure
251   // iterator checkpoints are valid across serialization boundaries. When
252   // `use_dataset_name` is set, the caller is responsible for making sure that
253   // the op name is unique across the graph.
254   //
255   // `*output` contains a pointer to the output `Node`. It is guaranteed to be
256   // non-null if the method returns with an OK status. The returned `Node`
257   // pointer is owned by the backing `Graph` of `GraphDefBuilder`.
258   Status AddDataset(const DatasetBase* dataset,
259                     const std::vector<Node*>& inputs, Node** output);
260   Status AddDataset(const DatasetBase* dataset,
261                     const std::vector<Node*>& inputs,
262                     const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
263                     Node** output);
264   Status AddDataset(
265       const DatasetBase* dataset,
266       const std::vector<std::pair<size_t, Node*>>& inputs,
267       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
268       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
269       Node** output);
270   Status AddDataset(
271       const DatasetBase* dataset,
272       const std::vector<std::pair<size_t, Node*>>& inputs,
273       const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
274       const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
275       bool use_dataset_name, Node** output);
276 
277   // Adds a user-defined function with name `function_name` to the graph and
278   // recursively adds all functions it references. If a function with a matching
279   // name has already been added, returns with OK status. If a user-defined with
280   // name `function_name` is not found in the context's function library,
281   // returns an InvalidArgumentError. If the function with name `function_name`
282   // or any of its dependent functions are stateful, and the context does not
283   // explicitly permit stateful functions, returns an InvalidArgument error.
284   Status AddFunction(SerializationContext* ctx, const string& function_name,
285                      const FunctionLibraryDefinition& lib_def);
286 
287   template <typename T>
BuildAttrValue(const T & value,AttrValue * attr)288   void BuildAttrValue(const T& value, AttrValue* attr) {
289     SetAttrValue(value, attr);
290   }
291 
292  protected:
builder()293   GraphDefBuilder* builder() { return b_; }
294 
295  private:
296   void AddPlaceholderInternal(const Tensor& val, Node** output);
297   void AddTensorInternal(const Tensor& val, Node** output);
298   bool HasAttr(const string& op_type_name, const string& attr_name) const;
299 
HasAttr(const OpDef * op_def,const string & attr_name)300   bool HasAttr(const OpDef* op_def, const string& attr_name) const {
301     for (const auto& attr : op_def->attr()) {
302       if (attr.name() == attr_name) {
303         return true;
304       }
305     }
306     return false;
307   }
308 
AddAttrFunctions(SerializationContext * ctx,const AttrValue & attr_value,const FunctionLibraryDefinition & lib_def)309   Status AddAttrFunctions(SerializationContext* ctx,
310                           const AttrValue& attr_value,
311                           const FunctionLibraryDefinition& lib_def) {
312     if (attr_value.has_func()) {
313       TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def));
314     } else if (attr_value.has_list()) {
315       for (const NameAttrList& name_attr_list : attr_value.list().func()) {
316         TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def));
317       }
318     }
319     return Status::OK();
320   }
321 
322   GraphDefBuilder* b_;
323 };
324 
325 class StatsAggregator;
326 
327 // A utility class for running a function and ensuring that there is always a
328 // `tensorflow::data` symbol on the stack.
329 class Runner {
330  public:
~Runner()331   virtual ~Runner() {}
332 
333   // Runs the given function.
334   virtual void Run(const std::function<void()>& f) = 0;
335 
336   // Returns a global singleton Runner.
337   static Runner* get();
338 };
339 
340 // A class which provides a sequence of splits. Splits represent subdivisions of
341 // a dataset, e.g. filenames or ranges within files. We use splitting to
342 // partition input data into smaller pieces for distributed processing (see
343 // go/tf-data-splitting-design).
344 //
345 // Datasets provide a `MakeSplitProvider` method to expose a listing of their
346 // splits.
347 //
348 // Iterators created with a split provider will only iterate over the splits
349 // provided by the split provider.
350 class SplitProvider {
351  public:
~SplitProvider()352   virtual ~SplitProvider() {}
353   // Stores the next split in `*split`, setting `*end_of_splits` to indicate
354   // whether there were any splits left.
355   virtual Status GetNext(Tensor* split, bool* end_of_splits) = 0;
356   // Resets the split provider to its beginning.
357   virtual Status Reset() = 0;
358   // Saves the state of this split provider.
359   virtual Status Save(std::function<std::string(std::string)> full_name,
360                       IteratorStateWriter* writer) = 0;
361   // Restores the state of this split provider.
362   virtual Status Restore(std::function<std::string(std::string)> full_name,
363                          IteratorStateReader* reader) = 0;
364 };
365 
366 // A cut-down version of `OpKernelContext` for running computations in
367 // iterators. Note that we cannot simply use `OpKernelContext` here because we
368 // might run computation in an iterator whose lifetime is not nested within the
369 // lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching).
370 //
371 // TODO(mrry): We're making some daring assumptions about the lifetime of the
372 // runner passed in here. A runner will be deleted when the original step ends,
373 // but all existing runners only close over session-lifetime (or longer-lived)
374 // state, so we can make a copy of the function. There's nothing in the
375 // definition of the API from which we took the runner to guarantee that what we
376 // are doing is safe. We should formalize the properties here.
377 class IteratorContext {
378  public:
379   struct Params {
ParamsParams380     explicit Params(IteratorContext* ctx)
381         : allocator_getter(ctx->allocator_getter()),
382           cancellation_manager(ctx->cancellation_manager()),
383           collective_executor(ctx->collective_executor()),
384           env(ctx->env()),
385           flr(ctx->flr()),
386           function_handle_cache(ctx->function_handle_cache()),
387           is_restoring(ctx->is_restoring()),
388           resource_mgr(ctx->resource_mgr()),
389           model(ctx->model()),
390           runner(*(ctx->runner())),
391           runner_threadpool_size(ctx->runner_threadpool_size()),
392           split_providers(ctx->split_providers()),
393           stats_aggregator(ctx->stats_aggregator()),
394           thread_factory(ctx->thread_factory()),
395           thread_pool(ctx->thread_pool()) {}
396 
ParamsParams397     explicit Params(OpKernelContext* ctx)
398         : collective_executor(ctx->collective_executor()),
399           env(ctx->env()),
400           flr(ctx->function_library()) {
401       // NOTE: need reinterpret_cast because function.h forward-declares Device.
402       DeviceBase* device =
403           reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
404       allocator_getter = [device](AllocatorAttributes attrs) {
405         return device->GetAllocator(attrs);
406       };
407 
408       thread::ThreadPool* thread_pool =
409           ctx->device()->tensorflow_device_thread_pool();
410       if (thread_pool) {
411         runner_threadpool_size = thread_pool->NumThreads();
412       } else {
413         static const int32_t kDefaultRunnerThreadpoolSize =
414             port::MaxParallelism();
415         runner_threadpool_size = kDefaultRunnerThreadpoolSize;
416       }
417 
418       // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so
419       // that a symbol in the tensorflow::data namespace is always on the stack
420       // when executing a function inside a Dataset.
421       runner = std::bind(
422           [](
423               // Note: `runner` is a const reference to avoid copying it.
424               const std::function<void(std::function<void()>)>& ctx_runner,
425               std::function<void()> fn) {
426             std::function<void()> wrapped_fn = std::bind(
427                 [](const std::function<void()>& fn) { Runner::get()->Run(fn); },
428                 std::move(fn));
429             ctx_runner(std::move(wrapped_fn));
430           },
431           *ctx->runner(), std::placeholders::_1);
432     }
433 
434     // The Allocator to be used to allocate the output of an iterator.
435     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
436 
437     // The CancellationManager to be used to cancel execution of ops.
438     CancellationManager* cancellation_manager;
439 
440     // Collective support.
441     CollectiveExecutor* collective_executor = nullptr;
442 
443     // Interface to operating system functionality.
444     Env* env = nullptr;
445 
446     // The FunctionLibraryRuntime object to be used to make function calls.
447     FunctionLibraryRuntime* flr = nullptr;
448 
449     // A FunctionHandleCache that owns all the function handles. Not owned.
450     FunctionHandleCache* function_handle_cache = nullptr;
451 
452     // Marks whether the iterator is restored from a checkpoint.
453     bool is_restoring = false;
454 
455     // A resource manager for storing dataset-related state, e.g. random
456     // seeds or cached tensors. Not owned.
457     ResourceMgr* resource_mgr = nullptr;
458 
459     // If non-null, identifies the object used for performance modeling.
460     std::shared_ptr<model::Model> model = nullptr;
461 
462     // Function call support.
463     std::function<void(std::function<void()>)> runner = nullptr;
464 
465     // Number of threads used for executing user-defined functions.
466     int32 runner_threadpool_size = 0;
467 
468     // Split providers indicating which splits to process. May be empty,
469     // indicating that the iterator should process all splits.
470     std::vector<std::shared_ptr<SplitProvider>> split_providers;
471 
472     // The `StatsAggregator` object to record statistics about the iterator.
473     //
474     // TODO(b/147325552): Remove this API and any of its uses after we switch to
475     // using C++ based implementation for tf.data options (on 4/12/2021).
476     std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
477 
478     // A factory for creating threads to perform blocking work.
479     std::shared_ptr<ThreadFactory> thread_factory = nullptr;
480 
481     // A shared thread pool to schedule computation into.
482     thread::ThreadPoolInterface* thread_pool = nullptr;
483   };
484 
IteratorContext(IteratorContext * ctx)485   explicit IteratorContext(IteratorContext* ctx) : params_(Params{ctx}) {}
486 
IteratorContext(OpKernelContext * ctx)487   explicit IteratorContext(OpKernelContext* ctx) : params_(Params{ctx}) {}
488 
IteratorContext(Params params)489   explicit IteratorContext(Params params) : params_(std::move(params)) {}
490 
allocator(AllocatorAttributes attrs)491   Allocator* allocator(AllocatorAttributes attrs) {
492     return params_.allocator_getter(attrs);
493   }
494 
allocator_getter()495   std::function<Allocator*(AllocatorAttributes)> allocator_getter() {
496     return params_.allocator_getter;
497   }
498 
cancellation_manager()499   CancellationManager* cancellation_manager() {
500     return params_.cancellation_manager;
501   }
502 
collective_executor()503   CollectiveExecutor* collective_executor() {
504     return params_.collective_executor;
505   }
506 
env()507   Env* env() const { return params_.env; }
508 
flr()509   FunctionLibraryRuntime* flr() { return params_.flr; }
510 
function_handle_cache()511   FunctionHandleCache* function_handle_cache() {
512     return params_.function_handle_cache;
513   }
514 
is_restoring()515   bool is_restoring() { return params_.is_restoring; }
516 
resource_mgr()517   ResourceMgr* resource_mgr() { return params_.resource_mgr; }
518 
model()519   const std::shared_ptr<model::Model>& model() { return params_.model; }
520 
runner()521   std::function<void(std::function<void()>)>* runner() {
522     return &params_.runner;
523   }
524 
runner_threadpool_size()525   int32 runner_threadpool_size() { return params_.runner_threadpool_size; }
526 
split_providers()527   std::vector<std::shared_ptr<SplitProvider>> split_providers() {
528     return params_.split_providers;
529   }
530 
stats_aggregator()531   std::shared_ptr<StatsAggregator> stats_aggregator() {
532     return params_.stats_aggregator;
533   }
534 
thread_factory()535   const std::shared_ptr<ThreadFactory>& thread_factory() {
536     return params_.thread_factory;
537   }
538 
thread_pool()539   thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; }
540 
CreateThreadPool(const string & name,int num_threads)541   std::unique_ptr<thread::ThreadPool> CreateThreadPool(const string& name,
542                                                        int num_threads) {
543     if (params_.thread_pool) {
544       // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which
545       // is an instance of `thread::ThreadPoolInterface`). Notably, the
546       // ownership of `params_.thread_pool` is *not* transferred onto the newly
547       // created `ThreadPool` instance.
548       return absl::make_unique<thread::ThreadPool>(params_.thread_pool);
549     } else {
550       return absl::make_unique<thread::ThreadPool>(params_.env, ThreadOptions(),
551                                                    name, num_threads,
552                                                    /*low_latency_hint=*/false);
553     }
554   }
555 
StartThread(const string & name,std::function<void ()> fn)556   std::unique_ptr<Thread> StartThread(const string& name,
557                                       std::function<void()> fn) {
558     if (params_.thread_factory) {
559       return params_.thread_factory->StartThread(name, std::move(fn));
560     } else {
561       return absl::WrapUnique(
562           Env::Default()->StartThread({}, name, std::move(fn)));
563     }
564   }
565 
566  private:
567   Params params_;
568 };
569 
570 // Aggregates runtime support needed for dataset and iterator serialization.
571 class SerializationContext {
572  public:
573   // Enum describing what to do during serialization when external state is
574   // encountered.
575   enum class ExternalStatePolicy : int64 {
576     // Proceed with serialization, but log a warning about what state will be
577     // lost.
578     kWarn = 0,
579     // Proceed with serialization without logging any warning.
580     kIgnore = 1,
581     // Fail the serialization with an error.
582     kFail = 2,
583   };
584 
585   // Handles the CheckExternalState status according to the external state
586   // policy.
HandleCheckExternalStateStatus(Status s)587   Status HandleCheckExternalStateStatus(Status s) {
588     if (s.ok()) {
589       return s;
590     }
591     switch (params_.external_state_policy) {
592       case ExternalStatePolicy::kWarn:
593         LOG(WARNING) << s.ToString();
594         return Status::OK();
595       case ExternalStatePolicy::kIgnore:
596         VLOG(2) << "Ignoring error status: " << s.ToString();
597         return Status::OK();
598       case ExternalStatePolicy::kFail:
599         return s;
600     }
601     LOG(FATAL) << "Control should never reach here";
602   }
603 
604   struct Params {
ParamsParams605     explicit Params() {}
606 
ParamsParams607     explicit Params(OpKernelContext* ctx)
608         : resource_mgr(ctx->resource_manager()),
609           device_name(ctx->device()->attributes().name()) {}
610 
611     std::vector<std::pair<string, Tensor>>* input_list = nullptr;  // Not owned.
612 
613     // Indicates what to do if the dataset depends on external state.
614     ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;
615 
616     // Indicates whether an attempt to serialize a dataset that does not
617     // implement serialization should result in an error. If set to `false`, the
618     // serialized graph will replace the dataset with a placeholder returned in
619     // `input_list`.
620     bool fail_if_unimplemented = true;
621 
622     // Indicates whether (potentially large) data tensors should be
623     // serialized, or replaced with a placeholder returned in `input_list`. The
624     // latter makes sense to do when performing data agnostic graph rewrites to
625     // reduce the memory usage.
626     bool serialize_data_tensors = true;
627 
628     // Indicates whether datasets that use random seeds should have the values
629     // of random seeds serialized or not. If the values of random seeds are
630     // serialized, the deserialized dataset will have the same seeds as the
631     // original dataset. Otherwise, the deserialized dataset will use different
632     // seeds. This param does not affect datasets that use fixed seeds; fixed
633     // seeds will always be preserved.
634     bool preserve_random_seeds = true;
635 
636     // A resource manager for looking up resources during serialization.
637     ResourceMgr* resource_mgr;
638 
639     // The name of the device doing the serialization.
640     std::string device_name;
641   };
642 
SerializationContext(Params params)643   explicit SerializationContext(Params params) : params_(params) {}
644 
input_list()645   std::vector<std::pair<string, Tensor>>* input_list() {
646     return params_.input_list;
647   }
648 
external_state_policy()649   ExternalStatePolicy external_state_policy() const {
650     return params_.external_state_policy;
651   }
652 
fail_if_unimplemented()653   bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
654 
serialize_data_tensors()655   bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
656 
preserve_random_seeds()657   bool preserve_random_seeds() const { return params_.preserve_random_seeds; }
658 
resource_mgr()659   const ResourceMgr* resource_mgr() const { return params_.resource_mgr; }
660 
device_name()661   const std::string& device_name() const { return params_.device_name; }
662 
663  private:
664   Params params_;
665 
666   TF_DISALLOW_COPY_AND_ASSIGN(SerializationContext);
667 };
668 
669 // Represents the current position in a range of outputs, where the
670 // range of outputs is typically represented by an `DatasetBase`,
671 // defined below.
672 class IteratorBase {
673  public:
~IteratorBase()674   virtual ~IteratorBase() {
675     for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
676       (*rit)();
677     }
678   }
679 
680   // Gets the next output from the range that this iterator is traversing.
681   //
682   // If at least one output remains in this iterator's range, that
683   // output will be stored in `*out_tensors` and `false` will be
684   // stored in `*end_of_sequence`.
685   //
686   // If no more outputs remain in this iterator's range, `true` will be stored
687   // in `*end_of_sequence`, and `*out_tensors` will be empty.
688   //
689   // Implementations should never return `OutOfRange` error. If at end of
690   // sequence, set `*end_of_sequence = true` and return `Status::OK()`.
691   // Internally raised `OutOfRange` errors that do not imply end of sequence
692   // should be converted to a different error type before being propagated to
693   // the caller.
694   //
695   // Implementations must explicitly set `*end_of_sequence = false` if an
696   // `Status::OK()` status is returned and the iterator is not at the end of the
697   // sequence.
698   //
699   // This method is thread-safe.
700   //
701   // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and
702   // potentially remove this method.
703   virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
704                          bool* end_of_sequence) = 0;
705 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)706   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
707                  bool* end_of_sequence) {
708     return GetNext(&ctx, out_tensors, end_of_sequence);
709   }
710 
711   // Skips the next `num_to_skip` outputs from the range that this iterator
712   // is traversing.
713   //
714   // If there are not enough outputs to skip, it will set
715   // `*end_of_sequence = true` and return `Status::OK()`. `*num_skipped` will
716   // store the number of outputs that are skipped. When `*end_of_sequence` is
717   // `false`, `*num_skipped` should equal to `num_to_skip`.
718   virtual Status Skip(IteratorContext* ctx, int num_to_skip,
719                       bool* end_of_sequence, int* num_skipped) = 0;
720 
721   // Returns a vector of DataType values, representing the respective
722   // element types of each tuple component in the outputs of this
723   // iterator.
724   virtual const DataTypeVector& output_dtypes() const = 0;
725 
726   // Returns a vector of tensor shapes, representing the respective
727   // (and possibly partially defined) shapes of each tuple component
728   // in the outputs of this iterator.
729   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
730 
731   // Returns a string that identifies the sequence of iterators leading up to
732   // this iterator.
733   virtual const string& prefix() const = 0;
734 
735   // Performs initialization that needs to happen outside of a constructor to
736   // properly propagate errors.
Initialize(IteratorContext * ctx)737   virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
738 
739   // Performs initialization of the base iterator.
740   Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent);
741 
742   // Saves the state of this iterator.
Save(SerializationContext * ctx,IteratorStateWriter * writer)743   virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
744     int64_t start_us = EnvTime::NowMicros();
745     TF_RETURN_IF_ERROR(SaveInternal(ctx, writer));
746     VLOG(1) << "Saved " << prefix() << " in "
747             << (EnvTime::NowMicros() - start_us) << "us";
748     return Status::OK();
749   }
750 
751  protected:
752   // Returns a node that models this iterator.
753   virtual std::shared_ptr<model::Node> CreateNode(
754       IteratorContext* ctx, model::Node::Args args) const = 0;
755 
756   // Restores the state of this iterator.
Restore(IteratorContext * ctx,IteratorStateReader * reader)757   virtual Status Restore(IteratorContext* ctx, IteratorStateReader* reader) {
758     int64_t start_us = EnvTime::NowMicros();
759     TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader));
760     VLOG(1) << "Restored " << prefix() << " in "
761             << (EnvTime::NowMicros() - start_us) << "us";
762     return Status::OK();
763   }
764 
765   // This is needed so that sub-classes of IteratorBase can call
766   // `SaveInternal` on their input iterators.
SaveInput(SerializationContext * ctx,IteratorStateWriter * writer,const std::unique_ptr<IteratorBase> & input)767   Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
768                    const std::unique_ptr<IteratorBase>& input) {
769     return input->Save(ctx, writer);
770   }
771 
772   // This is needed so that sub-classes of IteratorBase can call
773   // `RestoreInternal` on their input iterators.
RestoreInput(IteratorContext * ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)774   Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
775                       const std::unique_ptr<IteratorBase>& input) {
776     return input->Restore(ctx, reader);
777   }
778 
RestoreInput(IteratorContext && ctx,IteratorStateReader * reader,const std::unique_ptr<IteratorBase> & input)779   Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader,
780                       const std::unique_ptr<IteratorBase>& input) {
781     return RestoreInput(&ctx, reader, input);
782   }
783 
784   // Saves the state of this iterator.
785   //
786   // This method is used to store the state of the iterator in a checkpoint.
787   // implementations have an override.
788   virtual Status SaveInternal(SerializationContext* ctx,
789                               IteratorStateWriter* writer) = 0;
790 
791   // Restores the state of this iterator.
792   //
793   // This method is used to restore the state of the iterator from a checkpoint.
794   //
795   // Implementations may assume that the iterator is in a clean state. That is,
796   // its `Initialize` method has been called, but its `GetNext` method has
797   // never been called.
798   // implementations have an override.
799   virtual Status RestoreInternal(IteratorContext* ctx,
800                                  IteratorStateReader* reader) = 0;
801 
802   // Returns a pointer to the node representing this iterator in the performance
803   // model. It may be null, if performance modeling is not enabled for this
804   // iterator.
model_node()805   std::shared_ptr<model::Node> model_node() const { return node_; }
806 
807   // Returns the number of elements produced by this iterator.
num_elements()808   int64 num_elements() const {
809     if (node_) return node_->num_elements();
810     return 0;
811   }
812 
813  private:
814   // For access to `AddCleanupFunction` and `Restore`.
815   friend class DatasetBase;
816   friend class DatasetBaseIterator;  // for access to `node_`
817 
818   std::vector<std::function<void()>> cleanup_fns_;
819   std::shared_ptr<model::Node> node_ = nullptr;
820   const IteratorBase* parent_ = nullptr;  // Not owned.
821   int64 id_ = 0;
822   int64 parent_id_ = 0;
823 };
824 
825 // Represents runtime information needed to construct a dataset.
826 class DatasetContext {
827  public:
828   struct Params {
829     string type_string;  // op type name of this dataset.
830     string node_name;    // graph node name of this dataset op, uniquely
831                          // identifying the dataset in the graph.
832   };
833 
DatasetContext(Params params)834   explicit DatasetContext(Params params) : params_(std::move(params)) {}
835 
DatasetContext(OpKernelContext * ctx)836   explicit DatasetContext(OpKernelContext* ctx) {
837     params_.type_string = ctx->op_kernel().type_string();
838     params_.node_name = ctx->op_kernel().name();
839   }
840 
type_string()841   const string& type_string() const { return params_.type_string; }
node_name()842   const string& node_name() const { return params_.node_name; }
843 
844  private:
845   Params params_;
846 };
847 
848 // Returns the number of bytes allocated for the given tensor.
849 int64 GetAllocatedBytes(const std::vector<Tensor>& element);
850 
851 // Returns the estimated memory usage in bytes of the given tensor.
852 int64 GetTotalBytes(const std::vector<Tensor>& element);
853 
854 // Validates and extracts a `DatasetBase` object from `tensor`.
855 //
856 // `tensor` must have been written by a call to SetVariantTensorToDataset().
857 //
858 // The retrieved pointer is a borrowed reference to the dataset, which is owned
859 // by the tensor. The consumer must either acquire its own reference to the
860 // dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
861 // destroyed or mutated while the retrieved pointer is in use.
862 Status GetDatasetFromVariantTensor(const Tensor& tensor,
863                                    DatasetBase** out_dataset);
864 
865 // Stores a `DatasetBase` object in `tensor`.
866 //
867 // The ownership of `dataset` is transferred to `tensor`.
868 Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor);
869 
870 // Represents a (potentially infinite) range of outputs, where each
871 // output is a tuple of tensors.
872 class DatasetBase : public core::RefCounted {
873  public:
874   // Key for storing the Dataset graph in the serialized format.
875   TF_EXPORT static const char kDatasetGraphKey[];
876 
877   // Key for storing the output node of the Dataset graph in the serialized
878   // format.
879   TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
880 
DatasetBase(DatasetContext && ctx)881   explicit DatasetBase(DatasetContext&& ctx)
882       : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {}
883 
884   // Op type name of this dataset.
type_string()885   const string& type_string() const { return type_string_; }
886 
887   // Graph node name of this dataset op, uniquely identifying the dataset in
888   // the graph.
node_name()889   const string& node_name() const { return node_name_; }
890 
891   // Initializes the dataset.
892   void Initialize();
893 
options()894   const Options& options() const { return options_; }
895 
num_sources()896   int64 num_sources() const { return num_sources_; }
897 
898   // Returns a new iterator for iterating over the range of elements in
899   // this dataset.
900   //
901   // This method may be called multiple times on the same instance,
902   // and the resulting iterators will have distinct state. Each
903   // iterator will traverse all elements in this dataset from the
904   // start.
905   //
906   // The prefix identifies the sequence of iterators leading up to the newly
907   // created iterator.
908   Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent,
909                       const string& output_prefix,
910                       std::unique_ptr<IteratorBase>* iterator) const;
911 
MakeIterator(IteratorContext && ctx,const IteratorBase * parent,const string & output_prefix,std::unique_ptr<IteratorBase> * iterator)912   Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent,
913                       const string& output_prefix,
914                       std::unique_ptr<IteratorBase>* iterator) const {
915     return MakeIterator(&ctx, parent, output_prefix, iterator);
916   }
917 
918   // Returns a new iterator restored from the checkpoint data in `reader`.
MakeIteratorFromCheckpoint(IteratorContext * ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)919   Status MakeIteratorFromCheckpoint(
920       IteratorContext* ctx, const string& output_prefix,
921       IteratorStateReader* reader,
922       std::unique_ptr<IteratorBase>* iterator) const {
923     std::unique_ptr<IteratorBase> it;
924     IteratorContext::Params params(ctx);
925     params.is_restoring = true;
926     TF_RETURN_IF_ERROR(MakeIterator(IteratorContext(std::move(params)),
927                                     /*parent=*/nullptr, output_prefix, &it));
928     TF_RETURN_IF_ERROR(it->Restore(ctx, reader));
929     *iterator = std::move(it);
930     return Status::OK();
931   }
932 
MakeIteratorFromCheckpoint(IteratorContext && ctx,const string & output_prefix,IteratorStateReader * reader,std::unique_ptr<IteratorBase> * iterator)933   Status MakeIteratorFromCheckpoint(
934       IteratorContext&& ctx, const string& output_prefix,
935       IteratorStateReader* reader,
936       std::unique_ptr<IteratorBase>* iterator) const {
937     return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator);
938   }
939 
940   // Returns a split provider which partitions the dataset's data into splits
941   // and provides them in a sequence. The split provider is stored in
942   // `*split_provider`.
943   virtual Status MakeSplitProviders(
944       std::vector<std::unique_ptr<SplitProvider>>* split_providers) const;
945 
946   // Returns a vector of DataType values, representing the respective
947   // element types of each tuple component in the outputs of this
948   // dataset.
949   virtual const DataTypeVector& output_dtypes() const = 0;
950 
951   // Returns a vector of tensor shapes, representing the respective
952   // (and possibly partially defined) shapes of each tuple component
953   // in the outputs of this dataset.
954   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
955 
956   // Returns the number of bytes allocated for tensors of this dataset.
AllocatedBytes()957   virtual int64 AllocatedBytes() const { return 0; }
958 
959   // Returns the estimated number of bytes used for tensors of this dataset.
TotalBytes()960   virtual int64 TotalBytes() const { return 0; }
961 
962   // Returns the cardinality of this dataset.
Cardinality()963   virtual int64 Cardinality() const { return kUnknownCardinality; }
964 
965   // A human-readable debug string for this dataset.
966   virtual string DebugString() const = 0;
967 
968   // Stores the dataset's input datasets in `*inputs`. The pointers stored in
969   // `*inputs` are borrowed. The only valid non-ok return status is
970   // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset
971   // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a
972   // default implementation of `MakeSplitProvider` when there is a single input
973   // dataset.
974   virtual Status InputDatasets(std::vector<const DatasetBase*>* inputs) const;
975 
976   // Indicates whether the dataset depends on any external state which would
977   // prevent it from being serializable. If so, the method returns
978   // `errors::FailedPrecondition` with a message that identifies the external
979   // state. Otherwise, the method returns `Status::OK()`.
980   virtual Status CheckExternalState() const = 0;
981 
982   // Return the element at a particular index for a randomly accessible dataset.
983   virtual Status Get(OpKernelContext* ctx, int64 index,
984                      std::vector<Tensor>* out_tensors);
985 
986   // Wrapper around a GraphDefBuilder which provides support for serializing
987   // Datasets as GraphDefs.
988   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
989    public:
DatasetGraphDefBuilder(GraphDefBuilder * b)990     explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
991         : GraphDefBuilderWrapper(b) {}
992     Status AddInputDataset(SerializationContext* ctx,
993                            const DatasetBase* dataset, Node** output);
994     Status AddDatasetOrTensor(SerializationContext* ctx, const Tensor& val,
995                               Node** output);
996 
997    private:
998     Status AddDatasetOrTensorHelper(SerializationContext* ctx,
999                                     const Tensor& val, Node** output);
1000     Status AddResourceHelper(SerializationContext* ctx, const Tensor& val,
1001                              Node** output);
1002   };
1003 
1004  protected:
1005   friend Status AsGraphDef(
1006       OpKernelContext* ctx, const DatasetBase* dataset,
1007       SerializationContext&& serialization_ctx,
1008       GraphDef* graph_def);  // For access to graph related members.
1009   friend class CapturedFunction;
1010 
1011   // Serializes the dataset into a `GraphDef`, which has two uses:
1012   //
1013   // 1) To perform static input pipeline optimizations, tf.data serializes the
1014   // dataset graph, applies graph rewrites, and then deserializes the graph.
1015   // If a subclass of `DatasetBase` does not implement this method, then it will
1016   // be excluded from static optimizations (and so will any upstream datasets).
1017   //
1018   // 2) To save the dataset so that it can restore at a later point (possibly in
1019   // different environment). If a subclass of `DatasetBase` does not implement
1020   // this method, then this migration will not be possible.
1021   virtual Status AsGraphDefInternal(SerializationContext* ctx,
1022                                     DatasetGraphDefBuilder* b,
1023                                     Node** node) const = 0;
1024 
1025   virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
1026       const string& prefix) const = 0;
1027 
set_options(const Options & options)1028   void set_options(const Options& options) { options_ = options; }
1029 
1030  private:
1031   // Computes the number of source datasets feeding into this dataset. A source
1032   // dataset is a leaf in the subtree of dataset inputs.
1033   Status ComputeNumSources();
1034 
1035   // Merges options from inputs to this dataset. If there is a conflict in a
1036   // field value, the options set on this dataset takes precedence over those in
1037   // the inputs. The order of precedence on the inputs is in the same order as
1038   // how they appear for this dataset.
1039   Status MergeOptionsFromInputs();
1040 
1041   const string type_string_;
1042   const string node_name_;
1043   Options options_;
1044   // The number of source datasets feeding into the dataset. A source dataset is
1045   // a leaf in the subtree of dataset inputs.
1046   int64 num_sources_ = -1;
1047 };
1048 
1049 // Represents an iterator that is associated with a particular dataset.
1050 class DatasetBaseIterator : public IteratorBase {
1051  public:
1052   struct BaseParams {
1053     // Owns one reference on the shared dataset object.
1054     const DatasetBase* dataset;
1055 
1056     // Identifies the sequence of iterators leading up to this iterator.
1057     const string prefix;
1058   };
1059 
1060   explicit DatasetBaseIterator(const BaseParams& params);
1061 
1062   ~DatasetBaseIterator() override;
1063 
dataset()1064   virtual const DatasetBase* dataset() const { return params_.dataset; }
1065 
output_dtypes()1066   const DataTypeVector& output_dtypes() const override {
1067     return params_.dataset->output_dtypes();
1068   }
1069 
output_shapes()1070   const std::vector<PartialTensorShape>& output_shapes() const override {
1071     return params_.dataset->output_shapes();
1072   }
1073 
prefix()1074   const string& prefix() const override { return params_.prefix; }
1075 
1076   // Returns a name to be used for the TraceMe event.
1077   //
1078   // NOTE: TraceMe supports passing key-value pairs of "arguments" using the
1079   // following format "name#arg_1=value_,...,arg_n=value_n".
1080   string BuildTraceMeName();
1081 
1082   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
1083                  bool* end_of_sequence) final;
1084 
GetNext(IteratorContext && ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1085   Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
1086                  bool* end_of_sequence) {
1087     return GetNext(&ctx, out_tensors, end_of_sequence);
1088   }
1089 
1090   Status Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence,
1091               int* num_skipped) final;
1092 
Save(SerializationContext * ctx,IteratorStateWriter * writer)1093   Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
1094     VLOG(2) << "Attempting to save checkpoints on iterator (prefix: "
1095             << prefix() << ") from " << dataset()->DebugString();
1096     return IteratorBase::Save(ctx, writer);
1097   }
1098 
1099  protected:
Restore(IteratorContext * ctx,IteratorStateReader * reader)1100   Status Restore(IteratorContext* ctx, IteratorStateReader* reader) final {
1101     VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: "
1102             << prefix() << ") from " << dataset()->DebugString();
1103     return IteratorBase::Restore(ctx, reader);
1104   }
1105 
1106   // Internal implementation of GetNext that is wrapped in tracing logic.
1107   //
1108   // See the docstring of `GetNext` method regaring the contract for
1109   // `out_tensors` and `end_of_sequence`. Implementations may assume that
1110   // `*out_tensors` is empty.
1111   virtual Status GetNextInternal(IteratorContext* ctx,
1112                                  std::vector<Tensor>* out_tensors,
1113                                  bool* end_of_sequence) = 0;
1114 
1115   // Internal implementation of Skip that is wrapped in tracing logic
1116   virtual Status SkipInternal(IteratorContext* ctx, int num_to_skip,
1117                               bool* end_of_sequence, int* num_skipped);
1118 
full_name(const string & name)1119   string full_name(const string& name) const {
1120     return FullName(params_.prefix, name);
1121   }
1122 
1123   // Returns a map of key-value pairs to included in the TraceMe string.
GetTraceMeMetadata()1124   virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; }
1125 
1126   // By default we model iterators using an unknown node, which acts as
1127   // pass-through with respect to performance modeling.
CreateNode(IteratorContext * ctx,model::Node::Args args)1128   std::shared_ptr<model::Node> CreateNode(
1129       IteratorContext* ctx, model::Node::Args args) const override {
1130     return model::MakeUnknownNode(std::move(args));
1131   }
1132 
1133   // When modeling is enabled, this method disables autotuning for the given
1134   // iterator (and the transitive closure of its inputs).
DisableAutotune(IteratorContext * ctx,IteratorBase * iterator)1135   void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1136     if (iterator->node_) {
1137       iterator->node_->set_autotune(false);
1138     }
1139   }
1140 
1141   // When modeling is enabled, this method enables autotuning for the given
1142   // iterator (and the transitive closure of its inputs).
EnableAutotune(IteratorContext * ctx,IteratorBase * iterator)1143   void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) {
1144     if (iterator->node_) {
1145       iterator->node_->set_autotune(true);
1146     }
1147   }
1148 
1149   // When modeling is enabled, this method records the fact that this iterator
1150   // has dequeued an element from an internal buffer.
RecordBufferDequeue(IteratorContext * ctx,const std::vector<Tensor> & element)1151   void RecordBufferDequeue(IteratorContext* ctx,
1152                            const std::vector<Tensor>& element) {
1153     if (collect_resource_usage(ctx)) {
1154       node_->record_buffer_event(-GetAllocatedBytes(element), -1);
1155 
1156       DCHECK_GE(node_->buffered_elements(), 0);
1157     }
1158   }
1159 
1160   // When modeling is enabled, this method records the fact that this iterator
1161   // has enqueued an element in an internal buffer.
RecordBufferEnqueue(IteratorContext * ctx,const std::vector<Tensor> & element)1162   void RecordBufferEnqueue(IteratorContext* ctx,
1163                            const std::vector<Tensor>& element) {
1164     if (collect_resource_usage(ctx)) {
1165       node_->record_buffer_event(GetAllocatedBytes(element), 1);
1166     }
1167   }
1168 
1169   // When modeling is enabled, this method records the fact that this iterator
1170   // has produced an element and its size in bytes.
RecordElement(IteratorContext * ctx,std::vector<Tensor> * out_tensors)1171   void RecordElement(IteratorContext* ctx, std::vector<Tensor>* out_tensors) {
1172     if (node_) {
1173       int64_t num_bytes = GetAllocatedBytes(*out_tensors);
1174       node_->record_element();
1175       node_->record_bytes_produced(num_bytes);
1176       if (node_->output()) {
1177         node_->output()->record_bytes_consumed(num_bytes);
1178       }
1179     }
1180   }
1181 
1182   // When modeling is enabled, this method records the fact that a thread of
1183   // this iterator has started work.
RecordStart(IteratorContext * ctx)1184   void RecordStart(IteratorContext* ctx) {
1185     if (collect_resource_usage(ctx)) {
1186       int64_t now_nanos = EnvTime::NowNanos();
1187       node_->record_start(now_nanos);
1188     }
1189   }
1190 
1191   // When modeling is enabled, this method records the fact that a thread of
1192   // this iterator has stopped work.
RecordStop(IteratorContext * ctx)1193   void RecordStop(IteratorContext* ctx) {
1194     if (collect_resource_usage(ctx)) {
1195       int64_t now_nanos = EnvTime::NowNanos();
1196       node_->record_stop(now_nanos);
1197     }
1198   }
1199 
1200   // Returns whether work is currently being recorded, i.e. whether we are
1201   // currently between a `RecordStart` and a `RecordStop`.
IsRecording(IteratorContext * ctx)1202   bool IsRecording(IteratorContext* ctx) {
1203     return collect_resource_usage(ctx) && node_->is_recording();
1204   }
1205 
1206  private:
collect_resource_usage(IteratorContext * ctx)1207   bool collect_resource_usage(IteratorContext* ctx) {
1208     auto model = ctx->model();
1209     return model && model->collect_resource_usage() && node_;
1210   }
1211 
1212   string traceme_metadata_;
1213   BaseParams params_;
1214 };
1215 
1216 // Represents an iterator that is associated with a particular dataset
1217 // with a particular type.
1218 template <class DatasetType>
1219 class DatasetIterator : public DatasetBaseIterator {
1220  public:
1221   struct Params {
1222     // Borrowed pointer to the dataset.
1223     const DatasetType* dataset;
1224 
1225     // Identifies the sequence of iterators leading up to this iterator.
1226     const string prefix;
1227   };
1228 
DatasetIterator(const Params & params)1229   explicit DatasetIterator(const Params& params)
1230       : DatasetBaseIterator({params.dataset, params.prefix}),
1231         typed_dataset_(params.dataset) {}
1232 
1233   // The dataset from which this iterator was created.
dataset()1234   const DatasetType* dataset() const final { return typed_dataset_; }
1235 
1236  private:
1237   const DatasetType* const typed_dataset_;  // Not owned.
1238 };
1239 
1240 template <typename T>
ParseScalarArgument(OpKernelContext * ctx,const StringPiece & argument_name,T * output)1241 Status ParseScalarArgument(OpKernelContext* ctx,
1242                            const StringPiece& argument_name, T* output) {
1243   const Tensor* argument_t;
1244   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1245   if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
1246     return errors::InvalidArgument(argument_name, " must be a scalar");
1247   }
1248   *output = argument_t->scalar<T>()();
1249   return Status::OK();
1250 }
1251 
1252 template <typename T>
ParseVectorArgument(OpKernelContext * ctx,const StringPiece & argument_name,std::vector<T> * output)1253 Status ParseVectorArgument(OpKernelContext* ctx,
1254                            const StringPiece& argument_name,
1255                            std::vector<T>* output) {
1256   const Tensor* argument_t;
1257   TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
1258   if (!TensorShapeUtils::IsVector(argument_t->shape())) {
1259     return errors::InvalidArgument(argument_name, " must be a vector");
1260   }
1261   int size = argument_t->vec<T>().size();
1262   output->reserve(size);
1263   for (int i = 0; i < size; ++i) {
1264     output->push_back(argument_t->vec<T>()(i));
1265   }
1266   return Status::OK();
1267 }
1268 
1269 // Encapsulates the work required to plug a DatasetBase into the core TensorFlow
1270 // graph execution engine.
1271 class DatasetOpKernel : public OpKernel {
1272  public:
DatasetOpKernel(OpKernelConstruction * ctx)1273   explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
1274 
1275   void Compute(OpKernelContext* ctx) final;
1276 
1277   // Indicates whether the given op corresponds to an op whose kernels subclass
1278   // the `DatasetOpKernel` class.
1279   static bool IsDatasetOp(const OpDef* op_def);
1280 
1281   string TraceString(const OpKernelContext& ctx, bool verbose) const override;
1282 
1283  protected:
1284   // Subclasses should implement this method. It will be called during Compute
1285   // execution.
1286   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0;
1287 };
1288 
1289 // Encapsulates the work required to plug unary Datasets into the core
1290 // TensorFlow graph execution engine.
1291 class UnaryDatasetOpKernel : public DatasetOpKernel {
1292  public:
UnaryDatasetOpKernel(OpKernelConstruction * ctx)1293   explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx)
1294       : DatasetOpKernel(ctx) {}
1295 
1296  protected:
1297   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1298   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1299                            DatasetBase** output) = 0;
1300 };
1301 
1302 // Encapsulates the work required to plug binary Datasets into the core
1303 // TensorFlow graph execution engine.
1304 class BinaryDatasetOpKernel : public DatasetOpKernel {
1305  public:
BinaryDatasetOpKernel(OpKernelConstruction * ctx)1306   explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx)
1307       : DatasetOpKernel(ctx) {}
1308 
1309  protected:
1310   void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final;
1311   virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1312                            DatasetBase* another_input,
1313                            DatasetBase** output) = 0;
1314 };
1315 
1316 // A simple background worker that executes closures asynchronously and without
1317 // blocking.
1318 //
1319 // A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
1320 // to avoid blocking an executor thread that may be required by the blocking
1321 // work.
1322 //
1323 // NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
1324 // purpose because its current implementation (in Eigen) uses a finite-length
1325 // queue and will block the caller when full. This can lead to deadlock under
1326 // heavy load. Since the number of concurrent work items in each user of a
1327 // `BackgroundWorker` is at most one per op invocation, the dynamic allocation
1328 // overhead is tolerable.
1329 class BackgroundWorker {
1330  public:
1331   BackgroundWorker(Env* env, const char* name);
1332 
1333   ~BackgroundWorker();
1334 
1335   void Schedule(std::function<void()> work_item);
1336 
1337  private:
1338   void WorkerLoop();
1339 
1340   Env* const env_;
1341   const char* const name_;
1342 
1343   std::unique_ptr<Thread> thread_;
1344   mutex mu_;
1345   condition_variable cond_var_;
1346   bool cancelled_ TF_GUARDED_BY(mu_) = false;
1347   std::deque<std::function<void()>> work_queue_ TF_GUARDED_BY(mu_);
1348 };
1349 
1350 }  // namespace data
1351 }  // namespace tensorflow
1352 
1353 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
1354