1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Iteration over tf.data.Datasets when eager execution is enabled.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22 23from tensorflow.contrib.data.python.ops import prefetching_ops 24from tensorflow.python.data.ops import iterator_ops 25from tensorflow.python.data.util import nest 26from tensorflow.python.data.util import sparse 27from tensorflow.python.eager import context 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import function 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import gen_dataset_ops 34from tensorflow.python.ops import resource_variable_ops 35 36_uid_counter = 0 37_uid_lock = threading.Lock() 38 39 40def _generate_shared_name(prefix): 41 with _uid_lock: 42 global _uid_counter 43 uid = _uid_counter 44 _uid_counter += 1 45 return "{}{}".format(prefix, uid) 46 47 48class Iterator(object): 49 """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" 50 51 def __init__(self, dataset): 52 """Creates a new iterator over the given dataset. 53 54 For example: 55 ```python 56 dataset = tf.data.Dataset.range(4) 57 for x in Iterator(dataset): 58 print(x) 59 ``` 60 61 Tensors produced will be placed on the device on which this iterator object 62 was created. 63 64 Args: 65 dataset: A `tf.data.Dataset` object. 66 67 Raises: 68 RuntimeError: When invoked without eager execution enabled. 69 """ 70 71 if not context.in_eager_mode(): 72 raise RuntimeError( 73 "{} objects can only be used when eager execution is enabled, use " 74 "tf.data.Dataset.make_iterator or " 75 "tf.data.Dataset.make_one_shot_iterator for graph construction". 76 format(type(self))) 77 with ops.device("/device:CPU:0"): 78 ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access 79 self._output_classes = dataset.output_classes 80 self._output_types = dataset.output_types 81 self._output_shapes = dataset.output_shapes 82 self._flat_output_types = nest.flatten( 83 sparse.as_dense_types(self._output_types, self._output_classes)) 84 self._flat_output_shapes = nest.flatten( 85 sparse.as_dense_shapes(self._output_shapes, self._output_classes)) 86 self._resource = gen_dataset_ops.iterator( 87 shared_name="", 88 container=_generate_shared_name("eageriterator"), 89 output_types=self._flat_output_types, 90 output_shapes=self._flat_output_shapes) 91 gen_dataset_ops.make_iterator(ds_variant, self._resource) 92 # Delete the resource when this object is deleted 93 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 94 handle=self._resource, handle_device="/device:CPU:0") 95 self._device = context.context().device_name 96 self._buffer_resource_handle = None 97 if not context.context().device_spec.device_type: 98 is_remote_device = False 99 else: 100 is_remote_device = context.context().device_spec.device_type != "CPU" 101 if is_remote_device: 102 with ops.device("/device:CPU:0"): 103 iter_string_handle = gen_dataset_ops.iterator_to_string_handle( 104 self._resource) 105 106 @function.Defun(dtypes.string) 107 def remote_fn(h): 108 remote_iterator = iterator_ops.Iterator.from_string_handle( 109 h, self._output_types, self._output_shapes) 110 return remote_iterator.get_next() 111 112 remote_fn.add_to_graph(None) 113 target = constant_op.constant("/device:CPU:0") 114 with ops.device(self._device): 115 self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long 116 string_arg=iter_string_handle, 117 f=remote_fn, 118 target_device=target, 119 buffer_size=10, 120 thread_pool_size=1, 121 container="", 122 shared_name=_generate_shared_name("function_buffer_resource")) 123 self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long 124 handle=self._buffer_resource_handle, 125 handle_device=self._device) 126 127 def __iter__(self): 128 return self 129 130 def __next__(self): # For Python 3 compatibility 131 return self.next() 132 133 def _next_internal(self): 134 """Returns a nested structure of `tf.Tensor`s containing the next element. 135 """ 136 with ops.device(self._device): 137 if self._buffer_resource_handle is not None: 138 ret = prefetching_ops.function_buffering_resource_get_next( 139 function_buffer_resource=self._buffer_resource_handle, 140 output_types=self._flat_output_types) 141 else: 142 # TODO(ashankar): Consider removing this ops.device() contextmanager 143 # and instead mimic ops placement in graphs: Operations on resource 144 # handles execute on the same device as where the resource is placed. 145 # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` 146 # because in eager mode this code will run synchronously on the calling 147 # thread. Therefore we do not need to make a defensive context switch 148 # to a background thread, and can achieve a small constant performance 149 # boost by invoking the iterator synchronously. 150 ret = gen_dataset_ops.iterator_get_next_sync( 151 self._resource, 152 output_types=self._flat_output_types, 153 output_shapes=self._flat_output_shapes) 154 155 return sparse.deserialize_sparse_tensors( 156 nest.pack_sequence_as(self._output_types, ret), self._output_types, 157 self._output_shapes, self._output_classes) 158 159 def next(self): 160 """Returns a nested structure of `tf.Tensor`s containing the next element. 161 """ 162 try: 163 return self._next_internal() 164 except errors.OutOfRangeError: 165 raise StopIteration 166 167 @property 168 def output_classes(self): 169 """Returns the class of each component of an element of this iterator. 170 171 The expected values are `tf.Tensor` and `tf.SparseTensor`. 172 173 Returns: 174 A nested structure of Python `type` objects corresponding to each 175 component of an element of this dataset. 176 """ 177 return self._output_classes 178 179 @property 180 def output_shapes(self): 181 """Returns the shape of each component of an element of this iterator. 182 183 Returns: 184 A nested structure of `tf.TensorShape` objects corresponding to each 185 component of an element of this dataset. 186 """ 187 return self._output_shapes 188 189 @property 190 def output_types(self): 191 """Returns the type of each component of an element of this iterator. 192 193 Returns: 194 A nested structure of `tf.DType` objects corresponding to each component 195 of an element of this dataset. 196 """ 197 return self._output_types 198 199 def get_next(self, name=None): 200 """Returns a nested structure of `tf.Tensor`s containing the next element. 201 202 Args: 203 name: (Optional.) A name for the created operation. Currently unused. 204 205 Returns: 206 A nested structure of `tf.Tensor` objects. 207 208 Raises: 209 `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. 210 """ 211 del name 212 return self._next_internal() 213