• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import functools
22import sys
23import threading
24import warnings
25import weakref
26
27import numpy as np
28import six
29from six.moves import queue as Queue  # pylint: disable=redefined-builtin
30
31from tensorflow.core.framework import graph_pb2
32from tensorflow.python import tf2
33from tensorflow.python.compat import compat
34from tensorflow.python.data.experimental.ops import distribute_options
35from tensorflow.python.data.experimental.ops import optimization_options
36from tensorflow.python.data.experimental.ops import stats_options
37from tensorflow.python.data.experimental.ops import threading_options
38from tensorflow.python.data.ops import iterator_ops
39from tensorflow.python.data.util import nest
40from tensorflow.python.data.util import options as options_lib
41from tensorflow.python.data.util import random_seed
42from tensorflow.python.data.util import sparse
43from tensorflow.python.data.util import structure
44from tensorflow.python.data.util import traverse
45from tensorflow.python.eager import context
46from tensorflow.python.eager import function as eager_function
47from tensorflow.python.framework import auto_control_deps
48from tensorflow.python.framework import composite_tensor
49from tensorflow.python.framework import constant_op
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import function
52from tensorflow.python.framework import ops
53from tensorflow.python.framework import random_seed as core_random_seed
54from tensorflow.python.framework import smart_cond
55from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
56from tensorflow.python.framework import tensor_shape
57from tensorflow.python.framework import tensor_spec
58from tensorflow.python.framework import tensor_util
59from tensorflow.python.framework import type_spec
60from tensorflow.python.ops import array_ops
61from tensorflow.python.ops import control_flow_ops
62from tensorflow.python.ops import gen_dataset_ops
63from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
64from tensorflow.python.ops import gen_io_ops
65from tensorflow.python.ops import math_ops
66from tensorflow.python.ops import script_ops
67from tensorflow.python.ops import string_ops
68from tensorflow.python.training.tracking import base as tracking_base
69from tensorflow.python.training.tracking import tracking
70from tensorflow.python.util import deprecation
71from tensorflow.python.util import function_utils
72from tensorflow.python.util import lazy_loader
73from tensorflow.python.util import nest as tf_nest
74from tensorflow.python.util.tf_export import tf_export
75
76# Loaded lazily due to a circular dependency (roughly
77# tf.function->wrap_function->dataset->autograph->tf.function).
78# TODO(b/133251390): Use a regular import.
79wrap_function = lazy_loader.LazyLoader(
80    "wrap_function", globals(),
81    "tensorflow.python.eager.wrap_function")
82# TODO(mdan): Create a public API for this.
83autograph_ctx = lazy_loader.LazyLoader(
84    "autograph_ctx", globals(),
85    "tensorflow.python.autograph.core.ag_ctx")
86autograph = lazy_loader.LazyLoader(
87    "autograph", globals(),
88    "tensorflow.python.autograph.impl.api")
89
90ops.NotDifferentiable("ReduceDataset")
91
92# A constant that can be used to enable auto-tuning.
93AUTOTUNE = -1
94tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
95
96
97@tf_export("data.Dataset", v1=[])
98@six.add_metaclass(abc.ABCMeta)
99class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
100  """Represents a potentially large set of elements.
101
102  The `tf.data.Dataset` API supports writing descriptive and efficient input
103  pipelines. `Dataset` usage follows a common pattern:
104
105  1. Create a source dataset from your input data.
106  2. Apply dataset transformations to preprocess the data.
107  3. Iterate over the dataset and process the elements.
108
109  Iteration happens in a streaming fashion, so the full dataset does not need to
110  fit into memory.
111
112  Source Datasets:
113
114  The simplest way to create a dataset is to create it from a python `list`:
115
116  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
117  >>> for element in dataset:
118  ...   print(element)
119  tf.Tensor(1, shape=(), dtype=int32)
120  tf.Tensor(2, shape=(), dtype=int32)
121  tf.Tensor(3, shape=(), dtype=int32)
122
123  To process lines from files, use `tf.data.TextLineDataset`:
124
125  >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
126
127  To process records written in the `TFRecord` format, use `TFRecordDataset`:
128
129  >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
130
131  To create a dataset of all files matching a pattern, use
132  `tf.data.Dataset.list_files`:
133
134  >>> dataset = tf.data.dataset.list_files("/path/*.txt")  # doctest: +SKIP
135
136  See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator`
137  for more ways to create datasets.
138
139  Transformations:
140
141  Once you have a dataset, you can apply transformations to prepare the data for
142  your model:
143
144  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
145  >>> dataset = dataset.map(lambda x: x*2)
146  >>> list(dataset.as_numpy_iterator())
147  [2, 4, 6]
148
149  Common Terms:
150
151  **Element**: A single output from calling `next()` on a dataset iterator.
152    Elements may be nested structures containing multiple components. For
153    example, the element `(1, (3, "apple"))` has one tuple nested in another
154    tuple. The components are `1`, `3`, and `"apple"`.
155  **Component**: The leaf in the nested structure of an element.
156
157  Supported types:
158
159  Elements can be nested structures of tuples, named tuples, and dictionaries.
160  Element components can be of any type representable by `tf.TypeSpec`,
161  including `tf.Tensor`, `tf.data.Dataset`, `tf.SparseTensor`,
162  `tf.RaggedTensor`, and `tf.TensorArray`.
163
164  >>> a = 1 # Integer element
165  >>> b = 2.0 # Float element
166  >>> c = (1, 2) # Tuple element with 2 components
167  >>> d = {"a": (2, 2), "b": 3} # Dict element with 3 components
168  >>> Point = collections.namedtuple("Point", ["x", "y"]) # doctest: +SKIP
169  >>> e = Point(1, 2) # Named tuple # doctest: +SKIP
170  >>> f = tf.data.Dataset.range(10) # Dataset element
171
172  """
173
174  def __init__(self, variant_tensor):
175    """Creates a DatasetV2 object.
176
177    This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
178    take anything in its constructor whereas in the DatasetV2, we expect
179    subclasses to create a variant_tensor and pass it in to the super() call.
180
181    Args:
182      variant_tensor: A DT_VARIANT tensor that represents the dataset.
183    """
184    self._variant_tensor_attr = variant_tensor
185    weak_self = weakref.proxy(self)
186    self._variant_tracker = self._track_trackable(
187        _VariantTracker(
188            self._variant_tensor,
189            # _trace_variant_creation only works when executing eagerly, so we
190            # don't want to run it immediately. We also want the _VariantTracker
191            # to have a weak reference to the Dataset to avoid creating
192            # reference cycles and making work for the garbage collector.
193            lambda: weak_self._trace_variant_creation()()),  # pylint: disable=unnecessary-lambda,protected-access
194        name="_variant_tracker")
195    self._graph_attr = ops.get_default_graph()
196
197  @property
198  def _variant_tensor(self):
199    return self._variant_tensor_attr
200
201  @_variant_tensor.setter
202  def _variant_tensor(self, _):
203    raise ValueError("The _variant_tensor property is read-only")
204
205  @deprecation.deprecated_args(None, "Use external_state_policy instead",
206                               "allow_stateful")
207  def _as_serialized_graph(
208      self,
209      allow_stateful=None,
210      strip_device_assignment=None,
211      external_state_policy=distribute_options.ExternalStatePolicy.WARN):
212    """Produces serialized graph representation of the dataset.
213
214    Args:
215      allow_stateful: If true, we allow stateful ops to be present in the graph
216        def. In that case, the state in these ops would be thrown away.
217      strip_device_assignment: If true, non-local (i.e. job and task) device
218        assignment is stripped from ops in the serialized graph.
219      external_state_policy: The ExternalStatePolicy enum that determines how we
220        handle input pipelines that depend on external state. By default, its
221        set to WARN.
222
223    Returns:
224      A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
225      serialized graph.
226    """
227    if external_state_policy:
228      policy = None
229      if external_state_policy:
230        policy = external_state_policy.value
231      return gen_dataset_ops.dataset_to_graph_v2(
232          self._variant_tensor,
233          external_state_policy=policy,
234          strip_device_assignment=strip_device_assignment)
235    if strip_device_assignment:
236      return gen_dataset_ops.dataset_to_graph(
237          self._variant_tensor,
238          allow_stateful=allow_stateful,
239          strip_device_assignment=strip_device_assignment)
240    return gen_dataset_ops.dataset_to_graph(
241        self._variant_tensor, allow_stateful=allow_stateful)
242
243  def _trace_variant_creation(self):
244    """Traces a function which outputs a variant `tf.Tensor` for this dataset.
245
246    Note that creating this function involves evaluating an op, and is currently
247    only supported when executing eagerly.
248
249    Returns:
250      A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`.
251    """
252    variant = self._variant_tensor
253    if not isinstance(variant, ops.EagerTensor):
254      raise NotImplementedError(
255          "Can only export Datasets which were created executing eagerly. "
256          "Please file a feature request if this is important to you.")
257    with context.eager_mode(), ops.device("CPU"):
258      # pylint: disable=protected-access
259      graph_def = graph_pb2.GraphDef().FromString(
260          self._as_serialized_graph(external_state_policy=distribute_options
261                                    .ExternalStatePolicy.FAIL).numpy())
262    output_node_name = None
263    for node in graph_def.node:
264      if node.op == "_Retval":
265        if output_node_name is not None:
266          raise AssertionError(
267              "Found multiple return values from the dataset's graph, expected "
268              "only one.")
269        output_node_name, = node.input
270    if output_node_name is None:
271      raise AssertionError("Could not find the dataset's output node.")
272    # Add functions used in this Dataset to the function's graph, since they
273    # need to follow it around (and for example be added to a SavedModel which
274    # references the dataset).
275    variant_function = wrap_function.function_from_graph_def(
276        graph_def, inputs=[], outputs=output_node_name + ":0")
277    for used_function in self._functions():
278      used_function.function.add_to_graph(variant_function.graph)
279    return variant_function
280
281  @abc.abstractmethod
282  def _inputs(self):
283    """Returns a list of the input datasets of the dataset."""
284
285    raise NotImplementedError("Dataset._inputs")
286
287  @property
288  def _graph(self):
289    return self._graph_attr
290
291  @_graph.setter
292  def _graph(self, _):
293    raise ValueError("The _graph property is read-only")
294
295  def _has_captured_ref(self):
296    """Whether this dataset uses a function that captures ref variables.
297
298    Returns:
299      A boolean, which if true indicates that the dataset or one of its inputs
300      uses a function that captures ref variables.
301    """
302    if context.executing_eagerly():
303      # RefVariables are not supported in eager mode
304      return False
305
306    def is_tensor_or_parent_ref(tensor):
307      if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
308        return True
309      # If the captured tensor is an eager tensor, we cannot trace its inputs.
310      if isinstance(tensor, ops._EagerTensorBase):  # pylint: disable=protected-access
311        return False
312      return any(is_tensor_or_parent_ref(x) for x in tensor.op.inputs)
313
314    for fn in self._functions():
315      if any(is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs):
316        return True
317
318    return any(
319        [input_dataset._has_captured_ref() for input_dataset in self._inputs()])  # pylint: disable=protected-access
320
321  # TODO(jsimsa): Change this to be the transitive closure of functions used
322  # by this dataset and its inputs.
323  def _functions(self):
324    """Returns a list of functions associated with this dataset.
325
326    Returns:
327      A list of `StructuredFunctionWrapper` objects.
328    """
329    return []
330
331  def options(self):
332    """Returns the options for this dataset and its inputs.
333
334    Returns:
335      A `tf.data.Options` object representing the dataset options.
336    """
337    options = Options()
338    for input_dataset in self._inputs():
339      input_options = input_dataset.options()
340      if input_options is not None:
341        options = options.merge(input_options)
342    return options
343
344  def _apply_options(self):
345    """Apply options, such as optimization configuration, to the dataset."""
346
347    dataset = self
348    options = self.options()
349
350    # (1) Apply threading options
351    if options.experimental_threading is not None:
352      t_options = options.experimental_threading
353      if t_options.max_intra_op_parallelism is not None:
354        dataset = _MaxIntraOpParallelismDataset(
355            dataset, t_options.max_intra_op_parallelism)
356      if t_options.private_threadpool_size is not None:
357        dataset = _PrivateThreadPoolDataset(dataset,
358                                            t_options.private_threadpool_size)
359
360    # (2) Apply graph rewrite options
361    # pylint: disable=protected-access
362    graph_rewrites = options._graph_rewrites()
363    graph_rewrite_configs = options._graph_rewrite_configs()
364    # pylint: enable=protected-access
365    if graph_rewrites:
366      if self._has_captured_ref():
367        warnings.warn(
368            "tf.data graph rewrites are not compatible with tf.Variable. "
369            "The following rewrites will be disabled: %s. To enable "
370            "rewrites, use resource variables instead by calling "
371            "`tf.enable_resource_variables()` at the start of the program." %
372            ", ".join(graph_rewrites))
373      else:
374        dataset = _OptimizeDataset(dataset, graph_rewrites,
375                                   graph_rewrite_configs)
376
377    # (3) Apply autotune options
378    autotune, algorithm, cpu_budget = options._autotune_settings()  # pylint: disable=protected-access
379
380    if autotune:
381      dataset = _ModelDataset(dataset, algorithm, cpu_budget)
382
383    # (4) Apply stats aggregator options
384    if options.experimental_stats and options.experimental_stats.aggregator:  # pylint: disable=line-too-long
385      dataset = _SetStatsAggregatorDataset(  # pylint: disable=protected-access
386          dataset, options.experimental_stats.aggregator,
387          options.experimental_stats.prefix,
388          options.experimental_stats.counter_prefix)
389    return dataset
390
391  def __iter__(self):
392    """Creates an `Iterator` for enumerating the elements of this dataset.
393
394    The returned iterator implements the Python iterator protocol and therefore
395    can only be used in eager mode.
396
397    Returns:
398      An `Iterator` over the elements of this dataset.
399
400    Raises:
401      RuntimeError: If not inside of tf.function and not executing eagerly.
402    """
403    if (context.executing_eagerly()
404        or ops.get_default_graph()._building_function):  # pylint: disable=protected-access
405      return iterator_ops.OwnedIterator(self)
406    else:
407      raise RuntimeError("__iter__() is only supported inside of tf.function "
408                         "or when eager execution is enabled.")
409
410  @abc.abstractproperty
411  def element_spec(self):
412    """The type specification of an element of this dataset.
413
414    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).element_spec
415    TensorSpec(shape=(), dtype=tf.int32, name=None)
416
417    Returns:
418      A nested structure of `tf.TypeSpec` objects matching the structure of an
419      element of this dataset and specifying the type of individual components.
420    """
421    raise NotImplementedError("Dataset.element_spec")
422
423  def __repr__(self):
424    output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
425    output_shapes = str(output_shapes).replace("'", "")
426    output_types = nest.map_structure(repr, get_legacy_output_types(self))
427    output_types = str(output_types).replace("'", "")
428    return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
429                                            output_types))
430
431  def as_numpy_iterator(self):
432    """Returns an iterator which converts all elements of the dataset to numpy.
433
434    Use `as_numpy_iterator` to inspect the content of your dataset. To see
435    element shapes and types, print dataset elements directly instead of using
436    `as_numpy_iterator`.
437
438    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
439    >>> for element in dataset:
440    ...   print(element)
441    tf.Tensor(1, shape=(), dtype=int32)
442    tf.Tensor(2, shape=(), dtype=int32)
443    tf.Tensor(3, shape=(), dtype=int32)
444
445    This method requires that you are running in eager mode and the dataset's
446    element_spec contains only `TensorSpec` components.
447
448    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
449    >>> for element in dataset.as_numpy_iterator():
450    ...   print(element)
451    1
452    2
453    3
454
455    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
456    >>> print(list(dataset.as_numpy_iterator()))
457    [1, 2, 3]
458
459    `as_numpy_iterator()` will preserve the nested structure of dataset
460    elements.
461
462    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
463    ...                                               'b': [5, 6]})
464    >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
465    ...                                       {'a': (2, 4), 'b': 6}]
466    True
467
468    Returns:
469      An iterable over the elements of the dataset, with their tensors converted
470      to numpy arrays.
471
472    Raises:
473      TypeError: if an element contains a non-`Tensor` value.
474      RuntimeError: if eager execution is not enabled.
475    """
476    if not context.executing_eagerly():
477      raise RuntimeError("as_numpy_iterator() is not supported while tracing "
478                         "functions")
479    for component_spec in nest.flatten(self.element_spec):
480      if not isinstance(component_spec, tensor_spec.TensorSpec):
481        raise TypeError(
482            "Dataset.as_numpy_iterator() does not support datasets containing "
483            + str(component_spec.value_type))
484
485    return _NumpyIterator(self)
486
487  @property
488  def _flat_shapes(self):
489    """Returns a list `tf.TensorShapes`s for the element tensor representation.
490
491    Returns:
492      A list `tf.TensorShapes`s for the element tensor representation.
493    """
494    return structure.get_flat_tensor_shapes(self.element_spec)
495
496  @property
497  def _flat_types(self):
498    """Returns a list `tf.DType`s for the element tensor representation.
499
500    Returns:
501      A list `tf.DType`s for the element tensor representation.
502    """
503    return structure.get_flat_tensor_types(self.element_spec)
504
505  @property
506  def _flat_structure(self):
507    """Helper for setting `output_shapes` and `output_types` attrs of an op.
508
509    Most dataset op constructors expect `output_shapes` and `output_types`
510    arguments that represent the flattened structure of an element. This helper
511    function generates these attrs as a keyword argument dictionary, allowing
512    `Dataset._variant_tensor` implementations to pass `**self._flat_structure`
513    to the op constructor.
514
515    Returns:
516      A dictionary of keyword arguments that can be passed to a dataset op
517      constructor.
518    """
519    return {
520        "output_shapes": self._flat_shapes,
521        "output_types": self._flat_types,
522    }
523
524  @property
525  def _type_spec(self):
526    return DatasetSpec(self.element_spec)
527
528  @staticmethod
529  def from_tensors(tensors):
530    """Creates a `Dataset` with a single element, comprising the given tensors.
531
532    >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3])
533    >>> list(dataset.as_numpy_iterator())
534    [array([1, 2, 3], dtype=int32)]
535    >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
536    >>> list(dataset.as_numpy_iterator())
537    [(array([1, 2, 3], dtype=int32), b'A')]
538
539    Note that if `tensors` contains a NumPy array, and eager execution is not
540    enabled, the values will be embedded in the graph as one or more
541    `tf.constant` operations. For large datasets (> 1 GB), this can waste
542    memory and run into byte limits of graph serialization. If `tensors`
543    contains one or more large NumPy arrays, consider the alternative described
544    in [this
545    guide](https://tensorflow.org/guide/data#consuming_numpy_arrays).
546
547    Args:
548      tensors: A dataset element.
549
550    Returns:
551      Dataset: A `Dataset`.
552    """
553    return TensorDataset(tensors)
554
555  @staticmethod
556  def from_tensor_slices(tensors):
557    """Creates a `Dataset` whose elements are slices of the given tensors.
558
559    The given tensors are sliced along their first dimension. This operation
560    preserves the structure of the input tensors, removing the first dimension
561    of each tensor and using it as the dataset dimension. All input tensors
562    must have the same size in their first dimensions.
563
564    >>> # Slicing a 1D tensor produces scalar tensor elements.
565    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
566    >>> list(dataset.as_numpy_iterator())
567    [1, 2, 3]
568
569    >>> # Slicing a 2D tensor produces 1D tensor elements.
570    >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
571    >>> list(dataset.as_numpy_iterator())
572    [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
573
574    >>> # Slicing a tuple of 1D tensors produces tuple elements containing
575    >>> # scalar tensors.
576    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
577    >>> list(dataset.as_numpy_iterator())
578    [(1, 3, 5), (2, 4, 6)]
579
580    >>> # Dictionary structure is also preserved.
581    >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
582    >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
583    ...                                       {'a': 2, 'b': 4}]
584    True
585
586    >>> # Two tensors can be combined into one Dataset object.
587    >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
588    >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
589    >>> dataset = Dataset.from_tensor_slices((features, labels))
590    >>> # Both the features and the labels tensors can be converted
591    >>> # to a Dataset object separately and combined after.
592    >>> features_dataset = Dataset.from_tensor_slices(features)
593    >>> labels_dataset = Dataset.from_tensor_slices(labels)
594    >>> dataset = Dataset.zip((features_dataset, labels_dataset))
595    >>> # A batched feature and label set can be converted to a Dataset
596    >>> # in similar fashion.
597    >>> batched_features = tf.constant([[[1, 3], [2, 3]],
598    ...                                 [[2, 1], [1, 2]],
599    ...                                 [[3, 3], [3, 2]]], shape=(3, 2, 2))
600    >>> batched_labels = tf.constant([['A', 'A'],
601    ...                               ['B', 'B'],
602    ...                               ['A', 'B']], shape=(3, 2, 1))
603    >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
604    >>> for element in dataset.as_numpy_iterator():
605    ...   print(element)
606    (array([[1, 3],
607           [2, 3]], dtype=int32), array([[b'A'],
608           [b'A']], dtype=object))
609    (array([[2, 1],
610           [1, 2]], dtype=int32), array([[b'B'],
611           [b'B']], dtype=object))
612    (array([[3, 3],
613           [3, 2]], dtype=int32), array([[b'A'],
614           [b'B']], dtype=object))
615
616    Note that if `tensors` contains a NumPy array, and eager execution is not
617    enabled, the values will be embedded in the graph as one or more
618    `tf.constant` operations. For large datasets (> 1 GB), this can waste
619    memory and run into byte limits of graph serialization. If `tensors`
620    contains one or more large NumPy arrays, consider the alternative described
621    in [this guide](
622    https://tensorflow.org/guide/data#consuming_numpy_arrays).
623
624    Args:
625      tensors: A dataset element, with each component having the same size in
626        the first dimension.
627
628    Returns:
629      Dataset: A `Dataset`.
630    """
631    return TensorSliceDataset(tensors)
632
633  class _GeneratorState(object):
634    """Stores outstanding iterators created from a Python generator.
635
636    This class keeps track of potentially multiple iterators that may have
637    been created from a generator, e.g. in the case that the dataset is
638    repeated, or nested within a parallel computation.
639    """
640
641    def __init__(self, generator):
642      self._generator = generator
643      self._lock = threading.Lock()
644      self._next_id = 0  # GUARDED_BY(self._lock)
645      self._args = {}
646      self._iterators = {}
647
648    def get_next_id(self, *args):
649      with self._lock:
650        ret = self._next_id
651        self._next_id += 1
652      self._args[ret] = args
653      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
654      # casting in `py_func()` will create an array of `np.int32` on Windows,
655      # leading to a runtime error.
656      return np.array(ret, dtype=np.int64)
657
658    def get_iterator(self, iterator_id):
659      try:
660        return self._iterators[iterator_id]
661      except KeyError:
662        iterator = iter(self._generator(*self._args.pop(iterator_id)))
663        self._iterators[iterator_id] = iterator
664        return iterator
665
666    def iterator_completed(self, iterator_id):
667      del self._iterators[iterator_id]
668
669  @staticmethod
670  def from_generator(generator, output_types, output_shapes=None, args=None):
671    """Creates a `Dataset` whose elements are generated by `generator`.
672
673    The `generator` argument must be a callable object that returns
674    an object that supports the `iter()` protocol (e.g. a generator function).
675    The elements generated by `generator` must be compatible with the given
676    `output_types` and (optional) `output_shapes` arguments.
677
678    >>> import itertools
679    >>>
680    >>> def gen():
681    ...   for i in itertools.count(1):
682    ...     yield (i, [1] * i)
683    >>>
684    >>> dataset = tf.data.Dataset.from_generator(
685    ...      gen,
686    ...      (tf.int64, tf.int64),
687    ...      (tf.TensorShape([]), tf.TensorShape([None])))
688    >>>
689    >>> list(dataset.take(3).as_numpy_iterator())
690    [(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]
691
692    NOTE: The current implementation of `Dataset.from_generator()` uses
693    `tf.numpy_function` and inherits the same constraints. In particular, it
694    requires the `Dataset`- and `Iterator`-related operations to be placed
695    on a device in the same process as the Python program that called
696    `Dataset.from_generator()`. The body of `generator` will not be
697    serialized in a `GraphDef`, and you should not use this method if you
698    need to serialize your model and restore it in a different environment.
699
700    NOTE: If `generator` depends on mutable global variables or other external
701    state, be aware that the runtime may invoke `generator` multiple times
702    (in order to support repeating the `Dataset`) and at any time
703    between the call to `Dataset.from_generator()` and the production of the
704    first element from the generator. Mutating global variables or external
705    state can cause undefined behavior, and we recommend that you explicitly
706    cache any external state in `generator` before calling
707    `Dataset.from_generator()`.
708
709    Args:
710      generator: A callable object that returns an object that supports the
711        `iter()` protocol. If `args` is not specified, `generator` must take no
712        arguments; otherwise it must take as many arguments as there are values
713        in `args`.
714      output_types: A nested structure of `tf.DType` objects corresponding to
715        each component of an element yielded by `generator`.
716      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
717        corresponding to each component of an element yielded by `generator`.
718      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
719        and passed to `generator` as NumPy-array arguments.
720
721    Returns:
722      Dataset: A `Dataset`.
723    """
724    if not callable(generator):
725      raise TypeError("`generator` must be callable.")
726    if output_shapes is None:
727      output_shapes = nest.map_structure(
728          lambda _: tensor_shape.TensorShape(None), output_types)
729    else:
730      output_shapes = nest.map_structure_up_to(
731          output_types, tensor_shape.as_shape, output_shapes)
732    if args is None:
733      args = ()
734    else:
735      args = tuple(ops.convert_n_to_tensor(args, name="args"))
736
737    flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
738    flattened_shapes = nest.flatten(output_shapes)
739
740    generator_state = DatasetV2._GeneratorState(generator)
741
742    def get_iterator_id_fn(unused_dummy):
743      """Creates a unique `iterator_id` for each pass over the dataset.
744
745      The returned `iterator_id` disambiguates between multiple concurrently
746      existing iterators.
747
748      Args:
749        unused_dummy: Ignored value.
750
751      Returns:
752        A `tf.int64` tensor whose value uniquely identifies an iterator in
753        `generator_state`.
754      """
755      return script_ops.numpy_function(generator_state.get_next_id, args,
756                                       dtypes.int64)
757
758    def generator_next_fn(iterator_id_t):
759      """Generates the next element from iterator with ID `iterator_id_t`.
760
761      We map this function across an infinite repetition of the
762      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
763
764      Args:
765        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the
766          iterator in `generator_state` from which to generate an element.
767
768      Returns:
769        The next element to generate from the iterator.
770      """
771
772      def generator_py_func(iterator_id):
773        """A `py_func` that will be called to invoke the iterator."""
774        # `next()` raises `StopIteration` when there are no more
775        # elements remaining to be generated.
776        values = next(generator_state.get_iterator(iterator_id))
777
778        # Use the same _convert function from the py_func() implementation to
779        # convert the returned values to arrays early, so that we can inspect
780        # their values.
781        try:
782          flattened_values = nest.flatten_up_to(output_types, values)
783        except (TypeError, ValueError):
784          six.reraise(TypeError, TypeError(
785              "`generator` yielded an element that did not match the expected "
786              "structure. The expected structure was %s, but the yielded "
787              "element was %s." % (output_types, values)), sys.exc_info()[2])
788        ret_arrays = []
789        for ret, dtype in zip(flattened_values, flattened_types):
790          try:
791            ret_arrays.append(script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
792                ret, dtype=dtype.as_numpy_dtype))
793          except (TypeError, ValueError):
794            six.reraise(TypeError, TypeError(
795                "`generator` yielded an element that could not be converted to "
796                "the expected type. The expected type was %s, but the yielded "
797                "element was %s." % (dtype.name, ret)), sys.exc_info()[2])
798
799        # Additional type and shape checking to ensure that the components
800        # of the generated element match the `output_types` and `output_shapes`
801        # arguments.
802        for (ret_array, expected_dtype, expected_shape) in zip(
803            ret_arrays, flattened_types, flattened_shapes):
804          if ret_array.dtype != expected_dtype.as_numpy_dtype:
805            raise TypeError(
806                "`generator` yielded an element of type %s where an element "
807                "of type %s was expected." % (ret_array.dtype,
808                                              expected_dtype.as_numpy_dtype))
809          if not expected_shape.is_compatible_with(ret_array.shape):
810            raise ValueError(
811                "`generator` yielded an element of shape %s where an element "
812                "of shape %s was expected." % (ret_array.shape, expected_shape))
813
814        return ret_arrays
815
816      flat_values = script_ops.numpy_function(generator_py_func,
817                                              [iterator_id_t], flattened_types)
818
819      # The `py_func()` op drops the inferred shapes, so we add them back in
820      # here.
821      if output_shapes is not None:
822        for ret_t, shape in zip(flat_values, flattened_shapes):
823          ret_t.set_shape(shape)
824
825      return nest.pack_sequence_as(output_types, flat_values)
826
827    def finalize_fn(iterator_id_t):
828      """Releases host-side state for the iterator with ID `iterator_id_t`."""
829
830      def finalize_py_func(iterator_id):
831        generator_state.iterator_completed(iterator_id)
832        # We return a dummy value so that the `finalize_fn` has a valid
833        # signature.
834        # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
835        # casting in `py_func()` will create an array of `np.int32` on Windows,
836        # leading to a runtime error.
837        return np.array(0, dtype=np.int64)
838
839      return script_ops.numpy_function(finalize_py_func, [iterator_id_t],
840                                       dtypes.int64)
841
842    # This function associates each traversal of `generator` with a unique
843    # iterator ID.
844    def flat_map_fn(dummy_arg):
845      # The `get_iterator_id_fn` gets a unique ID for the current instance of
846      # of the generator.
847      # The `generator_next_fn` gets the next element from the iterator with the
848      # given ID, and raises StopIteration when that iterator contains no
849      # more elements.
850      return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
851                               finalize_fn)
852
853    # A single-element dataset that, each time it is evaluated, contains a
854    # freshly-generated and unique (for the returned dataset) int64
855    # ID that will be used to identify the appropriate Python state, which
856    # is encapsulated in `generator_state`, and captured in
857    # `get_iterator_id_map_fn`.
858    dummy = 0
859    id_dataset = Dataset.from_tensors(dummy)
860
861    # A dataset that contains all of the elements generated by a
862    # single iterator created from `generator`, identified by the
863    # iterator ID contained in `id_dataset`. Lifting the iteration
864    # into a flat_map here enables multiple repetitions and/or nested
865    # versions of the returned dataset to be created, because it forces
866    # the generation of a new ID for each version.
867    return id_dataset.flat_map(flat_map_fn)
868
869  @staticmethod
870  def range(*args, **kwargs):
871    """Creates a `Dataset` of a step-separated range of values.
872
873    >>> list(Dataset.range(5).as_numpy_iterator())
874    [0, 1, 2, 3, 4]
875    >>> list(Dataset.range(2, 5).as_numpy_iterator())
876    [2, 3, 4]
877    >>> list(Dataset.range(1, 5, 2).as_numpy_iterator())
878    [1, 3]
879    >>> list(Dataset.range(1, 5, -2).as_numpy_iterator())
880    []
881    >>> list(Dataset.range(5, 1).as_numpy_iterator())
882    []
883    >>> list(Dataset.range(5, 1, -2).as_numpy_iterator())
884    [5, 3]
885    >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
886    [2, 3, 4]
887    >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
888    [1.0, 3.0]
889
890    Args:
891      *args: follows the same semantics as python's xrange.
892        len(args) == 1 -> start = 0, stop = args[0], step = 1
893        len(args) == 2 -> start = args[0], stop = args[1], step = 1
894        len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
895      **kwargs:
896        - output_type: Its expected dtype. (Optional, default: `tf.int64`).
897
898    Returns:
899      Dataset: A `RangeDataset`.
900
901    Raises:
902      ValueError: if len(args) == 0.
903    """
904    return RangeDataset(*args, **kwargs)
905
906  @staticmethod
907  def zip(datasets):
908    """Creates a `Dataset` by zipping together the given datasets.
909
910    This method has similar semantics to the built-in `zip()` function
911    in Python, with the main difference being that the `datasets`
912    argument can be an arbitrary nested structure of `Dataset` objects.
913
914    >>> # The nested structure of the `datasets` argument determines the
915    >>> # structure of elements in the resulting dataset.
916    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
917    >>> b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
918    >>> ds = tf.data.Dataset.zip((a, b))
919    >>> list(ds.as_numpy_iterator())
920    [(1, 4), (2, 5), (3, 6)]
921    >>> ds = tf.data.Dataset.zip((b, a))
922    >>> list(ds.as_numpy_iterator())
923    [(4, 1), (5, 2), (6, 3)]
924    >>>
925    >>> # The `datasets` argument may contain an arbitrary number of datasets.
926    >>> c = tf.data.Dataset.range(7, 13).batch(2)  # ==> [ [7, 8],
927    ...                                            #       [9, 10],
928    ...                                            #       [11, 12] ]
929    >>> ds = tf.data.Dataset.zip((a, b, c))
930    >>> for element in ds.as_numpy_iterator():
931    ...   print(element)
932    (1, 4, array([7, 8]))
933    (2, 5, array([ 9, 10]))
934    (3, 6, array([11, 12]))
935    >>>
936    >>> # The number of elements in the resulting dataset is the same as
937    >>> # the size of the smallest dataset in `datasets`.
938    >>> d = tf.data.Dataset.range(13, 15)  # ==> [ 13, 14 ]
939    >>> ds = tf.data.Dataset.zip((a, d))
940    >>> list(ds.as_numpy_iterator())
941    [(1, 13), (2, 14)]
942
943    Args:
944      datasets: A nested structure of datasets.
945
946    Returns:
947      Dataset: A `Dataset`.
948    """
949    return ZipDataset(datasets)
950
951  def concatenate(self, dataset):
952    """Creates a `Dataset` by concatenating the given dataset with this dataset.
953
954    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
955    >>> b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
956    >>> ds = a.concatenate(b)
957    >>> list(ds.as_numpy_iterator())
958    [1, 2, 3, 4, 5, 6, 7]
959    >>> # The input dataset and dataset to be concatenated should have the same
960    >>> # nested structures and output types.
961    >>> c = tf.data.Dataset.zip((a, b))
962    >>> a.concatenate(c)
963    Traceback (most recent call last):
964    TypeError: Two datasets to concatenate have different types
965    <dtype: 'int64'> and (tf.int64, tf.int64)
966    >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
967    >>> a.concatenate(d)
968    Traceback (most recent call last):
969    TypeError: Two datasets to concatenate have different types
970    <dtype: 'int64'> and <dtype: 'string'>
971
972    Args:
973      dataset: `Dataset` to be concatenated.
974
975    Returns:
976      Dataset: A `Dataset`.
977    """
978    return ConcatenateDataset(self, dataset)
979
980  def prefetch(self, buffer_size):
981    """Creates a `Dataset` that prefetches elements from this dataset.
982
983    Most dataset input pipelines should end with a call to `prefetch`. This
984    allows later elements to be prepared while the current element is being
985    processed. This often improves latency and throughput, at the cost of
986    using additional memory to store prefetched elements.
987
988    Note: Like other `Dataset` methods, prefetch operates on the
989    elements of the input dataset. It has no concept of examples vs. batches.
990    `examples.prefetch(2)` will prefetch two elements (2 examples),
991    while `examples.batch(20).prefetch(2)` will prefetch 2 elements
992    (2 batches, of 20 examples each).
993
994    >>> dataset = tf.data.Dataset.range(3)
995    >>> dataset = dataset.prefetch(2)
996    >>> list(dataset.as_numpy_iterator())
997    [0, 1, 2]
998
999    Args:
1000      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
1001        number of elements that will be buffered when prefetching.
1002
1003    Returns:
1004      Dataset: A `Dataset`.
1005    """
1006    return PrefetchDataset(self, buffer_size)
1007
1008  @staticmethod
1009  def list_files(file_pattern, shuffle=None, seed=None):
1010    """A dataset of all files matching one or more glob patterns.
1011
1012    The `file_pattern` argument should be a small number of glob patterns.
1013    If your filenames have already been globbed, use
1014    `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every
1015    filename with `list_files` may result in poor performance with remote
1016    storage systems.
1017
1018    NOTE: The default behavior of this method is to return filenames in
1019    a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
1020    to get results in a deterministic order.
1021
1022    Example:
1023      If we had the following files on our filesystem:
1024        - /path/to/dir/a.txt
1025        - /path/to/dir/b.py
1026        - /path/to/dir/c.py
1027      If we pass "/path/to/dir/*.py" as the directory, the dataset
1028      would produce:
1029        - /path/to/dir/b.py
1030        - /path/to/dir/c.py
1031
1032    Args:
1033      file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
1034        (scalar or vector), representing the filename glob (i.e. shell wildcard)
1035        pattern(s) that will be matched.
1036      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
1037        Defaults to `True`.
1038      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1039        seed that will be used to create the distribution. See
1040        `tf.random.set_seed` for behavior.
1041
1042    Returns:
1043     Dataset: A `Dataset` of strings corresponding to file names.
1044    """
1045    with ops.name_scope("list_files"):
1046      if shuffle is None:
1047        shuffle = True
1048      file_pattern = ops.convert_to_tensor(
1049          file_pattern, dtype=dtypes.string, name="file_pattern")
1050      matching_files = gen_io_ops.matching_files(file_pattern)
1051
1052      # Raise an exception if `file_pattern` does not match any files.
1053      condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
1054                                   name="match_not_empty")
1055
1056      message = math_ops.add(
1057          "No files matched pattern: ",
1058          string_ops.reduce_join(file_pattern, separator=", "), name="message")
1059
1060      assert_not_empty = control_flow_ops.Assert(
1061          condition, [message], summarize=1, name="assert_not_empty")
1062      with ops.control_dependencies([assert_not_empty]):
1063        matching_files = array_ops.identity(matching_files)
1064
1065      dataset = Dataset.from_tensor_slices(matching_files)
1066      if shuffle:
1067        # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
1068        # list of files might be empty.
1069        buffer_size = math_ops.maximum(
1070            array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
1071        dataset = dataset.shuffle(buffer_size, seed=seed)
1072      return dataset
1073
1074  def repeat(self, count=None):
1075    """Repeats this dataset so each original value is seen `count` times.
1076
1077    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1078    >>> dataset = dataset.repeat(3)
1079    >>> list(dataset.as_numpy_iterator())
1080    [1, 2, 3, 1, 2, 3, 1, 2, 3]
1081
1082    NOTE: If this dataset is a function of global state (e.g. a random number
1083    generator), then different repetitions may produce different elements.
1084
1085    Args:
1086      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1087        number of times the dataset should be repeated. The default behavior (if
1088        `count` is `None` or `-1`) is for the dataset be repeated indefinitely.
1089
1090    Returns:
1091      Dataset: A `Dataset`.
1092    """
1093    return RepeatDataset(self, count)
1094
1095  def enumerate(self, start=0):
1096    """Enumerates the elements of this dataset.
1097
1098    It is similar to python's `enumerate`.
1099
1100    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1101    >>> dataset = dataset.enumerate(start=5)
1102    >>> for element in dataset.as_numpy_iterator():
1103    ...   print(element)
1104    (5, 1)
1105    (6, 2)
1106    (7, 3)
1107
1108    >>> # The nested structure of the input dataset determines the structure of
1109    >>> # elements in the resulting dataset.
1110    >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
1111    >>> dataset = dataset.enumerate()
1112    >>> for element in dataset.as_numpy_iterator():
1113    ...   print(element)
1114    (0, array([7, 8], dtype=int32))
1115    (1, array([ 9, 10], dtype=int32))
1116
1117    Args:
1118      start: A `tf.int64` scalar `tf.Tensor`, representing the start value for
1119        enumeration.
1120
1121    Returns:
1122      Dataset: A `Dataset`.
1123    """
1124
1125    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
1126    return Dataset.zip((Dataset.range(start, max_value), self))
1127
1128  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
1129    """Randomly shuffles the elements of this dataset.
1130
1131    This dataset fills a buffer with `buffer_size` elements, then randomly
1132    samples elements from this buffer, replacing the selected elements with new
1133    elements. For perfect shuffling, a buffer size greater than or equal to the
1134    full size of the dataset is required.
1135
1136    For instance, if your dataset contains 10,000 elements but `buffer_size` is
1137    set to 1,000, then `shuffle` will initially select a random element from
1138    only the first 1,000 elements in the buffer. Once an element is selected,
1139    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
1140    maintaining the 1,000 element buffer.
1141
1142    `reshuffle_each_iteration` controls whether the shuffle order should be
1143    different for each epoch. In TF 1.X, the idiomatic way to create epochs
1144    was through the `repeat` transformation:
1145
1146    >>> dataset = tf.data.Dataset.range(3)
1147    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1148    >>> dataset = dataset.repeat(2)  # doctest: +SKIP
1149    [1, 0, 2, 1, 2, 0]
1150
1151    >>> dataset = tf.data.Dataset.range(3)
1152    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1153    >>> dataset = dataset.repeat(2)  # doctest: +SKIP
1154    [1, 0, 2, 1, 0, 2]
1155
1156    In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
1157    possible to also create epochs through Python iteration:
1158
1159    >>> dataset = tf.data.Dataset.range(3)
1160    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1161    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1162    [1, 0, 2]
1163    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1164    [1, 2, 0]
1165
1166    >>> dataset = tf.data.Dataset.range(3)
1167    >>> dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1168    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1169    [1, 0, 2]
1170    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1171    [1, 0, 2]
1172
1173    Args:
1174      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1175        elements from this dataset from which the new dataset will sample.
1176      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1177        seed that will be used to create the distribution. See
1178        `tf.random.set_seed` for behavior.
1179      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
1180        that the dataset should be pseudorandomly reshuffled each time it is
1181        iterated over. (Defaults to `True`.)
1182
1183    Returns:
1184      Dataset: A `Dataset`.
1185    """
1186    return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
1187
1188  def cache(self, filename=""):
1189    """Caches the elements in this dataset.
1190
1191    The first time the dataset is iterated over, its elements will be cached
1192    either in the specified file or in memory. Subsequent iterations will
1193    use the cached data.
1194
1195    Note: For the cache to be finalized, the input dataset must be iterated
1196    through in its entirety. Otherwise, subsequent iterations will not use
1197    cached data.
1198
1199    >>> dataset = tf.data.Dataset.range(5)
1200    >>> dataset = dataset.map(lambda x: x**2)
1201    >>> dataset = dataset.cache()
1202    >>> # The first time reading through the data will generate the data using
1203    >>> # `range` and `map`.
1204    >>> list(dataset.as_numpy_iterator())
1205    [0, 1, 4, 9, 16]
1206    >>> # Subsequent iterations read from the cache.
1207    >>> list(dataset.as_numpy_iterator())
1208    [0, 1, 4, 9, 16]
1209
1210    When caching to a file, the cached data will persist across runs. Even the
1211    first iteration through the data will read from the cache file. Changing
1212    the input pipeline before the call to `.cache()` will have no effect until
1213    the cache file is removed or the filename is changed.
1214
1215    >>> dataset = tf.data.Dataset.range(5)
1216    >>> dataset = dataset.cache("/path/to/file)  # doctest: +SKIP
1217    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1218    [0, 1, 2, 3, 4]
1219    >>> dataset = tf.data.Dataset.range(10)
1220    >>> dataset = dataset.cache("/path/to/file")  # Same file! # doctest: +SKIP
1221    >>> list(dataset.as_numpy_iterator())  # doctest: +SKIP
1222    [0, 1, 2, 3, 4]
1223
1224    Note: `cache` will produce exactly the same elements during each iteration
1225    through the dataset. If you wish to randomize the iteration order, make sure
1226    to call `shuffle` *after* calling `cache`.
1227
1228    Args:
1229      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
1230        directory on the filesystem to use for caching elements in this Dataset.
1231        If a filename is not provided, the dataset will be cached in memory.
1232
1233    Returns:
1234      Dataset: A `Dataset`.
1235    """
1236    return CacheDataset(self, filename)
1237
1238  def take(self, count):
1239    """Creates a `Dataset` with at most `count` elements from this dataset.
1240
1241    >>> dataset = tf.data.Dataset.range(10)
1242    >>> dataset = dataset.take(3)
1243    >>> list(dataset.as_numpy_iterator())
1244    [0, 1, 2]
1245
1246    Args:
1247      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1248        elements of this dataset that should be taken to form the new dataset.
1249        If `count` is -1, or if `count` is greater than the size of this
1250        dataset, the new dataset will contain all elements of this dataset.
1251
1252    Returns:
1253      Dataset: A `Dataset`.
1254    """
1255    return TakeDataset(self, count)
1256
1257  def skip(self, count):
1258    """Creates a `Dataset` that skips `count` elements from this dataset.
1259
1260    >>> dataset = tf.data.Dataset.range(10)
1261    >>> dataset = dataset.skip(7)
1262    >>> list(dataset.as_numpy_iterator())
1263    [7, 8, 9]
1264
1265    Args:
1266      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1267        elements of this dataset that should be skipped to form the new dataset.
1268        If `count` is greater than the size of this dataset, the new dataset
1269        will contain no elements.  If `count` is -1, skips the entire dataset.
1270
1271    Returns:
1272      Dataset: A `Dataset`.
1273    """
1274    return SkipDataset(self, count)
1275
1276  def shard(self, num_shards, index):
1277    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
1278
1279    `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will
1280    contain all elements of A whose index mod n = i.
1281
1282    >>> A = tf.data.Dataset.range(10)
1283    >>> B = A.shard(num_shards=3, index=0)
1284    >>> list(B.as_numpy_iterator())
1285    [0, 3, 6, 9]
1286    >>> C = A.shard(num_shards=3, index=1)
1287    >>> list(C.as_numpy_iterator())
1288    [1, 4, 7]
1289    >>> D = A.shard(num_shards=3, index=2)
1290    >>> list(D.as_numpy_iterator())
1291    [2, 5, 8]
1292
1293    This dataset operator is very useful when running distributed training, as
1294    it allows each worker to read a unique subset.
1295
1296    When reading a single input file, you can shard elements as follows:
1297
1298    ```python
1299    d = tf.data.TFRecordDataset(input_file)
1300    d = d.shard(num_workers, worker_index)
1301    d = d.repeat(num_epochs)
1302    d = d.shuffle(shuffle_buffer_size)
1303    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1304    ```
1305
1306    Important caveats:
1307
1308    - Be sure to shard before you use any randomizing operator (such as
1309      shuffle).
1310    - Generally it is best if the shard operator is used early in the dataset
1311      pipeline. For example, when reading from a set of TFRecord files, shard
1312      before converting the dataset to input samples. This avoids reading every
1313      file on every worker. The following is an example of an efficient
1314      sharding strategy within a complete pipeline:
1315
1316    ```python
1317    d = Dataset.list_files(pattern)
1318    d = d.shard(num_workers, worker_index)
1319    d = d.repeat(num_epochs)
1320    d = d.shuffle(shuffle_buffer_size)
1321    d = d.interleave(tf.data.TFRecordDataset,
1322                     cycle_length=num_readers, block_length=1)
1323    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1324    ```
1325
1326    Args:
1327      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
1328        shards operating in parallel.
1329      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
1330
1331    Returns:
1332      Dataset: A `Dataset`.
1333
1334    Raises:
1335      InvalidArgumentError: if `num_shards` or `index` are illegal values.
1336        Note: error checking is done on a best-effort basis, and errors aren't
1337        guaranteed to be caught upon dataset creation. (e.g. providing in a
1338        placeholder tensor bypasses the early checking, and will instead result
1339        in an error during a session.run call.)
1340    """
1341    return ShardDataset(self, num_shards, index)
1342
1343  def batch(self, batch_size, drop_remainder=False):
1344    """Combines consecutive elements of this dataset into batches.
1345
1346    >>> dataset = tf.data.Dataset.range(8)
1347    >>> dataset = dataset.batch(3)
1348    >>> list(dataset.as_numpy_iterator())
1349    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
1350
1351    >>> dataset = tf.data.Dataset.range(8)
1352    >>> dataset = dataset.batch(3, drop_remainder=True)
1353    >>> list(dataset.as_numpy_iterator())
1354    [array([0, 1, 2]), array([3, 4, 5])]
1355
1356    The components of the resulting element will have an additional outer
1357    dimension, which will be `batch_size` (or `N % batch_size` for the last
1358    element if `batch_size` does not divide the number of input elements `N`
1359    evenly and `drop_remainder` is `False`). If your program depends on the
1360    batches having the same outer dimension, you should set the `drop_remainder`
1361    argument to `True` to prevent the smaller batch from being produced.
1362
1363    Args:
1364      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1365        consecutive elements of this dataset to combine in a single batch.
1366      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1367        whether the last batch should be dropped in the case it has fewer than
1368        `batch_size` elements; the default behavior is not to drop the smaller
1369        batch.
1370
1371    Returns:
1372      Dataset: A `Dataset`.
1373    """
1374    return BatchDataset(self, batch_size, drop_remainder)
1375
1376  def padded_batch(self,
1377                   batch_size,
1378                   padded_shapes=None,
1379                   padding_values=None,
1380                   drop_remainder=False):
1381    """Combines consecutive elements of this dataset into padded batches.
1382
1383    This transformation combines multiple consecutive elements of the input
1384    dataset into a single element.
1385
1386    Like `tf.data.Dataset.batch`, the components of the resulting element will
1387    have an additional outer dimension, which will be `batch_size` (or
1388    `N % batch_size` for the last element if `batch_size` does not divide the
1389    number of input elements `N` evenly and `drop_remainder` is `False`). If
1390    your program depends on the batches having the same outer dimension, you
1391    should set the `drop_remainder` argument to `True` to prevent the smaller
1392    batch from being produced.
1393
1394    Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
1395    different shapes, and this transformation will pad each component to the
1396    respective shape in `padded_shapes`. The `padded_shapes` argument
1397    determines the resulting shape for each dimension of each component in an
1398    output element:
1399
1400    * If the dimension is a constant, the component will be padded out to that
1401      length in that dimension.
1402    * If the dimension is unknown, the component will be padded out to the
1403      maximum length of all elements in that dimension.
1404
1405    >>> A = (tf.data.Dataset
1406    ...      .range(1, 5, output_type=tf.int32)
1407    ...      .map(lambda x: tf.fill([x], x)))
1408    >>> # Pad to the smallest per-batch size that fits all elements.
1409    >>> B = A.padded_batch(2)
1410    >>> for element in B.as_numpy_iterator():
1411    ...   print(element)
1412    [[1 0]
1413     [2 2]]
1414    [[3 3 3 0]
1415     [4 4 4 4]]
1416    >>> # Pad to a fixed size.
1417    >>> C = A.padded_batch(2, padded_shapes=5)
1418    >>> for element in C.as_numpy_iterator():
1419    ...   print(element)
1420    [[1 0 0 0 0]
1421     [2 2 0 0 0]]
1422    [[3 3 3 0 0]
1423     [4 4 4 4 0]]
1424    >>> # Pad with a custom value.
1425    >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
1426    >>> for element in D.as_numpy_iterator():
1427    ...   print(element)
1428    [[ 1 -1 -1 -1 -1]
1429     [ 2  2 -1 -1 -1]]
1430    [[ 3  3  3 -1 -1]
1431     [ 4  4  4  4 -1]]
1432    >>> # Components of nested elements can be padded independently.
1433    >>> elements = [([1, 2, 3], [10]),
1434    ...             ([4, 5], [11, 12])]
1435    >>> dataset = tf.data.Dataset.from_generator(
1436    ...     lambda: iter(elements), (tf.int32, tf.int32))
1437    >>> # Pad the first component of the tuple to length 4, and the second
1438    >>> # component to the smallest size that fits.
1439    >>> dataset = dataset.padded_batch(2,
1440    ...     padded_shapes=([4], [None]),
1441    ...     padding_values=(-1, 100))
1442    >>> list(dataset.as_numpy_iterator())
1443    [(array([[ 1,  2,  3, -1], [ 4,  5, -1, -1]], dtype=int32),
1444      array([[ 10, 100], [ 11,  12]], dtype=int32))]
1445
1446    See also `tf.data.experimental.dense_to_sparse_batch`, which combines
1447    elements that may have different shapes into a `tf.SparseTensor`.
1448
1449    Args:
1450      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1451        consecutive elements of this dataset to combine in a single batch.
1452      padded_shapes: (Optional.) A nested structure of `tf.TensorShape` or
1453        `tf.int64` vector tensor-like objects representing the shape to which
1454        the respective component of each input element should be padded prior
1455        to batching. Any unknown dimensions will be padded to the maximum size
1456        of that dimension in each batch. If unset, all dimensions of all
1457        components are padded to the maximum size in the batch. `padded_shapes`
1458        must be set if any component has an unknown rank.
1459      padding_values: (Optional.) A nested structure of scalar-shaped
1460        `tf.Tensor`, representing the padding values to use for the respective
1461        components. None represents that the nested structure should be padded
1462        with default values.  Defaults are `0` for numeric types and the empty
1463        string for string types.
1464      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1465        whether the last batch should be dropped in the case it has fewer than
1466        `batch_size` elements; the default behavior is not to drop the smaller
1467        batch.
1468
1469    Returns:
1470      Dataset: A `Dataset`.
1471
1472    Raises:
1473      ValueError: If a component has an unknown rank, and  the `padded_shapes`
1474        argument is not set.
1475    """
1476    if padded_shapes is None:
1477      padded_shapes = get_legacy_output_shapes(self)
1478      # A `tf.TensorShape` only is only falsey if its *rank* is unknown:
1479      # bool(tf.TensorShape(None)) is False
1480      if not all(nest.flatten(padded_shapes)):
1481        raise ValueError("You must set the `padded_shapes` argument to "
1482                         "`Dataset.padded_batch` if any component of its input"
1483                         "has an unknown rank")
1484    return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
1485                              drop_remainder)
1486
1487  def map(self, map_func, num_parallel_calls=None):
1488    """Maps `map_func` across the elements of this dataset.
1489
1490    This transformation applies `map_func` to each element of this dataset, and
1491    returns a new dataset containing the transformed elements, in the same
1492    order as they appeared in the input. `map_func` can be used to change both
1493    the values and the structure of a dataset's elements. For example, adding 1
1494    to each element, or projecting a subset of element components.
1495
1496    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1497    >>> dataset = dataset.map(lambda x: x + 1)
1498    >>> list(dataset.as_numpy_iterator())
1499    [2, 3, 4, 5, 6]
1500
1501    The input signature of `map_func` is determined by the structure of each
1502    element in this dataset.
1503
1504    >>> dataset = Dataset.range(5)
1505    >>> # `map_func` takes a single argument of type `tf.Tensor` with the same
1506    >>> # shape and dtype.
1507    >>> result = dataset.map(lambda x: x + 1)
1508
1509    >>> # Each element is a tuple containing two `tf.Tensor` objects.
1510    >>> elements = [(1, "foo"), (2, "bar"), (3, "baz)")]
1511    >>> dataset = tf.data.Dataset.from_generator(
1512    ...     lambda: elements, (tf.int32, tf.string))
1513    >>> # `map_func` takes two arguments of type `tf.Tensor`. This function
1514    >>> # projects out just the first component.
1515    >>> result = dataset.map(lambda x_int, y_str: x_int)
1516    >>> list(result.as_numpy_iterator())
1517    [1, 2, 3]
1518
1519    >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects.
1520    >>> elements =  ([{"a": 1, "b": "foo"},
1521    ...               {"a": 2, "b": "bar"},
1522    ...               {"a": 3, "b": "baz"}])
1523    >>> dataset = tf.data.Dataset.from_generator(
1524    ...     lambda: elements, {"a": tf.int32, "b": tf.string})
1525    >>> # `map_func` takes a single argument of type `dict` with the same keys
1526    >>> # as the elements.
1527    >>> result = dataset.map(lambda d: str(d["a"]) + d["b"])
1528
1529    The value or values returned by `map_func` determine the structure of each
1530    element in the returned dataset.
1531
1532    >>> dataset = tf.data.Dataset.range(3)
1533    >>> # `map_func` returns two `tf.Tensor` objects.
1534    >>> def g(x):
1535    ...   return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
1536    >>> result = dataset.map(g)
1537    >>> result.element_spec
1538    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \
1539dtype=tf.string, name=None))
1540    >>> # Python primitives, lists, and NumPy arrays are implicitly converted to
1541    >>> # `tf.Tensor`.
1542    >>> def h(x):
1543    ...   return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
1544    >>> result = dataset.map(h)
1545    >>> result.element_spec
1546    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \
1547dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \
1548name=None))
1549    >>> # `map_func` can return nested structures.
1550    >>> def i(x):
1551    ...   return (37.0, [42, 16]), "foo"
1552    >>> result = dataset.map(i)
1553    >>> result.element_spec
1554    ((TensorSpec(shape=(), dtype=tf.float32, name=None),
1555      TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
1556     TensorSpec(shape=(), dtype=tf.string, name=None))
1557
1558    `map_func` can accept as arguments and return any type of dataset element.
1559
1560    Note that irrespective of the context in which `map_func` is defined (eager
1561    vs. graph), tf.data traces the function and executes it as a graph. To use
1562    Python code inside of the function you have two options:
1563
1564    1) Rely on AutoGraph to convert Python code into an equivalent graph
1565    computation. The downside of this approach is that AutoGraph can convert
1566    some but not all Python code.
1567
1568    2) Use `tf.py_function`, which allows you to write arbitrary Python code but
1569    will generally result in worse performance than 1). For example:
1570
1571    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
1572    >>> # transform a string tensor to upper case string using a Python function
1573    >>> def upper_case_fn(t: tf.Tensor):
1574    ...   return t.numpy().decode('utf-8').upper()
1575    >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn,
1576    ...           inp=[x], Tout=tf.string))
1577    >>> list(d.as_numpy_iterator())
1578    [b'HELLO', b'WORLD']
1579
1580    Args:
1581      map_func: A function mapping a dataset element to another dataset element.
1582      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
1583        representing the number elements to process asynchronously in parallel.
1584        If not specified, elements will be processed sequentially. If the value
1585        `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
1586        calls is set dynamically based on available CPU.
1587
1588    Returns:
1589      Dataset: A `Dataset`.
1590    """
1591    if num_parallel_calls is None:
1592      return MapDataset(self, map_func, preserve_cardinality=True)
1593    else:
1594      return ParallelMapDataset(
1595          self, map_func, num_parallel_calls, preserve_cardinality=True)
1596
1597  def flat_map(self, map_func):
1598    """Maps `map_func` across this dataset and flattens the result.
1599
1600    Use `flat_map` if you want to make sure that the order of your dataset
1601    stays the same. For example, to flatten a dataset of batches into a
1602    dataset of their elements:
1603
1604    >>> dataset = Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1605    >>> dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
1606    >>> list(dataset.as_numpy_iterator())
1607    [1, 2, 3, 4, 5, 6, 7, 8, 9]
1608
1609    `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
1610    `flat_map` produces the same output as
1611    `tf.data.Dataset.interleave(cycle_length=1)`
1612
1613    Args:
1614      map_func: A function mapping a dataset element to a dataset.
1615
1616    Returns:
1617      Dataset: A `Dataset`.
1618    """
1619    return FlatMapDataset(self, map_func)
1620
1621  def interleave(self,
1622                 map_func,
1623                 cycle_length=AUTOTUNE,
1624                 block_length=1,
1625                 num_parallel_calls=None,
1626                 deterministic=None):
1627    """Maps `map_func` across this dataset, and interleaves the results.
1628
1629    For example, you can use `Dataset.interleave()` to process many input files
1630    concurrently:
1631
1632    >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records
1633    >>> # from each file.
1634    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
1635    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
1636    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
1637    >>> def parse_fn(filename):
1638    ...   return tf.data.Dataset.range(10)
1639    >>> dataset = dataset.interleave(lambda x:
1640    ...     tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
1641    ...     cycle_length=4, block_length=16)
1642
1643    The `cycle_length` and `block_length` arguments control the order in which
1644    elements are produced. `cycle_length` controls the number of input elements
1645    that are processed concurrently. If you set `cycle_length` to 1, this
1646    transformation will handle one input element at a time, and will produce
1647    identical results to `tf.data.Dataset.flat_map`. In general,
1648    this transformation will apply `map_func` to `cycle_length` input elements,
1649    open iterators on the returned `Dataset` objects, and cycle through them
1650    producing `block_length` consecutive elements from each iterator, and
1651    consuming the next input element each time it reaches the end of an
1652    iterator.
1653
1654    For example:
1655
1656    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1657    >>> # NOTE: New lines indicate "block" boundaries.
1658    >>> dataset = dataset.interleave(
1659    ...     lambda x: Dataset.from_tensors(x).repeat(6),
1660    ...     cycle_length=2, block_length=4)
1661    >>> list(dataset.as_numpy_iterator())
1662    [1, 1, 1, 1,
1663     2, 2, 2, 2,
1664     1, 1,
1665     2, 2,
1666     3, 3, 3, 3,
1667     4, 4, 4, 4,
1668     3, 3,
1669     4, 4,
1670     5, 5, 5, 5,
1671     5, 5]
1672
1673    NOTE: The order of elements yielded by this transformation is
1674    deterministic, as long as `map_func` is a pure function and
1675    `deterministic=True`. If `map_func` contains any stateful operations, the
1676    order in which that state is accessed is undefined.
1677
1678    Performance can often be improved by setting `num_parallel_calls` so that
1679    `interleave` will use multiple threads to fetch elements. If determinism
1680    isn't required, it can also improve performance to set
1681    `deterministic=False`.
1682
1683    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
1684    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
1685    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
1686    >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
1687    ...     cycle_length=4, num_parallel_calls=tf.data.experimental.AUTOTUNE,
1688    ...     deterministic=False)
1689
1690    Args:
1691      map_func: A function mapping a dataset element to a dataset.
1692      cycle_length: (Optional.) The number of input elements that will be
1693        processed concurrently. If not specified, the value will be derived from
1694        the number of available CPU cores. If the `num_parallel_calls` argument
1695        is set to `tf.data.experimental.AUTOTUNE`, the `cycle_length` argument
1696        also identifies the maximum degree of parallelism.
1697      block_length: (Optional.) The number of consecutive elements to produce
1698        from each input element before cycling to another input element.
1699      num_parallel_calls: (Optional.) If specified, the implementation creates a
1700        threadpool, which is used to fetch inputs from cycle elements
1701        asynchronously and in parallel. The default behavior is to fetch inputs
1702        from cycle elements synchronously with no parallelism. If the value
1703        `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
1704        calls is set dynamically based on available CPU.
1705      deterministic: (Optional.) A boolean controlling whether determinism
1706        should be traded for performance by allowing elements to be produced out
1707        of order.  If `deterministic` is `None`, the
1708        `tf.data.Options.experimental_deterministic` dataset option (`True` by
1709        default) is used to decide whether to produce elements
1710        deterministically.
1711
1712    Returns:
1713      Dataset: A `Dataset`.
1714    """
1715    if num_parallel_calls is None:
1716      return InterleaveDataset(self, map_func, cycle_length, block_length)
1717    else:
1718      return ParallelInterleaveDataset(self, map_func, cycle_length,
1719                                       block_length, num_parallel_calls,
1720                                       deterministic)
1721
1722  def filter(self, predicate):
1723    """Filters this dataset according to `predicate`.
1724
1725    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1726    >>> dataset = dataset.filter(lambda x: x < 3)
1727    >>> list(dataset.as_numpy_iterator())
1728    [1, 2]
1729    >>> # `tf.math.equal(x, y)` is required for equality comparison
1730    >>> def filter_fn(x):
1731    ...   return tf.math.equal(x, 1)
1732    >>> dataset = dataset.filter(filter_fn)
1733    >>> list(dataset.as_numpy_iterator())
1734    [1]
1735
1736    Args:
1737      predicate: A function mapping a dataset element to a boolean.
1738
1739    Returns:
1740      Dataset: The `Dataset` containing the elements of this dataset for which
1741          `predicate` is `True`.
1742    """
1743    return FilterDataset(self, predicate)
1744
1745  def apply(self, transformation_func):
1746    """Applies a transformation function to this dataset.
1747
1748    `apply` enables chaining of custom `Dataset` transformations, which are
1749    represented as functions that take one `Dataset` argument and return a
1750    transformed `Dataset`.
1751
1752    >>> dataset = tf.data.Dataset.range(100)
1753    >>> def dataset_fn(ds):
1754    ...   return ds.filter(lambda x: x < 5)
1755    >>> dataset = dataset.apply(dataset_fn)
1756    >>> list(dataset.as_numpy_iterator())
1757    [0, 1, 2, 3, 4]
1758
1759    Args:
1760      transformation_func: A function that takes one `Dataset` argument and
1761        returns a `Dataset`.
1762
1763    Returns:
1764      Dataset: The `Dataset` returned by applying `transformation_func` to this
1765          dataset.
1766    """
1767    dataset = transformation_func(self)
1768    if not isinstance(dataset, DatasetV2):
1769      raise TypeError(
1770          "`transformation_func` must return a Dataset. Got {}.".format(
1771              dataset))
1772    dataset._input_datasets = [self]  # pylint: disable=protected-access
1773    return dataset
1774
1775  def window(self, size, shift=None, stride=1, drop_remainder=False):
1776    """Combines (nests of) input elements into a dataset of (nests of) windows.
1777
1778    A "window" is a finite dataset of flat elements of size `size` (or possibly
1779    fewer if there are not enough input elements to fill the window and
1780    `drop_remainder` evaluates to false).
1781
1782    The `stride` argument determines the stride of the input elements, and the
1783    `shift` argument determines the shift of the window.
1784
1785    >>> dataset = tf.data.Dataset.range(7).window(2)
1786    >>> for window in dataset:
1787    ...   print(list(window.as_numpy_iterator()))
1788    [0, 1]
1789    [2, 3]
1790    [4, 5]
1791    [6]
1792    >>> dataset = tf.data.Dataset.range(7).window(3, 2, 1, True)
1793    >>> for window in dataset:
1794    ...   print(list(window.as_numpy_iterator()))
1795    [0, 1, 2]
1796    [2, 3, 4]
1797    [4, 5, 6]
1798    >>> dataset = tf.data.Dataset.range(7).window(3, 1, 2, True)
1799    >>> for window in dataset:
1800    ...   print(list(window.as_numpy_iterator()))
1801    [0, 2, 4]
1802    [1, 3, 5]
1803    [2, 4, 6]
1804
1805    Note that when the `window` transformation is applied to a dataset of
1806    nested elements, it produces a dataset of nested windows.
1807
1808    >>> nested = ([1, 2, 3, 4], [5, 6, 7, 8])
1809    >>> dataset = tf.data.Dataset.from_tensor_slices(nested).window(2)
1810    >>> for window in dataset:
1811    ...   def to_numpy(ds):
1812    ...     return list(ds.as_numpy_iterator())
1813    ...   print(tuple(to_numpy(component) for component in window))
1814    ([1, 2], [5, 6])
1815    ([3, 4], [7, 8])
1816
1817    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3, 4]})
1818    >>> dataset = dataset.window(2)
1819    >>> for window in dataset:
1820    ...   def to_numpy(ds):
1821    ...     return list(ds.as_numpy_iterator())
1822    ...   print({'a': to_numpy(window['a'])})
1823    {'a': [1, 2]}
1824    {'a': [3, 4]}
1825
1826    Args:
1827      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
1828        of the input dataset to combine into a window.
1829      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1830        forward shift of the sliding window in each iteration. Defaults to
1831        `size`.
1832      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1833        stride of the input elements in the sliding window.
1834      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1835        whether a window should be dropped in case its size is smaller than
1836        `window_size`.
1837
1838    Returns:
1839      Dataset: A `Dataset` of (nests of) windows -- a finite datasets of flat
1840        elements created from the (nests of) input elements.
1841
1842    """
1843    if shift is None:
1844      shift = size
1845    return WindowDataset(self, size, shift, stride, drop_remainder)
1846
1847  def reduce(self, initial_state, reduce_func):
1848    """Reduces the input dataset to a single element.
1849
1850    The transformation calls `reduce_func` successively on every element of
1851    the input dataset until the dataset is exhausted, aggregating information in
1852    its internal state. The `initial_state` argument is used for the initial
1853    state and the final state is returned as the result.
1854
1855    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
1856    5
1857    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
1858    10
1859
1860    Args:
1861      initial_state: An element representing the initial state of the
1862        transformation.
1863      reduce_func: A function that maps `(old_state, input_element)` to
1864        `new_state`. It must take two arguments and return a new element
1865        The structure of `new_state` must match the structure of
1866        `initial_state`.
1867
1868    Returns:
1869      A dataset element corresponding to the final state of the transformation.
1870
1871    """
1872
1873    with ops.name_scope("initial_state"):
1874      initial_state = structure.normalize_element(initial_state)
1875    state_structure = structure.type_spec_from_value(initial_state)
1876
1877    # Iteratively rerun the reduce function until reaching a fixed point on
1878    # `state_structure`.
1879    need_to_rerun = True
1880    while need_to_rerun:
1881
1882      wrapped_func = StructuredFunctionWrapper(
1883          reduce_func,
1884          "reduce()",
1885          input_structure=(state_structure, self.element_spec),
1886          add_to_graph=False)
1887
1888      # Extract and validate class information from the returned values.
1889      output_classes = wrapped_func.output_classes
1890      state_classes = nest.map_structure(
1891          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
1892          state_structure)
1893      for new_state_class, state_class in zip(
1894          nest.flatten(output_classes), nest.flatten(state_classes)):
1895        if not issubclass(new_state_class, state_class):
1896          raise TypeError(
1897              "The element classes for the new state must match the initial "
1898              "state. Expected %s; got %s." %
1899              (state_classes, wrapped_func.output_classes))
1900
1901      # Extract and validate type information from the returned values.
1902      output_types = wrapped_func.output_types
1903      state_types = nest.map_structure(
1904          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
1905          state_structure)
1906      for new_state_type, state_type in zip(
1907          nest.flatten(output_types), nest.flatten(state_types)):
1908        if new_state_type != state_type:
1909          raise TypeError(
1910              "The element types for the new state must match the initial "
1911              "state. Expected %s; got %s." %
1912              (state_types, wrapped_func.output_types))
1913
1914      # Extract shape information from the returned values.
1915      output_shapes = wrapped_func.output_shapes
1916      state_shapes = nest.map_structure(
1917          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
1918          state_structure)
1919      flat_state_shapes = nest.flatten(state_shapes)
1920      flat_new_state_shapes = nest.flatten(output_shapes)
1921      weakened_state_shapes = [
1922          original.most_specific_compatible_shape(new)
1923          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
1924      ]
1925
1926      need_to_rerun = False
1927      for original_shape, weakened_shape in zip(flat_state_shapes,
1928                                                weakened_state_shapes):
1929        if original_shape.ndims is not None and (
1930            weakened_shape.ndims is None or
1931            original_shape.as_list() != weakened_shape.as_list()):
1932          need_to_rerun = True
1933          break
1934
1935      if need_to_rerun:
1936        # TODO(b/110122868): Support a "most specific compatible structure"
1937        # method for combining structures, to avoid using legacy structures
1938        # here.
1939        state_structure = structure.convert_legacy_structure(
1940            state_types,
1941            nest.pack_sequence_as(state_shapes, weakened_state_shapes),
1942            state_classes)
1943
1944    reduce_func = wrapped_func.function
1945    reduce_func.add_to_graph(ops.get_default_graph())
1946
1947    dataset = self._apply_options()
1948
1949    # pylint: disable=protected-access
1950    return structure.from_compatible_tensor_list(
1951        state_structure,
1952        gen_dataset_ops.reduce_dataset(
1953            dataset._variant_tensor,
1954            structure.to_tensor_list(state_structure, initial_state),
1955            reduce_func.captured_inputs,
1956            f=reduce_func,
1957            output_shapes=structure.get_flat_tensor_shapes(state_structure),
1958            output_types=structure.get_flat_tensor_types(state_structure)))
1959
1960  def unbatch(self):
1961    """Splits elements of a dataset into multiple elements.
1962
1963    For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
1964    where `B` may vary for each input element, then for each element in the
1965    dataset, the unbatched dataset will contain `B` consecutive elements
1966    of shape `[a0, a1, ...]`.
1967
1968    >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
1969    >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
1970    >>> dataset = dataset.unbatch()
1971    >>> list(dataset.as_numpy_iterator())
1972    [1, 2, 3, 1, 2, 1, 2, 3, 4]
1973
1974    Returns:
1975      A `Dataset`.
1976    """
1977    normalized_dataset = normalize_to_dense(self)
1978    return _UnbatchDataset(normalized_dataset)
1979
1980  def with_options(self, options):
1981    """Returns a new `tf.data.Dataset` with the given options set.
1982
1983    The options are "global" in the sense they apply to the entire dataset.
1984    If options are set multiple times, they are merged as long as different
1985    options do not use different non-default values.
1986
1987    >>> ds = tf.data.Dataset.range(5)
1988    >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5),
1989    ...                    cycle_length=3,
1990    ...                    num_parallel_calls=3)
1991    >>> options = tf.data.Options()
1992    >>> # This will make the interleave order non-deterministic.
1993    >>> options.experimental_deterministic = False
1994    >>> ds = ds.with_options(options)
1995
1996    Args:
1997      options: A `tf.data.Options` that identifies the options the use.
1998
1999    Returns:
2000      Dataset: A `Dataset` with the given options.
2001
2002    Raises:
2003      ValueError: when an option is set more than once to a non-default value
2004    """
2005    return _OptionsDataset(self, options)
2006
2007
2008@tf_export(v1=["data.Dataset"])
2009class DatasetV1(DatasetV2):
2010  """Represents a potentially large set of elements.
2011
2012  A `Dataset` can be used to represent an input pipeline as a
2013  collection of elements and a "logical plan" of transformations that act on
2014  those elements.
2015  """
2016
2017  def __init__(self):
2018    try:
2019      variant_tensor = self._as_variant_tensor()
2020    except AttributeError as e:
2021      if "_as_variant_tensor" in str(e):
2022        raise AttributeError("Please use _variant_tensor instead of "
2023                             "_as_variant_tensor() to obtain the variant "
2024                             "associated with a dataset")
2025      raise AttributeError("{}: A likely cause of this error is that the super "
2026                           "call for this dataset is not the last line of the "
2027                           "__init__ method. The base class causes the "
2028                           "_as_variant_tensor call in its constructor and "
2029                           "if that uses attributes defined in the __init__ "
2030                           "method, those attrs need to be defined before the "
2031                           "super call.".format(e))
2032    super(DatasetV1, self).__init__(variant_tensor)
2033
2034  @abc.abstractmethod
2035  def _as_variant_tensor(self):
2036    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
2037
2038    Returns:
2039      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
2040    """
2041    raise NotImplementedError("Dataset._as_variant_tensor")
2042
2043  @deprecation.deprecated(
2044      None, "Use `for ... in dataset:` to iterate over a dataset. If using "
2045      "`tf.estimator`, return the `Dataset` object directly from your input "
2046      "function. As a last resort, you can use "
2047      "`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
2048  def make_one_shot_iterator(self):
2049    """Creates an `Iterator` for enumerating the elements of this dataset.
2050
2051    Note: The returned iterator will be initialized automatically.
2052    A "one-shot" iterator does not currently support re-initialization.
2053
2054    Returns:
2055      An `Iterator` over the elements of this dataset.
2056    """
2057    return self._make_one_shot_iterator()
2058
2059  def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
2060    if context.executing_eagerly():
2061      return iterator_ops.OwnedIterator(self)
2062
2063    _ensure_same_dataset_graph(self)
2064    # Now that we create datasets at python object creation time, the capture
2065    # by value _make_dataset() function would try to capture these variant
2066    # tensor dataset inputs, which are marked as stateful ops and would throw
2067    # an error if we try and capture them. We therefore traverse the graph
2068    # to find all these ops and whitelist them so that the capturing
2069    # logic instead of throwing an error recreates these ops which is what was
2070    # happening before.
2071    all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
2072    graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
2073
2074    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
2075    # a 0-argument function.
2076    @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
2077    def _make_dataset():
2078      """Factory function for a dataset."""
2079      # NOTE(mrry): `Defun` does not capture the graph-level seed from the
2080      # enclosing graph, so if a graph-level seed is present we set the local
2081      # graph seed based on a combination of the graph- and op-level seeds.
2082      if graph_level_seed is not None:
2083        assert op_level_seed is not None
2084        core_random_seed.set_random_seed(
2085            (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
2086
2087      dataset = self._apply_options()
2088      return dataset._variant_tensor  # pylint: disable=protected-access
2089
2090    try:
2091      _make_dataset.add_to_graph(ops.get_default_graph())
2092    except ValueError as err:
2093      if "Cannot capture a stateful node" in str(err):
2094        raise ValueError(
2095            "Failed to create a one-shot iterator for a dataset. "
2096            "`Dataset.make_one_shot_iterator()` does not support datasets that "
2097            "capture stateful objects, such as a `Variable` or `LookupTable`. "
2098            "In these cases, use `Dataset.make_initializable_iterator()`. "
2099            "(Original error: %s)" % err)
2100      else:
2101        six.reraise(ValueError, err)
2102
2103    # pylint: disable=protected-access
2104    return iterator_ops.Iterator(
2105        gen_dataset_ops.one_shot_iterator(
2106            dataset_factory=_make_dataset, **self._flat_structure), None,
2107        get_legacy_output_types(self), get_legacy_output_shapes(self),
2108        get_legacy_output_classes(self))
2109
2110  @deprecation.deprecated(
2111      None, "Use `for ... in dataset:` to iterate over a dataset. If using "
2112      "`tf.estimator`, return the `Dataset` object directly from your input "
2113      "function. As a last resort, you can use "
2114      "`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
2115  def make_initializable_iterator(self, shared_name=None):
2116    """Creates an `Iterator` for enumerating the elements of this dataset.
2117
2118    Note: The returned iterator will be in an uninitialized state,
2119    and you must run the `iterator.initializer` operation before using it:
2120
2121    ```python
2122    dataset = ...
2123    iterator = dataset.make_initializable_iterator()
2124    # ...
2125    sess.run(iterator.initializer)
2126    ```
2127
2128    Args:
2129      shared_name: (Optional.) If non-empty, the returned iterator will be
2130        shared under the given name across multiple sessions that share the same
2131        devices (e.g. when using a remote server).
2132
2133    Returns:
2134      An `Iterator` over the elements of this dataset.
2135
2136    Raises:
2137      RuntimeError: If eager execution is enabled.
2138    """
2139
2140    return self._make_initializable_iterator(shared_name)
2141
2142  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=missing-docstring
2143    if context.executing_eagerly():
2144      raise RuntimeError(
2145          "dataset.make_initializable_iterator is not supported when eager "
2146          "execution is enabled. Use `for element in dataset` instead.")
2147    _ensure_same_dataset_graph(self)
2148    dataset = self._apply_options()
2149    if shared_name is None:
2150      shared_name = ""
2151    iterator_resource = gen_dataset_ops.iterator_v2(
2152        container="", shared_name=shared_name, **self._flat_structure)
2153    with ops.colocate_with(iterator_resource):
2154      initializer = gen_dataset_ops.make_iterator(
2155          dataset._variant_tensor,  # pylint: disable=protected-access
2156          iterator_resource)
2157    # pylint: disable=protected-access
2158    return iterator_ops.Iterator(
2159        iterator_resource, initializer, get_legacy_output_types(dataset),
2160        get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
2161
2162  @property
2163  @deprecation.deprecated(
2164      None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.")
2165  def output_classes(self):
2166    """Returns the class of each component of an element of this dataset.
2167
2168    Returns:
2169      A nested structure of Python `type` objects corresponding to each
2170      component of an element of this dataset.
2171    """
2172    return nest.map_structure(
2173        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2174        self.element_spec)
2175
2176  @property
2177  @deprecation.deprecated(
2178      None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.")
2179  def output_shapes(self):
2180    """Returns the shape of each component of an element of this dataset.
2181
2182    Returns:
2183      A nested structure of `tf.TensorShape` objects corresponding to each
2184      component of an element of this dataset.
2185    """
2186    return nest.map_structure(
2187        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2188        self.element_spec)
2189
2190  @property
2191  @deprecation.deprecated(
2192      None, "Use `tf.compat.v1.data.get_output_types(dataset)`.")
2193  def output_types(self):
2194    """Returns the type of each component of an element of this dataset.
2195
2196    Returns:
2197      A nested structure of `tf.DType` objects corresponding to each component
2198      of an element of this dataset.
2199    """
2200    return nest.map_structure(
2201        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2202        self.element_spec)
2203
2204  @property
2205  def element_spec(self):
2206    # TODO(b/110122868): Remove this override once all `Dataset` instances
2207    # implement `element_structure`.
2208    return structure.convert_legacy_structure(
2209        self.output_types, self.output_shapes, self.output_classes)
2210
2211  @staticmethod
2212  @functools.wraps(DatasetV2.from_tensors)
2213  def from_tensors(tensors):
2214    return DatasetV1Adapter(DatasetV2.from_tensors(tensors))
2215
2216  @staticmethod
2217  @functools.wraps(DatasetV2.from_tensor_slices)
2218  def from_tensor_slices(tensors):
2219    return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors))
2220
2221  @staticmethod
2222  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
2223  def from_sparse_tensor_slices(sparse_tensor):
2224    """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
2225
2226    Args:
2227      sparse_tensor: A `tf.SparseTensor`.
2228
2229    Returns:
2230      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
2231    """
2232    return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor))
2233
2234  @staticmethod
2235  @functools.wraps(DatasetV2.from_generator)
2236  def from_generator(generator, output_types, output_shapes=None, args=None):
2237    return DatasetV1Adapter(DatasetV2.from_generator(
2238        generator, output_types, output_shapes, args))
2239
2240  @staticmethod
2241  @functools.wraps(DatasetV2.range)
2242  def range(*args, **kwargs):
2243    return DatasetV1Adapter(DatasetV2.range(*args, **kwargs))
2244
2245  @staticmethod
2246  @functools.wraps(DatasetV2.zip)
2247  def zip(datasets):
2248    return DatasetV1Adapter(DatasetV2.zip(datasets))
2249
2250  @functools.wraps(DatasetV2.concatenate)
2251  def concatenate(self, dataset):
2252    return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset))
2253
2254  @functools.wraps(DatasetV2.prefetch)
2255  def prefetch(self, buffer_size):
2256    return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size))
2257
2258  @staticmethod
2259  @functools.wraps(DatasetV2.list_files)
2260  def list_files(file_pattern, shuffle=None, seed=None):
2261    return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed))
2262
2263  @functools.wraps(DatasetV2.repeat)
2264  def repeat(self, count=None):
2265    return DatasetV1Adapter(super(DatasetV1, self).repeat(count))
2266
2267  @functools.wraps(DatasetV2.shuffle)
2268  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
2269    return DatasetV1Adapter(super(DatasetV1, self).shuffle(
2270        buffer_size, seed, reshuffle_each_iteration))
2271
2272  @functools.wraps(DatasetV2.cache)
2273  def cache(self, filename=""):
2274    return DatasetV1Adapter(super(DatasetV1, self).cache(filename))
2275
2276  @functools.wraps(DatasetV2.take)
2277  def take(self, count):
2278    return DatasetV1Adapter(super(DatasetV1, self).take(count))
2279
2280  @functools.wraps(DatasetV2.skip)
2281  def skip(self, count):
2282    return DatasetV1Adapter(super(DatasetV1, self).skip(count))
2283
2284  @functools.wraps(DatasetV2.shard)
2285  def shard(self, num_shards, index):
2286    return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
2287
2288  @functools.wraps(DatasetV2.batch)
2289  def batch(self, batch_size, drop_remainder=False):
2290    return DatasetV1Adapter(super(DatasetV1, self).batch(
2291        batch_size, drop_remainder))
2292
2293  @functools.wraps(DatasetV2.padded_batch)
2294  def padded_batch(self,
2295                   batch_size,
2296                   padded_shapes=None,
2297                   padding_values=None,
2298                   drop_remainder=False):
2299    return DatasetV1Adapter(super(DatasetV1, self).padded_batch(
2300        batch_size, padded_shapes, padding_values, drop_remainder))
2301
2302  @functools.wraps(DatasetV2.map)
2303  def map(self, map_func, num_parallel_calls=None):
2304    if num_parallel_calls is None:
2305      return DatasetV1Adapter(
2306          MapDataset(self, map_func, preserve_cardinality=False))
2307    else:
2308      return DatasetV1Adapter(
2309          ParallelMapDataset(
2310              self, map_func, num_parallel_calls, preserve_cardinality=False))
2311
2312  @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
2313  def map_with_legacy_function(self, map_func, num_parallel_calls=None):
2314    """Maps `map_func` across the elements of this dataset.
2315
2316    NOTE: This is an escape hatch for existing uses of `map` that do not work
2317    with V2 functions. New uses are strongly discouraged and existing uses
2318    should migrate to `map` as this method will be removed in V2.
2319
2320    Args:
2321      map_func: A function mapping a nested structure of tensors (having shapes
2322        and types defined by `self.output_shapes` and `self.output_types`) to
2323        another nested structure of tensors.
2324      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
2325        representing the number elements to process asynchronously in parallel.
2326        If not specified, elements will be processed sequentially. If the value
2327        `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
2328        calls is set dynamically based on available CPU.
2329
2330    Returns:
2331      Dataset: A `Dataset`.
2332    """
2333    if num_parallel_calls is None:
2334      return DatasetV1Adapter(
2335          MapDataset(
2336              self,
2337              map_func,
2338              preserve_cardinality=False,
2339              use_legacy_function=True))
2340    else:
2341      return DatasetV1Adapter(
2342          ParallelMapDataset(
2343              self,
2344              map_func,
2345              num_parallel_calls,
2346              preserve_cardinality=False,
2347              use_legacy_function=True))
2348
2349  @functools.wraps(DatasetV2.flat_map)
2350  def flat_map(self, map_func):
2351    return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func))
2352
2353  @functools.wraps(DatasetV2.interleave)
2354  def interleave(self,
2355                 map_func,
2356                 cycle_length=AUTOTUNE,
2357                 block_length=1,
2358                 num_parallel_calls=None,
2359                 deterministic=None):
2360    return DatasetV1Adapter(
2361        super(DatasetV1, self).interleave(map_func, cycle_length, block_length,
2362                                          num_parallel_calls, deterministic))
2363
2364  @functools.wraps(DatasetV2.filter)
2365  def filter(self, predicate):
2366    return DatasetV1Adapter(super(DatasetV1, self).filter(predicate))
2367
2368  @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
2369  def filter_with_legacy_function(self, predicate):
2370    """Filters this dataset according to `predicate`.
2371
2372    NOTE: This is an escape hatch for existing uses of `filter` that do not work
2373    with V2 functions. New uses are strongly discouraged and existing uses
2374    should migrate to `filter` as this method will be removed in V2.
2375
2376    Args:
2377      predicate: A function mapping a nested structure of tensors (having shapes
2378        and types defined by `self.output_shapes` and `self.output_types`) to a
2379        scalar `tf.bool` tensor.
2380
2381    Returns:
2382      Dataset: The `Dataset` containing the elements of this dataset for which
2383          `predicate` is `True`.
2384    """
2385    return FilterDataset(self, predicate, use_legacy_function=True)
2386
2387  @functools.wraps(DatasetV2.apply)
2388  def apply(self, transformation_func):
2389    return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
2390
2391  @functools.wraps(DatasetV2.window)
2392  def window(self, size, shift=None, stride=1, drop_remainder=False):
2393    return DatasetV1Adapter(super(DatasetV1, self).window(
2394        size, shift, stride, drop_remainder))
2395
2396  @functools.wraps(DatasetV2.unbatch)
2397  def unbatch(self):
2398    return DatasetV1Adapter(super(DatasetV1, self).unbatch())
2399
2400  @functools.wraps(DatasetV2.with_options)
2401  def with_options(self, options):
2402    return DatasetV1Adapter(super(DatasetV1, self).with_options(options))
2403
2404
2405if tf2.enabled():
2406  Dataset = DatasetV2
2407else:
2408  Dataset = DatasetV1
2409
2410
2411class DatasetV1Adapter(DatasetV1):
2412  """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
2413
2414  def __init__(self, dataset):
2415    self._dataset = dataset
2416    super(DatasetV1Adapter, self).__init__()
2417
2418  def _as_variant_tensor(self):
2419    return self._dataset._variant_tensor  # pylint: disable=protected-access
2420
2421  def _has_captured_ref(self):
2422    return self._dataset._has_captured_ref()  # pylint: disable=protected-access
2423
2424  def _inputs(self):
2425    return self._dataset._inputs()  # pylint: disable=protected-access
2426
2427  def _functions(self):
2428    return self._dataset._functions()  # pylint: disable=protected-access
2429
2430  def options(self):
2431    return self._dataset.options()
2432
2433  @property
2434  def element_spec(self):
2435    return self._dataset.element_spec  # pylint: disable=protected-access
2436
2437  def __iter__(self):
2438    return iter(self._dataset)
2439
2440
2441def _ensure_same_dataset_graph(dataset):
2442  """Walks the dataset graph to ensure all datasets come from the same graph."""
2443  # pylint: disable=protected-access
2444  current_graph = ops.get_default_graph()
2445  bfs_q = Queue.Queue()
2446  bfs_q.put(dataset)
2447  visited = []
2448  while not bfs_q.empty():
2449    ds = bfs_q.get()
2450    visited.append(ds)
2451    ds_graph = ds._graph
2452    if current_graph != ds_graph:
2453      raise ValueError(
2454          "The graph (" + str(current_graph) + ") of the iterator is different "
2455          "from the graph (" + str(ds_graph) + ") the dataset: " +
2456          str(ds._variant_tensor) + " was  created in. If you are using the "
2457          "Estimator API, make sure that no part of the dataset returned by "
2458          "the `input_fn` function is defined outside the `input_fn` function. "
2459          "Please ensure that all datasets in the pipeline are created in the "
2460          "same graph as the iterator.")
2461    for input_ds in ds._inputs():
2462      if input_ds not in visited:
2463        bfs_q.put(input_ds)
2464
2465
2466@tf_export(v1=["data.make_one_shot_iterator"])
2467def make_one_shot_iterator(dataset):
2468  """Creates a `tf.compat.v1.data.Iterator` for enumerating dataset elements.
2469
2470  Note: The returned iterator will be initialized automatically.
2471  A "one-shot" iterator does not support re-initialization.
2472
2473  Args:
2474    dataset: A `tf.data.Dataset`.
2475
2476  Returns:
2477    A `tf.compat.v1.data.Iterator` over the elements of this dataset.
2478  """
2479  try:
2480    # Call the defined `_make_one_shot_iterator()` if there is one, because some
2481    # datasets (e.g. for prefetching) override its behavior.
2482    return dataset._make_one_shot_iterator()  # pylint: disable=protected-access
2483  except AttributeError:
2484    return DatasetV1Adapter(dataset)._make_one_shot_iterator()  # pylint: disable=protected-access
2485
2486
2487@tf_export(v1=["data.make_initializable_iterator"])
2488def make_initializable_iterator(dataset, shared_name=None):
2489  """Creates a `tf.compat.v1.data.Iterator` for enumerating the elements of a dataset.
2490
2491  Note: The returned iterator will be in an uninitialized state,
2492  and you must run the `iterator.initializer` operation before using it:
2493
2494  ```python
2495  dataset = ...
2496  iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
2497  # ...
2498  sess.run(iterator.initializer)
2499  ```
2500
2501  Args:
2502    dataset: A `tf.data.Dataset`.
2503    shared_name: (Optional.) If non-empty, the returned iterator will be shared
2504      under the given name across multiple sessions that share the same devices
2505      (e.g. when using a remote server).
2506
2507  Returns:
2508    A `tf.compat.v1.data.Iterator` over the elements of `dataset`.
2509
2510  Raises:
2511    RuntimeError: If eager execution is enabled.
2512  """
2513  try:
2514    # Call the defined `_make_initializable_iterator()` if there is one, because
2515    # some datasets (e.g. for prefetching) override its behavior.
2516    return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
2517  except AttributeError:
2518    return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
2519
2520
2521@tf_export("data.experimental.get_structure")
2522def get_structure(dataset_or_iterator):
2523  """Returns the type specification of an element of a `Dataset` or `Iterator`.
2524
2525  Args:
2526    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2527
2528  Returns:
2529    A nested structure of `tf.TypeSpec` objects matching the structure of an
2530    element of `dataset_or_iterator` and spacifying the type of individal
2531    components.
2532
2533  Raises:
2534    TypeError: If `dataset_or_iterator` is not a `Dataset` or `Iterator` object.
2535  """
2536  try:
2537    return dataset_or_iterator.element_spec  # pylint: disable=protected-access
2538  except AttributeError:
2539    raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator "
2540                    "object, but got %s." % type(dataset_or_iterator))
2541
2542
2543@tf_export(v1=["data.get_output_classes"])
2544def get_legacy_output_classes(dataset_or_iterator):
2545  """Returns the output classes of a `Dataset` or `Iterator` elements.
2546
2547  This utility method replaces the deprecated-in-V2
2548  `tf.compat.v1.Dataset.output_classes` property.
2549
2550  Args:
2551    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2552
2553  Returns:
2554    A nested structure of Python `type` objects matching the structure of the
2555    dataset / iterator elements and specifying the class of the individual
2556    components.
2557  """
2558  return nest.map_structure(
2559      lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2560      get_structure(dataset_or_iterator))
2561
2562
2563@tf_export(v1=["data.get_output_shapes"])
2564def get_legacy_output_shapes(dataset_or_iterator):
2565  """Returns the output shapes of a `Dataset` or `Iterator` elements.
2566
2567  This utility method replaces the deprecated-in-V2
2568  `tf.compat.v1.Dataset.output_shapes` property.
2569
2570  Args:
2571    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2572
2573  Returns:
2574    A nested structure of `tf.TensorShape` objects matching the structure of
2575    the dataset / iterator elements and specifying the shape of the individual
2576    components.
2577  """
2578  return nest.map_structure(
2579      lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2580      get_structure(dataset_or_iterator))
2581
2582
2583@tf_export(v1=["data.get_output_types"])
2584def get_legacy_output_types(dataset_or_iterator):
2585  """Returns the output shapes of a `Dataset` or `Iterator` elements.
2586
2587  This utility method replaces the deprecated-in-V2
2588  `tf.compat.v1.Dataset.output_types` property.
2589
2590  Args:
2591    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
2592
2593  Returns:
2594    A nested structure of `tf.DType` objects objects matching the structure of
2595    dataset / iterator elements and specifying the shape of the individual
2596    components.
2597  """
2598  return nest.map_structure(
2599      lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2600      get_structure(dataset_or_iterator))
2601
2602
2603@tf_export("data.Options")
2604class Options(options_lib.OptionsBase):
2605  """Represents options for tf.data.Dataset.
2606
2607  An `Options` object can be, for instance, used to control which graph
2608  optimizations to apply or whether to use performance modeling to dynamically
2609  tune the parallelism of operations such as `tf.data.Dataset.map` or
2610  `tf.data.Dataset.interleave`.
2611
2612  After constructing an `Options` object, use `dataset.with_options(options)` to
2613  apply the options to a dataset.
2614
2615  >>> dataset = tf.data.Dataset.range(3)
2616  >>> options = tf.data.Options()
2617  >>> # Set options here.
2618  >>> dataset = dataset.with_options(options)
2619  """
2620
2621  experimental_deterministic = options_lib.create_option(
2622      name="experimental_deterministic",
2623      ty=bool,
2624      docstring=
2625      "Whether the outputs need to be produced in deterministic order. If None,"
2626      " defaults to True.")
2627
2628  experimental_distribute = options_lib.create_option(
2629      name="experimental_distribute",
2630      ty=distribute_options.DistributeOptions,
2631      docstring=
2632      "The distribution strategy options associated with the dataset. See "
2633      "`tf.data.experimental.DistributeOptions` for more details.",
2634      default_factory=distribute_options.DistributeOptions)
2635
2636  experimental_optimization = options_lib.create_option(
2637      name="experimental_optimization",
2638      ty=optimization_options.OptimizationOptions,
2639      docstring=
2640      "The optimization options associated with the dataset. See "
2641      "`tf.data.experimental.OptimizationOptions` for more details.",
2642      default_factory=optimization_options.OptimizationOptions)
2643
2644  experimental_slack = options_lib.create_option(
2645      name="experimental_slack",
2646      ty=bool,
2647      docstring="Whether to introduce 'slack' in the last `prefetch` of the "
2648      "input pipeline, if it exists. This may reduce CPU contention with "
2649      "accelerator host-side activity at the start of a step. The slack "
2650      "frequency is determined by the number of devices attached to this "
2651      "input pipeline. If None, defaults to False.")
2652
2653  experimental_stats = options_lib.create_option(
2654      name="experimental_stats",
2655      ty=stats_options.StatsOptions,
2656      docstring=
2657      "The statistics options associated with the dataset. See "
2658      "`tf.data.experimental.StatsOptions` for more details.",
2659      default_factory=stats_options.StatsOptions)
2660
2661  experimental_threading = options_lib.create_option(
2662      name="experimental_threading",
2663      ty=threading_options.ThreadingOptions,
2664      docstring=
2665      "The threading options associated with the dataset. See "
2666      "`tf.data.experimental.ThreadingOptions` for more details.",
2667      default_factory=threading_options.ThreadingOptions)
2668
2669  experimental_external_state_policy = options_lib.create_option(
2670      name="experimental_external_state_policy",
2671      ty=distribute_options.ExternalStatePolicy,
2672      docstring="By default, tf.data will refuse to serialize a dataset or "
2673      "checkpoint its iterator if the dataset contains a stateful op as the "
2674      "serialization / checkpointing won't be able to capture its state. "
2675      "Users can -- at their own risk -- override this restriction by "
2676      "explicitly specifying that they are fine throwing away the state "
2677      "in these ops. There are three settings available - IGNORE: in which we"
2678      "completely ignore any state; WARN: We warn the user that some state "
2679      "might be thrown away; FAIL: We fail if any state is being captured.",
2680      default_factory=lambda: distribute_options.ExternalStatePolicy.WARN)
2681
2682  def _graph_rewrites(self):
2683    """Produces the list of enabled static graph rewrites."""
2684    result = []
2685    if self.experimental_optimization is not None:
2686      result.extend(self.experimental_optimization._graph_rewrites())  # pylint: disable=protected-access
2687    else:
2688      # Apply default options
2689      result.extend(
2690          optimization_options.OptimizationOptions()._graph_rewrites())  # pylint: disable=protected-access
2691
2692    if self.experimental_deterministic is False:
2693      result.append("make_sloppy")
2694    if self.experimental_stats and self.experimental_stats.latency_all_edges:
2695      result.append("latency_all_edges")
2696    if self.experimental_slack:
2697      result.append("slack")
2698    if (self.experimental_distribute and
2699        self.experimental_distribute._make_stateless):  # pylint: disable=protected-access
2700      result.append("make_stateless")
2701    return result
2702
2703  def _graph_rewrite_configs(self):
2704    """Produces the list of configurations for enabled graph optimizations."""
2705    result = []
2706    if self.experimental_optimization:
2707      result.extend(self.experimental_optimization._graph_rewrite_configs())  # pylint: disable=protected-access
2708
2709    if self.experimental_slack:
2710      num_devices = self.experimental_distribute.num_devices
2711      if num_devices is None:
2712        num_devices = 1
2713      result.append("slack:slack_period:%d" % num_devices)
2714    return result
2715
2716  def _autotune_settings(self):
2717    if self.experimental_optimization is not None:
2718      return self.experimental_optimization._autotune_settings()  # pylint: disable=protected-access
2719
2720    # Return default autotune options
2721    return optimization_options.OptimizationOptions()._autotune_settings()  # pylint: disable=protected-access
2722
2723  def merge(self, options):
2724    """Merges itself with the given `tf.data.Options`.
2725
2726    The given `tf.data.Options` can be merged as long as there does not exist an
2727    attribute that is set to different values in `self` and `options`.
2728
2729    Args:
2730      options: a `tf.data.Options` to merge with
2731
2732    Raises:
2733      ValueError: if the given `tf.data.Options` cannot be merged
2734
2735    Returns:
2736      New `tf.data.Options()` object which is the result of merging self with
2737      the input `tf.data.Options`.
2738    """
2739    return options_lib.merge_options(self, options)
2740
2741
2742class DatasetSource(DatasetV2):
2743  """Abstract class representing a dataset with no inputs."""
2744
2745  def _inputs(self):
2746    return []
2747
2748
2749class UnaryDataset(DatasetV2):
2750  """Abstract class representing a dataset with one input."""
2751
2752  def __init__(self, input_dataset, variant_tensor):
2753    self._input_dataset = input_dataset
2754    super(UnaryDataset, self).__init__(variant_tensor)
2755
2756  def _inputs(self):
2757    return [self._input_dataset]
2758
2759
2760class UnaryUnchangedStructureDataset(UnaryDataset):
2761  """Represents a unary dataset with the same input and output structure."""
2762
2763  def __init__(self, input_dataset, variant_tensor):
2764    self._input_dataset = input_dataset
2765    super(UnaryUnchangedStructureDataset, self).__init__(
2766        input_dataset, variant_tensor)
2767
2768  @property
2769  def element_spec(self):
2770    return self._input_dataset.element_spec
2771
2772
2773class TensorDataset(DatasetSource):
2774  """A `Dataset` with a single element."""
2775
2776  def __init__(self, element):
2777    """See `Dataset.from_tensors()` for details."""
2778    element = structure.normalize_element(element)
2779    self._structure = structure.type_spec_from_value(element)
2780    self._tensors = structure.to_tensor_list(self._structure, element)
2781
2782    variant_tensor = gen_dataset_ops.tensor_dataset(
2783        self._tensors,
2784        output_shapes=structure.get_flat_tensor_shapes(self._structure))
2785    super(TensorDataset, self).__init__(variant_tensor)
2786
2787  @property
2788  def element_spec(self):
2789    return self._structure
2790
2791
2792class TensorSliceDataset(DatasetSource):
2793  """A `Dataset` of slices from a dataset element."""
2794
2795  def __init__(self, element):
2796    """See `Dataset.from_tensor_slices()` for details."""
2797    element = structure.normalize_element(element)
2798    batched_spec = structure.type_spec_from_value(element)
2799    self._tensors = structure.to_batched_tensor_list(batched_spec, element)
2800    self._structure = nest.map_structure(
2801        lambda component_spec: component_spec._unbatch(), batched_spec)  # pylint: disable=protected-access
2802
2803    batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value(
2804        self._tensors[0].get_shape()[0]))
2805    for t in self._tensors[1:]:
2806      batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
2807          tensor_shape.dimension_value(t.get_shape()[0])))
2808
2809    variant_tensor = gen_dataset_ops.tensor_slice_dataset(
2810        self._tensors,
2811        output_shapes=structure.get_flat_tensor_shapes(self._structure))
2812    super(TensorSliceDataset, self).__init__(variant_tensor)
2813
2814  @property
2815  def element_spec(self):
2816    return self._structure
2817
2818
2819class SparseTensorSliceDataset(DatasetSource):
2820  """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
2821
2822  def __init__(self, sparse_tensor):
2823    """See `Dataset.from_sparse_tensor_slices()` for details."""
2824    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
2825      raise TypeError(
2826          "`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format(
2827              sparse_tensor))
2828    self._sparse_tensor = sparse_tensor
2829
2830    indices_shape = self._sparse_tensor.indices.get_shape()
2831    shape_shape = self._sparse_tensor.dense_shape.get_shape()
2832    rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1)
2833    self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64),
2834                       tensor_spec.TensorSpec([None],
2835                                              self._sparse_tensor.dtype),
2836                       tensor_spec.TensorSpec([rank], dtypes.int64))
2837
2838    variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
2839        self._sparse_tensor.indices, self._sparse_tensor.values,
2840        self._sparse_tensor.dense_shape)
2841    super(SparseTensorSliceDataset, self).__init__(variant_tensor)
2842
2843  @property
2844  def element_spec(self):
2845    return self._structure
2846
2847
2848class _VariantDataset(DatasetV2):
2849  """A Dataset wrapper around a `tf.variant`-typed function argument."""
2850
2851  def __init__(self, dataset_variant, structure):
2852    self._structure = structure
2853    super(_VariantDataset, self).__init__(dataset_variant)
2854
2855  def _inputs(self):
2856    return []
2857
2858  @property
2859  def element_spec(self):
2860    return self._structure
2861
2862
2863class _NestedVariant(composite_tensor.CompositeTensor):
2864
2865  def __init__(self, variant_tensor, element_spec, dataset_shape):
2866    self._variant_tensor = variant_tensor
2867    self._element_spec = element_spec
2868    self._dataset_shape = dataset_shape
2869
2870  @property
2871  def _type_spec(self):
2872    return DatasetSpec(self._element_spec, self._dataset_shape)
2873
2874
2875@tf_export("data.experimental.from_variant")
2876def from_variant(variant, structure):
2877  """Constructs a dataset from the given variant and structure.
2878
2879  Args:
2880    variant: A scalar `tf.variant` tensor representing a dataset.
2881    structure: A `tf.data.experimental.Structure` object representing the
2882      structure of each element in the dataset.
2883
2884  Returns:
2885    A `tf.data.Dataset` instance.
2886  """
2887  return _VariantDataset(variant, structure)  # pylint: disable=protected-access
2888
2889
2890@tf_export("data.experimental.to_variant")
2891def to_variant(dataset):
2892  """Returns a variant representing the given dataset.
2893
2894  Args:
2895    dataset: A `tf.data.Dataset`.
2896
2897  Returns:
2898    A scalar `tf.variant` tensor representing the given dataset.
2899  """
2900  return dataset._variant_tensor  # pylint: disable=protected-access
2901
2902
2903@tf_export(
2904    "data.DatasetSpec",
2905    v1=["data.DatasetSpec", "data.experimental.DatasetStructure"])
2906class DatasetSpec(type_spec.BatchableTypeSpec):
2907  """Type specification for `tf.data.Dataset`.
2908
2909  See `tf.TypeSpec` for more information about TensorFlow type specifications.
2910
2911  >>> dataset = tf.data.Dataset.range(3)
2912  >>> tf.data.DatasetSpec.from_value(dataset)
2913  DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([]))
2914  """
2915
2916  __slots__ = ["_element_spec", "_dataset_shape"]
2917
2918  def __init__(self, element_spec, dataset_shape=()):
2919    self._element_spec = element_spec
2920    self._dataset_shape = tensor_shape.as_shape(dataset_shape)
2921
2922  @property
2923  def value_type(self):
2924    return _VariantDataset
2925
2926  def _serialize(self):
2927    return (self._element_spec, self._dataset_shape)
2928
2929  @property
2930  def _component_specs(self):
2931    return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant)
2932
2933  def _to_components(self, value):
2934    return value._variant_tensor  # pylint: disable=protected-access
2935
2936  def _from_components(self, components):
2937    # pylint: disable=protected-access
2938    if self._dataset_shape.ndims == 0:
2939      return _VariantDataset(components, self._element_spec)
2940    else:
2941      return _NestedVariant(components, self._element_spec, self._dataset_shape)
2942
2943  def _to_tensor_list(self, value):
2944    return [
2945        ops.convert_to_tensor(
2946            tf_nest.map_structure(lambda x: x._variant_tensor, value))  # pylint: disable=protected-access
2947    ]
2948
2949  @staticmethod
2950  def from_value(value):
2951    """Creates a `DatasetSpec` for the given `tf.data.Dataset` value."""
2952    return DatasetSpec(value.element_spec)  # pylint: disable=protected-access
2953
2954  def _batch(self, batch_size):
2955    return DatasetSpec(
2956        self._element_spec,
2957        tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
2958
2959  def _unbatch(self):
2960    if self._dataset_shape.ndims == 0:
2961      raise ValueError("Unbatching a dataset is only supported for rank >= 1")
2962    return DatasetSpec(self._element_spec, self._dataset_shape[1:])
2963
2964  def _to_batched_tensor_list(self, value):
2965    if self._dataset_shape.ndims == 0:
2966      raise ValueError("Unbatching a dataset is only supported for rank >= 1")
2967    return self._to_tensor_list(value)
2968
2969  def _to_legacy_output_types(self):
2970    return self
2971
2972  def _to_legacy_output_shapes(self):
2973    return self
2974
2975  def _to_legacy_output_classes(self):
2976    return self
2977
2978
2979class StructuredFunctionWrapper(object):
2980  """A function wrapper that supports structured arguments and return values."""
2981
2982  # pylint: disable=protected-access
2983  def __init__(self,
2984               func,
2985               transformation_name,
2986               dataset=None,
2987               input_classes=None,
2988               input_shapes=None,
2989               input_types=None,
2990               input_structure=None,
2991               add_to_graph=True,
2992               use_legacy_function=False,
2993               defun_kwargs=None):
2994    """Creates a new `StructuredFunctionWrapper` for the given function.
2995
2996    Args:
2997      func: A function from a nested structure to another nested structure.
2998      transformation_name: Human-readable name of the transformation in which
2999        this function is being instantiated, for error messages.
3000      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
3001        dataset will be assumed as the structure for `func` arguments; otherwise
3002        `input_classes`, `input_shapes`, and `input_types` must be defined.
3003      input_classes: (Optional.) A nested structure of `type`. If given, this
3004        argument defines the Python types for `func` arguments.
3005      input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If
3006        given, this argument defines the shapes and structure for `func`
3007        arguments.
3008      input_types: (Optional.) A nested structure of `tf.DType`. If given, this
3009        argument defines the element types and structure for `func` arguments.
3010      input_structure: (Optional.) A `Structure` object. If given, this argument
3011        defines the element types and structure for `func` arguments.
3012      add_to_graph: (Optional.) If `True`, the function will be added to the
3013        default graph, if it exists.
3014      use_legacy_function: (Optional.) A boolean that determines whether the
3015        function be created using `tensorflow.python.eager.function.defun`
3016        (default behavior) or `tensorflow.python.framework.function.Defun`
3017        (legacy beheavior).
3018      defun_kwargs: (Optional.) A dictionary mapping string argument names to
3019        values. If supplied, will be passed to `function` as keyword arguments.
3020
3021    Raises:
3022      ValueError: If an invalid combination of `dataset`, `input_classes`,
3023        `input_shapes`, and `input_types` is passed.
3024    """
3025    if input_structure is None:
3026      if dataset is None:
3027        if input_classes is None or input_shapes is None or input_types is None:
3028          raise ValueError("Either `dataset`, `input_structure` or all of "
3029                           "`input_classes`, `input_shapes`, and `input_types` "
3030                           "must be specified.")
3031        self._input_structure = structure.convert_legacy_structure(
3032            input_types, input_shapes, input_classes)
3033      else:
3034        if not (input_classes is None and input_shapes is None and
3035                input_types is None):
3036          raise ValueError("Either `dataset`, `input_structure` or all of "
3037                           "`input_classes`, `input_shapes`, and `input_types` "
3038                           "must be specified.")
3039        self._input_structure = dataset.element_spec
3040    else:
3041      if not (dataset is None and input_classes is None and input_shapes is None
3042              and input_types is None):
3043        raise ValueError("Either `dataset`, `input_structure`, or all of "
3044                         "`input_classes`, `input_shapes`, and `input_types` "
3045                         "must be specified.")
3046      self._input_structure = input_structure
3047
3048    self._func = func
3049
3050    # There is no graph to add in eager mode.
3051    add_to_graph &= not context.executing_eagerly()
3052    # There are some lifetime issues when a legacy function is not added to a
3053    # out-living graph. It's already deprecated so de-priotizing the fix.
3054    add_to_graph |= use_legacy_function
3055
3056    if defun_kwargs is None:
3057      defun_kwargs = {}
3058
3059    readable_transformation_name = transformation_name.replace(
3060        ".", "_")[:-2] if len(transformation_name) > 2 else ""
3061
3062    func_name = "_".join(
3063        [readable_transformation_name,
3064         function_utils.get_func_name(func)])
3065    # Sanitize function name to remove symbols that interfere with graph
3066    # construction.
3067    for symbol in ["<", ">", "\\", "'", " "]:
3068      func_name = func_name.replace(symbol, "")
3069
3070    ag_ctx = autograph_ctx.control_status_ctx()
3071
3072    def _warn_if_collections(transformation_name):
3073      """Prints a warning if the given graph uses common graph collections.
3074
3075      NOTE(mrry): Currently a warning is only generated for resources. Any
3076      variables created will be automatically hoisted out to the outermost scope
3077      using `init_scope()`. Some collections (such as for control-flow contexts)
3078      are benign and should not generate a warning.
3079
3080      Args:
3081        transformation_name: A human-readable name for the transformation.
3082      """
3083      warnings.warn("Creating resources inside a function passed to %s "
3084                    "is not supported. Create each resource outside the "
3085                    "function, and capture it inside the function to use it." %
3086                    transformation_name, stacklevel=5)
3087
3088    def _wrapper_helper(*args):
3089      """Wrapper for passing nested structures to and from tf.data functions."""
3090      nested_args = structure.from_compatible_tensor_list(
3091          self._input_structure, args)
3092      if not _should_unpack_args(nested_args):
3093        nested_args = (nested_args,)
3094
3095      ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
3096      # If `func` returns a list of tensors, `nest.flatten()` and
3097      # `ops.convert_to_tensor()` would conspire to attempt to stack
3098      # those tensors into a single tensor, because the customized
3099      # version of `nest.flatten()` does not recurse into lists. Since
3100      # it is more likely that the list arose from returning the
3101      # result of an operation (such as `tf.numpy_function()`) that returns a
3102      # list of not-necessarily-stackable tensors, we treat the
3103      # returned value is a `tuple` instead. A user wishing to pack
3104      # the return value into a single tensor can use an explicit
3105      # `tf.stack()` before returning.
3106      if isinstance(ret, list):
3107        ret = tuple(ret)
3108
3109      try:
3110        self._output_structure = structure.type_spec_from_value(ret)
3111      except (ValueError, TypeError):
3112        six.reraise(
3113            TypeError,
3114            TypeError("Unsupported return value from function passed to "
3115                      "%s: %s." % (transformation_name, ret)),
3116            sys.exc_info()[2])
3117      return ret
3118
3119    if use_legacy_function:
3120      func_name = func_name + "_" + str(ops.uid())
3121
3122      @function.Defun(
3123          *structure.get_flat_tensor_types(self._input_structure),
3124          func_name=func_name,
3125          **defun_kwargs)
3126      def wrapper_fn(*args):
3127        ret = _wrapper_helper(*args)
3128        # _warn_if_collections(transformation_name, ops.get_default_graph(), 0)
3129        return structure.to_tensor_list(self._output_structure, ret)
3130
3131      self._function = wrapper_fn
3132      resource_tracker = tracking.ResourceTracker()
3133      with tracking.resource_tracker_scope(resource_tracker):
3134        if add_to_graph:
3135          self._function.add_to_graph(ops.get_default_graph())
3136        else:
3137          # Use the private method that will execute `wrapper_fn` but delay
3138          # adding it to the graph in case (e.g.) we need to rerun the function.
3139          self._function._create_definition_if_needed()
3140      if resource_tracker.resources:
3141        _warn_if_collections(transformation_name)
3142
3143    else:
3144      defun_kwargs.update({"func_name": func_name})
3145
3146      # Note: _wrapper_helper will apply autograph based on context.
3147      @eager_function.defun_with_attributes(
3148          input_signature=structure.get_flat_tensor_specs(
3149              self._input_structure),
3150          autograph=False,
3151          attributes=defun_kwargs)
3152      def wrapper_fn(*args):  # pylint: disable=missing-docstring
3153        ret = _wrapper_helper(*args)
3154        ret = structure.to_tensor_list(self._output_structure, ret)
3155        return [ops.convert_to_tensor(t) for t in ret]
3156
3157      resource_tracker = tracking.ResourceTracker()
3158      with tracking.resource_tracker_scope(resource_tracker):
3159        # TODO(b/141462134): Switch to using garbage collection.
3160        self._function = wrapper_fn.get_concrete_function()
3161
3162        if add_to_graph:
3163          self._function.add_to_graph(ops.get_default_graph())
3164      if resource_tracker.resources:
3165        _warn_if_collections(transformation_name)
3166
3167      outer_graph_seed = ops.get_default_graph().seed
3168      if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
3169        if self._function.graph._seed_used:
3170          warnings.warn(
3171              "Seed %s from outer graph might be getting used by function %s, "
3172              "if the random op has not been provided any seed. Explicitly set "
3173              "the seed in the function if this is not the intended behavior."
3174              %(outer_graph_seed, func_name), stacklevel=4)
3175  # pylint: enable=protected-access
3176
3177  @property
3178  def output_structure(self):
3179    return self._output_structure
3180
3181  @property
3182  def output_classes(self):
3183    return nest.map_structure(
3184        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
3185        self._output_structure)
3186
3187  @property
3188  def output_shapes(self):
3189    return nest.map_structure(
3190        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
3191        self._output_structure)
3192
3193  @property
3194  def output_types(self):
3195    return nest.map_structure(
3196        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
3197        self._output_structure)
3198
3199  @property
3200  def function(self):
3201    return self._function
3202
3203
3204class _GeneratorDataset(DatasetSource):
3205  """A `Dataset` that generates elements by invoking a function."""
3206
3207  def __init__(self, init_args, init_func, next_func, finalize_func):
3208    """Constructs a `_GeneratorDataset`.
3209
3210    Args:
3211      init_args: A nested structure representing the arguments to `init_func`.
3212      init_func: A TensorFlow function that will be called on `init_args` each
3213        time a C++ iterator over this dataset is constructed. Returns a nested
3214        structure representing the "state" of the dataset.
3215      next_func: A TensorFlow function that will be called on the result of
3216        `init_func` to produce each element, and that raises `OutOfRangeError`
3217        to terminate iteration.
3218      finalize_func: A TensorFlow function that will be called on the result of
3219        `init_func` immediately before a C++ iterator over this dataset is
3220        destroyed. The return value is ignored.
3221    """
3222    self._init_args = init_args
3223
3224    self._init_structure = structure.type_spec_from_value(init_args)
3225
3226    self._init_func = StructuredFunctionWrapper(
3227        init_func,
3228        self._transformation_name(),
3229        input_structure=self._init_structure)
3230
3231    self._next_func = StructuredFunctionWrapper(
3232        next_func,
3233        self._transformation_name(),
3234        input_structure=self._init_func.output_structure)
3235
3236    self._finalize_func = StructuredFunctionWrapper(
3237        finalize_func,
3238        self._transformation_name(),
3239        input_structure=self._init_func.output_structure)
3240    variant_tensor = gen_dataset_ops.generator_dataset(
3241        structure.to_tensor_list(self._init_structure, self._init_args) +
3242        self._init_func.function.captured_inputs,
3243        self._next_func.function.captured_inputs,
3244        self._finalize_func.function.captured_inputs,
3245        init_func=self._init_func.function,
3246        next_func=self._next_func.function,
3247        finalize_func=self._finalize_func.function,
3248        **self._flat_structure)
3249    super(_GeneratorDataset, self).__init__(variant_tensor)
3250
3251  @property
3252  def element_spec(self):
3253    return self._next_func.output_structure
3254
3255  def _transformation_name(self):
3256    return "Dataset.from_generator()"
3257
3258
3259class ZipDataset(DatasetV2):
3260  """A `Dataset` that zips its inputs together."""
3261
3262  def __init__(self, datasets):
3263    """See `Dataset.zip()` for details."""
3264    for ds in nest.flatten(datasets):
3265      if not isinstance(ds, DatasetV2):
3266        if isinstance(ds, list):
3267          message = ("The argument to `Dataset.zip()` must be a nested "
3268                     "structure of `Dataset` objects. Nested structures do not "
3269                     "support Python lists; please use a tuple instead.")
3270        else:
3271          message = ("The argument to `Dataset.zip()` must be a nested "
3272                     "structure of `Dataset` objects.")
3273        raise TypeError(message)
3274    self._datasets = datasets
3275    self._structure = nest.pack_sequence_as(
3276        self._datasets,
3277        [ds.element_spec for ds in nest.flatten(self._datasets)])
3278    variant_tensor = gen_dataset_ops.zip_dataset(
3279        [ds._variant_tensor for ds in nest.flatten(self._datasets)],
3280        **self._flat_structure)
3281    super(ZipDataset, self).__init__(variant_tensor)
3282
3283  def _inputs(self):
3284    return nest.flatten(self._datasets)
3285
3286  @property
3287  def element_spec(self):
3288    return self._structure
3289
3290
3291class ConcatenateDataset(DatasetV2):
3292  """A `Dataset` that concatenates its input with given dataset."""
3293
3294  def __init__(self, input_dataset, dataset_to_concatenate):
3295    """See `Dataset.concatenate()` for details."""
3296    self._input_dataset = input_dataset
3297    self._dataset_to_concatenate = dataset_to_concatenate
3298
3299    output_types = get_legacy_output_types(input_dataset)
3300    if output_types != get_legacy_output_types(dataset_to_concatenate):
3301      raise TypeError(
3302          "Two datasets to concatenate have different types %s and %s" %
3303          (output_types, get_legacy_output_types(dataset_to_concatenate)))
3304
3305    output_classes = get_legacy_output_classes(input_dataset)
3306    if output_classes != get_legacy_output_classes(dataset_to_concatenate):
3307      raise TypeError(
3308          "Two datasets to concatenate have different classes %s and %s" %
3309          (output_classes, get_legacy_output_classes(dataset_to_concatenate)))
3310
3311    input_shapes = get_legacy_output_shapes(self._input_dataset)
3312    output_shapes = nest.pack_sequence_as(input_shapes, [
3313        ts1.most_specific_compatible_shape(ts2)
3314        for (ts1, ts2) in zip(
3315            nest.flatten(input_shapes),
3316            nest.flatten(get_legacy_output_shapes(
3317                self._dataset_to_concatenate)))
3318    ])
3319
3320    self._structure = structure.convert_legacy_structure(
3321        output_types, output_shapes, output_classes)
3322
3323    self._input_datasets = [input_dataset, dataset_to_concatenate]
3324    # pylint: disable=protected-access
3325    variant_tensor = gen_dataset_ops.concatenate_dataset(
3326        input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
3327        **self._flat_structure)
3328    # pylint: enable=protected-access
3329    super(ConcatenateDataset, self).__init__(variant_tensor)
3330
3331  def _inputs(self):
3332    return self._input_datasets
3333
3334  @property
3335  def element_spec(self):
3336    return self._structure
3337
3338
3339class RepeatDataset(UnaryUnchangedStructureDataset):
3340  """A `Dataset` that repeats its input several times."""
3341
3342  def __init__(self, input_dataset, count):
3343    """See `Dataset.repeat()` for details."""
3344    self._input_dataset = input_dataset
3345    if count is None:
3346      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
3347    else:
3348      self._count = ops.convert_to_tensor(
3349          count, dtype=dtypes.int64, name="count")
3350    variant_tensor = gen_dataset_ops.repeat_dataset(
3351        input_dataset._variant_tensor,  # pylint: disable=protected-access
3352        count=self._count,
3353        **self._flat_structure)
3354    super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
3355
3356
3357class RangeDataset(DatasetSource):
3358  """A `Dataset` of a step separated range of values."""
3359
3360  def __init__(self, *args, **kwargs):
3361    """See `Dataset.range()` for details."""
3362    self._parse_args(*args, **kwargs)
3363    self._structure = tensor_spec.TensorSpec([], self._output_type)
3364    variant_tensor = gen_dataset_ops.range_dataset(
3365        start=self._start,
3366        stop=self._stop,
3367        step=self._step,
3368        **self._flat_structure)
3369    super(RangeDataset, self).__init__(variant_tensor)
3370
3371  def _parse_args(self, *args, **kwargs):
3372    """Parse arguments according to the same rules as the `range()` builtin."""
3373    if len(args) == 1:
3374      self._start = self._build_tensor(0, "start")
3375      self._stop = self._build_tensor(args[0], "stop")
3376      self._step = self._build_tensor(1, "step")
3377    elif len(args) == 2:
3378      self._start = self._build_tensor(args[0], "start")
3379      self._stop = self._build_tensor(args[1], "stop")
3380      self._step = self._build_tensor(1, "step")
3381    elif len(args) == 3:
3382      self._start = self._build_tensor(args[0], "start")
3383      self._stop = self._build_tensor(args[1], "stop")
3384      self._step = self._build_tensor(args[2], "step")
3385    else:
3386      raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
3387    if "output_type" in kwargs:
3388      self._output_type = kwargs["output_type"]
3389    else:
3390      self._output_type = dtypes.int64
3391
3392  def _build_tensor(self, int64_value, name):
3393    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
3394
3395  @property
3396  def element_spec(self):
3397    return self._structure
3398
3399
3400class _MemoryCacheDeleter(object):
3401  """An object which cleans up an anonymous memory cache resource.
3402
3403  An alternative to defining a __del__ method on an object. Even if the parent
3404  object is part of a reference cycle, the cycle will be collectable.
3405  """
3406
3407  def __init__(self, handle, device, deleter):
3408    self._deleter = deleter
3409    self._handle = handle
3410    self._device = device
3411    self._eager_mode = context.executing_eagerly()
3412
3413  def __del__(self):
3414    with ops.device(self._device):
3415      # Make sure the resource is deleted in the same mode as it was created in.
3416      if self._eager_mode:
3417        with context.eager_mode():
3418          gen_dataset_ops.delete_memory_cache(
3419              handle=self._handle, deleter=self._deleter)
3420      else:
3421        with context.graph_mode():
3422          gen_dataset_ops.delete_memory_cache(
3423              handle=self._handle, deleter=self._deleter)
3424
3425
3426class _MemoryCache(object):
3427  """Represents a memory cache resource."""
3428
3429  def __init__(self):
3430    super(_MemoryCache, self).__init__()
3431    self._device = context.context().device_name
3432    self._handle, self._deleter = (gen_dataset_ops.anonymous_memory_cache())
3433    self._resource_deleter = _MemoryCacheDeleter(
3434        handle=self._handle, device=self._device, deleter=self._deleter)
3435
3436  @property
3437  def handle(self):
3438    return self._handle
3439
3440
3441class CacheDataset(UnaryUnchangedStructureDataset):
3442  """A `Dataset` that caches elements of its input."""
3443
3444  def __init__(self, input_dataset, filename):
3445    """See `Dataset.cache()` for details."""
3446    self._input_dataset = input_dataset
3447    self._filename = ops.convert_to_tensor(
3448        filename, dtype=dtypes.string, name="filename")
3449    if tf2.enabled() and (context.executing_eagerly() or
3450                          ops.get_default_graph()._building_function):  # pylint: disable=protected-access
3451      self._cache = _MemoryCache()
3452      variant_tensor = gen_dataset_ops.cache_dataset_v2(
3453          input_dataset._variant_tensor,  # pylint: disable=protected-access
3454          filename=self._filename,
3455          cache=self._cache.handle,
3456          **self._flat_structure)
3457    else:
3458      variant_tensor = gen_dataset_ops.cache_dataset(
3459          input_dataset._variant_tensor,  # pylint: disable=protected-access
3460          filename=self._filename,
3461          **self._flat_structure)
3462    super(CacheDataset, self).__init__(input_dataset, variant_tensor)
3463
3464
3465class _RandomSeedGeneratorDeleter(object):
3466  """An object which cleans up an anonymous random seed generator resource.
3467
3468  An alternative to defining a __del__ method on an object. Even if the parent
3469  object is part of a reference cycle, the cycle will be collectable.
3470  """
3471
3472  def __init__(self, handle, device, deleter):
3473    self._deleter = deleter
3474    self._handle = handle
3475    self._device = device
3476    self._eager_mode = context.executing_eagerly()
3477
3478  def __del__(self):
3479    with ops.device(self._device):
3480      # Make sure the resource is deleted in the same mode as it was created in.
3481      if self._eager_mode:
3482        with context.eager_mode():
3483          gen_dataset_ops.delete_random_seed_generator(
3484              handle=self._handle, deleter=self._deleter)
3485      else:
3486        with context.graph_mode():
3487          gen_dataset_ops.delete_random_seed_generator(
3488              handle=self._handle, deleter=self._deleter)
3489
3490
3491class _RandomSeedGenerator(object):
3492  """Represents a random seed generator resource."""
3493
3494  def __init__(self, seed, seed2):
3495    super(_RandomSeedGenerator, self).__init__()
3496    self._device = context.context().device_name
3497    self._handle, self._deleter = (
3498        gen_dataset_ops.anonymous_random_seed_generator(seed=seed, seed2=seed2))
3499    self._resource_deleter = _RandomSeedGeneratorDeleter(
3500        handle=self._handle, device=self._device, deleter=self._deleter)
3501
3502  @property
3503  def handle(self):
3504    return self._handle
3505
3506
3507class ShuffleDataset(UnaryUnchangedStructureDataset):
3508  """A `Dataset` that randomly shuffles the elements of its input."""
3509
3510  def __init__(self,
3511               input_dataset,
3512               buffer_size,
3513               seed=None,
3514               reshuffle_each_iteration=None):
3515    """Randomly shuffles the elements of this dataset.
3516
3517    Args:
3518      input_dataset: The input dataset.
3519      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
3520        elements from this dataset from which the new dataset will sample.
3521      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
3522        seed that will be used to create the distribution. See
3523        `tf.random.set_seed` for behavior.
3524      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
3525        that the dataset should be pseudorandomly reshuffled each time it is
3526        iterated over. (Defaults to `True`.)
3527
3528    Returns:
3529      A `Dataset`.
3530
3531    Raises:
3532      ValueError: if invalid arguments are provided.
3533    """
3534    self._input_dataset = input_dataset
3535    self._buffer_size = ops.convert_to_tensor(
3536        buffer_size, dtype=dtypes.int64, name="buffer_size")
3537    self._seed, self._seed2 = random_seed.get_seed(seed)
3538
3539    if reshuffle_each_iteration is None:
3540      self._reshuffle_each_iteration = True
3541    else:
3542      self._reshuffle_each_iteration = reshuffle_each_iteration
3543
3544    if tf2.enabled() and self._reshuffle_each_iteration and (
3545        context.executing_eagerly() or
3546        ops.get_default_graph()._building_function):  # pylint: disable=protected-access
3547      self._seed_generator = _RandomSeedGenerator(self._seed, self._seed2)
3548      variant_tensor = gen_dataset_ops.shuffle_dataset_v2(
3549          input_dataset._variant_tensor,  # pylint: disable=protected-access
3550          buffer_size=self._buffer_size,
3551          seed_generator=self._seed_generator.handle,
3552          **self._flat_structure)
3553    else:
3554      variant_tensor = gen_dataset_ops.shuffle_dataset(
3555          input_dataset._variant_tensor,  # pylint: disable=protected-access
3556          buffer_size=self._buffer_size,
3557          seed=self._seed,
3558          seed2=self._seed2,
3559          reshuffle_each_iteration=self._reshuffle_each_iteration,
3560          **self._flat_structure)
3561    super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
3562
3563
3564class TakeDataset(UnaryUnchangedStructureDataset):
3565  """A `Dataset` containing the first `count` elements from its input."""
3566
3567  def __init__(self, input_dataset, count):
3568    """See `Dataset.take()` for details."""
3569    self._input_dataset = input_dataset
3570    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
3571    variant_tensor = gen_dataset_ops.take_dataset(
3572        input_dataset._variant_tensor,  # pylint: disable=protected-access
3573        count=self._count,
3574        **self._flat_structure)
3575    super(TakeDataset, self).__init__(input_dataset, variant_tensor)
3576
3577
3578class SkipDataset(UnaryUnchangedStructureDataset):
3579  """A `Dataset` skipping the first `count` elements from its input."""
3580
3581  def __init__(self, input_dataset, count):
3582    """See `Dataset.skip()` for details."""
3583    self._input_dataset = input_dataset
3584    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
3585    variant_tensor = gen_dataset_ops.skip_dataset(
3586        input_dataset._variant_tensor,  # pylint: disable=protected-access
3587        count=self._count,
3588        **self._flat_structure)
3589    super(SkipDataset, self).__init__(input_dataset, variant_tensor)
3590
3591
3592class ShardDataset(UnaryUnchangedStructureDataset):
3593  """A `Dataset` for sharding its input."""
3594
3595  def __init__(self, input_dataset, num_shards, index):
3596    """See `Dataset.shard()` for details."""
3597    self._input_dataset = input_dataset
3598    self._num_shards = ops.convert_to_tensor(
3599        num_shards, dtype=dtypes.int64, name="num_shards")
3600    self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
3601    variant_tensor = gen_dataset_ops.shard_dataset(
3602        input_dataset._variant_tensor,  # pylint: disable=protected-access
3603        num_shards=self._num_shards,
3604        index=self._index,
3605        **self._flat_structure)
3606    super(ShardDataset, self).__init__(input_dataset, variant_tensor)
3607
3608
3609class BatchDataset(UnaryDataset):
3610  """A `Dataset` that batches contiguous elements from its input."""
3611
3612  def __init__(self, input_dataset, batch_size, drop_remainder):
3613    """See `Dataset.batch()` for details."""
3614    self._input_dataset = input_dataset
3615    self._batch_size = ops.convert_to_tensor(
3616        batch_size, dtype=dtypes.int64, name="batch_size")
3617    self._drop_remainder = ops.convert_to_tensor(
3618        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
3619
3620    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
3621    # pylint: disable=protected-access
3622    if constant_drop_remainder:
3623      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
3624      # or `False` (explicitly retaining the remainder).
3625      # pylint: disable=g-long-lambda
3626      constant_batch_size = tensor_util.constant_value(self._batch_size)
3627      self._structure = nest.map_structure(
3628          lambda component_spec: component_spec._batch(constant_batch_size),
3629          input_dataset.element_spec)
3630    else:
3631      self._structure = nest.map_structure(
3632          lambda component_spec: component_spec._batch(None),
3633          input_dataset.element_spec)
3634    variant_tensor = gen_dataset_ops.batch_dataset_v2(
3635        input_dataset._variant_tensor,
3636        batch_size=self._batch_size,
3637        drop_remainder=self._drop_remainder,
3638        **self._flat_structure)
3639    super(BatchDataset, self).__init__(input_dataset, variant_tensor)
3640
3641  @property
3642  def element_spec(self):
3643    return self._structure
3644
3645
3646class _NumpyIterator(object):
3647  """Iterator over a dataset with elements converted to numpy."""
3648
3649  def __init__(self, dataset):
3650    self._iterator = iter(dataset)
3651
3652  def __iter__(self):
3653    return self
3654
3655  def next(self):
3656    return nest.map_structure(lambda x: x.numpy(), next(self._iterator))
3657
3658  def __next__(self):
3659    return self.next()
3660
3661
3662class _VariantTracker(tracking.CapturableResource):
3663  """Allows export of functions capturing a Dataset in SavedModels.
3664
3665  When saving a SavedModel, `tf.saved_model.save` traverses the object
3666  graph. Since Datasets reference _VariantTracker objects, that traversal will
3667  find a _VariantTracker for each Dataset and so know how to save and restore
3668  functions which reference the Dataset's variant Tensor.
3669  """
3670
3671  def __init__(self, variant_tensor, resource_creator):
3672    """Record that `variant_tensor` is associated with `resource_creator`.
3673
3674    Args:
3675      variant_tensor: The variant-dtype Tensor associated with the Dataset. This
3676        Tensor will be a captured input to functions which use the Dataset, and
3677        is used by saving code to identify the corresponding _VariantTracker.
3678      resource_creator: A zero-argument function which creates a new
3679        variant-dtype Tensor. This function will be included in SavedModels and
3680        run to re-create the Dataset's variant Tensor on restore.
3681    """
3682    super(_VariantTracker, self).__init__(device="CPU")
3683    self._resource_handle = variant_tensor
3684    self._create_resource = resource_creator
3685
3686
3687def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
3688  """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
3689
3690  Args:
3691    padded_shape: A `tf.TensorShape`.
3692    input_component_shape: A `tf.TensorShape`.
3693
3694  Returns:
3695    `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
3696    `False`.
3697  """
3698
3699  if padded_shape.dims is None or input_component_shape.dims is None:
3700    return True
3701  if len(padded_shape.dims) != len(input_component_shape.dims):
3702    return False
3703  for padded_dim, input_dim in zip(
3704      padded_shape.dims, input_component_shape.dims):
3705    if (padded_dim.value is not None and input_dim.value is not None
3706        and padded_dim.value < input_dim.value):
3707      return False
3708  return True
3709
3710
3711def _padded_shape_to_tensor(padded_shape, input_component_shape):
3712  """Converts `padded_shape` to a `tf.Tensor` representing that shape.
3713
3714  Args:
3715    padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
3716      sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
3717    input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
3718      be compatible.
3719
3720  Returns:
3721    A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
3722
3723  Raises:
3724    ValueError: If `padded_shape` is not a shape or not compatible with
3725      `input_component_shape`.
3726    TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
3727  """
3728  try:
3729    # Try to convert the `padded_shape` to a `tf.TensorShape`
3730    padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
3731    # We will return the "canonical" tensor representation, which uses
3732    # `-1` in place of `None`.
3733    ret = ops.convert_to_tensor(
3734        [dim if dim is not None else -1
3735         for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
3736  except (TypeError, ValueError):
3737    # The argument was not trivially convertible to a
3738    # `tf.TensorShape`, so fall back on the conversion to tensor
3739    # machinery.
3740    ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
3741    if ret.shape.dims is not None and len(ret.shape.dims) != 1:
3742      six.reraise(ValueError, ValueError(
3743          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
3744          "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2])
3745    if ret.dtype != dtypes.int64:
3746      six.reraise(
3747          TypeError,
3748          TypeError(
3749              "Padded shape %s must be a 1-D tensor of tf.int64 values, but "
3750              "its element type was %s." % (padded_shape, ret.dtype.name)),
3751          sys.exc_info()[2])
3752    padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
3753
3754  if not _is_padded_shape_compatible_with(padded_shape_as_shape,
3755                                          input_component_shape):
3756    raise ValueError("The padded shape %s is not compatible with the "
3757                     "corresponding input component shape %s."
3758                     % (padded_shape_as_shape, input_component_shape))
3759
3760  return ret
3761
3762
3763def _padding_value_to_tensor(value, output_type):
3764  """Converts the padding value to a tensor.
3765
3766  Args:
3767    value: The padding value.
3768    output_type: Its expected dtype.
3769
3770  Returns:
3771    A scalar `Tensor`.
3772
3773  Raises:
3774    ValueError: if the padding value is not a scalar.
3775    TypeError: if the padding value's type does not match `output_type`.
3776  """
3777  value = ops.convert_to_tensor(value, name="padding_value")
3778  if not value.shape.is_compatible_with(tensor_shape.TensorShape([])):
3779    raise ValueError("Padding value should be a scalar, but is not: %s" % value)
3780  if value.dtype != output_type:
3781    raise TypeError("Padding value tensor (%s) does not match output type: %s" %
3782                    (value, output_type))
3783  return value
3784
3785
3786def _padding_values_or_default(padding_values, input_dataset):
3787  """Returns padding values with None elements replaced with default values."""
3788  def make_zero(t):
3789    if t.base_dtype == dtypes.string:
3790      return ""
3791    elif t.base_dtype == dtypes.variant:
3792      error_msg = ("Unable to create padding for field of type 'variant' "
3793                   "because t.base_type == dtypes.variant == "
3794                   "{}.".format(
3795                       t.base_dtype))
3796      raise TypeError(error_msg)
3797    else:
3798      return np.zeros_like(t.as_numpy_dtype())
3799  def value_or_default(value, default):
3800    return default if value is None else value
3801
3802  default_padding = nest.map_structure(make_zero,
3803                                       get_legacy_output_types(input_dataset))
3804  return nest.map_structure_up_to(padding_values, value_or_default,
3805                                  padding_values, default_padding)
3806
3807
3808class PaddedBatchDataset(UnaryDataset):
3809  """A `Dataset` that batches and pads contiguous elements from its input."""
3810
3811  def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
3812               drop_remainder):
3813    """See `Dataset.batch()` for details."""
3814    self._input_dataset = input_dataset
3815    if sparse.any_sparse(get_legacy_output_classes(input_dataset)):
3816      # TODO(b/63669786): support batching of sparse tensors
3817      raise TypeError(
3818          "Batching of padded sparse tensors is not currently supported")
3819    self._input_dataset = input_dataset
3820    self._batch_size = ops.convert_to_tensor(
3821        batch_size, dtype=dtypes.int64, name="batch_size")
3822    padding_values = _padding_values_or_default(padding_values, input_dataset)
3823
3824    input_shapes = get_legacy_output_shapes(input_dataset)
3825    flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
3826
3827    flat_padded_shapes_as_tensors = []
3828
3829    for input_component_shape, padded_shape in zip(
3830        nest.flatten(input_shapes), flat_padded_shapes):
3831      flat_padded_shapes_as_tensors.append(
3832          _padded_shape_to_tensor(padded_shape, input_component_shape))
3833
3834    self._padded_shapes = nest.pack_sequence_as(input_shapes,
3835                                                flat_padded_shapes_as_tensors)
3836
3837    self._padding_values = nest.map_structure_up_to(
3838        input_shapes, _padding_value_to_tensor, padding_values,
3839        get_legacy_output_types(input_dataset))
3840    self._drop_remainder = ops.convert_to_tensor(
3841        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
3842
3843    def _padded_shape_to_batch_shape(s):
3844      return tensor_shape.TensorShape([
3845          tensor_util.constant_value(self._batch_size)
3846          if smart_cond.smart_constant_value(self._drop_remainder) else None
3847      ]).concatenate(tensor_util.constant_value_as_shape(s))
3848
3849    output_shapes = nest.map_structure(
3850        _padded_shape_to_batch_shape, self._padded_shapes)
3851    self._structure = structure.convert_legacy_structure(
3852        get_legacy_output_types(self._input_dataset), output_shapes,
3853        get_legacy_output_classes(self._input_dataset))
3854
3855    # pylint: disable=protected-access
3856    # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
3857    if smart_cond.smart_constant_value(self._drop_remainder) is False:
3858      variant_tensor = gen_dataset_ops.padded_batch_dataset(
3859          input_dataset._variant_tensor,  # pylint: disable=protected-access
3860          batch_size=self._batch_size,
3861          padded_shapes=[
3862              ops.convert_to_tensor(s, dtype=dtypes.int64)
3863              for s in nest.flatten(self._padded_shapes)
3864          ],
3865          padding_values=nest.flatten(self._padding_values),
3866          output_shapes=structure.get_flat_tensor_shapes(self._structure))
3867    else:
3868      variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
3869          input_dataset._variant_tensor,  # pylint: disable=protected-access
3870          batch_size=self._batch_size,
3871          padded_shapes=[
3872              ops.convert_to_tensor(s, dtype=dtypes.int64)
3873              for s in nest.flatten(self._padded_shapes)
3874          ],
3875          padding_values=nest.flatten(self._padding_values),
3876          drop_remainder=self._drop_remainder,
3877          output_shapes=structure.get_flat_tensor_shapes(self._structure))
3878    super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
3879
3880  @property
3881  def element_spec(self):
3882    return self._structure
3883
3884
3885def _should_unpack_args(args):
3886  """Returns `True` if `args` should be `*args` when passed to a callable."""
3887  return type(args) is tuple  # pylint: disable=unidiomatic-typecheck
3888
3889
3890class MapDataset(UnaryDataset):
3891  """A `Dataset` that maps a function over elements in its input."""
3892
3893  def __init__(self,
3894               input_dataset,
3895               map_func,
3896               use_inter_op_parallelism=True,
3897               preserve_cardinality=False,
3898               use_legacy_function=False):
3899    """See `Dataset.map()` for details."""
3900    self._input_dataset = input_dataset
3901    self._use_inter_op_parallelism = use_inter_op_parallelism
3902    self._preserve_cardinality = preserve_cardinality
3903    self._map_func = StructuredFunctionWrapper(
3904        map_func,
3905        self._transformation_name(),
3906        dataset=input_dataset,
3907        use_legacy_function=use_legacy_function)
3908    variant_tensor = gen_dataset_ops.map_dataset(
3909        input_dataset._variant_tensor,  # pylint: disable=protected-access
3910        self._map_func.function.captured_inputs,
3911        f=self._map_func.function,
3912        use_inter_op_parallelism=self._use_inter_op_parallelism,
3913        preserve_cardinality=self._preserve_cardinality,
3914        **self._flat_structure)
3915    super(MapDataset, self).__init__(input_dataset, variant_tensor)
3916
3917  def _functions(self):
3918    return [self._map_func]
3919
3920  @property
3921  def element_spec(self):
3922    return self._map_func.output_structure
3923
3924  def _transformation_name(self):
3925    return "Dataset.map()"
3926
3927
3928class ParallelMapDataset(UnaryDataset):
3929  """A `Dataset` that maps a function over elements in its input in parallel."""
3930
3931  def __init__(self,
3932               input_dataset,
3933               map_func,
3934               num_parallel_calls,
3935               use_inter_op_parallelism=True,
3936               preserve_cardinality=False,
3937               use_legacy_function=False):
3938    """See `Dataset.map()` for details."""
3939    self._input_dataset = input_dataset
3940    self._use_inter_op_parallelism = use_inter_op_parallelism
3941    self._map_func = StructuredFunctionWrapper(
3942        map_func,
3943        self._transformation_name(),
3944        dataset=input_dataset,
3945        use_legacy_function=use_legacy_function)
3946    self._num_parallel_calls = ops.convert_to_tensor(
3947        num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
3948    self._preserve_cardinality = preserve_cardinality
3949    variant_tensor = gen_dataset_ops.parallel_map_dataset(
3950        input_dataset._variant_tensor,  # pylint: disable=protected-access
3951        self._map_func.function.captured_inputs,
3952        f=self._map_func.function,
3953        num_parallel_calls=self._num_parallel_calls,
3954        use_inter_op_parallelism=self._use_inter_op_parallelism,
3955        preserve_cardinality=self._preserve_cardinality,
3956        **self._flat_structure)
3957    super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
3958
3959  def _functions(self):
3960    return [self._map_func]
3961
3962  @property
3963  def element_spec(self):
3964    return self._map_func.output_structure
3965
3966  def _transformation_name(self):
3967    return "Dataset.map()"
3968
3969
3970class FlatMapDataset(UnaryDataset):
3971  """A `Dataset` that maps a function over its input and flattens the result."""
3972
3973  def __init__(self, input_dataset, map_func):
3974    """See `Dataset.flat_map()` for details."""
3975    self._input_dataset = input_dataset
3976    self._map_func = StructuredFunctionWrapper(
3977        map_func, self._transformation_name(), dataset=input_dataset)
3978    if not isinstance(self._map_func.output_structure, DatasetSpec):
3979      raise TypeError(
3980          "`map_func` must return a `Dataset` object. Got {}".format(
3981              type(self._map_func.output_structure)))
3982    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
3983    variant_tensor = gen_dataset_ops.flat_map_dataset(
3984        input_dataset._variant_tensor,  # pylint: disable=protected-access
3985        self._map_func.function.captured_inputs,
3986        f=self._map_func.function,
3987        **self._flat_structure)
3988    super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
3989
3990  def _functions(self):
3991    return [self._map_func]
3992
3993  @property
3994  def element_spec(self):
3995    return self._structure
3996
3997  def _transformation_name(self):
3998    return "Dataset.flat_map()"
3999
4000
4001class InterleaveDataset(UnaryDataset):
4002  """A `Dataset` that interleaves the result of transformed inputs."""
4003
4004  def __init__(self, input_dataset, map_func, cycle_length, block_length):
4005    """See `Dataset.interleave()` for details."""
4006    self._input_dataset = input_dataset
4007    self._map_func = StructuredFunctionWrapper(
4008        map_func, self._transformation_name(), dataset=input_dataset)
4009    if not isinstance(self._map_func.output_structure, DatasetSpec):
4010      raise TypeError(
4011          "`map_func` must return a `Dataset` object. Got {}".format(
4012              type(self._map_func.output_structure)))
4013    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
4014    self._cycle_length = ops.convert_to_tensor(
4015        cycle_length, dtype=dtypes.int64, name="cycle_length")
4016    self._block_length = ops.convert_to_tensor(
4017        block_length, dtype=dtypes.int64, name="block_length")
4018
4019    variant_tensor = gen_dataset_ops.interleave_dataset(
4020        input_dataset._variant_tensor,  # pylint: disable=protected-access
4021        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
4022        self._cycle_length,
4023        self._block_length,
4024        f=self._map_func.function,
4025        **self._flat_structure)
4026    super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
4027
4028  def _functions(self):
4029    return [self._map_func]
4030
4031  @property
4032  def element_spec(self):
4033    return self._structure
4034
4035  def _transformation_name(self):
4036    return "Dataset.interleave()"
4037
4038
4039class ParallelInterleaveDataset(UnaryDataset):
4040  """A `Dataset` that maps a function over its input and interleaves the result."""
4041
4042  def __init__(self,
4043               input_dataset,
4044               map_func,
4045               cycle_length,
4046               block_length,
4047               num_parallel_calls,
4048               deterministic=None):
4049    """See `Dataset.interleave()` for details."""
4050    self._input_dataset = input_dataset
4051    self._map_func = StructuredFunctionWrapper(
4052        map_func, self._transformation_name(), dataset=input_dataset)
4053    if not isinstance(self._map_func.output_structure, DatasetSpec):
4054      raise TypeError(
4055          "`map_func` must return a `Dataset` object. Got {}".format(
4056              type(self._map_func.output_structure)))
4057    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
4058    self._cycle_length = ops.convert_to_tensor(
4059        cycle_length, dtype=dtypes.int64, name="cycle_length")
4060    self._block_length = ops.convert_to_tensor(
4061        block_length, dtype=dtypes.int64, name="block_length")
4062    self._num_parallel_calls = ops.convert_to_tensor(
4063        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
4064    if deterministic is None:
4065      deterministic_string = "default"
4066    elif deterministic:
4067      deterministic_string = "true"
4068    else:
4069      deterministic_string = "false"
4070
4071    if deterministic is not None or compat.forward_compatible(2020, 2, 20):
4072      variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
4073          input_dataset._variant_tensor,  # pylint: disable=protected-access
4074          self._map_func.function.captured_inputs,  # pylint: disable=protected-access
4075          self._cycle_length,
4076          self._block_length,
4077          self._num_parallel_calls,
4078          f=self._map_func.function,
4079          deterministic=deterministic_string,
4080          **self._flat_structure)
4081    else:
4082      variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
4083          input_dataset._variant_tensor,  # pylint: disable=protected-access
4084          self._map_func.function.captured_inputs,  # pylint: disable=protected-access
4085          self._cycle_length,
4086          self._block_length,
4087          self._num_parallel_calls,
4088          f=self._map_func.function,
4089          **self._flat_structure)
4090    super(ParallelInterleaveDataset, self).__init__(input_dataset,
4091                                                    variant_tensor)
4092
4093  def _functions(self):
4094    return [self._map_func]
4095
4096  @property
4097  def element_spec(self):
4098    return self._structure
4099
4100  def _transformation_name(self):
4101    return "Dataset.interleave()"
4102
4103
4104class FilterDataset(UnaryUnchangedStructureDataset):
4105  """A `Dataset` that filters its input according to a predicate function."""
4106
4107  def __init__(self, input_dataset, predicate, use_legacy_function=False):
4108    """See `Dataset.filter()` for details."""
4109    self._input_dataset = input_dataset
4110    wrapped_func = StructuredFunctionWrapper(
4111        predicate,
4112        self._transformation_name(),
4113        dataset=input_dataset,
4114        use_legacy_function=use_legacy_function)
4115    if not wrapped_func.output_structure.is_compatible_with(
4116        tensor_spec.TensorSpec([], dtypes.bool)):
4117      error_msg = ("`predicate` return type must be convertible to a scalar "
4118                   "boolean tensor. Was {}.").format(
4119                       wrapped_func.output_structure)
4120      raise ValueError(error_msg)
4121    self._predicate = wrapped_func
4122    variant_tensor = gen_dataset_ops.filter_dataset(
4123        input_dataset._variant_tensor,  # pylint: disable=protected-access
4124        other_arguments=self._predicate.function.captured_inputs,
4125        predicate=self._predicate.function,
4126        **self._flat_structure)
4127    super(FilterDataset, self).__init__(input_dataset, variant_tensor)
4128
4129  def _functions(self):
4130    return [self._predicate]
4131
4132  def _transformation_name(self):
4133    return "Dataset.filter()"
4134
4135
4136class PrefetchDataset(UnaryUnchangedStructureDataset):
4137  """A `Dataset` that asynchronously prefetches its input."""
4138
4139  def __init__(self, input_dataset, buffer_size, slack_period=None):
4140    """See `Dataset.prefetch()` for details.
4141
4142    Args:
4143      input_dataset: The input dataset.
4144      buffer_size: See `Dataset.prefetch()` for details.
4145      slack_period: (Optional.) An integer. If non-zero, determines the number
4146        of GetNext calls before injecting slack into the execution. This may
4147        reduce CPU contention at the start of a step. Note that a tensorflow
4148        user should not have to set this manually; enable this behavior
4149        automatically via `tf.data.Options.experimental_slack` instead. Defaults
4150        to None.
4151    """
4152    self._input_dataset = input_dataset
4153    if buffer_size is None:
4154      buffer_size = -1  # This is the sentinel for auto-tuning.
4155    self._buffer_size = ops.convert_to_tensor(
4156        buffer_size, dtype=dtypes.int64, name="buffer_size")
4157    variant_tensor = gen_dataset_ops.prefetch_dataset(
4158        input_dataset._variant_tensor,  # pylint: disable=protected-access
4159        buffer_size=self._buffer_size,
4160        slack_period=slack_period,
4161        **self._flat_structure)
4162    super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
4163
4164
4165class WindowDataset(UnaryDataset):
4166  """A dataset that creates window datasets from the input elements."""
4167
4168  def __init__(self, input_dataset, size, shift, stride, drop_remainder):
4169    """See `window_dataset()` for more details."""
4170    self._input_dataset = input_dataset
4171    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
4172    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
4173    self._stride = ops.convert_to_tensor(
4174        stride, dtype=dtypes.int64, name="stride")
4175    self._drop_remainder = ops.convert_to_tensor(
4176        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4177    self._structure = nest.pack_sequence_as(
4178        get_legacy_output_classes(input_dataset), [
4179            DatasetSpec(  # pylint: disable=g-complex-comprehension
4180                structure.convert_legacy_structure(
4181                    output_type, output_shape, output_class))
4182            for output_class, output_shape, output_type in zip(
4183                nest.flatten(get_legacy_output_classes(input_dataset)),
4184                nest.flatten(get_legacy_output_shapes(input_dataset)),
4185                nest.flatten(get_legacy_output_types(input_dataset)))
4186        ])
4187    variant_tensor = gen_dataset_ops.window_dataset(
4188        input_dataset._variant_tensor,  # pylint: disable=protected-access
4189        self._size,
4190        self._shift,
4191        self._stride,
4192        self._drop_remainder,
4193        **self._flat_structure)
4194    super(WindowDataset, self).__init__(input_dataset, variant_tensor)
4195
4196  @property
4197  def element_spec(self):
4198    return self._structure
4199
4200
4201class _OptionsDataset(UnaryUnchangedStructureDataset):
4202  """An identity `Dataset` that stores options."""
4203
4204  def __init__(self, input_dataset, options):
4205    self._input_dataset = input_dataset
4206    self._options = input_dataset.options()
4207    if self._options:
4208      self._options = self._options.merge(options)
4209    else:
4210      self._options = options
4211    variant_tensor = input_dataset._variant_tensor  # pylint: disable=protected-access
4212    super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
4213
4214  def options(self):
4215    return self._options
4216
4217
4218class _ModelDataset(UnaryUnchangedStructureDataset):
4219  """A `Dataset` that acts as an identity, and models performance."""
4220
4221  def __init__(self, input_dataset, algorithm, cpu_budget):
4222    self._input_dataset = input_dataset
4223    variant_tensor = gen_dataset_ops.model_dataset(
4224        input_dataset._variant_tensor,  # pylint: disable=protected-access
4225        algorithm=algorithm.value,
4226        cpu_budget=cpu_budget,
4227        **self._flat_structure)
4228    super(_ModelDataset, self).__init__(input_dataset, variant_tensor)
4229
4230
4231class _OptimizeDataset(UnaryUnchangedStructureDataset):
4232  """A `Dataset` that acts as an identity, and applies optimizations."""
4233
4234  def __init__(self, input_dataset, optimizations, optimization_configs=None):
4235    self._input_dataset = input_dataset
4236    if optimizations is None:
4237      optimizations = []
4238    if optimization_configs is None:
4239      optimization_configs = []
4240    self._optimizations = ops.convert_to_tensor(
4241        optimizations, dtype=dtypes.string, name="optimizations")
4242    variant_tensor = gen_dataset_ops.optimize_dataset(
4243        input_dataset._variant_tensor,  # pylint: disable=protected-access
4244        self._optimizations,
4245        optimization_configs=optimization_configs,
4246        **self._flat_structure)
4247    super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
4248
4249
4250class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
4251  """A `Dataset` that acts as an identity, and sets a stats aggregator."""
4252
4253  def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
4254    self._input_dataset = input_dataset
4255    self._stats_aggregator = aggregator
4256    self._prefix = prefix
4257    self._counter_prefix = counter_prefix
4258    variant_tensor = ged_ops.set_stats_aggregator_dataset(
4259        input_dataset._variant_tensor,  # pylint: disable=protected-access
4260        self._stats_aggregator._resource,  # pylint: disable=protected-access
4261        self._prefix,
4262        self._counter_prefix,
4263        **self._flat_structure)
4264    super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
4265                                                     variant_tensor)
4266
4267
4268class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
4269  """A `Dataset` that acts as an identity, overriding intra-op parallelism."""
4270
4271  def __init__(self, input_dataset, max_intra_op_parallelism):
4272    self._input_dataset = input_dataset
4273    self._max_intra_op_parallelism = ops.convert_to_tensor(
4274        max_intra_op_parallelism,
4275        dtype=dtypes.int64,
4276        name="max_intra_op_parallelism")
4277    variant_tensor = ged_ops.max_intra_op_parallelism_dataset(
4278        input_dataset._variant_tensor,  # pylint: disable=protected-access
4279        self._max_intra_op_parallelism,
4280        **self._flat_structure)
4281    super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
4282                                                        variant_tensor)
4283
4284
4285class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
4286  """A `Dataset` that acts as an identity, setting a private threadpool."""
4287
4288  def __init__(self, input_dataset, num_threads):
4289    self._input_dataset = input_dataset
4290    self._num_threads = ops.convert_to_tensor(
4291        num_threads, dtype=dtypes.int64, name="num_threads")
4292    variant_tensor = ged_ops.private_thread_pool_dataset(
4293        input_dataset._variant_tensor,  # pylint: disable=protected-access
4294        self._num_threads,
4295        **self._flat_structure)
4296    super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
4297                                                    variant_tensor)
4298
4299
4300def normalize_to_dense(dataset):
4301  """Normalizes non-tensor components in a dataset to dense representations.
4302
4303  This is necessary for dataset transformations that slice along the batch
4304  dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`.
4305
4306  Args:
4307    dataset: Dataset to normalize.
4308
4309  Returns:
4310    A dataset whose sparse and ragged tensors have been normalized to their
4311    dense representations.
4312  """
4313
4314  # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all
4315  # non-tensor components.
4316  #
4317  # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck.
4318  if _should_unpack_args(dataset.element_spec):
4319    def normalize(*args):
4320      return structure.to_batched_tensor_list(dataset.element_spec, tuple(args))
4321  else:
4322    def normalize(arg):
4323      return structure.to_batched_tensor_list(dataset.element_spec, arg)
4324
4325  normalized_dataset = dataset.map(normalize)
4326
4327  # NOTE(mrry): Our `map()` has lost information about the structure of
4328  # non-tensor components, so re-apply the structure of the original dataset.
4329  return _RestructuredDataset(normalized_dataset, dataset.element_spec)
4330
4331
4332class _RestructuredDataset(UnaryDataset):
4333  """An internal helper for changing the structure and shape of a dataset."""
4334
4335  def __init__(self, dataset, structure):
4336    self._input_dataset = dataset
4337    self._structure = structure
4338
4339    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
4340    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
4341
4342  @property
4343  def element_spec(self):
4344    return self._structure
4345
4346
4347class _UnbatchDataset(UnaryDataset):
4348  """A dataset that splits the elements of its input into multiple elements."""
4349
4350  def __init__(self, input_dataset):
4351    """See `unbatch()` for more details."""
4352    flat_shapes = input_dataset._flat_shapes  # pylint: disable=protected-access
4353    if any(s.ndims == 0 for s in flat_shapes):
4354      raise ValueError("Cannot unbatch an input with scalar components.")
4355    known_batch_dim = tensor_shape.Dimension(None)
4356    for s in flat_shapes:
4357      try:
4358        known_batch_dim = known_batch_dim.merge_with(s[0])
4359      except ValueError:
4360        raise ValueError("Cannot unbatch an input whose components have "
4361                         "different batch sizes.")
4362    self._input_dataset = input_dataset
4363    self._structure = nest.map_structure(
4364        lambda component_spec: component_spec._unbatch(),  # pylint: disable=protected-access
4365        get_structure(input_dataset))
4366    variant_tensor = ged_ops.unbatch_dataset(
4367        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
4368        **self._flat_structure)
4369    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
4370
4371  @property
4372  def element_spec(self):
4373    return self._structure
4374
4375
4376def _collect_resource_inputs(op):
4377  """Collects resource inputs for the given ops (and its variant inputs)."""
4378
4379  def _process(op_queue, seen_ops):
4380    """Processes the next element of the op queue."""
4381
4382    result = []
4383    op = op_queue.pop()
4384    if op in seen_ops:
4385      return result
4386    seen_ops.add(op)
4387    for t in op.inputs:
4388      if t.dtype == dtypes.variant:
4389        # Conservatively assume that any variant inputs are datasets.
4390        op_queue.append(t.op)
4391      elif t.dtype == dtypes.resource:
4392        result.append(t)
4393    return result
4394
4395  op_queue = [op]
4396  seen_ops = set()
4397  resource_inputs = []
4398  while op_queue:
4399    resource_inputs.extend(_process(op_queue, seen_ops))
4400
4401  return resource_inputs
4402
4403
4404@auto_control_deps.register_acd_resource_resolver
4405def _resource_resolver(op, resource_inputs):
4406  """Updates resource inputs for tf.data ops with indirect dependencies."""
4407
4408  updated = False
4409  if op.type in [
4410      "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset"
4411  ]:
4412    indirect_resource_inputs = _collect_resource_inputs(op)
4413    for inp in indirect_resource_inputs:
4414      if inp not in resource_inputs:
4415        updated = True
4416        resource_inputs.add(inp)
4417
4418  if op.type in [
4419      "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional"
4420  ]:
4421    iterator_resource = op.inputs[0]
4422    make_iterator_ops = [
4423        op for op in iterator_resource.consumers() if op.type == "MakeIterator"
4424    ]
4425
4426    if len(make_iterator_ops) == 1:
4427      indirect_resource_inputs = _collect_resource_inputs(make_iterator_ops[0])
4428      for inp in indirect_resource_inputs:
4429        if inp not in resource_inputs:
4430          updated = True
4431          resource_inputs.add(inp)
4432
4433  return updated
4434