• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import functools
22import multiprocessing
23import sys
24import threading
25import warnings
26import weakref
27
28import numpy as np
29import six
30from six.moves import queue as Queue  # pylint: disable=redefined-builtin
31
32from tensorflow.core.framework import dataset_options_pb2
33from tensorflow.core.framework import graph_pb2
34from tensorflow.python import tf2
35from tensorflow.python.data.ops import iterator_ops
36from tensorflow.python.data.ops import options as options_lib
37from tensorflow.python.data.util import nest
38from tensorflow.python.data.util import random_seed
39from tensorflow.python.data.util import structure
40from tensorflow.python.data.util import traverse
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.eager import function as eager_function
44from tensorflow.python.framework import auto_control_deps
45from tensorflow.python.framework import auto_control_deps_utils as acd_utils
46from tensorflow.python.framework import composite_tensor
47from tensorflow.python.framework import constant_op
48from tensorflow.python.framework import dtypes
49from tensorflow.python.framework import function
50from tensorflow.python.framework import ops
51from tensorflow.python.framework import random_seed as core_random_seed
52from tensorflow.python.framework import smart_cond
53from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
54from tensorflow.python.framework import tensor_shape
55from tensorflow.python.framework import tensor_spec
56from tensorflow.python.framework import tensor_util
57from tensorflow.python.framework import type_spec
58from tensorflow.python.ops import array_ops
59from tensorflow.python.ops import check_ops
60from tensorflow.python.ops import control_flow_ops
61from tensorflow.python.ops import gen_dataset_ops
62from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
63from tensorflow.python.ops import gen_io_ops
64from tensorflow.python.ops import logging_ops
65from tensorflow.python.ops import math_ops
66from tensorflow.python.ops import random_ops
67from tensorflow.python.ops import script_ops
68from tensorflow.python.ops import string_ops
69from tensorflow.python.ops.ragged import ragged_tensor
70from tensorflow.python.training.tracking import base as tracking_base
71from tensorflow.python.training.tracking import tracking
72from tensorflow.python.util import deprecation
73from tensorflow.python.util import function_utils
74from tensorflow.python.util import lazy_loader
75from tensorflow.python.util import nest as tf_nest
76from tensorflow.python.util.compat import collections_abc
77from tensorflow.python.util.tf_export import tf_export
78
79# Loaded lazily due to a circular dependency (roughly
80# tf.function->wrap_function->dataset->autograph->tf.function).
81# TODO(b/133251390): Use a regular import.
82wrap_function = lazy_loader.LazyLoader(
83    "wrap_function", globals(),
84    "tensorflow.python.eager.wrap_function")
85# TODO(mdan): Create a public API for this.
86autograph_ctx = lazy_loader.LazyLoader(
87    "autograph_ctx", globals(),
88    "tensorflow.python.autograph.core.ag_ctx")
89autograph = lazy_loader.LazyLoader(
90    "autograph", globals(),
91    "tensorflow.python.autograph.impl.api")
92# Loaded lazily due to a circular dependency
93# dataset_ops->interleave_ops->dataset_ops
94# TODO(aaudibert): Switch to the core sample_from_datasets after it is migrated
95# out of experimental. Then we can remove this lazy loading.
96interleave_ops = lazy_loader.LazyLoader(
97    "interleave_ops", globals(),
98    "tensorflow.python.data.experimental.ops.interleave_ops"
99)
100
101ops.NotDifferentiable("ReduceDataset")
102
103# A constant that can be used to enable auto-tuning.
104AUTOTUNE = -1
105tf_export("data.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
106# TODO(b/168128531): Deprecate and remove this symbol.
107tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
108
109# Constants representing infinite and unknown cardinalities.
110INFINITE = -1
111UNKNOWN = -2
112tf_export("data.INFINITE_CARDINALITY").export_constant(__name__, "INFINITE")
113tf_export("data.UNKNOWN_CARDINALITY").export_constant(__name__, "UNKNOWN")
114
115
116@tf_export("data.Dataset", v1=[])
117@six.add_metaclass(abc.ABCMeta)
118class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
119                composite_tensor.CompositeTensor):
120  """Represents a potentially large set of elements.
121
122  The `tf.data.Dataset` API supports writing descriptive and efficient input
123  pipelines. `Dataset` usage follows a common pattern:
124
125  1. Create a source dataset from your input data.
126  2. Apply dataset transformations to preprocess the data.
127  3. Iterate over the dataset and process the elements.
128
129  Iteration happens in a streaming fashion, so the full dataset does not need to
130  fit into memory.
131
132  Source Datasets:
133
134  The simplest way to create a dataset is to create it from a python `list`:
135
136  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
137  >>> for element in dataset:
138  ...   print(element)
139  tf.Tensor(1, shape=(), dtype=int32)
140  tf.Tensor(2, shape=(), dtype=int32)
141  tf.Tensor(3, shape=(), dtype=int32)
142
143  To process lines from files, use `tf.data.TextLineDataset`:
144
145  >>> dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
146
147  To process records written in the `TFRecord` format, use `TFRecordDataset`:
148
149  >>> dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
150
151  To create a dataset of all files matching a pattern, use
152  `tf.data.Dataset.list_files`:
153
154  ```python
155  dataset = tf.data.Dataset.list_files("/path/*.txt")
156  ```
157
158  See `tf.data.FixedLengthRecordDataset` and `tf.data.Dataset.from_generator`
159  for more ways to create datasets.
160
161  Transformations:
162
163  Once you have a dataset, you can apply transformations to prepare the data for
164  your model:
165
166  >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
167  >>> dataset = dataset.map(lambda x: x*2)
168  >>> list(dataset.as_numpy_iterator())
169  [2, 4, 6]
170
171  Common Terms:
172
173  **Element**: A single output from calling `next()` on a dataset iterator.
174    Elements may be nested structures containing multiple components. For
175    example, the element `(1, (3, "apple"))` has one tuple nested in another
176    tuple. The components are `1`, `3`, and `"apple"`.
177
178  **Component**: The leaf in the nested structure of an element.
179
180  Supported types:
181
182  Elements can be nested structures of tuples, named tuples, and dictionaries.
183  Note that Python lists are *not* treated as nested structures of components.
184  Instead, lists are converted to tensors and treated as components. For
185  example, the element `(1, [1, 2, 3])` has only two components; the tensor `1`
186  and the tensor `[1, 2, 3]`. Element components can be of any type
187  representable by `tf.TypeSpec`, including `tf.Tensor`, `tf.data.Dataset`,
188  `tf.sparse.SparseTensor`, `tf.RaggedTensor`, and `tf.TensorArray`.
189
190  ```python
191  a = 1 # Integer element
192  b = 2.0 # Float element
193  c = (1, 2) # Tuple element with 2 components
194  d = {"a": (2, 2), "b": 3} # Dict element with 3 components
195  Point = collections.namedtuple("Point", ["x", "y"])
196  e = Point(1, 2) # Named tuple
197  f = tf.data.Dataset.range(10) # Dataset element
198  ```
199
200  For more information,
201  read [this guide](https://www.tensorflow.org/guide/data).
202  """
203
204  def __init__(self, variant_tensor):
205    """Creates a DatasetV2 object.
206
207    This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
208    take anything in its constructor whereas in the DatasetV2, we expect
209    subclasses to create a variant_tensor and pass it in to the super() call.
210
211    Args:
212      variant_tensor: A DT_VARIANT tensor that represents the dataset.
213    """
214    self._variant_tensor_attr = variant_tensor
215    weak_self = weakref.proxy(self)
216    self._variant_tracker = self._track_trackable(
217        _VariantTracker(
218            self._variant_tensor,
219            # _trace_variant_creation only works when executing eagerly, so we
220            # don't want to run it immediately. We also want the _VariantTracker
221            # to have a weak reference to the Dataset to avoid creating
222            # reference cycles and making work for the garbage collector.
223            lambda: weak_self._trace_variant_creation()()),  # pylint: disable=unnecessary-lambda,protected-access
224        name="_variant_tracker")
225    self._graph_attr = ops.get_default_graph()
226
227    # Initialize the options for this dataset and its inputs.
228    self._options_attr = options_lib.Options()
229    for input_dataset in self._inputs():
230      input_options = None
231      if isinstance(input_dataset, DatasetV1):
232        # If the V1 dataset does not have the `_dataset` attribute, we assume it
233        # is a dataset source and hence does not have options. Otherwise, we
234        # grab the options of `_dataset` object
235        if hasattr(input_dataset, "_dataset"):
236          if not isinstance(input_dataset._dataset, DatasetV2):
237            raise AssertionError(
238                "The input_dataset._dataset of dataset %s should be DatasetV2."
239                % type(self))
240          input_options = input_dataset._dataset._options_attr
241      elif isinstance(input_dataset, DatasetV2):
242        input_options = input_dataset._options_attr
243      else:
244        raise TypeError("Unexpected dataset type: ", type(input_dataset))
245      if input_options is not None:
246        self._options_attr = self._options_attr.merge(input_options)
247    self._options_attr._set_mutable(False)  # pylint: disable=protected-access
248
249  @property
250  def _variant_tensor(self):
251    return self._variant_tensor_attr
252
253  @_variant_tensor.setter
254  def _variant_tensor(self, _):
255    raise ValueError("The _variant_tensor property is read-only")
256
257  @deprecation.deprecated_args(None, "Use external_state_policy instead",
258                               "allow_stateful")
259  def _as_serialized_graph(
260      self,
261      allow_stateful=None,
262      strip_device_assignment=None,
263      external_state_policy=options_lib.ExternalStatePolicy.WARN):
264    """Produces serialized graph representation of the dataset.
265
266    Args:
267      allow_stateful: If true, we allow stateful ops to be present in the graph
268        def. In that case, the state in these ops would be thrown away.
269      strip_device_assignment: If true, non-local (i.e. job and task) device
270        assignment is stripped from ops in the serialized graph.
271      external_state_policy: The ExternalStatePolicy enum that determines how we
272        handle input pipelines that depend on external state. By default, its
273        set to WARN.
274
275    Returns:
276      A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
277      serialized graph.
278    """
279    if external_state_policy:
280      policy = external_state_policy.value
281      return gen_dataset_ops.dataset_to_graph_v2(
282          self._variant_tensor,
283          external_state_policy=policy,
284          strip_device_assignment=strip_device_assignment)
285    if strip_device_assignment:
286      return gen_dataset_ops.dataset_to_graph(
287          self._variant_tensor,
288          allow_stateful=allow_stateful,
289          strip_device_assignment=strip_device_assignment)
290    return gen_dataset_ops.dataset_to_graph(
291        self._variant_tensor, allow_stateful=allow_stateful)
292
293  def _trace_variant_creation(self):
294    """Traces a function which outputs a variant `tf.Tensor` for this dataset.
295
296    Note that creating this function involves evaluating an op, and is currently
297    only supported when executing eagerly.
298
299    Returns:
300      A zero-argument `ConcreteFunction` which outputs a variant `tf.Tensor`.
301    """
302    variant = self._variant_tensor
303    if not isinstance(variant, ops.EagerTensor):
304      raise NotImplementedError(
305          "Can only export Datasets which were created executing eagerly. "
306          "Please file a feature request if this is important to you.")
307    with context.eager_mode(), ops.device("CPU"):
308      # pylint: disable=protected-access
309      graph_def = graph_pb2.GraphDef().FromString(
310          self._as_serialized_graph(external_state_policy=options_lib
311                                    .ExternalStatePolicy.FAIL).numpy())
312    output_node_name = None
313    for node in graph_def.node:
314      if node.op == "_Retval":
315        if output_node_name is not None:
316          raise AssertionError(
317              "Found multiple return values from the dataset's graph, expected "
318              "only one.")
319        output_node_name, = node.input
320    if output_node_name is None:
321      raise AssertionError("Could not find the dataset's output node.")
322    # Add functions used in this Dataset to the function's graph, since they
323    # need to follow it around (and for example be added to a SavedModel which
324    # references the dataset).
325    variant_function = wrap_function.function_from_graph_def(
326        graph_def, inputs=[], outputs=output_node_name + ":0")
327    for used_function in self._functions():
328      used_function.function.add_to_graph(variant_function.graph)
329    return variant_function
330
331  @abc.abstractmethod
332  def _inputs(self):
333    """Returns a list of the input datasets of the dataset."""
334
335    raise NotImplementedError("Dataset._inputs")
336
337  @property
338  def _graph(self):
339    return self._graph_attr
340
341  @_graph.setter
342  def _graph(self, _):
343    raise ValueError("The _graph property is read-only")
344
345  # TODO(jsimsa): Change this to be the transitive closure of functions used
346  # by this dataset and its inputs.
347  def _functions(self):
348    """Returns a list of functions associated with this dataset.
349
350    Returns:
351      A list of `StructuredFunctionWrapper` objects.
352    """
353    return []
354
355  def _options(self):
356    """Returns the options tensor for this dataset."""
357    # pylint: disable=protected-access
358    return gen_dataset_ops.get_options(self._variant_tensor)
359
360  @classmethod
361  def _options_tensor_to_options(cls, serialized_options):
362    """Converts options tensor to tf.data.Options object."""
363    options = options_lib.Options()
364    if tensor_util.constant_value(serialized_options) is not None:
365      pb = dataset_options_pb2.Options.FromString(tensor_util.constant_value(
366          serialized_options))
367      options._from_proto(pb)  # pylint: disable=protected-access
368    return options
369
370  def options(self):
371    """Returns the options for this dataset and its inputs.
372
373    Returns:
374      A `tf.data.Options` object representing the dataset options.
375    """
376    if context.executing_eagerly():
377      options = self._options_tensor_to_options(self._options())
378      options._set_mutable(False)  # pylint: disable=protected-access
379      return options
380    warnings.warn("To make it possible to preserve tf.data options across "
381                  "serialization boundaries, their implementation has moved to "
382                  "be part of the TensorFlow graph. As a consequence, the "
383                  "options value is in general no longer known at graph "
384                  "construction time. Invoking this method in graph mode "
385                  "retains the legacy behavior of the original implementation, "
386                  "but note that the returned value might not reflect the "
387                  "actual value of the options.")
388    return self._options_attr
389
390  def _apply_debug_options(self):
391    if DEBUG_MODE:
392      # Disable autotuning and static optimizations that could introduce
393      # parallelism or asynchrony.
394      options = options_lib.Options()
395      options.autotune.enabled = False
396      options.experimental_optimization.map_and_batch_fusion = False
397      options.experimental_optimization.map_parallelization = False
398      dataset = _OptionsDataset(self, options)
399    else:
400      dataset = self
401
402    return dataset
403
404  def __iter__(self):
405    """Creates an iterator for elements of this dataset.
406
407    The returned iterator implements the Python Iterator protocol.
408
409    Returns:
410      An `tf.data.Iterator` for the elements of this dataset.
411
412    Raises:
413      RuntimeError: If not inside of tf.function and not executing eagerly.
414    """
415    if context.executing_eagerly() or ops.inside_function():
416      with ops.colocate_with(self._variant_tensor):
417        return iterator_ops.OwnedIterator(self)
418    else:
419      raise RuntimeError("__iter__() is only supported inside of tf.function "
420                         "or when eager execution is enabled.")
421
422  def __bool__(self):
423    return True  # Required as __len__ is defined
424
425  __nonzero__ = __bool__  # Python 2 backward compatibility
426
427  def __len__(self):
428    """Returns the length of the dataset if it is known and finite.
429
430    This method requires that you are running in eager mode, and that the
431    length of the dataset is known and non-infinite. When the length may be
432    unknown or infinite, or if you are running in graph mode, use
433    `tf.data.Dataset.cardinality` instead.
434
435    Returns:
436      An integer representing the length of the dataset.
437
438    Raises:
439      RuntimeError: If the dataset length is unknown or infinite, or if eager
440        execution is not enabled.
441    """
442    if not context.executing_eagerly():
443      raise TypeError("__len__() is not supported while tracing functions. "
444                      "Use `tf.data.Dataset.cardinality` instead.")
445    length = self.cardinality()
446    if length.numpy() == INFINITE:
447      raise TypeError("dataset length is infinite.")
448    if length.numpy() == UNKNOWN:
449      raise TypeError("dataset length is unknown.")
450    return length
451
452  @abc.abstractproperty
453  def element_spec(self):
454    """The type specification of an element of this dataset.
455
456    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
457    >>> dataset.element_spec
458    TensorSpec(shape=(), dtype=tf.int32, name=None)
459
460    For more information,
461    read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
462
463    Returns:
464      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
465      element of this dataset and specifying the type of individual components.
466    """
467    raise NotImplementedError("Dataset.element_spec")
468
469  def __repr__(self):
470    output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
471    output_shapes = str(output_shapes).replace("'", "")
472    output_types = nest.map_structure(repr, get_legacy_output_types(self))
473    output_types = str(output_types).replace("'", "")
474    return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
475                                            output_types))
476
477  def as_numpy_iterator(self):
478    """Returns an iterator which converts all elements of the dataset to numpy.
479
480    Use `as_numpy_iterator` to inspect the content of your dataset. To see
481    element shapes and types, print dataset elements directly instead of using
482    `as_numpy_iterator`.
483
484    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
485    >>> for element in dataset:
486    ...   print(element)
487    tf.Tensor(1, shape=(), dtype=int32)
488    tf.Tensor(2, shape=(), dtype=int32)
489    tf.Tensor(3, shape=(), dtype=int32)
490
491    This method requires that you are running in eager mode and the dataset's
492    element_spec contains only `TensorSpec` components.
493
494    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
495    >>> for element in dataset.as_numpy_iterator():
496    ...   print(element)
497    1
498    2
499    3
500
501    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
502    >>> print(list(dataset.as_numpy_iterator()))
503    [1, 2, 3]
504
505    `as_numpy_iterator()` will preserve the nested structure of dataset
506    elements.
507
508    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
509    ...                                               'b': [5, 6]})
510    >>> list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
511    ...                                       {'a': (2, 4), 'b': 6}]
512    True
513
514    Returns:
515      An iterable over the elements of the dataset, with their tensors converted
516      to numpy arrays.
517
518    Raises:
519      TypeError: if an element contains a non-`Tensor` value.
520      RuntimeError: if eager execution is not enabled.
521    """
522    if not context.executing_eagerly():
523      raise RuntimeError("as_numpy_iterator() is not supported while tracing "
524                         "functions")
525    for component_spec in nest.flatten(self.element_spec):
526      if not isinstance(
527          component_spec,
528          (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)):
529        raise TypeError(
530            "Dataset.as_numpy_iterator() does not support datasets containing "
531            + str(component_spec.value_type))
532
533    return _NumpyIterator(self)
534
535  @property
536  def _flat_shapes(self):
537    """Returns a list `tf.TensorShapes`s for the element tensor representation.
538
539    Returns:
540      A list `tf.TensorShapes`s for the element tensor representation.
541    """
542    return structure.get_flat_tensor_shapes(self.element_spec)
543
544  @property
545  def _flat_types(self):
546    """Returns a list `tf.DType`s for the element tensor representation.
547
548    Returns:
549      A list `tf.DType`s for the element tensor representation.
550    """
551    return structure.get_flat_tensor_types(self.element_spec)
552
553  @property
554  def _flat_structure(self):
555    """Helper for setting `output_shapes` and `output_types` attrs of an op.
556
557    Most dataset op constructors expect `output_shapes` and `output_types`
558    arguments that represent the flattened structure of an element. This helper
559    function generates these attrs as a keyword argument dictionary, allowing
560    `Dataset._variant_tensor` implementations to pass `**self._flat_structure`
561    to the op constructor.
562
563    Returns:
564      A dictionary of keyword arguments that can be passed to a dataset op
565      constructor.
566    """
567    return {
568        "output_shapes": self._flat_shapes,
569        "output_types": self._flat_types,
570    }
571
572  @property
573  def _type_spec(self):
574    return DatasetSpec(self.element_spec)
575
576  @staticmethod
577  def from_tensors(tensors):
578    """Creates a `Dataset` with a single element, comprising the given tensors.
579
580    `from_tensors` produces a dataset containing only a single element. To slice
581    the input tensor into multiple elements, use `from_tensor_slices` instead.
582
583    >>> dataset = tf.data.Dataset.from_tensors([1, 2, 3])
584    >>> list(dataset.as_numpy_iterator())
585    [array([1, 2, 3], dtype=int32)]
586    >>> dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
587    >>> list(dataset.as_numpy_iterator())
588    [(array([1, 2, 3], dtype=int32), b'A')]
589
590    >>> # You can use `from_tensors` to produce a dataset which repeats
591    >>> # the same example many times.
592    >>> example = tf.constant([1,2,3])
593    >>> dataset = tf.data.Dataset.from_tensors(example).repeat(2)
594    >>> list(dataset.as_numpy_iterator())
595    [array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)]
596
597    Note that if `tensors` contains a NumPy array, and eager execution is not
598    enabled, the values will be embedded in the graph as one or more
599    `tf.constant` operations. For large datasets (> 1 GB), this can waste
600    memory and run into byte limits of graph serialization. If `tensors`
601    contains one or more large NumPy arrays, consider the alternative described
602    in [this
603    guide](https://tensorflow.org/guide/data#consuming_numpy_arrays).
604
605    Args:
606      tensors: A dataset "element". Supported values are documented
607        [here](https://www.tensorflow.org/guide/data#dataset_structure).
608
609    Returns:
610      Dataset: A `Dataset`.
611    """
612    return TensorDataset(tensors)
613
614  @staticmethod
615  def from_tensor_slices(tensors):
616    """Creates a `Dataset` whose elements are slices of the given tensors.
617
618    The given tensors are sliced along their first dimension. This operation
619    preserves the structure of the input tensors, removing the first dimension
620    of each tensor and using it as the dataset dimension. All input tensors
621    must have the same size in their first dimensions.
622
623    >>> # Slicing a 1D tensor produces scalar tensor elements.
624    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
625    >>> list(dataset.as_numpy_iterator())
626    [1, 2, 3]
627
628    >>> # Slicing a 2D tensor produces 1D tensor elements.
629    >>> dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
630    >>> list(dataset.as_numpy_iterator())
631    [array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
632
633    >>> # Slicing a tuple of 1D tensors produces tuple elements containing
634    >>> # scalar tensors.
635    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
636    >>> list(dataset.as_numpy_iterator())
637    [(1, 3, 5), (2, 4, 6)]
638
639    >>> # Dictionary structure is also preserved.
640    >>> dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
641    >>> list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
642    ...                                       {'a': 2, 'b': 4}]
643    True
644
645    >>> # Two tensors can be combined into one Dataset object.
646    >>> features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
647    >>> labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
648    >>> dataset = Dataset.from_tensor_slices((features, labels))
649    >>> # Both the features and the labels tensors can be converted
650    >>> # to a Dataset object separately and combined after.
651    >>> features_dataset = Dataset.from_tensor_slices(features)
652    >>> labels_dataset = Dataset.from_tensor_slices(labels)
653    >>> dataset = Dataset.zip((features_dataset, labels_dataset))
654    >>> # A batched feature and label set can be converted to a Dataset
655    >>> # in similar fashion.
656    >>> batched_features = tf.constant([[[1, 3], [2, 3]],
657    ...                                 [[2, 1], [1, 2]],
658    ...                                 [[3, 3], [3, 2]]], shape=(3, 2, 2))
659    >>> batched_labels = tf.constant([['A', 'A'],
660    ...                               ['B', 'B'],
661    ...                               ['A', 'B']], shape=(3, 2, 1))
662    >>> dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
663    >>> for element in dataset.as_numpy_iterator():
664    ...   print(element)
665    (array([[1, 3],
666           [2, 3]], dtype=int32), array([[b'A'],
667           [b'A']], dtype=object))
668    (array([[2, 1],
669           [1, 2]], dtype=int32), array([[b'B'],
670           [b'B']], dtype=object))
671    (array([[3, 3],
672           [3, 2]], dtype=int32), array([[b'A'],
673           [b'B']], dtype=object))
674
675    Note that if `tensors` contains a NumPy array, and eager execution is not
676    enabled, the values will be embedded in the graph as one or more
677    `tf.constant` operations. For large datasets (> 1 GB), this can waste
678    memory and run into byte limits of graph serialization. If `tensors`
679    contains one or more large NumPy arrays, consider the alternative described
680    in [this guide](
681    https://tensorflow.org/guide/data#consuming_numpy_arrays).
682
683    Args:
684      tensors: A dataset element, whose components have the same first
685        dimension. Supported values are documented
686        [here](https://www.tensorflow.org/guide/data#dataset_structure).
687
688    Returns:
689      Dataset: A `Dataset`.
690    """
691    return TensorSliceDataset(tensors)
692
693  class _GeneratorState(object):
694    """Stores outstanding iterators created from a Python generator.
695
696    This class keeps track of potentially multiple iterators that may have
697    been created from a generator, e.g. in the case that the dataset is
698    repeated, or nested within a parallel computation.
699    """
700
701    def __init__(self, generator):
702      self._generator = generator
703      self._lock = threading.Lock()
704      self._next_id = 0  # GUARDED_BY(self._lock)
705      self._args = {}
706      self._iterators = {}
707
708    def get_next_id(self, *args):
709      with self._lock:
710        ret = self._next_id
711        self._next_id += 1
712      self._args[ret] = args
713      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
714      # casting in `py_func()` will create an array of `np.int32` on Windows,
715      # leading to a runtime error.
716      return np.array(ret, dtype=np.int64)
717
718    def get_iterator(self, iterator_id):
719      try:
720        return self._iterators[iterator_id]
721      except KeyError:
722        iterator = iter(self._generator(*self._args.pop(iterator_id)))
723        self._iterators[iterator_id] = iterator
724        return iterator
725
726    def iterator_completed(self, iterator_id):
727      del self._iterators[iterator_id]
728
729  @staticmethod
730  @deprecation.deprecated_args(None, "Use output_signature instead",
731                               "output_types", "output_shapes")
732  def from_generator(generator,
733                     output_types=None,
734                     output_shapes=None,
735                     args=None,
736                     output_signature=None):
737    """Creates a `Dataset` whose elements are generated by `generator`.
738
739    The `generator` argument must be a callable object that returns
740    an object that supports the `iter()` protocol (e.g. a generator function).
741
742    The elements generated by `generator` must be compatible with either the
743    given `output_signature` argument or with the given `output_types` and
744    (optionally) `output_shapes` arguments, whichever was specified.
745
746    The recommended way to call `from_generator` is to use the
747    `output_signature` argument. In this case the output will be assumed to
748    consist of objects with the classes, shapes and types defined by
749    `tf.TypeSpec` objects from `output_signature` argument:
750
751    >>> def gen():
752    ...   ragged_tensor = tf.ragged.constant([[1, 2], [3]])
753    ...   yield 42, ragged_tensor
754    >>>
755    >>> dataset = tf.data.Dataset.from_generator(
756    ...      gen,
757    ...      output_signature=(
758    ...          tf.TensorSpec(shape=(), dtype=tf.int32),
759    ...          tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))
760    >>>
761    >>> list(dataset.take(1))
762    [(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
763    <tf.RaggedTensor [[1, 2], [3]]>)]
764
765    There is also a deprecated way to call `from_generator` by either with
766    `output_types` argument alone or together with `output_shapes` argument.
767    In this case the output of the function will be assumed to consist of
768    `tf.Tensor` objects with the types defined by `output_types` and with the
769    shapes which are either unknown or defined by `output_shapes`.
770
771    Note: The current implementation of `Dataset.from_generator()` uses
772    `tf.numpy_function` and inherits the same constraints. In particular, it
773    requires the dataset and iterator related operations to be placed
774    on a device in the same process as the Python program that called
775    `Dataset.from_generator()`. The body of `generator` will not be
776    serialized in a `GraphDef`, and you should not use this method if you
777    need to serialize your model and restore it in a different environment.
778
779    Note: If `generator` depends on mutable global variables or other external
780    state, be aware that the runtime may invoke `generator` multiple times
781    (in order to support repeating the `Dataset`) and at any time
782    between the call to `Dataset.from_generator()` and the production of the
783    first element from the generator. Mutating global variables or external
784    state can cause undefined behavior, and we recommend that you explicitly
785    cache any external state in `generator` before calling
786    `Dataset.from_generator()`.
787
788    Note: While the `output_signature` parameter makes it possible to yield
789    `Dataset` elements, the scope of `Dataset.from_generator()` should be
790    limited to logic that cannot be expressed through tf.data operations. Using
791    tf.data operations within the generator function is an anti-pattern and may
792    result in incremental memory growth.
793
794    Args:
795      generator: A callable object that returns an object that supports the
796        `iter()` protocol. If `args` is not specified, `generator` must take no
797        arguments; otherwise it must take as many arguments as there are values
798        in `args`.
799      output_types: (Optional.) A (nested) structure of `tf.DType` objects
800        corresponding to each component of an element yielded by `generator`.
801      output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
802        objects corresponding to each component of an element yielded by
803        `generator`.
804      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
805        and passed to `generator` as NumPy-array arguments.
806      output_signature: (Optional.) A (nested) structure of `tf.TypeSpec`
807        objects corresponding to each component of an element yielded by
808        `generator`.
809
810    Returns:
811      Dataset: A `Dataset`.
812    """
813    if not callable(generator):
814      raise TypeError("`generator` must be callable.")
815
816    if output_signature is not None:
817      if output_types is not None:
818        raise TypeError("`output_types` can not be used together with "
819                        "`output_signature`")
820      if output_shapes is not None:
821        raise TypeError("`output_shapes` can not be used together with "
822                        "`output_signature`")
823      if not all(
824          isinstance(_, type_spec.TypeSpec)
825          for _ in nest.flatten(output_signature)):
826        raise TypeError("All the elements of `output_signature` must be "
827                        "`tf.TypeSpec` objects.")
828    else:
829      if output_types is None:
830        raise TypeError("Either `output_signature` or `output_types` must "
831                        "be specified")
832
833    if output_signature is None:
834      if output_shapes is None:
835        output_shapes = nest.map_structure(
836            lambda _: tensor_shape.TensorShape(None), output_types)
837      else:
838        output_shapes = nest.map_structure_up_to(output_types,
839                                                 tensor_shape.as_shape,
840                                                 output_shapes)
841      output_signature = nest.map_structure_up_to(output_types,
842                                                  tensor_spec.TensorSpec,
843                                                  output_shapes, output_types)
844    if all(
845        isinstance(x, tensor_spec.TensorSpec)
846        for x in nest.flatten(output_signature)):
847      output_types = nest.pack_sequence_as(
848          output_signature, [x.dtype for x in nest.flatten(output_signature)])
849      output_shapes = nest.pack_sequence_as(
850          output_signature, [x.shape for x in nest.flatten(output_signature)])
851
852    if args is None:
853      args = ()
854    else:
855      args = tuple(ops.convert_n_to_tensor(args, name="args"))
856
857    generator_state = DatasetV2._GeneratorState(generator)
858
859    def get_iterator_id_fn(unused_dummy):
860      """Creates a unique `iterator_id` for each pass over the dataset.
861
862      The returned `iterator_id` disambiguates between multiple concurrently
863      existing iterators.
864
865      Args:
866        unused_dummy: Ignored value.
867
868      Returns:
869        A `tf.int64` tensor whose value uniquely identifies an iterator in
870        `generator_state`.
871      """
872      return script_ops.numpy_function(generator_state.get_next_id, args,
873                                       dtypes.int64)
874
875    def generator_next_fn(iterator_id_t):
876      """Generates the next element from iterator with ID `iterator_id_t`.
877
878      We map this function across an infinite repetition of the
879      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
880
881      Args:
882        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies the
883          iterator in `generator_state` from which to generate an element.
884
885      Returns:
886        The next element to generate from the iterator.
887      """
888      if output_types and output_shapes:
889        flattened_types = [
890            dtypes.as_dtype(dt) for dt in nest.flatten(output_types)
891        ]
892        flattened_shapes = nest.flatten(output_shapes)
893
894        def generator_py_func(iterator_id):
895          """A `py_func` that will be called to invoke the iterator."""
896          # `next()` raises `StopIteration` when there are no more
897          # elements remaining to be generated.
898          values = next(generator_state.get_iterator(iterator_id))
899
900          # Use the same _convert function from the py_func() implementation to
901          # convert the returned values to arrays early, so that we can inspect
902          # their values.
903          try:
904            flattened_values = nest.flatten_up_to(output_types, values)
905          except (TypeError, ValueError):
906            six.reraise(
907                TypeError,
908                TypeError(
909                    "`generator` yielded an element that did not match the "
910                    "expected structure. The expected structure was %s, but "
911                    "the yielded element was %s." % (output_types, values)),
912                sys.exc_info()[2])
913          ret_arrays = []
914          for ret, dtype in zip(flattened_values, flattened_types):
915            try:
916              ret_arrays.append(
917                  script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
918                      ret,
919                      dtype=dtype.as_numpy_dtype))
920            except (TypeError, ValueError):
921              six.reraise(
922                  TypeError,
923                  TypeError(
924                      "`generator` yielded an element that could not be "
925                      "converted to the expected type. The expected type was "
926                      "%s, but the yielded element was %s." %
927                      (dtype.name, ret)),
928                  sys.exc_info()[2])
929
930          # Additional type and shape checking to ensure that the components of
931          # the generated element match the `output_types` and `output_shapes`
932          # arguments.
933          for (ret_array, expected_dtype,
934               expected_shape) in zip(ret_arrays, flattened_types,
935                                      flattened_shapes):
936            if ret_array.dtype != expected_dtype.as_numpy_dtype:
937              raise TypeError(
938                  "`generator` yielded an element of type %s where an element "
939                  "of type %s was expected." %
940                  (ret_array.dtype, expected_dtype.as_numpy_dtype))
941            if not expected_shape.is_compatible_with(ret_array.shape):
942              raise ValueError(
943                  "`generator` yielded an element of shape %s where an element "
944                  "of shape %s was expected." %
945                  (ret_array.shape, expected_shape))
946
947          return ret_arrays
948
949        flat_values = script_ops.numpy_function(generator_py_func,
950                                                [iterator_id_t],
951                                                flattened_types)
952
953        # The `py_func()` op drops the inferred shapes, so we add them back in
954        # here.
955        if output_shapes is not None:
956          for ret_t, shape in zip(flat_values, flattened_shapes):
957            ret_t.set_shape(shape)
958
959        return nest.pack_sequence_as(output_types, flat_values)
960      else:
961        flat_output_types = structure.get_flat_tensor_types(output_signature)
962
963        def generator_py_func(iterator_id):
964          """A `py_func` that will be called to invoke the iterator."""
965          # `next()` raises `StopIteration` when there are no more
966          # elements remaining to be generated.
967          values = next(generator_state.get_iterator(iterator_id.numpy()))
968
969          try:
970            values = structure.normalize_element(values, output_signature)
971          except (TypeError, ValueError):
972            six.reraise(
973                TypeError,
974                TypeError(
975                    "`generator` yielded an element that did not match the "
976                    "expected structure. The expected structure was %s, but "
977                    "the yielded element was %s." % (output_signature, values)),
978                sys.exc_info()[2])
979
980          values_spec = structure.type_spec_from_value(values)
981
982          if not structure.are_compatible(values_spec, output_signature):
983            raise TypeError(
984                "`generator` yielded an element of %s where an element "
985                "of %s was expected." % (values_spec, output_signature))
986
987          return structure.to_tensor_list(output_signature, values)
988
989        return script_ops._eager_py_func(  # pylint: disable=protected-access
990            generator_py_func,
991            inp=[iterator_id_t],
992            Tout=flat_output_types,
993            use_tape_cache=False)
994
995    def finalize_fn(iterator_id_t):
996      """Releases host-side state for the iterator with ID `iterator_id_t`."""
997
998      def finalize_py_func(iterator_id):
999        generator_state.iterator_completed(iterator_id)
1000        # We return a dummy value so that the `finalize_fn` has a valid
1001        # signature.
1002        # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
1003        # casting in `py_func()` will create an array of `np.int32` on Windows,
1004        # leading to a runtime error.
1005        return np.array(0, dtype=np.int64)
1006
1007      return script_ops.numpy_function(finalize_py_func, [iterator_id_t],
1008                                       dtypes.int64)
1009
1010    # This function associates each traversal of `generator` with a unique
1011    # iterator ID.
1012    def flat_map_fn(dummy_arg):
1013      # The `get_iterator_id_fn` gets a unique ID for the current instance of
1014      # of the generator.
1015      # The `generator_next_fn` gets the next element from the iterator with the
1016      # given ID, and raises StopIteration when that iterator contains no
1017      # more elements.
1018      return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
1019                               finalize_fn, output_signature)
1020
1021    # A single-element dataset that, each time it is evaluated, contains a
1022    # freshly-generated and unique (for the returned dataset) int64
1023    # ID that will be used to identify the appropriate Python state, which
1024    # is encapsulated in `generator_state`, and captured in
1025    # `get_iterator_id_map_fn`.
1026    dummy = 0
1027    id_dataset = Dataset.from_tensors(dummy)
1028
1029    # A dataset that contains all of the elements generated by a
1030    # single iterator created from `generator`, identified by the
1031    # iterator ID contained in `id_dataset`. Lifting the iteration
1032    # into a flat_map here enables multiple repetitions and/or nested
1033    # versions of the returned dataset to be created, because it forces
1034    # the generation of a new ID for each version.
1035    return id_dataset.flat_map(flat_map_fn)
1036
1037  @staticmethod
1038  def range(*args, **kwargs):
1039    """Creates a `Dataset` of a step-separated range of values.
1040
1041    >>> list(Dataset.range(5).as_numpy_iterator())
1042    [0, 1, 2, 3, 4]
1043    >>> list(Dataset.range(2, 5).as_numpy_iterator())
1044    [2, 3, 4]
1045    >>> list(Dataset.range(1, 5, 2).as_numpy_iterator())
1046    [1, 3]
1047    >>> list(Dataset.range(1, 5, -2).as_numpy_iterator())
1048    []
1049    >>> list(Dataset.range(5, 1).as_numpy_iterator())
1050    []
1051    >>> list(Dataset.range(5, 1, -2).as_numpy_iterator())
1052    [5, 3]
1053    >>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
1054    [2, 3, 4]
1055    >>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
1056    [1.0, 3.0]
1057
1058    Args:
1059      *args: follows the same semantics as python's xrange.
1060        len(args) == 1 -> start = 0, stop = args[0], step = 1.
1061        len(args) == 2 -> start = args[0], stop = args[1], step = 1.
1062        len(args) == 3 -> start = args[0], stop = args[1], step = args[2].
1063      **kwargs:
1064        - output_type: Its expected dtype. (Optional, default: `tf.int64`).
1065
1066    Returns:
1067      Dataset: A `RangeDataset`.
1068
1069    Raises:
1070      ValueError: if len(args) == 0.
1071    """
1072    return RangeDataset(*args, **kwargs)
1073
1074  @staticmethod
1075  def zip(datasets):
1076    """Creates a `Dataset` by zipping together the given datasets.
1077
1078    This method has similar semantics to the built-in `zip()` function
1079    in Python, with the main difference being that the `datasets`
1080    argument can be a (nested) structure of `Dataset` objects. The supported
1081    nesting mechanisms are documented
1082    [here] (https://www.tensorflow.org/guide/data#dataset_structure).
1083
1084    >>> # The nested structure of the `datasets` argument determines the
1085    >>> # structure of elements in the resulting dataset.
1086    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
1087    >>> b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
1088    >>> ds = tf.data.Dataset.zip((a, b))
1089    >>> list(ds.as_numpy_iterator())
1090    [(1, 4), (2, 5), (3, 6)]
1091    >>> ds = tf.data.Dataset.zip((b, a))
1092    >>> list(ds.as_numpy_iterator())
1093    [(4, 1), (5, 2), (6, 3)]
1094    >>>
1095    >>> # The `datasets` argument may contain an arbitrary number of datasets.
1096    >>> c = tf.data.Dataset.range(7, 13).batch(2)  # ==> [ [7, 8],
1097    ...                                            #       [9, 10],
1098    ...                                            #       [11, 12] ]
1099    >>> ds = tf.data.Dataset.zip((a, b, c))
1100    >>> for element in ds.as_numpy_iterator():
1101    ...   print(element)
1102    (1, 4, array([7, 8]))
1103    (2, 5, array([ 9, 10]))
1104    (3, 6, array([11, 12]))
1105    >>>
1106    >>> # The number of elements in the resulting dataset is the same as
1107    >>> # the size of the smallest dataset in `datasets`.
1108    >>> d = tf.data.Dataset.range(13, 15)  # ==> [ 13, 14 ]
1109    >>> ds = tf.data.Dataset.zip((a, d))
1110    >>> list(ds.as_numpy_iterator())
1111    [(1, 13), (2, 14)]
1112
1113    Args:
1114      datasets: A (nested) structure of datasets.
1115
1116    Returns:
1117      Dataset: A `Dataset`.
1118    """
1119    return ZipDataset(datasets)
1120
1121  def concatenate(self, dataset):
1122    """Creates a `Dataset` by concatenating the given dataset with this dataset.
1123
1124    >>> a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
1125    >>> b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
1126    >>> ds = a.concatenate(b)
1127    >>> list(ds.as_numpy_iterator())
1128    [1, 2, 3, 4, 5, 6, 7]
1129    >>> # The input dataset and dataset to be concatenated should have
1130    >>> # compatible element specs.
1131    >>> c = tf.data.Dataset.zip((a, b))
1132    >>> a.concatenate(c)
1133    Traceback (most recent call last):
1134    TypeError: Two datasets to concatenate have different types
1135    <dtype: 'int64'> and (tf.int64, tf.int64)
1136    >>> d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
1137    >>> a.concatenate(d)
1138    Traceback (most recent call last):
1139    TypeError: Two datasets to concatenate have different types
1140    <dtype: 'int64'> and <dtype: 'string'>
1141
1142    Args:
1143      dataset: `Dataset` to be concatenated.
1144
1145    Returns:
1146      Dataset: A `Dataset`.
1147    """
1148    return ConcatenateDataset(self, dataset)
1149
1150  def prefetch(self, buffer_size):
1151    """Creates a `Dataset` that prefetches elements from this dataset.
1152
1153    Most dataset input pipelines should end with a call to `prefetch`. This
1154    allows later elements to be prepared while the current element is being
1155    processed. This often improves latency and throughput, at the cost of
1156    using additional memory to store prefetched elements.
1157
1158    Note: Like other `Dataset` methods, prefetch operates on the
1159    elements of the input dataset. It has no concept of examples vs. batches.
1160    `examples.prefetch(2)` will prefetch two elements (2 examples),
1161    while `examples.batch(20).prefetch(2)` will prefetch 2 elements
1162    (2 batches, of 20 examples each).
1163
1164    >>> dataset = tf.data.Dataset.range(3)
1165    >>> dataset = dataset.prefetch(2)
1166    >>> list(dataset.as_numpy_iterator())
1167    [0, 1, 2]
1168
1169    Args:
1170      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the maximum
1171        number of elements that will be buffered when prefetching. If the value
1172        `tf.data.AUTOTUNE` is used, then the buffer size is dynamically tuned.
1173    Returns:
1174      Dataset: A `Dataset`.
1175    """
1176    if DEBUG_MODE:
1177      return self
1178    return PrefetchDataset(self, buffer_size)
1179
1180  @staticmethod
1181  def list_files(file_pattern, shuffle=None, seed=None):
1182    """A dataset of all files matching one or more glob patterns.
1183
1184    The `file_pattern` argument should be a small number of glob patterns.
1185    If your filenames have already been globbed, use
1186    `Dataset.from_tensor_slices(filenames)` instead, as re-globbing every
1187    filename with `list_files` may result in poor performance with remote
1188    storage systems.
1189
1190    Note: The default behavior of this method is to return filenames in
1191    a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
1192    to get results in a deterministic order.
1193
1194    Example:
1195      If we had the following files on our filesystem:
1196
1197        - /path/to/dir/a.txt
1198        - /path/to/dir/b.py
1199        - /path/to/dir/c.py
1200
1201      If we pass "/path/to/dir/*.py" as the directory, the dataset
1202      would produce:
1203
1204        - /path/to/dir/b.py
1205        - /path/to/dir/c.py
1206
1207    Args:
1208      file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
1209        (scalar or vector), representing the filename glob (i.e. shell wildcard)
1210        pattern(s) that will be matched.
1211      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
1212        Defaults to `True`.
1213      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1214        seed that will be used to create the distribution. See
1215        `tf.random.set_seed` for behavior.
1216
1217    Returns:
1218     Dataset: A `Dataset` of strings corresponding to file names.
1219    """
1220    with ops.name_scope("list_files"):
1221      if shuffle is None:
1222        shuffle = True
1223      file_pattern = ops.convert_to_tensor(
1224          file_pattern, dtype=dtypes.string, name="file_pattern")
1225      matching_files = gen_io_ops.matching_files(file_pattern)
1226
1227      # Raise an exception if `file_pattern` does not match any files.
1228      condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
1229                                   name="match_not_empty")
1230
1231      message = math_ops.add(
1232          "No files matched pattern: ",
1233          string_ops.reduce_join(file_pattern, separator=", "), name="message")
1234
1235      assert_not_empty = control_flow_ops.Assert(
1236          condition, [message], summarize=1, name="assert_not_empty")
1237      with ops.control_dependencies([assert_not_empty]):
1238        matching_files = array_ops.identity(matching_files)
1239
1240      dataset = Dataset.from_tensor_slices(matching_files)
1241      if shuffle:
1242        # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
1243        # list of files might be empty.
1244        buffer_size = math_ops.maximum(
1245            array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
1246        dataset = dataset.shuffle(buffer_size, seed=seed)
1247      return dataset
1248
1249  def repeat(self, count=None):
1250    """Repeats this dataset so each original value is seen `count` times.
1251
1252    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1253    >>> dataset = dataset.repeat(3)
1254    >>> list(dataset.as_numpy_iterator())
1255    [1, 2, 3, 1, 2, 3, 1, 2, 3]
1256
1257    Note: If this dataset is a function of global state (e.g. a random number
1258    generator), then different repetitions may produce different elements.
1259
1260    Args:
1261      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1262        number of times the dataset should be repeated. The default behavior (if
1263        `count` is `None` or `-1`) is for the dataset be repeated indefinitely.
1264
1265    Returns:
1266      Dataset: A `Dataset`.
1267    """
1268    return RepeatDataset(self, count)
1269
1270  def enumerate(self, start=0):
1271    """Enumerates the elements of this dataset.
1272
1273    It is similar to python's `enumerate`.
1274
1275    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1276    >>> dataset = dataset.enumerate(start=5)
1277    >>> for element in dataset.as_numpy_iterator():
1278    ...   print(element)
1279    (5, 1)
1280    (6, 2)
1281    (7, 3)
1282
1283    >>> # The (nested) structure of the input dataset determines the
1284    >>> # structure of elements in the resulting dataset.
1285    >>> dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
1286    >>> dataset = dataset.enumerate()
1287    >>> for element in dataset.as_numpy_iterator():
1288    ...   print(element)
1289    (0, array([7, 8], dtype=int32))
1290    (1, array([ 9, 10], dtype=int32))
1291
1292    Args:
1293      start: A `tf.int64` scalar `tf.Tensor`, representing the start value for
1294        enumeration.
1295
1296    Returns:
1297      Dataset: A `Dataset`.
1298    """
1299
1300    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
1301    return Dataset.zip((Dataset.range(start, max_value), self))
1302
1303  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
1304    """Randomly shuffles the elements of this dataset.
1305
1306    This dataset fills a buffer with `buffer_size` elements, then randomly
1307    samples elements from this buffer, replacing the selected elements with new
1308    elements. For perfect shuffling, a buffer size greater than or equal to the
1309    full size of the dataset is required.
1310
1311    For instance, if your dataset contains 10,000 elements but `buffer_size` is
1312    set to 1,000, then `shuffle` will initially select a random element from
1313    only the first 1,000 elements in the buffer. Once an element is selected,
1314    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
1315    maintaining the 1,000 element buffer.
1316
1317    `reshuffle_each_iteration` controls whether the shuffle order should be
1318    different for each epoch. In TF 1.X, the idiomatic way to create epochs
1319    was through the `repeat` transformation:
1320
1321    ```python
1322    dataset = tf.data.Dataset.range(3)
1323    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1324    dataset = dataset.repeat(2)
1325    # [1, 0, 2, 1, 2, 0]
1326
1327    dataset = tf.data.Dataset.range(3)
1328    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1329    dataset = dataset.repeat(2)
1330    # [1, 0, 2, 1, 0, 2]
1331    ```
1332
1333    In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
1334    possible to also create epochs through Python iteration:
1335
1336    ```python
1337    dataset = tf.data.Dataset.range(3)
1338    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
1339    list(dataset.as_numpy_iterator())
1340    # [1, 0, 2]
1341    list(dataset.as_numpy_iterator())
1342    # [1, 2, 0]
1343    ```
1344
1345    ```python
1346    dataset = tf.data.Dataset.range(3)
1347    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
1348    list(dataset.as_numpy_iterator())
1349    # [1, 0, 2]
1350    list(dataset.as_numpy_iterator())
1351    # [1, 0, 2]
1352    ```
1353
1354    Args:
1355      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1356        elements from this dataset from which the new dataset will sample.
1357      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
1358        seed that will be used to create the distribution. See
1359        `tf.random.set_seed` for behavior.
1360      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
1361        that the dataset should be pseudorandomly reshuffled each time it is
1362        iterated over. (Defaults to `True`.)
1363
1364    Returns:
1365      Dataset: A `Dataset`.
1366    """
1367    return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
1368
1369  def cache(self, filename=""):
1370    """Caches the elements in this dataset.
1371
1372    The first time the dataset is iterated over, its elements will be cached
1373    either in the specified file or in memory. Subsequent iterations will
1374    use the cached data.
1375
1376    Note: For the cache to be finalized, the input dataset must be iterated
1377    through in its entirety. Otherwise, subsequent iterations will not use
1378    cached data.
1379
1380    >>> dataset = tf.data.Dataset.range(5)
1381    >>> dataset = dataset.map(lambda x: x**2)
1382    >>> dataset = dataset.cache()
1383    >>> # The first time reading through the data will generate the data using
1384    >>> # `range` and `map`.
1385    >>> list(dataset.as_numpy_iterator())
1386    [0, 1, 4, 9, 16]
1387    >>> # Subsequent iterations read from the cache.
1388    >>> list(dataset.as_numpy_iterator())
1389    [0, 1, 4, 9, 16]
1390
1391    When caching to a file, the cached data will persist across runs. Even the
1392    first iteration through the data will read from the cache file. Changing
1393    the input pipeline before the call to `.cache()` will have no effect until
1394    the cache file is removed or the filename is changed.
1395
1396    ```python
1397    dataset = tf.data.Dataset.range(5)
1398    dataset = dataset.cache("/path/to/file")
1399    list(dataset.as_numpy_iterator())
1400    # [0, 1, 2, 3, 4]
1401    dataset = tf.data.Dataset.range(10)
1402    dataset = dataset.cache("/path/to/file")  # Same file!
1403    list(dataset.as_numpy_iterator())
1404    # [0, 1, 2, 3, 4]
1405    ```
1406
1407    Note: `cache` will produce exactly the same elements during each iteration
1408    through the dataset. If you wish to randomize the iteration order, make sure
1409    to call `shuffle` *after* calling `cache`.
1410
1411    Args:
1412      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
1413        directory on the filesystem to use for caching elements in this Dataset.
1414        If a filename is not provided, the dataset will be cached in memory.
1415
1416    Returns:
1417      Dataset: A `Dataset`.
1418    """
1419    return CacheDataset(self, filename)
1420
1421  def take(self, count):
1422    """Creates a `Dataset` with at most `count` elements from this dataset.
1423
1424    >>> dataset = tf.data.Dataset.range(10)
1425    >>> dataset = dataset.take(3)
1426    >>> list(dataset.as_numpy_iterator())
1427    [0, 1, 2]
1428
1429    Args:
1430      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1431        elements of this dataset that should be taken to form the new dataset.
1432        If `count` is -1, or if `count` is greater than the size of this
1433        dataset, the new dataset will contain all elements of this dataset.
1434
1435    Returns:
1436      Dataset: A `Dataset`.
1437    """
1438    return TakeDataset(self, count)
1439
1440  def skip(self, count):
1441    """Creates a `Dataset` that skips `count` elements from this dataset.
1442
1443    >>> dataset = tf.data.Dataset.range(10)
1444    >>> dataset = dataset.skip(7)
1445    >>> list(dataset.as_numpy_iterator())
1446    [7, 8, 9]
1447
1448    Args:
1449      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
1450        elements of this dataset that should be skipped to form the new dataset.
1451        If `count` is greater than the size of this dataset, the new dataset
1452        will contain no elements.  If `count` is -1, skips the entire dataset.
1453
1454    Returns:
1455      Dataset: A `Dataset`.
1456    """
1457    return SkipDataset(self, count)
1458
1459  def shard(self, num_shards, index):
1460    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
1461
1462    `shard` is deterministic. The Dataset produced by `A.shard(n, i)` will
1463    contain all elements of A whose index mod n = i.
1464
1465    >>> A = tf.data.Dataset.range(10)
1466    >>> B = A.shard(num_shards=3, index=0)
1467    >>> list(B.as_numpy_iterator())
1468    [0, 3, 6, 9]
1469    >>> C = A.shard(num_shards=3, index=1)
1470    >>> list(C.as_numpy_iterator())
1471    [1, 4, 7]
1472    >>> D = A.shard(num_shards=3, index=2)
1473    >>> list(D.as_numpy_iterator())
1474    [2, 5, 8]
1475
1476    This dataset operator is very useful when running distributed training, as
1477    it allows each worker to read a unique subset.
1478
1479    When reading a single input file, you can shard elements as follows:
1480
1481    ```python
1482    d = tf.data.TFRecordDataset(input_file)
1483    d = d.shard(num_workers, worker_index)
1484    d = d.repeat(num_epochs)
1485    d = d.shuffle(shuffle_buffer_size)
1486    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1487    ```
1488
1489    Important caveats:
1490
1491    - Be sure to shard before you use any randomizing operator (such as
1492      shuffle).
1493    - Generally it is best if the shard operator is used early in the dataset
1494      pipeline. For example, when reading from a set of TFRecord files, shard
1495      before converting the dataset to input samples. This avoids reading every
1496      file on every worker. The following is an example of an efficient
1497      sharding strategy within a complete pipeline:
1498
1499    ```python
1500    d = Dataset.list_files(pattern)
1501    d = d.shard(num_workers, worker_index)
1502    d = d.repeat(num_epochs)
1503    d = d.shuffle(shuffle_buffer_size)
1504    d = d.interleave(tf.data.TFRecordDataset,
1505                     cycle_length=num_readers, block_length=1)
1506    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
1507    ```
1508
1509    Args:
1510      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
1511        shards operating in parallel.
1512      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
1513
1514    Returns:
1515      Dataset: A `Dataset`.
1516
1517    Raises:
1518      InvalidArgumentError: if `num_shards` or `index` are illegal values.
1519
1520        Note: error checking is done on a best-effort basis, and errors aren't
1521        guaranteed to be caught upon dataset creation. (e.g. providing in a
1522        placeholder tensor bypasses the early checking, and will instead result
1523        in an error during a session.run call.)
1524    """
1525    return ShardDataset(self, num_shards, index)
1526
1527  def batch(self,
1528            batch_size,
1529            drop_remainder=False,
1530            num_parallel_calls=None,
1531            deterministic=None):
1532    """Combines consecutive elements of this dataset into batches.
1533
1534    >>> dataset = tf.data.Dataset.range(8)
1535    >>> dataset = dataset.batch(3)
1536    >>> list(dataset.as_numpy_iterator())
1537    [array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
1538
1539    >>> dataset = tf.data.Dataset.range(8)
1540    >>> dataset = dataset.batch(3, drop_remainder=True)
1541    >>> list(dataset.as_numpy_iterator())
1542    [array([0, 1, 2]), array([3, 4, 5])]
1543
1544    The components of the resulting element will have an additional outer
1545    dimension, which will be `batch_size` (or `N % batch_size` for the last
1546    element if `batch_size` does not divide the number of input elements `N`
1547    evenly and `drop_remainder` is `False`). If your program depends on the
1548    batches having the same outer dimension, you should set the `drop_remainder`
1549    argument to `True` to prevent the smaller batch from being produced.
1550
1551    Note: If your program requires data to have a statically known shape (e.g.,
1552    when using XLA), you should use `drop_remainder=True`. Without
1553    `drop_remainder=True` the shape of the output dataset will have an unknown
1554    leading dimension due to the possibility of a smaller final batch.
1555
1556    Args:
1557      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1558        consecutive elements of this dataset to combine in a single batch.
1559      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1560        whether the last batch should be dropped in the case it has fewer than
1561        `batch_size` elements; the default behavior is not to drop the smaller
1562        batch.
1563      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
1564        representing the number of batches to compute asynchronously in
1565        parallel.
1566        If not specified, batches will be computed sequentially. If the value
1567        `tf.data.AUTOTUNE` is used, then the number of parallel
1568        calls is set dynamically based on available resources.
1569      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
1570        boolean is specified (`True` or `False`), it controls the order in which
1571        the transformation produces elements. If set to `False`, the
1572        transformation is allowed to yield elements out of order to trade
1573        determinism for performance. If not specified, the
1574        `tf.data.Options.deterministic` option (`True` by default) controls the
1575        behavior.
1576
1577    Returns:
1578      Dataset: A `Dataset`.
1579    """
1580    if num_parallel_calls is None or DEBUG_MODE:
1581      if deterministic is not None and not DEBUG_MODE:
1582        warnings.warn("The `deterministic` argument has no effect unless the "
1583                      "`num_parallel_calls` argument is specified.")
1584      return BatchDataset(self, batch_size, drop_remainder)
1585    else:
1586      return ParallelBatchDataset(self, batch_size, drop_remainder,
1587                                  num_parallel_calls, deterministic)
1588
1589  def padded_batch(self,
1590                   batch_size,
1591                   padded_shapes=None,
1592                   padding_values=None,
1593                   drop_remainder=False):
1594    """Combines consecutive elements of this dataset into padded batches.
1595
1596    This transformation combines multiple consecutive elements of the input
1597    dataset into a single element.
1598
1599    Like `tf.data.Dataset.batch`, the components of the resulting element will
1600    have an additional outer dimension, which will be `batch_size` (or
1601    `N % batch_size` for the last element if `batch_size` does not divide the
1602    number of input elements `N` evenly and `drop_remainder` is `False`). If
1603    your program depends on the batches having the same outer dimension, you
1604    should set the `drop_remainder` argument to `True` to prevent the smaller
1605    batch from being produced.
1606
1607    Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
1608    different shapes, and this transformation will pad each component to the
1609    respective shape in `padded_shapes`. The `padded_shapes` argument
1610    determines the resulting shape for each dimension of each component in an
1611    output element:
1612
1613    * If the dimension is a constant, the component will be padded out to that
1614      length in that dimension.
1615    * If the dimension is unknown, the component will be padded out to the
1616      maximum length of all elements in that dimension.
1617
1618    >>> A = (tf.data.Dataset
1619    ...      .range(1, 5, output_type=tf.int32)
1620    ...      .map(lambda x: tf.fill([x], x)))
1621    >>> # Pad to the smallest per-batch size that fits all elements.
1622    >>> B = A.padded_batch(2)
1623    >>> for element in B.as_numpy_iterator():
1624    ...   print(element)
1625    [[1 0]
1626     [2 2]]
1627    [[3 3 3 0]
1628     [4 4 4 4]]
1629    >>> # Pad to a fixed size.
1630    >>> C = A.padded_batch(2, padded_shapes=5)
1631    >>> for element in C.as_numpy_iterator():
1632    ...   print(element)
1633    [[1 0 0 0 0]
1634     [2 2 0 0 0]]
1635    [[3 3 3 0 0]
1636     [4 4 4 4 0]]
1637    >>> # Pad with a custom value.
1638    >>> D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
1639    >>> for element in D.as_numpy_iterator():
1640    ...   print(element)
1641    [[ 1 -1 -1 -1 -1]
1642     [ 2  2 -1 -1 -1]]
1643    [[ 3  3  3 -1 -1]
1644     [ 4  4  4  4 -1]]
1645    >>> # Components of nested elements can be padded independently.
1646    >>> elements = [([1, 2, 3], [10]),
1647    ...             ([4, 5], [11, 12])]
1648    >>> dataset = tf.data.Dataset.from_generator(
1649    ...     lambda: iter(elements), (tf.int32, tf.int32))
1650    >>> # Pad the first component of the tuple to length 4, and the second
1651    >>> # component to the smallest size that fits.
1652    >>> dataset = dataset.padded_batch(2,
1653    ...     padded_shapes=([4], [None]),
1654    ...     padding_values=(-1, 100))
1655    >>> list(dataset.as_numpy_iterator())
1656    [(array([[ 1,  2,  3, -1], [ 4,  5, -1, -1]], dtype=int32),
1657      array([[ 10, 100], [ 11,  12]], dtype=int32))]
1658    >>> # Pad with a single value and multiple components.
1659    >>> E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1)
1660    >>> for element in E.as_numpy_iterator():
1661    ...   print(element)
1662    (array([[ 1, -1],
1663           [ 2,  2]], dtype=int32), array([[ 1, -1],
1664           [ 2,  2]], dtype=int32))
1665    (array([[ 3,  3,  3, -1],
1666           [ 4,  4,  4,  4]], dtype=int32), array([[ 3,  3,  3, -1],
1667           [ 4,  4,  4,  4]], dtype=int32))
1668
1669    See also `tf.data.experimental.dense_to_sparse_batch`, which combines
1670    elements that may have different shapes into a `tf.sparse.SparseTensor`.
1671
1672    Args:
1673      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
1674        consecutive elements of this dataset to combine in a single batch.
1675      padded_shapes: (Optional.) A (nested) structure of `tf.TensorShape` or
1676        `tf.int64` vector tensor-like objects representing the shape to which
1677        the respective component of each input element should be padded prior
1678        to batching. Any unknown dimensions will be padded to the maximum size
1679        of that dimension in each batch. If unset, all dimensions of all
1680        components are padded to the maximum size in the batch. `padded_shapes`
1681        must be set if any component has an unknown rank.
1682      padding_values: (Optional.) A (nested) structure of scalar-shaped
1683        `tf.Tensor`, representing the padding values to use for the respective
1684        components. None represents that the (nested) structure should be padded
1685        with default values.  Defaults are `0` for numeric types and the empty
1686        string for string types. The `padding_values` should have the same
1687        (nested) structure as the input dataset. If `padding_values` is a single
1688        element and the input dataset has multiple components, then the same
1689        `padding_values` will be used to pad every component of the dataset.
1690        If `padding_values` is a scalar, then its value will be broadcasted
1691        to match the shape of each component.
1692      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1693        whether the last batch should be dropped in the case it has fewer than
1694        `batch_size` elements; the default behavior is not to drop the smaller
1695        batch.
1696
1697    Returns:
1698      Dataset: A `Dataset`.
1699
1700    Raises:
1701      ValueError: If a component has an unknown rank, and  the `padded_shapes`
1702        argument is not set.
1703    """
1704    if padded_shapes is None:
1705      padded_shapes = get_legacy_output_shapes(self)
1706      # A `tf.TensorShape` is only false if its *rank* is unknown:
1707      # bool(tf.TensorShape(None)) is False
1708      if not all(nest.flatten(padded_shapes)):
1709        raise ValueError("You must set the `padded_shapes` argument to "
1710                         "`Dataset.padded_batch` if any component of its "
1711                         "input has an unknown rank")
1712    return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
1713                              drop_remainder)
1714
1715  def map(self, map_func, num_parallel_calls=None, deterministic=None):
1716    """Maps `map_func` across the elements of this dataset.
1717
1718    This transformation applies `map_func` to each element of this dataset, and
1719    returns a new dataset containing the transformed elements, in the same
1720    order as they appeared in the input. `map_func` can be used to change both
1721    the values and the structure of a dataset's elements. Supported structure
1722    constructs are documented
1723    [here](https://www.tensorflow.org/guide/data#dataset_structure).
1724
1725    For example, `map` can be used for adding 1 to each element, or projecting a
1726    subset of element components.
1727
1728    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1729    >>> dataset = dataset.map(lambda x: x + 1)
1730    >>> list(dataset.as_numpy_iterator())
1731    [2, 3, 4, 5, 6]
1732
1733    The input signature of `map_func` is determined by the structure of each
1734    element in this dataset.
1735
1736    >>> dataset = Dataset.range(5)
1737    >>> # `map_func` takes a single argument of type `tf.Tensor` with the same
1738    >>> # shape and dtype.
1739    >>> result = dataset.map(lambda x: x + 1)
1740
1741    >>> # Each element is a tuple containing two `tf.Tensor` objects.
1742    >>> elements = [(1, "foo"), (2, "bar"), (3, "baz")]
1743    >>> dataset = tf.data.Dataset.from_generator(
1744    ...     lambda: elements, (tf.int32, tf.string))
1745    >>> # `map_func` takes two arguments of type `tf.Tensor`. This function
1746    >>> # projects out just the first component.
1747    >>> result = dataset.map(lambda x_int, y_str: x_int)
1748    >>> list(result.as_numpy_iterator())
1749    [1, 2, 3]
1750
1751    >>> # Each element is a dictionary mapping strings to `tf.Tensor` objects.
1752    >>> elements =  ([{"a": 1, "b": "foo"},
1753    ...               {"a": 2, "b": "bar"},
1754    ...               {"a": 3, "b": "baz"}])
1755    >>> dataset = tf.data.Dataset.from_generator(
1756    ...     lambda: elements, {"a": tf.int32, "b": tf.string})
1757    >>> # `map_func` takes a single argument of type `dict` with the same keys
1758    >>> # as the elements.
1759    >>> result = dataset.map(lambda d: str(d["a"]) + d["b"])
1760
1761    The value or values returned by `map_func` determine the structure of each
1762    element in the returned dataset.
1763
1764    >>> dataset = tf.data.Dataset.range(3)
1765    >>> # `map_func` returns two `tf.Tensor` objects.
1766    >>> def g(x):
1767    ...   return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
1768    >>> result = dataset.map(g)
1769    >>> result.element_spec
1770    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), \
1771dtype=tf.string, name=None))
1772    >>> # Python primitives, lists, and NumPy arrays are implicitly converted to
1773    >>> # `tf.Tensor`.
1774    >>> def h(x):
1775    ...   return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
1776    >>> result = dataset.map(h)
1777    >>> result.element_spec
1778    (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), \
1779dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, \
1780name=None))
1781    >>> # `map_func` can return nested structures.
1782    >>> def i(x):
1783    ...   return (37.0, [42, 16]), "foo"
1784    >>> result = dataset.map(i)
1785    >>> result.element_spec
1786    ((TensorSpec(shape=(), dtype=tf.float32, name=None),
1787      TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
1788     TensorSpec(shape=(), dtype=tf.string, name=None))
1789
1790    `map_func` can accept as arguments and return any type of dataset element.
1791
1792    Note that irrespective of the context in which `map_func` is defined (eager
1793    vs. graph), tf.data traces the function and executes it as a graph. To use
1794    Python code inside of the function you have a few options:
1795
1796    1) Rely on AutoGraph to convert Python code into an equivalent graph
1797    computation. The downside of this approach is that AutoGraph can convert
1798    some but not all Python code.
1799
1800    2) Use `tf.py_function`, which allows you to write arbitrary Python code but
1801    will generally result in worse performance than 1). For example:
1802
1803    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
1804    >>> # transform a string tensor to upper case string using a Python function
1805    >>> def upper_case_fn(t: tf.Tensor):
1806    ...   return t.numpy().decode('utf-8').upper()
1807    >>> d = d.map(lambda x: tf.py_function(func=upper_case_fn,
1808    ...           inp=[x], Tout=tf.string))
1809    >>> list(d.as_numpy_iterator())
1810    [b'HELLO', b'WORLD']
1811
1812    3) Use `tf.numpy_function`, which also allows you to write arbitrary
1813    Python code. Note that `tf.py_function` accepts `tf.Tensor` whereas
1814    `tf.numpy_function` accepts numpy arrays and returns only numpy arrays.
1815    For example:
1816
1817    >>> d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
1818    >>> def upper_case_fn(t: np.ndarray):
1819    ...   return t.decode('utf-8').upper()
1820    >>> d = d.map(lambda x: tf.numpy_function(func=upper_case_fn,
1821    ...           inp=[x], Tout=tf.string))
1822    >>> list(d.as_numpy_iterator())
1823    [b'HELLO', b'WORLD']
1824
1825    Note that the use of `tf.numpy_function` and `tf.py_function`
1826    in general precludes the possibility of executing user-defined
1827    transformations in parallel (because of Python GIL).
1828
1829    Performance can often be improved by setting `num_parallel_calls` so that
1830    `map` will use multiple threads to process elements. If deterministic order
1831    isn't required, it can also improve performance to set
1832    `deterministic=False`.
1833
1834    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1835    >>> dataset = dataset.map(lambda x: x + 1,
1836    ...     num_parallel_calls=tf.data.AUTOTUNE,
1837    ...     deterministic=False)
1838
1839    The order of elements yielded by this transformation is deterministic if
1840    `deterministic=True`. If `map_func` contains stateful operations and
1841    `num_parallel_calls > 1`, the order in which that state is accessed is
1842    undefined, so the values of output elements may not be deterministic
1843    regardless of the `deterministic` flag value.
1844
1845    Args:
1846      map_func: A function mapping a dataset element to another dataset element.
1847      num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
1848        representing the number elements to process asynchronously in parallel.
1849        If not specified, elements will be processed sequentially. If the value
1850        `tf.data.AUTOTUNE` is used, then the number of parallel
1851        calls is set dynamically based on available CPU.
1852      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
1853        boolean is specified (`True` or `False`), it controls the order in which
1854        the transformation produces elements. If set to `False`, the
1855        transformation is allowed to yield elements out of order to trade
1856        determinism for performance. If not specified, the
1857        `tf.data.Options.deterministic` option (`True` by default) controls the
1858        behavior.
1859
1860    Returns:
1861      Dataset: A `Dataset`.
1862    """
1863    if num_parallel_calls is None or DEBUG_MODE:
1864      if deterministic is not None and not DEBUG_MODE:
1865        warnings.warn("The `deterministic` argument has no effect unless the "
1866                      "`num_parallel_calls` argument is specified.")
1867      return MapDataset(self, map_func, preserve_cardinality=True)
1868    else:
1869      return ParallelMapDataset(
1870          self,
1871          map_func,
1872          num_parallel_calls,
1873          deterministic,
1874          preserve_cardinality=True)
1875
1876  def flat_map(self, map_func):
1877    """Maps `map_func` across this dataset and flattens the result.
1878
1879    The type signature is:
1880
1881    ```
1882    def flat_map(
1883      self: Dataset[T],
1884      map_func: Callable[[T], Dataset[S]]
1885    ) -> Dataset[S]
1886    ```
1887
1888    Use `flat_map` if you want to make sure that the order of your dataset
1889    stays the same. For example, to flatten a dataset of batches into a
1890    dataset of their elements:
1891
1892    >>> dataset = tf.data.Dataset.from_tensor_slices(
1893    ...     [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1894    >>> dataset = dataset.flat_map(
1895    ...     lambda x: tf.data.Dataset.from_tensor_slices(x))
1896    >>> list(dataset.as_numpy_iterator())
1897    [1, 2, 3, 4, 5, 6, 7, 8, 9]
1898
1899    `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
1900    `flat_map` produces the same output as
1901    `tf.data.Dataset.interleave(cycle_length=1)`
1902
1903    Args:
1904      map_func: A function mapping a dataset element to a dataset.
1905
1906    Returns:
1907      Dataset: A `Dataset`.
1908    """
1909    return FlatMapDataset(self, map_func)
1910
1911  def interleave(self,
1912                 map_func,
1913                 cycle_length=None,
1914                 block_length=None,
1915                 num_parallel_calls=None,
1916                 deterministic=None):
1917    """Maps `map_func` across this dataset, and interleaves the results.
1918
1919    The type signature is:
1920
1921    ```
1922    def interleave(
1923      self: Dataset[T],
1924      map_func: Callable[[T], Dataset[S]]
1925    ) -> Dataset[S]
1926    ```
1927
1928    For example, you can use `Dataset.interleave()` to process many input files
1929    concurrently:
1930
1931    >>> # Preprocess 4 files concurrently, and interleave blocks of 16 records
1932    >>> # from each file.
1933    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
1934    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
1935    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
1936    >>> def parse_fn(filename):
1937    ...   return tf.data.Dataset.range(10)
1938    >>> dataset = dataset.interleave(lambda x:
1939    ...     tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
1940    ...     cycle_length=4, block_length=16)
1941
1942    The `cycle_length` and `block_length` arguments control the order in which
1943    elements are produced. `cycle_length` controls the number of input elements
1944    that are processed concurrently. If you set `cycle_length` to 1, this
1945    transformation will handle one input element at a time, and will produce
1946    identical results to `tf.data.Dataset.flat_map`. In general,
1947    this transformation will apply `map_func` to `cycle_length` input elements,
1948    open iterators on the returned `Dataset` objects, and cycle through them
1949    producing `block_length` consecutive elements from each iterator, and
1950    consuming the next input element each time it reaches the end of an
1951    iterator.
1952
1953    For example:
1954
1955    >>> dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
1956    >>> # NOTE: New lines indicate "block" boundaries.
1957    >>> dataset = dataset.interleave(
1958    ...     lambda x: Dataset.from_tensors(x).repeat(6),
1959    ...     cycle_length=2, block_length=4)
1960    >>> list(dataset.as_numpy_iterator())
1961    [1, 1, 1, 1,
1962     2, 2, 2, 2,
1963     1, 1,
1964     2, 2,
1965     3, 3, 3, 3,
1966     4, 4, 4, 4,
1967     3, 3,
1968     4, 4,
1969     5, 5, 5, 5,
1970     5, 5]
1971
1972    Note: The order of elements yielded by this transformation is
1973    deterministic, as long as `map_func` is a pure function and
1974    `deterministic=True`. If `map_func` contains any stateful operations, the
1975    order in which that state is accessed is undefined.
1976
1977    Performance can often be improved by setting `num_parallel_calls` so that
1978    `interleave` will use multiple threads to fetch elements. If determinism
1979    isn't required, it can also improve performance to set
1980    `deterministic=False`.
1981
1982    >>> filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
1983    ...              "/var/data/file3.txt", "/var/data/file4.txt"]
1984    >>> dataset = tf.data.Dataset.from_tensor_slices(filenames)
1985    >>> dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
1986    ...     cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE,
1987    ...     deterministic=False)
1988
1989    Args:
1990      map_func: A function that takes a dataset element and returns a
1991        `tf.data.Dataset`.
1992      cycle_length: (Optional.) The number of input elements that will be
1993        processed concurrently. If not set, the tf.data runtime decides what it
1994        should be based on available CPU. If `num_parallel_calls` is set to
1995        `tf.data.AUTOTUNE`, the `cycle_length` argument identifies
1996        the maximum degree of parallelism.
1997      block_length: (Optional.) The number of consecutive elements to produce
1998        from each input element before cycling to another input element. If not
1999        set, defaults to 1.
2000      num_parallel_calls: (Optional.) If specified, the implementation creates a
2001        threadpool, which is used to fetch inputs from cycle elements
2002        asynchronously and in parallel. The default behavior is to fetch inputs
2003        from cycle elements synchronously with no parallelism. If the value
2004        `tf.data.AUTOTUNE` is used, then the number of parallel
2005        calls is set dynamically based on available CPU.
2006      deterministic: (Optional.) When `num_parallel_calls` is specified, if this
2007        boolean is specified (`True` or `False`), it controls the order in which
2008        the transformation produces elements. If set to `False`, the
2009        transformation is allowed to yield elements out of order to trade
2010        determinism for performance. If not specified, the
2011        `tf.data.Options.deterministic` option (`True` by default) controls the
2012        behavior.
2013
2014    Returns:
2015      Dataset: A `Dataset`.
2016    """
2017    if block_length is None:
2018      block_length = 1
2019
2020    if cycle_length is None:
2021      cycle_length = AUTOTUNE
2022
2023    if num_parallel_calls is None or DEBUG_MODE:
2024      if deterministic is not None and not DEBUG_MODE:
2025        warnings.warn("The `deterministic` argument has no effect unless the "
2026                      "`num_parallel_calls` argument is specified.")
2027      return InterleaveDataset(self, map_func, cycle_length, block_length)
2028    else:
2029      return ParallelInterleaveDataset(
2030          self,
2031          map_func,
2032          cycle_length,
2033          block_length,
2034          num_parallel_calls,
2035          deterministic=deterministic)
2036
2037  def filter(self, predicate):
2038    """Filters this dataset according to `predicate`.
2039
2040    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
2041    >>> dataset = dataset.filter(lambda x: x < 3)
2042    >>> list(dataset.as_numpy_iterator())
2043    [1, 2]
2044    >>> # `tf.math.equal(x, y)` is required for equality comparison
2045    >>> def filter_fn(x):
2046    ...   return tf.math.equal(x, 1)
2047    >>> dataset = dataset.filter(filter_fn)
2048    >>> list(dataset.as_numpy_iterator())
2049    [1]
2050
2051    Args:
2052      predicate: A function mapping a dataset element to a boolean.
2053
2054    Returns:
2055      Dataset: The `Dataset` containing the elements of this dataset for which
2056          `predicate` is `True`.
2057    """
2058    return FilterDataset(self, predicate)
2059
2060  def apply(self, transformation_func):
2061    """Applies a transformation function to this dataset.
2062
2063    `apply` enables chaining of custom `Dataset` transformations, which are
2064    represented as functions that take one `Dataset` argument and return a
2065    transformed `Dataset`.
2066
2067    >>> dataset = tf.data.Dataset.range(100)
2068    >>> def dataset_fn(ds):
2069    ...   return ds.filter(lambda x: x < 5)
2070    >>> dataset = dataset.apply(dataset_fn)
2071    >>> list(dataset.as_numpy_iterator())
2072    [0, 1, 2, 3, 4]
2073
2074    Args:
2075      transformation_func: A function that takes one `Dataset` argument and
2076        returns a `Dataset`.
2077
2078    Returns:
2079      Dataset: The `Dataset` returned by applying `transformation_func` to this
2080          dataset.
2081    """
2082    dataset = transformation_func(self)
2083    if not isinstance(dataset, DatasetV2):
2084      raise TypeError(
2085          "`transformation_func` must return a Dataset. Got {}.".format(
2086              dataset))
2087    dataset._input_datasets = [self]  # pylint: disable=protected-access
2088    return dataset
2089
2090  def window(self, size, shift=None, stride=1, drop_remainder=False):
2091    """Returns a dataset of "windows".
2092
2093    Each "window" is a dataset that contains a subset of elements of the
2094    input dataset. These are finite datasets of size `size` (or possibly fewer
2095    if there are not enough input elements to fill the window and
2096    `drop_remainder` evaluates to `False`).
2097
2098    For example:
2099
2100    >>> dataset = tf.data.Dataset.range(7).window(3)
2101    >>> for window in dataset:
2102    ...   print(window)
2103    <...Dataset shapes: (), types: tf.int64>
2104    <...Dataset shapes: (), types: tf.int64>
2105    <...Dataset shapes: (), types: tf.int64>
2106
2107    Since windows are datasets, they can be iterated over:
2108
2109    >>> for window in dataset:
2110    ...   print([item.numpy() for item in window])
2111    [0, 1, 2]
2112    [3, 4, 5]
2113    [6]
2114
2115    #### Shift
2116
2117    The `shift` argument determines the number of input elements to shift
2118    between the start of each window. If windows and elements are both numbered
2119    starting at 0, the first element in window `k` will be element `k * shift`
2120    of the input dataset. In particular, the first element of the first window
2121    will always be the first element of the input dataset.
2122
2123    >>> dataset = tf.data.Dataset.range(7).window(3, shift=1,
2124    ...                                           drop_remainder=True)
2125    >>> for window in dataset:
2126    ...   print(list(window.as_numpy_iterator()))
2127    [0, 1, 2]
2128    [1, 2, 3]
2129    [2, 3, 4]
2130    [3, 4, 5]
2131    [4, 5, 6]
2132
2133    #### Stride
2134
2135    The `stride` argument determines the stride between input elements within a
2136    window.
2137
2138    >>> dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2,
2139    ...                                           drop_remainder=True)
2140    >>> for window in dataset:
2141    ...   print(list(window.as_numpy_iterator()))
2142    [0, 2, 4]
2143    [1, 3, 5]
2144    [2, 4, 6]
2145
2146    #### Nested elements
2147
2148    When the `window` transformation is applied to a dataset whos elements are
2149    nested structures, it produces a dataset where the elements have the same
2150    nested structure but each leaf is replaced by a window. In other words,
2151    the nesting is applied outside of the windows as opposed inside of them.
2152
2153    The type signature is:
2154
2155    ```
2156    def window(
2157        self: Dataset[Nest[T]], ...
2158    ) -> Dataset[Nest[Dataset[T]]]
2159    ```
2160
2161    Applying `window` to a `Dataset` of tuples gives a tuple of windows:
2162
2163    >>> dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5],
2164    ...                                               [6, 7, 8, 9, 10]))
2165    >>> dataset = dataset.window(2)
2166    >>> windows = next(iter(dataset))
2167    >>> windows
2168    (<...Dataset shapes: (), types: tf.int32>,
2169     <...Dataset shapes: (), types: tf.int32>)
2170
2171    >>> def to_numpy(ds):
2172    ...   return list(ds.as_numpy_iterator())
2173    >>>
2174    >>> for windows in dataset:
2175    ...   print(to_numpy(windows[0]), to_numpy(windows[1]))
2176    [1, 2] [6, 7]
2177    [3, 4] [8, 9]
2178    [5] [10]
2179
2180    Applying `window` to a `Dataset` of dictionaries gives a dictionary of
2181    `Datasets`:
2182
2183    >>> dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3],
2184    ...                                               'b': [4, 5, 6],
2185    ...                                               'c': [7, 8, 9]})
2186    >>> dataset = dataset.window(2)
2187    >>> def to_numpy(ds):
2188    ...   return list(ds.as_numpy_iterator())
2189    >>>
2190    >>> for windows in dataset:
2191    ...   print(tf.nest.map_structure(to_numpy, windows))
2192    {'a': [1, 2], 'b': [4, 5], 'c': [7, 8]}
2193    {'a': [3], 'b': [6], 'c': [9]}
2194
2195    #### Flatten a dataset of windows
2196
2197    The `Dataset.flat_map` and `Dataset.interleave` methods can be used to
2198    flatten a dataset of windows into a single dataset.
2199
2200    The argument to `flat_map` is a function that takes an element from the
2201    dataset and returns a `Dataset`. `flat_map` chains together the resulting
2202    datasets sequentially.
2203
2204    For example, to turn each window into a dense tensor:
2205
2206    >>> size = 3
2207    >>> dataset = tf.data.Dataset.range(7).window(size, shift=1,
2208    ...                                           drop_remainder=True)
2209    >>> batched = dataset.flat_map(lambda x:x.batch(3))
2210    >>> for batch in batched:
2211    ...   print(batch.numpy())
2212    [0 1 2]
2213    [1 2 3]
2214    [2 3 4]
2215    [3 4 5]
2216    [4 5 6]
2217
2218    Args:
2219      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
2220        of the input dataset to combine into a window. Must be positive.
2221      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2222        number of input elements by which the window moves in each iteration.
2223        Defaults to `size`. Must be positive.
2224      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2225        stride of the input elements in the sliding window. Must be positive.
2226        The default value of 1 means "retain every input element".
2227      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2228        whether the last windows should be dropped if their size is smaller than
2229        `size`.
2230
2231    Returns:
2232      Dataset: A `Dataset` of (nests of) windows. Each window is a finite
2233        datasets of flat elements.
2234    """
2235    if shift is None:
2236      shift = size
2237    return WindowDataset(self, size, shift, stride, drop_remainder)
2238
2239  def reduce(self, initial_state, reduce_func):
2240    """Reduces the input dataset to a single element.
2241
2242    The transformation calls `reduce_func` successively on every element of
2243    the input dataset until the dataset is exhausted, aggregating information in
2244    its internal state. The `initial_state` argument is used for the initial
2245    state and the final state is returned as the result.
2246
2247    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
2248    5
2249    >>> tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
2250    10
2251
2252    Args:
2253      initial_state: An element representing the initial state of the
2254        transformation.
2255      reduce_func: A function that maps `(old_state, input_element)` to
2256        `new_state`. It must take two arguments and return a new element
2257        The structure of `new_state` must match the structure of
2258        `initial_state`.
2259
2260    Returns:
2261      A dataset element corresponding to the final state of the transformation.
2262
2263    """
2264
2265    with ops.name_scope("initial_state"):
2266      initial_state = structure.normalize_element(initial_state)
2267    state_structure = structure.type_spec_from_value(initial_state)
2268
2269    # Iteratively rerun the reduce function until reaching a fixed point on
2270    # `state_structure`.
2271    need_to_rerun = True
2272    while need_to_rerun:
2273
2274      wrapped_func = StructuredFunctionWrapper(
2275          reduce_func,
2276          "reduce()",
2277          input_structure=(state_structure, self.element_spec),
2278          add_to_graph=False)
2279
2280      # Extract and validate class information from the returned values.
2281      output_classes = wrapped_func.output_classes
2282      state_classes = nest.map_structure(
2283          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
2284          state_structure)
2285      for new_state_class, state_class in zip(
2286          nest.flatten(output_classes), nest.flatten(state_classes)):
2287        if not issubclass(new_state_class, state_class):
2288          raise TypeError(
2289              "The element classes for the new state must match the initial "
2290              "state. Expected %s; got %s." %
2291              (state_classes, wrapped_func.output_classes))
2292
2293      # Extract and validate type information from the returned values.
2294      output_types = wrapped_func.output_types
2295      state_types = nest.map_structure(
2296          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
2297          state_structure)
2298      for new_state_type, state_type in zip(
2299          nest.flatten(output_types), nest.flatten(state_types)):
2300        if new_state_type != state_type:
2301          raise TypeError(
2302              "The element types for the new state must match the initial "
2303              "state. Expected %s; got %s." %
2304              (state_types, wrapped_func.output_types))
2305
2306      # Extract shape information from the returned values.
2307      output_shapes = wrapped_func.output_shapes
2308      state_shapes = nest.map_structure(
2309          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
2310          state_structure)
2311      flat_state_shapes = nest.flatten(state_shapes)
2312      flat_new_state_shapes = nest.flatten(output_shapes)
2313      weakened_state_shapes = [
2314          original.most_specific_compatible_shape(new)
2315          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
2316      ]
2317
2318      need_to_rerun = False
2319      for original_shape, weakened_shape in zip(flat_state_shapes,
2320                                                weakened_state_shapes):
2321        if original_shape.ndims is not None and (
2322            weakened_shape.ndims is None or
2323            original_shape.as_list() != weakened_shape.as_list()):
2324          need_to_rerun = True
2325          break
2326
2327      if need_to_rerun:
2328        # TODO(b/110122868): Support a "most specific compatible structure"
2329        # method for combining structures, to avoid using legacy structures
2330        # here.
2331        state_structure = structure.convert_legacy_structure(
2332            state_types,
2333            nest.pack_sequence_as(state_shapes, weakened_state_shapes),
2334            state_classes)
2335
2336    reduce_func = wrapped_func.function
2337    reduce_func.add_to_graph(ops.get_default_graph())
2338
2339    dataset = self._apply_debug_options()
2340
2341    # pylint: disable=protected-access
2342    return structure.from_compatible_tensor_list(
2343        state_structure,
2344        gen_dataset_ops.reduce_dataset(
2345            dataset._variant_tensor,
2346            structure.to_tensor_list(state_structure, initial_state),
2347            reduce_func.captured_inputs,
2348            f=reduce_func,
2349            output_shapes=structure.get_flat_tensor_shapes(state_structure),
2350            output_types=structure.get_flat_tensor_types(state_structure)))
2351
2352  def get_single_element(self):
2353    """Returns the single element of the `dataset` as a nested structure of tensors.
2354
2355    The function enables you to use a `tf.data.Dataset` in a stateless
2356    "tensor-in tensor-out" expression, without creating an iterator.
2357    This facilitates the ease of data transformation on tensors using the
2358    optimized `tf.data.Dataset` abstraction on top of them.
2359
2360    For example, lets consider a `preprocessing_fn` which would take as an
2361    input the raw features and returns the processed feature along with
2362    it's label.
2363
2364    ```python
2365    def preprocessing_fn(raw_feature):
2366      # ... the raw_feature is preprocessed as per the use-case
2367      return feature
2368
2369    raw_features = ...  # input batch of BATCH_SIZE elements.
2370    dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
2371              .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2372              .batch(BATCH_SIZE))
2373
2374    processed_features = dataset.get_single_element()
2375    ```
2376
2377    In the above example, the `raw_features` tensor of length=BATCH_SIZE
2378    was converted to a `tf.data.Dataset`. Next, each of the `raw_feature` was
2379    mapped using the `preprocessing_fn` and the processed features were
2380    grouped into a single batch. The final `dataset` contains only one element
2381    which is a batch of all the processed features.
2382
2383    NOTE: The `dataset` should contain only one element.
2384
2385    Now, instead of creating an iterator for the `dataset` and retrieving the
2386    batch of features, the `tf.data.get_single_element()` function is used
2387    to skip the iterator creation process and directly output the batch of
2388    features.
2389
2390    This can be particularly useful when your tensor transformations are
2391    expressed as `tf.data.Dataset` operations, and you want to use those
2392    transformations while serving your model.
2393
2394    #### Keras
2395
2396    ```python
2397
2398    model = ... # A pre-built or custom model
2399
2400    class PreprocessingModel(tf.keras.Model):
2401      def __init__(self, model):
2402        super().__init__(self)
2403        self.model = model
2404
2405      @tf.function(input_signature=[...])
2406      def serving_fn(self, data):
2407        ds = tf.data.Dataset.from_tensor_slices(data)
2408        ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2409        ds = ds.batch(batch_size=BATCH_SIZE)
2410        return tf.argmax(self.model(ds.get_single_element()), axis=-1)
2411
2412    preprocessing_model = PreprocessingModel(model)
2413    your_exported_model_dir = ... # save the model to this path.
2414    tf.saved_model.save(preprocessing_model, your_exported_model_dir,
2415                  signatures={'serving_default': preprocessing_model.serving_fn}
2416                  )
2417    ```
2418
2419    #### Estimator
2420
2421    In the case of estimators, you need to generally define a `serving_input_fn`
2422    which would require the features to be processed by the model while
2423    inferencing.
2424
2425    ```python
2426    def serving_input_fn():
2427
2428      raw_feature_spec = ... # Spec for the raw_features
2429      input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
2430          raw_feature_spec, default_batch_size=None)
2431      )
2432      serving_input_receiver = input_fn()
2433      raw_features = serving_input_receiver.features
2434
2435      def preprocessing_fn(raw_feature):
2436        # ... the raw_feature is preprocessed as per the use-case
2437        return feature
2438
2439      dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
2440                .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
2441                .batch(BATCH_SIZE))
2442
2443      processed_features = dataset.get_single_element()
2444
2445      # Please note that the value of `BATCH_SIZE` should be equal to
2446      # the size of the leading dimension of `raw_features`. This ensures
2447      # that `dataset` has only element, which is a pre-requisite for
2448      # using `dataset.get_single_element()`.
2449
2450      return tf.estimator.export.ServingInputReceiver(
2451          processed_features, serving_input_receiver.receiver_tensors)
2452
2453    estimator = ... # A pre-built or custom estimator
2454    estimator.export_saved_model(your_exported_model_dir, serving_input_fn)
2455    ```
2456
2457    Returns:
2458      A nested structure of `tf.Tensor` objects, corresponding to the single
2459      element of `dataset`.
2460
2461    Raises:
2462      InvalidArgumentError: (at runtime) if `dataset` does not contain exactly
2463        one element.
2464    """
2465
2466    return structure.from_compatible_tensor_list(
2467        self.element_spec,
2468        gen_dataset_ops.dataset_to_single_element(self._variant_tensor,
2469                                                  **self._flat_structure))  # pylint: disable=protected-access
2470
2471  def unbatch(self):
2472    """Splits elements of a dataset into multiple elements.
2473
2474    For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
2475    where `B` may vary for each input element, then for each element in the
2476    dataset, the unbatched dataset will contain `B` consecutive elements
2477    of shape `[a0, a1, ...]`.
2478
2479    >>> elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
2480    >>> dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
2481    >>> dataset = dataset.unbatch()
2482    >>> list(dataset.as_numpy_iterator())
2483    [1, 2, 3, 1, 2, 1, 2, 3, 4]
2484
2485    Note: `unbatch` requires a data copy to slice up the batched tensor into
2486    smaller, unbatched tensors. When optimizing performance, try to avoid
2487    unnecessary usage of `unbatch`.
2488
2489    Returns:
2490      A `Dataset`.
2491    """
2492    normalized_dataset = normalize_to_dense(self)
2493    return _UnbatchDataset(normalized_dataset)
2494
2495  def with_options(self, options):
2496    """Returns a new `tf.data.Dataset` with the given options set.
2497
2498    The options are "global" in the sense they apply to the entire dataset.
2499    If options are set multiple times, they are merged as long as different
2500    options do not use different non-default values.
2501
2502    >>> ds = tf.data.Dataset.range(5)
2503    >>> ds = ds.interleave(lambda x: tf.data.Dataset.range(5),
2504    ...                    cycle_length=3,
2505    ...                    num_parallel_calls=3)
2506    >>> options = tf.data.Options()
2507    >>> # This will make the interleave order non-deterministic.
2508    >>> options.deterministic = False
2509    >>> ds = ds.with_options(options)
2510
2511    Args:
2512      options: A `tf.data.Options` that identifies the options the use.
2513
2514    Returns:
2515      Dataset: A `Dataset` with the given options.
2516
2517    Raises:
2518      ValueError: when an option is set more than once to a non-default value
2519    """
2520    return _OptionsDataset(self, options)
2521
2522  def cardinality(self):
2523    """Returns the cardinality of the dataset, if known.
2524
2525    `cardinality` may return `tf.data.INFINITE_CARDINALITY` if the dataset
2526    contains an infinite number of elements or `tf.data.UNKNOWN_CARDINALITY` if
2527    the analysis fails to determine the number of elements in the dataset
2528    (e.g. when the dataset source is a file).
2529
2530    >>> dataset = tf.data.Dataset.range(42)
2531    >>> print(dataset.cardinality().numpy())
2532    42
2533    >>> dataset = dataset.repeat()
2534    >>> cardinality = dataset.cardinality()
2535    >>> print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
2536    True
2537    >>> dataset = dataset.filter(lambda x: True)
2538    >>> cardinality = dataset.cardinality()
2539    >>> print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
2540    True
2541
2542    Returns:
2543      A scalar `tf.int64` `Tensor` representing the cardinality of the dataset.
2544      If the cardinality is infinite or unknown, `cardinality` returns the
2545      named constants `tf.data.INFINITE_CARDINALITY` and
2546      `tf.data.UNKNOWN_CARDINALITY` respectively.
2547    """
2548    return gen_dataset_ops.dataset_cardinality(self._variant_tensor)
2549
2550  def group_by_window(self,
2551                      key_func,
2552                      reduce_func,
2553                      window_size=None,
2554                      window_size_func=None):
2555    """Groups windows of elements by key and reduces them.
2556
2557    This transformation maps each consecutive element in a dataset to a key
2558    using `key_func` and groups the elements by key. It then applies
2559    `reduce_func` to at most `window_size_func(key)` elements matching the same
2560    key. All except the final window for each key will contain
2561    `window_size_func(key)` elements; the final window may be smaller.
2562
2563    You may provide either a constant `window_size` or a window size determined
2564    by the key through `window_size_func`.
2565
2566    >>> dataset = tf.data.Dataset.range(10)
2567    >>> window_size = 5
2568    >>> key_func = lambda x: x%2
2569    >>> reduce_func = lambda key, dataset: dataset.batch(window_size)
2570    >>> dataset = dataset.group_by_window(
2571    ...           key_func=key_func,
2572    ...           reduce_func=reduce_func,
2573    ...           window_size=window_size)
2574    >>> for elem in dataset.as_numpy_iterator():
2575    ...   print(elem)
2576    [0 2 4 6 8]
2577    [1 3 5 7 9]
2578
2579    Args:
2580      key_func: A function mapping a nested structure of tensors (having shapes
2581        and types defined by `self.output_shapes` and `self.output_types`) to a
2582        scalar `tf.int64` tensor.
2583      reduce_func: A function mapping a key and a dataset of up to `window_size`
2584        consecutive elements matching that key to another dataset.
2585      window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
2586        consecutive elements matching the same key to combine in a single batch,
2587        which will be passed to `reduce_func`. Mutually exclusive with
2588        `window_size_func`.
2589      window_size_func: A function mapping a key to a `tf.int64` scalar
2590        `tf.Tensor`, representing the number of consecutive elements matching
2591        the same key to combine in a single batch, which will be passed to
2592        `reduce_func`. Mutually exclusive with `window_size`.
2593
2594    Returns:
2595      A `Dataset`.
2596
2597    Raises:
2598      ValueError: if neither or both of {`window_size`, `window_size_func`} are
2599        passed.
2600    """
2601    if (window_size is not None and window_size_func or
2602        not (window_size is not None or window_size_func)):
2603      raise ValueError("Must pass either window_size or window_size_func.")
2604
2605    if window_size is not None:
2606
2607      def constant_window_func(unused_key):
2608        return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
2609
2610      window_size_func = constant_window_func
2611
2612    assert window_size_func is not None
2613
2614    return _GroupByWindowDataset(self, key_func, reduce_func, window_size_func)
2615
2616  def bucket_by_sequence_length(self,
2617                                element_length_func,
2618                                bucket_boundaries,
2619                                bucket_batch_sizes,
2620                                padded_shapes=None,
2621                                padding_values=None,
2622                                pad_to_bucket_boundary=False,
2623                                no_padding=False,
2624                                drop_remainder=False):
2625    """A transformation that buckets elements in a `Dataset` by length.
2626
2627    Elements of the `Dataset` are grouped together by length and then are padded
2628    and batched.
2629
2630    This is useful for sequence tasks in which the elements have variable
2631    length. Grouping together elements that have similar lengths reduces the
2632    total fraction of padding in a batch which increases training step
2633    efficiency.
2634
2635    Below is an example to bucketize the input data to the 3 buckets
2636    "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2.
2637
2638    >>> elements = [
2639    ...   [0], [1, 2, 3, 4], [5, 6, 7],
2640    ...   [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
2641    >>> dataset = tf.data.Dataset.from_generator(
2642    ...     lambda: elements, tf.int64, output_shapes=[None])
2643    >>> dataset = dataset.bucket_by_sequence_length(
2644    ...         element_length_func=lambda elem: tf.shape(elem)[0],
2645    ...         bucket_boundaries=[3, 5],
2646    ...         bucket_batch_sizes=[2, 2, 2])
2647    >>> for elem in dataset.as_numpy_iterator():
2648    ...   print(elem)
2649    [[1 2 3 4]
2650    [5 6 7 0]]
2651    [[ 7  8  9 10 11  0]
2652    [13 14 15 16 19 20]]
2653    [[ 0  0]
2654    [21 22]]
2655
2656    Args:
2657      element_length_func: function from element in `Dataset` to `tf.int32`,
2658        determines the length of the element, which will determine the bucket it
2659        goes into.
2660      bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
2661      bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
2662        `len(bucket_boundaries) + 1`.
2663      padded_shapes: Nested structure of `tf.TensorShape` to pass to
2664        `tf.data.Dataset.padded_batch`. If not provided, will use
2665        `dataset.output_shapes`, which will result in variable length dimensions
2666        being padded out to the maximum length in each batch.
2667      padding_values: Values to pad with, passed to
2668        `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
2669      pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
2670        size to maximum length in batch. If `True`, will pad dimensions with
2671        unknown size to bucket boundary minus 1 (i.e., the maximum length in
2672        each bucket), and caller must ensure that the source `Dataset` does not
2673        contain any elements with length longer than `max(bucket_boundaries)`.
2674      no_padding: `bool`, indicates whether to pad the batch features (features
2675        need to be either of type `tf.sparse.SparseTensor` or of same shape).
2676      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
2677        whether the last batch should be dropped in the case it has fewer than
2678        `batch_size` elements; the default behavior is not to drop the smaller
2679        batch.
2680
2681    Returns:
2682      A `Dataset`.
2683
2684    Raises:
2685      ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
2686    """
2687    if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
2688      raise ValueError(
2689          "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
2690
2691    batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
2692
2693    def element_to_bucket_id(*args):
2694      """Return int64 id of the length bucket for this element."""
2695      seq_length = element_length_func(*args)
2696
2697      boundaries = list(bucket_boundaries)
2698      buckets_min = [np.iinfo(np.int32).min] + boundaries
2699      buckets_max = boundaries + [np.iinfo(np.int32).max]
2700      conditions_c = math_ops.logical_and(
2701          math_ops.less_equal(buckets_min, seq_length),
2702          math_ops.less(seq_length, buckets_max))
2703      bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
2704
2705      return bucket_id
2706
2707    def window_size_fn(bucket_id):
2708      # The window size is set to the batch size for this bucket
2709      window_size = batch_sizes[bucket_id]
2710      return window_size
2711
2712    def make_padded_shapes(shapes, none_filler=None):
2713      padded = []
2714      for shape in nest.flatten(shapes):
2715        shape = tensor_shape.TensorShape(shape)
2716        shape = [
2717            none_filler if tensor_shape.dimension_value(d) is None else d
2718            for d in shape
2719        ]
2720        padded.append(shape)
2721      return nest.pack_sequence_as(shapes, padded)
2722
2723    def batching_fn(bucket_id, grouped_dataset):
2724      """Batch elements in dataset."""
2725      batch_size = window_size_fn(bucket_id)
2726      if no_padding:
2727        return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder)
2728      none_filler = None
2729      if pad_to_bucket_boundary:
2730        err_msg = ("When pad_to_bucket_boundary=True, elements must have "
2731                   "length < max(bucket_boundaries).")
2732        check = check_ops.assert_less(
2733            bucket_id,
2734            constant_op.constant(
2735                len(bucket_batch_sizes) - 1, dtype=dtypes.int64),
2736            message=err_msg)
2737        with ops.control_dependencies([check]):
2738          boundaries = constant_op.constant(
2739              bucket_boundaries, dtype=dtypes.int64)
2740          bucket_boundary = boundaries[bucket_id]
2741          none_filler = bucket_boundary - 1
2742      input_shapes = get_legacy_output_shapes(grouped_dataset)
2743      shapes = make_padded_shapes(
2744          padded_shapes or input_shapes, none_filler=none_filler)
2745      return grouped_dataset.padded_batch(
2746          batch_size, shapes, padding_values, drop_remainder=drop_remainder)
2747
2748    return self.group_by_window(
2749        key_func=element_to_bucket_id,
2750        reduce_func=batching_fn,
2751        window_size_func=window_size_fn)
2752
2753  @staticmethod
2754  def random(seed=None):
2755    """Creates a `Dataset` of pseudorandom values.
2756
2757    The dataset generates a sequence of uniformly distributed integer values.
2758
2759    >>> ds1 = tf.data.Dataset.random(seed=4).take(10)
2760    >>> ds2 = tf.data.Dataset.random(seed=4).take(10)
2761    >>> print(list(ds2.as_numpy_iterator())==list(ds2.as_numpy_iterator()))
2762    True
2763
2764    Args:
2765      seed: (Optional) If specified, the dataset produces a deterministic
2766        sequence of values.
2767
2768    Returns:
2769      Dataset: A `Dataset`.
2770    """
2771    return RandomDataset(seed=seed)
2772
2773  def snapshot(self,
2774               path,
2775               compression="AUTO",
2776               reader_func=None,
2777               shard_func=None):
2778    """API to persist the output of the input dataset.
2779
2780    The snapshot API allows users to transparently persist the output of their
2781    preprocessing pipeline to disk, and materialize the pre-processed data on a
2782    different training run.
2783
2784    This API enables repeated preprocessing steps to be consolidated, and allows
2785    re-use of already processed data, trading off disk storage and network
2786    bandwidth for freeing up more valuable CPU resources and accelerator compute
2787    time.
2788
2789    https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
2790    has detailed design documentation of this feature.
2791
2792    Users can specify various options to control the behavior of snapshot,
2793    including how snapshots are read from and written to by passing in
2794    user-defined functions to the `reader_func` and `shard_func` parameters.
2795
2796    `shard_func` is a user specified function that maps input elements to
2797    snapshot shards.
2798
2799    Users may want to specify this function to control how snapshot files should
2800    be written to disk. Below is an example of how a potential `shard_func`
2801    could be written.
2802
2803    ```python
2804    dataset = ...
2805    dataset = dataset.enumerate()
2806    dataset = dataset.snapshot("/path/to/snapshot/dir",
2807        shard_func=lambda x, y: x % NUM_SHARDS, ...)
2808    dataset = dataset.map(lambda x, y: y)
2809    ```
2810
2811    `reader_func` is a user specified function that accepts a single argument:
2812    (1) a Dataset of Datasets, each representing a "split" of elements of the
2813    original dataset. The cardinality of the input dataset matches the
2814    number of the shards specified in the `shard_func` (see above). The function
2815    should return a Dataset of elements of the original dataset.
2816
2817    Users may want specify this function to control how snapshot files should be
2818    read from disk, including the amount of shuffling and parallelism.
2819
2820    Here is an example of a standard reader function a user can define. This
2821    function enables both dataset shuffling and parallel reading of datasets:
2822
2823    ```python
2824    def user_reader_func(datasets):
2825      # shuffle the datasets splits
2826      datasets = datasets.shuffle(NUM_CORES)
2827      # read datasets in parallel and interleave their elements
2828      return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
2829
2830    dataset = dataset.snapshot("/path/to/snapshot/dir",
2831        reader_func=user_reader_func)
2832    ```
2833
2834    By default, snapshot parallelizes reads by the number of cores available on
2835    the system, but will not attempt to shuffle the data.
2836
2837    Args:
2838      path: Required. A directory to use for storing / loading the snapshot to /
2839        from.
2840      compression: Optional. The type of compression to apply to the snapshot
2841        written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
2842        Defaults to `AUTO`, which attempts to pick an appropriate compression
2843        algorithm for the dataset.
2844      reader_func: Optional. A function to control how to read data from
2845        snapshot shards.
2846      shard_func: Optional. A function to control how to shard data when writing
2847        a snapshot.
2848
2849    Returns:
2850      A `Dataset`.
2851    """
2852
2853    project_func = None
2854    input_dataset = self
2855    if shard_func is None:
2856      input_dataset = input_dataset.enumerate()
2857      # This sets the amount of parallelism based on the number of CPU cores on
2858      # the machine where this Python code is executed, which may differ from
2859      # the number of CPU cores where the input pipeline graph is actually
2860      # executed (e.g. remote Cloud TPU workers).
2861      local_shard_func = lambda index, _: index % multiprocessing.cpu_count()
2862      project_func = lambda _, elem: elem
2863    else:
2864      local_shard_func = shard_func
2865    dataset = _SnapshotDataset(
2866        input_dataset=input_dataset,
2867        path=path,
2868        compression=compression,
2869        reader_func=reader_func,
2870        # This will not do the right thing where the graph is built on a
2871        # different machine than the executor (e.g. Cloud TPUs).
2872        shard_func=local_shard_func)
2873    if project_func is not None:
2874      dataset = dataset.map(project_func)
2875    return dataset
2876
2877  def scan(self, initial_state, scan_func):
2878    """A transformation that scans a function across an input dataset.
2879
2880    This transformation is a stateful relative of `tf.data.Dataset.map`.
2881    In addition to mapping `scan_func` across the elements of the input dataset,
2882    `scan()` accumulates one or more state tensors, whose initial values are
2883    `initial_state`.
2884
2885    >>> dataset = tf.data.Dataset.range(10)
2886    >>> initial_state = tf.constant(0, dtype=tf.int64)
2887    >>> scan_func = lambda state, i: (state + i, state + i)
2888    >>> dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func)
2889    >>> list(dataset.as_numpy_iterator())
2890    [0, 1, 3, 6, 10, 15, 21, 28, 36, 45]
2891
2892    Args:
2893      initial_state: A nested structure of tensors, representing the initial
2894        state of the accumulator.
2895      scan_func: A function that maps `(old_state, input_element)` to
2896        `(new_state, output_element)`. It must take two arguments and return a
2897        pair of nested structures of tensors. The `new_state` must match the
2898        structure of `initial_state`.
2899
2900    Returns:
2901      A `Dataset`.
2902    """
2903
2904    return _ScanDataset(self, initial_state=initial_state, scan_func=scan_func)
2905
2906  def take_while(self, predicate):
2907    """A transformation that stops dataset iteration based on a `predicate`.
2908
2909    >>> dataset = tf.data.Dataset.range(10)
2910    >>> dataset = dataset.take_while(lambda x: x < 5)
2911    >>> list(dataset.as_numpy_iterator())
2912    [0, 1, 2, 3, 4]
2913
2914    Args:
2915      predicate: A function that maps a nested structure of tensors (having
2916        shapes and types defined by `self.output_shapes` and
2917        `self.output_types`) to a scalar `tf.bool` tensor.
2918
2919    Returns:
2920      A `Dataset`.
2921    """
2922
2923    return _TakeWhileDataset(self, predicate)
2924
2925  def unique(self):
2926    """A transformation that discards duplicate elements of a `Dataset`.
2927
2928    Use this transformation to produce a dataset that contains one instance of
2929    each unique element in the input. For example:
2930
2931    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
2932    >>> dataset = dataset.unique()
2933    >>> sorted(list(dataset.as_numpy_iterator()))
2934    [1, 2, 37]
2935
2936    Note: This transformation only supports datasets which fit into memory
2937    and have elements of either `tf.int32`, `tf.int64` or `tf.string` type.
2938
2939    Returns:
2940      A `Dataset`.
2941    """
2942
2943    return _UniqueDataset(self)
2944
2945  def rejection_resample(self,
2946                         class_func,
2947                         target_dist,
2948                         initial_dist=None,
2949                         seed=None):
2950    """A transformation that resamples a dataset to achieve a target distribution.
2951
2952    Lets consider the following example where a dataset with an initial data
2953    distribution of `init_dist` needs to be resampled into a dataset with
2954    `target_dist` distribution.
2955
2956    >>> import collections
2957    >>> initial_dist = [0.5, 0.5]
2958    >>> target_dist = [0.6, 0.4]
2959    >>> num_classes = len(initial_dist)
2960    >>> num_samples = 100000
2961    >>> data_np = np.random.choice(num_classes, num_samples, p=initial_dist)
2962    >>> dataset = tf.data.Dataset.from_tensor_slices(data_np)
2963    >>> x = collections.defaultdict(int)
2964    >>> for i in dataset:
2965    ...   x[i.numpy()] += 1
2966
2967    The value of `x` will be close to `{0: 50000, 1: 50000}` as per the
2968    `initial_dist` distribution.
2969
2970    >>> dataset = dataset.rejection_resample(
2971    ...    class_func=lambda x: x % 2,
2972    ...    target_dist=target_dist,
2973    ...    initial_dist=initial_dist)
2974
2975    >>> y = collections.defaultdict(int)
2976    >>> for i in dataset:
2977    ...   cls, _ = i
2978    ...   y[cls.numpy()] += 1
2979
2980    The value of `y` will be now be close to `{0: 75000, 1: 50000}` thus
2981    satisfying the `target_dist` distribution.
2982
2983    Args:
2984      class_func: A function mapping an element of the input dataset to a scalar
2985        `tf.int32` tensor. Values should be in `[0, num_classes)`.
2986      target_dist: A floating point type tensor, shaped `[num_classes]`.
2987      initial_dist: (Optional.)  A floating point type tensor, shaped
2988        `[num_classes]`.  If not provided, the true class distribution is
2989        estimated live in a streaming fashion.
2990      seed: (Optional.) Python integer seed for the resampler.
2991
2992    Returns:
2993      A `Dataset`
2994    """
2995
2996    target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
2997    target_dist_t = math_ops.cast(target_dist_t, dtypes.float32)
2998
2999    # Get initial distribution.
3000    if initial_dist is not None:
3001      initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
3002      initial_dist_t = math_ops.cast(initial_dist_t, dtypes.float32)
3003      acceptance_dist, prob_of_original = (
3004          _calculate_acceptance_probs_with_mixing(initial_dist_t,
3005                                                  target_dist_t))
3006      initial_dist_ds = DatasetV2.from_tensors(initial_dist_t).repeat()
3007      acceptance_dist_ds = DatasetV2.from_tensors(acceptance_dist).repeat()
3008      prob_of_original_ds = DatasetV2.from_tensors(prob_of_original).repeat()
3009    else:
3010      initial_dist_ds = _estimate_initial_dist_ds(target_dist_t,
3011                                                  self.map(class_func))
3012      acceptance_and_original_prob_ds = initial_dist_ds.map(
3013          lambda initial: _calculate_acceptance_probs_with_mixing(  # pylint: disable=g-long-lambda
3014              initial, target_dist_t))
3015      acceptance_dist_ds = acceptance_and_original_prob_ds.map(
3016          lambda accept_prob, _: accept_prob)
3017      prob_of_original_ds = acceptance_and_original_prob_ds.map(
3018          lambda _, prob_original: prob_original)
3019    filtered_ds = _filter_ds(self, acceptance_dist_ds, initial_dist_ds,
3020                             class_func, seed)
3021    # Prefetch filtered dataset for speed.
3022    filtered_ds = filtered_ds.prefetch(3)
3023
3024    prob_original_static = _get_prob_original_static(
3025        initial_dist_t, target_dist_t) if initial_dist is not None else None
3026
3027    def add_class_value(*x):
3028      if len(x) == 1:
3029        return class_func(*x), x[0]
3030      else:
3031        return class_func(*x), x
3032
3033    if prob_original_static == 1:
3034      return self.map(add_class_value)
3035    elif prob_original_static == 0:
3036      return filtered_ds
3037    else:
3038      return interleave_ops.sample_from_datasets(
3039          [self.map(add_class_value), filtered_ds],
3040          weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
3041          seed=seed,
3042          stop_on_empty_dataset=True)
3043
3044
3045@tf_export(v1=["data.Dataset"])
3046class DatasetV1(DatasetV2):
3047  """Represents a potentially large set of elements.
3048
3049  A `Dataset` can be used to represent an input pipeline as a
3050  collection of elements and a "logical plan" of transformations that act on
3051  those elements.
3052  """
3053
3054  def __init__(self):
3055    try:
3056      variant_tensor = self._as_variant_tensor()
3057    except AttributeError as e:
3058      if "_as_variant_tensor" in str(e):
3059        raise AttributeError("Please use _variant_tensor instead of "
3060                             "_as_variant_tensor() to obtain the variant "
3061                             "associated with a dataset")
3062      raise AttributeError("{}: A likely cause of this error is that the super "
3063                           "call for this dataset is not the last line of the "
3064                           "__init__ method. The base class causes the "
3065                           "_as_variant_tensor call in its constructor and "
3066                           "if that uses attributes defined in the __init__ "
3067                           "method, those attrs need to be defined before the "
3068                           "super call.".format(e))
3069    super(DatasetV1, self).__init__(variant_tensor)
3070
3071  @abc.abstractmethod
3072  def _as_variant_tensor(self):
3073    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
3074
3075    Returns:
3076      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
3077    """
3078    raise NotImplementedError("Dataset._as_variant_tensor")
3079
3080  @deprecation.deprecated(
3081      None, "This is a deprecated API that should only be used in TF 1 graph "
3082      "mode and legacy TF 2 graph mode available through `tf.compat.v1`. In "
3083      "all other situations -- namely, eager mode and inside `tf.function` -- "
3084      "you can consume dataset elements using `for elem in dataset: ...` or "
3085      "by explicitly creating iterator via `iterator = iter(dataset)` and "
3086      "fetching its elements via `values = next(iterator)`. Furthermore, "
3087      "this API is not available in TF 2. During the transition from TF 1 "
3088      "to TF 2 you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)` "
3089      "to create a TF 1 graph mode style iterator for a dataset created "
3090      "through TF 2 APIs. Note that this should be a transient state of your "
3091      "code base as there are in general no guarantees about the "
3092      "interoperability of TF 1 and TF 2 code.")
3093  def make_one_shot_iterator(self):
3094    """Creates an iterator for elements of this dataset.
3095
3096    Note: The returned iterator will be initialized automatically.
3097    A "one-shot" iterator does not currently support re-initialization. For
3098    that see `make_initializable_iterator`.
3099
3100    Example:
3101
3102    ```python
3103    # Building graph ...
3104    dataset = ...
3105    next_value = dataset.make_one_shot_iterator().get_next()
3106
3107    # ... from within a session ...
3108    try:
3109      while True:
3110        value = sess.run(next_value)
3111        ...
3112    except tf.errors.OutOfRangeError:
3113        pass
3114    ```
3115
3116    Returns:
3117      An `tf.data.Iterator` for elements of this dataset.
3118    """
3119    return self._make_one_shot_iterator()
3120
3121  def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
3122    if context.executing_eagerly():
3123      with ops.colocate_with(self._variant_tensor):
3124        return iterator_ops.OwnedIterator(self)
3125
3126    _ensure_same_dataset_graph(self)
3127    # Some ops (e.g. dataset ops) are marked as stateful but are stil safe to
3128    # to capture by value. We must allowlist these ops so that the capturing
3129    # logic captures the ops instead of raising an exception.
3130    allowlisted_stateful_ops = traverse.obtain_capture_by_value_ops(self)
3131    graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
3132
3133    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
3134    # a 0-argument function.
3135    @function.Defun(
3136        capture_by_value=True,
3137        allowlisted_stateful_ops=allowlisted_stateful_ops)
3138    def _make_dataset():
3139      """Factory function for a dataset."""
3140      # NOTE(mrry): `Defun` does not capture the graph-level seed from the
3141      # enclosing graph, so if a graph-level seed is present we set the local
3142      # graph seed based on a combination of the graph- and op-level seeds.
3143      if graph_level_seed is not None:
3144        assert op_level_seed is not None
3145        core_random_seed.set_random_seed(
3146            (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
3147
3148      dataset = self._apply_debug_options()
3149      return dataset._variant_tensor  # pylint: disable=protected-access
3150
3151    try:
3152      _make_dataset.add_to_graph(ops.get_default_graph())
3153    except ValueError as err:
3154      if "Cannot capture a stateful node" in str(err):
3155        raise ValueError(
3156            "Failed to create a one-shot iterator for a dataset. "
3157            "`Dataset.make_one_shot_iterator()` does not support datasets that "
3158            "capture stateful objects, such as a `Variable` or `LookupTable`. "
3159            "In these cases, use `Dataset.make_initializable_iterator()`. "
3160            "(Original error: %s)" % err)
3161      else:
3162        six.reraise(ValueError, err)
3163
3164    with ops.colocate_with(self._variant_tensor):
3165      # pylint: disable=protected-access
3166      return iterator_ops.Iterator(
3167          gen_dataset_ops.one_shot_iterator(
3168              dataset_factory=_make_dataset, **self._flat_structure), None,
3169          get_legacy_output_types(self), get_legacy_output_shapes(self),
3170          get_legacy_output_classes(self))
3171
3172  @deprecation.deprecated(
3173      None, "This is a deprecated API that should only be used in TF 1 graph "
3174      "mode and legacy TF 2 graph mode available through `tf.compat.v1`. "
3175      "In all other situations -- namely, eager mode and inside `tf.function` "
3176      "-- you can consume dataset elements using `for elem in dataset: ...` "
3177      "or by explicitly creating iterator via `iterator = iter(dataset)` "
3178      "and fetching its elements via `values = next(iterator)`. "
3179      "Furthermore, this API is not available in TF 2. During the transition "
3180      "from TF 1 to TF 2 you can use "
3181      "`tf.compat.v1.data.make_initializable_iterator(dataset)` to create a TF "
3182      "1 graph mode style iterator for a dataset created through TF 2 APIs. "
3183      "Note that this should be a transient state of your code base as there "
3184      "are in general no guarantees about the interoperability of TF 1 and TF "
3185      "2 code.")
3186  def make_initializable_iterator(self, shared_name=None):
3187    """Creates an iterator for elements of this dataset.
3188
3189    Note: The returned iterator will be in an uninitialized state,
3190    and you must run the `iterator.initializer` operation before using it:
3191
3192    ```python
3193    # Building graph ...
3194    dataset = ...
3195    iterator = dataset.make_initializable_iterator()
3196    next_value = iterator.get_next()  # This is a Tensor.
3197
3198    # ... from within a session ...
3199    sess.run(iterator.initializer)
3200    try:
3201      while True:
3202        value = sess.run(next_value)
3203        ...
3204    except tf.errors.OutOfRangeError:
3205        pass
3206    ```
3207
3208    Args:
3209      shared_name: (Optional.) If non-empty, the returned iterator will be
3210        shared under the given name across multiple sessions that share the same
3211        devices (e.g. when using a remote server).
3212
3213    Returns:
3214      A `tf.data.Iterator` for elements of this dataset.
3215
3216    Raises:
3217      RuntimeError: If eager execution is enabled.
3218    """
3219    return self._make_initializable_iterator(shared_name)
3220
3221  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=missing-docstring
3222    if context.executing_eagerly():
3223      raise RuntimeError(
3224          "dataset.make_initializable_iterator is not supported when eager "
3225          "execution is enabled. Use `for element in dataset` instead.")
3226    _ensure_same_dataset_graph(self)
3227    dataset = self._apply_debug_options()
3228    if shared_name is None:
3229      shared_name = ""
3230
3231    with ops.colocate_with(self._variant_tensor):
3232      iterator_resource = gen_dataset_ops.iterator_v2(
3233          container="", shared_name=shared_name, **self._flat_structure)
3234
3235      initializer = gen_dataset_ops.make_iterator(
3236          dataset._variant_tensor,  # pylint: disable=protected-access
3237          iterator_resource)
3238
3239      # pylint: disable=protected-access
3240      return iterator_ops.Iterator(iterator_resource, initializer,
3241                                   get_legacy_output_types(dataset),
3242                                   get_legacy_output_shapes(dataset),
3243                                   get_legacy_output_classes(dataset))
3244
3245  @property
3246  @deprecation.deprecated(
3247      None, "Use `tf.compat.v1.data.get_output_classes(dataset)`.")
3248  def output_classes(self):
3249    """Returns the class of each component of an element of this dataset.
3250
3251    Returns:
3252      A (nested) structure of Python `type` objects corresponding to each
3253      component of an element of this dataset.
3254    """
3255    return nest.map_structure(
3256        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
3257        self.element_spec)
3258
3259  @property
3260  @deprecation.deprecated(
3261      None, "Use `tf.compat.v1.data.get_output_shapes(dataset)`.")
3262  def output_shapes(self):
3263    """Returns the shape of each component of an element of this dataset.
3264
3265    Returns:
3266      A (nested) structure of `tf.TensorShape` objects corresponding to each
3267      component of an element of this dataset.
3268    """
3269    return nest.map_structure(
3270        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
3271        self.element_spec)
3272
3273  @property
3274  @deprecation.deprecated(
3275      None, "Use `tf.compat.v1.data.get_output_types(dataset)`.")
3276  def output_types(self):
3277    """Returns the type of each component of an element of this dataset.
3278
3279    Returns:
3280      A (nested) structure of `tf.DType` objects corresponding to each component
3281      of an element of this dataset.
3282    """
3283    return nest.map_structure(
3284        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
3285        self.element_spec)
3286
3287  @property
3288  def element_spec(self):
3289    # TODO(b/110122868): Remove this override once all `Dataset` instances
3290    # implement `element_structure`.
3291    return structure.convert_legacy_structure(
3292        self.output_types, self.output_shapes, self.output_classes)
3293
3294  @staticmethod
3295  @functools.wraps(DatasetV2.from_tensors)
3296  def from_tensors(tensors):
3297    return DatasetV1Adapter(DatasetV2.from_tensors(tensors))
3298
3299  @staticmethod
3300  @functools.wraps(DatasetV2.from_tensor_slices)
3301  def from_tensor_slices(tensors):
3302    return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors))
3303
3304  @staticmethod
3305  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
3306  def from_sparse_tensor_slices(sparse_tensor):
3307    """Splits each rank-N `tf.sparse.SparseTensor` in this dataset row-wise.
3308
3309    Args:
3310      sparse_tensor: A `tf.sparse.SparseTensor`.
3311
3312    Returns:
3313      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
3314    """
3315    return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor))
3316
3317  @staticmethod
3318  @functools.wraps(DatasetV2.from_generator)
3319  @deprecation.deprecated_args(None, "Use output_signature instead",
3320                               "output_types", "output_shapes")
3321  def from_generator(generator,
3322                     output_types=None,
3323                     output_shapes=None,
3324                     args=None,
3325                     output_signature=None):
3326    # Calling DatasetV2.from_generator with output_shapes or output_types is
3327    # deprecated, but this is already checked by the decorator on this function.
3328    with deprecation.silence():
3329      return DatasetV1Adapter(
3330          DatasetV2.from_generator(generator, output_types, output_shapes, args,
3331                                   output_signature))
3332
3333  @staticmethod
3334  @functools.wraps(DatasetV2.range)
3335  def range(*args, **kwargs):
3336    return DatasetV1Adapter(DatasetV2.range(*args, **kwargs))
3337
3338  @staticmethod
3339  @functools.wraps(DatasetV2.zip)
3340  def zip(datasets):
3341    return DatasetV1Adapter(DatasetV2.zip(datasets))
3342
3343  @functools.wraps(DatasetV2.concatenate)
3344  def concatenate(self, dataset):
3345    return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset))
3346
3347  @functools.wraps(DatasetV2.prefetch)
3348  def prefetch(self, buffer_size):
3349    return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size))
3350
3351  @staticmethod
3352  @functools.wraps(DatasetV2.list_files)
3353  def list_files(file_pattern, shuffle=None, seed=None):
3354    return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed))
3355
3356  @functools.wraps(DatasetV2.repeat)
3357  def repeat(self, count=None):
3358    return DatasetV1Adapter(super(DatasetV1, self).repeat(count))
3359
3360  @functools.wraps(DatasetV2.shuffle)
3361  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
3362    return DatasetV1Adapter(super(DatasetV1, self).shuffle(
3363        buffer_size, seed, reshuffle_each_iteration))
3364
3365  @functools.wraps(DatasetV2.cache)
3366  def cache(self, filename=""):
3367    return DatasetV1Adapter(super(DatasetV1, self).cache(filename))
3368
3369  @functools.wraps(DatasetV2.take)
3370  def take(self, count):
3371    return DatasetV1Adapter(super(DatasetV1, self).take(count))
3372
3373  @functools.wraps(DatasetV2.skip)
3374  def skip(self, count):
3375    return DatasetV1Adapter(super(DatasetV1, self).skip(count))
3376
3377  @functools.wraps(DatasetV2.shard)
3378  def shard(self, num_shards, index):
3379    return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
3380
3381  @functools.wraps(DatasetV2.batch)
3382  def batch(self,
3383            batch_size,
3384            drop_remainder=False,
3385            num_parallel_calls=None,
3386            deterministic=None):
3387    return DatasetV1Adapter(
3388        super(DatasetV1, self).batch(batch_size, drop_remainder,
3389                                     num_parallel_calls, deterministic))
3390
3391  @functools.wraps(DatasetV2.padded_batch)
3392  def padded_batch(self,
3393                   batch_size,
3394                   padded_shapes=None,
3395                   padding_values=None,
3396                   drop_remainder=False):
3397    return DatasetV1Adapter(
3398        super(DatasetV1, self).padded_batch(batch_size, padded_shapes,
3399                                            padding_values, drop_remainder))
3400
3401  @functools.wraps(DatasetV2.map)
3402  def map(self, map_func, num_parallel_calls=None, deterministic=None):
3403    if num_parallel_calls is None or DEBUG_MODE:
3404      return DatasetV1Adapter(
3405          MapDataset(self, map_func, preserve_cardinality=False))
3406    else:
3407      return DatasetV1Adapter(
3408          ParallelMapDataset(
3409              self,
3410              map_func,
3411              num_parallel_calls,
3412              deterministic,
3413              preserve_cardinality=False))
3414
3415  @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
3416  def map_with_legacy_function(self,
3417                               map_func,
3418                               num_parallel_calls=None,
3419                               deterministic=None):
3420    """Maps `map_func` across the elements of this dataset.
3421
3422    Note: This is an escape hatch for existing uses of `map` that do not work
3423    with V2 functions. New uses are strongly discouraged and existing uses
3424    should migrate to `map` as this method will be removed in V2.
3425
3426    Args:
3427      map_func: A function mapping a (nested) structure of tensors (having
3428        shapes and types defined by `self.output_shapes` and
3429        `self.output_types`) to another (nested) structure of tensors.
3430      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
3431        representing the number elements to process asynchronously in parallel.
3432        If not specified, elements will be processed sequentially. If the value
3433        `tf.data.AUTOTUNE` is used, then the number of parallel
3434        calls is set dynamically based on available CPU.
3435      deterministic: (Optional.) When `num_parallel_calls` is specified, this
3436        boolean controls the order in which the transformation produces
3437        elements. If set to `False`, the transformation is allowed to yield
3438        elements out of order to trade determinism for performance. If not
3439        specified, the `tf.data.Options.deterministic` option (`True` by
3440        default) controls the behavior.
3441
3442    Returns:
3443      Dataset: A `Dataset`.
3444    """
3445    if num_parallel_calls is None:
3446      if deterministic is not None:
3447        warnings.warn("The `deterministic` argument has no effect unless the "
3448                      "`num_parallel_calls` argument is specified.")
3449      return DatasetV1Adapter(
3450          MapDataset(
3451              self,
3452              map_func,
3453              preserve_cardinality=False,
3454              use_legacy_function=True))
3455    else:
3456      return DatasetV1Adapter(
3457          ParallelMapDataset(
3458              self,
3459              map_func,
3460              num_parallel_calls,
3461              deterministic,
3462              preserve_cardinality=False,
3463              use_legacy_function=True))
3464
3465  @functools.wraps(DatasetV2.flat_map)
3466  def flat_map(self, map_func):
3467    return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func))
3468
3469  @functools.wraps(DatasetV2.interleave)
3470  def interleave(self,
3471                 map_func,
3472                 cycle_length=None,
3473                 block_length=None,
3474                 num_parallel_calls=None,
3475                 deterministic=None):
3476    return DatasetV1Adapter(
3477        super(DatasetV1, self).interleave(map_func, cycle_length, block_length,
3478                                          num_parallel_calls, deterministic))
3479
3480  @functools.wraps(DatasetV2.filter)
3481  def filter(self, predicate):
3482    return DatasetV1Adapter(super(DatasetV1, self).filter(predicate))
3483
3484  @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
3485  def filter_with_legacy_function(self, predicate):
3486    """Filters this dataset according to `predicate`.
3487
3488    Note: This is an escape hatch for existing uses of `filter` that do not work
3489    with V2 functions. New uses are strongly discouraged and existing uses
3490    should migrate to `filter` as this method will be removed in V2.
3491
3492    Args:
3493      predicate: A function mapping a (nested) structure of tensors (having
3494        shapes and types defined by `self.output_shapes` and
3495        `self.output_types`) to a scalar `tf.bool` tensor.
3496
3497    Returns:
3498      Dataset: The `Dataset` containing the elements of this dataset for which
3499          `predicate` is `True`.
3500    """
3501    return FilterDataset(self, predicate, use_legacy_function=True)
3502
3503  @functools.wraps(DatasetV2.apply)
3504  def apply(self, transformation_func):
3505    return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
3506
3507  @functools.wraps(DatasetV2.window)
3508  def window(self, size, shift=None, stride=1, drop_remainder=False):
3509    return DatasetV1Adapter(super(DatasetV1, self).window(
3510        size, shift, stride, drop_remainder))
3511
3512  @functools.wraps(DatasetV2.unbatch)
3513  def unbatch(self):
3514    return DatasetV1Adapter(super(DatasetV1, self).unbatch())
3515
3516  @functools.wraps(DatasetV2.with_options)
3517  def with_options(self, options):
3518    return DatasetV1Adapter(super(DatasetV1, self).with_options(options))
3519
3520
3521if tf2.enabled():
3522  Dataset = DatasetV2
3523else:
3524  Dataset = DatasetV1
3525
3526
3527class DatasetV1Adapter(DatasetV1):
3528  """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
3529
3530  def __init__(self, dataset):
3531    self._dataset = dataset
3532    super(DatasetV1Adapter, self).__init__()
3533
3534  def _as_variant_tensor(self):
3535    return self._dataset._variant_tensor  # pylint: disable=protected-access
3536
3537  def _inputs(self):
3538    return self._dataset._inputs()  # pylint: disable=protected-access
3539
3540  def _functions(self):
3541    return self._dataset._functions()  # pylint: disable=protected-access
3542
3543  def options(self):
3544    return self._dataset.options()
3545
3546  @property
3547  def element_spec(self):
3548    return self._dataset.element_spec  # pylint: disable=protected-access
3549
3550  def __iter__(self):
3551    return iter(self._dataset)
3552
3553
3554def _ensure_same_dataset_graph(dataset):
3555  """Walks the dataset graph to ensure all datasets come from the same graph."""
3556  # pylint: disable=protected-access
3557  current_graph = ops.get_default_graph()
3558  bfs_q = Queue.Queue()
3559  bfs_q.put(dataset)
3560  visited = []
3561  while not bfs_q.empty():
3562    ds = bfs_q.get()
3563    visited.append(ds)
3564    ds_graph = ds._graph
3565    if current_graph != ds_graph:
3566      raise ValueError(
3567          "The graph (" + str(current_graph) + ") of the iterator is different "
3568          "from the graph (" + str(ds_graph) + ") the dataset: " +
3569          str(ds._variant_tensor) + " was  created in. If you are using the "
3570          "Estimator API, make sure that no part of the dataset returned by "
3571          "the `input_fn` function is defined outside the `input_fn` function. "
3572          "Please ensure that all datasets in the pipeline are created in the "
3573          "same graph as the iterator.")
3574    for input_ds in ds._inputs():
3575      if input_ds not in visited:
3576        bfs_q.put(input_ds)
3577
3578
3579@tf_export(v1=["data.make_one_shot_iterator"])
3580def make_one_shot_iterator(dataset):
3581  """Creates an iterator for elements of `dataset`.
3582
3583  Note: The returned iterator will be initialized automatically.
3584  A "one-shot" iterator does not support re-initialization.
3585
3586  Args:
3587    dataset: A `tf.data.Dataset`.
3588
3589  Returns:
3590    A `tf.data.Iterator` for elements of `dataset`.
3591
3592  @compatibility(TF2)
3593  This is a legacy API for consuming dataset elements and should only be used
3594  during transition from TF 1 to TF 2. Note that using this API should be
3595  a transient state of your code base as there are in general no guarantees
3596  about the interoperability of TF 1 and TF 2 code.
3597
3598  In TF 2 datasets are Python iterables which means you can consume their
3599  elements using `for elem in dataset: ...` or by explicitly creating iterator
3600  via `iterator = iter(dataset)` and fetching its elements via
3601  `values = next(iterator)`.
3602  @end_compatibility
3603  """
3604  try:
3605    # Call the defined `_make_one_shot_iterator()` if there is one, because some
3606    # datasets (e.g. for prefetching) override its behavior.
3607    return dataset._make_one_shot_iterator()  # pylint: disable=protected-access
3608  except AttributeError:
3609    return DatasetV1Adapter(dataset)._make_one_shot_iterator()  # pylint: disable=protected-access
3610
3611
3612@tf_export(v1=["data.make_initializable_iterator"])
3613def make_initializable_iterator(dataset, shared_name=None):
3614  """Creates an iterator for elements of `dataset`.
3615
3616  Note: The returned iterator will be in an uninitialized state,
3617  and you must run the `iterator.initializer` operation before using it:
3618
3619  ```python
3620  dataset = ...
3621  iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
3622  # ...
3623  sess.run(iterator.initializer)
3624  ```
3625
3626  Args:
3627    dataset: A `tf.data.Dataset`.
3628    shared_name: (Optional.) If non-empty, the returned iterator will be shared
3629      under the given name across multiple sessions that share the same devices
3630      (e.g. when using a remote server).
3631
3632  Returns:
3633    A `tf.data.Iterator` for elements of `dataset`.
3634
3635  Raises:
3636    RuntimeError: If eager execution is enabled.
3637
3638  @compatibility(TF2)
3639  This is a legacy API for consuming dataset elements and should only be used
3640  during transition from TF 1 to TF 2. Note that using this API should be
3641  a transient state of your code base as there are in general no guarantees
3642  about the interoperability of TF 1 and TF 2 code.
3643
3644  In TF 2 datasets are Python iterables which means you can consume their
3645  elements using `for elem in dataset: ...` or by explicitly creating iterator
3646  via `iterator = iter(dataset)` and fetching its elements via
3647  `values = next(iterator)`.
3648  @end_compatibility
3649  """
3650  try:
3651    # Call the defined `_make_initializable_iterator()` if there is one, because
3652    # some datasets (e.g. for prefetching) override its behavior.
3653    return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
3654  except AttributeError:
3655    return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
3656
3657
3658@tf_export("data.experimental.get_structure")
3659def get_structure(dataset_or_iterator):
3660  """Returns the type signature for elements of the input dataset / iterator.
3661
3662  Args:
3663    dataset_or_iterator: A `tf.data.Dataset` or an `tf.data.Iterator`.
3664
3665  Returns:
3666    A (nested) structure of `tf.TypeSpec` objects matching the structure of an
3667    element of `dataset_or_iterator` and specifying the type of individual
3668    components.
3669
3670  Raises:
3671    TypeError: If input is not a `tf.data.Dataset` or an `tf.data.Iterator`
3672      object.
3673  """
3674  try:
3675    return dataset_or_iterator.element_spec  # pylint: disable=protected-access
3676  except AttributeError:
3677    raise TypeError("`dataset_or_iterator` must be a `tf.data.Dataset` or "
3678                    "tf.data.Iterator object, but got %s." %
3679                    type(dataset_or_iterator))
3680
3681
3682@tf_export(v1=["data.get_output_classes"])
3683def get_legacy_output_classes(dataset_or_iterator):
3684  """Returns the output classes for elements of the input dataset / iterator.
3685
3686  Args:
3687    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
3688
3689  Returns:
3690    A (nested) structure of Python `type` objects matching the structure of the
3691    dataset / iterator elements and specifying the class of the individual
3692    components.
3693
3694  @compatibility(TF2)
3695  This is a legacy API for inspecting the type signature of dataset elements. In
3696  TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
3697  @end_compatibility
3698  """
3699  return nest.map_structure(
3700      lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
3701      get_structure(dataset_or_iterator))
3702
3703
3704@tf_export(v1=["data.get_output_shapes"])
3705def get_legacy_output_shapes(dataset_or_iterator):
3706  """Returns the output shapes for elements of the input dataset / iterator.
3707
3708  Args:
3709    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
3710
3711  Returns:
3712    A (nested) structure of `tf.TensorShape` objects matching the structure of
3713    the dataset / iterator elements and specifying the shape of the individual
3714    components.
3715
3716  @compatibility(TF2)
3717  This is a legacy API for inspecting the type signature of dataset elements. In
3718  TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
3719  @end_compatibility
3720  """
3721  return nest.map_structure(
3722      lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
3723      get_structure(dataset_or_iterator))
3724
3725
3726@tf_export(v1=["data.get_output_types"])
3727def get_legacy_output_types(dataset_or_iterator):
3728  """Returns the output shapes for elements of the input dataset / iterator.
3729
3730  Args:
3731    dataset_or_iterator: A `tf.data.Dataset` or `tf.data.Iterator`.
3732
3733  Returns:
3734    A (nested) structure of `tf.DType` objects matching the structure of
3735    dataset / iterator elements and specifying the shape of the individual
3736    components.
3737
3738  @compatibility(TF2)
3739  This is a legacy API for inspecting the type signature of dataset elements. In
3740  TF 2, you should use the `tf.data.Dataset.element_spec` attribute instead.
3741  @end_compatibility
3742  """
3743  return nest.map_structure(
3744      lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
3745      get_structure(dataset_or_iterator))
3746
3747
3748class DatasetSource(DatasetV2):
3749  """Abstract class representing a dataset with no inputs."""
3750
3751  def _inputs(self):
3752    return []
3753
3754
3755class UnaryDataset(DatasetV2):
3756  """Abstract class representing a dataset with one input."""
3757
3758  def __init__(self, input_dataset, variant_tensor):
3759    self._input_dataset = input_dataset
3760    super(UnaryDataset, self).__init__(variant_tensor)
3761
3762  def _inputs(self):
3763    return [self._input_dataset]
3764
3765
3766class UnaryUnchangedStructureDataset(UnaryDataset):
3767  """Represents a unary dataset with the same input and output structure."""
3768
3769  def __init__(self, input_dataset, variant_tensor):
3770    self._input_dataset = input_dataset
3771    super(UnaryUnchangedStructureDataset, self).__init__(
3772        input_dataset, variant_tensor)
3773
3774  @property
3775  def element_spec(self):
3776    return self._input_dataset.element_spec
3777
3778
3779class TensorDataset(DatasetSource):
3780  """A `Dataset` with a single element."""
3781
3782  def __init__(self, element):
3783    """See `Dataset.from_tensors()` for details."""
3784    element = structure.normalize_element(element)
3785    self._structure = structure.type_spec_from_value(element)
3786    self._tensors = structure.to_tensor_list(self._structure, element)
3787
3788    variant_tensor = gen_dataset_ops.tensor_dataset(
3789        self._tensors,
3790        output_shapes=structure.get_flat_tensor_shapes(self._structure))
3791    super(TensorDataset, self).__init__(variant_tensor)
3792
3793  @property
3794  def element_spec(self):
3795    return self._structure
3796
3797
3798class TensorSliceDataset(DatasetSource):
3799  """A `Dataset` of slices from a dataset element."""
3800
3801  def __init__(self, element):
3802    """See `Dataset.from_tensor_slices()` for details."""
3803    element = structure.normalize_element(element)
3804    batched_spec = structure.type_spec_from_value(element)
3805    self._tensors = structure.to_batched_tensor_list(batched_spec, element)
3806    self._structure = nest.map_structure(
3807        lambda component_spec: component_spec._unbatch(), batched_spec)  # pylint: disable=protected-access
3808
3809    batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value(
3810        self._tensors[0].get_shape()[0]))
3811    for t in self._tensors[1:]:
3812      batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
3813          tensor_shape.dimension_value(t.get_shape()[0])))
3814
3815    variant_tensor = gen_dataset_ops.tensor_slice_dataset(
3816        self._tensors,
3817        output_shapes=structure.get_flat_tensor_shapes(self._structure))
3818    super(TensorSliceDataset, self).__init__(variant_tensor)
3819
3820  @property
3821  def element_spec(self):
3822    return self._structure
3823
3824
3825class SparseTensorSliceDataset(DatasetSource):
3826  """A `Dataset` that splits a rank-N `tf.sparse.SparseTensor` into its rows."""
3827
3828  def __init__(self, sparse_tensor):
3829    """See `Dataset.from_sparse_tensor_slices()` for details."""
3830    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
3831      raise TypeError(
3832          "`sparse_tensor` must be a `tf.sparse.SparseTensor` object."
3833          "Was {}.".format(sparse_tensor))
3834    self._sparse_tensor = sparse_tensor
3835
3836    indices_shape = self._sparse_tensor.indices.get_shape()
3837    shape_shape = self._sparse_tensor.dense_shape.get_shape()
3838    rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1)
3839    self._structure = (tensor_spec.TensorSpec([None, rank], dtypes.int64),
3840                       tensor_spec.TensorSpec([None],
3841                                              self._sparse_tensor.dtype),
3842                       tensor_spec.TensorSpec([rank], dtypes.int64))
3843
3844    variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
3845        self._sparse_tensor.indices, self._sparse_tensor.values,
3846        self._sparse_tensor.dense_shape)
3847    super(SparseTensorSliceDataset, self).__init__(variant_tensor)
3848
3849  @property
3850  def element_spec(self):
3851    return self._structure
3852
3853
3854class _VariantDataset(DatasetV2):
3855  """A Dataset wrapper around a `tf.variant`-typed function argument."""
3856
3857  def __init__(self, dataset_variant, structure):
3858    self._structure = structure
3859    super(_VariantDataset, self).__init__(dataset_variant)
3860
3861  def _inputs(self):
3862    return []
3863
3864  @property
3865  def element_spec(self):
3866    return self._structure
3867
3868
3869class _NestedVariant(composite_tensor.CompositeTensor):
3870
3871  def __init__(self, variant_tensor, element_spec, dataset_shape):
3872    self._variant_tensor = variant_tensor
3873    self._element_spec = element_spec
3874    self._dataset_shape = dataset_shape
3875
3876  @property
3877  def _type_spec(self):
3878    return DatasetSpec(self._element_spec, self._dataset_shape)
3879
3880
3881@tf_export("data.experimental.from_variant")
3882def from_variant(variant, structure):
3883  """Constructs a dataset from the given variant and (nested) structure.
3884
3885  Args:
3886    variant: A scalar `tf.variant` tensor representing a dataset.
3887    structure: A (nested) structure of `tf.TypeSpec` objects representing the
3888      structure of each element in the dataset.
3889
3890  Returns:
3891    A `tf.data.Dataset` instance.
3892  """
3893  return _VariantDataset(variant, structure)  # pylint: disable=protected-access
3894
3895
3896@tf_export("data.experimental.to_variant")
3897def to_variant(dataset):
3898  """Returns a variant representing the given dataset.
3899
3900  Args:
3901    dataset: A `tf.data.Dataset`.
3902
3903  Returns:
3904    A scalar `tf.variant` tensor representing the given dataset.
3905  """
3906  return dataset._variant_tensor  # pylint: disable=protected-access
3907
3908
3909@tf_export(
3910    "data.DatasetSpec",
3911    v1=["data.DatasetSpec", "data.experimental.DatasetStructure"])
3912class DatasetSpec(type_spec.BatchableTypeSpec):
3913  """Type specification for `tf.data.Dataset`.
3914
3915  See `tf.TypeSpec` for more information about TensorFlow type specifications.
3916
3917  >>> dataset = tf.data.Dataset.range(3)
3918  >>> tf.data.DatasetSpec.from_value(dataset)
3919  DatasetSpec(TensorSpec(shape=(), dtype=tf.int64, name=None), TensorShape([]))
3920  """
3921
3922  __slots__ = ["_element_spec", "_dataset_shape"]
3923
3924  def __init__(self, element_spec, dataset_shape=()):
3925    self._element_spec = element_spec
3926    self._dataset_shape = tensor_shape.as_shape(dataset_shape)
3927
3928  @property
3929  def value_type(self):
3930    return Dataset
3931
3932  @property
3933  def element_spec(self):
3934    """The inner element spec."""
3935    return self._element_spec
3936
3937  def _serialize(self):
3938    return (self._element_spec, self._dataset_shape)
3939
3940  @property
3941  def _component_specs(self):
3942    return tensor_spec.TensorSpec(self._dataset_shape, dtypes.variant)
3943
3944  def _to_components(self, value):
3945    return value._variant_tensor  # pylint: disable=protected-access
3946
3947  def _from_components(self, components):
3948    # pylint: disable=protected-access
3949    if self._dataset_shape.ndims == 0:
3950      return _VariantDataset(components, self._element_spec)
3951    else:
3952      return _NestedVariant(components, self._element_spec, self._dataset_shape)
3953
3954  def _to_tensor_list(self, value):
3955    return [
3956        ops.convert_to_tensor(
3957            tf_nest.map_structure(lambda x: x._variant_tensor, value))  # pylint: disable=protected-access
3958    ]
3959
3960  @staticmethod
3961  def from_value(value):
3962    """Creates a `DatasetSpec` for the given `tf.data.Dataset` value."""
3963    return DatasetSpec(value.element_spec)  # pylint: disable=protected-access
3964
3965  def _batch(self, batch_size):
3966    return DatasetSpec(
3967        self._element_spec,
3968        tensor_shape.TensorShape([batch_size]).concatenate(self._dataset_shape))
3969
3970  def _unbatch(self):
3971    if self._dataset_shape.ndims == 0:
3972      raise ValueError("Unbatching a dataset is only supported for rank >= 1")
3973    return DatasetSpec(self._element_spec, self._dataset_shape[1:])
3974
3975  def _to_batched_tensor_list(self, value):
3976    if self._dataset_shape.ndims == 0:
3977      raise ValueError("Unbatching a dataset is only supported for rank >= 1")
3978    return self._to_tensor_list(value)
3979
3980  def _to_legacy_output_types(self):
3981    return self
3982
3983  def _to_legacy_output_shapes(self):
3984    return self
3985
3986  def _to_legacy_output_classes(self):
3987    return self
3988
3989
3990class StructuredFunctionWrapper(object):
3991  """A function wrapper that supports structured arguments and return values."""
3992
3993  def __init__(self,
3994               func,
3995               transformation_name,
3996               dataset=None,
3997               input_classes=None,
3998               input_shapes=None,
3999               input_types=None,
4000               input_structure=None,
4001               add_to_graph=True,
4002               use_legacy_function=False,
4003               defun_kwargs=None):
4004    """Creates a new `StructuredFunctionWrapper` for the given function.
4005
4006    Args:
4007      func: A function from a (nested) structure to another (nested) structure.
4008      transformation_name: Human-readable name of the transformation in which
4009        this function is being instantiated, for error messages.
4010      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
4011        dataset will be assumed as the structure for `func` arguments; otherwise
4012        `input_classes`, `input_shapes`, and `input_types` must be defined.
4013      input_classes: (Optional.) A (nested) structure of `type`. If given, this
4014        argument defines the Python types for `func` arguments.
4015      input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If
4016        given, this argument defines the shapes and structure for `func`
4017        arguments.
4018      input_types: (Optional.) A (nested) structure of `tf.DType`. If given,
4019        this argument defines the element types and structure for `func`
4020        arguments.
4021      input_structure: (Optional.) A `Structure` object. If given, this argument
4022        defines the element types and structure for `func` arguments.
4023      add_to_graph: (Optional.) If `True`, the function will be added to the
4024        default graph, if it exists.
4025      use_legacy_function: (Optional.) A boolean that determines whether the
4026        function be created using `tensorflow.python.eager.function.defun`
4027        (default behavior) or `tensorflow.python.framework.function.Defun`
4028        (legacy behavior).
4029      defun_kwargs: (Optional.) A dictionary mapping string argument names to
4030        values. If supplied, will be passed to `function` as keyword arguments.
4031
4032    Raises:
4033      ValueError: If an invalid combination of `dataset`, `input_classes`,
4034        `input_shapes`, and `input_types` is passed.
4035    """
4036    # pylint: disable=protected-access
4037    if input_structure is None:
4038      if dataset is None:
4039        if input_classes is None or input_shapes is None or input_types is None:
4040          raise ValueError("Either `dataset`, `input_structure` or all of "
4041                           "`input_classes`, `input_shapes`, and `input_types` "
4042                           "must be specified.")
4043        self._input_structure = structure.convert_legacy_structure(
4044            input_types, input_shapes, input_classes)
4045      else:
4046        if not (input_classes is None and input_shapes is None and
4047                input_types is None):
4048          raise ValueError("Either `dataset`, `input_structure` or all of "
4049                           "`input_classes`, `input_shapes`, and `input_types` "
4050                           "must be specified.")
4051        self._input_structure = dataset.element_spec
4052    else:
4053      if not (dataset is None and input_classes is None and input_shapes is None
4054              and input_types is None):
4055        raise ValueError("Either `dataset`, `input_structure`, or all of "
4056                         "`input_classes`, `input_shapes`, and `input_types` "
4057                         "must be specified.")
4058      self._input_structure = input_structure
4059
4060    self._func = func
4061
4062    if defun_kwargs is None:
4063      defun_kwargs = {}
4064
4065    readable_transformation_name = transformation_name.replace(
4066        ".", "_")[:-2] if len(transformation_name) > 2 else ""
4067
4068    func_name = "_".join(
4069        [readable_transformation_name,
4070         function_utils.get_func_name(func)])
4071    # Sanitize function name to remove symbols that interfere with graph
4072    # construction.
4073    for symbol in ["<", ">", "\\", "'", " "]:
4074      func_name = func_name.replace(symbol, "")
4075
4076    ag_ctx = autograph_ctx.control_status_ctx()
4077
4078    def wrapper_helper(*args):
4079      """Wrapper for passing nested structures to and from tf.data functions."""
4080      nested_args = structure.from_compatible_tensor_list(
4081          self._input_structure, args)
4082      if not _should_unpack(nested_args):
4083        nested_args = (nested_args,)
4084      ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
4085      if _should_pack(ret):
4086        ret = tuple(ret)
4087
4088      try:
4089        self._output_structure = structure.type_spec_from_value(ret)
4090      except (ValueError, TypeError):
4091        six.reraise(
4092            TypeError,
4093            TypeError("Unsupported return value from function passed to "
4094                      "%s: %s." % (transformation_name, ret)),
4095            sys.exc_info()[2])
4096      return ret
4097
4098    def trace_legacy_function(defun_kwargs):
4099      @function.Defun(*structure.get_flat_tensor_types(self._input_structure),
4100                      **defun_kwargs)
4101      def wrapped_fn(*args):
4102        ret = wrapper_helper(*args)
4103        return structure.to_tensor_list(self._output_structure, ret)
4104
4105      return lambda: wrapped_fn
4106
4107    def trace_py_function(defun_kwargs):
4108      # First we trace the function to infer the output structure.
4109      @eager_function.defun_with_attributes(
4110          input_signature=structure.get_flat_tensor_specs(
4111              self._input_structure),
4112          autograph=False,
4113          attributes=defun_kwargs)
4114      def unused(*args):  # pylint: disable=missing-docstring,unused-variable
4115        ret = wrapper_helper(*args)
4116        ret = structure.to_tensor_list(self._output_structure, ret)
4117        return [ops.convert_to_tensor(t) for t in ret]
4118
4119      _ = unused.get_concrete_function()
4120
4121      def py_function_wrapper(*args):
4122        nested_args = structure.from_compatible_tensor_list(
4123            self._input_structure, args)
4124        if not _should_unpack(nested_args):
4125          nested_args = (nested_args,)
4126        ret = self._func(*nested_args)
4127        if _should_pack(ret):
4128          ret = tuple(ret)
4129        ret = structure.to_tensor_list(self._output_structure, ret)
4130        return [ops.convert_to_tensor(t) for t in ret]
4131
4132      # Next we trace the function wrapped in `eager_py_func` to force eager
4133      # execution.
4134      @eager_function.defun_with_attributes(
4135          input_signature=structure.get_flat_tensor_specs(
4136              self._input_structure),
4137          autograph=False,
4138          attributes=defun_kwargs)
4139      def wrapped_fn(*args):  # pylint: disable=missing-docstring
4140        return script_ops.eager_py_func(
4141            py_function_wrapper, args,
4142            structure.get_flat_tensor_types(self._output_structure))
4143
4144      return wrapped_fn.get_concrete_function
4145
4146    def trace_tf_function(defun_kwargs):
4147      # Note: wrapper_helper will apply autograph based on context.
4148      @eager_function.defun_with_attributes(
4149          input_signature=structure.get_flat_tensor_specs(
4150              self._input_structure),
4151          autograph=False,
4152          attributes=defun_kwargs)
4153      def wrapped_fn(*args):  # pylint: disable=missing-docstring
4154        ret = wrapper_helper(*args)
4155        ret = structure.to_tensor_list(self._output_structure, ret)
4156        return [ops.convert_to_tensor(t) for t in ret]
4157
4158      return wrapped_fn.get_concrete_function
4159
4160    if use_legacy_function:
4161      defun_kwargs.update({"func_name": func_name + "_" + str(ops.uid())})
4162      fn_factory = trace_legacy_function(defun_kwargs)
4163    else:
4164      defun_kwargs.update({"func_name": func_name})
4165      defun_kwargs.update({"_tf_data_function": True})
4166      if DEBUG_MODE:
4167        fn_factory = trace_py_function(defun_kwargs)
4168      else:
4169        if def_function.functions_run_eagerly():
4170          warnings.warn(
4171              "Even though the `tf.config.experimental_run_functions_eagerly` "
4172              "option is set, this option does not apply to tf.data functions. "
4173              "To force eager execution of tf.data functions, please use "
4174              "`tf.data.experimental.enable_debug_mode()`.")
4175        fn_factory = trace_tf_function(defun_kwargs)
4176
4177    self._function = fn_factory()
4178    # There is no graph to add in eager mode.
4179    add_to_graph &= not context.executing_eagerly()
4180    # There are some lifetime issues when a legacy function is not added to a
4181    # out-living graph. It's already deprecated so de-prioritizing the fix.
4182    add_to_graph |= use_legacy_function
4183    if add_to_graph:
4184      self._function.add_to_graph(ops.get_default_graph())
4185
4186    if not use_legacy_function:
4187      outer_graph_seed = ops.get_default_graph().seed
4188      if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
4189        if self._function.graph._seed_used:
4190          warnings.warn(
4191              "Seed %s from outer graph might be getting used by function %s, "
4192              "if the random op has not been provided any seed. Explicitly set "
4193              "the seed in the function if this is not the intended behavior."
4194              %(outer_graph_seed, func_name), stacklevel=4)
4195
4196  @property
4197  def output_structure(self):
4198    return self._output_structure
4199
4200  @property
4201  def output_classes(self):
4202    return nest.map_structure(
4203        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
4204        self._output_structure)
4205
4206  @property
4207  def output_shapes(self):
4208    return nest.map_structure(
4209        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
4210        self._output_structure)
4211
4212  @property
4213  def output_types(self):
4214    return nest.map_structure(
4215        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
4216        self._output_structure)
4217
4218  @property
4219  def function(self):
4220    return self._function
4221
4222
4223class _GeneratorDataset(DatasetSource):
4224  """A `Dataset` that generates elements by invoking a function."""
4225
4226  def __init__(self, init_args, init_func, next_func, finalize_func,
4227               output_signature):
4228    """Constructs a `_GeneratorDataset`.
4229
4230    Args:
4231      init_args: A (nested) structure representing the arguments to `init_func`.
4232      init_func: A TensorFlow function that will be called on `init_args` each
4233        time a C++ iterator over this dataset is constructed. Returns a (nested)
4234        structure representing the "state" of the dataset.
4235      next_func: A TensorFlow function that will be called on the result of
4236        `init_func` to produce each element, and that raises `OutOfRangeError`
4237        to terminate iteration.
4238      finalize_func: A TensorFlow function that will be called on the result of
4239        `init_func` immediately before a C++ iterator over this dataset is
4240        destroyed. The return value is ignored.
4241      output_signature: A (nested) structure of `tf.TypeSpec` objects describing
4242        the output of `next_func`.
4243    """
4244    self._init_args = init_args
4245
4246    self._init_structure = structure.type_spec_from_value(init_args)
4247
4248    self._init_func = StructuredFunctionWrapper(
4249        init_func,
4250        self._transformation_name(),
4251        input_structure=self._init_structure)
4252
4253    self._next_func = StructuredFunctionWrapper(
4254        next_func,
4255        self._transformation_name(),
4256        input_structure=self._init_func.output_structure)
4257
4258    self._finalize_func = StructuredFunctionWrapper(
4259        finalize_func,
4260        self._transformation_name(),
4261        input_structure=self._init_func.output_structure)
4262
4263    self._output_signature = output_signature
4264
4265    variant_tensor = gen_dataset_ops.generator_dataset(
4266        structure.to_tensor_list(self._init_structure, self._init_args) +
4267        self._init_func.function.captured_inputs,
4268        self._next_func.function.captured_inputs,
4269        self._finalize_func.function.captured_inputs,
4270        init_func=self._init_func.function,
4271        next_func=self._next_func.function,
4272        finalize_func=self._finalize_func.function,
4273        **self._flat_structure)
4274    super(_GeneratorDataset, self).__init__(variant_tensor)
4275
4276  @property
4277  def element_spec(self):
4278    return self._output_signature
4279
4280  def _transformation_name(self):
4281    return "Dataset.from_generator()"
4282
4283
4284class ZipDataset(DatasetV2):
4285  """A `Dataset` that zips its inputs together."""
4286
4287  def __init__(self, datasets):
4288    """See `Dataset.zip()` for details."""
4289    for ds in nest.flatten(datasets):
4290      if not isinstance(ds, DatasetV2):
4291        if isinstance(ds, list):
4292          message = ("The argument to `Dataset.zip()` must be a (nested) "
4293                     "structure of `Dataset` objects. Python `list` is not "
4294                     "supported, please use a `tuple` instead.")
4295        else:
4296          message = ("The argument to `Dataset.zip()` must be a (nested) "
4297                     "structure of `Dataset` objects.")
4298        raise TypeError(message)
4299    self._datasets = datasets
4300    self._structure = nest.pack_sequence_as(
4301        self._datasets,
4302        [ds.element_spec for ds in nest.flatten(self._datasets)])
4303    variant_tensor = gen_dataset_ops.zip_dataset(
4304        [ds._variant_tensor for ds in nest.flatten(self._datasets)],
4305        **self._flat_structure)
4306    super(ZipDataset, self).__init__(variant_tensor)
4307
4308  def _inputs(self):
4309    return nest.flatten(self._datasets)
4310
4311  @property
4312  def element_spec(self):
4313    return self._structure
4314
4315
4316class ConcatenateDataset(DatasetV2):
4317  """A `Dataset` that concatenates its input with given dataset."""
4318
4319  def __init__(self, input_dataset, dataset_to_concatenate):
4320    """See `Dataset.concatenate()` for details."""
4321    self._input_dataset = input_dataset
4322    self._dataset_to_concatenate = dataset_to_concatenate
4323
4324    output_types = get_legacy_output_types(input_dataset)
4325    if output_types != get_legacy_output_types(dataset_to_concatenate):
4326      raise TypeError(
4327          "Two datasets to concatenate have different types %s and %s" %
4328          (output_types, get_legacy_output_types(dataset_to_concatenate)))
4329
4330    output_classes = get_legacy_output_classes(input_dataset)
4331    if output_classes != get_legacy_output_classes(dataset_to_concatenate):
4332      raise TypeError(
4333          "Two datasets to concatenate have different classes %s and %s" %
4334          (output_classes, get_legacy_output_classes(dataset_to_concatenate)))
4335
4336    spec1 = input_dataset.element_spec
4337    spec2 = dataset_to_concatenate.element_spec
4338    self._structure = nest.pack_sequence_as(spec1, [
4339        ts1.most_specific_compatible_type(ts2)
4340        for (ts1, ts2) in zip(nest.flatten(spec1), nest.flatten(spec2))
4341    ])
4342
4343    self._input_datasets = [input_dataset, dataset_to_concatenate]
4344    # pylint: disable=protected-access
4345    variant_tensor = gen_dataset_ops.concatenate_dataset(
4346        input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
4347        **self._flat_structure)
4348    # pylint: enable=protected-access
4349    super(ConcatenateDataset, self).__init__(variant_tensor)
4350
4351  def _inputs(self):
4352    return self._input_datasets
4353
4354  @property
4355  def element_spec(self):
4356    return self._structure
4357
4358
4359class RepeatDataset(UnaryUnchangedStructureDataset):
4360  """A `Dataset` that repeats its input several times."""
4361
4362  def __init__(self, input_dataset, count):
4363    """See `Dataset.repeat()` for details."""
4364    self._input_dataset = input_dataset
4365    if count is None:
4366      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
4367    else:
4368      self._count = ops.convert_to_tensor(
4369          count, dtype=dtypes.int64, name="count")
4370    variant_tensor = gen_dataset_ops.repeat_dataset(
4371        input_dataset._variant_tensor,  # pylint: disable=protected-access
4372        count=self._count,
4373        **self._flat_structure)
4374    super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
4375
4376
4377class RangeDataset(DatasetSource):
4378  """A `Dataset` of a step separated range of values."""
4379
4380  def __init__(self, *args, **kwargs):
4381    """See `Dataset.range()` for details."""
4382    self._parse_args(*args, **kwargs)
4383    self._structure = tensor_spec.TensorSpec([], self._output_type)
4384    variant_tensor = gen_dataset_ops.range_dataset(
4385        start=self._start,
4386        stop=self._stop,
4387        step=self._step,
4388        **self._flat_structure)
4389    super(RangeDataset, self).__init__(variant_tensor)
4390
4391  def _parse_args(self, *args, **kwargs):
4392    """Parse arguments according to the same rules as the `range()` builtin."""
4393    if len(args) == 1:
4394      self._start = self._build_tensor(0, "start")
4395      self._stop = self._build_tensor(args[0], "stop")
4396      self._step = self._build_tensor(1, "step")
4397    elif len(args) == 2:
4398      self._start = self._build_tensor(args[0], "start")
4399      self._stop = self._build_tensor(args[1], "stop")
4400      self._step = self._build_tensor(1, "step")
4401    elif len(args) == 3:
4402      self._start = self._build_tensor(args[0], "start")
4403      self._stop = self._build_tensor(args[1], "stop")
4404      self._step = self._build_tensor(args[2], "step")
4405    else:
4406      raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
4407    if "output_type" in kwargs:
4408      self._output_type = kwargs["output_type"]
4409    else:
4410      self._output_type = dtypes.int64
4411
4412  def _build_tensor(self, int64_value, name):
4413    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
4414
4415  @property
4416  def element_spec(self):
4417    return self._structure
4418
4419
4420class CacheDataset(UnaryUnchangedStructureDataset):
4421  """A `Dataset` that caches elements of its input."""
4422
4423  def __init__(self, input_dataset, filename):
4424    """See `Dataset.cache()` for details."""
4425    self._input_dataset = input_dataset
4426    self._filename = ops.convert_to_tensor(
4427        filename, dtype=dtypes.string, name="filename")
4428    if tf2.enabled() and (context.executing_eagerly() or ops.inside_function()):
4429      variant_tensor = gen_dataset_ops.cache_dataset_v2(
4430          input_dataset._variant_tensor,  # pylint: disable=protected-access
4431          filename=self._filename,
4432          cache=gen_dataset_ops.dummy_memory_cache(),
4433          **self._flat_structure)
4434    else:
4435      variant_tensor = gen_dataset_ops.cache_dataset(
4436          input_dataset._variant_tensor,  # pylint: disable=protected-access
4437          filename=self._filename,
4438          **self._flat_structure)
4439    super(CacheDataset, self).__init__(input_dataset, variant_tensor)
4440
4441
4442class ShuffleDataset(UnaryUnchangedStructureDataset):
4443  """A `Dataset` that randomly shuffles the elements of its input."""
4444
4445  def __init__(self,
4446               input_dataset,
4447               buffer_size,
4448               seed=None,
4449               reshuffle_each_iteration=None):
4450    """Randomly shuffles the elements of this dataset.
4451
4452    Args:
4453      input_dataset: The input dataset.
4454      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
4455        elements from this dataset from which the new dataset will sample.
4456      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
4457        seed that will be used to create the distribution. See
4458        `tf.random.set_seed` for behavior.
4459      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
4460        that the dataset should be pseudorandomly reshuffled each time it is
4461        iterated over. (Defaults to `True`.)
4462
4463    Returns:
4464      A `Dataset`.
4465
4466    Raises:
4467      ValueError: if invalid arguments are provided.
4468    """
4469    self._input_dataset = input_dataset
4470    self._buffer_size = ops.convert_to_tensor(
4471        buffer_size, dtype=dtypes.int64, name="buffer_size")
4472    self._seed, self._seed2 = random_seed.get_seed(seed)
4473    if reshuffle_each_iteration is None:
4474      reshuffle_each_iteration = True
4475    self._reshuffle_each_iteration = reshuffle_each_iteration
4476
4477    if (tf2.enabled() and
4478        (context.executing_eagerly() or ops.inside_function())):
4479      variant_tensor = gen_dataset_ops.shuffle_dataset_v3(
4480          input_dataset._variant_tensor,  # pylint: disable=protected-access
4481          buffer_size=self._buffer_size,
4482          seed=self._seed,
4483          seed2=self._seed2,
4484          seed_generator=gen_dataset_ops.dummy_seed_generator(),
4485          reshuffle_each_iteration=self._reshuffle_each_iteration,
4486          **self._flat_structure)
4487    else:
4488      variant_tensor = gen_dataset_ops.shuffle_dataset(
4489          input_dataset._variant_tensor,  # pylint: disable=protected-access
4490          buffer_size=self._buffer_size,
4491          seed=self._seed,
4492          seed2=self._seed2,
4493          reshuffle_each_iteration=self._reshuffle_each_iteration,
4494          **self._flat_structure)
4495    super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
4496
4497
4498class TakeDataset(UnaryUnchangedStructureDataset):
4499  """A `Dataset` containing the first `count` elements from its input."""
4500
4501  def __init__(self, input_dataset, count):
4502    """See `Dataset.take()` for details."""
4503    self._input_dataset = input_dataset
4504    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
4505    variant_tensor = gen_dataset_ops.take_dataset(
4506        input_dataset._variant_tensor,  # pylint: disable=protected-access
4507        count=self._count,
4508        **self._flat_structure)
4509    super(TakeDataset, self).__init__(input_dataset, variant_tensor)
4510
4511
4512class SkipDataset(UnaryUnchangedStructureDataset):
4513  """A `Dataset` skipping the first `count` elements from its input."""
4514
4515  def __init__(self, input_dataset, count):
4516    """See `Dataset.skip()` for details."""
4517    self._input_dataset = input_dataset
4518    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
4519    variant_tensor = gen_dataset_ops.skip_dataset(
4520        input_dataset._variant_tensor,  # pylint: disable=protected-access
4521        count=self._count,
4522        **self._flat_structure)
4523    super(SkipDataset, self).__init__(input_dataset, variant_tensor)
4524
4525
4526class ShardDataset(UnaryUnchangedStructureDataset):
4527  """A `Dataset` for sharding its input."""
4528
4529  def __init__(self, input_dataset, num_shards, index):
4530    """See `Dataset.shard()` for details."""
4531    self._input_dataset = input_dataset
4532    self._num_shards = ops.convert_to_tensor(
4533        num_shards, dtype=dtypes.int64, name="num_shards")
4534    self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
4535    variant_tensor = gen_dataset_ops.shard_dataset(
4536        input_dataset._variant_tensor,  # pylint: disable=protected-access
4537        num_shards=self._num_shards,
4538        index=self._index,
4539        **self._flat_structure)
4540    super(ShardDataset, self).__init__(input_dataset, variant_tensor)
4541
4542
4543class BatchDataset(UnaryDataset):
4544  """A `Dataset` that batches contiguous elements from its input."""
4545
4546  def __init__(self, input_dataset, batch_size, drop_remainder):
4547    """See `Dataset.batch()` for details."""
4548    self._input_dataset = input_dataset
4549    self._batch_size = ops.convert_to_tensor(
4550        batch_size, dtype=dtypes.int64, name="batch_size")
4551    self._drop_remainder = ops.convert_to_tensor(
4552        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4553
4554    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
4555    # pylint: disable=protected-access
4556    if constant_drop_remainder:
4557      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
4558      # or `False` (explicitly retaining the remainder).
4559      # pylint: disable=g-long-lambda
4560      constant_batch_size = tensor_util.constant_value(self._batch_size)
4561      self._structure = nest.map_structure(
4562          lambda component_spec: component_spec._batch(constant_batch_size),
4563          input_dataset.element_spec)
4564    else:
4565      self._structure = nest.map_structure(
4566          lambda component_spec: component_spec._batch(None),
4567          input_dataset.element_spec)
4568    variant_tensor = gen_dataset_ops.batch_dataset_v2(
4569        input_dataset._variant_tensor,
4570        batch_size=self._batch_size,
4571        drop_remainder=self._drop_remainder,
4572        **self._flat_structure)
4573    super(BatchDataset, self).__init__(input_dataset, variant_tensor)
4574
4575  @property
4576  def element_spec(self):
4577    return self._structure
4578
4579
4580class ParallelBatchDataset(UnaryDataset):
4581  """A `Dataset` that batches contiguous elements from its input in parallel."""
4582
4583  def __init__(self, input_dataset, batch_size, drop_remainder,
4584               num_parallel_calls, deterministic):
4585    """See `Dataset.batch()` for details."""
4586    self._input_dataset = input_dataset
4587    self._batch_size = ops.convert_to_tensor(
4588        batch_size, dtype=dtypes.int64, name="batch_size")
4589    self._drop_remainder = ops.convert_to_tensor(
4590        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4591    self._num_parallel_calls = ops.convert_to_tensor(
4592        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
4593    if deterministic is None:
4594      self._deterministic = "default"
4595    elif deterministic:
4596      self._deterministic = "true"
4597    else:
4598      self._deterministic = "false"
4599
4600    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
4601    # pylint: disable=protected-access
4602    if constant_drop_remainder:
4603      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
4604      # or `False` (explicitly retaining the remainder).
4605      # pylint: disable=g-long-lambda
4606      constant_batch_size = tensor_util.constant_value(self._batch_size)
4607      self._structure = nest.map_structure(
4608          lambda component_spec: component_spec._batch(constant_batch_size),
4609          input_dataset.element_spec)
4610    else:
4611      self._structure = nest.map_structure(
4612          lambda component_spec: component_spec._batch(None),
4613          input_dataset.element_spec)
4614
4615    variant_tensor = gen_dataset_ops.parallel_batch_dataset(
4616        input_dataset._variant_tensor,
4617        batch_size=self._batch_size,
4618        num_parallel_calls=self._num_parallel_calls,
4619        drop_remainder=self._drop_remainder,
4620        deterministic=self._deterministic,
4621        **self._flat_structure)
4622
4623    super(ParallelBatchDataset, self).__init__(input_dataset, variant_tensor)
4624
4625  @property
4626  def element_spec(self):
4627    return self._structure
4628
4629
4630class _NumpyIterator(object):
4631  """Iterator over a dataset with elements converted to numpy."""
4632
4633  __slots__ = ["_iterator"]
4634
4635  def __init__(self, dataset):
4636    self._iterator = iter(dataset)
4637
4638  def __iter__(self):
4639    return self
4640
4641  def __next__(self):
4642
4643    def to_numpy(x):
4644      numpy = x._numpy()  # pylint: disable=protected-access
4645      if isinstance(numpy, np.ndarray):
4646        # `numpy` shares the same underlying buffer as the `x` Tensor.
4647        # Tensors are expected to be immutable, so we disable writes.
4648        numpy.setflags(write=False)
4649      return numpy
4650
4651    return nest.map_structure(to_numpy, next(self._iterator))
4652
4653  def next(self):
4654    return self.__next__()
4655
4656
4657class _VariantTracker(tracking.CapturableResource):
4658  """Allows export of functions capturing a Dataset in SavedModels.
4659
4660  When saving a SavedModel, `tf.saved_model.save` traverses the object
4661  graph. Since Datasets reference _VariantTracker objects, that traversal will
4662  find a _VariantTracker for each Dataset and so know how to save and restore
4663  functions which reference the Dataset's variant Tensor.
4664  """
4665
4666  def __init__(self, variant_tensor, resource_creator):
4667    """Record that `variant_tensor` is associated with `resource_creator`.
4668
4669    Args:
4670      variant_tensor: The variant-dtype Tensor associated with the Dataset. This
4671        Tensor will be a captured input to functions which use the Dataset, and
4672        is used by saving code to identify the corresponding _VariantTracker.
4673      resource_creator: A zero-argument function which creates a new
4674        variant-dtype Tensor. This function will be included in SavedModels and
4675        run to re-create the Dataset's variant Tensor on restore.
4676    """
4677    super(_VariantTracker, self).__init__(device="CPU")
4678    self._resource_handle = variant_tensor
4679    self._create_resource = resource_creator
4680
4681
4682def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
4683  """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
4684
4685  Args:
4686    padded_shape: A `tf.TensorShape`.
4687    input_component_shape: A `tf.TensorShape`.
4688
4689  Returns:
4690    `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
4691    `False`.
4692  """
4693
4694  if padded_shape.dims is None or input_component_shape.dims is None:
4695    return True
4696  if len(padded_shape.dims) != len(input_component_shape.dims):
4697    return False
4698  for padded_dim, input_dim in zip(
4699      padded_shape.dims, input_component_shape.dims):
4700    if (padded_dim.value is not None and input_dim.value is not None
4701        and padded_dim.value < input_dim.value):
4702      return False
4703  return True
4704
4705
4706def _padded_shape_to_tensor(padded_shape, input_component_shape):
4707  """Converts `padded_shape` to a `tf.Tensor` representing that shape.
4708
4709  Args:
4710    padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
4711      sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
4712    input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
4713      be compatible.
4714
4715  Returns:
4716    A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
4717
4718  Raises:
4719    ValueError: If `padded_shape` is not a shape or not compatible with
4720      `input_component_shape`.
4721    TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
4722  """
4723  try:
4724    # Try to convert the `padded_shape` to a `tf.TensorShape`
4725    padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
4726    # We will return the "canonical" tensor representation, which uses
4727    # `-1` in place of `None`.
4728    ret = ops.convert_to_tensor(
4729        [dim if dim is not None else -1
4730         for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
4731  except (TypeError, ValueError):
4732    # The argument was not trivially convertible to a
4733    # `tf.TensorShape`, so fall back on the conversion to tensor
4734    # machinery.
4735    ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
4736    if ret.shape.dims is not None and len(ret.shape.dims) != 1:
4737      six.reraise(ValueError, ValueError(
4738          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
4739          "shape was %s." % (padded_shape, ret.shape)), sys.exc_info()[2])
4740    if ret.dtype != dtypes.int64:
4741      six.reraise(
4742          TypeError,
4743          TypeError(
4744              "Padded shape %s must be a 1-D tensor of tf.int64 values, but "
4745              "its element type was %s." % (padded_shape, ret.dtype.name)),
4746          sys.exc_info()[2])
4747    padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
4748
4749  if not _is_padded_shape_compatible_with(padded_shape_as_shape,
4750                                          input_component_shape):
4751    raise ValueError("The padded shape %s is not compatible with the "
4752                     "corresponding input component shape %s."
4753                     % (padded_shape_as_shape, input_component_shape))
4754
4755  return ret
4756
4757
4758def _padding_value_to_tensor(value, output_type):
4759  """Converts the padding value to a tensor.
4760
4761  Args:
4762    value: The padding value.
4763    output_type: Its expected dtype.
4764
4765  Returns:
4766    A scalar `Tensor`.
4767
4768  Raises:
4769    ValueError: if the padding value is not a scalar.
4770    TypeError: if the padding value's type does not match `output_type`.
4771  """
4772  value = ops.convert_to_tensor(value, name="padding_value")
4773  if not value.shape.is_compatible_with(tensor_shape.TensorShape([])):
4774    raise ValueError("Padding value should be a scalar, but is not: %s" % value)
4775  if value.dtype != output_type:
4776    raise TypeError("Padding value tensor (%s) does not match output type: %s" %
4777                    (value, output_type))
4778  return value
4779
4780
4781def _padding_values_or_default(padding_values, input_dataset):
4782  """Returns padding values with None elements replaced with default values."""
4783
4784  def make_zero(t):
4785    if t.base_dtype == dtypes.string:
4786      return ""
4787    elif t.base_dtype == dtypes.variant:
4788      error_msg = ("Unable to create padding for field of type 'variant' "
4789                   "because t.base_type == dtypes.variant == "
4790                   "{}.".format(t.base_dtype))
4791      raise TypeError(error_msg)
4792    elif t.base_dtype == dtypes.bfloat16:
4793      # Special case `bfloat16` because it is not supported by NumPy.
4794      return constant_op.constant(0, dtype=dtypes.bfloat16)
4795    else:
4796      return np.zeros_like(t.as_numpy_dtype())
4797
4798  def value_or_default(value, default):
4799    return default if value is None else value
4800
4801  default_padding = nest.map_structure(
4802      make_zero,
4803      get_legacy_output_types(input_dataset))
4804  return nest.map_structure_up_to(padding_values, value_or_default,
4805                                  padding_values, default_padding)
4806
4807
4808class PaddedBatchDataset(UnaryDataset):
4809  """A `Dataset` that batches and pads contiguous elements from its input."""
4810
4811  def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
4812               drop_remainder):
4813    """See `Dataset.batch()` for details."""
4814    self._input_dataset = input_dataset
4815
4816    def check_types(component_spec):
4817      if not isinstance(component_spec, tensor_spec.TensorSpec):
4818        raise TypeError("Padded batching of components of type ",
4819                        type(component_spec), " is not supported.")
4820
4821    nest.map_structure(check_types, input_dataset.element_spec)
4822    self._input_dataset = input_dataset
4823    self._batch_size = ops.convert_to_tensor(
4824        batch_size, dtype=dtypes.int64, name="batch_size")
4825    padding_values = _padding_values_or_default(padding_values, input_dataset)
4826
4827    input_shapes = get_legacy_output_shapes(input_dataset)
4828    flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
4829
4830    flat_padded_shapes_as_tensors = []
4831
4832    for input_component_shape, padded_shape in zip(
4833        nest.flatten(input_shapes), flat_padded_shapes):
4834      flat_padded_shapes_as_tensors.append(
4835          _padded_shape_to_tensor(padded_shape, input_component_shape))
4836
4837    self._padded_shapes = nest.pack_sequence_as(input_shapes,
4838                                                flat_padded_shapes_as_tensors)
4839
4840    # If padding_values is a single element and input_shapes is a structure,
4841    # "broadcast" padding_values to the same structure as input_shapes.
4842    if nest.is_sequence(input_shapes) and not nest.is_sequence(padding_values):
4843      padding_values = nest.map_structure(lambda _: padding_values,
4844                                          input_shapes)
4845
4846    self._padding_values = nest.map_structure_up_to(
4847        input_shapes, _padding_value_to_tensor, padding_values,
4848        get_legacy_output_types(input_dataset))
4849    self._drop_remainder = ops.convert_to_tensor(
4850        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
4851
4852    def _padded_shape_to_batch_shape(s):
4853      return tensor_shape.TensorShape([
4854          tensor_util.constant_value(self._batch_size)
4855          if smart_cond.smart_constant_value(self._drop_remainder) else None
4856      ]).concatenate(tensor_util.constant_value_as_shape(s))
4857
4858    output_shapes = nest.map_structure(
4859        _padded_shape_to_batch_shape, self._padded_shapes)
4860    self._structure = structure.convert_legacy_structure(
4861        get_legacy_output_types(self._input_dataset), output_shapes,
4862        get_legacy_output_classes(self._input_dataset))
4863
4864    # pylint: disable=protected-access
4865    # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
4866    if smart_cond.smart_constant_value(self._drop_remainder) is False:
4867      variant_tensor = gen_dataset_ops.padded_batch_dataset(
4868          input_dataset._variant_tensor,  # pylint: disable=protected-access
4869          batch_size=self._batch_size,
4870          padded_shapes=[
4871              ops.convert_to_tensor(s, dtype=dtypes.int64)
4872              for s in nest.flatten(self._padded_shapes)
4873          ],
4874          padding_values=nest.flatten(self._padding_values),
4875          output_shapes=structure.get_flat_tensor_shapes(self._structure))
4876    else:
4877      variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
4878          input_dataset._variant_tensor,  # pylint: disable=protected-access
4879          batch_size=self._batch_size,
4880          padded_shapes=[
4881              ops.convert_to_tensor(s, dtype=dtypes.int64)
4882              for s in nest.flatten(self._padded_shapes)
4883          ],
4884          padding_values=nest.flatten(self._padding_values),
4885          drop_remainder=self._drop_remainder,
4886          output_shapes=structure.get_flat_tensor_shapes(self._structure))
4887    super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
4888
4889  @property
4890  def element_spec(self):
4891    return self._structure
4892
4893
4894def _should_pack(arg):
4895  """Determines whether the caller needs to pack the argument in a tuple.
4896
4897  If user-defined function returns a list of tensors, `nest.flatten()` and
4898  `ops.convert_to_tensor()` and would conspire to attempt to stack those tensors
4899  into a single tensor because the tf.data version of `nest.flatten()` does
4900  not recurse into lists. Since it is more likely that the list arose from
4901  returning the result of an operation (such as `tf.numpy_function()`) that
4902  returns a list of not-necessarily-stackable tensors, we treat the returned
4903  value as a `tuple` instead. A user wishing to pack the return value into a
4904  single tensor can use an explicit `tf.stack()` before returning.
4905
4906  Args:
4907    arg: argument to check
4908
4909  Returns:
4910    Indication of whether the caller needs to pack the argument in a tuple.
4911  """
4912  return isinstance(arg, list)
4913
4914
4915def _should_unpack(arg):
4916  """Determines whether the caller needs to unpack the argument from a tuple.
4917
4918  Args:
4919    arg: argument to check
4920
4921  Returns:
4922    Indication of whether the caller needs to unpack the argument from a tuple.
4923  """
4924  return type(arg) is tuple  # pylint: disable=unidiomatic-typecheck
4925
4926
4927class MapDataset(UnaryDataset):
4928  """A `Dataset` that maps a function over elements in its input."""
4929
4930  def __init__(self,
4931               input_dataset,
4932               map_func,
4933               use_inter_op_parallelism=True,
4934               preserve_cardinality=False,
4935               use_legacy_function=False):
4936    """See `Dataset.map()` for details."""
4937    self._input_dataset = input_dataset
4938    self._use_inter_op_parallelism = use_inter_op_parallelism
4939    self._preserve_cardinality = preserve_cardinality
4940    self._map_func = StructuredFunctionWrapper(
4941        map_func,
4942        self._transformation_name(),
4943        dataset=input_dataset,
4944        use_legacy_function=use_legacy_function)
4945    variant_tensor = gen_dataset_ops.map_dataset(
4946        input_dataset._variant_tensor,  # pylint: disable=protected-access
4947        self._map_func.function.captured_inputs,
4948        f=self._map_func.function,
4949        use_inter_op_parallelism=self._use_inter_op_parallelism,
4950        preserve_cardinality=self._preserve_cardinality,
4951        **self._flat_structure)
4952    super(MapDataset, self).__init__(input_dataset, variant_tensor)
4953
4954  def _functions(self):
4955    return [self._map_func]
4956
4957  @property
4958  def element_spec(self):
4959    return self._map_func.output_structure
4960
4961  def _transformation_name(self):
4962    return "Dataset.map()"
4963
4964
4965class ParallelMapDataset(UnaryDataset):
4966  """A `Dataset` that maps a function over elements in its input in parallel."""
4967
4968  def __init__(self,
4969               input_dataset,
4970               map_func,
4971               num_parallel_calls,
4972               deterministic,
4973               use_inter_op_parallelism=True,
4974               preserve_cardinality=False,
4975               use_legacy_function=False):
4976    """See `Dataset.map()` for details."""
4977    self._input_dataset = input_dataset
4978    self._use_inter_op_parallelism = use_inter_op_parallelism
4979    self._map_func = StructuredFunctionWrapper(
4980        map_func,
4981        self._transformation_name(),
4982        dataset=input_dataset,
4983        use_legacy_function=use_legacy_function)
4984    if deterministic is None:
4985      self._deterministic = "default"
4986    elif deterministic:
4987      self._deterministic = "true"
4988    else:
4989      self._deterministic = "false"
4990    self._preserve_cardinality = preserve_cardinality
4991    self._num_parallel_calls = ops.convert_to_tensor(
4992        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
4993    variant_tensor = gen_dataset_ops.parallel_map_dataset_v2(
4994        input_dataset._variant_tensor,  # pylint: disable=protected-access
4995        self._map_func.function.captured_inputs,
4996        f=self._map_func.function,
4997        num_parallel_calls=self._num_parallel_calls,
4998        deterministic=self._deterministic,
4999        use_inter_op_parallelism=self._use_inter_op_parallelism,
5000        preserve_cardinality=self._preserve_cardinality,
5001        **self._flat_structure)
5002    super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
5003
5004  def _functions(self):
5005    return [self._map_func]
5006
5007  @property
5008  def element_spec(self):
5009    return self._map_func.output_structure
5010
5011  def _transformation_name(self):
5012    return "Dataset.map()"
5013
5014
5015class FlatMapDataset(UnaryDataset):
5016  """A `Dataset` that maps a function over its input and flattens the result."""
5017
5018  def __init__(self, input_dataset, map_func):
5019    """See `Dataset.flat_map()` for details."""
5020    self._input_dataset = input_dataset
5021    self._map_func = StructuredFunctionWrapper(
5022        map_func, self._transformation_name(), dataset=input_dataset)
5023    if not isinstance(self._map_func.output_structure, DatasetSpec):
5024      raise TypeError(
5025          "`map_func` must return a `Dataset` object. Got {}".format(
5026              type(self._map_func.output_structure)))
5027    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
5028    variant_tensor = gen_dataset_ops.flat_map_dataset(
5029        input_dataset._variant_tensor,  # pylint: disable=protected-access
5030        self._map_func.function.captured_inputs,
5031        f=self._map_func.function,
5032        **self._flat_structure)
5033    super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
5034
5035  def _functions(self):
5036    return [self._map_func]
5037
5038  @property
5039  def element_spec(self):
5040    return self._structure
5041
5042  def _transformation_name(self):
5043    return "Dataset.flat_map()"
5044
5045
5046class InterleaveDataset(UnaryDataset):
5047  """A `Dataset` that interleaves the result of transformed inputs."""
5048
5049  def __init__(self, input_dataset, map_func, cycle_length, block_length):
5050    """See `Dataset.interleave()` for details."""
5051
5052    self._input_dataset = input_dataset
5053    self._map_func = StructuredFunctionWrapper(
5054        map_func, self._transformation_name(), dataset=input_dataset)
5055    if not isinstance(self._map_func.output_structure, DatasetSpec):
5056      raise TypeError(
5057          "`map_func` must return a `Dataset` object. Got {}".format(
5058              type(self._map_func.output_structure)))
5059    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
5060    self._cycle_length = ops.convert_to_tensor(
5061        cycle_length, dtype=dtypes.int64, name="cycle_length")
5062    self._block_length = ops.convert_to_tensor(
5063        block_length, dtype=dtypes.int64, name="block_length")
5064
5065    variant_tensor = gen_dataset_ops.interleave_dataset(
5066        input_dataset._variant_tensor,  # pylint: disable=protected-access
5067        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
5068        self._cycle_length,
5069        self._block_length,
5070        f=self._map_func.function,
5071        **self._flat_structure)
5072    super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
5073
5074  def _functions(self):
5075    return [self._map_func]
5076
5077  @property
5078  def element_spec(self):
5079    return self._structure
5080
5081  def _transformation_name(self):
5082    return "Dataset.interleave()"
5083
5084
5085class ParallelInterleaveDataset(UnaryDataset):
5086  """A `Dataset` that maps a function over its input and interleaves the result."""
5087
5088  def __init__(self,
5089               input_dataset,
5090               map_func,
5091               cycle_length,
5092               block_length,
5093               num_parallel_calls,
5094               buffer_output_elements=AUTOTUNE,
5095               prefetch_input_elements=AUTOTUNE,
5096               deterministic=None):
5097    """See `Dataset.interleave()` for details."""
5098    self._input_dataset = input_dataset
5099    self._map_func = StructuredFunctionWrapper(
5100        map_func, self._transformation_name(), dataset=input_dataset)
5101    if not isinstance(self._map_func.output_structure, DatasetSpec):
5102      raise TypeError(
5103          "`map_func` must return a `Dataset` object. Got {}".format(
5104              type(self._map_func.output_structure)))
5105    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
5106    self._cycle_length = ops.convert_to_tensor(
5107        cycle_length, dtype=dtypes.int64, name="cycle_length")
5108    self._block_length = ops.convert_to_tensor(
5109        block_length, dtype=dtypes.int64, name="block_length")
5110    self._buffer_output_elements = ops.convert_to_tensor(
5111        buffer_output_elements,
5112        dtype=dtypes.int64,
5113        name="buffer_output_elements")
5114    self._prefetch_input_elements = ops.convert_to_tensor(
5115        prefetch_input_elements,
5116        dtype=dtypes.int64,
5117        name="prefetch_input_elements")
5118
5119    self._num_parallel_calls = ops.convert_to_tensor(
5120        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
5121    if deterministic is None:
5122      deterministic_string = "default"
5123    elif deterministic:
5124      deterministic_string = "true"
5125    else:
5126      deterministic_string = "false"
5127
5128    variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
5129        input_dataset._variant_tensor,  # pylint: disable=protected-access
5130        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
5131        self._cycle_length,
5132        self._block_length,
5133        self._buffer_output_elements,
5134        self._prefetch_input_elements,
5135        self._num_parallel_calls,
5136        f=self._map_func.function,
5137        deterministic=deterministic_string,
5138        **self._flat_structure)
5139    super(ParallelInterleaveDataset, self).__init__(input_dataset,
5140                                                    variant_tensor)
5141
5142  def _functions(self):
5143    return [self._map_func]
5144
5145  @property
5146  def element_spec(self):
5147    return self._structure
5148
5149  def _transformation_name(self):
5150    return "Dataset.interleave()"
5151
5152
5153class FilterDataset(UnaryUnchangedStructureDataset):
5154  """A `Dataset` that filters its input according to a predicate function."""
5155
5156  def __init__(self, input_dataset, predicate, use_legacy_function=False):
5157    """See `Dataset.filter()` for details."""
5158    self._input_dataset = input_dataset
5159    wrapped_func = StructuredFunctionWrapper(
5160        predicate,
5161        self._transformation_name(),
5162        dataset=input_dataset,
5163        use_legacy_function=use_legacy_function)
5164    if not wrapped_func.output_structure.is_compatible_with(
5165        tensor_spec.TensorSpec([], dtypes.bool)):
5166      error_msg = ("`predicate` return type must be convertible to a scalar "
5167                   "boolean tensor. Was {}.").format(
5168                       wrapped_func.output_structure)
5169      raise ValueError(error_msg)
5170    self._predicate = wrapped_func
5171    variant_tensor = gen_dataset_ops.filter_dataset(
5172        input_dataset._variant_tensor,  # pylint: disable=protected-access
5173        other_arguments=self._predicate.function.captured_inputs,
5174        predicate=self._predicate.function,
5175        **self._flat_structure)
5176    super(FilterDataset, self).__init__(input_dataset, variant_tensor)
5177
5178  def _functions(self):
5179    return [self._predicate]
5180
5181  def _transformation_name(self):
5182    return "Dataset.filter()"
5183
5184
5185class PrefetchDataset(UnaryUnchangedStructureDataset):
5186  """A `Dataset` that asynchronously prefetches its input."""
5187
5188  def __init__(self, input_dataset, buffer_size, slack_period=None):
5189    """See `Dataset.prefetch()` for details.
5190
5191    Args:
5192      input_dataset: The input dataset.
5193      buffer_size: See `Dataset.prefetch()` for details.
5194      slack_period: (Optional.) An integer. If non-zero, determines the number
5195        of GetNext calls before injecting slack into the execution. This may
5196        reduce CPU contention at the start of a step. Note that a tensorflow
5197        user should not have to set this manually; enable this behavior
5198        automatically via `tf.data.Options.experimental_slack` instead. Defaults
5199        to None.
5200    """
5201    self._input_dataset = input_dataset
5202    if buffer_size is None:
5203      buffer_size = AUTOTUNE
5204    self._buffer_size = ops.convert_to_tensor(
5205        buffer_size, dtype=dtypes.int64, name="buffer_size")
5206    # pylint: disable=protected-access
5207    # We colocate the prefetch dataset with its input as this collocation only
5208    # happens automatically in graph mode.
5209    with ops.colocate_with(input_dataset._variant_tensor):
5210      variant_tensor = gen_dataset_ops.prefetch_dataset(
5211          input_dataset._variant_tensor,
5212          buffer_size=self._buffer_size,
5213          slack_period=slack_period,
5214          **self._flat_structure)
5215    super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
5216
5217
5218class WindowDataset(UnaryDataset):
5219  """A dataset that creates window datasets from the input elements."""
5220
5221  def __init__(self, input_dataset, size, shift, stride, drop_remainder):
5222    """See `window_dataset()` for more details."""
5223    self._input_dataset = input_dataset
5224    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
5225    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
5226    self._stride = ops.convert_to_tensor(
5227        stride, dtype=dtypes.int64, name="stride")
5228    self._drop_remainder = ops.convert_to_tensor(
5229        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
5230    self._structure = nest.pack_sequence_as(
5231        get_legacy_output_classes(input_dataset), [
5232            DatasetSpec(  # pylint: disable=g-complex-comprehension
5233                structure.convert_legacy_structure(
5234                    output_type, output_shape, output_class))
5235            for output_class, output_shape, output_type in zip(
5236                nest.flatten(get_legacy_output_classes(input_dataset)),
5237                nest.flatten(get_legacy_output_shapes(input_dataset)),
5238                nest.flatten(get_legacy_output_types(input_dataset)))
5239        ])
5240    variant_tensor = gen_dataset_ops.window_dataset(
5241        input_dataset._variant_tensor,  # pylint: disable=protected-access
5242        self._size,
5243        self._shift,
5244        self._stride,
5245        self._drop_remainder,
5246        **self._flat_structure)
5247    super(WindowDataset, self).__init__(input_dataset, variant_tensor)
5248
5249  @property
5250  def element_spec(self):
5251    return self._structure
5252
5253
5254class _OptionsDataset(UnaryUnchangedStructureDataset):
5255  """An identity `Dataset` that stores options."""
5256
5257  def __init__(self, input_dataset, options):
5258    # pylint: disable=protected-access
5259    self._input_dataset = input_dataset
5260    options_pb = dataset_options_pb2.Options()
5261    options_pb.CopyFrom(options._to_proto())
5262    with ops.colocate_with(input_dataset._variant_tensor):
5263      variant_tensor = gen_dataset_ops.options_dataset(
5264          input_dataset._variant_tensor,
5265          options_pb.SerializeToString(), **self._flat_structure)
5266    super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
5267
5268    if self._options_attr:
5269      self._options_attr._set_mutable(True)
5270      self._options_attr = self._options_attr.merge(options)
5271    else:
5272      self._options_attr = options
5273    self._options_attr._set_mutable(False)
5274
5275
5276def normalize_to_dense(dataset):
5277  """Normalizes non-tensor components in a dataset to dense representations.
5278
5279  This is necessary for dataset transformations that slice along the batch
5280  dimension and are oblivious to non-tensors, e.g. `unbatch`, `rebatch`.
5281
5282  Args:
5283    dataset: Dataset to normalize.
5284
5285  Returns:
5286    A dataset whose sparse and ragged tensors have been normalized to their
5287    dense representations.
5288  """
5289
5290  # NOTE(mrry): This leads to a somewhat inefficient re-encoding step for all
5291  # non-tensor components.
5292  #
5293  # TODO(mrry): Consider optimizing this if it turns out to be a bottleneck.
5294  if _should_unpack(dataset.element_spec):
5295
5296    def normalize(*args):
5297      return structure.to_batched_tensor_list(dataset.element_spec, tuple(args))
5298  else:
5299    def normalize(arg):
5300      return structure.to_batched_tensor_list(dataset.element_spec, arg)
5301
5302  normalized_dataset = dataset.map(normalize)
5303
5304  # NOTE(mrry): Our `map()` has lost information about the structure of
5305  # non-tensor components, so re-apply the structure of the original dataset.
5306  return _RestructuredDataset(normalized_dataset, dataset.element_spec)
5307
5308
5309class _RestructuredDataset(UnaryDataset):
5310  """An internal helper for changing the element spec of a dataset."""
5311
5312  def __init__(self, dataset, structure):
5313    self._input_dataset = dataset
5314    self._structure = structure
5315
5316    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
5317    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
5318
5319  @property
5320  def element_spec(self):
5321    return self._structure
5322
5323
5324class _UnbatchDataset(UnaryDataset):
5325  """A dataset that splits the elements of its input into multiple elements."""
5326
5327  def __init__(self, input_dataset):
5328    """See `unbatch()` for more details."""
5329    flat_shapes = input_dataset._flat_shapes  # pylint: disable=protected-access
5330    if any(s.ndims == 0 for s in flat_shapes):
5331      raise ValueError("Cannot unbatch an input with scalar components.")
5332    known_batch_dim = tensor_shape.Dimension(None)
5333    for s in flat_shapes:
5334      try:
5335        known_batch_dim = known_batch_dim.merge_with(s[0])
5336      except ValueError:
5337        raise ValueError("Cannot unbatch an input whose components have "
5338                         "different batch sizes.")
5339    self._input_dataset = input_dataset
5340    self._structure = nest.map_structure(
5341        lambda component_spec: component_spec._unbatch(),  # pylint: disable=protected-access
5342        get_structure(input_dataset))
5343    variant_tensor = ged_ops.unbatch_dataset(
5344        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
5345        **self._flat_structure)
5346    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
5347
5348  @property
5349  def element_spec(self):
5350    return self._structure
5351
5352
5353class _GroupByWindowDataset(UnaryDataset):
5354  """A `Dataset` that groups its input and performs a windowed reduction."""
5355
5356  def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
5357    """See `group_by_window()` for details."""
5358    self._input_dataset = input_dataset
5359    self._make_key_func(key_func, input_dataset)
5360    self._make_reduce_func(reduce_func, input_dataset)
5361    self._make_window_size_func(window_size_func)
5362    variant_tensor = ged_ops.group_by_window_dataset(
5363        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
5364        self._key_func.function.captured_inputs,
5365        self._reduce_func.function.captured_inputs,
5366        self._window_size_func.function.captured_inputs,
5367        key_func=self._key_func.function,
5368        reduce_func=self._reduce_func.function,
5369        window_size_func=self._window_size_func.function,
5370        **self._flat_structure)
5371    super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
5372
5373  def _make_window_size_func(self, window_size_func):
5374    """Make wrapping defun for window_size_func."""
5375
5376    def window_size_func_wrapper(key):
5377      return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
5378
5379    self._window_size_func = StructuredFunctionWrapper(
5380        window_size_func_wrapper,
5381        self._transformation_name(),
5382        input_structure=tensor_spec.TensorSpec([], dtypes.int64))
5383    if not self._window_size_func.output_structure.is_compatible_with(
5384        tensor_spec.TensorSpec([], dtypes.int64)):
5385      raise ValueError(
5386          "`window_size_func` must return a single tf.int64 scalar tensor.")
5387
5388  def _make_key_func(self, key_func, input_dataset):
5389    """Make wrapping defun for key_func."""
5390
5391    def key_func_wrapper(*args):
5392      return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
5393
5394    self._key_func = StructuredFunctionWrapper(
5395        key_func_wrapper, self._transformation_name(), dataset=input_dataset)
5396    if not self._key_func.output_structure.is_compatible_with(
5397        tensor_spec.TensorSpec([], dtypes.int64)):
5398      raise ValueError(
5399          "`key_func` must return a single tf.int64 scalar tensor.")
5400
5401  def _make_reduce_func(self, reduce_func, input_dataset):
5402    """Make wrapping defun for reduce_func."""
5403    nested_dataset = DatasetSpec(input_dataset.element_spec)
5404    input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset)
5405    self._reduce_func = StructuredFunctionWrapper(
5406        reduce_func,
5407        self._transformation_name(),
5408        input_structure=input_structure)
5409    if not isinstance(self._reduce_func.output_structure, DatasetSpec):
5410      raise TypeError("`reduce_func` must return a `Dataset` object.")
5411    # pylint: disable=protected-access
5412    self._element_spec = (self._reduce_func.output_structure._element_spec)
5413
5414  @property
5415  def element_spec(self):
5416    return self._element_spec
5417
5418  def _functions(self):
5419    return [self._key_func, self._reduce_func, self._window_size_func]
5420
5421  def _transformation_name(self):
5422    return "Dataset.group_by_window()"
5423
5424
5425class RandomDataset(DatasetSource):
5426  """A `Dataset` of pseudorandom values."""
5427
5428  def __init__(self, seed=None):
5429    """A `Dataset` of pseudorandom values."""
5430    self._seed, self._seed2 = random_seed.get_seed(seed)
5431    variant_tensor = ged_ops.random_dataset(
5432        seed=self._seed, seed2=self._seed2, **self._flat_structure)
5433    super(RandomDataset, self).__init__(variant_tensor)
5434
5435  @property
5436  def element_spec(self):
5437    return tensor_spec.TensorSpec([], dtypes.int64)
5438
5439
5440def _get_prob_original_static(initial_dist_t, target_dist_t):
5441  """Returns the static probability of sampling from the original.
5442
5443  `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
5444  an Op that it isn't defined for. We have some custom logic to avoid this.
5445
5446  Args:
5447    initial_dist_t: A tensor of the initial distribution.
5448    target_dist_t: A tensor of the target distribution.
5449
5450  Returns:
5451    The probability of sampling from the original distribution as a constant,
5452    if it is a constant, or `None`.
5453  """
5454  init_static = tensor_util.constant_value(initial_dist_t)
5455  target_static = tensor_util.constant_value(target_dist_t)
5456
5457  if init_static is None or target_static is None:
5458    return None
5459  else:
5460    return np.min(target_static / init_static)
5461
5462
5463def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_func, seed):
5464  """Filters a dataset based on per-class acceptance probabilities.
5465
5466  Args:
5467    dataset: The dataset to be filtered.
5468    acceptance_dist_ds: A dataset of acceptance probabilities.
5469    initial_dist_ds: A dataset of the initial probability distribution, given or
5470      estimated.
5471    class_func: A function mapping an element of the input dataset to a scalar
5472      `tf.int32` tensor. Values should be in `[0, num_classes)`.
5473    seed: (Optional.) Python integer seed for the resampler.
5474
5475  Returns:
5476    A dataset of (class value, data) after filtering.
5477  """
5478
5479  def maybe_warn_on_large_rejection(accept_dist, initial_dist):
5480    proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
5481    return control_flow_ops.cond(
5482        math_ops.less(proportion_rejected, .5),
5483        lambda: accept_dist,
5484        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
5485            accept_dist, [proportion_rejected, initial_dist, accept_dist],
5486            message="Proportion of examples rejected by sampler is high: ",
5487            summarize=100,
5488            first_n=10))
5489
5490  acceptance_dist_ds = (
5491      DatasetV2.zip((acceptance_dist_ds,
5492                     initial_dist_ds)).map(maybe_warn_on_large_rejection))
5493
5494  def _gather_and_copy(acceptance_prob, data):
5495    if isinstance(data, tuple):
5496      class_val = class_func(*data)
5497    else:
5498      class_val = class_func(data)
5499    return class_val, array_ops.gather(acceptance_prob, class_val), data
5500
5501  current_probabilities_and_class_and_data_ds = DatasetV2.zip(
5502      (acceptance_dist_ds, dataset)).map(_gather_and_copy)
5503
5504  def _reject(unused_class_val, p, unused_data):
5505    return random_ops.random_uniform([], seed=seed, dtype=p.dtype) < p
5506
5507  filtered_ds = current_probabilities_and_class_and_data_ds.filter(_reject)
5508  return filtered_ds.map(lambda class_value, _, data: (class_value, data))
5509
5510
5511# pylint: disable=missing-function-docstring
5512def _estimate_initial_dist_ds(target_dist_t,
5513                              class_values_ds,
5514                              dist_estimation_batch_size=32,
5515                              smoothing_constant=10):
5516  num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
5517  initial_examples_per_class_seen = array_ops.fill([num_classes],
5518                                                   np.int64(smoothing_constant))
5519
5520  def update_estimate_and_tile(num_examples_per_class_seen, c):
5521    updated_examples_per_class_seen, dist = _estimate_data_distribution(
5522        c, num_examples_per_class_seen)
5523    tiled_dist = array_ops.tile(
5524        array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
5525    return updated_examples_per_class_seen, tiled_dist
5526
5527  initial_dist_ds = (
5528      class_values_ds.batch(dist_estimation_batch_size).scan(
5529          initial_examples_per_class_seen, update_estimate_and_tile).unbatch())
5530
5531  return initial_dist_ds
5532
5533
5534def _get_target_to_initial_ratio(initial_probs, target_probs):
5535  # Add tiny to initial_probs to avoid divide by zero.
5536  denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
5537  return target_probs / denom
5538
5539
5540def _estimate_data_distribution(c, num_examples_per_class_seen):
5541  """Estimate data distribution as labels are seen.
5542
5543  Args:
5544    c: The class labels.  Type `int32`, shape `[batch_size]`.
5545    num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, containing
5546      counts.
5547
5548  Returns:
5549    num_examples_per_lass_seen: Updated counts.  Type `int64`, shape
5550      `[num_classes]`.
5551    dist: The updated distribution.  Type `float32`, shape `[num_classes]`.
5552  """
5553  num_classes = num_examples_per_class_seen.get_shape()[0]
5554  # Update the class-count based on what labels are seen in batch.
5555  num_examples_per_class_seen = math_ops.add(
5556      num_examples_per_class_seen,
5557      math_ops.reduce_sum(
5558          array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
5559  init_prob_estimate = math_ops.truediv(
5560      num_examples_per_class_seen,
5561      math_ops.reduce_sum(num_examples_per_class_seen))
5562  dist = math_ops.cast(init_prob_estimate, dtypes.float32)
5563  return num_examples_per_class_seen, dist
5564
5565
5566def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
5567  """Calculates the acceptance probabilities and mixing ratio.
5568
5569  In this case, we assume that we can *either* sample from the original data
5570  distribution with probability `m`, or sample from a reshaped distribution
5571  that comes from rejection sampling on the original distribution. This
5572  rejection sampling is done on a per-class basis, with `a_i` representing the
5573  probability of accepting data from class `i`.
5574
5575  This method is based on solving the following analysis for the reshaped
5576  distribution:
5577
5578  Let F be the probability of a rejection (on any example).
5579  Let p_i be the proportion of examples in the data in class i (init_probs)
5580  Let a_i is the rate the rejection sampler should *accept* class i
5581  Let t_i is the target proportion in the minibatches for class i (target_probs)
5582
5583  ```
5584  F = sum_i(p_i * (1-a_i))
5585    = 1 - sum_i(p_i * a_i)     using sum_i(p_i) = 1
5586  ```
5587
5588  An example with class `i` will be accepted if `k` rejections occur, then an
5589  example with class `i` is seen by the rejector, and it is accepted. This can
5590  be written as follows:
5591
5592  ```
5593  t_i = sum_k=0^inf(F^k * p_i * a_i)
5594      = p_i * a_j / (1 - F)    using geometric series identity, since 0 <= F < 1
5595      = p_i * a_i / sum_j(p_j * a_j)        using F from above
5596  ```
5597
5598  Note that the following constraints hold:
5599  ```
5600  0 <= p_i <= 1, sum_i(p_i) = 1
5601  0 <= a_i <= 1
5602  0 <= t_i <= 1, sum_i(t_i) = 1
5603  ```
5604
5605  A solution for a_i in terms of the other variables is the following:
5606    ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
5607
5608  If we try to minimize the amount of data rejected, we get the following:
5609
5610  M_max = max_i [ t_i / p_i ]
5611  M_min = min_i [ t_i / p_i ]
5612
5613  The desired probability of accepting data if it comes from class `i`:
5614
5615  a_i = (t_i/p_i - m) / (M_max - m)
5616
5617  The desired probability of pulling a data element from the original dataset,
5618  rather than the filtered one:
5619
5620  m = M_min
5621
5622  Args:
5623    initial_probs: A Tensor of the initial probability distribution, given or
5624      estimated.
5625    target_probs: A Tensor of the corresponding classes.
5626
5627  Returns:
5628    (A 1D Tensor with the per-class acceptance probabilities, the desired
5629    probability of pull from the original distribution.)
5630  """
5631  ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
5632  max_ratio = math_ops.reduce_max(ratio_l)
5633  min_ratio = math_ops.reduce_min(ratio_l)
5634
5635  # Target prob to sample from original distribution.
5636  m = min_ratio
5637
5638  # TODO(joelshor): Simplify fraction, if possible.
5639  a_i = (ratio_l - m) / (max_ratio - m)
5640  return a_i, m
5641
5642
5643class _TakeWhileDataset(UnaryUnchangedStructureDataset):
5644  """A dataset that stops iteration when `predicate` returns false."""
5645
5646  def __init__(self, input_dataset, predicate):
5647    """See `take_while()` for details."""
5648
5649    self._input_dataset = input_dataset
5650    wrapped_func = StructuredFunctionWrapper(
5651        predicate, self._transformation_name(), dataset=self._input_dataset)
5652
5653    if not wrapped_func.output_structure.is_compatible_with(
5654        tensor_spec.TensorSpec([], dtypes.bool)):
5655      raise ValueError("`predicate` must return a scalar boolean tensor.")
5656
5657    self._predicate = wrapped_func
5658    var_tensor = ged_ops.take_while_dataset(
5659        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
5660        other_arguments=self._predicate.function.captured_inputs,
5661        predicate=self._predicate.function,
5662        **self._flat_structure)
5663    super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
5664
5665  def _functions(self):
5666    return [self._predicate]
5667
5668  def _transformation_name(self):
5669    return "Dataset.take_while()"
5670
5671
5672class _UniqueDataset(UnaryUnchangedStructureDataset):
5673  """A `Dataset` contains the unique elements from its input."""
5674
5675  def __init__(self, input_dataset):
5676    """See `unique()` for details."""
5677    self._input_dataset = input_dataset
5678    if get_legacy_output_types(input_dataset) not in (dtypes.int32,
5679                                                      dtypes.int64,
5680                                                      dtypes.string):
5681      raise TypeError(
5682          "`tf.data.Dataset.unique()` only supports inputs with a single "
5683          "`tf.int32`, `tf.int64`, or `tf.string` component.")
5684    variant_tensor = ged_ops.unique_dataset(
5685        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
5686        **self._flat_structure)
5687    super(_UniqueDataset, self).__init__(input_dataset, variant_tensor)
5688
5689
5690def _collect_resource_inputs(op):
5691  """Collects resource inputs for the given ops (and its variant inputs)."""
5692
5693  def _process(op_queue, seen_ops):
5694    """Processes the next element of the op queue.
5695
5696    Args:
5697      op_queue: Queue of Dataset operations to process.
5698      seen_ops: Already processed set of Operations.
5699
5700    Returns:
5701      A 2-tuple containing sets of resource handles. The first tuple entry
5702      contains read-only handles and the second entry contains read-write
5703      handles.
5704    """
5705
5706    reads = []
5707    writes = []
5708    op = op_queue.pop()
5709    if op in seen_ops:
5710      return reads, writes
5711    seen_ops.add(op)
5712    # TODO(b/150139257): All resource inputs are in writes right now since we
5713    # have not updated the functional ops to set the special attribute that ACD
5714    # uses to figure out which of the op's inputs are read-only.
5715    reads, writes = acd_utils.get_read_write_resource_inputs(op)
5716    # Conservatively assume that any variant inputs are datasets.
5717    op_queue.extend(t.op for t in op.inputs if t.dtype == dtypes.variant)
5718    return reads, writes
5719
5720  op_queue = [op]
5721  seen_ops = set()
5722  all_reads = []
5723  all_writes = []
5724  while op_queue:
5725    reads, writes = _process(op_queue, seen_ops)
5726    all_reads.extend(reads)
5727    all_writes.extend(writes)
5728
5729  return all_reads, all_writes
5730
5731
5732class _SnapshotDataset(UnaryUnchangedStructureDataset):
5733  """A dataset that allows saving and re-use of already processed data."""
5734
5735  def __init__(self,
5736               input_dataset,
5737               path,
5738               shard_func,
5739               compression=None,
5740               reader_func=None,
5741               pending_snapshot_expiry_seconds=None,
5742               use_legacy_function=False):
5743
5744    if reader_func is None:
5745      reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
5746          lambda x: x,
5747          cycle_length=multiprocessing.cpu_count(),
5748          num_parallel_calls=AUTOTUNE)
5749
5750    self._input_dataset = input_dataset
5751    self._path = path
5752    self._compression = compression
5753
5754    self._reader_func = StructuredFunctionWrapper(
5755        reader_func,
5756        self._transformation_name() + ".reader_func",
5757        # Dataset of datasets of input elements
5758        input_structure=DatasetSpec(DatasetSpec(input_dataset.element_spec)),
5759        use_legacy_function=use_legacy_function)
5760    self._shard_func = StructuredFunctionWrapper(
5761        shard_func,
5762        self._transformation_name() + ".shard_func",
5763        dataset=input_dataset,
5764        use_legacy_function=use_legacy_function)
5765
5766    if ((not self._shard_func.output_structure.is_compatible_with(
5767        tensor_spec.TensorSpec([], dtypes.int32))) and
5768        (not self._shard_func.output_structure.is_compatible_with(
5769            tensor_spec.TensorSpec([], dtypes.int64)))):
5770      raise TypeError(
5771          "shard_func must return a 0-dimension tensor containing an int.")
5772
5773    variant_tensor = ged_ops.snapshot_dataset_v2(
5774        input_dataset._variant_tensor,  # pylint: disable=protected-access
5775        path,
5776        self._reader_func.function.captured_inputs,
5777        self._shard_func.function.captured_inputs,
5778        compression=compression,
5779        reader_func=self._reader_func.function,
5780        shard_func=self._shard_func.function,
5781        **self._flat_structure)
5782    super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
5783
5784  def _functions(self):
5785    return [self._reader_func, self._shard_func]
5786
5787  def _transformation_name(self):
5788    return "Dataset.snapshot()"
5789
5790
5791class _ScanDataset(UnaryDataset):
5792  """A dataset that scans a function across its input."""
5793
5794  def __init__(self,
5795               input_dataset,
5796               initial_state,
5797               scan_func,
5798               use_default_device=None):
5799    """See `scan()` for details."""
5800    self._input_dataset = input_dataset
5801    self._initial_state = structure.normalize_element(initial_state)
5802
5803    # Compute initial values for the state classes, shapes and types based on
5804    # the initial state. The shapes may be refined by running `tf_scan_func` one
5805    # or more times below.
5806    self._state_structure = structure.type_spec_from_value(self._initial_state)
5807
5808    # Iteratively rerun the scan function until reaching a fixed point on
5809    # `self._state_shapes`.
5810    need_to_rerun = True
5811    while need_to_rerun:
5812
5813      wrapped_func = StructuredFunctionWrapper(
5814          scan_func,
5815          self._transformation_name(),
5816          input_structure=(self._state_structure, input_dataset.element_spec),
5817          add_to_graph=False)
5818      if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
5819              and len(wrapped_func.output_types) == 2):
5820        raise TypeError("The scan function must return a pair comprising the "
5821                        "new state and the output value.")
5822
5823      new_state_classes, self._output_classes = wrapped_func.output_classes
5824
5825      # Extract and validate class information from the returned values.
5826      new_state_classes, output_classes = wrapped_func.output_classes
5827      old_state_classes = nest.map_structure(
5828          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
5829          self._state_structure)
5830      for new_state_class, old_state_class in zip(
5831          nest.flatten(new_state_classes), nest.flatten(old_state_classes)):
5832        if not issubclass(new_state_class, old_state_class):
5833          raise TypeError(
5834              "The element classes for the new state must match the initial "
5835              "state. Expected %s; got %s." %
5836              (old_state_classes, new_state_classes))
5837
5838      # Extract and validate type information from the returned values.
5839      new_state_types, output_types = wrapped_func.output_types
5840      old_state_types = nest.map_structure(
5841          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
5842          self._state_structure)
5843      for new_state_type, old_state_type in zip(
5844          nest.flatten(new_state_types), nest.flatten(old_state_types)):
5845        if new_state_type != old_state_type:
5846          raise TypeError(
5847              "The element types for the new state must match the initial "
5848              "state. Expected %s; got %s." %
5849              (old_state_types, new_state_types))
5850
5851      # Extract shape information from the returned values.
5852      new_state_shapes, output_shapes = wrapped_func.output_shapes
5853      old_state_shapes = nest.map_structure(
5854          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
5855          self._state_structure)
5856      self._element_spec = structure.convert_legacy_structure(
5857          output_types, output_shapes, output_classes)
5858
5859      flat_state_shapes = nest.flatten(old_state_shapes)
5860      flat_new_state_shapes = nest.flatten(new_state_shapes)
5861      weakened_state_shapes = [
5862          original.most_specific_compatible_shape(new)
5863          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
5864      ]
5865
5866      need_to_rerun = False
5867      for original_shape, weakened_shape in zip(flat_state_shapes,
5868                                                weakened_state_shapes):
5869        if original_shape.ndims is not None and (
5870            weakened_shape.ndims is None or
5871            original_shape.as_list() != weakened_shape.as_list()):
5872          need_to_rerun = True
5873          break
5874
5875      if need_to_rerun:
5876        # TODO(b/110122868): Support a "most specific compatible structure"
5877        # method for combining structures, to avoid using legacy structures
5878        # in this method.
5879        self._state_structure = structure.convert_legacy_structure(
5880            old_state_types,
5881            nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
5882            old_state_classes)
5883
5884    self._scan_func = wrapped_func
5885    self._scan_func.function.add_to_graph(ops.get_default_graph())
5886    # pylint: disable=protected-access
5887    if use_default_device is not None:
5888      variant_tensor = ged_ops.scan_dataset(
5889          self._input_dataset._variant_tensor,
5890          structure.to_tensor_list(self._state_structure, self._initial_state),
5891          self._scan_func.function.captured_inputs,
5892          f=self._scan_func.function,
5893          preserve_cardinality=True,
5894          use_default_device=use_default_device,
5895          **self._flat_structure)
5896    else:
5897      variant_tensor = ged_ops.scan_dataset(
5898          self._input_dataset._variant_tensor,
5899          structure.to_tensor_list(self._state_structure, self._initial_state),
5900          self._scan_func.function.captured_inputs,
5901          f=self._scan_func.function,
5902          preserve_cardinality=True,
5903          **self._flat_structure)
5904    super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
5905
5906  def _functions(self):
5907    return [self._scan_func]
5908
5909  @property
5910  def element_spec(self):
5911    return self._element_spec
5912
5913  def _transformation_name(self):
5914    return "Dataset.scan()"
5915
5916
5917@auto_control_deps.register_acd_resource_resolver
5918def _resource_resolver(op, resource_reads, resource_writes):
5919  """Updates resource inputs for tf.data ops with indirect dependencies."""
5920
5921  updated = False
5922  if op.type in [
5923      "DatasetToSingleElement", "DatasetToTFRecord", "ReduceDataset"
5924  ]:
5925    reads, writes = _collect_resource_inputs(op)
5926    for inp in reads:
5927      if inp not in resource_reads:
5928        updated = True
5929        resource_reads.add(inp)
5930    for inp in writes:
5931      if inp not in resource_writes:
5932        updated = True
5933        resource_writes.add(inp)
5934
5935  if op.type in [
5936      "IteratorGetNext", "IteratorGetNextSync", "IteratorGetNextAsOptional"
5937  ]:
5938    iterator_resource = op.inputs[0]
5939    make_iterator_ops = [
5940        op for op in iterator_resource.consumers() if op.type == "MakeIterator"
5941    ]
5942
5943    if len(make_iterator_ops) == 1:
5944      reads, writes = _collect_resource_inputs(make_iterator_ops[0])
5945      for inp in reads:
5946        if inp not in resource_reads:
5947          updated = True
5948          resource_reads.add(inp)
5949      for inp in writes:
5950        if inp not in resource_writes:
5951          updated = True
5952          resource_writes.add(inp)
5953
5954  return updated
5955
5956
5957DEBUG_MODE = False
5958
5959
5960@tf_export("data.experimental.enable_debug_mode")
5961def enable_debug_mode():
5962  """Enables debug mode for tf.data.
5963
5964  Example usage with pdb module:
5965  ```
5966  import tensorflow as tf
5967  import pdb
5968
5969  tf.data.experimental.enable_debug_mode()
5970
5971  def func(x):
5972    # Python 3.7 and older requires `pdb.Pdb(nosigint=True).set_trace()`
5973    pdb.set_trace()
5974    x = x + 1
5975    return x
5976
5977  dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
5978  dataset = dataset.map(func)
5979
5980  for item in dataset:
5981    print(item)
5982  ```
5983
5984  The effect of debug mode is two-fold:
5985
5986  1) Any transformations that would introduce asynchrony, parallelism, or
5987  non-determinism to the input pipeline execution will be forced to execute
5988  synchronously, sequentially, and deterministically.
5989
5990  2) Any user-defined functions passed into tf.data transformations such as
5991  `map` will be wrapped in `tf.py_function` so that their body is executed
5992  "eagerly" as a Python function as opposed to a traced TensorFlow graph, which
5993  is the default behavior. Note that even when debug mode is enabled, the
5994  user-defined function is still traced  to infer the shape and type of its
5995  outputs; as a consequence, any `print` statements or breakpoints will be
5996  triggered once during the tracing before the actual execution of the input
5997  pipeline.
5998
5999  NOTE: As the debug mode setting affects the construction of the tf.data input
6000  pipeline, it should be enabled before any tf.data definitions.
6001
6002  Raises:
6003    ValueError: When invoked from graph mode.
6004  """
6005  if context.executing_eagerly():
6006    toggle_debug_mode(True)
6007  else:
6008    raise ValueError("Debug mode is only supported in eager mode.")
6009
6010
6011def toggle_debug_mode(debug_mode):
6012  global DEBUG_MODE
6013  DEBUG_MODE = debug_mode
6014