• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Iteration over tf.data.Datasets when eager execution is enabled."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21import threading
23from tensorflow.contrib.data.python.ops import prefetching_ops
24from tensorflow.python.data.ops import iterator_ops
25from tensorflow.python.data.util import nest
26from tensorflow.python.data.util import sparse
27from tensorflow.python.eager import context
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import function
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import gen_dataset_ops
34from tensorflow.python.ops import resource_variable_ops
36_uid_counter = 0
37_uid_lock = threading.Lock()
40def _generate_shared_name(prefix):
41  with _uid_lock:
42    global _uid_counter
43    uid = _uid_counter
44    _uid_counter += 1
45  return "{}{}".format(prefix, uid)
48class Iterator(object):
49  """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
51  def __init__(self, dataset):
52    """Creates a new iterator over the given dataset.
54    For example:
55    ```python
56    dataset = tf.data.Dataset.range(4)
57    for x in Iterator(dataset):
58      print(x)
59    ```
61    Tensors produced will be placed on the device on which this iterator object
62    was created.
64    Args:
65      dataset: A `tf.data.Dataset` object.
67    Raises:
68      RuntimeError: When invoked without eager execution enabled.
69    """
71    if not context.in_eager_mode():
72      raise RuntimeError(
73          "{} objects can only be used when eager execution is enabled, use "
74          "tf.data.Dataset.make_iterator or "
75          "tf.data.Dataset.make_one_shot_iterator for graph construction".
76          format(type(self)))
77    with ops.device("/device:CPU:0"):
78      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
79      self._output_classes = dataset.output_classes
80      self._output_types = dataset.output_types
81      self._output_shapes = dataset.output_shapes
82      self._flat_output_types = nest.flatten(
83          sparse.as_dense_types(self._output_types, self._output_classes))
84      self._flat_output_shapes = nest.flatten(
85          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
86      self._resource = gen_dataset_ops.iterator(
87          shared_name="",
88          container=_generate_shared_name("eageriterator"),
89          output_types=self._flat_output_types,
90          output_shapes=self._flat_output_shapes)
91      gen_dataset_ops.make_iterator(ds_variant, self._resource)
92      # Delete the resource when this object is deleted
93      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
94          handle=self._resource, handle_device="/device:CPU:0")
95    self._device = context.context().device_name
96    self._buffer_resource_handle = None
97    if not context.context().device_spec.device_type:
98      is_remote_device = False
99    else:
100      is_remote_device = context.context().device_spec.device_type != "CPU"
101    if is_remote_device:
102      with ops.device("/device:CPU:0"):
103        iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
104            self._resource)
106        @function.Defun(dtypes.string)
107        def remote_fn(h):
108          remote_iterator = iterator_ops.Iterator.from_string_handle(
109              h, self._output_types, self._output_shapes)
110          return remote_iterator.get_next()
112        remote_fn.add_to_graph(None)
113        target = constant_op.constant("/device:CPU:0")
114      with ops.device(self._device):
115        self._buffer_resource_handle = prefetching_ops.function_buffering_resource(  # pylint: disable=line-too-long
116            string_arg=iter_string_handle,
117            f=remote_fn,
118            target_device=target,
119            buffer_size=10,
120            thread_pool_size=1,
121            container="",
122            shared_name=_generate_shared_name("function_buffer_resource"))
123        self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(  # pylint: disable=line-too-long
124            handle=self._buffer_resource_handle,
125            handle_device=self._device)
127  def __iter__(self):
128    return self
130  def __next__(self):  # For Python 3 compatibility
131    return self.next()
133  def _next_internal(self):
134    """Returns a nested structure of `tf.Tensor`s containing the next element.
135    """
136    with ops.device(self._device):
137      if self._buffer_resource_handle is not None:
138        ret = prefetching_ops.function_buffering_resource_get_next(
139            function_buffer_resource=self._buffer_resource_handle,
140            output_types=self._flat_output_types)
141      else:
142        # TODO(ashankar): Consider removing this ops.device() contextmanager
143        # and instead mimic ops placement in graphs: Operations on resource
144        # handles execute on the same device as where the resource is placed.
145        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
146        # because in eager mode this code will run synchronously on the calling
147        # thread. Therefore we do not need to make a defensive context switch
148        # to a background thread, and can achieve a small constant performance
149        # boost by invoking the iterator synchronously.
150        ret = gen_dataset_ops.iterator_get_next_sync(
151            self._resource,
152            output_types=self._flat_output_types,
153            output_shapes=self._flat_output_shapes)
155    return sparse.deserialize_sparse_tensors(
156        nest.pack_sequence_as(self._output_types, ret), self._output_types,
157        self._output_shapes, self._output_classes)
159  def next(self):
160    """Returns a nested structure of `tf.Tensor`s containing the next element.
161    """
162    try:
163      return self._next_internal()
164    except errors.OutOfRangeError:
165      raise StopIteration
167  @property
168  def output_classes(self):
169    """Returns the class of each component of an element of this iterator.
171    The expected values are `tf.Tensor` and `tf.SparseTensor`.
173    Returns:
174      A nested structure of Python `type` objects corresponding to each
175      component of an element of this dataset.
176    """
177    return self._output_classes
179  @property
180  def output_shapes(self):
181    """Returns the shape of each component of an element of this iterator.
183    Returns:
184      A nested structure of `tf.TensorShape` objects corresponding to each
185      component of an element of this dataset.
186    """
187    return self._output_shapes
189  @property
190  def output_types(self):
191    """Returns the type of each component of an element of this iterator.
193    Returns:
194      A nested structure of `tf.DType` objects corresponding to each component
195      of an element of this dataset.
196    """
197    return self._output_types
199  def get_next(self, name=None):
200    """Returns a nested structure of `tf.Tensor`s containing the next element.
202    Args:
203      name: (Optional.) A name for the created operation. Currently unused.
205    Returns:
206      A nested structure of `tf.Tensor` objects.
208    Raises:
209      `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
210    """
211    del name
212    return self._next_internal()