• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import functools
22import threading
23import warnings
24
25import numpy as np
26import six
27from six.moves import queue as Queue  # pylint: disable=redefined-builtin
28
29
30from tensorflow.python.compat import compat
31from tensorflow.python.data.experimental.ops import optimization_options
32from tensorflow.python.data.experimental.ops import stats_options
33from tensorflow.python.data.experimental.ops import threading_options
34from tensorflow.python.data.ops import iterator_ops
35from tensorflow.python.data.util import nest
36from tensorflow.python.data.util import options as options_lib
37from tensorflow.python.data.util import random_seed
38from tensorflow.python.data.util import sparse
39from tensorflow.python.data.util import structure as structure_lib
40from tensorflow.python.data.util import traverse
41from tensorflow.python.eager import context
42from tensorflow.python.eager import function as eager_function
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import function
46from tensorflow.python.framework import ops
47from tensorflow.python.framework import random_seed as core_random_seed
48from tensorflow.python.framework import smart_cond
49from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
50from tensorflow.python.framework import tensor_shape
51from tensorflow.python.framework import tensor_spec
52from tensorflow.python.framework import tensor_util
53from tensorflow.python.ops import array_ops
54from tensorflow.python.ops import control_flow_ops
55from tensorflow.python.ops import gen_dataset_ops
56from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
57from tensorflow.python.ops import gen_io_ops
58from tensorflow.python.ops import math_ops
59from tensorflow.python.ops import script_ops
60from tensorflow.python.ops import string_ops
61from tensorflow.python.platform import tf_logging as logging
62from tensorflow.python.training.tracking import tracking
63from tensorflow.python.util import deprecation
64from tensorflow.python.util import function_utils
65from tensorflow.python.util.tf_export import tf_export
66
67
68ops.NotDifferentiable("ReduceDataset")
69
70
71@tf_export("data.Dataset", v1=[])
72@six.add_metaclass(abc.ABCMeta)
73class DatasetV2(object):
74  """Represents a potentially large set of elements.
75
76  A `Dataset` can be used to represent an input pipeline as a
77  collection of elements (nested structures of tensors) and a "logical
78  plan" of transformations that act on those elements.
79  """
80
81  def __init__(self, variant_tensor):
82    """Creates a DatasetV2 object.
83
84    This is a difference between DatasetV1 and DatasetV2. DatasetV1 does not
85    take anything in its constructor whereas in the DatasetV2, we expect
86    subclasses to create a variant_tensor and pass it in to the super() call.
87
88    Args:
89      variant_tensor: A DT_VARIANT tensor that represents the dataset.
90    """
91    self._variant_tensor_attr = variant_tensor
92    self._graph_attr = ops.get_default_graph()
93
94  @property
95  def _variant_tensor(self):
96    return self._variant_tensor_attr
97
98  @_variant_tensor.setter
99  def _variant_tensor(self, _):
100    raise ValueError("The _variant_tensor property is read-only")
101
102  def _as_serialized_graph(self):
103    """Produces serialized graph representation of the dataset.
104
105    Returns:
106      A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
107      serialized graph.
108    """
109    return gen_dataset_ops.dataset_to_graph(self._variant_tensor)
110
111  @abc.abstractmethod
112  def _inputs(self):
113    """Returns a list of the input datasets of the dataset."""
114
115    raise NotImplementedError("Dataset._inputs")
116
117  @property
118  def _graph(self):
119    return self._graph_attr
120
121  @_graph.setter
122  def _graph(self, _):
123    raise ValueError("The _graph property is read-only")
124
125  def _has_captured_ref(self):
126    """Whether this dataset uses a function that captures ref variables.
127
128    Returns:
129      A boolean, which if true indicates that the dataset or one of its inputs
130      uses a function that captures ref variables.
131    """
132    if context.executing_eagerly():
133      # RefVariables are not supported in eager mode
134      return False
135
136    def is_tensor_or_parent_ref(tensor):
137      if tensor.dtype._is_ref_dtype:  # pylint: disable=protected-access
138        return True
139      return any([is_tensor_or_parent_ref(x) for x in tensor.op.inputs])
140
141    for fn in self._functions():
142      if any([is_tensor_or_parent_ref(t) for t in fn.function.captured_inputs]):
143        return True
144
145    return any(
146        [input_dataset._has_captured_ref() for input_dataset in self._inputs()])  # pylint: disable=protected-access
147
148  # TODO(jsimsa): Change this to be the transitive closure of functions used
149  # by this dataset and its inputs.
150  def _functions(self):
151    """Returns a list of functions associated with this dataset.
152
153    Returns:
154      A list of `StructuredFunctionWrapper` objects.
155    """
156    return []
157
158  def options(self):
159    """Returns the options for this dataset and its inputs.
160
161    Returns:
162      A `tf.data.Options` object representing the dataset options.
163    """
164    options = Options()
165    for input_dataset in self._inputs():
166      input_options = input_dataset.options()
167      if input_options is not None:
168        options = options.merge(input_options)
169    return options
170
171  def _apply_options(self):
172    """Apply options, such as optimization configuration, to the dataset."""
173
174    dataset = self
175    options = self.options()
176    if options.experimental_threading is not None:
177      t_options = options.experimental_threading
178      if t_options.max_intra_op_parallelism is not None:
179        dataset = _MaxIntraOpParallelismDataset(
180            dataset, t_options.max_intra_op_parallelism)
181      if t_options.private_threadpool_size is not None:
182        dataset = _PrivateThreadPoolDataset(dataset,
183                                            t_options.private_threadpool_size)
184    static_optimizations = options._static_optimizations()  # pylint: disable=protected-access
185    if static_optimizations:
186      if self._has_captured_ref():
187        warnings.warn(
188            "tf.data static optimizations are not compatible with tf.Variable. "
189            "The following optimizations will be disabled: %s. To enable "
190            "optimizations, use resource variables instead by calling "
191            "`tf.enable_resource_variables()` at the start of the program." %
192            ", ".join(static_optimizations))
193      else:
194        dataset = _OptimizeDataset(dataset, static_optimizations)
195
196    autotune = True
197    cpu_budget = 0  # Indicates that all CPU cores should be used.
198    if options.experimental_optimization is not None:
199      if options.experimental_optimization.autotune is False:  # pylint: disable=g-bool-id-comparison
200        autotune = False
201      if options.experimental_optimization.autotune_cpu_budget is not None:
202        cpu_budget = options.experimental_optimization.autotune_cpu_budget
203
204    if autotune:
205      dataset = _ModelDataset(dataset, cpu_budget)
206
207    if options.experimental_stats and options.experimental_stats.aggregator:  # pylint: disable=line-too-long
208      dataset = _SetStatsAggregatorDataset(  # pylint: disable=protected-access
209          dataset, options.experimental_stats.aggregator,
210          options.experimental_stats.prefix,
211          options.experimental_stats.counter_prefix)
212    return dataset
213
214  def __iter__(self):
215    """Creates an `Iterator` for enumerating the elements of this dataset.
216
217    The returned iterator implements the Python iterator protocol and therefore
218    can only be used in eager mode.
219
220    Returns:
221      An `Iterator` over the elements of this dataset.
222
223    Raises:
224      RuntimeError: If eager execution is not enabled.
225    """
226    if context.executing_eagerly():
227      return iterator_ops.EagerIterator(self)
228    else:
229      raise RuntimeError("dataset.__iter__() is only supported when eager "
230                         "execution is enabled.")
231
232  @abc.abstractproperty
233  def _element_structure(self):
234    """The structure of an element of this dataset.
235
236    Returns:
237      A `Structure` object representing the structure of an element of this
238      dataset.
239    """
240    raise NotImplementedError("Dataset._element_structure")
241
242  def __repr__(self):
243    output_shapes = nest.map_structure(str, get_legacy_output_shapes(self))
244    output_shapes = str(output_shapes).replace("'", "")
245    output_types = nest.map_structure(repr, get_legacy_output_types(self))
246    output_types = str(output_types).replace("'", "")
247    return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
248                                            output_types))
249
250  @staticmethod
251  def from_tensors(tensors):
252    """Creates a `Dataset` with a single element, comprising the given tensors.
253
254    Note that if `tensors` contains a NumPy array, and eager execution is not
255    enabled, the values will be embedded in the graph as one or more
256    `tf.constant` operations. For large datasets (> 1 GB), this can waste
257    memory and run into byte limits of graph serialization. If `tensors`
258    contains one or more large NumPy arrays, consider the alternative described
259    in [this
260    guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
261
262    Args:
263      tensors: A nested structure of tensors.
264
265    Returns:
266      Dataset: A `Dataset`.
267    """
268    return TensorDataset(tensors)
269
270  @staticmethod
271  def from_tensor_slices(tensors):
272    """Creates a `Dataset` whose elements are slices of the given tensors.
273
274    Note that if `tensors` contains a NumPy array, and eager execution is not
275    enabled, the values will be embedded in the graph as one or more
276    `tf.constant` operations. For large datasets (> 1 GB), this can waste
277    memory and run into byte limits of graph serialization. If `tensors`
278    contains one or more large NumPy arrays, consider the alternative described
279    in [this guide](
280    https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
281
282    Args:
283      tensors: A nested structure of tensors, each having the same size in the
284        0th dimension.
285
286    Returns:
287      Dataset: A `Dataset`.
288    """
289    return TensorSliceDataset(tensors)
290
291  class _GeneratorState(object):
292    """Stores outstanding iterators created from a Python generator.
293
294    This class keeps track of potentially multiple iterators that may have
295    been created from a generator, e.g. in the case that the dataset is
296    repeated, or nested within a parallel computation.
297    """
298
299    def __init__(self, generator):
300      self._generator = generator
301      self._lock = threading.Lock()
302      self._next_id = 0  # GUARDED_BY(self._lock)
303      self._args = {}
304      self._iterators = {}
305
306    def get_next_id(self, *args):
307      with self._lock:
308        ret = self._next_id
309        self._next_id += 1
310      self._args[ret] = args
311      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
312      # casting in `py_func()` will create an array of `np.int32` on Windows,
313      # leading to a runtime error.
314      return np.array(ret, dtype=np.int64)
315
316    def get_iterator(self, iterator_id):
317      try:
318        return self._iterators[iterator_id]
319      except KeyError:
320        iterator = iter(self._generator(*self._args.pop(iterator_id)))
321        self._iterators[iterator_id] = iterator
322        return iterator
323
324    def iterator_completed(self, iterator_id):
325      del self._iterators[iterator_id]
326
327  @staticmethod
328  def from_generator(generator, output_types, output_shapes=None, args=None):
329    """Creates a `Dataset` whose elements are generated by `generator`.
330
331    The `generator` argument must be a callable object that returns
332    an object that support the `iter()` protocol (e.g. a generator function).
333    The elements generated by `generator` must be compatible with the given
334    `output_types` and (optional) `output_shapes` arguments.
335
336    For example:
337
338    ```python
339    import itertools
340    tf.enable_eager_execution()
341
342    def gen():
343      for i in itertools.count(1):
344        yield (i, [1] * i)
345
346    ds = tf.data.Dataset.from_generator(
347        gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
348
349    for value in ds.take(2):
350      print value
351    # (1, array([1]))
352    # (2, array([1, 1]))
353    ```
354
355    NOTE: The current implementation of `Dataset.from_generator()` uses
356    `tf.py_func` and inherits the same constraints. In particular, it
357    requires the `Dataset`- and `Iterator`-related operations to be placed
358    on a device in the same process as the Python program that called
359    `Dataset.from_generator()`. The body of `generator` will not be
360    serialized in a `GraphDef`, and you should not use this method if you
361    need to serialize your model and restore it in a different environment.
362
363    NOTE: If `generator` depends on mutable global variables or other external
364    state, be aware that the runtime may invoke `generator` multiple times
365    (in order to support repeating the `Dataset`) and at any time
366    between the call to `Dataset.from_generator()` and the production of the
367    first element from the generator. Mutating global variables or external
368    state can cause undefined behavior, and we recommend that you explicitly
369    cache any external state in `generator` before calling
370    `Dataset.from_generator()`.
371
372    Args:
373      generator: A callable object that returns an object that supports the
374        `iter()` protocol. If `args` is not specified, `generator` must take
375        no arguments; otherwise it must take as many arguments as there are
376        values in `args`.
377      output_types: A nested structure of `tf.DType` objects corresponding to
378        each component of an element yielded by `generator`.
379      output_shapes: (Optional.) A nested structure of `tf.TensorShape`
380        objects corresponding to each component of an element yielded by
381        `generator`.
382      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
383        and passed to `generator` as NumPy-array arguments.
384
385    Returns:
386      Dataset: A `Dataset`.
387    """
388    if not callable(generator):
389      raise TypeError("`generator` must be callable.")
390    if output_shapes is None:
391      output_shapes = nest.map_structure(
392          lambda _: tensor_shape.TensorShape(None), output_types)
393    else:
394      output_shapes = nest.map_structure_up_to(
395          output_types, tensor_shape.as_shape, output_shapes)
396    if args is None:
397      args = ()
398    else:
399      args = tuple(ops.convert_n_to_tensor(args, name="args"))
400
401    flattened_types = [dtypes.as_dtype(dt) for dt in nest.flatten(output_types)]
402    flattened_shapes = nest.flatten(output_shapes)
403
404    generator_state = DatasetV2._GeneratorState(generator)
405
406    def get_iterator_id_fn(unused_dummy):
407      """Creates a unique `iterator_id` for each pass over the dataset.
408
409      The returned `iterator_id` disambiguates between multiple concurrently
410      existing iterators.
411
412      Args:
413        unused_dummy: Ignored value.
414
415      Returns:
416        A `tf.int64` tensor whose value uniquely identifies an iterator in
417        `generator_state`.
418      """
419      return script_ops.py_func(
420          generator_state.get_next_id, args, dtypes.int64, stateful=True)
421
422    def generator_next_fn(iterator_id_t):
423      """Generates the next element from iterator with ID `iterator_id_t`.
424
425      We map this function across an infinite repetition of the
426      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
427
428      Args:
429        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
430          the iterator in `generator_state` from which to generate an element.
431
432      Returns:
433        A nested structure of tensors representing an element from the iterator.
434      """
435
436      def generator_py_func(iterator_id):
437        """A `py_func` that will be called to invoke the iterator."""
438        # `next()` raises `StopIteration` when there are no more
439        # elements remaining to be generated.
440        values = next(generator_state.get_iterator(iterator_id))
441
442        # Use the same _convert function from the py_func() implementation to
443        # convert the returned values to arrays early, so that we can inspect
444        # their values.
445        try:
446          flattened_values = nest.flatten_up_to(output_types, values)
447        except (TypeError, ValueError):
448          raise TypeError(
449              "`generator` yielded an element that did not match the expected "
450              "structure. The expected structure was %s, but the yielded "
451              "element was %s." % (output_types, values))
452        ret_arrays = []
453        for ret, dtype in zip(flattened_values, flattened_types):
454          try:
455            ret_arrays.append(script_ops.FuncRegistry._convert(  # pylint: disable=protected-access
456                ret, dtype=dtype.as_numpy_dtype))
457          except (TypeError, ValueError):
458            raise TypeError(
459                "`generator` yielded an element that could not be converted to "
460                "the expected type. The expected type was %s, but the yielded "
461                "element was %s." % (dtype.name, ret))
462
463        # Additional type and shape checking to ensure that the components
464        # of the generated element match the `output_types` and `output_shapes`
465        # arguments.
466        for (ret_array, expected_dtype, expected_shape) in zip(
467            ret_arrays, flattened_types, flattened_shapes):
468          if ret_array.dtype != expected_dtype.as_numpy_dtype:
469            raise TypeError(
470                "`generator` yielded an element of type %s where an element "
471                "of type %s was expected." % (ret_array.dtype,
472                                              expected_dtype.as_numpy_dtype))
473          if not expected_shape.is_compatible_with(ret_array.shape):
474            raise ValueError(
475                "`generator` yielded an element of shape %s where an element "
476                "of shape %s was expected." % (ret_array.shape, expected_shape))
477
478        return ret_arrays
479
480      flat_values = script_ops.py_func(
481          generator_py_func, [iterator_id_t], flattened_types, stateful=True)
482
483      # The `py_func()` op drops the inferred shapes, so we add them back in
484      # here.
485      if output_shapes is not None:
486        for ret_t, shape in zip(flat_values, flattened_shapes):
487          ret_t.set_shape(shape)
488
489      return nest.pack_sequence_as(output_types, flat_values)
490
491    def finalize_fn(iterator_id_t):
492      """Releases host-side state for the iterator with ID `iterator_id_t`."""
493
494      def finalize_py_func(iterator_id):
495        generator_state.iterator_completed(iterator_id)
496        # We return a dummy value so that the `finalize_fn` has a valid
497        # signature.
498        # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
499        # casting in `py_func()` will create an array of `np.int32` on Windows,
500        # leading to a runtime error.
501        return np.array(0, dtype=np.int64)
502
503      return script_ops.py_func(
504          finalize_py_func, [iterator_id_t], dtypes.int64, stateful=True)
505
506    # This function associates each traversal of `generator` with a unique
507    # iterator ID.
508    def flat_map_fn(dummy_arg):
509      # The `get_iterator_id_fn` gets a unique ID for the current instance of
510      # of the generator.
511      # The `generator_next_fn` gets the next element from the iterator with the
512      # given ID, and raises StopIteration when that iterator contains no
513      # more elements.
514      return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
515                               finalize_fn)
516
517    # A single-element dataset that, each time it is evaluated, contains a
518    # freshly-generated and unique (for the returned dataset) int64
519    # ID that will be used to identify the appropriate Python state, which
520    # is encapsulated in `generator_state`, and captured in
521    # `get_iterator_id_map_fn`.
522    dummy = 0
523    id_dataset = Dataset.from_tensors(dummy)
524
525    # A dataset that contains all of the elements generated by a
526    # single iterator created from `generator`, identified by the
527    # iterator ID contained in `id_dataset`. Lifting the iteration
528    # into a flat_map here enables multiple repetitions and/or nested
529    # versions of the returned dataset to be created, because it forces
530    # the generation of a new ID for each version.
531    return id_dataset.flat_map(flat_map_fn)
532
533  @staticmethod
534  def range(*args):
535    """Creates a `Dataset` of a step-separated range of values.
536
537    For example:
538
539    ```python
540    Dataset.range(5) == [0, 1, 2, 3, 4]
541    Dataset.range(2, 5) == [2, 3, 4]
542    Dataset.range(1, 5, 2) == [1, 3]
543    Dataset.range(1, 5, -2) == []
544    Dataset.range(5, 1) == []
545    Dataset.range(5, 1, -2) == [5, 3]
546    ```
547
548    Args:
549      *args: follows the same semantics as python's xrange.
550        len(args) == 1 -> start = 0, stop = args[0], step = 1
551        len(args) == 2 -> start = args[0], stop = args[1], step = 1
552        len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
553
554    Returns:
555      Dataset: A `RangeDataset`.
556
557    Raises:
558      ValueError: if len(args) == 0.
559    """
560    return RangeDataset(*args)
561
562  @staticmethod
563  def zip(datasets):
564    """Creates a `Dataset` by zipping together the given datasets.
565
566    This method has similar semantics to the built-in `zip()` function
567    in Python, with the main difference being that the `datasets`
568    argument can be an arbitrary nested structure of `Dataset` objects.
569    For example:
570
571    ```python
572    # NOTE: The following examples use `{ ... }` to represent the
573    # contents of a dataset.
574    a = { 1, 2, 3 }
575    b = { 4, 5, 6 }
576    c = { (7, 8), (9, 10), (11, 12) }
577    d = { 13, 14 }
578
579    # The nested structure of the `datasets` argument determines the
580    # structure of elements in the resulting dataset.
581    Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
582    Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
583
584    # The `datasets` argument may contain an arbitrary number of
585    # datasets.
586    Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
587                                (2, 5, (9, 10)),
588                                (3, 6, (11, 12)) }
589
590    # The number of elements in the resulting dataset is the same as
591    # the size of the smallest dataset in `datasets`.
592    Dataset.zip((a, d)) == { (1, 13), (2, 14) }
593    ```
594
595    Args:
596      datasets: A nested structure of datasets.
597
598    Returns:
599      Dataset: A `Dataset`.
600    """
601    return ZipDataset(datasets)
602
603  def concatenate(self, dataset):
604    """Creates a `Dataset` by concatenating given dataset with this dataset.
605
606    ```python
607    # NOTE: The following examples use `{ ... }` to represent the
608    # contents of a dataset.
609    a = { 1, 2, 3 }
610    b = { 4, 5, 6, 7 }
611
612    # Input dataset and dataset to be concatenated should have same
613    # nested structures and output types.
614    # c = { (8, 9), (10, 11), (12, 13) }
615    # d = { 14.0, 15.0, 16.0 }
616    # a.concatenate(c) and a.concatenate(d) would result in error.
617
618    a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
619    ```
620
621    Args:
622      dataset: `Dataset` to be concatenated.
623
624    Returns:
625      Dataset: A `Dataset`.
626    """
627    return ConcatenateDataset(self, dataset)
628
629  def prefetch(self, buffer_size):
630    """Creates a `Dataset` that prefetches elements from this dataset.
631
632    Args:
633      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
634        maximum number of elements that will be buffered when prefetching.
635
636    Returns:
637      Dataset: A `Dataset`.
638    """
639    return PrefetchDataset(self, buffer_size)
640
641  @staticmethod
642  def list_files(file_pattern, shuffle=None, seed=None):
643    """A dataset of all files matching one or more glob patterns.
644
645    NOTE: The default behavior of this method is to return filenames in
646    a non-deterministic random shuffled order. Pass a `seed` or `shuffle=False`
647    to get results in a deterministic order.
648
649    Example:
650      If we had the following files on our filesystem:
651        - /path/to/dir/a.txt
652        - /path/to/dir/b.py
653        - /path/to/dir/c.py
654      If we pass "/path/to/dir/*.py" as the directory, the dataset would
655      produce:
656        - /path/to/dir/b.py
657        - /path/to/dir/c.py
658
659    Args:
660      file_pattern: A string, a list of strings, or a `tf.Tensor` of string type
661        (scalar or vector), representing the filename glob (i.e. shell wildcard)
662        pattern(s) that will be matched.
663      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
664        Defaults to `True`.
665      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
666        seed that will be used to create the distribution. See
667        `tf.set_random_seed` for behavior.
668
669    Returns:
670     Dataset: A `Dataset` of strings corresponding to file names.
671    """
672    with ops.name_scope("list_files"):
673      if shuffle is None:
674        shuffle = True
675      file_pattern = ops.convert_to_tensor(
676          file_pattern, dtype=dtypes.string, name="file_pattern")
677      matching_files = gen_io_ops.matching_files(file_pattern)
678
679      # Raise an exception if `file_pattern` does not match any files.
680      condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
681                                   name="match_not_empty")
682
683      message = math_ops.add(
684          "No files matched pattern: ",
685          string_ops.reduce_join(file_pattern, separator=", "), name="message")
686
687      assert_not_empty = control_flow_ops.Assert(
688          condition, [message], summarize=1, name="assert_not_empty")
689      with ops.control_dependencies([assert_not_empty]):
690        matching_files = array_ops.identity(matching_files)
691
692      dataset = Dataset.from_tensor_slices(matching_files)
693      if shuffle:
694        # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
695        # list of files might be empty.
696        buffer_size = math_ops.maximum(
697            array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
698        dataset = dataset.shuffle(buffer_size, seed=seed)
699      return dataset
700
701  def repeat(self, count=None):
702    """Repeats this dataset `count` times.
703
704    NOTE: If this dataset is a function of global state (e.g. a random number
705    generator), then different repetitions may produce different elements.
706
707    Args:
708      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
709        number of times the dataset should be repeated. The default behavior
710        (if `count` is `None` or `-1`) is for the dataset be repeated
711        indefinitely.
712
713    Returns:
714      Dataset: A `Dataset`.
715    """
716    return RepeatDataset(self, count)
717
718  def _enumerate(self, start=0):
719
720    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
721    return Dataset.zip((Dataset.range(start, max_value), self))
722
723  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
724    """Randomly shuffles the elements of this dataset.
725
726    This dataset fills a buffer with `buffer_size` elements, then randomly
727    samples elements from this buffer, replacing the selected elements with new
728    elements. For perfect shuffling, a buffer size greater than or equal to the
729    full size of the dataset is required.
730
731    For instance, if your dataset contains 10,000 elements but `buffer_size` is
732    set to 1,000, then `shuffle` will initially select a random element from
733    only the first 1,000 elements in the buffer. Once an element is selected,
734    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
735    maintaining the 1,000 element buffer.
736
737    Args:
738      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
739        number of elements from this dataset from which the new
740        dataset will sample.
741      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
742        random seed that will be used to create the distribution. See
743        `tf.set_random_seed` for behavior.
744      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
745        that the dataset should be pseudorandomly reshuffled each time it is
746        iterated over. (Defaults to `True`.)
747
748    Returns:
749      Dataset: A `Dataset`.
750    """
751    return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
752
753  def cache(self, filename=""):
754    """Caches the elements in this dataset.
755
756    Args:
757      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
758        directory on the filesystem to use for caching tensors in this Dataset.
759        If a filename is not provided, the dataset will be cached in memory.
760
761    Returns:
762      Dataset: A `Dataset`.
763    """
764    return CacheDataset(self, filename)
765
766  def take(self, count):
767    """Creates a `Dataset` with at most `count` elements from this dataset.
768
769    Args:
770      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
771        elements of this dataset that should be taken to form the new dataset.
772        If `count` is -1, or if `count` is greater than the size of this
773        dataset, the new dataset will contain all elements of this dataset.
774
775    Returns:
776      Dataset: A `Dataset`.
777    """
778    return TakeDataset(self, count)
779
780  def skip(self, count):
781    """Creates a `Dataset` that skips `count` elements from this dataset.
782
783    Args:
784      count: A `tf.int64` scalar `tf.Tensor`, representing the number
785        of elements of this dataset that should be skipped to form the
786        new dataset.  If `count` is greater than the size of this
787        dataset, the new dataset will contain no elements.  If `count`
788        is -1, skips the entire dataset.
789
790    Returns:
791      Dataset: A `Dataset`.
792    """
793    return SkipDataset(self, count)
794
795  def shard(self, num_shards, index):
796    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
797
798    This dataset operator is very useful when running distributed training, as
799    it allows each worker to read a unique subset.
800
801    When reading a single input file, you can skip elements as follows:
802
803    ```python
804    d = tf.data.TFRecordDataset(input_file)
805    d = d.shard(num_workers, worker_index)
806    d = d.repeat(num_epochs)
807    d = d.shuffle(shuffle_buffer_size)
808    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
809    ```
810
811    Important caveats:
812
813    - Be sure to shard before you use any randomizing operator (such as
814      shuffle).
815    - Generally it is best if the shard operator is used early in the dataset
816      pipeline. For example, when reading from a set of TFRecord files, shard
817      before converting the dataset to input samples. This avoids reading every
818      file on every worker. The following is an example of an efficient
819      sharding strategy within a complete pipeline:
820
821    ```python
822    d = Dataset.list_files(pattern)
823    d = d.shard(num_workers, worker_index)
824    d = d.repeat(num_epochs)
825    d = d.shuffle(shuffle_buffer_size)
826    d = d.interleave(tf.data.TFRecordDataset,
827                     cycle_length=num_readers, block_length=1)
828    d = d.map(parser_fn, num_parallel_calls=num_map_threads)
829    ```
830
831    Args:
832      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
833        shards operating in parallel.
834      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
835
836    Returns:
837      Dataset: A `Dataset`.
838
839    Raises:
840      InvalidArgumentError: if `num_shards` or `index` are illegal values.
841        Note: error checking is done on a best-effort basis, and errors aren't
842        guaranteed to be caught upon dataset creation. (e.g. providing in a
843        placeholder tensor bypasses the early checking, and will instead result
844        in an error during a session.run call.)
845    """
846    return ShardDataset(self, num_shards, index)
847
848  def batch(self, batch_size, drop_remainder=False):
849    """Combines consecutive elements of this dataset into batches.
850
851    The tensors in the resulting element will have an additional outer
852    dimension, which will be `batch_size` (or `N % batch_size` for the last
853    element if `batch_size` does not divide the number of input elements `N`
854    evenly and `drop_remainder` is `False`). If your program depends on the
855    batches having the same outer dimension, you should set the `drop_remainder`
856    argument to `True` to prevent the smaller batch from being produced.
857
858    Args:
859      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
860        consecutive elements of this dataset to combine in a single batch.
861      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
862        whether the last batch should be dropped in the case it has fewer than
863        `batch_size` elements; the default behavior is not to drop the smaller
864        batch.
865
866    Returns:
867      Dataset: A `Dataset`.
868    """
869    return BatchDataset(self, batch_size, drop_remainder)
870
871  def padded_batch(self,
872                   batch_size,
873                   padded_shapes,
874                   padding_values=None,
875                   drop_remainder=False):
876    """Combines consecutive elements of this dataset into padded batches.
877
878    This transformation combines multiple consecutive elements of the input
879    dataset into a single element.
880
881    Like `tf.data.Dataset.batch`, the tensors in the resulting element will
882    have an additional outer dimension, which will be `batch_size` (or
883    `N % batch_size` for the last element if `batch_size` does not divide the
884    number of input elements `N` evenly and `drop_remainder` is `False`). If
885    your program depends on the batches having the same outer dimension, you
886    should set the `drop_remainder` argument to `True` to prevent the smaller
887    batch from being produced.
888
889    Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
890    different shapes, and this transformation will pad each component to the
891    respective shape in `padding_shapes`. The `padding_shapes` argument
892    determines the resulting shape for each dimension of each component in an
893    output element:
894
895    * If the dimension is a constant (e.g. `tf.Dimension(37)`), the component
896      will be padded out to that length in that dimension.
897    * If the dimension is unknown (e.g. `tf.Dimension(None)`), the component
898      will be padded out to the maximum length of all elements in that
899      dimension.
900
901    See also `tf.data.experimental.dense_to_sparse_batch`, which combines
902    elements that may have different shapes into a `tf.SparseTensor`.
903
904    Args:
905      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
906        consecutive elements of this dataset to combine in a single batch.
907      padded_shapes: A nested structure of `tf.TensorShape` or
908        `tf.int64` vector tensor-like objects representing the shape
909        to which the respective component of each input element should
910        be padded prior to batching. Any unknown dimensions
911        (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
912        tensor-like object) will be padded to the maximum size of that
913        dimension in each batch.
914      padding_values: (Optional.) A nested structure of scalar-shaped
915        `tf.Tensor`, representing the padding values to use for the
916        respective components.  Defaults are `0` for numeric types and
917        the empty string for string types.
918      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
919        whether the last batch should be dropped in the case it has fewer than
920        `batch_size` elements; the default behavior is not to drop the smaller
921        batch.
922
923    Returns:
924      Dataset: A `Dataset`.
925    """
926    return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
927                              drop_remainder)
928
929  def map(self, map_func, num_parallel_calls=None):
930    """Maps `map_func` across the elements of this dataset.
931
932    This transformation applies `map_func` to each element of this dataset, and
933    returns a new dataset containing the transformed elements, in the same
934    order as they appeared in the input.
935
936    For example:
937
938    ```python
939    # NOTE: The following examples use `{ ... }` to represent the
940    # contents of a dataset.
941    a = { 1, 2, 3, 4, 5 }
942
943    a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
944    ```
945
946    The input signature of `map_func` is determined by the structure of each
947    element in this dataset. For example:
948
949    ```python
950    # Each element is a `tf.Tensor` object.
951    a = { 1, 2, 3, 4, 5 }
952    # `map_func` takes a single argument of type `tf.Tensor` with the same
953    # shape and dtype.
954    result = a.map(lambda x: ...)
955
956    # Each element is a tuple containing two `tf.Tensor` objects.
957    b = { (1, "foo"), (2, "bar"), (3, "baz") }
958    # `map_func` takes two arguments of type `tf.Tensor`.
959    result = b.map(lambda x_int, y_str: ...)
960
961    # Each element is a dictionary mapping strings to `tf.Tensor` objects.
962    c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
963    # `map_func` takes a single argument of type `dict` with the same keys as
964    # the elements.
965    result = c.map(lambda d: ...)
966    ```
967
968    The value or values returned by `map_func` determine the structure of each
969    element in the returned dataset.
970
971    ```python
972    # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
973    def f(...):
974      return tf.constant(37.0)
975    result = dataset.map(f)
976    result.output_classes == tf.Tensor
977    result.output_types == tf.float32
978    result.output_shapes == []  # scalar
979
980    # `map_func` returns two `tf.Tensor` objects.
981    def g(...):
982      return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
983    result = dataset.map(g)
984    result.output_classes == (tf.Tensor, tf.Tensor)
985    result.output_types == (tf.float32, tf.string)
986    result.output_shapes == ([], [3])
987
988    # Python primitives, lists, and NumPy arrays are implicitly converted to
989    # `tf.Tensor`.
990    def h(...):
991      return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
992    result = dataset.map(h)
993    result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
994    result.output_types == (tf.float32, tf.string, tf.float64)
995    result.output_shapes == ([], [3], [2])
996
997    # `map_func` can return nested structures.
998    def i(...):
999      return {"a": 37.0, "b": [42, 16]}, "foo"
1000    result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
1001    result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
1002    result.output_shapes == ({"a": [], "b": [2]}, [])
1003    ```
1004
1005    In addition to `tf.Tensor` objects, `map_func` can accept as arguments and
1006    return `tf.SparseTensor` objects.
1007
1008    Args:
1009      map_func: A function mapping a nested structure of tensors (having
1010        shapes and types defined by `self.output_shapes` and
1011       `self.output_types`) to another nested structure of tensors.
1012      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
1013        representing the number elements to process asynchronously in parallel.
1014        If not specified, elements will be processed sequentially. If the value
1015        `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
1016        calls is set dynamically based on available CPU.
1017
1018    Returns:
1019      Dataset: A `Dataset`.
1020    """
1021    if num_parallel_calls is None:
1022      return MapDataset(self, map_func, preserve_cardinality=True)
1023    else:
1024      return ParallelMapDataset(
1025          self, map_func, num_parallel_calls, preserve_cardinality=True)
1026
1027  def flat_map(self, map_func):
1028    """Maps `map_func` across this dataset and flattens the result.
1029
1030    Use `flat_map` if you want to make sure that the order of your dataset
1031    stays the same. For example, to flatten a dataset of batches into a
1032    dataset of their elements:
1033
1034    ```python
1035    # NOTE: The following examples use `{ ... }` to represent the
1036    # contents of a dataset. '[...]' represents a tensor.
1037    a = {[1,2,3,4,5], [6,7,8,9], [10]}
1038
1039    a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
1040      {[1,2,3,4,5,6,7,8,9,10]}
1041    ```
1042
1043    `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
1044    `flat_map` produces the same output as
1045    `tf.data.Dataset.interleave(cycle_length=1)`
1046
1047    Args:
1048      map_func: A function mapping a nested structure of tensors (having shapes
1049        and types defined by `self.output_shapes` and `self.output_types`) to a
1050        `Dataset`.
1051
1052    Returns:
1053      Dataset: A `Dataset`.
1054    """
1055    return FlatMapDataset(self, map_func)
1056
1057  def interleave(self,
1058                 map_func,
1059                 cycle_length,
1060                 block_length=1,
1061                 num_parallel_calls=None):
1062    """Maps `map_func` across this dataset, and interleaves the results.
1063
1064    For example, you can use `Dataset.interleave()` to process many input files
1065    concurrently:
1066
1067    ```python
1068    # Preprocess 4 files concurrently, and interleave blocks of 16 records from
1069    # each file.
1070    filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
1071    dataset = (Dataset.from_tensor_slices(filenames)
1072               .interleave(lambda x:
1073                   TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
1074                   cycle_length=4, block_length=16))
1075    ```
1076
1077    The `cycle_length` and `block_length` arguments control the order in which
1078    elements are produced. `cycle_length` controls the number of input elements
1079    that are processed concurrently. If you set `cycle_length` to 1, this
1080    transformation will handle one input element at a time, and will produce
1081    identical results to `tf.data.Dataset.flat_map`. In general,
1082    this transformation will apply `map_func` to `cycle_length` input elements,
1083    open iterators on the returned `Dataset` objects, and cycle through them
1084    producing `block_length` consecutive elements from each iterator, and
1085    consuming the next input element each time it reaches the end of an
1086    iterator.
1087
1088    For example:
1089
1090    ```python
1091    # NOTE: The following examples use `{ ... }` to represent the
1092    # contents of a dataset.
1093    a = { 1, 2, 3, 4, 5 }
1094
1095    # NOTE: New lines indicate "block" boundaries.
1096    a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
1097                 cycle_length=2, block_length=4) == {
1098        1, 1, 1, 1,
1099        2, 2, 2, 2,
1100        1, 1,
1101        2, 2,
1102        3, 3, 3, 3,
1103        4, 4, 4, 4,
1104        3, 3,
1105        4, 4,
1106        5, 5, 5, 5,
1107        5, 5,
1108    }
1109    ```
1110
1111    NOTE: The order of elements yielded by this transformation is
1112    deterministic, as long as `map_func` is a pure function. If
1113    `map_func` contains any stateful operations, the order in which
1114    that state is accessed is undefined.
1115
1116    Args:
1117      map_func: A function mapping a nested structure of tensors (having shapes
1118        and types defined by `self.output_shapes` and `self.output_types`) to a
1119        `Dataset`.
1120      cycle_length: The number of elements from this dataset that will be
1121        processed concurrently.
1122      block_length: The number of consecutive elements to produce from each
1123        input element before cycling to another input element.
1124      num_parallel_calls: (Optional.) If specified, the implementation creates
1125        a threadpool, which is used to fetch inputs from cycle elements
1126        asynchronously and in parallel. The default behavior is to fetch inputs
1127        from cycle elements synchronously with no parallelism. If the value
1128        `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
1129        calls is set dynamically based on available CPU.
1130
1131    Returns:
1132      Dataset: A `Dataset`.
1133    """
1134    if num_parallel_calls is None:
1135      return InterleaveDataset(self, map_func, cycle_length, block_length)
1136    else:
1137      return ParallelInterleaveDataset(self, map_func, cycle_length,
1138                                       block_length, num_parallel_calls)
1139
1140  def filter(self, predicate):
1141    """Filters this dataset according to `predicate`.
1142
1143    ```python
1144    d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
1145
1146    d = d.filter(lambda x: x < 3) # [1, 2]
1147
1148    # `tf.math.equal(x, y)` is required for equality comparison
1149    def filter_fn(x):
1150      return tf.math.equal(x, 1)
1151
1152    d = d.filter(filter_fn) # [1]
1153    ```
1154
1155    Args:
1156      predicate: A function mapping a nested structure of tensors (having shapes
1157        and types defined by `self.output_shapes` and `self.output_types`) to a
1158        scalar `tf.bool` tensor.
1159
1160    Returns:
1161      Dataset: The `Dataset` containing the elements of this dataset for which
1162          `predicate` is `True`.
1163    """
1164    return FilterDataset(self, predicate)
1165
1166  def apply(self, transformation_func):
1167    """Applies a transformation function to this dataset.
1168
1169    `apply` enables chaining of custom `Dataset` transformations, which are
1170    represented as functions that take one `Dataset` argument and return a
1171    transformed `Dataset`.
1172
1173    For example:
1174
1175    ```
1176    dataset = (dataset.map(lambda x: x ** 2)
1177               .apply(group_by_window(key_func, reduce_func, window_size))
1178               .map(lambda x: x ** 3))
1179    ```
1180
1181    Args:
1182      transformation_func: A function that takes one `Dataset` argument and
1183        returns a `Dataset`.
1184
1185    Returns:
1186      Dataset: The `Dataset` returned by applying `transformation_func` to this
1187          dataset.
1188    """
1189    dataset = transformation_func(self)
1190    if not isinstance(dataset, DatasetV2):
1191      raise TypeError(
1192          "`transformation_func` must return a Dataset. Got {}.".format(
1193              dataset))
1194    dataset._input_datasets = [self]  # pylint: disable=protected-access
1195    return dataset
1196
1197  def window(self, size, shift=None, stride=1, drop_remainder=False):
1198    """Combines input elements into a dataset of windows.
1199
1200    Each window is a dataset itself and contains `size` elements (or
1201    possibly fewer if there are not enough input elements to fill the window
1202    and `drop_remainder` evaluates to false).
1203
1204    The `stride` argument determines the stride of the input elements,
1205    and the `shift` argument determines the shift of the window.
1206
1207    For example:
1208    - `tf.data.Dataset.range(7).window(2)` produces
1209      `{{0, 1}, {2, 3}, {4, 5}, {6}}`
1210    - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
1211      `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
1212    - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
1213      `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
1214
1215    Args:
1216      size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
1217        of the input dataset to combine into a window.
1218      shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1219        forward shift of the sliding window in each iteration. Defaults to
1220        `size`.
1221      stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1222        stride of the input elements in the sliding window.
1223      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
1224        whether a window should be dropped in case its size is smaller than
1225        `window_size`.
1226
1227    Returns:
1228      Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
1229        the same structure as this dataset, but a finite subsequence of its
1230        elements.
1231    """
1232    if shift is None:
1233      shift = size
1234    return WindowDataset(self, size, shift, stride, drop_remainder)
1235
1236  def reduce(self, initial_state, reduce_func):
1237    """Reduces the input dataset to a single element.
1238
1239    The transformation calls `reduce_func` successively on every element of
1240    the input dataset until the dataset is exhausted, aggregating information in
1241    its internal state. The `initial_state` argument is used for the initial
1242    state and the final state is returned as the result.
1243
1244    For example:
1245    - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
1246      produces `5`
1247    - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
1248      produces `10`
1249
1250    Args:
1251      initial_state: A nested structure of tensors, representing the initial
1252        state of the transformation.
1253      reduce_func: A function that maps `(old_state, input_element)` to
1254        `new_state`. It must take two arguments and return a nested structure
1255        of tensors. The structure of `new_state` must match the structure of
1256        `initial_state`.
1257
1258    Returns:
1259      A nested structure of `tf.Tensor` objects, corresponding to the final
1260      state of the transformation.
1261
1262    """
1263
1264    with ops.name_scope("initial_state"):
1265      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
1266      # values to tensors.
1267      initial_state = nest.pack_sequence_as(initial_state, [
1268          sparse_tensor_lib.SparseTensor.from_value(t)
1269          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
1270              t, name="component_%d" % i)
1271          for i, t in enumerate(nest.flatten(initial_state))
1272      ])
1273
1274    # Compute initial values for the state classes, shapes and types based on
1275    # the initial state.
1276    state_structure = structure_lib.Structure.from_value(initial_state)
1277
1278    # Iteratively rerun the reduce function until reaching a fixed point on
1279    # `state_structure`.
1280    need_to_rerun = True
1281    while need_to_rerun:
1282
1283      wrapped_func = StructuredFunctionWrapper(
1284          reduce_func,
1285          "reduce()",
1286          input_structure=structure_lib.NestedStructure(
1287              (state_structure, self._element_structure)),
1288          add_to_graph=False)
1289
1290      # Extract and validate class information from the returned values.
1291      output_classes = wrapped_func.output_classes
1292      state_classes = state_structure._to_legacy_output_classes()  # pylint: disable=protected-access
1293      for new_state_class, state_class in zip(
1294          nest.flatten(output_classes), nest.flatten(state_classes)):
1295        if not issubclass(new_state_class, state_class):
1296          raise TypeError(
1297              "The element classes for the new state must match the initial "
1298              "state. Expected %s; got %s." % (state_classes,
1299                                               wrapped_func.output_classes))
1300
1301      # Extract and validate type information from the returned values.
1302      output_types = wrapped_func.output_types
1303      state_types = state_structure._to_legacy_output_types()  # pylint: disable=protected-access
1304      for new_state_type, state_type in zip(
1305          nest.flatten(output_types), nest.flatten(state_types)):
1306        if new_state_type != state_type:
1307          raise TypeError(
1308              "The element types for the new state must match the initial "
1309              "state. Expected %s; got %s." % (state_types,
1310                                               wrapped_func.output_types))
1311
1312      # Extract shape information from the returned values.
1313      output_shapes = wrapped_func.output_shapes
1314      state_shapes = state_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
1315      flat_state_shapes = nest.flatten(state_shapes)
1316      flat_new_state_shapes = nest.flatten(output_shapes)
1317      weakened_state_shapes = [
1318          original.most_specific_compatible_shape(new)
1319          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
1320      ]
1321
1322      need_to_rerun = False
1323      for original_shape, weakened_shape in zip(flat_state_shapes,
1324                                                weakened_state_shapes):
1325        if original_shape.ndims is not None and (
1326            weakened_shape.ndims is None or
1327            original_shape.as_list() != weakened_shape.as_list()):
1328          need_to_rerun = True
1329          break
1330
1331      if need_to_rerun:
1332        # TODO(b/110122868): Support a "most specific compatible structure"
1333        # method for combining structures, to avoid using legacy structures
1334        # here.
1335        state_structure = structure_lib.convert_legacy_structure(
1336            state_types,
1337            nest.pack_sequence_as(state_shapes, weakened_state_shapes),
1338            state_classes)
1339
1340    reduce_func = wrapped_func.function
1341    reduce_func.add_to_graph(ops.get_default_graph())
1342
1343    # pylint: disable=protected-access
1344    return state_structure._from_compatible_tensor_list(
1345        gen_dataset_ops.reduce_dataset(
1346            self._variant_tensor,
1347            state_structure._to_tensor_list(initial_state),
1348            reduce_func.captured_inputs,
1349            f=reduce_func,
1350            output_shapes=state_structure._flat_shapes,
1351            output_types=state_structure._flat_types))
1352
1353  def with_options(self, options):
1354    """Returns a new `tf.data.Dataset` with the given options set.
1355
1356    The options are "global" in the sense they apply to the entire dataset.
1357    If options are set multiple times, they are merged as long as different
1358    options do not use different non-default values.
1359
1360    Args:
1361      options: A `tf.data.Options` that identifies the options the use.
1362
1363    Returns:
1364      Dataset: A `Dataset` with the given options.
1365
1366    Raises:
1367      ValueError: when an option is set more than once to a non-default value
1368    """
1369    return _OptionsDataset(self, options)
1370
1371
1372@tf_export(v1=["data.Dataset"])
1373class DatasetV1(DatasetV2):
1374  """Represents a potentially large set of elements.
1375
1376  A `Dataset` can be used to represent an input pipeline as a
1377  collection of elements (nested structures of tensors) and a "logical
1378  plan" of transformations that act on those elements.
1379  """
1380
1381  def __init__(self):
1382    try:
1383      variant_tensor = self._as_variant_tensor()
1384    except AttributeError as e:
1385      if "_as_variant_tensor" in str(e):
1386        raise AttributeError("Please use _variant_tensor instead of "
1387                             "_as_variant_tensor() to obtain the variant "
1388                             "associated with a dataset")
1389      raise AttributeError("A likely cause of this error is that the super "
1390                           "call for this dataset is not the last line of the "
1391                           "__init__ method. The base class causes the "
1392                           "_as_variant_tensor call in its constructor and "
1393                           "if that uses attributes defined in the __init__ "
1394                           "method, those attrs need to be defined before the "
1395                           "super call.")
1396    super(DatasetV1, self).__init__(variant_tensor)
1397
1398  @abc.abstractmethod
1399  def _as_variant_tensor(self):
1400    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
1401
1402    Returns:
1403      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
1404    """
1405    raise NotImplementedError("Dataset._as_variant_tensor")
1406
1407  @deprecation.deprecated(
1408      None, "Use `for ... in dataset:` to iterate over a dataset. If using "
1409      "`tf.estimator`, return the `Dataset` object directly from your input "
1410      "function. As a last resort, you can use "
1411      "`tf.compat.v1.data.make_one_shot_iterator(dataset)`.")
1412  def make_one_shot_iterator(self):
1413    """Creates an `Iterator` for enumerating the elements of this dataset.
1414
1415    Note: The returned iterator will be initialized automatically.
1416    A "one-shot" iterator does not currently support re-initialization.
1417
1418    Returns:
1419      An `Iterator` over the elements of this dataset.
1420    """
1421    return self._make_one_shot_iterator()
1422
1423  def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
1424    if context.executing_eagerly():
1425      return iterator_ops.EagerIterator(self)
1426
1427    _ensure_same_dataset_graph(self)
1428    # Now that we create datasets at python object creation time, the capture
1429    # by value _make_dataset() function would try to capture these variant
1430    # tensor dataset inputs, which are marked as stateful ops and would throw
1431    # an error if we try and capture them. We therefore traverse the graph
1432    # to find all these ops and whitelist them so that the capturing
1433    # logic instead of throwing an error recreates these ops which is what was
1434    # happening before.
1435    all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
1436    graph_level_seed, op_level_seed = core_random_seed.get_seed(None)
1437
1438    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
1439    # a 0-argument function.
1440    @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
1441    def _make_dataset():
1442      """Factory function for a dataset."""
1443      # NOTE(mrry): `Defun` does not capture the graph-level seed from the
1444      # enclosing graph, so if a graph-level seed is present we set the local
1445      # graph seed based on a combination of the graph- and op-level seeds.
1446      if graph_level_seed is not None:
1447        assert op_level_seed is not None
1448        core_random_seed.set_random_seed(
1449            (graph_level_seed + 87654321 * op_level_seed) % (2 ** 63 - 1))
1450
1451      dataset = self._apply_options()
1452      return dataset._variant_tensor  # pylint: disable=protected-access
1453
1454    try:
1455      _make_dataset.add_to_graph(ops.get_default_graph())
1456    except ValueError as err:
1457      if "Cannot capture a stateful node" in str(err):
1458        raise ValueError(
1459            "Failed to create a one-shot iterator for a dataset. "
1460            "`Dataset.make_one_shot_iterator()` does not support datasets that "
1461            "capture stateful objects, such as a `Variable` or `LookupTable`. "
1462            "In these cases, use `Dataset.make_initializable_iterator()`. "
1463            "(Original error: %s)" % err)
1464      else:
1465        six.reraise(ValueError, err)
1466
1467    # pylint: disable=protected-access
1468    return iterator_ops.Iterator(
1469        gen_dataset_ops.one_shot_iterator(
1470            dataset_factory=_make_dataset, **flat_structure(self)),
1471        None, get_legacy_output_types(self), get_legacy_output_shapes(self),
1472        get_legacy_output_classes(self))
1473
1474  @deprecation.deprecated(
1475      None, "Use `for ... in dataset:` to iterate over a dataset. If using "
1476      "`tf.estimator`, return the `Dataset` object directly from your input "
1477      "function. As a last resort, you can use "
1478      "`tf.compat.v1.data.make_initializable_iterator(dataset)`.")
1479  def make_initializable_iterator(self, shared_name=None):
1480    """Creates an `Iterator` for enumerating the elements of this dataset.
1481
1482    Note: The returned iterator will be in an uninitialized state,
1483    and you must run the `iterator.initializer` operation before using it:
1484
1485    ```python
1486    dataset = ...
1487    iterator = dataset.make_initializable_iterator()
1488    # ...
1489    sess.run(iterator.initializer)
1490    ```
1491
1492    Args:
1493      shared_name: (Optional.) If non-empty, the returned iterator will be
1494        shared under the given name across multiple sessions that share the
1495        same devices (e.g. when using a remote server).
1496
1497    Returns:
1498      An `Iterator` over the elements of this dataset.
1499
1500    Raises:
1501      RuntimeError: If eager execution is enabled.
1502    """
1503
1504    return self._make_initializable_iterator(shared_name)
1505
1506  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=missing-docstring
1507    if context.executing_eagerly():
1508      raise RuntimeError(
1509          "dataset.make_initializable_iterator is not supported when eager "
1510          "execution is enabled.")
1511    _ensure_same_dataset_graph(self)
1512    dataset = self._apply_options()
1513    if shared_name is None:
1514      shared_name = ""
1515    if compat.forward_compatible(2018, 8, 3):
1516      iterator_resource = gen_dataset_ops.iterator_v2(
1517          container="", shared_name=shared_name, **flat_structure(self))
1518    else:
1519      iterator_resource = gen_dataset_ops.iterator(
1520          container="", shared_name=shared_name, **flat_structure(self))
1521    with ops.colocate_with(iterator_resource):
1522      initializer = gen_dataset_ops.make_iterator(
1523          dataset._variant_tensor,  # pylint: disable=protected-access
1524          iterator_resource)
1525    # pylint: disable=protected-access
1526    return iterator_ops.Iterator(
1527        iterator_resource, initializer, get_legacy_output_types(dataset),
1528        get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
1529
1530  @property
1531  def output_classes(self):
1532    """Returns the class of each component of an element of this dataset.
1533
1534    The expected values are `tf.Tensor` and `tf.SparseTensor`.
1535
1536    Returns:
1537      A nested structure of Python `type` objects corresponding to each
1538      component of an element of this dataset.
1539    """
1540    return self._element_structure._to_legacy_output_classes()  # pylint: disable=protected-access
1541
1542  @property
1543  def output_shapes(self):
1544    """Returns the shape of each component of an element of this dataset.
1545
1546    Returns:
1547      A nested structure of `tf.TensorShape` objects corresponding to each
1548      component of an element of this dataset.
1549    """
1550    return self._element_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
1551
1552  @property
1553  def output_types(self):
1554    """Returns the type of each component of an element of this dataset.
1555
1556    Returns:
1557      A nested structure of `tf.DType` objects corresponding to each component
1558      of an element of this dataset.
1559    """
1560    return self._element_structure._to_legacy_output_types()  # pylint: disable=protected-access
1561
1562  @property
1563  def _element_structure(self):
1564    # TODO(b/110122868): Remove this override once all `Dataset` instances
1565    # implement `element_structure`.
1566    return structure_lib.convert_legacy_structure(
1567        self.output_types, self.output_shapes, self.output_classes)
1568
1569  @staticmethod
1570  @functools.wraps(DatasetV2.from_tensors)
1571  def from_tensors(tensors):
1572    return DatasetV1Adapter(DatasetV2.from_tensors(tensors))
1573
1574  @staticmethod
1575  @functools.wraps(DatasetV2.from_tensor_slices)
1576  def from_tensor_slices(tensors):
1577    return DatasetV1Adapter(DatasetV2.from_tensor_slices(tensors))
1578
1579  @staticmethod
1580  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
1581  def from_sparse_tensor_slices(sparse_tensor):
1582    """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
1583
1584    Args:
1585      sparse_tensor: A `tf.SparseTensor`.
1586
1587    Returns:
1588      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
1589    """
1590    return DatasetV1Adapter(SparseTensorSliceDataset(sparse_tensor))
1591
1592  @staticmethod
1593  @functools.wraps(DatasetV2.from_generator)
1594  def from_generator(generator, output_types, output_shapes=None, args=None):
1595    return DatasetV1Adapter(DatasetV2.from_generator(
1596        generator, output_types, output_shapes, args))
1597
1598  @staticmethod
1599  @functools.wraps(DatasetV2.range)
1600  def range(*args):
1601    return DatasetV1Adapter(DatasetV2.range(*args))
1602
1603  @staticmethod
1604  @functools.wraps(DatasetV2.zip)
1605  def zip(datasets):
1606    return DatasetV1Adapter(DatasetV2.zip(datasets))
1607
1608  @functools.wraps(DatasetV2.concatenate)
1609  def concatenate(self, dataset):
1610    return DatasetV1Adapter(super(DatasetV1, self).concatenate(dataset))
1611
1612  @functools.wraps(DatasetV2.prefetch)
1613  def prefetch(self, buffer_size):
1614    return DatasetV1Adapter(super(DatasetV1, self).prefetch(buffer_size))
1615
1616  @staticmethod
1617  @functools.wraps(DatasetV2.list_files)
1618  def list_files(file_pattern, shuffle=None, seed=None):
1619    return DatasetV1Adapter(DatasetV2.list_files(file_pattern, shuffle, seed))
1620
1621  @functools.wraps(DatasetV2.repeat)
1622  def repeat(self, count=None):
1623    return DatasetV1Adapter(super(DatasetV1, self).repeat(count))
1624
1625  @functools.wraps(DatasetV2.shuffle)
1626  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
1627    return DatasetV1Adapter(super(DatasetV1, self).shuffle(
1628        buffer_size, seed, reshuffle_each_iteration))
1629
1630  @functools.wraps(DatasetV2.cache)
1631  def cache(self, filename=""):
1632    return DatasetV1Adapter(super(DatasetV1, self).cache(filename))
1633
1634  @functools.wraps(DatasetV2.take)
1635  def take(self, count):
1636    return DatasetV1Adapter(super(DatasetV1, self).take(count))
1637
1638  @functools.wraps(DatasetV2.skip)
1639  def skip(self, count):
1640    return DatasetV1Adapter(super(DatasetV1, self).skip(count))
1641
1642  @functools.wraps(DatasetV2.shard)
1643  def shard(self, num_shards, index):
1644    return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
1645
1646  @functools.wraps(DatasetV2.batch)
1647  def batch(self, batch_size, drop_remainder=False):
1648    return DatasetV1Adapter(super(DatasetV1, self).batch(
1649        batch_size, drop_remainder))
1650
1651  @functools.wraps(DatasetV2.padded_batch)
1652  def padded_batch(self,
1653                   batch_size,
1654                   padded_shapes,
1655                   padding_values=None,
1656                   drop_remainder=False):
1657    return DatasetV1Adapter(super(DatasetV1, self).padded_batch(
1658        batch_size, padded_shapes, padding_values, drop_remainder))
1659
1660  @functools.wraps(DatasetV2.map)
1661  def map(self, map_func, num_parallel_calls=None):
1662    if num_parallel_calls is None:
1663      return DatasetV1Adapter(
1664          MapDataset(self, map_func, preserve_cardinality=False))
1665    else:
1666      return DatasetV1Adapter(
1667          ParallelMapDataset(
1668              self, map_func, num_parallel_calls, preserve_cardinality=False))
1669
1670  @deprecation.deprecated(None, "Use `tf.data.Dataset.map()")
1671  def map_with_legacy_function(self, map_func, num_parallel_calls=None):
1672    """Maps `map_func` across the elements of this dataset.
1673
1674    NOTE: This is an escape hatch for existing uses of `map` that do not work
1675    with V2 functions. New uses are strongly discouraged and existing uses
1676    should migrate to `map` as this method will be removed in V2.
1677
1678    Args:
1679      map_func: A function mapping a nested structure of tensors (having shapes
1680        and types defined by `self.output_shapes` and `self.output_types`) to
1681        another nested structure of tensors.
1682      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
1683        representing the number elements to process asynchronously in parallel.
1684        If not specified, elements will be processed sequentially. If the value
1685        `tf.data.experimental.AUTOTUNE` is used, then the number of parallel
1686        calls is set dynamically based on available CPU.
1687
1688    Returns:
1689      Dataset: A `Dataset`.
1690    """
1691    if num_parallel_calls is None:
1692      return DatasetV1Adapter(
1693          MapDataset(
1694              self,
1695              map_func,
1696              preserve_cardinality=False,
1697              use_legacy_function=True))
1698    else:
1699      return DatasetV1Adapter(
1700          ParallelMapDataset(
1701              self,
1702              map_func,
1703              num_parallel_calls,
1704              preserve_cardinality=False,
1705              use_legacy_function=True))
1706
1707  @functools.wraps(DatasetV2.flat_map)
1708  def flat_map(self, map_func):
1709    return DatasetV1Adapter(super(DatasetV1, self).flat_map(map_func))
1710
1711  @functools.wraps(DatasetV2.interleave)
1712  def interleave(self,
1713                 map_func,
1714                 cycle_length,
1715                 block_length=1,
1716                 num_parallel_calls=None):
1717    return DatasetV1Adapter(super(DatasetV1, self).interleave(
1718        map_func, cycle_length, block_length, num_parallel_calls))
1719
1720  @functools.wraps(DatasetV2.filter)
1721  def filter(self, predicate):
1722    return DatasetV1Adapter(super(DatasetV1, self).filter(predicate))
1723
1724  @deprecation.deprecated(None, "Use `tf.data.Dataset.filter()")
1725  def filter_with_legacy_function(self, predicate):
1726    """Filters this dataset according to `predicate`.
1727
1728    NOTE: This is an escape hatch for existing uses of `filter` that do not work
1729    with V2 functions. New uses are strongly discouraged and existing uses
1730    should migrate to `filter` as this method will be removed in V2.
1731
1732    Args:
1733      predicate: A function mapping a nested structure of tensors (having shapes
1734        and types defined by `self.output_shapes` and `self.output_types`) to a
1735        scalar `tf.bool` tensor.
1736
1737    Returns:
1738      Dataset: The `Dataset` containing the elements of this dataset for which
1739          `predicate` is `True`.
1740    """
1741    return FilterDataset(self, predicate, use_legacy_function=True)
1742
1743  @functools.wraps(DatasetV2.apply)
1744  def apply(self, transformation_func):
1745    return DatasetV1Adapter(super(DatasetV1, self).apply(transformation_func))
1746
1747  @functools.wraps(DatasetV2.window)
1748  def window(self, size, shift=None, stride=1, drop_remainder=False):
1749    return DatasetV1Adapter(super(DatasetV1, self).window(
1750        size, shift, stride, drop_remainder))
1751
1752  @functools.wraps(DatasetV2.with_options)
1753  def with_options(self, options):
1754    return DatasetV1Adapter(super(DatasetV1, self).with_options(options))
1755
1756
1757# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
1758# this alias in place.
1759Dataset = DatasetV1
1760
1761
1762class DatasetV1Adapter(DatasetV1):
1763  """Wraps a V2 `Dataset` object in the `tf.compat.v1.data.Dataset` API."""
1764
1765  def __init__(self, dataset):
1766    self._dataset = dataset
1767    super(DatasetV1Adapter, self).__init__()
1768
1769  def _as_variant_tensor(self):
1770    return self._dataset._variant_tensor  # pylint: disable=protected-access
1771
1772  def _has_captured_ref(self):
1773    return self._dataset._has_captured_ref()  # pylint: disable=protected-access
1774
1775  def _inputs(self):
1776    return self._dataset._inputs()  # pylint: disable=protected-access
1777
1778  def options(self):
1779    return self._dataset.options()
1780
1781  @property
1782  def _element_structure(self):
1783    return self._dataset._element_structure  # pylint: disable=protected-access
1784
1785  def __iter__(self):
1786    return iter(self._dataset)
1787
1788
1789def _ensure_same_dataset_graph(dataset):
1790  """Walks the dataset graph to ensure all datasets come from the same graph."""
1791  current_graph = ops.get_default_graph()
1792  bfs_q = Queue.Queue()
1793  bfs_q.put(dataset)  # pylint: disable=protected-access
1794  visited = []
1795  while not bfs_q.empty():
1796    ds = bfs_q.get()
1797    visited.append(ds)
1798    ds_graph = ds._graph  # pylint: disable=protected-access
1799    if current_graph != ds_graph:
1800      logging.warning("The graph (" + str(current_graph) + ") of the iterator "
1801                      "is different from the graph (" + str(ds_graph) + ") "
1802                      "the dataset: " + str(ds._variant_tensor) + " was "  # pylint: disable=protected-access
1803                      "created in. If you are using the Estimator API, "
1804                      "make sure that no part of the dataset returned by the "
1805                      "`input_fn` function is defined outside the `input_fn` "
1806                      "function. Please ensure that all datasets in the "
1807                      "pipeline are created in the same graph as the iterator. "
1808                      "NOTE: This warning will become an error in future "
1809                      "versions of TensorFlow.")
1810    for input_ds in ds._inputs():  # pylint: disable=protected-access
1811      if input_ds not in visited:
1812        bfs_q.put(input_ds)
1813
1814
1815@tf_export(v1=["data.make_one_shot_iterator"])
1816def make_one_shot_iterator(dataset):
1817  """Creates a `tf.data.Iterator` for enumerating the elements of a dataset.
1818
1819  Note: The returned iterator will be initialized automatically.
1820  A "one-shot" iterator does not support re-initialization.
1821
1822  Args:
1823    dataset: A `tf.data.Dataset`.
1824
1825  Returns:
1826    A `tf.data.Iterator` over the elements of this dataset.
1827  """
1828  try:
1829    # Call the defined `_make_one_shot_iterator()` if there is one, because some
1830    # datasets (e.g. for prefetching) override its behavior.
1831    return dataset._make_one_shot_iterator()  # pylint: disable=protected-access
1832  except AttributeError:
1833    return DatasetV1Adapter(dataset)._make_one_shot_iterator()  # pylint: disable=protected-access
1834
1835
1836@tf_export(v1=["data.make_initializable_iterator"])
1837def make_initializable_iterator(dataset, shared_name=None):
1838  """Creates a `tf.data.Iterator` for enumerating the elements of a dataset.
1839
1840  Note: The returned iterator will be in an uninitialized state,
1841  and you must run the `iterator.initializer` operation before using it:
1842
1843  ```python
1844  dataset = ...
1845  iterator = tf.data.make_initializable_iterator(dataset)
1846  # ...
1847  sess.run(iterator.initializer)
1848  ```
1849
1850  Args:
1851    dataset: A `tf.data.Dataset`.
1852    shared_name: (Optional.) If non-empty, the returned iterator will be
1853      shared under the given name across multiple sessions that share the
1854      same devices (e.g. when using a remote server).
1855
1856  Returns:
1857    A `tf.data.Iterator` over the elements of `dataset`.
1858
1859  Raises:
1860    RuntimeError: If eager execution is enabled.
1861  """
1862  try:
1863    # Call the defined `_make_initializable_iterator()` if there is one, because
1864    # some datasets (e.g. for prefetching) override its behavior.
1865    return dataset._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
1866  except AttributeError:
1867    return DatasetV1Adapter(dataset)._make_initializable_iterator(shared_name)  # pylint: disable=protected-access
1868
1869
1870# TODO(b/110122868): Replace this method with a public API for reflecting on
1871# dataset structure.
1872def get_structure(dataset_or_iterator):
1873  """Returns the `tf.data.experimental.Structure` of a `Dataset` or `Iterator`.
1874
1875  Args:
1876    dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
1877    `EagerIterator`.
1878
1879  Returns:
1880    A `tf.data.experimental.Structure` representing the structure of the
1881    elements of `dataset_or_iterator`.
1882
1883  Raises:
1884    TypeError: If `dataset_or_iterator` is not a dataset or iterator object.
1885  """
1886  try:
1887    ret = dataset_or_iterator._element_structure  # pylint: disable=protected-access
1888    if isinstance(ret, structure_lib.Structure):
1889      return ret
1890  except AttributeError:
1891    pass
1892  raise TypeError("`dataset_or_iterator` must be a Dataset or Iterator object, "
1893                  "but got %s." % type(dataset_or_iterator))
1894
1895
1896# TODO(b/110122868): Remove all uses of this method.
1897def get_legacy_output_shapes(dataset_or_iterator):
1898  """Returns the output shapes of a `Dataset` or `Iterator`.
1899
1900  This utility method replaces the deprecated-in-V2
1901  `tf.compat.v1.Dataset.output_shapes` property.
1902
1903  Args:
1904    dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
1905    `EagerIterator`.
1906
1907  Returns:
1908    A nested structure of `tf.TensorShape` objects corresponding to each
1909    component of an element of the given dataset or iterator.
1910  """
1911  return get_structure(dataset_or_iterator)._to_legacy_output_shapes()  # pylint: disable=protected-access
1912
1913
1914# TODO(b/110122868): Remove all uses of this method.
1915def get_legacy_output_types(dataset_or_iterator):
1916  """Returns the output shapes of a `Dataset` or `Iterator`.
1917
1918  This utility method replaces the deprecated-in-V2
1919  `tf.compat.v1.Dataset.output_types` property.
1920
1921  Args:
1922    dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
1923    `EagerIterator`.
1924
1925  Returns:
1926    A nested structure of `tf.DType` objects corresponding to each component
1927    of an element of this dataset.
1928  """
1929  return get_structure(dataset_or_iterator)._to_legacy_output_types()  # pylint: disable=protected-access
1930
1931
1932# TODO(b/110122868): Remove all uses of this method.
1933def get_legacy_output_classes(dataset_or_iterator):
1934  """Returns the output classes of a `Dataset` or `Iterator`.
1935
1936  This utility method replaces the deprecated-in-V2
1937  `tf.compat.v1.Dataset.output_classes` property.
1938
1939  Args:
1940    dataset_or_iterator: A `tf.data.Dataset`, `tf.data.Iterator`, or
1941    `EagerIterator`.
1942
1943  Returns:
1944    A nested structure of Python `type` or `tf.data.experimental.Structure`
1945    objects corresponding to each component of an element of this dataset.
1946  """
1947  return get_structure(dataset_or_iterator)._to_legacy_output_classes()  # pylint: disable=protected-access
1948
1949
1950@tf_export("data.Options")
1951class Options(options_lib.OptionsBase):
1952  """Represents options for tf.data.Dataset.
1953
1954  An `Options` object can be, for instance, used to control which static
1955  optimizations to apply or whether to use performance modeling to dynamically
1956  tune the parallelism of operations such as `tf.data.Dataset.map` or
1957  `tf.data.Dataset.interleave`.
1958  """
1959
1960  experimental_deterministic = options_lib.create_option(
1961      name="experimental_deterministic",
1962      ty=bool,
1963      docstring=
1964      "Whether the outputs need to be produced in deterministic order. If None,"
1965      " defaults to True.")
1966
1967  experimental_numa_aware = options_lib.create_option(
1968      name="experimental_numa_aware",
1969      ty=bool,
1970      docstring=
1971      "Whether to use NUMA-aware operations. If None, defaults to False.")
1972
1973  experimental_optimization = options_lib.create_option(
1974      name="experimental_optimization",
1975      ty=optimization_options.OptimizationOptions,
1976      docstring=
1977      "The optimization options associated with the dataset. See "
1978      "`tf.data.experimental.OptimizationOptions` for more details.",
1979      default_factory=optimization_options.OptimizationOptions)
1980
1981  experimental_stats = options_lib.create_option(
1982      name="experimental_stats",
1983      ty=stats_options.StatsOptions,
1984      docstring=
1985      "The statistics options associated with the dataset. See "
1986      "`tf.data.experimental.StatsOptions` for more details.",
1987      default_factory=stats_options.StatsOptions)
1988
1989  experimental_threading = options_lib.create_option(
1990      name="experimental_threading",
1991      ty=threading_options.ThreadingOptions,
1992      docstring=
1993      "The threading options associated with the dataset. See "
1994      "`tf.data.experimental.ThreadingOptions` for more details.",
1995      default_factory=threading_options.ThreadingOptions)
1996
1997  def _static_optimizations(self):
1998    """Produces the list of enabled static optimizations."""
1999
2000    result = []
2001    result.extend(self.experimental_optimization._static_optimizations())  # pylint: disable=protected-access
2002
2003    if self.experimental_numa_aware:
2004      result.append("make_numa_aware")
2005    if self.experimental_deterministic is False:
2006      result.append("make_sloppy")
2007    exp_stats_options = self.experimental_stats
2008    if exp_stats_options and exp_stats_options.latency_all_edges:
2009      result.append("latency_all_edges")
2010    return result
2011
2012  def merge(self, options):
2013    """Merges itself with the given `tf.data.Options`.
2014
2015    The given `tf.data.Options` can be merged as long as there does not exist an
2016    attribute that is set to different values in `self` and `options`.
2017
2018    Args:
2019      options: a `tf.data.Options` to merge with
2020
2021    Raises:
2022      ValueError: if the given `tf.data.Options` cannot be merged
2023
2024    Returns:
2025      New `tf.data.Options()` object which is the result of merging self with
2026      the input `tf.data.Options`.
2027    """
2028    return options_lib.merge_options(self, options)
2029
2030
2031class DatasetSource(DatasetV2):
2032  """Abstract class representing a dataset with no inputs."""
2033
2034  def _inputs(self):
2035    return []
2036
2037
2038class UnaryDataset(DatasetV2):
2039  """Abstract class representing a dataset with one input."""
2040
2041  def __init__(self, input_dataset, variant_tensor):
2042    self._input_dataset = input_dataset
2043    super(UnaryDataset, self).__init__(variant_tensor)
2044
2045  def _inputs(self):
2046    return [self._input_dataset]
2047
2048
2049class UnaryUnchangedStructureDataset(UnaryDataset):
2050  """Represents a unary dataset with the same input and output structure."""
2051
2052  def __init__(self, input_dataset, variant_tensor):
2053    self._input_dataset = input_dataset
2054    super(UnaryUnchangedStructureDataset, self).__init__(
2055        input_dataset, variant_tensor)
2056
2057  @property
2058  def _element_structure(self):
2059    return self._input_dataset._element_structure  # pylint: disable=protected-access
2060
2061
2062class TensorDataset(DatasetSource):
2063  """A `Dataset` with a single element, viz. a nested structure of tensors."""
2064
2065  def __init__(self, tensors):
2066    """See `Dataset.from_tensors()` for details."""
2067    with ops.name_scope("tensors"):
2068      tensors = nest.pack_sequence_as(tensors, [
2069          sparse_tensor_lib.SparseTensor.from_value(t)
2070          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
2071              t, name="component_%d" % i)
2072          for i, t in enumerate(nest.flatten(tensors))
2073      ])
2074    self._structure = structure_lib.Structure.from_value(tensors)
2075    self._tensors = self._structure._to_tensor_list(tensors)  # pylint: disable=protected-access
2076
2077    variant_tensor = gen_dataset_ops.tensor_dataset(
2078        self._tensors, output_shapes=self._structure._flat_shapes)  # pylint: disable=protected-access
2079    super(TensorDataset, self).__init__(variant_tensor)
2080
2081  @property
2082  def _element_structure(self):
2083    return self._structure
2084
2085
2086class TensorSliceDataset(DatasetSource):
2087  """A `Dataset` of slices from a nested structure of tensors."""
2088
2089  def __init__(self, tensors):
2090    """See `Dataset.from_tensor_slices()` for details."""
2091    with ops.name_scope("tensors"):
2092      tensors = nest.pack_sequence_as(tensors, [
2093          sparse_tensor_lib.SparseTensor.from_value(t)
2094          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
2095              t, name="component_%d" % i)
2096          for i, t in enumerate(nest.flatten(tensors))
2097      ])
2098
2099    batched_structure = structure_lib.Structure.from_value(tensors)
2100    # pylint: disable=protected-access
2101    self._tensors = batched_structure._to_batched_tensor_list(tensors)
2102    self._structure = batched_structure._unbatch()
2103    # pylint: enable=protected-access
2104
2105    batch_dim = tensor_shape.Dimension(tensor_shape.dimension_value(
2106        self._tensors[0].get_shape()[0]))
2107    for t in self._tensors[1:]:
2108      batch_dim.assert_is_compatible_with(tensor_shape.Dimension(
2109          tensor_shape.dimension_value(t.get_shape()[0])))
2110
2111    variant_tensor = gen_dataset_ops.tensor_slice_dataset(
2112        self._tensors, output_shapes=self._structure._flat_shapes)  # pylint: disable=protected-access
2113    super(TensorSliceDataset, self).__init__(variant_tensor)
2114
2115  @property
2116  def _element_structure(self):
2117    return self._structure
2118
2119
2120class SparseTensorSliceDataset(DatasetSource):
2121  """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
2122
2123  def __init__(self, sparse_tensor):
2124    """See `Dataset.from_sparse_tensor_slices()` for details."""
2125    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
2126      raise TypeError(
2127          "`sparse_tensor` must be a `tf.SparseTensor` object. Was {}.".format(
2128              sparse_tensor))
2129    self._sparse_tensor = sparse_tensor
2130
2131    indices_shape = self._sparse_tensor.indices.get_shape()
2132    shape_shape = self._sparse_tensor.dense_shape.get_shape()
2133    rank = (indices_shape.dims[1] - 1).merge_with(shape_shape.dims[0] - 1)
2134    self._structure = structure_lib.NestedStructure(
2135        (structure_lib.TensorStructure(dtypes.int64, [None, rank]),
2136         structure_lib.TensorStructure(self._sparse_tensor.dtype, [None]),
2137         structure_lib.TensorStructure(dtypes.int64, [rank])))
2138
2139    variant_tensor = gen_dataset_ops.sparse_tensor_slice_dataset(
2140        self._sparse_tensor.indices, self._sparse_tensor.values,
2141        self._sparse_tensor.dense_shape)
2142    super(SparseTensorSliceDataset, self).__init__(variant_tensor)
2143
2144  @property
2145  def _element_structure(self):
2146    return self._structure
2147
2148
2149class _VariantDataset(DatasetV2):
2150  """A Dataset wrapper around a `tf.variant`-typed function argument."""
2151
2152  def __init__(self, dataset_variant, structure):
2153    self._structure = structure
2154    super(_VariantDataset, self).__init__(dataset_variant)
2155
2156  def _inputs(self):
2157    return []
2158
2159  @property
2160  def _element_structure(self):
2161    return self._structure
2162
2163
2164@tf_export("data.experimental.DatasetStructure")
2165class DatasetStructure(structure_lib.Structure):
2166  """Represents a `Dataset` of structured values."""
2167
2168  def __init__(self, element_structure):
2169    self._element_structure = element_structure
2170
2171  @property
2172  def _flat_shapes(self):
2173    return [tensor_shape.scalar()]
2174
2175  @property
2176  def _flat_types(self):
2177    return [dtypes.variant]
2178
2179  def is_compatible_with(self, other):
2180    # pylint: disable=protected-access
2181    return (isinstance(other, DatasetStructure) and
2182            self._element_structure.is_compatible_with(
2183                other._element_structure))
2184
2185  def _to_tensor_list(self, value):
2186    return [value._variant_tensor]  # pylint: disable=protected-access
2187
2188  def _to_batched_tensor_list(self, value):
2189    raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
2190
2191  def _from_tensor_list(self, flat_value):
2192    if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
2193        not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
2194      raise ValueError(
2195          "DatasetStructure corresponds to a single tf.variant scalar.")
2196    return self._from_compatible_tensor_list(flat_value)
2197
2198  def _from_compatible_tensor_list(self, flat_value):
2199    # pylint: disable=protected-access
2200    return _VariantDataset(flat_value[0], self._element_structure)
2201
2202  @staticmethod
2203  def from_value(value):
2204    return DatasetStructure(value._element_structure)  # pylint: disable=protected-access
2205
2206  def _to_legacy_output_types(self):
2207    return self
2208
2209  def _to_legacy_output_shapes(self):
2210    return self
2211
2212  def _to_legacy_output_classes(self):
2213    return self
2214
2215  def _batch(self, batch_size):
2216    raise NotImplementedError("Batching for `tf.data.Dataset` objects.")
2217
2218  def _unbatch(self):
2219    raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
2220
2221
2222# pylint: disable=protected-access
2223structure_lib.Structure._register_custom_converter(DatasetV2,
2224                                                   DatasetStructure.from_value)
2225# pylint: enable=protected-access
2226
2227
2228class StructuredFunctionWrapper(object):
2229  """A function wrapper that supports structured arguments and return values."""
2230
2231  # pylint: disable=protected-access
2232  def __init__(self,
2233               func,
2234               transformation_name,
2235               dataset=None,
2236               input_classes=None,
2237               input_shapes=None,
2238               input_types=None,
2239               input_structure=None,
2240               add_to_graph=True,
2241               use_legacy_function=False,
2242               defun_kwargs=None):
2243    """Creates a new `StructuredFunctionWrapper` for the given function.
2244
2245    Args:
2246      func: A function from a nested structure to another nested structure.
2247      transformation_name: Human-readable name of the transformation in which
2248        this function is being instantiated, for error messages.
2249      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
2250        dataset will be assumed as the structure for `func` arguments; otherwise
2251        `input_classes`, `input_shapes`, and `input_types` must be defined.
2252      input_classes: (Optional.) A nested structure of `type`. If given, this
2253        argument defines the Python types for `func` arguments.
2254      input_shapes: (Optional.) A nested structure of `tf.TensorShape`. If
2255        given, this argument defines the shapes and structure for `func`
2256        arguments.
2257      input_types: (Optional.) A nested structure of `tf.DType`. If given, this
2258        argument defines the element types and structure for `func` arguments.
2259      input_structure: (Optional.) A `Structure` object. If given, this argument
2260        defines the element types and structure for `func` arguments.
2261      add_to_graph: (Optional.) If `True`, the function will be added to the
2262        default graph.
2263      use_legacy_function: (Optional.) A boolean that determines whether the
2264        function be created using `tensorflow.python.eager.function.defun`
2265        (default behavior) or `tensorflow.python.framework.function.Defun`
2266        (legacy beheavior).
2267      defun_kwargs: (Optional.) A dictionary mapping string argument names to
2268        values. If supplied, will be passed to `function` as keyword arguments.
2269
2270    Raises:
2271      ValueError: If an invalid combination of `dataset`, `input_classes`,
2272        `input_shapes`, and `input_types` is passed.
2273    """
2274    if input_structure is None:
2275      if dataset is None:
2276        if input_classes is None or input_shapes is None or input_types is None:
2277          raise ValueError("Either `dataset`, `input_structure` or all of "
2278                           "`input_classes`, `input_shapes`, and `input_types` "
2279                           "must be specified.")
2280        self._input_structure = structure_lib.convert_legacy_structure(
2281            input_types, input_shapes, input_classes)
2282      else:
2283        if not (input_classes is None and input_shapes is None and
2284                input_types is None):
2285          raise ValueError("Either `dataset`, `input_structure` or all of "
2286                           "`input_classes`, `input_shapes`, and `input_types` "
2287                           "must be specified.")
2288        self._input_structure = dataset._element_structure
2289    else:
2290      if not (dataset is None and input_classes is None and input_shapes is None
2291              and input_types is None):
2292        raise ValueError("Either `dataset`, `input_structure`, or all of "
2293                         "`input_classes`, `input_shapes`, and `input_types` "
2294                         "must be specified.")
2295      self._input_structure = input_structure
2296
2297    if defun_kwargs is None:
2298      defun_kwargs = {}
2299
2300    readable_transformation_name = transformation_name.replace(
2301        ".", "_")[:-2] if len(transformation_name) > 2 else ""
2302
2303    func_name = "_".join(
2304        [readable_transformation_name,
2305         function_utils.get_func_name(func)])
2306
2307    def _warn_if_collections(transformation_name):
2308      """Prints a warning if the given graph uses common graph collections.
2309
2310      NOTE(mrry): Currently a warning is only generated for resources. Any
2311      variables created will be automatically hoisted out to the outermost scope
2312      using `init_scope()`. Some collections (such as for control-flow contexts)
2313      are benign and should not generate a warning.
2314
2315      Args:
2316        transformation_name: A human-readable name for the transformation.
2317      """
2318      warnings.warn("Creating resources inside a function passed to %s "
2319                    "is not supported. Create each resource outside the "
2320                    "function, and capture it inside the function to use it." %
2321                    transformation_name, stacklevel=5)
2322
2323    def _wrapper_helper(*args):
2324      """Wrapper for passing nested structures to and from tf.data functions."""
2325      nested_args = self._input_structure._from_compatible_tensor_list(args)
2326      if not _should_unpack_args(nested_args):
2327        nested_args = (nested_args,)
2328
2329      ret = func(*nested_args)
2330      # If `func` returns a list of tensors, `nest.flatten()` and
2331      # `ops.convert_to_tensor()` would conspire to attempt to stack
2332      # those tensors into a single tensor, because the customized
2333      # version of `nest.flatten()` does not recurse into lists. Since
2334      # it is more likely that the list arose from returning the
2335      # result of an operation (such as `tf.py_func()`) that returns a
2336      # list of not-necessarily-stackable tensors, we treat the
2337      # returned value is a `tuple` instead. A user wishing to pack
2338      # the return value into a single tensor can use an explicit
2339      # `tf.stack()` before returning.
2340      if isinstance(ret, list):
2341        ret = tuple(ret)
2342
2343      try:
2344        self._output_structure = structure_lib.Structure.from_value(ret)
2345      except (ValueError, TypeError):
2346        raise TypeError("Unsupported return value from function passed to "
2347                        "%s: %s." % (transformation_name, ret))
2348      return ret
2349
2350    if use_legacy_function:
2351      func_name = func_name + "_" + str(ops.uid())
2352
2353      @function.Defun(
2354          *self._input_structure._flat_types,
2355          func_name=func_name,
2356          **defun_kwargs)
2357      def wrapper_fn(*args):
2358        ret = _wrapper_helper(*args)
2359        # _warn_if_collections(transformation_name, ops.get_default_graph(), 0)
2360        return self._output_structure._to_tensor_list(ret)
2361
2362      self._function = wrapper_fn
2363      resource_tracker = tracking.ResourceTracker()
2364      with tracking.resource_tracker_scope(resource_tracker):
2365        if add_to_graph:
2366          self._function.add_to_graph(ops.get_default_graph())
2367        else:
2368          # Use the private method that will execute `wrapper_fn` but delay
2369          # adding it to the graph in case (e.g.) we need to rerun the function.
2370          self._function._create_definition_if_needed()
2371      if resource_tracker.resources:
2372        _warn_if_collections(transformation_name)
2373
2374    else:
2375      defun_kwargs.update({"func_name": func_name})
2376
2377      # TODO(b/124254153): Enable autograph once the overhead is low enough.
2378      # TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
2379      @eager_function.defun_with_attributes(
2380          input_signature=[
2381              tensor_spec.TensorSpec(input_shape, input_type)  # pylint: disable=g-complex-comprehension
2382              for input_shape, input_type in zip(
2383                  self._input_structure._flat_shapes,
2384                  self._input_structure._flat_types)
2385          ],
2386          autograph=False,
2387          attributes=defun_kwargs)
2388      def wrapper_fn(*args):  # pylint: disable=missing-docstring
2389        ret = _wrapper_helper(*args)
2390        ret = self._output_structure._to_tensor_list(ret)
2391        return [ops.convert_to_tensor(t) for t in ret]
2392
2393      resource_tracker = tracking.ResourceTracker()
2394      with tracking.resource_tracker_scope(resource_tracker):
2395        self._function = wrapper_fn._get_concrete_function_internal()
2396        if add_to_graph:
2397          self._function.add_to_graph(ops.get_default_graph())
2398      if resource_tracker.resources:
2399        _warn_if_collections(transformation_name)
2400
2401      outer_graph_seed = ops.get_default_graph().seed
2402      if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
2403        if self._function.graph._seed_used:
2404          warnings.warn(
2405              "Seed %s from outer graph might be getting used by function %s, "
2406              "if the random op has not been provided any seed. Explicitly set "
2407              "the seed in the function if this is not the intended behavior."
2408              %(outer_graph_seed, func_name), stacklevel=4)
2409  # pylint: enable=protected-access
2410
2411  @property
2412  def output_structure(self):
2413    return self._output_structure
2414
2415  @property
2416  def output_classes(self):
2417    return self._output_structure._to_legacy_output_classes()  # pylint: disable=protected-access
2418
2419  @property
2420  def output_shapes(self):
2421    return self._output_structure._to_legacy_output_shapes()  # pylint: disable=protected-access
2422
2423  @property
2424  def output_types(self):
2425    return self._output_structure._to_legacy_output_types()  # pylint: disable=protected-access
2426
2427  @property
2428  def function(self):
2429    return self._function
2430
2431
2432def flat_structure(dataset):
2433  """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops.
2434
2435  Most Dataset op constructors expect `output_shapes` and `output_types`
2436  arguments that represent the flattened structure of an element. This helper
2437  function generates these attrs as a keyword argument dictionary, allowing
2438  `Dataset._variant_tensor` implementations to pass
2439  `**flat_structure(self)` to the op constructor.
2440
2441  Args:
2442    dataset: A `tf.data.Dataset`.
2443
2444  Returns:
2445    A dictionary of keyword arguments that can be passed to many Dataset op
2446    constructors.
2447  """
2448  # pylint: disable=protected-access
2449  structure = dataset._element_structure
2450  return {
2451      "output_shapes": structure._flat_shapes,
2452      "output_types": structure._flat_types,
2453  }
2454
2455
2456class _GeneratorDataset(DatasetSource):
2457  """A `Dataset` that generates elements by invoking a function."""
2458
2459  def __init__(self, init_args, init_func, next_func, finalize_func):
2460    """Constructs a `_GeneratorDataset`.
2461
2462    Args:
2463      init_args: A nested structure representing the arguments to `init_func`.
2464      init_func: A TensorFlow function that will be called on `init_args` each
2465        time a C++ iterator over this dataset is constructed. Returns a nested
2466        structure representing the "state" of the dataset.
2467      next_func: A TensorFlow function that will be called on the result of
2468        `init_func` to produce each element, and that raises `OutOfRangeError`
2469        to terminate iteration.
2470      finalize_func: A TensorFlow function that will be called on the result of
2471        `init_func` immediately before a C++ iterator over this dataset is
2472        destroyed. The return value is ignored.
2473    """
2474    self._init_args = init_args
2475
2476    self._init_structure = structure_lib.Structure.from_value(init_args)
2477
2478    self._init_func = StructuredFunctionWrapper(
2479        init_func,
2480        self._transformation_name(),
2481        input_structure=self._init_structure)
2482
2483    self._next_func = StructuredFunctionWrapper(
2484        next_func,
2485        self._transformation_name(),
2486        input_structure=self._init_func.output_structure)
2487
2488    self._finalize_func = StructuredFunctionWrapper(
2489        finalize_func,
2490        self._transformation_name(),
2491        input_structure=self._init_func.output_structure)
2492    variant_tensor = gen_dataset_ops.generator_dataset(
2493        self._init_structure._to_tensor_list(self._init_args)  # pylint: disable=protected-access
2494        + self._init_func.function.captured_inputs,
2495        self._next_func.function.captured_inputs,
2496        self._finalize_func.function.captured_inputs,
2497        init_func=self._init_func.function,
2498        next_func=self._next_func.function,
2499        finalize_func=self._finalize_func.function,
2500        **flat_structure(self))
2501    super(_GeneratorDataset, self).__init__(variant_tensor)
2502
2503  @property
2504  def _element_structure(self):
2505    return self._next_func.output_structure
2506
2507  def _transformation_name(self):
2508    return "Dataset.from_generator()"
2509
2510
2511class ZipDataset(DatasetV2):
2512  """A `Dataset` that zips its inputs together."""
2513
2514  def __init__(self, datasets):
2515    """See `Dataset.zip()` for details."""
2516    for ds in nest.flatten(datasets):
2517      if not isinstance(ds, DatasetV2):
2518        if isinstance(ds, list):
2519          message = ("The argument to `Dataset.zip()` must be a nested "
2520                     "structure of `Dataset` objects. Nested structures do not "
2521                     "support Python lists; please use a tuple instead.")
2522        else:
2523          message = ("The argument to `Dataset.zip()` must be a nested "
2524                     "structure of `Dataset` objects.")
2525        raise TypeError(message)
2526    self._datasets = datasets
2527    self._structure = structure_lib.NestedStructure(
2528        nest.pack_sequence_as(
2529            self._datasets,
2530            [ds._element_structure for ds in nest.flatten(self._datasets)]))  # pylint: disable=protected-access
2531
2532    # pylint: disable=protected-access
2533    variant_tensor = gen_dataset_ops.zip_dataset(
2534        [ds._variant_tensor for ds in nest.flatten(self._datasets)],
2535        **flat_structure(self))
2536    # pylint: enable=protected-access
2537    super(ZipDataset, self).__init__(variant_tensor)
2538
2539  def _inputs(self):
2540    return nest.flatten(self._datasets)
2541
2542  @property
2543  def _element_structure(self):
2544    return self._structure
2545
2546
2547class ConcatenateDataset(DatasetV2):
2548  """A `Dataset` that concatenates its input with given dataset."""
2549
2550  def __init__(self, input_dataset, dataset_to_concatenate):
2551    """See `Dataset.concatenate()` for details."""
2552    self._input_dataset = input_dataset
2553    self._dataset_to_concatenate = dataset_to_concatenate
2554
2555    output_types = get_legacy_output_types(input_dataset)
2556    if output_types != get_legacy_output_types(dataset_to_concatenate):
2557      raise TypeError(
2558          "Two datasets to concatenate have different types %s and %s" %
2559          (output_types, get_legacy_output_types(dataset_to_concatenate)))
2560
2561    output_classes = get_legacy_output_classes(input_dataset)
2562    if output_classes != get_legacy_output_classes(dataset_to_concatenate):
2563      raise TypeError(
2564          "Two datasets to concatenate have different classes %s and %s" %
2565          (output_classes, get_legacy_output_classes(dataset_to_concatenate)))
2566
2567    input_shapes = get_legacy_output_shapes(self._input_dataset)
2568    output_shapes = nest.pack_sequence_as(input_shapes, [
2569        ts1.most_specific_compatible_shape(ts2)
2570        for (ts1, ts2) in zip(
2571            nest.flatten(input_shapes),
2572            nest.flatten(get_legacy_output_shapes(
2573                self._dataset_to_concatenate)))
2574    ])
2575
2576    self._structure = structure_lib.convert_legacy_structure(
2577        output_types, output_shapes, output_classes)
2578
2579    self._input_datasets = [input_dataset, dataset_to_concatenate]
2580    # pylint: disable=protected-access
2581    variant_tensor = gen_dataset_ops.concatenate_dataset(
2582        input_dataset._variant_tensor, dataset_to_concatenate._variant_tensor,
2583        **flat_structure(self))
2584    # pylint: enable=protected-access
2585    super(ConcatenateDataset, self).__init__(variant_tensor)
2586
2587  def _inputs(self):
2588    return self._input_datasets
2589
2590  @property
2591  def _element_structure(self):
2592    return self._structure
2593
2594
2595class RepeatDataset(UnaryUnchangedStructureDataset):
2596  """A `Dataset` that repeats its input several times."""
2597
2598  def __init__(self, input_dataset, count):
2599    """See `Dataset.repeat()` for details."""
2600    self._input_dataset = input_dataset
2601    if count is None:
2602      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
2603    else:
2604      self._count = ops.convert_to_tensor(
2605          count, dtype=dtypes.int64, name="count")
2606    variant_tensor = gen_dataset_ops.repeat_dataset(
2607        input_dataset._variant_tensor,  # pylint: disable=protected-access
2608        count=self._count,
2609        **flat_structure(self))
2610    super(RepeatDataset, self).__init__(input_dataset, variant_tensor)
2611
2612
2613class RangeDataset(DatasetSource):
2614  """A `Dataset` of a step separated range of values."""
2615
2616  def __init__(self, *args):
2617    """See `Dataset.range()` for details."""
2618    self._parse_args(*args)
2619    variant_tensor = gen_dataset_ops.range_dataset(
2620        start=self._start,
2621        stop=self._stop,
2622        step=self._step,
2623        **flat_structure(self))
2624    super(RangeDataset, self).__init__(variant_tensor)
2625
2626  def _parse_args(self, *args):
2627    """Parse arguments according to the same rules as the `range()` builtin."""
2628    if len(args) == 1:
2629      self._start = self._build_tensor(0, "start")
2630      self._stop = self._build_tensor(args[0], "stop")
2631      self._step = self._build_tensor(1, "step")
2632    elif len(args) == 2:
2633      self._start = self._build_tensor(args[0], "start")
2634      self._stop = self._build_tensor(args[1], "stop")
2635      self._step = self._build_tensor(1, "step")
2636    elif len(args) == 3:
2637      self._start = self._build_tensor(args[0], "start")
2638      self._stop = self._build_tensor(args[1], "stop")
2639      self._step = self._build_tensor(args[2], "step")
2640    else:
2641      raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
2642
2643  def _build_tensor(self, int64_value, name):
2644    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
2645
2646  @property
2647  def _element_structure(self):
2648    return structure_lib.TensorStructure(dtypes.int64, [])
2649
2650
2651class CacheDataset(UnaryUnchangedStructureDataset):
2652  """A `Dataset` that caches elements of its input."""
2653
2654  def __init__(self, input_dataset, filename):
2655    """See `Dataset.cache()` for details."""
2656    self._input_dataset = input_dataset
2657    self._filename = ops.convert_to_tensor(
2658        filename, dtype=dtypes.string, name="filename")
2659    variant_tensor = gen_dataset_ops.cache_dataset(
2660        input_dataset._variant_tensor,  # pylint: disable=protected-access
2661        filename=self._filename,
2662        **flat_structure(self))
2663    super(CacheDataset, self).__init__(input_dataset, variant_tensor)
2664
2665
2666class ShuffleDataset(UnaryUnchangedStructureDataset):
2667  """A `Dataset` that randomly shuffles the elements of its input."""
2668
2669  def __init__(self,
2670               input_dataset,
2671               buffer_size,
2672               seed=None,
2673               reshuffle_each_iteration=None):
2674    """Randomly shuffles the elements of this dataset.
2675
2676    Args:
2677      input_dataset: The input dataset.
2678      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
2679        number of elements from this dataset from which the new
2680        dataset will sample.
2681      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
2682        random seed that will be used to create the distribution. See
2683        `tf.set_random_seed` for behavior.
2684      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
2685        that the dataset should be pseudorandomly reshuffled each time it is
2686        iterated over. (Defaults to `True`.)
2687
2688    Returns:
2689      A `Dataset`.
2690
2691    Raises:
2692      ValueError: if invalid arguments are provided.
2693    """
2694    self._input_dataset = input_dataset
2695    self._buffer_size = ops.convert_to_tensor(
2696        buffer_size, dtype=dtypes.int64, name="buffer_size")
2697    self._seed, self._seed2 = random_seed.get_seed(seed)
2698
2699    if reshuffle_each_iteration is None:
2700      self._reshuffle_each_iteration = True
2701    else:
2702      self._reshuffle_each_iteration = reshuffle_each_iteration
2703    variant_tensor = gen_dataset_ops.shuffle_dataset(
2704        input_dataset._variant_tensor,  # pylint: disable=protected-access
2705        buffer_size=self._buffer_size,
2706        seed=self._seed,
2707        seed2=self._seed2,
2708        reshuffle_each_iteration=self._reshuffle_each_iteration,
2709        **flat_structure(self))
2710    super(ShuffleDataset, self).__init__(input_dataset, variant_tensor)
2711
2712
2713class TakeDataset(UnaryUnchangedStructureDataset):
2714  """A `Dataset` containing the first `count` elements from its input."""
2715
2716  def __init__(self, input_dataset, count):
2717    """See `Dataset.take()` for details."""
2718    self._input_dataset = input_dataset
2719    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
2720    variant_tensor = gen_dataset_ops.take_dataset(
2721        input_dataset._variant_tensor,  # pylint: disable=protected-access
2722        count=self._count,
2723        **flat_structure(self))
2724    super(TakeDataset, self).__init__(input_dataset, variant_tensor)
2725
2726
2727class SkipDataset(UnaryUnchangedStructureDataset):
2728  """A `Dataset` skipping the first `count` elements from its input."""
2729
2730  def __init__(self, input_dataset, count):
2731    """See `Dataset.skip()` for details."""
2732    self._input_dataset = input_dataset
2733    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
2734    variant_tensor = gen_dataset_ops.skip_dataset(
2735        input_dataset._variant_tensor,  # pylint: disable=protected-access
2736        count=self._count,
2737        **flat_structure(self))
2738    super(SkipDataset, self).__init__(input_dataset, variant_tensor)
2739
2740
2741class ShardDataset(UnaryUnchangedStructureDataset):
2742  """A `Dataset` for sharding its input."""
2743
2744  def __init__(self, input_dataset, num_shards, index):
2745    """See `Dataset.shard()` for details."""
2746    self._input_dataset = input_dataset
2747    self._num_shards = ops.convert_to_tensor(
2748        num_shards, dtype=dtypes.int64, name="num_shards")
2749    self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
2750    variant_tensor = gen_dataset_ops.shard_dataset(
2751        input_dataset._variant_tensor,  # pylint: disable=protected-access
2752        num_shards=self._num_shards,
2753        index=self._index,
2754        **flat_structure(self))
2755    super(ShardDataset, self).__init__(input_dataset, variant_tensor)
2756
2757
2758class BatchDataset(UnaryDataset):
2759  """A `Dataset` that batches contiguous elements from its input."""
2760
2761  def __init__(self, input_dataset, batch_size, drop_remainder):
2762    """See `Dataset.batch()` for details."""
2763    self._input_dataset = input_dataset
2764    self._batch_size = ops.convert_to_tensor(
2765        batch_size, dtype=dtypes.int64, name="batch_size")
2766    self._drop_remainder = ops.convert_to_tensor(
2767        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
2768
2769    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder)
2770    # pylint: disable=protected-access
2771    if constant_drop_remainder:
2772      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
2773      # or `False` (explicitly retaining the remainder).
2774      self._structure = input_dataset._element_structure._batch(
2775          tensor_util.constant_value(self._batch_size))
2776    else:
2777      self._structure = input_dataset._element_structure._batch(None)
2778    variant_tensor = gen_dataset_ops.batch_dataset_v2(
2779        input_dataset._variant_tensor,  # pylint: disable=protected-access
2780        batch_size=self._batch_size,
2781        drop_remainder=self._drop_remainder,
2782        **flat_structure(self))
2783    super(BatchDataset, self).__init__(input_dataset, variant_tensor)
2784
2785  @property
2786  def _element_structure(self):
2787    return self._structure
2788
2789
2790def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
2791  """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
2792
2793  Args:
2794    padded_shape: A `tf.TensorShape`.
2795    input_component_shape: A `tf.TensorShape`.
2796
2797  Returns:
2798    `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
2799    `False`.
2800  """
2801
2802  if padded_shape.dims is None or input_component_shape.dims is None:
2803    return True
2804  if len(padded_shape.dims) != len(input_component_shape.dims):
2805    return False
2806  for padded_dim, input_dim in zip(
2807      padded_shape.dims, input_component_shape.dims):
2808    if (padded_dim.value is not None and input_dim.value is not None
2809        and padded_dim.value < input_dim.value):
2810      return False
2811  return True
2812
2813
2814def _padded_shape_to_tensor(padded_shape, input_component_shape):
2815  """Converts `padded_shape` to a `tf.Tensor` representing that shape.
2816
2817  Args:
2818    padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
2819      sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
2820    input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
2821      be compatible.
2822
2823  Returns:
2824    A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
2825
2826  Raises:
2827    ValueError: If `padded_shape` is not a shape or not compatible with
2828      `input_component_shape`.
2829    TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
2830  """
2831  try:
2832    # Try to convert the `padded_shape` to a `tf.TensorShape`
2833    padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
2834    # We will return the "canonical" tensor representation, which uses
2835    # `-1` in place of `None`.
2836    ret = ops.convert_to_tensor(
2837        [dim if dim is not None else -1
2838         for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
2839  except (TypeError, ValueError):
2840    # The argument was not trivially convertible to a
2841    # `tf.TensorShape`, so fall back on the conversion to tensor
2842    # machinery.
2843    ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
2844    if ret.shape.dims is not None and len(ret.shape.dims) != 1:
2845      raise ValueError(
2846          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
2847          "shape was %s." % (padded_shape, ret.shape))
2848    if ret.dtype != dtypes.int64:
2849      raise TypeError(
2850          "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
2851          "element type was %s." % (padded_shape, ret.dtype.name))
2852    padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
2853
2854  if not _is_padded_shape_compatible_with(padded_shape_as_shape,
2855                                          input_component_shape):
2856    raise ValueError("The padded shape %s is not compatible with the "
2857                     "corresponding input component shape %s."
2858                     % (padded_shape_as_shape, input_component_shape))
2859
2860  return ret
2861
2862
2863def _padding_value_to_tensor(value, output_type):
2864  """Converts the padding value to a tensor.
2865
2866  Args:
2867    value: The padding value.
2868    output_type: Its expected dtype.
2869
2870  Returns:
2871    A scalar `Tensor`.
2872
2873  Raises:
2874    ValueError: if the padding value is not a scalar.
2875    TypeError: if the padding value's type does not match `output_type`.
2876  """
2877  value = ops.convert_to_tensor(value, name="padding_value")
2878  if not value.shape.is_compatible_with(tensor_shape.scalar()):
2879    raise ValueError("Padding value should be a scalar, but is not: %s" % value)
2880  if value.dtype != output_type:
2881    raise TypeError("Padding value tensor (%s) does not match output type: %s" %
2882                    (value, output_type))
2883  return value
2884
2885
2886def _default_padding(input_dataset):
2887  """Returns default padding tensors in a structure matching `input_dataset`."""
2888  def make_zero(t):
2889    if t.base_dtype == dtypes.string:
2890      return ""
2891    elif t.base_dtype == dtypes.variant:
2892      error_msg = ("Unable to create padding for field of type 'variant' "
2893                   "because t.base_type == dtypes.variant == "
2894                   "{}.".format(
2895                       t.base_dtype))
2896      raise TypeError(error_msg)
2897    else:
2898      return np.zeros_like(t.as_numpy_dtype())
2899
2900  return nest.map_structure(
2901      make_zero, get_legacy_output_types(input_dataset))
2902
2903
2904class PaddedBatchDataset(UnaryDataset):
2905  """A `Dataset` that batches and pads contiguous elements from its input."""
2906
2907  def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
2908               drop_remainder):
2909    """See `Dataset.batch()` for details."""
2910    self._input_dataset = input_dataset
2911    if sparse.any_sparse(get_legacy_output_classes(input_dataset)):
2912      # TODO(b/63669786): support batching of sparse tensors
2913      raise TypeError(
2914          "Batching of padded sparse tensors is not currently supported")
2915    self._input_dataset = input_dataset
2916    self._batch_size = ops.convert_to_tensor(
2917        batch_size, dtype=dtypes.int64, name="batch_size")
2918    padding_values = (
2919        padding_values
2920        if padding_values is not None else _default_padding(input_dataset))
2921
2922    input_shapes = get_legacy_output_shapes(input_dataset)
2923    flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
2924
2925    flat_padded_shapes_as_tensors = []
2926
2927    for input_component_shape, padded_shape in zip(
2928        nest.flatten(input_shapes), flat_padded_shapes):
2929      flat_padded_shapes_as_tensors.append(
2930          _padded_shape_to_tensor(padded_shape, input_component_shape))
2931
2932    self._padded_shapes = nest.pack_sequence_as(input_shapes,
2933                                                flat_padded_shapes_as_tensors)
2934
2935    self._padding_values = nest.map_structure_up_to(
2936        input_shapes, _padding_value_to_tensor, padding_values,
2937        get_legacy_output_types(input_dataset))
2938    self._drop_remainder = ops.convert_to_tensor(
2939        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
2940
2941    def _padded_shape_to_batch_shape(s):
2942      return tensor_shape.vector(
2943          tensor_util.constant_value(self._batch_size) if smart_cond.
2944          smart_constant_value(self._drop_remainder) else None).concatenate(
2945              tensor_util.constant_value_as_shape(s))
2946
2947    output_shapes = nest.map_structure(
2948        _padded_shape_to_batch_shape, self._padded_shapes)
2949    self._structure = structure_lib.convert_legacy_structure(
2950        get_legacy_output_types(self._input_dataset), output_shapes,
2951        get_legacy_output_classes(self._input_dataset))
2952
2953    # pylint: disable=protected-access
2954    # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
2955    if smart_cond.smart_constant_value(self._drop_remainder) is False:
2956      variant_tensor = gen_dataset_ops.padded_batch_dataset(
2957          input_dataset._variant_tensor,  # pylint: disable=protected-access
2958          batch_size=self._batch_size,
2959          padded_shapes=[
2960              ops.convert_to_tensor(s, dtype=dtypes.int64)
2961              for s in nest.flatten(self._padded_shapes)
2962          ],
2963          padding_values=nest.flatten(self._padding_values),
2964          output_shapes=self._structure._flat_shapes)
2965    else:
2966      variant_tensor = gen_dataset_ops.padded_batch_dataset_v2(
2967          input_dataset._variant_tensor,  # pylint: disable=protected-access
2968          batch_size=self._batch_size,
2969          padded_shapes=[
2970              ops.convert_to_tensor(s, dtype=dtypes.int64)
2971              for s in nest.flatten(self._padded_shapes)
2972          ],
2973          padding_values=nest.flatten(self._padding_values),
2974          drop_remainder=self._drop_remainder,
2975          output_shapes=self._structure._flat_shapes)
2976    super(PaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
2977
2978  @property
2979  def _element_structure(self):
2980    return self._structure
2981
2982
2983def _should_unpack_args(args):
2984  """Returns `True` if `args` should be `*args` when passed to a callable."""
2985  return type(args) is tuple  # pylint: disable=unidiomatic-typecheck
2986
2987
2988class MapDataset(UnaryDataset):
2989  """A `Dataset` that maps a function over elements in its input."""
2990
2991  def __init__(self,
2992               input_dataset,
2993               map_func,
2994               use_inter_op_parallelism=True,
2995               preserve_cardinality=False,
2996               use_legacy_function=False):
2997    """See `Dataset.map()` for details."""
2998    self._input_dataset = input_dataset
2999    self._use_inter_op_parallelism = use_inter_op_parallelism
3000    self._preserve_cardinality = preserve_cardinality
3001    self._map_func = StructuredFunctionWrapper(
3002        map_func,
3003        self._transformation_name(),
3004        dataset=input_dataset,
3005        use_legacy_function=use_legacy_function)
3006    variant_tensor = gen_dataset_ops.map_dataset(
3007        input_dataset._variant_tensor,  # pylint: disable=protected-access
3008        self._map_func.function.captured_inputs,
3009        f=self._map_func.function,
3010        use_inter_op_parallelism=self._use_inter_op_parallelism,
3011        preserve_cardinality=self._preserve_cardinality,
3012        **flat_structure(self))
3013    super(MapDataset, self).__init__(input_dataset, variant_tensor)
3014
3015  def _functions(self):
3016    return [self._map_func]
3017
3018  @property
3019  def _element_structure(self):
3020    return self._map_func.output_structure
3021
3022  def _transformation_name(self):
3023    return "Dataset.map()"
3024
3025
3026class ParallelMapDataset(UnaryDataset):
3027  """A `Dataset` that maps a function over elements in its input in parallel."""
3028
3029  def __init__(self,
3030               input_dataset,
3031               map_func,
3032               num_parallel_calls,
3033               use_inter_op_parallelism=True,
3034               preserve_cardinality=False,
3035               use_legacy_function=False):
3036    """See `Dataset.map()` for details."""
3037    self._input_dataset = input_dataset
3038    self._use_inter_op_parallelism = use_inter_op_parallelism
3039    self._map_func = StructuredFunctionWrapper(
3040        map_func,
3041        self._transformation_name(),
3042        dataset=input_dataset,
3043        use_legacy_function=use_legacy_function)
3044    self._num_parallel_calls = ops.convert_to_tensor(
3045        num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
3046    self._preserve_cardinality = preserve_cardinality
3047    variant_tensor = gen_dataset_ops.parallel_map_dataset(
3048        input_dataset._variant_tensor,  # pylint: disable=protected-access
3049        self._map_func.function.captured_inputs,
3050        f=self._map_func.function,
3051        num_parallel_calls=self._num_parallel_calls,
3052        use_inter_op_parallelism=self._use_inter_op_parallelism,
3053        preserve_cardinality=self._preserve_cardinality,
3054        **flat_structure(self))
3055    super(ParallelMapDataset, self).__init__(input_dataset, variant_tensor)
3056
3057  def _functions(self):
3058    return [self._map_func]
3059
3060  @property
3061  def _element_structure(self):
3062    return self._map_func.output_structure
3063
3064  def _transformation_name(self):
3065    return "Dataset.map()"
3066
3067
3068class FlatMapDataset(UnaryDataset):
3069  """A `Dataset` that maps a function over its input and flattens the result."""
3070
3071  def __init__(self, input_dataset, map_func):
3072    """See `Dataset.flat_map()` for details."""
3073    self._input_dataset = input_dataset
3074    self._map_func = StructuredFunctionWrapper(
3075        map_func, self._transformation_name(), dataset=input_dataset)
3076    if not isinstance(self._map_func.output_structure, DatasetStructure):
3077      raise TypeError(
3078          "`map_func` must return a `Dataset` object. Got {}".format(
3079              type(self._map_func.output_structure)))
3080    self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
3081    variant_tensor = gen_dataset_ops.flat_map_dataset(
3082        input_dataset._variant_tensor,  # pylint: disable=protected-access
3083        self._map_func.function.captured_inputs,
3084        f=self._map_func.function,
3085        **flat_structure(self))
3086    super(FlatMapDataset, self).__init__(input_dataset, variant_tensor)
3087
3088  def _functions(self):
3089    return [self._map_func]
3090
3091  @property
3092  def _element_structure(self):
3093    return self._structure
3094
3095  def _transformation_name(self):
3096    return "Dataset.flat_map()"
3097
3098
3099class InterleaveDataset(UnaryDataset):
3100  """A `Dataset` that maps a function over its input and interleaves the result.
3101  """
3102
3103  def __init__(self, input_dataset, map_func, cycle_length, block_length):
3104    """See `Dataset.interleave()` for details."""
3105    self._input_dataset = input_dataset
3106    self._map_func = StructuredFunctionWrapper(
3107        map_func, self._transformation_name(), dataset=input_dataset)
3108    if not isinstance(self._map_func.output_structure, DatasetStructure):
3109      raise TypeError(
3110          "`map_func` must return a `Dataset` object. Got {}".format(
3111              type(self._map_func.output_structure)))
3112    self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
3113    self._cycle_length = ops.convert_to_tensor(
3114        cycle_length, dtype=dtypes.int64, name="cycle_length")
3115    self._block_length = ops.convert_to_tensor(
3116        block_length, dtype=dtypes.int64, name="block_length")
3117
3118    variant_tensor = gen_dataset_ops.interleave_dataset(
3119        input_dataset._variant_tensor,  # pylint: disable=protected-access
3120        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
3121        self._cycle_length,
3122        self._block_length,
3123        f=self._map_func.function,
3124        **flat_structure(self))
3125    super(InterleaveDataset, self).__init__(input_dataset, variant_tensor)
3126
3127  def _functions(self):
3128    return [self._map_func]
3129
3130  @property
3131  def _element_structure(self):
3132    return self._structure
3133
3134  def _transformation_name(self):
3135    return "Dataset.interleave()"
3136
3137
3138class ParallelInterleaveDataset(UnaryDataset):
3139  """A `Dataset` that maps a function over its input and interleaves the result."""
3140
3141  def __init__(self, input_dataset, map_func, cycle_length, block_length,
3142               num_parallel_calls):
3143    """See `Dataset.interleave()` for details."""
3144    self._input_dataset = input_dataset
3145    self._map_func = StructuredFunctionWrapper(
3146        map_func, self._transformation_name(), dataset=input_dataset)
3147    if not isinstance(self._map_func.output_structure, DatasetStructure):
3148      raise TypeError(
3149          "`map_func` must return a `Dataset` object. Got {}".format(
3150              type(self._map_func.output_structure)))
3151    self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
3152    self._cycle_length = ops.convert_to_tensor(
3153        cycle_length, dtype=dtypes.int64, name="cycle_length")
3154    self._block_length = ops.convert_to_tensor(
3155        block_length, dtype=dtypes.int64, name="block_length")
3156    self._num_parallel_calls = ops.convert_to_tensor(
3157        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
3158    variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
3159        input_dataset._variant_tensor,  # pylint: disable=protected-access
3160        self._map_func.function.captured_inputs,  # pylint: disable=protected-access
3161        self._cycle_length,
3162        self._block_length,
3163        self._num_parallel_calls,
3164        f=self._map_func.function,
3165        **flat_structure(self))
3166    super(ParallelInterleaveDataset, self).__init__(input_dataset,
3167                                                    variant_tensor)
3168
3169  def _functions(self):
3170    return [self._map_func]
3171
3172  @property
3173  def _element_structure(self):
3174    return self._structure
3175
3176  def _transformation_name(self):
3177    return "Dataset.interleave()"
3178
3179
3180class FilterDataset(UnaryUnchangedStructureDataset):
3181  """A `Dataset` that filters its input according to a predicate function."""
3182
3183  def __init__(self, input_dataset, predicate, use_legacy_function=False):
3184    """See `Dataset.filter()` for details."""
3185    self._input_dataset = input_dataset
3186    wrapped_func = StructuredFunctionWrapper(
3187        predicate,
3188        self._transformation_name(),
3189        dataset=input_dataset,
3190        use_legacy_function=use_legacy_function)
3191    if not wrapped_func.output_structure.is_compatible_with(
3192        structure_lib.TensorStructure(dtypes.bool, [])):
3193      error_msg = ("`predicate` return type must be convertible to a scalar "
3194                   "boolean tensor. Was {}.").format(
3195                       wrapped_func.output_structure)
3196      raise ValueError(error_msg)
3197    self._predicate = wrapped_func
3198    variant_tensor = gen_dataset_ops.filter_dataset(
3199        input_dataset._variant_tensor,  # pylint: disable=protected-access
3200        other_arguments=self._predicate.function.captured_inputs,
3201        predicate=self._predicate.function,
3202        **flat_structure(self))
3203    super(FilterDataset, self).__init__(input_dataset, variant_tensor)
3204
3205  def _functions(self):
3206    return [self._predicate]
3207
3208  def _transformation_name(self):
3209    return "Dataset.filter()"
3210
3211
3212class PrefetchDataset(UnaryUnchangedStructureDataset):
3213  """A `Dataset` that asynchronously prefetches its input."""
3214
3215  def __init__(self, input_dataset, buffer_size):
3216    """See `Dataset.prefetch()` for details."""
3217    self._input_dataset = input_dataset
3218    if buffer_size is None:
3219      buffer_size = -1  # This is the sentinel for auto-tuning.
3220    self._buffer_size = ops.convert_to_tensor(
3221        buffer_size, dtype=dtypes.int64, name="buffer_size")
3222    variant_tensor = gen_dataset_ops.prefetch_dataset(
3223        input_dataset._variant_tensor,  # pylint: disable=protected-access
3224        buffer_size=self._buffer_size,
3225        **flat_structure(self))
3226    super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
3227
3228
3229class WindowDataset(UnaryDataset):
3230  """A dataset that creates window datasets from the input elements."""
3231
3232  def __init__(self, input_dataset, size, shift, stride, drop_remainder):
3233    """See `window_dataset()` for more details."""
3234    self._input_dataset = input_dataset
3235    self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
3236    self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
3237    self._stride = ops.convert_to_tensor(
3238        stride, dtype=dtypes.int64, name="stride")
3239    self._drop_remainder = ops.convert_to_tensor(
3240        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
3241    nest_of_structures = nest.pack_sequence_as(
3242        get_legacy_output_classes(input_dataset),
3243        [
3244            DatasetStructure(structure_lib.convert_legacy_structure(
3245                output_type, output_shape, output_class))
3246            for output_class, output_shape, output_type in zip(
3247                nest.flatten(get_legacy_output_classes(input_dataset)),
3248                nest.flatten(get_legacy_output_shapes(input_dataset)),
3249                nest.flatten(get_legacy_output_types(input_dataset)))
3250        ])
3251    self._structure = structure_lib.NestedStructure(nest_of_structures)
3252    variant_tensor = gen_dataset_ops.window_dataset(
3253        input_dataset._variant_tensor,  # pylint: disable=protected-access
3254        self._size,
3255        self._shift,
3256        self._stride,
3257        self._drop_remainder,
3258        **flat_structure(self))
3259    super(WindowDataset, self).__init__(input_dataset, variant_tensor)
3260
3261  @property
3262  def _element_structure(self):
3263    return self._structure
3264
3265
3266class _OptionsDataset(UnaryUnchangedStructureDataset):
3267  """An identity `Dataset` that stores options."""
3268
3269  def __init__(self, input_dataset, options):
3270    self._input_dataset = input_dataset
3271    self._options = input_dataset.options()
3272    if self._options:
3273      self._options = self._options.merge(options)
3274    else:
3275      self._options = options
3276    variant_tensor = input_dataset._variant_tensor  # pylint: disable=protected-access
3277    super(_OptionsDataset, self).__init__(input_dataset, variant_tensor)
3278
3279  def options(self):
3280    return self._options
3281
3282
3283class _ModelDataset(UnaryUnchangedStructureDataset):
3284  """A `Dataset` that acts as an identity, and models performance."""
3285
3286  def __init__(self, input_dataset, cpu_budget):
3287    self._input_dataset = input_dataset
3288    variant_tensor = gen_dataset_ops.model_dataset(
3289        input_dataset._variant_tensor,  # pylint: disable=protected-access
3290        cpu_budget=cpu_budget,
3291        **flat_structure(self))
3292    super(_ModelDataset, self).__init__(input_dataset, variant_tensor)
3293
3294
3295class _OptimizeDataset(UnaryUnchangedStructureDataset):
3296  """A `Dataset` that acts as an identity, and applies optimizations."""
3297
3298  def __init__(self, input_dataset, optimizations):
3299    self._input_dataset = input_dataset
3300    if optimizations is None:
3301      optimizations = []
3302    self._optimizations = ops.convert_to_tensor(
3303        optimizations, dtype=dtypes.string, name="optimizations")
3304    variant_tensor = gen_dataset_ops.optimize_dataset(
3305        input_dataset._variant_tensor,  # pylint: disable=protected-access
3306        self._optimizations,
3307        **flat_structure(self))
3308    super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
3309
3310
3311class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
3312  """A `Dataset` that acts as an identity, and sets a stats aggregator."""
3313
3314  def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
3315    self._input_dataset = input_dataset
3316    self._stats_aggregator = aggregator
3317    self._prefix = prefix
3318    self._counter_prefix = counter_prefix
3319    variant_tensor = ged_ops.experimental_set_stats_aggregator_dataset(
3320        input_dataset._variant_tensor,  # pylint: disable=protected-access
3321        self._stats_aggregator._resource,  # pylint: disable=protected-access
3322        self._prefix,
3323        self._counter_prefix,
3324        **flat_structure(self))
3325    super(_SetStatsAggregatorDataset, self).__init__(input_dataset,
3326                                                     variant_tensor)
3327
3328
3329class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
3330  """A `Dataset` that acts as an identity, overriding intra-op parallelism."""
3331
3332  def __init__(self, input_dataset, max_intra_op_parallelism):
3333    self._input_dataset = input_dataset
3334    self._max_intra_op_parallelism = ops.convert_to_tensor(
3335        max_intra_op_parallelism,
3336        dtype=dtypes.int64,
3337        name="max_intra_op_parallelism")
3338    variant_tensor = ged_ops.experimental_max_intra_op_parallelism_dataset(
3339        input_dataset._variant_tensor,  # pylint: disable=protected-access
3340        self._max_intra_op_parallelism,
3341        **flat_structure(self))
3342    super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset,
3343                                                        variant_tensor)
3344
3345
3346class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
3347  """A `Dataset` that acts as an identity, setting a private threadpool."""
3348
3349  def __init__(self, input_dataset, num_threads):
3350    self._input_dataset = input_dataset
3351    self._num_threads = ops.convert_to_tensor(
3352        num_threads, dtype=dtypes.int64, name="num_threads")
3353    variant_tensor = ged_ops.experimental_private_thread_pool_dataset(
3354        input_dataset._variant_tensor,  # pylint: disable=protected-access
3355        self._num_threads,
3356        **flat_structure(self))
3357    super(_PrivateThreadPoolDataset, self).__init__(input_dataset,
3358                                                    variant_tensor)
3359