1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Iterators.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import threading 22import warnings 23 24import six 25 26from tensorflow.python.data.ops import optional_ops 27from tensorflow.python.data.ops import options as options_lib 28from tensorflow.python.data.util import nest 29from tensorflow.python.data.util import structure 30from tensorflow.python.eager import context 31from tensorflow.python.framework import composite_tensor 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import tensor_spec 37from tensorflow.python.framework import type_spec 38from tensorflow.python.ops import gen_dataset_ops 39from tensorflow.python.training.saver import BaseSaverBuilder 40from tensorflow.python.training.tracking import base as trackable 41from tensorflow.python.util import _pywrap_utils 42from tensorflow.python.util import deprecation 43from tensorflow.python.util import lazy_loader 44from tensorflow.python.util.compat import collections_abc 45from tensorflow.python.util.tf_export import tf_export 46 47 48# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple 49# times, e.g. when you are distributing different elements to multiple 50# devices in a single step. However, a common pitfall arises when 51# users call `Iterator.get_next()` in each iteration of their training 52# loop. `Iterator.get_next()` adds ops to the graph, and executing 53# each op allocates resources (including threads); as a consequence, 54# invoking it in every iteration of a training loop causes slowdown 55# and eventual resource exhaustion. To guard against this outcome, we 56# log a warning when the number of uses crosses a threshold of suspicion. 57GET_NEXT_CALL_WARNING_THRESHOLD = 32 58 59GET_NEXT_CALL_WARNING_MESSAGE = ( 60 "An unusually high number of `Iterator.get_next()` calls was detected. " 61 "This often indicates that `Iterator.get_next()` is being called inside " 62 "a training loop, which will cause gradual slowdown and eventual resource " 63 "exhaustion. If this is the case, restructure your code to call " 64 "`next_element = iterator.get_next()` once outside the loop, and use " 65 "`next_element` as the input to some computation that is invoked inside " 66 "the loop.") 67 68# NOTE(jsimsa): Threshold used as a heuristic to check for infinite loop during 69# tf.function tracing. 70GET_NEXT_CALL_ERROR_THRESHOLD = 32 71 72GET_NEXT_CALL_ERROR_MESSAGE = ( 73 "An unusually high number of `tf.data.Iterator.get_next()` calls was " 74 "detected. This suggests that the `for elem in dataset: ...` idiom is used " 75 "within tf.function with AutoGraph disabled. This idiom is only supported " 76 "when AutoGraph is enabled.") 77 78# Collection of all IteratorResources in the `Graph`. 79GLOBAL_ITERATORS = "iterators" 80 81 82autograph_ctx = lazy_loader.LazyLoader( 83 "autograph_ctx", globals(), 84 "tensorflow.python.autograph.core.ag_ctx") 85 86 87def _device_stack_is_empty(): 88 if context.executing_eagerly(): 89 return context.context().device_name is None 90 # pylint: disable=protected-access 91 device_stack = ops.get_default_graph()._device_functions_outer_to_inner 92 # pylint: enable=protected-access 93 return not bool(device_stack) 94 95 96@tf_export(v1=["data.Iterator"]) 97class Iterator(trackable.Trackable): 98 """Represents the state of iterating through a `Dataset`.""" 99 100 def __init__(self, iterator_resource, initializer, output_types, 101 output_shapes, output_classes): 102 """Creates a new iterator from the given iterator resource. 103 104 Note: Most users will not call this initializer directly, and will 105 instead use `Dataset.make_initializable_iterator()` or 106 `Dataset.make_one_shot_iterator()`. 107 108 Args: 109 iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the 110 iterator. 111 initializer: A `tf.Operation` that should be run to initialize this 112 iterator. 113 output_types: A (nested) structure of `tf.DType` objects corresponding to 114 each component of an element of this iterator. 115 output_shapes: A (nested) structure of `tf.TensorShape` objects 116 corresponding to each component of an element of this iterator. 117 output_classes: A (nested) structure of Python `type` objects 118 corresponding to each component of an element of this iterator. 119 """ 120 self._iterator_resource = iterator_resource 121 self._initializer = initializer 122 123 if (output_types is None or output_shapes is None 124 or output_classes is None): 125 raise ValueError("If `structure` is not specified, all of " 126 "`output_types`, `output_shapes`, and `output_classes`" 127 " must be specified.") 128 self._element_spec = structure.convert_legacy_structure( 129 output_types, output_shapes, output_classes) 130 self._flat_tensor_shapes = structure.get_flat_tensor_shapes( 131 self._element_spec) 132 self._flat_tensor_types = structure.get_flat_tensor_types( 133 self._element_spec) 134 135 self._string_handle = gen_dataset_ops.iterator_to_string_handle( 136 self._iterator_resource) 137 self._get_next_call_count = 0 138 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource) 139 140 @staticmethod 141 def from_structure(output_types, 142 output_shapes=None, 143 shared_name=None, 144 output_classes=None): 145 """Creates a new, uninitialized `Iterator` with the given structure. 146 147 This iterator-constructing method can be used to create an iterator that 148 is reusable with many different datasets. 149 150 The returned iterator is not bound to a particular dataset, and it has 151 no `initializer`. To initialize the iterator, run the operation returned by 152 `Iterator.make_initializer(dataset)`. 153 154 The following is an example 155 156 ```python 157 iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) 158 159 dataset_range = Dataset.range(10) 160 range_initializer = iterator.make_initializer(dataset_range) 161 162 dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) 163 evens_initializer = iterator.make_initializer(dataset_evens) 164 165 # Define a model based on the iterator; in this example, the model_fn 166 # is expected to take scalar tf.int64 Tensors as input (see 167 # the definition of 'iterator' above). 168 prediction, loss = model_fn(iterator.get_next()) 169 170 # Train for `num_epochs`, where for each epoch, we first iterate over 171 # dataset_range, and then iterate over dataset_evens. 172 for _ in range(num_epochs): 173 # Initialize the iterator to `dataset_range` 174 sess.run(range_initializer) 175 while True: 176 try: 177 pred, loss_val = sess.run([prediction, loss]) 178 except tf.errors.OutOfRangeError: 179 break 180 181 # Initialize the iterator to `dataset_evens` 182 sess.run(evens_initializer) 183 while True: 184 try: 185 pred, loss_val = sess.run([prediction, loss]) 186 except tf.errors.OutOfRangeError: 187 break 188 ``` 189 190 Args: 191 output_types: A (nested) structure of `tf.DType` objects corresponding to 192 each component of an element of this dataset. 193 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 194 objects corresponding to each component of an element of this dataset. 195 If omitted, each component will have an unconstrainted shape. 196 shared_name: (Optional.) If non-empty, this iterator will be shared under 197 the given name across multiple sessions that share the same devices 198 (e.g. when using a remote server). 199 output_classes: (Optional.) A (nested) structure of Python `type` objects 200 corresponding to each component of an element of this iterator. If 201 omitted, each component is assumed to be of type `tf.Tensor`. 202 203 Returns: 204 An `Iterator`. 205 206 Raises: 207 TypeError: If the structures of `output_shapes` and `output_types` are 208 not the same. 209 """ 210 output_types = nest.map_structure(dtypes.as_dtype, output_types) 211 if output_shapes is None: 212 output_shapes = nest.map_structure( 213 lambda _: tensor_shape.TensorShape(None), output_types) 214 else: 215 output_shapes = nest.map_structure_up_to(output_types, 216 tensor_shape.as_shape, 217 output_shapes) 218 if output_classes is None: 219 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 220 nest.assert_same_structure(output_types, output_shapes) 221 output_structure = structure.convert_legacy_structure( 222 output_types, output_shapes, output_classes) 223 if shared_name is None: 224 shared_name = "" 225 iterator_resource = gen_dataset_ops.iterator_v2( 226 container="", 227 shared_name=shared_name, 228 output_types=structure.get_flat_tensor_types(output_structure), 229 output_shapes=structure.get_flat_tensor_shapes( 230 output_structure)) 231 return Iterator(iterator_resource, None, output_types, output_shapes, 232 output_classes) 233 234 @staticmethod 235 def from_string_handle(string_handle, 236 output_types, 237 output_shapes=None, 238 output_classes=None): 239 """Creates a new, uninitialized `Iterator` based on the given handle. 240 241 This method allows you to define a "feedable" iterator where you can choose 242 between concrete iterators by feeding a value in a `tf.Session.run` call. 243 In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you 244 would 245 feed it with the value of `tf.data.Iterator.string_handle` in each step. 246 247 For example, if you had two iterators that marked the current position in 248 a training dataset and a test dataset, you could choose which to use in 249 each step as follows: 250 251 ```python 252 train_iterator = tf.data.Dataset(...).make_one_shot_iterator() 253 train_iterator_handle = sess.run(train_iterator.string_handle()) 254 255 test_iterator = tf.data.Dataset(...).make_one_shot_iterator() 256 test_iterator_handle = sess.run(test_iterator.string_handle()) 257 258 handle = tf.compat.v1.placeholder(tf.string, shape=[]) 259 iterator = tf.data.Iterator.from_string_handle( 260 handle, train_iterator.output_types) 261 262 next_element = iterator.get_next() 263 loss = f(next_element) 264 265 train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) 266 test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) 267 ``` 268 269 Args: 270 string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to 271 a handle produced by the `Iterator.string_handle()` method. 272 output_types: A (nested) structure of `tf.DType` objects corresponding to 273 each component of an element of this dataset. 274 output_shapes: (Optional.) A (nested) structure of `tf.TensorShape` 275 objects corresponding to each component of an element of this dataset. 276 If omitted, each component will have an unconstrainted shape. 277 output_classes: (Optional.) A (nested) structure of Python `type` objects 278 corresponding to each component of an element of this iterator. If 279 omitted, each component is assumed to be of type `tf.Tensor`. 280 281 Returns: 282 An `Iterator`. 283 """ 284 output_types = nest.map_structure(dtypes.as_dtype, output_types) 285 if output_shapes is None: 286 output_shapes = nest.map_structure( 287 lambda _: tensor_shape.TensorShape(None), output_types) 288 else: 289 output_shapes = nest.map_structure_up_to(output_types, 290 tensor_shape.as_shape, 291 output_shapes) 292 if output_classes is None: 293 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 294 nest.assert_same_structure(output_types, output_shapes) 295 output_structure = structure.convert_legacy_structure( 296 output_types, output_shapes, output_classes) 297 string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) 298 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 299 string_handle, 300 output_types=structure.get_flat_tensor_types(output_structure), 301 output_shapes=structure.get_flat_tensor_shapes(output_structure)) 302 return Iterator(iterator_resource, None, output_types, output_shapes, 303 output_classes) 304 305 @property 306 def initializer(self): 307 """A `tf.Operation` that should be run to initialize this iterator. 308 309 Returns: 310 A `tf.Operation` that should be run to initialize this iterator 311 312 Raises: 313 ValueError: If this iterator initializes itself automatically. 314 """ 315 if self._initializer is not None: 316 return self._initializer 317 else: 318 # TODO(mrry): Consider whether one-shot iterators should have 319 # initializers that simply reset their state to the beginning. 320 raise ValueError("Iterator does not have an initializer.") 321 322 def make_initializer(self, dataset, name=None): 323 """Returns a `tf.Operation` that initializes this iterator on `dataset`. 324 325 Args: 326 dataset: A `Dataset` whose `element_spec` if compatible with this 327 iterator. 328 name: (Optional.) A name for the created operation. 329 330 Returns: 331 A `tf.Operation` that can be run to initialize this iterator on the given 332 `dataset`. 333 334 Raises: 335 TypeError: If `dataset` and this iterator do not have a compatible 336 `element_spec`. 337 """ 338 with ops.name_scope(name, "make_initializer") as name: 339 # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due 340 # to that creating a circular dependency. 341 # pylint: disable=protected-access 342 dataset_output_types = nest.map_structure( 343 lambda component_spec: component_spec._to_legacy_output_types(), 344 dataset.element_spec) 345 dataset_output_shapes = nest.map_structure( 346 lambda component_spec: component_spec._to_legacy_output_shapes(), 347 dataset.element_spec) 348 dataset_output_classes = nest.map_structure( 349 lambda component_spec: component_spec._to_legacy_output_classes(), 350 dataset.element_spec) 351 # pylint: enable=protected-access 352 353 nest.assert_same_structure(self.output_types, dataset_output_types) 354 nest.assert_same_structure(self.output_shapes, dataset_output_shapes) 355 for iterator_class, dataset_class in zip( 356 nest.flatten(self.output_classes), 357 nest.flatten(dataset_output_classes)): 358 if iterator_class is not dataset_class: 359 raise TypeError( 360 "Expected output classes %r but got dataset with output class %r." 361 % (self.output_classes, dataset_output_classes)) 362 for iterator_dtype, dataset_dtype in zip( 363 nest.flatten(self.output_types), nest.flatten(dataset_output_types)): 364 if iterator_dtype != dataset_dtype: 365 raise TypeError( 366 "Expected output types %r but got dataset with output types %r." % 367 (self.output_types, dataset_output_types)) 368 for iterator_shape, dataset_shape in zip( 369 nest.flatten(self.output_shapes), nest.flatten( 370 dataset_output_shapes)): 371 if not iterator_shape.is_compatible_with(dataset_shape): 372 raise TypeError("Expected output shapes compatible with %r but got " 373 "dataset with output shapes %r." % 374 (self.output_shapes, dataset_output_shapes)) 375 376 # TODO(b/169442955): Investigate the need for this colocation constraint. 377 with ops.colocate_with(self._iterator_resource): 378 # pylint: disable=protected-access 379 return gen_dataset_ops.make_iterator( 380 dataset._variant_tensor, self._iterator_resource, name=name) 381 382 def get_next(self, name=None): 383 """Returns the next element. 384 385 In graph mode, you should typically call this method *once* and use its 386 result as the input to another computation. A typical loop will then call 387 `tf.Session.run` on the result of that computation. The loop will terminate 388 when the `Iterator.get_next()` operation raises 389 `tf.errors.OutOfRangeError`. The following skeleton shows how to use 390 this method when building a training loop: 391 392 ```python 393 dataset = ... # A `tf.data.Dataset` object. 394 iterator = dataset.make_initializable_iterator() 395 next_element = iterator.get_next() 396 397 # Build a TensorFlow graph that does something with each element. 398 loss = model_function(next_element) 399 optimizer = ... # A `tf.compat.v1.train.Optimizer` object. 400 train_op = optimizer.minimize(loss) 401 402 with tf.compat.v1.Session() as sess: 403 try: 404 while True: 405 sess.run(train_op) 406 except tf.errors.OutOfRangeError: 407 pass 408 ``` 409 410 NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. 411 when you are distributing different elements to multiple devices in a single 412 step. However, a common pitfall arises when users call `Iterator.get_next()` 413 in each iteration of their training loop. `Iterator.get_next()` adds ops to 414 the graph, and executing each op allocates resources (including threads); as 415 a consequence, invoking it in every iteration of a training loop causes 416 slowdown and eventual resource exhaustion. To guard against this outcome, we 417 log a warning when the number of uses crosses a fixed threshold of 418 suspiciousness. 419 420 Args: 421 name: (Optional.) A name for the created operation. 422 423 Returns: 424 A (nested) structure of values matching `tf.data.Iterator.element_spec`. 425 """ 426 self._get_next_call_count += 1 427 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: 428 warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) 429 430 # TODO(b/169442955): Investigate the need for this colocation constraint. 431 with ops.colocate_with(self._iterator_resource): 432 # pylint: disable=protected-access 433 flat_ret = gen_dataset_ops.iterator_get_next( 434 self._iterator_resource, 435 output_types=self._flat_tensor_types, 436 output_shapes=self._flat_tensor_shapes, 437 name=name) 438 return structure.from_tensor_list(self._element_spec, flat_ret) 439 440 def get_next_as_optional(self): 441 # TODO(b/169442955): Investigate the need for this colocation constraint. 442 with ops.colocate_with(self._iterator_resource): 443 # pylint: disable=protected-access 444 return optional_ops._OptionalImpl( 445 gen_dataset_ops.iterator_get_next_as_optional( 446 self._iterator_resource, 447 output_types=structure.get_flat_tensor_types(self.element_spec), 448 output_shapes=structure.get_flat_tensor_shapes( 449 self.element_spec)), self.element_spec) 450 451 def string_handle(self, name=None): 452 """Returns a string-valued `tf.Tensor` that represents this iterator. 453 454 Args: 455 name: (Optional.) A name for the created operation. 456 457 Returns: 458 A scalar `tf.Tensor` of type `tf.string`. 459 """ 460 if name is None: 461 return self._string_handle 462 else: 463 return gen_dataset_ops.iterator_to_string_handle( 464 self._iterator_resource, name=name) 465 466 @property 467 @deprecation.deprecated( 468 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 469 def output_classes(self): 470 """Returns the class of each component of an element of this iterator. 471 472 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 473 474 Returns: 475 A (nested) structure of Python `type` objects corresponding to each 476 component of an element of this dataset. 477 """ 478 return nest.map_structure( 479 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 480 self._element_spec) 481 482 @property 483 @deprecation.deprecated( 484 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 485 def output_shapes(self): 486 """Returns the shape of each component of an element of this iterator. 487 488 Returns: 489 A (nested) structure of `tf.TensorShape` objects corresponding to each 490 component of an element of this dataset. 491 """ 492 return nest.map_structure( 493 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 494 self._element_spec) 495 496 @property 497 @deprecation.deprecated( 498 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 499 def output_types(self): 500 """Returns the type of each component of an element of this iterator. 501 502 Returns: 503 A (nested) structure of `tf.DType` objects corresponding to each component 504 of an element of this dataset. 505 """ 506 return nest.map_structure( 507 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 508 self._element_spec) 509 510 @property 511 def element_spec(self): 512 """The type specification of an element of this iterator. 513 514 For more information, 515 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 516 517 Returns: 518 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 519 element of this iterator and specifying the type of individual components. 520 """ 521 522 return self._element_spec 523 524 def _gather_saveables_for_checkpoint(self): 525 526 def _saveable_factory(name): 527 return _IteratorSaveable(self._iterator_resource, name) 528 529 return {"ITERATOR": _saveable_factory} 530 531 532_uid_counter = 0 533_uid_lock = threading.Lock() 534 535 536def _generate_shared_name(prefix): 537 with _uid_lock: 538 global _uid_counter 539 uid = _uid_counter 540 _uid_counter += 1 541 return "{}{}".format(prefix, uid) 542 543 544class IteratorResourceDeleter(object): 545 """An object which cleans up an iterator resource handle. 546 547 An alternative to defining a __del__ method on an object. Even if the parent 548 object is part of a reference cycle, the cycle will be collectable. 549 """ 550 551 __slots__ = ["_deleter", "_handle", "_eager_mode"] 552 553 def __init__(self, handle, deleter): 554 self._deleter = deleter 555 self._handle = handle 556 self._eager_mode = context.executing_eagerly() 557 558 def __del__(self): 559 # Make sure the resource is deleted in the same mode as it was created in. 560 if self._eager_mode: 561 with context.eager_mode(): 562 gen_dataset_ops.delete_iterator( 563 handle=self._handle, deleter=self._deleter) 564 else: 565 with context.graph_mode(): 566 gen_dataset_ops.delete_iterator( 567 handle=self._handle, deleter=self._deleter) 568 569 570@tf_export("data.Iterator", v1=[]) 571@six.add_metaclass(abc.ABCMeta) 572class IteratorBase(collections_abc.Iterator, trackable.Trackable, 573 composite_tensor.CompositeTensor): 574 """Represents an iterator of a `tf.data.Dataset`. 575 576 `tf.data.Iterator` is the primary mechanism for enumerating elements of a 577 `tf.data.Dataset`. It supports the Python Iterator protocol, which means 578 it can be iterated over using a for-loop: 579 580 >>> dataset = tf.data.Dataset.range(2) 581 >>> for element in dataset: 582 ... print(element) 583 tf.Tensor(0, shape=(), dtype=int64) 584 tf.Tensor(1, shape=(), dtype=int64) 585 586 or by fetching individual elements explicitly via `get_next()`: 587 588 >>> dataset = tf.data.Dataset.range(2) 589 >>> iterator = iter(dataset) 590 >>> print(iterator.get_next()) 591 tf.Tensor(0, shape=(), dtype=int64) 592 >>> print(iterator.get_next()) 593 tf.Tensor(1, shape=(), dtype=int64) 594 595 In addition, non-raising iteration is supported via `get_next_as_optional()`, 596 which returns the next element (if available) wrapped in a 597 `tf.experimental.Optional`. 598 599 >>> dataset = tf.data.Dataset.from_tensors(42) 600 >>> iterator = iter(dataset) 601 >>> optional = iterator.get_next_as_optional() 602 >>> print(optional.has_value()) 603 tf.Tensor(True, shape=(), dtype=bool) 604 >>> optional = iterator.get_next_as_optional() 605 >>> print(optional.has_value()) 606 tf.Tensor(False, shape=(), dtype=bool) 607 """ 608 609 @abc.abstractproperty 610 def element_spec(self): 611 """The type specification of an element of this iterator. 612 613 >>> dataset = tf.data.Dataset.from_tensors(42) 614 >>> iterator = iter(dataset) 615 >>> iterator.element_spec 616 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 617 618 For more information, 619 read [this guide](https://www.tensorflow.org/guide/data#dataset_structure). 620 621 Returns: 622 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 623 element of this iterator, specifying the type of individual components. 624 """ 625 raise NotImplementedError("Iterator.element_spec") 626 627 @abc.abstractmethod 628 def get_next(self): 629 """Returns the next element. 630 631 >>> dataset = tf.data.Dataset.from_tensors(42) 632 >>> iterator = iter(dataset) 633 >>> print(iterator.get_next()) 634 tf.Tensor(42, shape=(), dtype=int32) 635 636 Returns: 637 A (nested) structure of values matching `tf.data.Iterator.element_spec`. 638 639 Raises: 640 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 641 """ 642 raise NotImplementedError("Iterator.get_next()") 643 644 @abc.abstractmethod 645 def get_next_as_optional(self): 646 """Returns the next element warpped in `tf.experimental.Optional`. 647 648 If the iterator has reached the end of the sequence, the returned 649 `tf.experimental.Optional` will have no value. 650 651 >>> dataset = tf.data.Dataset.from_tensors(42) 652 >>> iterator = iter(dataset) 653 >>> optional = iterator.get_next_as_optional() 654 >>> print(optional.has_value()) 655 tf.Tensor(True, shape=(), dtype=bool) 656 >>> print(optional.get_value()) 657 tf.Tensor(42, shape=(), dtype=int32) 658 >>> optional = iterator.get_next_as_optional() 659 >>> print(optional.has_value()) 660 tf.Tensor(False, shape=(), dtype=bool) 661 662 Returns: 663 A `tf.experimental.Optional` object representing the next element. 664 """ 665 raise NotImplementedError("Iterator.get_next_as_optional()") 666 667 668class OwnedIterator(IteratorBase): 669 """An iterator producing tf.Tensor objects from a tf.data.Dataset. 670 671 The iterator resource created through `OwnedIterator` is owned by the Python 672 object and the life time of the underlying resource is tied to the life time 673 of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use 674 in eager mode and inside of tf.functions. 675 """ 676 677 def __init__(self, dataset=None, components=None, element_spec=None): 678 """Creates a new iterator from the given dataset. 679 680 If `dataset` is not specified, the iterator will be created from the given 681 tensor components and element structure. In particular, the alternative for 682 constructing the iterator is used when the iterator is reconstructed from 683 it `CompositeTensor` representation. 684 685 Args: 686 dataset: A `tf.data.Dataset` object. 687 components: Tensor components to construct the iterator from. 688 element_spec: A (nested) structure of `TypeSpec` objects that 689 represents the type specification of elements of the iterator. 690 691 Raises: 692 ValueError: If `dataset` is not provided and either `components` or 693 `element_spec` is not provided. Or `dataset` is provided and either 694 `components` and `element_spec` is provided. 695 """ 696 super(OwnedIterator, self).__init__() 697 error_message = ("Either `dataset` or both `components` and " 698 "`element_spec` need to be provided.") 699 700 if dataset is None: 701 if (components is None or element_spec is None): 702 raise ValueError(error_message) 703 # pylint: disable=protected-access 704 self._element_spec = element_spec 705 self._flat_output_types = structure.get_flat_tensor_types( 706 self._element_spec) 707 self._flat_output_shapes = structure.get_flat_tensor_shapes( 708 self._element_spec) 709 self._iterator_resource, self._deleter = components 710 else: 711 if (components is not None or element_spec is not None): 712 raise ValueError(error_message) 713 self._create_iterator(dataset) 714 715 self._get_next_call_count = 0 716 717 def _create_iterator(self, dataset): 718 # pylint: disable=protected-access 719 dataset = dataset._apply_debug_options() 720 721 # Store dataset reference to ensure that dataset is alive when this iterator 722 # is being used. For example, `tf.data.Dataset.from_generator` registers 723 # a few py_funcs that are needed in `self._next_internal`. If the dataset 724 # is deleted, this iterator crashes on `self.__next__(...)` call. 725 self._dataset = dataset 726 727 ds_variant = dataset._variant_tensor 728 self._element_spec = dataset.element_spec 729 self._flat_output_types = structure.get_flat_tensor_types( 730 self._element_spec) 731 self._flat_output_shapes = structure.get_flat_tensor_shapes( 732 self._element_spec) 733 with ops.colocate_with(ds_variant): 734 self._iterator_resource, self._deleter = ( 735 gen_dataset_ops.anonymous_iterator_v2( 736 output_types=self._flat_output_types, 737 output_shapes=self._flat_output_shapes)) 738 gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) 739 # Delete the resource when this object is deleted 740 self._resource_deleter = IteratorResourceDeleter( 741 handle=self._iterator_resource, 742 deleter=self._deleter) 743 744 def __iter__(self): 745 return self 746 747 def next(self): # For Python 2 compatibility 748 return self.__next__() 749 750 def _next_internal(self): 751 autograph_status = autograph_ctx.control_status_ctx().status 752 autograph_disabled = autograph_status == autograph_ctx.Status.DISABLED 753 if not context.executing_eagerly() and autograph_disabled: 754 self._get_next_call_count += 1 755 if self._get_next_call_count > GET_NEXT_CALL_ERROR_THRESHOLD: 756 raise ValueError(GET_NEXT_CALL_ERROR_MESSAGE) 757 758 if not context.executing_eagerly(): 759 # TODO(b/169442955): Investigate the need for this colocation constraint. 760 with ops.colocate_with(self._iterator_resource): 761 ret = gen_dataset_ops.iterator_get_next( 762 self._iterator_resource, 763 output_types=self._flat_output_types, 764 output_shapes=self._flat_output_shapes) 765 return structure.from_compatible_tensor_list(self._element_spec, ret) 766 767 # TODO(b/77291417): This runs in sync mode as iterators use an error status 768 # to communicate that there is no more data to iterate over. 769 with context.execution_mode(context.SYNC): 770 ret = gen_dataset_ops.iterator_get_next( 771 self._iterator_resource, 772 output_types=self._flat_output_types, 773 output_shapes=self._flat_output_shapes) 774 775 try: 776 # Fast path for the case `self._structure` is not a nested structure. 777 return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access 778 except AttributeError: 779 return structure.from_compatible_tensor_list(self._element_spec, ret) 780 781 @property 782 def _type_spec(self): 783 return IteratorSpec(self.element_spec) 784 785 def __next__(self): 786 try: 787 return self._next_internal() 788 except errors.OutOfRangeError: 789 raise StopIteration 790 791 @property 792 @deprecation.deprecated( 793 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 794 def output_classes(self): 795 """Returns the class of each component of an element of this iterator. 796 797 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 798 799 Returns: 800 A (nested) structure of Python `type` objects corresponding to each 801 component of an element of this dataset. 802 """ 803 return nest.map_structure( 804 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 805 self._element_spec) 806 807 @property 808 @deprecation.deprecated( 809 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 810 def output_shapes(self): 811 """Returns the shape of each component of an element of this iterator. 812 813 Returns: 814 A (nested) structure of `tf.TensorShape` objects corresponding to each 815 component of an element of this dataset. 816 """ 817 return nest.map_structure( 818 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 819 self._element_spec) 820 821 @property 822 @deprecation.deprecated( 823 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 824 def output_types(self): 825 """Returns the type of each component of an element of this iterator. 826 827 Returns: 828 A (nested) structure of `tf.DType` objects corresponding to each component 829 of an element of this dataset. 830 """ 831 return nest.map_structure( 832 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 833 self._element_spec) 834 835 @property 836 def element_spec(self): 837 return self._element_spec 838 839 def get_next(self): 840 return self._next_internal() 841 842 def get_next_as_optional(self): 843 # TODO(b/169442955): Investigate the need for this colocation constraint. 844 with ops.colocate_with(self._iterator_resource): 845 # pylint: disable=protected-access 846 return optional_ops._OptionalImpl( 847 gen_dataset_ops.iterator_get_next_as_optional( 848 self._iterator_resource, 849 output_types=structure.get_flat_tensor_types(self.element_spec), 850 output_shapes=structure.get_flat_tensor_shapes( 851 self.element_spec)), self.element_spec) 852 853 def _gather_saveables_for_checkpoint(self): 854 855 def _saveable_factory(name): 856 """Returns a SaveableObject for serialization/deserialization.""" 857 policy = None 858 if self._dataset: 859 policy = self._dataset.options().experimental_external_state_policy 860 if policy: 861 return _IteratorSaveable( 862 self._iterator_resource, 863 name, 864 external_state_policy=policy) 865 else: 866 return _IteratorSaveable(self._iterator_resource, name) 867 868 return {"ITERATOR": _saveable_factory} 869 870 871@tf_export("data.IteratorSpec", v1=[]) 872class IteratorSpec(type_spec.TypeSpec): 873 """Type specification for `tf.data.Iterator`. 874 875 For instance, `tf.data.IteratorSpec` can be used to define a tf.function that 876 takes `tf.data.Iterator` as an input argument: 877 878 >>> @tf.function(input_signature=[tf.data.IteratorSpec( 879 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 880 ... def square(iterator): 881 ... x = iterator.get_next() 882 ... return x * x 883 >>> dataset = tf.data.Dataset.from_tensors(5) 884 >>> iterator = iter(dataset) 885 >>> print(square(iterator)) 886 tf.Tensor(25, shape=(), dtype=int32) 887 888 Attributes: 889 element_spec: A (nested) structure of `tf.TypeSpec` objects that represents 890 the type specification of the iterator elements. 891 """ 892 893 __slots__ = ["_element_spec"] 894 895 def __init__(self, element_spec): 896 self._element_spec = element_spec 897 898 @property 899 def value_type(self): 900 return OwnedIterator 901 902 def _serialize(self): 903 return (self._element_spec,) 904 905 @property 906 def _component_specs(self): 907 return ( 908 tensor_spec.TensorSpec([], dtypes.resource), 909 tensor_spec.TensorSpec([], dtypes.variant), 910 ) 911 912 def _to_components(self, value): 913 return (value._iterator_resource, value._deleter) # pylint: disable=protected-access 914 915 def _from_components(self, components): 916 return OwnedIterator( 917 dataset=None, 918 components=components, 919 element_spec=self._element_spec) 920 921 @staticmethod 922 def from_value(value): 923 return IteratorSpec(value.element_spec) # pylint: disable=protected-access 924 925 926# TODO(b/71645805): Expose trackable stateful objects from dataset. 927class _IteratorSaveable(BaseSaverBuilder.SaveableObject): 928 """SaveableObject for saving/restoring iterator state.""" 929 930 def __init__( 931 self, 932 iterator_resource, 933 name, 934 external_state_policy=options_lib.ExternalStatePolicy.FAIL): 935 serialized_iterator = gen_dataset_ops.serialize_iterator( 936 iterator_resource, external_state_policy=external_state_policy.value) 937 specs = [ 938 BaseSaverBuilder.SaveSpec( 939 serialized_iterator, 940 "", 941 name + "_STATE", 942 device=iterator_resource.device) 943 ] 944 super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) 945 946 def restore(self, restored_tensors, restored_shapes): 947 with ops.colocate_with(self.op): 948 return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) 949 950 951@deprecation.deprecated( 952 None, "Use `tf.data.Iterator.get_next_as_optional()` instead.") 953@tf_export("data.experimental.get_next_as_optional") 954def get_next_as_optional(iterator): 955 """Returns a `tf.experimental.Optional` with the next element of the iterator. 956 957 If the iterator has reached the end of the sequence, the returned 958 `tf.experimental.Optional` will have no value. 959 960 Args: 961 iterator: A `tf.data.Iterator`. 962 963 Returns: 964 A `tf.experimental.Optional` object which either contains the next element 965 of the iterator (if it exists) or no value. 966 """ 967 return iterator.get_next_as_optional() 968 969 970_pywrap_utils.RegisterType("OwnedIterator", OwnedIterator) 971