1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Iterators.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import threading 21import warnings 22 23from tensorflow.python.compat import compat 24from tensorflow.python.data.ops import optional_ops 25from tensorflow.python.data.util import nest 26from tensorflow.python.data.util import structure as structure_lib 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.ops import gen_dataset_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.training.saver import BaseSaverBuilder 35from tensorflow.python.training.tracking import base as trackable 36from tensorflow.python.util.tf_export import tf_export 37 38 39# NOTE(mrry): It is legitimate to call `Iterator.get_next()` multiple 40# times, e.g. when you are distributing different elements to multiple 41# devices in a single step. However, a common pitfall arises when 42# users call `Iterator.get_next()` in each iteration of their training 43# loop. `Iterator.get_next()` adds ops to the graph, and executing 44# each op allocates resources (including threads); as a consequence, 45# invoking it in every iteration of a training loop causes slowdown 46# and eventual resource exhaustion. To guard against this outcome, we 47# log a warning when the number of uses crosses a threshold of suspicion. 48GET_NEXT_CALL_WARNING_THRESHOLD = 32 49 50GET_NEXT_CALL_WARNING_MESSAGE = ( 51 "An unusually high number of `Iterator.get_next()` calls was detected. " 52 "This often indicates that `Iterator.get_next()` is being called inside " 53 "a training loop, which will cause gradual slowdown and eventual resource " 54 "exhaustion. If this is the case, restructure your code to call " 55 "`next_element = iterator.get_next()` once outside the loop, and use " 56 "`next_element` as the input to some computation that is invoked inside " 57 "the loop.") 58 59# Collection of all IteratorResources in the `Graph`. 60GLOBAL_ITERATORS = "iterators" 61 62 63def _device_stack_is_empty(): 64 # pylint: disable=protected-access 65 device_stack = ops.get_default_graph()._device_functions_outer_to_inner 66 # pylint: enable=protected-access 67 return not bool(device_stack) 68 69 70@tf_export(v1=["data.Iterator"]) 71class Iterator(trackable.Trackable): 72 """Represents the state of iterating through a `Dataset`.""" 73 74 def __init__(self, iterator_resource, initializer, output_types, 75 output_shapes, output_classes): 76 """Creates a new iterator from the given iterator resource. 77 78 Note: Most users will not call this initializer directly, and will 79 instead use `Dataset.make_initializable_iterator()` or 80 `Dataset.make_one_shot_iterator()`. 81 82 Args: 83 iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the 84 iterator. 85 initializer: A `tf.Operation` that should be run to initialize this 86 iterator. 87 output_types: A nested structure of `tf.DType` objects corresponding to 88 each component of an element of this iterator. 89 output_shapes: A nested structure of `tf.TensorShape` objects 90 corresponding to each component of an element of this iterator. 91 output_classes: A nested structure of Python `type` objects corresponding 92 to each component of an element of this iterator. 93 """ 94 self._iterator_resource = iterator_resource 95 self._initializer = initializer 96 97 if (output_types is None or output_shapes is None 98 or output_classes is None): 99 raise ValueError("If `structure` is not specified, all of " 100 "`output_types`, `output_shapes`, and `output_classes`" 101 " must be specified.") 102 self._structure = structure_lib.convert_legacy_structure( 103 output_types, output_shapes, output_classes) 104 105 self._string_handle = gen_dataset_ops.iterator_to_string_handle( 106 self._iterator_resource) 107 self._get_next_call_count = 0 108 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource) 109 110 @staticmethod 111 def from_structure(output_types, 112 output_shapes=None, 113 shared_name=None, 114 output_classes=None): 115 """Creates a new, uninitialized `Iterator` with the given structure. 116 117 This iterator-constructing method can be used to create an iterator that 118 is reusable with many different datasets. 119 120 The returned iterator is not bound to a particular dataset, and it has 121 no `initializer`. To initialize the iterator, run the operation returned by 122 `Iterator.make_initializer(dataset)`. 123 124 The following is an example 125 126 ```python 127 iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) 128 129 dataset_range = Dataset.range(10) 130 range_initializer = iterator.make_initializer(dataset_range) 131 132 dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) 133 evens_initializer = iterator.make_initializer(dataset_evens) 134 135 # Define a model based on the iterator; in this example, the model_fn 136 # is expected to take scalar tf.int64 Tensors as input (see 137 # the definition of 'iterator' above). 138 prediction, loss = model_fn(iterator.get_next()) 139 140 # Train for `num_epochs`, where for each epoch, we first iterate over 141 # dataset_range, and then iterate over dataset_evens. 142 for _ in range(num_epochs): 143 # Initialize the iterator to `dataset_range` 144 sess.run(range_initializer) 145 while True: 146 try: 147 pred, loss_val = sess.run([prediction, loss]) 148 except tf.errors.OutOfRangeError: 149 break 150 151 # Initialize the iterator to `dataset_evens` 152 sess.run(evens_initializer) 153 while True: 154 try: 155 pred, loss_val = sess.run([prediction, loss]) 156 except tf.errors.OutOfRangeError: 157 break 158 ``` 159 160 Args: 161 output_types: A nested structure of `tf.DType` objects corresponding to 162 each component of an element of this dataset. 163 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects 164 corresponding to each component of an element of this dataset. If 165 omitted, each component will have an unconstrainted shape. 166 shared_name: (Optional.) If non-empty, this iterator will be shared under 167 the given name across multiple sessions that share the same devices 168 (e.g. when using a remote server). 169 output_classes: (Optional.) A nested structure of Python `type` objects 170 corresponding to each component of an element of this iterator. If 171 omitted, each component is assumed to be of type `tf.Tensor`. 172 173 Returns: 174 An `Iterator`. 175 176 Raises: 177 TypeError: If the structures of `output_shapes` and `output_types` are 178 not the same. 179 """ 180 output_types = nest.map_structure(dtypes.as_dtype, output_types) 181 if output_shapes is None: 182 output_shapes = nest.map_structure( 183 lambda _: tensor_shape.TensorShape(None), output_types) 184 else: 185 output_shapes = nest.map_structure_up_to( 186 output_types, tensor_shape.as_shape, output_shapes) 187 if output_classes is None: 188 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 189 nest.assert_same_structure(output_types, output_shapes) 190 output_structure = structure_lib.convert_legacy_structure( 191 output_types, output_shapes, output_classes) 192 if shared_name is None: 193 shared_name = "" 194 # pylint: disable=protected-access 195 if compat.forward_compatible(2018, 8, 3): 196 if _device_stack_is_empty(): 197 with ops.device("/cpu:0"): 198 iterator_resource = gen_dataset_ops.iterator_v2( 199 container="", 200 shared_name=shared_name, 201 output_types=output_structure._flat_types, 202 output_shapes=output_structure._flat_shapes) 203 else: 204 iterator_resource = gen_dataset_ops.iterator_v2( 205 container="", 206 shared_name=shared_name, 207 output_types=output_structure._flat_types, 208 output_shapes=output_structure._flat_shapes) 209 else: 210 iterator_resource = gen_dataset_ops.iterator( 211 container="", 212 shared_name=shared_name, 213 output_types=output_structure._flat_types, 214 output_shapes=output_structure._flat_shapes) 215 # pylint: enable=protected-access 216 return Iterator(iterator_resource, None, output_types, output_shapes, 217 output_classes) 218 219 @staticmethod 220 def from_string_handle(string_handle, 221 output_types, 222 output_shapes=None, 223 output_classes=None): 224 """Creates a new, uninitialized `Iterator` based on the given handle. 225 226 This method allows you to define a "feedable" iterator where you can choose 227 between concrete iterators by feeding a value in a `tf.Session.run` call. 228 In that case, `string_handle` would be a `tf.placeholder`, and you would 229 feed it with the value of `tf.data.Iterator.string_handle` in each step. 230 231 For example, if you had two iterators that marked the current position in 232 a training dataset and a test dataset, you could choose which to use in 233 each step as follows: 234 235 ```python 236 train_iterator = tf.data.Dataset(...).make_one_shot_iterator() 237 train_iterator_handle = sess.run(train_iterator.string_handle()) 238 239 test_iterator = tf.data.Dataset(...).make_one_shot_iterator() 240 test_iterator_handle = sess.run(test_iterator.string_handle()) 241 242 handle = tf.placeholder(tf.string, shape=[]) 243 iterator = tf.data.Iterator.from_string_handle( 244 handle, train_iterator.output_types) 245 246 next_element = iterator.get_next() 247 loss = f(next_element) 248 249 train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) 250 test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) 251 ``` 252 253 Args: 254 string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates 255 to a handle produced by the `Iterator.string_handle()` method. 256 output_types: A nested structure of `tf.DType` objects corresponding to 257 each component of an element of this dataset. 258 output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects 259 corresponding to each component of an element of this dataset. If 260 omitted, each component will have an unconstrainted shape. 261 output_classes: (Optional.) A nested structure of Python `type` objects 262 corresponding to each component of an element of this iterator. If 263 omitted, each component is assumed to be of type `tf.Tensor`. 264 265 Returns: 266 An `Iterator`. 267 """ 268 output_types = nest.map_structure(dtypes.as_dtype, output_types) 269 if output_shapes is None: 270 output_shapes = nest.map_structure( 271 lambda _: tensor_shape.TensorShape(None), output_types) 272 else: 273 output_shapes = nest.map_structure_up_to( 274 output_types, tensor_shape.as_shape, output_shapes) 275 if output_classes is None: 276 output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) 277 nest.assert_same_structure(output_types, output_shapes) 278 output_structure = structure_lib.convert_legacy_structure( 279 output_types, output_shapes, output_classes) 280 string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) 281 # pylint: disable=protected-access 282 if compat.forward_compatible(2018, 8, 3): 283 if _device_stack_is_empty(): 284 with ops.device("/cpu:0"): 285 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 286 string_handle, 287 output_types=output_structure._flat_types, 288 output_shapes=output_structure._flat_shapes) 289 else: 290 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 291 string_handle, 292 output_types=output_structure._flat_types, 293 output_shapes=output_structure._flat_shapes) 294 else: 295 iterator_resource = gen_dataset_ops.iterator_from_string_handle( 296 string_handle, 297 output_types=output_structure._flat_types, 298 output_shapes=output_structure._flat_shapes) 299 # pylint: enable=protected-access 300 return Iterator(iterator_resource, None, output_types, output_shapes, 301 output_classes) 302 303 @property 304 def initializer(self): 305 """A `tf.Operation` that should be run to initialize this iterator. 306 307 Returns: 308 A `tf.Operation` that should be run to initialize this iterator 309 310 Raises: 311 ValueError: If this iterator initializes itself automatically. 312 """ 313 if self._initializer is not None: 314 return self._initializer 315 else: 316 # TODO(mrry): Consider whether one-shot iterators should have 317 # initializers that simply reset their state to the beginning. 318 raise ValueError("Iterator does not have an initializer.") 319 320 def make_initializer(self, dataset, name=None): 321 """Returns a `tf.Operation` that initializes this iterator on `dataset`. 322 323 Args: 324 dataset: A `Dataset` with compatible structure to this iterator. 325 name: (Optional.) A name for the created operation. 326 327 Returns: 328 A `tf.Operation` that can be run to initialize this iterator on the given 329 `dataset`. 330 331 Raises: 332 TypeError: If `dataset` and this iterator do not have a compatible 333 element structure. 334 """ 335 with ops.name_scope(name, "make_initializer") as name: 336 # pylint: disable=protected-access 337 # NOTE(mrry): Cannot depend on `dataset_ops.get_legacy_output*()` due 338 # to that creating a circular dependency. 339 dataset_output_types = ( 340 dataset._element_structure._to_legacy_output_types()) 341 dataset_output_shapes = ( 342 dataset._element_structure._to_legacy_output_shapes()) 343 dataset_output_classes = ( 344 dataset._element_structure._to_legacy_output_classes()) 345 # pylint: enable=protected-access 346 347 nest.assert_same_structure(self.output_types, dataset_output_types) 348 nest.assert_same_structure(self.output_shapes, dataset_output_shapes) 349 for iterator_class, dataset_class in zip( 350 nest.flatten(self.output_classes), 351 nest.flatten(dataset_output_classes)): 352 if iterator_class is not dataset_class: 353 raise TypeError( 354 "Expected output classes %r but got dataset with output class %r." 355 % (self.output_classes, dataset_output_classes)) 356 for iterator_dtype, dataset_dtype in zip( 357 nest.flatten(self.output_types), nest.flatten(dataset_output_types)): 358 if iterator_dtype != dataset_dtype: 359 raise TypeError( 360 "Expected output types %r but got dataset with output types %r." % 361 (self.output_types, dataset_output_types)) 362 for iterator_shape, dataset_shape in zip( 363 nest.flatten(self.output_shapes), nest.flatten( 364 dataset_output_shapes)): 365 if not iterator_shape.is_compatible_with(dataset_shape): 366 raise TypeError("Expected output shapes compatible with %r but got " 367 "dataset with output shapes %r." % 368 (self.output_shapes, dataset_output_shapes)) 369 with ops.colocate_with(self._iterator_resource): 370 return gen_dataset_ops.make_iterator( 371 dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access 372 373 def get_next(self, name=None): 374 """Returns a nested structure of `tf.Tensor`s representing the next element. 375 376 In graph mode, you should typically call this method *once* and use its 377 result as the input to another computation. A typical loop will then call 378 `tf.Session.run` on the result of that computation. The loop will terminate 379 when the `Iterator.get_next()` operation raises 380 `tf.errors.OutOfRangeError`. The following skeleton shows how to use 381 this method when building a training loop: 382 383 ```python 384 dataset = ... # A `tf.data.Dataset` object. 385 iterator = dataset.make_initializable_iterator() 386 next_element = iterator.get_next() 387 388 # Build a TensorFlow graph that does something with each element. 389 loss = model_function(next_element) 390 optimizer = ... # A `tf.train.Optimizer` object. 391 train_op = optimizer.minimize(loss) 392 393 with tf.Session() as sess: 394 try: 395 while True: 396 sess.run(train_op) 397 except tf.errors.OutOfRangeError: 398 pass 399 ``` 400 401 NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. 402 when you are distributing different elements to multiple devices in a single 403 step. However, a common pitfall arises when users call `Iterator.get_next()` 404 in each iteration of their training loop. `Iterator.get_next()` adds ops to 405 the graph, and executing each op allocates resources (including threads); as 406 a consequence, invoking it in every iteration of a training loop causes 407 slowdown and eventual resource exhaustion. To guard against this outcome, we 408 log a warning when the number of uses crosses a fixed threshold of 409 suspiciousness. 410 411 Args: 412 name: (Optional.) A name for the created operation. 413 414 Returns: 415 A nested structure of `tf.Tensor` objects. 416 """ 417 self._get_next_call_count += 1 418 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: 419 warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) 420 421 # pylint: disable=protected-access 422 flat_ret = gen_dataset_ops.iterator_get_next( 423 self._iterator_resource, 424 output_types=self._structure._flat_types, 425 output_shapes=self._structure._flat_shapes, name=name) 426 return self._structure._from_tensor_list(flat_ret) 427 428 def string_handle(self, name=None): 429 """Returns a string-valued `tf.Tensor` that represents this iterator. 430 431 Args: 432 name: (Optional.) A name for the created operation. 433 434 Returns: 435 A scalar `tf.Tensor` of type `tf.string`. 436 """ 437 if name is None: 438 return self._string_handle 439 else: 440 return gen_dataset_ops.iterator_to_string_handle( 441 self._iterator_resource, name=name) 442 443 @property 444 def output_classes(self): 445 """Returns the class of each component of an element of this iterator. 446 447 The expected values are `tf.Tensor` and `tf.SparseTensor`. 448 449 Returns: 450 A nested structure of Python `type` objects corresponding to each 451 component of an element of this dataset. 452 """ 453 return self._structure._to_legacy_output_classes() # pylint: disable=protected-access 454 455 @property 456 def output_shapes(self): 457 """Returns the shape of each component of an element of this iterator. 458 459 Returns: 460 A nested structure of `tf.TensorShape` objects corresponding to each 461 component of an element of this dataset. 462 """ 463 return self._structure._to_legacy_output_shapes() # pylint: disable=protected-access 464 465 @property 466 def output_types(self): 467 """Returns the type of each component of an element of this iterator. 468 469 Returns: 470 A nested structure of `tf.DType` objects corresponding to each component 471 of an element of this dataset. 472 """ 473 return self._structure._to_legacy_output_types() # pylint: disable=protected-access 474 475 @property 476 def _element_structure(self): 477 """The structure of an element of this iterator. 478 479 Returns: 480 A `Structure` object representing the structure of the components of this 481 optional. 482 """ 483 return self._structure 484 485 def _gather_saveables_for_checkpoint(self): 486 487 def _saveable_factory(name): 488 return _IteratorSaveable(self._iterator_resource, name) 489 490 return {"ITERATOR": _saveable_factory} 491 492 493_uid_counter = 0 494_uid_lock = threading.Lock() 495 496 497def _generate_shared_name(prefix): 498 with _uid_lock: 499 global _uid_counter 500 uid = _uid_counter 501 _uid_counter += 1 502 return "{}{}".format(prefix, uid) 503 504 505class EagerIterator(trackable.Trackable): 506 """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" 507 508 def __init__(self, dataset): 509 """Creates a new iterator over the given dataset. 510 511 For example: 512 ```python 513 dataset = tf.data.Dataset.range(4) 514 for x in Iterator(dataset): 515 print(x) 516 ``` 517 518 Tensors produced will be placed on the device on which this iterator object 519 was created. 520 521 Args: 522 dataset: A `tf.data.Dataset` object. 523 524 Raises: 525 RuntimeError: When invoked without eager execution enabled. 526 """ 527 528 if not context.executing_eagerly(): 529 raise RuntimeError( 530 "{} objects can only be used when eager execution is enabled, use " 531 "tf.data.Dataset.make_initializable_iterator or " 532 "tf.data.Dataset.make_one_shot_iterator for graph construction". 533 format(type(self))) 534 self._device = context.context().device_name 535 with ops.device("/cpu:0"): 536 # pylint: disable=protected-access 537 dataset = dataset._apply_options() 538 ds_variant = dataset._variant_tensor 539 self._structure = dataset._element_structure 540 self._flat_output_types = self._structure._flat_types 541 self._flat_output_shapes = self._structure._flat_shapes 542 with ops.colocate_with(ds_variant): 543 self._iterator_resource = gen_dataset_ops.anonymous_iterator( 544 output_types=self._flat_output_types, 545 output_shapes=self._flat_output_shapes) 546 gen_dataset_ops.make_iterator(ds_variant, self._iterator_resource) 547 # Delete the resource when this object is deleted 548 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 549 handle=self._iterator_resource, handle_device=self._device) 550 # pylint: enable=protected-access 551 552 def __iter__(self): 553 return self 554 555 def __next__(self): # For Python 3 compatibility 556 return self.next() 557 558 def _next_internal(self): 559 """Returns a nested structure of `tf.Tensor`s containing the next element. 560 """ 561 # This runs in sync mode as iterators use an error status to communicate 562 # that there is no more data to iterate over. 563 # TODO(b/77291417): Fix 564 with context.execution_mode(context.SYNC): 565 with ops.device(self._device): 566 # TODO(ashankar): Consider removing this ops.device() contextmanager 567 # and instead mimic ops placement in graphs: Operations on resource 568 # handles execute on the same device as where the resource is placed. 569 # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` 570 # because in eager mode this code will run synchronously on the calling 571 # thread. Therefore we do not need to make a defensive context switch 572 # to a background thread, and can achieve a small constant performance 573 # boost by invoking the iterator synchronously. 574 ret = gen_dataset_ops.iterator_get_next_sync( 575 self._iterator_resource, 576 output_types=self._flat_output_types, 577 output_shapes=self._flat_output_shapes) 578 579 return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access 580 581 def next(self): 582 """Returns a nested structure of `tf.Tensor`s containing the next element. 583 """ 584 try: 585 return self._next_internal() 586 except errors.OutOfRangeError: 587 raise StopIteration 588 589 @property 590 def output_classes(self): 591 """Returns the class of each component of an element of this iterator. 592 593 The expected values are `tf.Tensor` and `tf.SparseTensor`. 594 595 Returns: 596 A nested structure of Python `type` objects corresponding to each 597 component of an element of this dataset. 598 """ 599 return self._structure._to_legacy_output_classes() # pylint: disable=protected-access 600 601 @property 602 def output_shapes(self): 603 """Returns the shape of each component of an element of this iterator. 604 605 Returns: 606 A nested structure of `tf.TensorShape` objects corresponding to each 607 component of an element of this dataset. 608 """ 609 return self._structure._to_legacy_output_shapes() # pylint: disable=protected-access 610 611 @property 612 def output_types(self): 613 """Returns the type of each component of an element of this iterator. 614 615 Returns: 616 A nested structure of `tf.DType` objects corresponding to each component 617 of an element of this dataset. 618 """ 619 return self._structure._to_legacy_output_types() # pylint: disable=protected-access 620 621 @property 622 def _element_structure(self): 623 """The structure of an element of this iterator. 624 625 Returns: 626 A `Structure` object representing the structure of the components of this 627 optional. 628 """ 629 return self._structure 630 631 def get_next(self, name=None): 632 """Returns a nested structure of `tf.Tensor`s containing the next element. 633 634 Args: 635 name: (Optional.) A name for the created operation. Currently unused. 636 637 Returns: 638 A nested structure of `tf.Tensor` objects. 639 640 Raises: 641 `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. 642 """ 643 del name 644 return self._next_internal() 645 646 def _gather_saveables_for_checkpoint(self): 647 648 def _saveable_factory(name): 649 return _IteratorSaveable(self._iterator_resource, name) 650 651 return {"ITERATOR": _saveable_factory} 652 653 654# TODO(b/71645805): Expose trackable stateful objects from dataset 655# attributes(potential). 656class _IteratorSaveable(BaseSaverBuilder.SaveableObject): 657 """SaveableObject for saving/restoring iterator state.""" 658 659 def __init__(self, iterator_resource, name): 660 serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) 661 specs = [ 662 BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") 663 ] 664 super(_IteratorSaveable, self).__init__(iterator_resource, specs, name) 665 666 def restore(self, restored_tensors, restored_shapes): 667 with ops.colocate_with(self.op): 668 return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) 669 670 671def get_next_as_optional(iterator): 672 """Returns an `Optional` that contains the next value from the iterator. 673 674 If `iterator` has reached the end of the sequence, the returned `Optional` 675 will have no value. 676 677 Args: 678 iterator: A `tf.data.Iterator` object. 679 680 Returns: 681 An `Optional` object representing the next value from the iterator (if it 682 has one) or no value. 683 """ 684 # pylint: disable=protected-access 685 return optional_ops._OptionalImpl( 686 gen_dataset_ops.iterator_get_next_as_optional( 687 iterator._iterator_resource, 688 output_types=iterator._element_structure._flat_types, 689 output_shapes=iterator._element_structure._flat_shapes), 690 iterator._element_structure) 691