• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Iterators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import threading
22import warnings
23
24import six
25
26from tensorflow.python.data.experimental.ops import distribute_options
27from tensorflow.python.data.ops import optional_ops
28from tensorflow.python.data.util import nest
29from tensorflow.python.data.util import structure
30from tensorflow.python.eager import context
31from tensorflow.python.framework import composite_tensor
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import tensor_spec
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import gen_dataset_ops
39from tensorflow.python.training.saver import BaseSaverBuilder
40from tensorflow.python.training.tracking import base as trackable
41from tensorflow.python.util import deprecation
42from tensorflow.python.util.compat import collections_abc
43from tensorflow.python.util.tf_export import tf_export
44
45
46# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple
47# times, e.g. when you are distributing different elements to multiple
48# devices in a single step. However, a common pitfall arises when
49# users call `Iterator.get_next()` in each iteration of their training
50# loop. `Iterator.get_next()` adds ops to the graph, and executing
51# each op allocates resources (including threads); as a consequence,
52# invoking it in every iteration of a training loop causes slowdown
53# and eventual resource exhaustion. To guard against this outcome, we
54# log a warning when the number of uses crosses a threshold of suspicion.
55GET_NEXT_CALL_WARNING_THRESHOLD = 32
56
57GET_NEXT_CALL_WARNING_MESSAGE = (
58    "An unusually high number of `Iterator.get_next()` calls was detected. "
59    "This often indicates that `Iterator.get_next()` is being called inside "
60    "a training loop, which will cause gradual slowdown and eventual resource "
61    "exhaustion. If this is the case, restructure your code to call "
62    "`next_element = iterator.get_next()` once outside the loop, and use "
63    "`next_element` as the input to some computation that is invoked inside "
64    "the loop.")
65
66# Collection of all IteratorResources in the `Graph`.
67GLOBAL_ITERATORS = "iterators"
68
69
70def _device_stack_is_empty():
71  if context.executing_eagerly():
72    return context.context().device_name is None
73  # pylint: disable=protected-access
74  device_stack = ops.get_default_graph()._device_functions_outer_to_inner
75  # pylint: enable=protected-access
76  return not bool(device_stack)
77
78
79@tf_export(v1=["data.Iterator"])
80class Iterator(trackable.Trackable):
81  """Represents the state of iterating through a `Dataset`."""
82
83  def __init__(self, iterator_resource, initializer, output_types,
84               output_shapes, output_classes):
85    """Creates a new iterator from the given iterator resource.
86
87    Note: Most users will not call this initializer directly, and will
88    instead use `Dataset.make_initializable_iterator()` or
89    `Dataset.make_one_shot_iterator()`.
90
91    Args:
92      iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the
93        iterator.
94      initializer: A `tf.Operation` that should be run to initialize this
95        iterator.
96      output_types: A nested structure of `tf.DType` objects corresponding to
97        each component of an element of this iterator.
98      output_shapes: A nested structure of `tf.TensorShape` objects
99        corresponding to each component of an element of this iterator.
100      output_classes: A nested structure of Python `type` objects corresponding
101        to each component of an element of this iterator.
102    """
103    self._iterator_resource = iterator_resource
104    self._initializer = initializer
105
106    if (output_types is None or output_shapes is None
107        or output_classes is None):
108      raise ValueError("If `structure` is not specified, all of "
109                       "`output_types`, `output_shapes`, and `output_classes`"
110                       " must be specified.")
111    self._element_spec = structure.convert_legacy_structure(
112        output_types, output_shapes, output_classes)
113    self._flat_tensor_shapes = structure.get_flat_tensor_shapes(
114        self._element_spec)
115    self._flat_tensor_types = structure.get_flat_tensor_types(
116        self._element_spec)
117
118    self._string_handle = gen_dataset_ops.iterator_to_string_handle(
119        self._iterator_resource)
120    self._get_next_call_count = 0
121    ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
122
123  @staticmethod
124  def from_structure(output_types,
125                     output_shapes=None,
126                     shared_name=None,
127                     output_classes=None):
128    """Creates a new, uninitialized `Iterator` with the given structure.
129
130    This iterator-constructing method can be used to create an iterator that
131    is reusable with many different datasets.
132
133    The returned iterator is not bound to a particular dataset, and it has
134    no `initializer`. To initialize the iterator, run the operation returned by
135    `Iterator.make_initializer(dataset)`.
136
137    The following is an example
138
139    ```python
140    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
141
142    dataset_range = Dataset.range(10)
143    range_initializer = iterator.make_initializer(dataset_range)
144
145    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
146    evens_initializer = iterator.make_initializer(dataset_evens)
147
148    # Define a model based on the iterator; in this example, the model_fn
149    # is expected to take scalar tf.int64 Tensors as input (see
150    # the definition of 'iterator' above).
151    prediction, loss = model_fn(iterator.get_next())
152
153    # Train for `num_epochs`, where for each epoch, we first iterate over
154    # dataset_range, and then iterate over dataset_evens.
155    for _ in range(num_epochs):
156      # Initialize the iterator to `dataset_range`
157      sess.run(range_initializer)
158      while True:
159        try:
160          pred, loss_val = sess.run([prediction, loss])
161        except tf.errors.OutOfRangeError:
162          break
163
164      # Initialize the iterator to `dataset_evens`
165      sess.run(evens_initializer)
166      while True:
167        try:
168          pred, loss_val = sess.run([prediction, loss])
169        except tf.errors.OutOfRangeError:
170          break
171    ```
172
173    Args:
174      output_types: A nested structure of `tf.DType` objects corresponding to
175        each component of an element of this dataset.
176      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
177        corresponding to each component of an element of this dataset. If
178        omitted, each component will have an unconstrainted shape.
179      shared_name: (Optional.) If non-empty, this iterator will be shared under
180        the given name across multiple sessions that share the same devices
181        (e.g. when using a remote server).
182      output_classes: (Optional.) A nested structure of Python `type` objects
183        corresponding to each component of an element of this iterator. If
184        omitted, each component is assumed to be of type `tf.Tensor`.
185
186    Returns:
187      An `Iterator`.
188
189    Raises:
190      TypeError: If the structures of `output_shapes` and `output_types` are
191        not the same.
192    """
193    output_types = nest.map_structure(dtypes.as_dtype, output_types)
194    if output_shapes is None:
195      output_shapes = nest.map_structure(
196          lambda _: tensor_shape.TensorShape(None), output_types)
197    else:
198      output_shapes = nest.map_structure_up_to(output_types,
199                                               tensor_shape.as_shape,
200                                               output_shapes)
201    if output_classes is None:
202      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
203    nest.assert_same_structure(output_types, output_shapes)
204    output_structure = structure.convert_legacy_structure(
205        output_types, output_shapes, output_classes)
206    if shared_name is None:
207      shared_name = ""
208    iterator_resource = gen_dataset_ops.iterator_v2(
209        container="",
210        shared_name=shared_name,
211        output_types=structure.get_flat_tensor_types(output_structure),
212        output_shapes=structure.get_flat_tensor_shapes(
213            output_structure))
214    return Iterator(iterator_resource, None, output_types, output_shapes,
215                    output_classes)
216
217  @staticmethod
218  def from_string_handle(string_handle,
219                         output_types,
220                         output_shapes=None,
221                         output_classes=None):
222    """Creates a new, uninitialized `Iterator` based on the given handle.
223
224    This method allows you to define a "feedable" iterator where you can choose
225    between concrete iterators by feeding a value in a `tf.Session.run` call.
226    In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you
227    would
228    feed it with the value of `tf.data.Iterator.string_handle` in each step.
229
230    For example, if you had two iterators that marked the current position in
231    a training dataset and a test dataset, you could choose which to use in
232    each step as follows:
233
234    ```python
235    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
236    train_iterator_handle = sess.run(train_iterator.string_handle())
237
238    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
239    test_iterator_handle = sess.run(test_iterator.string_handle())
240
241    handle = tf.compat.v1.placeholder(tf.string, shape=[])
242    iterator = tf.data.Iterator.from_string_handle(
243        handle, train_iterator.output_types)
244
245    next_element = iterator.get_next()
246    loss = f(next_element)
247
248    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
249    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
250    ```
251
252    Args:
253      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to
254        a handle produced by the `Iterator.string_handle()` method.
255      output_types: A nested structure of `tf.DType` objects corresponding to
256        each component of an element of this dataset.
257      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
258        corresponding to each component of an element of this dataset. If
259        omitted, each component will have an unconstrainted shape.
260      output_classes: (Optional.) A nested structure of Python `type` objects
261        corresponding to each component of an element of this iterator. If
262        omitted, each component is assumed to be of type `tf.Tensor`.
263
264    Returns:
265      An `Iterator`.
266    """
267    output_types = nest.map_structure(dtypes.as_dtype, output_types)
268    if output_shapes is None:
269      output_shapes = nest.map_structure(
270          lambda _: tensor_shape.TensorShape(None), output_types)
271    else:
272      output_shapes = nest.map_structure_up_to(output_types,
273                                               tensor_shape.as_shape,
274                                               output_shapes)
275    if output_classes is None:
276      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
277    nest.assert_same_structure(output_types, output_shapes)
278    output_structure = structure.convert_legacy_structure(
279        output_types, output_shapes, output_classes)
280    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
281    iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
282        string_handle,
283        output_types=structure.get_flat_tensor_types(output_structure),
284        output_shapes=structure.get_flat_tensor_shapes(output_structure))
285    return Iterator(iterator_resource, None, output_types, output_shapes,
286                    output_classes)
287
288  @property
289  def initializer(self):
290    """A `tf.Operation` that should be run to initialize this iterator.
291
292    Returns:
293      A `tf.Operation` that should be run to initialize this iterator
294
295    Raises:
296      ValueError: If this iterator initializes itself automatically.
297    """
298    if self._initializer is not None:
299      return self._initializer
300    else:
301      # TODO(mrry): Consider whether one-shot iterators should have
302      # initializers that simply reset their state to the beginning.
303      raise ValueError("Iterator does not have an initializer.")
304
305  def make_initializer(self, dataset, name=None):
306    """Returns a `tf.Operation` that initializes this iterator on `dataset`.
307
308    Args:
309      dataset: A `Dataset` with compatible structure to this iterator.
310      name: (Optional.) A name for the created operation.
311
312    Returns:
313      A `tf.Operation` that can be run to initialize this iterator on the given
314      `dataset`.
315
316    Raises:
317      TypeError: If `dataset` and this iterator do not have a compatible
318        element structure.
319    """
320    with ops.name_scope(name, "make_initializer") as name:
321      # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due
322      # to that creating a circular dependency.
323      # pylint: disable=protected-access
324      dataset_output_types = nest.map_structure(
325          lambda component_spec: component_spec._to_legacy_output_types(),
326          dataset.element_spec)
327      dataset_output_shapes = nest.map_structure(
328          lambda component_spec: component_spec._to_legacy_output_shapes(),
329          dataset.element_spec)
330      dataset_output_classes = nest.map_structure(
331          lambda component_spec: component_spec._to_legacy_output_classes(),
332          dataset.element_spec)
333      # pylint: enable=protected-access
334
335      nest.assert_same_structure(self.output_types, dataset_output_types)
336      nest.assert_same_structure(self.output_shapes, dataset_output_shapes)
337      for iterator_class, dataset_class in zip(
338          nest.flatten(self.output_classes),
339          nest.flatten(dataset_output_classes)):
340        if iterator_class is not dataset_class:
341          raise TypeError(
342              "Expected output classes %r but got dataset with output class %r."
343              % (self.output_classes, dataset_output_classes))
344      for iterator_dtype, dataset_dtype in zip(
345          nest.flatten(self.output_types), nest.flatten(dataset_output_types)):
346        if iterator_dtype != dataset_dtype:
347          raise TypeError(
348              "Expected output types %r but got dataset with output types %r." %
349              (self.output_types, dataset_output_types))
350      for iterator_shape, dataset_shape in zip(
351          nest.flatten(self.output_shapes), nest.flatten(
352              dataset_output_shapes)):
353        if not iterator_shape.is_compatible_with(dataset_shape):
354          raise TypeError("Expected output shapes compatible with %r but got "
355                          "dataset with output shapes %r." %
356                          (self.output_shapes, dataset_output_shapes))
357
358    # TODO(b/169442955): Investigate the need for this colocation constraint.
359    with ops.colocate_with(self._iterator_resource):
360      # pylint: disable=protected-access
361      return gen_dataset_ops.make_iterator(
362          dataset._variant_tensor, self._iterator_resource, name=name)
363
364  def get_next(self, name=None):
365    """Returns a nested structure of `tf.Tensor`s representing the next element.
366
367    In graph mode, you should typically call this method *once* and use its
368    result as the input to another computation. A typical loop will then call
369    `tf.Session.run` on the result of that computation. The loop will terminate
370    when the `Iterator.get_next()` operation raises
371    `tf.errors.OutOfRangeError`. The following skeleton shows how to use
372    this method when building a training loop:
373
374    ```python
375    dataset = ...  # A `tf.data.Dataset` object.
376    iterator = dataset.make_initializable_iterator()
377    next_element = iterator.get_next()
378
379    # Build a TensorFlow graph that does something with each element.
380    loss = model_function(next_element)
381    optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
382    train_op = optimizer.minimize(loss)
383
384    with tf.compat.v1.Session() as sess:
385      try:
386        while True:
387          sess.run(train_op)
388      except tf.errors.OutOfRangeError:
389        pass
390    ```
391
392    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
393    when you are distributing different elements to multiple devices in a single
394    step. However, a common pitfall arises when users call `Iterator.get_next()`
395    in each iteration of their training loop. `Iterator.get_next()` adds ops to
396    the graph, and executing each op allocates resources (including threads); as
397    a consequence, invoking it in every iteration of a training loop causes
398    slowdown and eventual resource exhaustion. To guard against this outcome, we
399    log a warning when the number of uses crosses a fixed threshold of
400    suspiciousness.
401
402    Args:
403      name: (Optional.) A name for the created operation.
404
405    Returns:
406      A nested structure of `tf.Tensor` objects.
407    """
408    self._get_next_call_count += 1
409    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
410      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
411
412    # TODO(b/169442955): Investigate the need for this colocation constraint.
413    with ops.colocate_with(self._iterator_resource):
414      # pylint: disable=protected-access
415      flat_ret = gen_dataset_ops.iterator_get_next(
416          self._iterator_resource,
417          output_types=self._flat_tensor_types,
418          output_shapes=self._flat_tensor_shapes,
419          name=name)
420      return structure.from_tensor_list(self._element_spec, flat_ret)
421
422  def get_next_as_optional(self):
423    # TODO(b/169442955): Investigate the need for this colocation constraint.
424    with ops.colocate_with(self._iterator_resource):
425      # pylint: disable=protected-access
426      return optional_ops._OptionalImpl(
427          gen_dataset_ops.iterator_get_next_as_optional(
428              self._iterator_resource,
429              output_types=structure.get_flat_tensor_types(self.element_spec),
430              output_shapes=structure.get_flat_tensor_shapes(
431                  self.element_spec)), self.element_spec)
432
433  def string_handle(self, name=None):
434    """Returns a string-valued `tf.Tensor` that represents this iterator.
435
436    Args:
437      name: (Optional.) A name for the created operation.
438
439    Returns:
440      A scalar `tf.Tensor` of type `tf.string`.
441    """
442    if name is None:
443      return self._string_handle
444    else:
445      return gen_dataset_ops.iterator_to_string_handle(
446          self._iterator_resource, name=name)
447
448  @property
449  @deprecation.deprecated(
450      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
451  def output_classes(self):
452    """Returns the class of each component of an element of this iterator.
453
454    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
455
456    Returns:
457      A nested structure of Python `type` objects corresponding to each
458      component of an element of this dataset.
459    """
460    return nest.map_structure(
461        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
462        self._element_spec)
463
464  @property
465  @deprecation.deprecated(
466      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
467  def output_shapes(self):
468    """Returns the shape of each component of an element of this iterator.
469
470    Returns:
471      A nested structure of `tf.TensorShape` objects corresponding to each
472      component of an element of this dataset.
473    """
474    return nest.map_structure(
475        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
476        self._element_spec)
477
478  @property
479  @deprecation.deprecated(
480      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
481  def output_types(self):
482    """Returns the type of each component of an element of this iterator.
483
484    Returns:
485      A nested structure of `tf.DType` objects corresponding to each component
486      of an element of this dataset.
487    """
488    return nest.map_structure(
489        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
490        self._element_spec)
491
492  @property
493  def element_spec(self):
494    return self._element_spec
495
496  def _gather_saveables_for_checkpoint(self):
497
498    def _saveable_factory(name):
499      return _IteratorSaveable(self._iterator_resource, name)
500
501    return {"ITERATOR": _saveable_factory}
502
503
504_uid_counter = 0
505_uid_lock = threading.Lock()
506
507
508def _generate_shared_name(prefix):
509  with _uid_lock:
510    global _uid_counter
511    uid = _uid_counter
512    _uid_counter += 1
513  return "{}{}".format(prefix, uid)
514
515
516class IteratorResourceDeleter(object):
517  """An object which cleans up an iterator resource handle.
518
519  An alternative to defining a __del__ method on an object. Even if the parent
520  object is part of a reference cycle, the cycle will be collectable.
521  """
522
523  __slots__ = ["_deleter", "_handle", "_eager_mode"]
524
525  def __init__(self, handle, deleter):
526    self._deleter = deleter
527    self._handle = handle
528    self._eager_mode = context.executing_eagerly()
529
530  def __del__(self):
531    # Make sure the resource is deleted in the same mode as it was created in.
532    if self._eager_mode:
533      with context.eager_mode():
534        gen_dataset_ops.delete_iterator(
535            handle=self._handle, deleter=self._deleter)
536    else:
537      with context.graph_mode():
538        gen_dataset_ops.delete_iterator(
539            handle=self._handle, deleter=self._deleter)
540
541
542@tf_export("data.Iterator", v1=[])
543@six.add_metaclass(abc.ABCMeta)
544class IteratorBase(collections_abc.Iterator, trackable.Trackable,
545                   composite_tensor.CompositeTensor):
546  """Represents an iterator of a `tf.data.Dataset`.
547
548  `tf.data.Iterator` is the primary mechanism for enumerating elements of a
549  `tf.data.Dataset`. It supports the Python Iterator protocol, which means
550  it can be iterated over using a for-loop:
551
552  >>> dataset = tf.data.Dataset.range(2)
553  >>> for element in dataset:
554  ...   print(element)
555  tf.Tensor(0, shape=(), dtype=int64)
556  tf.Tensor(1, shape=(), dtype=int64)
557
558  or by fetching individual elements explicitly via `get_next()`:
559
560  >>> dataset = tf.data.Dataset.range(2)
561  >>> iterator = iter(dataset)
562  >>> print(iterator.get_next())
563  tf.Tensor(0, shape=(), dtype=int64)
564  >>> print(iterator.get_next())
565  tf.Tensor(1, shape=(), dtype=int64)
566
567  In addition, non-raising iteration is supported via `get_next_as_optional()`,
568  which returns the next element (if available) wrapped in a
569  `tf.experimental.Optional`.
570
571  >>> dataset = tf.data.Dataset.from_tensors(42)
572  >>> iterator = iter(dataset)
573  >>> optional = iterator.get_next_as_optional()
574  >>> print(optional.has_value())
575  tf.Tensor(True, shape=(), dtype=bool)
576  >>> optional = iterator.get_next_as_optional()
577  >>> print(optional.has_value())
578  tf.Tensor(False, shape=(), dtype=bool)
579  """
580
581  @abc.abstractproperty
582  def element_spec(self):
583    """The type specification of an element of this iterator.
584
585    >>> dataset = tf.data.Dataset.from_tensors(42)
586    >>> iterator = iter(dataset)
587    >>> iterator.element_spec
588    tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
589
590    Returns:
591      A nested structure of `tf.TypeSpec` objects matching the structure of an
592      element of this iterator, specifying the type of individual components.
593    """
594    raise NotImplementedError("Iterator.element_spec")
595
596  @abc.abstractmethod
597  def get_next(self):
598    """Returns a nested structure of `tf.Tensor`s containing the next element.
599
600    >>> dataset = tf.data.Dataset.from_tensors(42)
601    >>> iterator = iter(dataset)
602    >>> print(iterator.get_next())
603    tf.Tensor(42, shape=(), dtype=int32)
604
605    Returns:
606      A nested structure of `tf.Tensor` objects.
607
608    Raises:
609      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
610    """
611    raise NotImplementedError("Iterator.get_next()")
612
613  @abc.abstractmethod
614  def get_next_as_optional(self):
615    """Returns a `tf.experimental.Optional` which contains the next element.
616
617    If the iterator has reached the end of the sequence, the returned
618    `tf.experimental.Optional` will have no value.
619
620    >>> dataset = tf.data.Dataset.from_tensors(42)
621    >>> iterator = iter(dataset)
622    >>> optional = iterator.get_next_as_optional()
623    >>> print(optional.has_value())
624    tf.Tensor(True, shape=(), dtype=bool)
625    >>> print(optional.get_value())
626    tf.Tensor(42, shape=(), dtype=int32)
627    >>> optional = iterator.get_next_as_optional()
628    >>> print(optional.has_value())
629    tf.Tensor(False, shape=(), dtype=bool)
630
631    Returns:
632      A `tf.experimental.Optional` object representing the next element.
633    """
634    raise NotImplementedError("Iterator.get_next_as_optional()")
635
636
637class OwnedIterator(IteratorBase):
638  """An iterator producing tf.Tensor objects from a tf.data.Dataset.
639
640  The iterator resource  created through `OwnedIterator` is owned by the Python
641  object and the life time of the underlying resource is tied to the life time
642  of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use
643  in eager mode and inside of tf.functions.
644  """
645
646  def __init__(self, dataset=None, components=None, element_spec=None):
647    """Creates a new iterator from the given dataset.
648
649    If `dataset` is not specified, the iterator will be created from the given
650    tensor components and element structure. In particular, the alternative for
651    constructing the iterator is used when the iterator is reconstructed from
652    it `CompositeTensor` representation.
653
654    Args:
655      dataset: A `tf.data.Dataset` object.
656      components: Tensor components to construct the iterator from.
657      element_spec: A nested structure of `TypeSpec` objects that
658        represents the type specification of elements of the iterator.
659
660    Raises:
661      ValueError: If `dataset` is not provided and either `components` or
662        `element_spec` is not provided. Or `dataset` is provided and either
663        `components` and `element_spec` is provided.
664    """
665    super(OwnedIterator, self).__init__()
666    error_message = ("Either `dataset` or both `components` and "
667                     "`element_spec` need to be provided.")
668
669    if dataset is None:
670      if (components is None or element_spec is None):
671        raise ValueError(error_message)
672      # pylint: disable=protected-access
673      self._element_spec = element_spec
674      self._flat_output_types = structure.get_flat_tensor_types(
675          self._element_spec)
676      self._flat_output_shapes = structure.get_flat_tensor_shapes(
677          self._element_spec)
678      self._iterator_resource, self._deleter = components
679    else:
680      if (components is not None or element_spec is not None):
681        raise ValueError(error_message)
682      self._create_iterator(dataset)
683
684  def _create_iterator(self, dataset):
685    # pylint: disable=protected-access
686    dataset = dataset._apply_options()
687
688    # Store dataset reference to ensure that dataset is alive when this iterator
689    # is being used. For example, `tf.data.Dataset.from_generator` registers
690    # a few py_funcs that are needed in `self._next_internal`.  If the dataset
691    # is deleted, this iterator crashes on `self.__next__(...)` call.
692    self._dataset = dataset
693
694    ds_variant = dataset._variant_tensor
695    self._element_spec = dataset.element_spec
696    self._flat_output_types = structure.get_flat_tensor_types(
697        self._element_spec)
698    self._flat_output_shapes = structure.get_flat_tensor_shapes(
699        self._element_spec)
700    with ops.colocate_with(ds_variant):
701      self._iterator_resource, self._deleter = (
702          gen_dataset_ops.anonymous_iterator_v2(
703              output_types=self._flat_output_types,
704              output_shapes=self._flat_output_shapes))
705      gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource)
706      # Delete the resource when this object is deleted
707      self._resource_deleter = IteratorResourceDeleter(
708          handle=self._iterator_resource,
709          deleter=self._deleter)
710
711  def __iter__(self):
712    return self
713
714  def next(self):  # For Python 2 compatibility
715    return self.__next__()
716
717  def _next_internal(self):
718    if not context.executing_eagerly():
719      # TODO(b/169442955): Investigate the need for this colocation constraint.
720      with ops.colocate_with(self._iterator_resource):
721        ret = gen_dataset_ops.iterator_get_next(
722            self._iterator_resource,
723            output_types=self._flat_output_types,
724            output_shapes=self._flat_output_shapes)
725      return structure.from_compatible_tensor_list(self._element_spec, ret)
726
727    # TODO(b/77291417): This runs in sync mode as iterators use an error status
728    # to communicate that there is no more data to iterate over.
729    with context.execution_mode(context.SYNC):
730      ret = gen_dataset_ops.iterator_get_next(
731          self._iterator_resource,
732          output_types=self._flat_output_types,
733          output_shapes=self._flat_output_shapes)
734
735      try:
736        # Fast path for the case `self._structure` is not a nested structure.
737        return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
738      except AttributeError:
739        return structure.from_compatible_tensor_list(self._element_spec, ret)
740
741  @property
742  def _type_spec(self):
743    return IteratorSpec(self.element_spec)
744
745  def __next__(self):
746    try:
747      return self._next_internal()
748    except errors.OutOfRangeError:
749      raise StopIteration
750
751  @property
752  @deprecation.deprecated(
753      None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.")
754  def output_classes(self):
755    """Returns the class of each component of an element of this iterator.
756
757    The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`.
758
759    Returns:
760      A nested structure of Python `type` objects corresponding to each
761      component of an element of this dataset.
762    """
763    return nest.map_structure(
764        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
765        self._element_spec)
766
767  @property
768  @deprecation.deprecated(
769      None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.")
770  def output_shapes(self):
771    """Returns the shape of each component of an element of this iterator.
772
773    Returns:
774      A nested structure of `tf.TensorShape` objects corresponding to each
775      component of an element of this dataset.
776    """
777    return nest.map_structure(
778        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
779        self._element_spec)
780
781  @property
782  @deprecation.deprecated(
783      None, "Use `tf.compat.v1.data.get_output_types(iterator)`.")
784  def output_types(self):
785    """Returns the type of each component of an element of this iterator.
786
787    Returns:
788      A nested structure of `tf.DType` objects corresponding to each component
789      of an element of this dataset.
790    """
791    return nest.map_structure(
792        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
793        self._element_spec)
794
795  @property
796  def element_spec(self):
797    return self._element_spec
798
799  def get_next(self):
800    return self._next_internal()
801
802  def get_next_as_optional(self):
803    # TODO(b/169442955): Investigate the need for this colocation constraint.
804    with ops.colocate_with(self._iterator_resource):
805      # pylint: disable=protected-access
806      return optional_ops._OptionalImpl(
807          gen_dataset_ops.iterator_get_next_as_optional(
808              self._iterator_resource,
809              output_types=structure.get_flat_tensor_types(self.element_spec),
810              output_shapes=structure.get_flat_tensor_shapes(
811                  self.element_spec)), self.element_spec)
812
813  def _gather_saveables_for_checkpoint(self):
814
815    def _saveable_factory(name):
816      """Returns a SaveableObject for serialization/deserialization."""
817      policy = None
818      if self._dataset:
819        policy = self._dataset.options().experimental_external_state_policy
820      if policy:
821        return _IteratorSaveable(
822            self._iterator_resource,
823            name,
824            external_state_policy=policy)
825      else:
826        return _IteratorSaveable(self._iterator_resource, name)
827
828    return {"ITERATOR": _saveable_factory}
829
830
831@tf_export("data.IteratorSpec", v1=[])
832class IteratorSpec(type_spec.TypeSpec):
833  """Type specification for `tf.data.Iterator`.
834
835  For instance, `tf.data.IteratorSpec` can be used to define a tf.function that
836  takes `tf.data.Iterator` as an input argument:
837
838  >>> @tf.function(input_signature=[tf.data.IteratorSpec(
839  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
840  ... def square(iterator):
841  ...   x = iterator.get_next()
842  ...   return x * x
843  >>> dataset = tf.data.Dataset.from_tensors(5)
844  >>> iterator = iter(dataset)
845  >>> print(square(iterator))
846  tf.Tensor(25, shape=(), dtype=int32)
847
848  Attributes:
849    element_spec: A nested structure of `TypeSpec` objects that represents the
850      type specification of the iterator elements.
851  """
852
853  __slots__ = ["_element_spec"]
854
855  def __init__(self, element_spec):
856    self._element_spec = element_spec
857
858  @property
859  def value_type(self):
860    return OwnedIterator
861
862  def _serialize(self):
863    return (self._element_spec,)
864
865  @property
866  def _component_specs(self):
867    return (
868        tensor_spec.TensorSpec([], dtypes.resource),
869        tensor_spec.TensorSpec([], dtypes.variant),
870    )
871
872  def _to_components(self, value):
873    return (value._iterator_resource, value._deleter)  # pylint: disable=protected-access
874
875  def _from_components(self, components):
876    return OwnedIterator(
877        dataset=None,
878        components=components,
879        element_spec=self._element_spec)
880
881  @staticmethod
882  def from_value(value):
883    return IteratorSpec(value.element_spec)  # pylint: disable=protected-access
884
885
886# TODO(b/71645805): Expose trackable stateful objects from dataset.
887class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
888  """SaveableObject for saving/restoring iterator state."""
889
890  def __init__(
891      self,
892      iterator_resource,
893      name,
894      external_state_policy=distribute_options.ExternalStatePolicy.FAIL):
895    serialized_iterator = gen_dataset_ops.serialize_iterator(
896        iterator_resource, external_state_policy=external_state_policy.value)
897    specs = [
898        BaseSaverBuilder.SaveSpec(
899            serialized_iterator,
900            "",
901            name + "_STATE",
902            device=iterator_resource.device)
903    ]
904    super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
905
906  def restore(self, restored_tensors, restored_shapes):
907    with ops.colocate_with(self.op):
908      return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
909
910
911@deprecation.deprecated(
912    None, "Use `tf.data.Iterator.get_next_as_optional()` instead.")
913@tf_export("data.experimental.get_next_as_optional")
914def get_next_as_optional(iterator):
915  """Returns a `tf.experimental.Optional` with the next element of the iterator.
916
917  If the iterator has reached the end of the sequence, the returned
918  `tf.experimental.Optional` will have no value.
919
920  Args:
921    iterator: A `tf.data.Iterator`.
922
923  Returns:
924    A `tf.experimental.Optional` object which either contains the next element
925    of the iterator (if it exists) or no value.
926  """
927  return iterator.get_next_as_optional()
928