1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Iterators.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import threading 22import warnings 23 24import six 25 26from tensorflow.python.data.experimental.ops import distribute_options 27from tensorflow.python.data.ops import optional_ops 28from tensorflow.python.data.util import nest 29from tensorflow.python.data.util import structure 30from tensorflow.python.eager import context 31from tensorflow.python.framework import composite_tensor 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import tensor_spec 37from tensorflow.python.framework import type_spec 38from tensorflow.python.ops import gen_dataset_ops 39from tensorflow.python.training.saver import BaseSaverBuilder 40from tensorflow.python.training.tracking import base as trackable 41from tensorflow.python.util import deprecation 42from tensorflow.python.util.compat import collections_abc 43from tensorflow.python.util.tf_export import tf_export 44 45 46# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple 47# times, e.g. when you are distributing different elements to multiple 48# devices in a single step. However, a common pitfall arises when 49# users call `Iterator.get_next()` in each iteration of their training 50# loop. `Iterator.get_next()` adds ops to the graph, and executing 51# each op allocates resources (including threads); as a consequence, 52# invoking it in every iteration of a training loop causes slowdown 53# and eventual resource exhaustion. To guard against this outcome, we 54# log a warning when the number of uses crosses a threshold of suspicion. 55GET_NEXT_CALL_WARNING_THRESHOLD = 32 56 57GET_NEXT_CALL_WARNING_MESSAGE = ( 58 "An unusually high number of `Iterator.get_next()` calls was detected. " 59 "This often indicates that `Iterator.get_next()` is being called inside " 60 "a training loop, which will cause gradual slowdown and eventual resource " 61 "exhaustion. If this is the case, restructure your code to call " 62 "`next_element = iterator.get_next()` once outside the loop, and use " 63 "`next_element` as the input to some computation that is invoked inside " 64 "the loop.") 65 66# Collection of all IteratorResources in the `Graph`. 67GLOBAL_ITERATORS = "iterators" 68 69 70def _device_stack_is_empty(): 71 if context.executing_eagerly(): 72 return context.context().device_name is None 73 # pylint: disable=protected-access 74 device_stack = ops.get_default_graph()._device_functions_outer_to_inner 75 # pylint: enable=protected-access 76 return not bool(device_stack) 77 78 79@tf_export(v1=["data.Iterator"]) 80class Iterator(trackable.Trackable): 81 """Represents the state of iterating through a `Dataset`.""" 82 83 def __init__(self, iterator_resource, initializer, output_types, 84 output_shapes, output_classes): 85 """Creates a new iterator from the given iterator resource. 86 87 Note: Most users will not call this initializer directly, and will 88 instead use `Dataset.make_initializable_iterator()` or 89 `Dataset.make_one_shot_iterator()`. 90 91 Args: 92 iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the 93 iterator. 94 initializer: A `tf.Operation` that should be run to initialize this 95 iterator. 96 output_types: A nested structure of `tf.DType` objects corresponding to 97 each component of an element of this iterator. 98 output_shapes: A nested structure of `tf.TensorShape` objects 99 corresponding to each component of an element of this iterator. 100 output_classes: A nested structure of Python `type` objects corresponding 101 to each component of an element of this iterator. 102 """ 103 self._iterator_resource = iterator_resource 104 self._initializer = initializer 105 106 if (output_types is None or output_shapes is None 107 or output_classes is None): 108 raise ValueError("If `structure` is not specified, all of " 109 "`output_types`, `output_shapes`, and `output_classes`" 110 " must be specified.") 111 self._element_spec = structure.convert_legacy_structure( 112 output_types, output_shapes, output_classes) 113 self._flat_tensor_shapes = structure.get_flat_tensor_shapes( 114 self._element_spec) 115 self._flat_tensor_types = structure.get_flat_tensor_types( 116 self._element_spec) 117 118 self._string_handle = gen_dataset_ops.iterator_to_string_handle( 119 self._iterator_resource) 120 self._get_next_call_count = 0 121 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource) 122 123 @staticmethod 124 def from_structure(output_types, 125 output_shapes=None, 126 shared_name=None, 127 output_classes=None): 128 """Creates a new, uninitialized `Iterator` with the given structure. 129 130 This iterator-constructing method can be used to create an iterator that 131 is reusable with many different datasets. 132 133 The returned iterator is not bound to a particular dataset, and it has 134 no `initializer`. To initialize the iterator, run the operation returned by 135 `Iterator.make_initializer(dataset)`. 136 137 The following is an example 138 139 ```python 140 iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) 141 142 dataset_range = Dataset.range(10) 143 range_initializer = iterator.make_initializer(dataset_range) 144 145 dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) 146 evens_initializer = iterator.make_initializer(dataset_evens) 147 148 # Define a model based on the iterator; in this example, the model_fn 149 # is expected to take scalar tf.int64 Tensors as input (see 150 # the definition of 'iterator' above). 151 prediction, loss = model_fn(iterator.get_next()) 152 153 # Train for `num_epochs`, where for each epoch, we first iterate over 154 # dataset_range, and then iterate over dataset_evens. 155 for _ in range(num_epochs): 156 # Initialize the iterator to `dataset_range` 157 sess.run(range_initializer) 158 while True: 159 try: 160 pred, loss_val = sess.run([prediction, loss]) 161 except tf.errors.OutOfRangeError: 162 break 163 164 # Initialize the iterator to `dataset_evens` 165 sess.run(evens_initializer) 166 while True: 167 try: 168 pred, loss_val = sess.run([prediction, loss]) 169 except tf.errors.OutOfRangeError: 170 break 171 ``` 172 173 Args: 174 output_types: A nested structure of `tf.DType` objects corresponding to 175 each component of an element of this dataset. 176 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects 177 corresponding to each component of an element of this dataset. If 178 omitted, each component will have an unconstrainted shape. 179 shared_name: (Optional.) If non-empty, this iterator will be shared under 180 the given name across multiple sessions that share the same devices 181 (e.g. when using a remote server). 182 output_classes: (Optional.) A nested structure of Python `type` objects 183 corresponding to each component of an element of this iterator. If 184 omitted, each component is assumed to be of type `tf.Tensor`. 185 186 Returns: 187 An `Iterator`. 188 189 Raises: 190 TypeError: If the structures of `output_shapes` and `output_types` are 191 not the same. 192 """ 193 output_types = nest.map_structure(dtypes.as_dtype, output_types) 194 if output_shapes is None: 195 output_shapes = nest.map_structure( 196 lambda _: tensor_shape.TensorShape(None), output_types) 197 else: 198 output_shapes = nest.map_structure_up_to(output_types, 199 tensor_shape.as_shape, 200 output_shapes) 201 if output_classes is None: 202 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 203 nest.assert_same_structure(output_types, output_shapes) 204 output_structure = structure.convert_legacy_structure( 205 output_types, output_shapes, output_classes) 206 if shared_name is None: 207 shared_name = "" 208 iterator_resource = gen_dataset_ops.iterator_v2( 209 container="", 210 shared_name=shared_name, 211 output_types=structure.get_flat_tensor_types(output_structure), 212 output_shapes=structure.get_flat_tensor_shapes( 213 output_structure)) 214 return Iterator(iterator_resource, None, output_types, output_shapes, 215 output_classes) 216 217 @staticmethod 218 def from_string_handle(string_handle, 219 output_types, 220 output_shapes=None, 221 output_classes=None): 222 """Creates a new, uninitialized `Iterator` based on the given handle. 223 224 This method allows you to define a "feedable" iterator where you can choose 225 between concrete iterators by feeding a value in a `tf.Session.run` call. 226 In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you 227 would 228 feed it with the value of `tf.data.Iterator.string_handle` in each step. 229 230 For example, if you had two iterators that marked the current position in 231 a training dataset and a test dataset, you could choose which to use in 232 each step as follows: 233 234 ```python 235 train_iterator = tf.data.Dataset(...).make_one_shot_iterator() 236 train_iterator_handle = sess.run(train_iterator.string_handle()) 237 238 test_iterator = tf.data.Dataset(...).make_one_shot_iterator() 239 test_iterator_handle = sess.run(test_iterator.string_handle()) 240 241 handle = tf.compat.v1.placeholder(tf.string, shape=[]) 242 iterator = tf.data.Iterator.from_string_handle( 243 handle, train_iterator.output_types) 244 245 next_element = iterator.get_next() 246 loss = f(next_element) 247 248 train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) 249 test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) 250 ``` 251 252 Args: 253 string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to 254 a handle produced by the `Iterator.string_handle()` method. 255 output_types: A nested structure of `tf.DType` objects corresponding to 256 each component of an element of this dataset. 257 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects 258 corresponding to each component of an element of this dataset. If 259 omitted, each component will have an unconstrainted shape. 260 output_classes: (Optional.) A nested structure of Python `type` objects 261 corresponding to each component of an element of this iterator. If 262 omitted, each component is assumed to be of type `tf.Tensor`. 263 264 Returns: 265 An `Iterator`. 266 """ 267 output_types = nest.map_structure(dtypes.as_dtype, output_types) 268 if output_shapes is None: 269 output_shapes = nest.map_structure( 270 lambda _: tensor_shape.TensorShape(None), output_types) 271 else: 272 output_shapes = nest.map_structure_up_to(output_types, 273 tensor_shape.as_shape, 274 output_shapes) 275 if output_classes is None: 276 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 277 nest.assert_same_structure(output_types, output_shapes) 278 output_structure = structure.convert_legacy_structure( 279 output_types, output_shapes, output_classes) 280 string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) 281 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 282 string_handle, 283 output_types=structure.get_flat_tensor_types(output_structure), 284 output_shapes=structure.get_flat_tensor_shapes(output_structure)) 285 return Iterator(iterator_resource, None, output_types, output_shapes, 286 output_classes) 287 288 @property 289 def initializer(self): 290 """A `tf.Operation` that should be run to initialize this iterator. 291 292 Returns: 293 A `tf.Operation` that should be run to initialize this iterator 294 295 Raises: 296 ValueError: If this iterator initializes itself automatically. 297 """ 298 if self._initializer is not None: 299 return self._initializer 300 else: 301 # TODO(mrry): Consider whether one-shot iterators should have 302 # initializers that simply reset their state to the beginning. 303 raise ValueError("Iterator does not have an initializer.") 304 305 def make_initializer(self, dataset, name=None): 306 """Returns a `tf.Operation` that initializes this iterator on `dataset`. 307 308 Args: 309 dataset: A `Dataset` with compatible structure to this iterator. 310 name: (Optional.) A name for the created operation. 311 312 Returns: 313 A `tf.Operation` that can be run to initialize this iterator on the given 314 `dataset`. 315 316 Raises: 317 TypeError: If `dataset` and this iterator do not have a compatible 318 element structure. 319 """ 320 with ops.name_scope(name, "make_initializer") as name: 321 # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due 322 # to that creating a circular dependency. 323 # pylint: disable=protected-access 324 dataset_output_types = nest.map_structure( 325 lambda component_spec: component_spec._to_legacy_output_types(), 326 dataset.element_spec) 327 dataset_output_shapes = nest.map_structure( 328 lambda component_spec: component_spec._to_legacy_output_shapes(), 329 dataset.element_spec) 330 dataset_output_classes = nest.map_structure( 331 lambda component_spec: component_spec._to_legacy_output_classes(), 332 dataset.element_spec) 333 # pylint: enable=protected-access 334 335 nest.assert_same_structure(self.output_types, dataset_output_types) 336 nest.assert_same_structure(self.output_shapes, dataset_output_shapes) 337 for iterator_class, dataset_class in zip( 338 nest.flatten(self.output_classes), 339 nest.flatten(dataset_output_classes)): 340 if iterator_class is not dataset_class: 341 raise TypeError( 342 "Expected output classes %r but got dataset with output class %r." 343 % (self.output_classes, dataset_output_classes)) 344 for iterator_dtype, dataset_dtype in zip( 345 nest.flatten(self.output_types), nest.flatten(dataset_output_types)): 346 if iterator_dtype != dataset_dtype: 347 raise TypeError( 348 "Expected output types %r but got dataset with output types %r." % 349 (self.output_types, dataset_output_types)) 350 for iterator_shape, dataset_shape in zip( 351 nest.flatten(self.output_shapes), nest.flatten( 352 dataset_output_shapes)): 353 if not iterator_shape.is_compatible_with(dataset_shape): 354 raise TypeError("Expected output shapes compatible with %r but got " 355 "dataset with output shapes %r." % 356 (self.output_shapes, dataset_output_shapes)) 357 358 # TODO(b/169442955): Investigate the need for this colocation constraint. 359 with ops.colocate_with(self._iterator_resource): 360 # pylint: disable=protected-access 361 return gen_dataset_ops.make_iterator( 362 dataset._variant_tensor, self._iterator_resource, name=name) 363 364 def get_next(self, name=None): 365 """Returns a nested structure of `tf.Tensor`s representing the next element. 366 367 In graph mode, you should typically call this method *once* and use its 368 result as the input to another computation. A typical loop will then call 369 `tf.Session.run` on the result of that computation. The loop will terminate 370 when the `Iterator.get_next()` operation raises 371 `tf.errors.OutOfRangeError`. The following skeleton shows how to use 372 this method when building a training loop: 373 374 ```python 375 dataset = ... # A `tf.data.Dataset` object. 376 iterator = dataset.make_initializable_iterator() 377 next_element = iterator.get_next() 378 379 # Build a TensorFlow graph that does something with each element. 380 loss = model_function(next_element) 381 optimizer = ... # A `tf.compat.v1.train.Optimizer` object. 382 train_op = optimizer.minimize(loss) 383 384 with tf.compat.v1.Session() as sess: 385 try: 386 while True: 387 sess.run(train_op) 388 except tf.errors.OutOfRangeError: 389 pass 390 ``` 391 392 NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. 393 when you are distributing different elements to multiple devices in a single 394 step. However, a common pitfall arises when users call `Iterator.get_next()` 395 in each iteration of their training loop. `Iterator.get_next()` adds ops to 396 the graph, and executing each op allocates resources (including threads); as 397 a consequence, invoking it in every iteration of a training loop causes 398 slowdown and eventual resource exhaustion. To guard against this outcome, we 399 log a warning when the number of uses crosses a fixed threshold of 400 suspiciousness. 401 402 Args: 403 name: (Optional.) A name for the created operation. 404 405 Returns: 406 A nested structure of `tf.Tensor` objects. 407 """ 408 self._get_next_call_count += 1 409 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: 410 warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) 411 412 # TODO(b/169442955): Investigate the need for this colocation constraint. 413 with ops.colocate_with(self._iterator_resource): 414 # pylint: disable=protected-access 415 flat_ret = gen_dataset_ops.iterator_get_next( 416 self._iterator_resource, 417 output_types=self._flat_tensor_types, 418 output_shapes=self._flat_tensor_shapes, 419 name=name) 420 return structure.from_tensor_list(self._element_spec, flat_ret) 421 422 def get_next_as_optional(self): 423 # TODO(b/169442955): Investigate the need for this colocation constraint. 424 with ops.colocate_with(self._iterator_resource): 425 # pylint: disable=protected-access 426 return optional_ops._OptionalImpl( 427 gen_dataset_ops.iterator_get_next_as_optional( 428 self._iterator_resource, 429 output_types=structure.get_flat_tensor_types(self.element_spec), 430 output_shapes=structure.get_flat_tensor_shapes( 431 self.element_spec)), self.element_spec) 432 433 def string_handle(self, name=None): 434 """Returns a string-valued `tf.Tensor` that represents this iterator. 435 436 Args: 437 name: (Optional.) A name for the created operation. 438 439 Returns: 440 A scalar `tf.Tensor` of type `tf.string`. 441 """ 442 if name is None: 443 return self._string_handle 444 else: 445 return gen_dataset_ops.iterator_to_string_handle( 446 self._iterator_resource, name=name) 447 448 @property 449 @deprecation.deprecated( 450 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 451 def output_classes(self): 452 """Returns the class of each component of an element of this iterator. 453 454 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 455 456 Returns: 457 A nested structure of Python `type` objects corresponding to each 458 component of an element of this dataset. 459 """ 460 return nest.map_structure( 461 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 462 self._element_spec) 463 464 @property 465 @deprecation.deprecated( 466 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 467 def output_shapes(self): 468 """Returns the shape of each component of an element of this iterator. 469 470 Returns: 471 A nested structure of `tf.TensorShape` objects corresponding to each 472 component of an element of this dataset. 473 """ 474 return nest.map_structure( 475 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 476 self._element_spec) 477 478 @property 479 @deprecation.deprecated( 480 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 481 def output_types(self): 482 """Returns the type of each component of an element of this iterator. 483 484 Returns: 485 A nested structure of `tf.DType` objects corresponding to each component 486 of an element of this dataset. 487 """ 488 return nest.map_structure( 489 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 490 self._element_spec) 491 492 @property 493 def element_spec(self): 494 return self._element_spec 495 496 def _gather_saveables_for_checkpoint(self): 497 498 def _saveable_factory(name): 499 return _IteratorSaveable(self._iterator_resource, name) 500 501 return {"ITERATOR": _saveable_factory} 502 503 504_uid_counter = 0 505_uid_lock = threading.Lock() 506 507 508def _generate_shared_name(prefix): 509 with _uid_lock: 510 global _uid_counter 511 uid = _uid_counter 512 _uid_counter += 1 513 return "{}{}".format(prefix, uid) 514 515 516class IteratorResourceDeleter(object): 517 """An object which cleans up an iterator resource handle. 518 519 An alternative to defining a __del__ method on an object. Even if the parent 520 object is part of a reference cycle, the cycle will be collectable. 521 """ 522 523 __slots__ = ["_deleter", "_handle", "_eager_mode"] 524 525 def __init__(self, handle, deleter): 526 self._deleter = deleter 527 self._handle = handle 528 self._eager_mode = context.executing_eagerly() 529 530 def __del__(self): 531 # Make sure the resource is deleted in the same mode as it was created in. 532 if self._eager_mode: 533 with context.eager_mode(): 534 gen_dataset_ops.delete_iterator( 535 handle=self._handle, deleter=self._deleter) 536 else: 537 with context.graph_mode(): 538 gen_dataset_ops.delete_iterator( 539 handle=self._handle, deleter=self._deleter) 540 541 542@tf_export("data.Iterator", v1=[]) 543@six.add_metaclass(abc.ABCMeta) 544class IteratorBase(collections_abc.Iterator, trackable.Trackable, 545 composite_tensor.CompositeTensor): 546 """Represents an iterator of a `tf.data.Dataset`. 547 548 `tf.data.Iterator` is the primary mechanism for enumerating elements of a 549 `tf.data.Dataset`. It supports the Python Iterator protocol, which means 550 it can be iterated over using a for-loop: 551 552 >>> dataset = tf.data.Dataset.range(2) 553 >>> for element in dataset: 554 ... print(element) 555 tf.Tensor(0, shape=(), dtype=int64) 556 tf.Tensor(1, shape=(), dtype=int64) 557 558 or by fetching individual elements explicitly via `get_next()`: 559 560 >>> dataset = tf.data.Dataset.range(2) 561 >>> iterator = iter(dataset) 562 >>> print(iterator.get_next()) 563 tf.Tensor(0, shape=(), dtype=int64) 564 >>> print(iterator.get_next()) 565 tf.Tensor(1, shape=(), dtype=int64) 566 567 In addition, non-raising iteration is supported via `get_next_as_optional()`, 568 which returns the next element (if available) wrapped in a 569 `tf.experimental.Optional`. 570 571 >>> dataset = tf.data.Dataset.from_tensors(42) 572 >>> iterator = iter(dataset) 573 >>> optional = iterator.get_next_as_optional() 574 >>> print(optional.has_value()) 575 tf.Tensor(True, shape=(), dtype=bool) 576 >>> optional = iterator.get_next_as_optional() 577 >>> print(optional.has_value()) 578 tf.Tensor(False, shape=(), dtype=bool) 579 """ 580 581 @abc.abstractproperty 582 def element_spec(self): 583 """The type specification of an element of this iterator. 584 585 >>> dataset = tf.data.Dataset.from_tensors(42) 586 >>> iterator = iter(dataset) 587 >>> iterator.element_spec 588 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 589 590 Returns: 591 A nested structure of `tf.TypeSpec` objects matching the structure of an 592 element of this iterator, specifying the type of individual components. 593 """ 594 raise NotImplementedError("Iterator.element_spec") 595 596 @abc.abstractmethod 597 def get_next(self): 598 """Returns a nested structure of `tf.Tensor`s containing the next element. 599 600 >>> dataset = tf.data.Dataset.from_tensors(42) 601 >>> iterator = iter(dataset) 602 >>> print(iterator.get_next()) 603 tf.Tensor(42, shape=(), dtype=int32) 604 605 Returns: 606 A nested structure of `tf.Tensor` objects. 607 608 Raises: 609 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 610 """ 611 raise NotImplementedError("Iterator.get_next()") 612 613 @abc.abstractmethod 614 def get_next_as_optional(self): 615 """Returns a `tf.experimental.Optional` which contains the next element. 616 617 If the iterator has reached the end of the sequence, the returned 618 `tf.experimental.Optional` will have no value. 619 620 >>> dataset = tf.data.Dataset.from_tensors(42) 621 >>> iterator = iter(dataset) 622 >>> optional = iterator.get_next_as_optional() 623 >>> print(optional.has_value()) 624 tf.Tensor(True, shape=(), dtype=bool) 625 >>> print(optional.get_value()) 626 tf.Tensor(42, shape=(), dtype=int32) 627 >>> optional = iterator.get_next_as_optional() 628 >>> print(optional.has_value()) 629 tf.Tensor(False, shape=(), dtype=bool) 630 631 Returns: 632 A `tf.experimental.Optional` object representing the next element. 633 """ 634 raise NotImplementedError("Iterator.get_next_as_optional()") 635 636 637class OwnedIterator(IteratorBase): 638 """An iterator producing tf.Tensor objects from a tf.data.Dataset. 639 640 The iterator resource created through `OwnedIterator` is owned by the Python 641 object and the life time of the underlying resource is tied to the life time 642 of the `OwnedIterator` object. This makes `OwnedIterator` appropriate for use 643 in eager mode and inside of tf.functions. 644 """ 645 646 def __init__(self, dataset=None, components=None, element_spec=None): 647 """Creates a new iterator from the given dataset. 648 649 If `dataset` is not specified, the iterator will be created from the given 650 tensor components and element structure. In particular, the alternative for 651 constructing the iterator is used when the iterator is reconstructed from 652 it `CompositeTensor` representation. 653 654 Args: 655 dataset: A `tf.data.Dataset` object. 656 components: Tensor components to construct the iterator from. 657 element_spec: A nested structure of `TypeSpec` objects that 658 represents the type specification of elements of the iterator. 659 660 Raises: 661 ValueError: If `dataset` is not provided and either `components` or 662 `element_spec` is not provided. Or `dataset` is provided and either 663 `components` and `element_spec` is provided. 664 """ 665 super(OwnedIterator, self).__init__() 666 error_message = ("Either `dataset` or both `components` and " 667 "`element_spec` need to be provided.") 668 669 if dataset is None: 670 if (components is None or element_spec is None): 671 raise ValueError(error_message) 672 # pylint: disable=protected-access 673 self._element_spec = element_spec 674 self._flat_output_types = structure.get_flat_tensor_types( 675 self._element_spec) 676 self._flat_output_shapes = structure.get_flat_tensor_shapes( 677 self._element_spec) 678 self._iterator_resource, self._deleter = components 679 else: 680 if (components is not None or element_spec is not None): 681 raise ValueError(error_message) 682 self._create_iterator(dataset) 683 684 def _create_iterator(self, dataset): 685 # pylint: disable=protected-access 686 dataset = dataset._apply_options() 687 688 # Store dataset reference to ensure that dataset is alive when this iterator 689 # is being used. For example, `tf.data.Dataset.from_generator` registers 690 # a few py_funcs that are needed in `self._next_internal`. If the dataset 691 # is deleted, this iterator crashes on `self.__next__(...)` call. 692 self._dataset = dataset 693 694 ds_variant = dataset._variant_tensor 695 self._element_spec = dataset.element_spec 696 self._flat_output_types = structure.get_flat_tensor_types( 697 self._element_spec) 698 self._flat_output_shapes = structure.get_flat_tensor_shapes( 699 self._element_spec) 700 with ops.colocate_with(ds_variant): 701 self._iterator_resource, self._deleter = ( 702 gen_dataset_ops.anonymous_iterator_v2( 703 output_types=self._flat_output_types, 704 output_shapes=self._flat_output_shapes)) 705 gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) 706 # Delete the resource when this object is deleted 707 self._resource_deleter = IteratorResourceDeleter( 708 handle=self._iterator_resource, 709 deleter=self._deleter) 710 711 def __iter__(self): 712 return self 713 714 def next(self): # For Python 2 compatibility 715 return self.__next__() 716 717 def _next_internal(self): 718 if not context.executing_eagerly(): 719 # TODO(b/169442955): Investigate the need for this colocation constraint. 720 with ops.colocate_with(self._iterator_resource): 721 ret = gen_dataset_ops.iterator_get_next( 722 self._iterator_resource, 723 output_types=self._flat_output_types, 724 output_shapes=self._flat_output_shapes) 725 return structure.from_compatible_tensor_list(self._element_spec, ret) 726 727 # TODO(b/77291417): This runs in sync mode as iterators use an error status 728 # to communicate that there is no more data to iterate over. 729 with context.execution_mode(context.SYNC): 730 ret = gen_dataset_ops.iterator_get_next( 731 self._iterator_resource, 732 output_types=self._flat_output_types, 733 output_shapes=self._flat_output_shapes) 734 735 try: 736 # Fast path for the case `self._structure` is not a nested structure. 737 return self._element_spec._from_compatible_tensor_list(ret) # pylint: disable=protected-access 738 except AttributeError: 739 return structure.from_compatible_tensor_list(self._element_spec, ret) 740 741 @property 742 def _type_spec(self): 743 return IteratorSpec(self.element_spec) 744 745 def __next__(self): 746 try: 747 return self._next_internal() 748 except errors.OutOfRangeError: 749 raise StopIteration 750 751 @property 752 @deprecation.deprecated( 753 None, "Use `tf.compat.v1.data.get_output_classes(iterator)`.") 754 def output_classes(self): 755 """Returns the class of each component of an element of this iterator. 756 757 The expected values are `tf.Tensor` and `tf.sparse.SparseTensor`. 758 759 Returns: 760 A nested structure of Python `type` objects corresponding to each 761 component of an element of this dataset. 762 """ 763 return nest.map_structure( 764 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 765 self._element_spec) 766 767 @property 768 @deprecation.deprecated( 769 None, "Use `tf.compat.v1.data.get_output_shapes(iterator)`.") 770 def output_shapes(self): 771 """Returns the shape of each component of an element of this iterator. 772 773 Returns: 774 A nested structure of `tf.TensorShape` objects corresponding to each 775 component of an element of this dataset. 776 """ 777 return nest.map_structure( 778 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 779 self._element_spec) 780 781 @property 782 @deprecation.deprecated( 783 None, "Use `tf.compat.v1.data.get_output_types(iterator)`.") 784 def output_types(self): 785 """Returns the type of each component of an element of this iterator. 786 787 Returns: 788 A nested structure of `tf.DType` objects corresponding to each component 789 of an element of this dataset. 790 """ 791 return nest.map_structure( 792 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 793 self._element_spec) 794 795 @property 796 def element_spec(self): 797 return self._element_spec 798 799 def get_next(self): 800 return self._next_internal() 801 802 def get_next_as_optional(self): 803 # TODO(b/169442955): Investigate the need for this colocation constraint. 804 with ops.colocate_with(self._iterator_resource): 805 # pylint: disable=protected-access 806 return optional_ops._OptionalImpl( 807 gen_dataset_ops.iterator_get_next_as_optional( 808 self._iterator_resource, 809 output_types=structure.get_flat_tensor_types(self.element_spec), 810 output_shapes=structure.get_flat_tensor_shapes( 811 self.element_spec)), self.element_spec) 812 813 def _gather_saveables_for_checkpoint(self): 814 815 def _saveable_factory(name): 816 """Returns a SaveableObject for serialization/deserialization.""" 817 policy = None 818 if self._dataset: 819 policy = self._dataset.options().experimental_external_state_policy 820 if policy: 821 return _IteratorSaveable( 822 self._iterator_resource, 823 name, 824 external_state_policy=policy) 825 else: 826 return _IteratorSaveable(self._iterator_resource, name) 827 828 return {"ITERATOR": _saveable_factory} 829 830 831@tf_export("data.IteratorSpec", v1=[]) 832class IteratorSpec(type_spec.TypeSpec): 833 """Type specification for `tf.data.Iterator`. 834 835 For instance, `tf.data.IteratorSpec` can be used to define a tf.function that 836 takes `tf.data.Iterator` as an input argument: 837 838 >>> @tf.function(input_signature=[tf.data.IteratorSpec( 839 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 840 ... def square(iterator): 841 ... x = iterator.get_next() 842 ... return x * x 843 >>> dataset = tf.data.Dataset.from_tensors(5) 844 >>> iterator = iter(dataset) 845 >>> print(square(iterator)) 846 tf.Tensor(25, shape=(), dtype=int32) 847 848 Attributes: 849 element_spec: A nested structure of `TypeSpec` objects that represents the 850 type specification of the iterator elements. 851 """ 852 853 __slots__ = ["_element_spec"] 854 855 def __init__(self, element_spec): 856 self._element_spec = element_spec 857 858 @property 859 def value_type(self): 860 return OwnedIterator 861 862 def _serialize(self): 863 return (self._element_spec,) 864 865 @property 866 def _component_specs(self): 867 return ( 868 tensor_spec.TensorSpec([], dtypes.resource), 869 tensor_spec.TensorSpec([], dtypes.variant), 870 ) 871 872 def _to_components(self, value): 873 return (value._iterator_resource, value._deleter) # pylint: disable=protected-access 874 875 def _from_components(self, components): 876 return OwnedIterator( 877 dataset=None, 878 components=components, 879 element_spec=self._element_spec) 880 881 @staticmethod 882 def from_value(value): 883 return IteratorSpec(value.element_spec) # pylint: disable=protected-access 884 885 886# TODO(b/71645805): Expose trackable stateful objects from dataset. 887class _IteratorSaveable(BaseSaverBuilder.SaveableObject): 888 """SaveableObject for saving/restoring iterator state.""" 889 890 def __init__( 891 self, 892 iterator_resource, 893 name, 894 external_state_policy=distribute_options.ExternalStatePolicy.FAIL): 895 serialized_iterator = gen_dataset_ops.serialize_iterator( 896 iterator_resource, external_state_policy=external_state_policy.value) 897 specs = [ 898 BaseSaverBuilder.SaveSpec( 899 serialized_iterator, 900 "", 901 name + "_STATE", 902 device=iterator_resource.device) 903 ] 904 super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) 905 906 def restore(self, restored_tensors, restored_shapes): 907 with ops.colocate_with(self.op): 908 return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) 909 910 911@deprecation.deprecated( 912 None, "Use `tf.data.Iterator.get_next_as_optional()` instead.") 913@tf_export("data.experimental.get_next_as_optional") 914def get_next_as_optional(iterator): 915 """Returns a `tf.experimental.Optional` with the next element of the iterator. 916 917 If the iterator has reached the end of the sequence, the returned 918 `tf.experimental.Optional` will have no value. 919 920 Args: 921 iterator: A `tf.data.Iterator`. 922 923 Returns: 924 A `tf.experimental.Optional` object which either contains the next element 925 of the iterator (if it exists) or no value. 926 """ 927 return iterator.get_next_as_optional() 928