• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Iterators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import threading
22import warnings
23
24import six
25
26from tensorflow.python.data.ops import optional_ops
27from tensorflow.python.data.ops import options as options_lib
28from tensorflow.python.data.util import nest
29from tensorflow.python.data.util import structure
30from tensorflow.python.eager import context
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import gen_dataset_ops
39from tensorflow.python.training.saver import BaseSaverBuilder
40from tensorflow.python.training.tracking import base as trackable
41from tensorflow.python.util import _pywrap_utils
42from tensorflow.python.util import deprecation
43from tensorflow.python.util import lazy_loader
44from tensorflow.python.util.compat import collections_abc
45from tensorflow.python.util.tf_export import tf_export
46
47
48# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
49# times, e.g. when you are distributing different elements to multiple
50# devices in a single step. However, a common pitfall arises when
51# users call `Iterator.get_next()` in each iteration of their training
52# loop. `Iterator.get_next()` adds ops to the graph, and executing
53# each op allocates resources (including threads); as a consequence,
54# invoking it in every iteration of a training loop causes slowdown
55# and eventual resource exhaustion. To guard against this outcome, we
56# log a warning when the number of uses crosses a threshold of suspicion.
57GET_NEXT_CALL_WARNING_THRESHOLD = 32
58
59GET_NEXT_CALL_WARNING_MESSAGE = (
60    "An unusually high number of `Iterator.get_next()` calls was detected. "
61    "This often indicates that `Iterator.get_next()` is being called inside "
62    "a training loop, which will cause gradual slowdown and eventual resource "
63    "exhaustion. If this is the case, restructure your code to call "
64    "`next_element = iterator.get_next()` once outside the loop, and use "
65    "`next_element` as the input to some computation that is invoked inside "
66    "the loop.")
67
68# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during
69# tf.function tracing.
70GET_NEXT_CALL_ERROR_THRESHOLD = 32
71
72GET_NEXT_CALL_ERROR_MESSAGE = (
73    "An unusually high number of `tf.data.Iterator.get_next()` calls was "
74    "detected. This suggests that the `for elem in dataset: ...` idiom is used "
75    "within tf.function with AutoGraph disabled. This idiom is only supported "
76    "when AutoGraph is enabled.")
77
78# Collection of all IteratorResources in the `Graph`.
79GLOBAL_ITERATORS = "iterators"
80
81
82autograph_ctx = lazy_loader.LazyLoader(
83    "autograph_ctx", globals(),
84    "tensorflow.python.autograph.core.ag_ctx")
85
86
87def _device_stack_is_empty():
88  if context.executing_eagerly():
89    return context.context().device_name is None
90  # pylint: disable=protected-access
91  device_stack = ops.get_default_graph()._device_functions_outer_to_inner
92  # pylint: enable=protected-access
93  return not bool(device_stack)
94
95
96@tf_export(v1=["data.Iterator"])
97class Iterator(trackable.Trackable):
98  """Represents the state of iterating through a `Dataset`."""
99
100  def __init__(self, iterator_resource, initializer, output_types,
101               output_shapes, output_classes):
102    """Creates a new iterator from the given iterator resource.
103
104    Note: Most users will not call this initializer directly, and will
105    instead use `Dataset.make_initializable_iterator()` or
106    `Dataset.make_one_shot_iterator()`.
107
108    Args:
109      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
110        iterator.
111      initializer: A `tf.Operation` that should be run to initialize this
112        iterator.
113      output_types: A (nested) structure of `tf.DType` objects corresponding to
114        each component of an element of this iterator.
115      output_shapes: A (nested) structure of `tf.TensorShape` objects
116        corresponding to each component of an element of this iterator.
117      output_classes: A (nested) structure of Python `type` objects
118        corresponding to each component of an element of this iterator.
119    """
120    self._iterator_resource = iterator_resource
121    self._initializer = initializer
122
123    if (output_types is None or output_shapes is None
124        or output_classes is None):
125      raise ValueError("If `structure` is not specified, all of "
126                       "`output_types`, `output_shapes`, and `output_classes`"
127                       " must be specified.")
128    self._element_spec = structure.convert_legacy_structure(
129        output_types, output_shapes, output_classes)
130    self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
131        self._element_spec)
132    self._flat_tensor_types = structure.get_flat_tensor_types(
133        self._element_spec)
134
135    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
136        self._iterator_resource)
137    self._get_next_call_count = 0
138    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
139
140  @staticmethod
141  def from_structure(output_types,
142                     output_shapes=None,
143                     shared_name=None,
144                     output_classes=None):
145    """Creates a new, uninitialized `Iterator` with the given structure.
146
147    This iterator-constructing method can be used to create an iterator that
148    is reusable with many different datasets.
149
150    The returned iterator is not bound to a particular dataset, and it has
151    no `initializer`. To initialize the iterator, run the operation returned by
152    `Iterator.make_initializer(dataset)`.
153
154    The following is an example
155
156    ```python
157    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
158
159    dataset_range = Dataset.range(10)
160    range_initializer = iterator.make_initializer(dataset_range)
161
162    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
163    evens_initializer = iterator.make_initializer(dataset_evens)
164
165    # Define a model based on the iterator; in this example, the model_fn
166    # is expected to take scalar tf.int64 Tensors as input (see
167    # the definition of 'iterator' above).
168    prediction, loss = model_fn(iterator.get_next())
169
170    # Train for `num_epochs`, where for each epoch, we first iterate over
171    # dataset_range, and then iterate over dataset_evens.
172    for _ in range(num_epochs):
173      # Initialize the iterator to `dataset_range`
174      sess.run(range_initializer)
175      while True:
176        try:
177          pred, loss_val = sess.run([prediction, loss])
178        except tf.errors.OutOfRangeError:
179          break
180
181      # Initialize the iterator to `dataset_evens`
182      sess.run(evens_initializer)
183      while True:
184        try:
185          pred, loss_val = sess.run([prediction, loss])
186        except tf.errors.OutOfRangeError:
187          break
188    ```
189
190    Args:
191      output_types: A (nested) structure of `tf.DType` objects corresponding to
192        each component of an element of this dataset.
193      output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
194        objects corresponding to each component of an element of this dataset.
195        If omitted, each component will have an unconstrainted shape.
196      shared_name: (Optional.) If non-empty, this iterator will be shared under
197        the given name across multiple sessions that share the same devices
198        (e.g. when using a remote server).
199      output_classes: (Optional.) A (nested) structure of Python `type` objects
200        corresponding to each component of an element of this iterator. If
201        omitted, each component is assumed to be of type `tf.Tensor`.
202
203    Returns:
204      An `Iterator`.
205
206    Raises:
207      TypeError: If the structures of `output_shapes` and `output_types` are
208        not the same.
209    """
210    output_types = nest.map_structure(dtypes.as_dtype, output_types)
211    if output_shapes is None:
212      output_shapes = nest.map_structure(
213          lambda _: tensor_shape.TensorShape(None), output_types)
214    else:
215      output_shapes = nest.map_structure_up_to(output_types,
216                                               tensor_shape.as_shape,
217                                               output_shapes)
218    if output_classes is None:
219      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
220    nest.assert_same_structure(output_types, output_shapes)
221    output_structure = structure.convert_legacy_structure(
222        output_types, output_shapes, output_classes)
223    if shared_name is None:
224      shared_name = ""
225    iterator_resource = gen_dataset_ops.iterator_v2(
226        container="",
227        shared_name=shared_name,
228        output_types=structure.get_flat_tensor_types(output_structure),
229        output_shapes=structure.get_flat_tensor_shapes(
230            output_structure))
231    return Iterator(iterator_resource, None, output_types, output_shapes,
232                    output_classes)
233
234  @staticmethod
235  def from_string_handle(string_handle,
236                         output_types,
237                         output_shapes=None,
238                         output_classes=None):
239    """Creates a new, uninitialized `Iterator` based on the given handle.
240
241    This method allows you to define a "feedable" iterator where you can choose
242    between concrete iterators by feeding a value in a `tf.Session.run` call.
243    In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you
244    would
245    feed it with the value of `tf.data.Iterator.string_handle` in each step.
246
247    For example, if you had two iterators that marked the current position in
248    a training dataset and a test dataset, you could choose which to use in
249    each step as follows:
250
251    ```python
252    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
253    train_iterator_handle = sess.run(train_iterator.string_handle())
254
255    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
256    test_iterator_handle = sess.run(test_iterator.string_handle())
257
258    handle = tf.compat.v1.placeholder(tf.string, shape=[])
259    iterator = tf.data.Iterator.from_string_handle(
260        handle, train_iterator.output_types)
261
262    next_element = iterator.get_next()
263    loss = f(next_element)
264
265    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
266    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
267    ```
268
269    Args:
270      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to
271        a handle produced by the `Iterator.string_handle()` method.
272      output_types: A (nested) structure of `tf.DType` objects corresponding to
273        each component of an element of this dataset.
274      output_shapes: (Optional.) A (nested) structure of `tf.TensorShape`
275        objects corresponding to each component of an element of this dataset.
276        If omitted, each component will have an unconstrainted shape.
277      output_classes: (Optional.) A (nested) structure of Python `type` objects
278        corresponding to each component of an element of this iterator. If
279        omitted, each component is assumed to be of type `tf.Tensor`.
280
281    Returns:
282      An `Iterator`.
283    """
284    output_types = nest.map_structure(dtypes.as_dtype, output_types)
285    if output_shapes is None:
286      output_shapes = nest.map_structure(
287          lambda _: tensor_shape.TensorShape(None), output_types)
288    else:
289      output_shapes = nest.map_structure_up_to(output_types,
290                                               tensor_shape.as_shape,
291                                               output_shapes)
292    if output_classes is None:
293      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
294    nest.assert_same_structure(output_types, output_shapes)
295    output_structure = structure.convert_legacy_structure(
296        output_types, output_shapes, output_classes)
297    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
298    iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
299        string_handle,
300        output_types=structure.get_flat_tensor_types(output_structure),
301        output_shapes=structure.get_flat_tensor_shapes(output_structure))
302    return Iterator(iterator_resource, None, output_types, output_shapes,
303                    output_classes)
304
305  @property
306  def initializer(self):
307    """A `tf.Operation` that should be run to initialize this iterator.
308
309    Returns:
310      A `tf.Operation` that should be run to initialize this iterator
311
312    Raises:
313      ValueError: If this iterator initializes itself automatically.
314    """
315    if self._initializer is not None:
316      return self._initializer
317    else:
318      # TODO(mrry): Consider whether one-shot iterators should have
319      # initializers that simply reset their state to the beginning.
320      raise ValueError("Iterator does not have an initializer.")
321
322  def make_initializer(self, dataset, name=None):
323    """Returns a `tf.Operation` that initializes this iterator on `dataset`.
324
325    Args:
326      dataset: A `Dataset` whose `element_spec` if compatible with this
327        iterator.
328      name: (Optional.) A name for the created operation.
329
330    Returns:
331      A `tf.Operation` that can be run to initialize this iterator on the given
332      `dataset`.
333
334    Raises:
335      TypeError: If `dataset` and this iterator do not have a compatible
336        `element_spec`.
337    """
338    with ops.name_scope(name, "make_initializer") as name:
339      # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
340      # to that creating a circular dependency.
341      # pylint: disable=protected-access
342      dataset_output_types = nest.map_structure(
343          lambda component_spec: component_spec._to_legacy_output_types(),
344          dataset.element_spec)
345      dataset_output_shapes = nest.map_structure(
346          lambda component_spec: component_spec._to_legacy_output_shapes(),
347          dataset.element_spec)
348      dataset_output_classes = nest.map_structure(
349          lambda component_spec: component_spec._to_legacy_output_classes(),
350          dataset.element_spec)
351      # pylint: enable=protected-access
352
353      nest.assert_same_structure(self.output_types, dataset_output_types)
354      nest.assert_same_structure(self.output_shapes, dataset_output_shapes)
355      for iterator_class, dataset_class in zip(
356          nest.flatten(self.output_classes),
357          nest.flatten(dataset_output_classes)):
358        if iterator_class is not dataset_class:
359          raise TypeError(
360              "Expected output classes %r but got dataset with output class %r."
361              % (self.output_classes, dataset_output_classes))
362      for iterator_dtype, dataset_dtype in zip(
363          nest.flatten(self.output_types), nest.flatten(dataset_output_types)):
364        if iterator_dtype != dataset_dtype:
365          raise TypeError(
366              "Expected output types %r but got dataset with output types %r." %
367              (self.output_types, dataset_output_types))
368      for iterator_shape, dataset_shape in zip(
369          nest.flatten(self.output_shapes), nest.flatten(
370              dataset_output_shapes)):
371        if not iterator_shape.is_compatible_with(dataset_shape):
372          raise TypeError("Expected output shapes compatible with %r but got "
373                          "dataset with output shapes %r." %
374                          (self.output_shapes, dataset_output_shapes))
375
376    # TODO(b/169442955): Investigate the need for this colocation constraint.
377    with ops.colocate_with(self._iterator_resource):
378      # pylint: disable=protected-access
379      return gen_dataset_ops.make_iterator(
380          dataset._variant_tensor, self._iterator_resource, name=name)
381
382  def get_next(self, name=None):
383    """Returns the next element.
384
385    In graph mode, you should typically call this method *once* and use its
386    result as the input to another computation. A typical loop will then call
387    `tf.Session.run` on the result of that computation. The loop will terminate
388    when the `Iterator.get_next()` operation raises
389    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
390    this method when building a training loop:
391
392    ```python
393    dataset = ...  # A `tf.data.Dataset` object.
394    iterator = dataset.make_initializable_iterator()
395    next_element = iterator.get_next()
396
397    # Build a TensorFlow graph that does something with each element.
398    loss = model_function(next_element)
399    optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
400    train_op = optimizer.minimize(loss)
401
402    with tf.compat.v1.Session() as sess:
403      try:
404        while True:
405          sess.run(train_op)
406      except tf.errors.OutOfRangeError:
407        pass
408    ```
409
410    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
411    when you are distributing different elements to multiple devices in a single
412    step. However, a common pitfall arises when users call `Iterator.get_next()`
413    in each iteration of their training loop. `Iterator.get_next()` adds ops to
414    the graph, and executing each op allocates resources (including threads); as
415    a consequence, invoking it in every iteration of a training loop causes
416    slowdown and eventual resource exhaustion. To guard against this outcome, we
417    log a warning when the number of uses crosses a fixed threshold of
418    suspiciousness.
419
420    Args:
421      name: (Optional.) A name for the created operation.
422
423    Returns:
424      A (nested) structure of values matching `tf.data.Iterator.element_spec`.
425    """
426    self._get_next_call_count += 1
427    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
428      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
429
430    # TODO(b/169442955): Investigate the need for this colocation constraint.
431    with ops.colocate_with(self._iterator_resource):
432      # pylint: disable=protected-access
433      flat_ret = gen_dataset_ops.iterator_get_next(
434          self._iterator_resource,
435          output_types=self._flat_tensor_types,
436          output_shapes=self._flat_tensor_shapes,
437          name=name)
438      return structure.from_tensor_list(self._element_spec, flat_ret)
439
440  def get_next_as_optional(self):
441    # TODO(b/169442955): Investigate the need for this colocation constraint.
442    with ops.colocate_with(self._iterator_resource):
443      # pylint: disable=protected-access
444      return optional_ops._OptionalImpl(
445          gen_dataset_ops.iterator_get_next_as_optional(
446              self._iterator_resource,
447              output_types=structure.get_flat_tensor_types(self.element_spec),
448              output_shapes=structure.get_flat_tensor_shapes(
449                  self.element_spec)), self.element_spec)
450
451  def string_handle(self, name=None):
452    """Returns a string-valued `tf.Tensor` that represents this iterator.
453
454    Args:
455      name: (Optional.) A name for the created operation.
456
457    Returns:
458      A scalar `tf.Tensor` of type `tf.string`.
459    """
460    if name is None:
461      return self._string_handle
462    else:
463      return gen_dataset_ops.iterator_to_string_handle(
464          self._iterator_resource, name=name)
465
466  @property
467  @deprecation.deprecated(
468      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
469  def output_classes(self):
470    """Returns the class of each component of an element of this iterator.
471
472    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
473
474    Returns:
475      A (nested) structure of Python `type` objects corresponding to each
476      component of an element of this dataset.
477    """
478    return nest.map_structure(
479        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
480        self._element_spec)
481
482  @property
483  @deprecation.deprecated(
484      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
485  def output_shapes(self):
486    """Returns the shape of each component of an element of this iterator.
487
488    Returns:
489      A (nested) structure of `tf.TensorShape` objects corresponding to each
490      component of an element of this dataset.
491    """
492    return nest.map_structure(
493        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
494        self._element_spec)
495
496  @property
497  @deprecation.deprecated(
498      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
499  def output_types(self):
500    """Returns the type of each component of an element of this iterator.
501
502    Returns:
503      A (nested) structure of `tf.DType` objects corresponding to each component
504      of an element of this dataset.
505    """
506    return nest.map_structure(
507        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
508        self._element_spec)
509
510  @property
511  def element_spec(self):
512    """The type specification of an element of this iterator.
513
514    For more information,
515    read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
516
517    Returns:
518      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
519      element of this iterator and specifying the type of individual components.
520    """
521
522    return self._element_spec
523
524  def _gather_saveables_for_checkpoint(self):
525
526    def _saveable_factory(name):
527      return _IteratorSaveable(self._iterator_resource, name)
528
529    return {"ITERATOR": _saveable_factory}
530
531
532_uid_counter = 0
533_uid_lock = threading.Lock()
534
535
536def _generate_shared_name(prefix):
537  with _uid_lock:
538    global _uid_counter
539    uid = _uid_counter
540    _uid_counter += 1
541  return "{}{}".format(prefix, uid)
542
543
544class IteratorResourceDeleter(object):
545  """An object which cleans up an iterator resource handle.
546
547  An alternative to defining a __del__ method on an object. Even if the parent
548  object is part of a reference cycle, the cycle will be collectable.
549  """
550
551  __slots__ = ["_deleter", "_handle", "_eager_mode"]
552
553  def __init__(self, handle, deleter):
554    self._deleter = deleter
555    self._handle = handle
556    self._eager_mode = context.executing_eagerly()
557
558  def __del__(self):
559    # Make sure the resource is deleted in the same mode as it was created in.
560    if self._eager_mode:
561      with context.eager_mode():
562        gen_dataset_ops.delete_iterator(
563            handle=self._handle, deleter=self._deleter)
564    else:
565      with context.graph_mode():
566        gen_dataset_ops.delete_iterator(
567            handle=self._handle, deleter=self._deleter)
568
569
570@tf_export("data.Iterator", v1=[])
571@six.add_metaclass(abc.ABCMeta)
572class IteratorBase(collections_abc.Iterator, trackable.Trackable,
573                   composite_tensor.CompositeTensor):
574  """Represents an iterator of a `tf.data.Dataset`.
575
576  `tf.data.Iterator` is the primary mechanism for enumerating elements of a
577  `tf.data.Dataset`. It supports the Python Iterator protocol, which means
578  it can be iterated over using a for-loop:
579
580  >>> dataset = tf.data.Dataset.range(2)
581  >>> for element in dataset:
582  ...   print(element)
583  tf.Tensor(0, shape=(), dtype=int64)
584  tf.Tensor(1, shape=(), dtype=int64)
585
586  or by fetching individual elements explicitly via `get_next()`:
587
588  >>> dataset = tf.data.Dataset.range(2)
589  >>> iterator = iter(dataset)
590  >>> print(iterator.get_next())
591  tf.Tensor(0, shape=(), dtype=int64)
592  >>> print(iterator.get_next())
593  tf.Tensor(1, shape=(), dtype=int64)
594
595  In addition, non-raising iteration is supported via `get_next_as_optional()`,
596  which returns the next element (if available) wrapped in a
597  `tf.experimental.Optional`.
598
599  >>> dataset = tf.data.Dataset.from_tensors(42)
600  >>> iterator = iter(dataset)
601  >>> optional = iterator.get_next_as_optional()
602  >>> print(optional.has_value())
603  tf.Tensor(True, shape=(), dtype=bool)
604  >>> optional = iterator.get_next_as_optional()
605  >>> print(optional.has_value())
606  tf.Tensor(False, shape=(), dtype=bool)
607  """
608
609  @abc.abstractproperty
610  def element_spec(self):
611    """The type specification of an element of this iterator.
612
613    >>> dataset = tf.data.Dataset.from_tensors(42)
614    >>> iterator = iter(dataset)
615    >>> iterator.element_spec
616    tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
617
618    For more information,
619    read [this guide](https://www.tensorflow.org/guide/data#dataset_structure).
620
621    Returns:
622      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
623      element of this iterator, specifying the type of individual components.
624    """
625    raise NotImplementedError("Iterator.element_spec")
626
627  @abc.abstractmethod
628  def get_next(self):
629    """Returns the next element.
630
631    >>> dataset = tf.data.Dataset.from_tensors(42)
632    >>> iterator = iter(dataset)
633    >>> print(iterator.get_next())
634    tf.Tensor(42, shape=(), dtype=int32)
635
636    Returns:
637      A (nested) structure of values matching `tf.data.Iterator.element_spec`.
638
639    Raises:
640      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
641    """
642    raise NotImplementedError("Iterator.get_next()")
643
644  @abc.abstractmethod
645  def get_next_as_optional(self):
646    """Returns the next element warpped in `tf.experimental.Optional`.
647
648    If the iterator has reached the end of the sequence, the returned
649    `tf.experimental.Optional` will have no value.
650
651    >>> dataset = tf.data.Dataset.from_tensors(42)
652    >>> iterator = iter(dataset)
653    >>> optional = iterator.get_next_as_optional()
654    >>> print(optional.has_value())
655    tf.Tensor(True, shape=(), dtype=bool)
656    >>> print(optional.get_value())
657    tf.Tensor(42, shape=(), dtype=int32)
658    >>> optional = iterator.get_next_as_optional()
659    >>> print(optional.has_value())
660    tf.Tensor(False, shape=(), dtype=bool)
661
662    Returns:
663      A `tf.experimental.Optional` object representing the next element.
664    """
665    raise NotImplementedError("Iterator.get_next_as_optional()")
666
667
668class OwnedIterator(IteratorBase):
669  """An iterator producing tf.Tensor objects from a tf.data.Dataset.
670
671  The iterator resource  created through `OwnedIterator` is owned by the Python
672  object and the life time of the underlying resource is tied to the life time
673  of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use
674  in eager mode and inside of tf.functions.
675  """
676
677  def __init__(self, dataset=None, components=None, element_spec=None):
678    """Creates a new iterator from the given dataset.
679
680    If `dataset` is not specified, the iterator will be created from the given
681    tensor components and element structure. In particular, the alternative for
682    constructing the iterator is used when the iterator is reconstructed from
683    it `CompositeTensor` representation.
684
685    Args:
686      dataset: A `tf.data.Dataset` object.
687      components: Tensor components to construct the iterator from.
688      element_spec: A (nested) structure of `TypeSpec` objects that
689        represents the type specification of elements of the iterator.
690
691    Raises:
692      ValueError: If `dataset` is not provided and either `components` or
693        `element_spec` is not provided. Or `dataset` is provided and either
694        `components` and `element_spec` is provided.
695    """
696    super(OwnedIterator, self).__init__()
697    error_message = ("Either `dataset` or both `components` and "
698                     "`element_spec` need to be provided.")
699
700    if dataset is None:
701      if (components is None or element_spec is None):
702        raise ValueError(error_message)
703      # pylint: disable=protected-access
704      self._element_spec = element_spec
705      self._flat_output_types = structure.get_flat_tensor_types(
706          self._element_spec)
707      self._flat_output_shapes = structure.get_flat_tensor_shapes(
708          self._element_spec)
709      self._iterator_resource, self._deleter = components
710    else:
711      if (components is not None or element_spec is not None):
712        raise ValueError(error_message)
713      self._create_iterator(dataset)
714
715    self._get_next_call_count = 0
716
717  def _create_iterator(self, dataset):
718    # pylint: disable=protected-access
719    dataset = dataset._apply_debug_options()
720
721    # Store dataset reference to ensure that dataset is alive when this iterator
722    # is being used. For example, `tf.data.Dataset.from_generator` registers
723    # a few py_funcs that are needed in `self._next_internal`.  If the dataset
724    # is deleted, this iterator crashes on `self.__next__(...)` call.
725    self._dataset = dataset
726
727    ds_variant = dataset._variant_tensor
728    self._element_spec = dataset.element_spec
729    self._flat_output_types = structure.get_flat_tensor_types(
730        self._element_spec)
731    self._flat_output_shapes = structure.get_flat_tensor_shapes(
732        self._element_spec)
733    with ops.colocate_with(ds_variant):
734      self._iterator_resource, self._deleter = (
735          gen_dataset_ops.anonymous_iterator_v2(
736              output_types=self._flat_output_types,
737              output_shapes=self._flat_output_shapes))
738      gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
739      # Delete the resource when this object is deleted
740      self._resource_deleter = IteratorResourceDeleter(
741          handle=self._iterator_resource,
742          deleter=self._deleter)
743
744  def __iter__(self):
745    return self
746
747  def next(self):  # For Python 2 compatibility
748    return self.__next__()
749
750  def _next_internal(self):
751    autograph_status = autograph_ctx.control_status_ctx().status
752    autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED
753    if not context.executing_eagerly() and autograph_disabled:
754      self._get_next_call_count += 1
755      if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD:
756        raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE)
757
758    if not context.executing_eagerly():
759      # TODO(b/169442955): Investigate the need for this colocation constraint.
760      with ops.colocate_with(self._iterator_resource):
761        ret = gen_dataset_ops.iterator_get_next(
762            self._iterator_resource,
763            output_types=self._flat_output_types,
764            output_shapes=self._flat_output_shapes)
765      return structure.from_compatible_tensor_list(self._element_spec, ret)
766
767    # TODO(b/77291417): This runs in sync mode as iterators use an error status
768    # to communicate that there is no more data to iterate over.
769    with context.execution_mode(context.SYNC):
770      ret = gen_dataset_ops.iterator_get_next(
771          self._iterator_resource,
772          output_types=self._flat_output_types,
773          output_shapes=self._flat_output_shapes)
774
775      try:
776        # Fast path for the case `self._structure` is not a nested structure.
777        return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
778      except AttributeError:
779        return structure.from_compatible_tensor_list(self._element_spec, ret)
780
781  @property
782  def _type_spec(self):
783    return IteratorSpec(self.element_spec)
784
785  def __next__(self):
786    try:
787      return self._next_internal()
788    except errors.OutOfRangeError:
789      raise StopIteration
790
791  @property
792  @deprecation.deprecated(
793      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
794  def output_classes(self):
795    """Returns the class of each component of an element of this iterator.
796
797    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
798
799    Returns:
800      A (nested) structure of Python `type` objects corresponding to each
801      component of an element of this dataset.
802    """
803    return nest.map_structure(
804        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
805        self._element_spec)
806
807  @property
808  @deprecation.deprecated(
809      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
810  def output_shapes(self):
811    """Returns the shape of each component of an element of this iterator.
812
813    Returns:
814      A (nested) structure of `tf.TensorShape` objects corresponding to each
815      component of an element of this dataset.
816    """
817    return nest.map_structure(
818        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
819        self._element_spec)
820
821  @property
822  @deprecation.deprecated(
823      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
824  def output_types(self):
825    """Returns the type of each component of an element of this iterator.
826
827    Returns:
828      A (nested) structure of `tf.DType` objects corresponding to each component
829      of an element of this dataset.
830    """
831    return nest.map_structure(
832        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
833        self._element_spec)
834
835  @property
836  def element_spec(self):
837    return self._element_spec
838
839  def get_next(self):
840    return self._next_internal()
841
842  def get_next_as_optional(self):
843    # TODO(b/169442955): Investigate the need for this colocation constraint.
844    with ops.colocate_with(self._iterator_resource):
845      # pylint: disable=protected-access
846      return optional_ops._OptionalImpl(
847          gen_dataset_ops.iterator_get_next_as_optional(
848              self._iterator_resource,
849              output_types=structure.get_flat_tensor_types(self.element_spec),
850              output_shapes=structure.get_flat_tensor_shapes(
851                  self.element_spec)), self.element_spec)
852
853  def _gather_saveables_for_checkpoint(self):
854
855    def _saveable_factory(name):
856      """Returns a SaveableObject for serialization/deserialization."""
857      policy = None
858      if self._dataset:
859        policy = self._dataset.options().experimental_external_state_policy
860      if policy:
861        return _IteratorSaveable(
862            self._iterator_resource,
863            name,
864            external_state_policy=policy)
865      else:
866        return _IteratorSaveable(self._iterator_resource, name)
867
868    return {"ITERATOR": _saveable_factory}
869
870
871@tf_export("data.IteratorSpec", v1=[])
872class IteratorSpec(type_spec.TypeSpec):
873  """Type specification for `tf.data.Iterator`.
874
875  For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
876  takes `tf.data.Iterator` as an input argument:
877
878  >>> @tf.function(input_signature=[tf.data.IteratorSpec(
879  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
880  ... def square(iterator):
881  ...   x = iterator.get_next()
882  ...   return x * x
883  >>> dataset = tf.data.Dataset.from_tensors(5)
884  >>> iterator = iter(dataset)
885  >>> print(square(iterator))
886  tf.Tensor(25, shape=(), dtype=int32)
887
888  Attributes:
889    element_spec: A (nested) structure of `tf.TypeSpec` objects that represents
890      the type specification of the iterator elements.
891  """
892
893  __slots__ = ["_element_spec"]
894
895  def __init__(self, element_spec):
896    self._element_spec = element_spec
897
898  @property
899  def value_type(self):
900    return OwnedIterator
901
902  def _serialize(self):
903    return (self._element_spec,)
904
905  @property
906  def _component_specs(self):
907    return (
908        tensor_spec.TensorSpec([], dtypes.resource),
909        tensor_spec.TensorSpec([], dtypes.variant),
910    )
911
912  def _to_components(self, value):
913    return (value._iterator_resource, value._deleter)  # pylint: disable=protected-access
914
915  def _from_components(self, components):
916    return OwnedIterator(
917        dataset=None,
918        components=components,
919        element_spec=self._element_spec)
920
921  @staticmethod
922  def from_value(value):
923    return IteratorSpec(value.element_spec)  # pylint: disable=protected-access
924
925
926# TODO(b/71645805): Expose trackable stateful objects from dataset.
927class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
928  """SaveableObject for saving/restoring iterator state."""
929
930  def __init__(
931      self,
932      iterator_resource,
933      name,
934      external_state_policy=options_lib.ExternalStatePolicy.FAIL):
935    serialized_iterator = gen_dataset_ops.serialize_iterator(
936        iterator_resource, external_state_policy=external_state_policy.value)
937    specs = [
938        BaseSaverBuilder.SaveSpec(
939            serialized_iterator,
940            "",
941            name + "_STATE",
942            device=iterator_resource.device)
943    ]
944    super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
945
946  def restore(self, restored_tensors, restored_shapes):
947    with ops.colocate_with(self.op):
948      return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
949
950
951@deprecation.deprecated(
952    None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
953@tf_export("data.experimental.get_next_as_optional")
954def get_next_as_optional(iterator):
955  """Returns a `tf.experimental.Optional` with the next element of the iterator.
956
957  If the iterator has reached the end of the sequence, the returned
958  `tf.experimental.Optional` will have no value.
959
960  Args:
961    iterator: A `tf.data.Iterator`.
962
963  Returns:
964    A `tf.experimental.Optional` object which either contains the next element
965    of the iterator (if it exists) or no value.
966  """
967  return iterator.get_next_as_optional()
968
969
970_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator)
971