1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Iterator`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import warnings 21 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.core.protobuf import cluster_pb2 26from tensorflow.core.protobuf import config_pb2 27from tensorflow.python.client import session 28from tensorflow.python.data.kernel_tests import test_base 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.data.ops import iterator_ops 31from tensorflow.python.data.util import structure 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import combinations 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import function 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.framework import tensor_spec 42from tensorflow.python.framework import test_util 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import data_flow_ops 45from tensorflow.python.ops import functional_ops 46from tensorflow.python.ops import gradients_impl 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import parsing_ops 49from tensorflow.python.ops import script_ops 50from tensorflow.python.ops import variables 51from tensorflow.python.platform import test 52from tensorflow.python.training import server_lib 53from tensorflow.python.util import compat 54 55 56class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): 57 58 @combinations.generate(test_base.graph_only_combinations()) 59 def testNoGradients(self): 60 component = constant_op.constant([1.]) 61 side = constant_op.constant(0.) 62 add = lambda x: x + side 63 dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) 64 value = dataset_ops.make_one_shot_iterator(dataset).get_next() 65 self.assertIsNone(gradients_impl.gradients(value, component)[0]) 66 self.assertIsNone(gradients_impl.gradients(value, side)[0]) 67 self.assertIsNone(gradients_impl.gradients(value, [component, side])[0]) 68 69 @combinations.generate(test_base.graph_only_combinations()) 70 def testCapturingStateInOneShotRaisesException(self): 71 var = variables.Variable(37.0, name="myvar") 72 dataset = ( 73 dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) 74 .map(lambda x: x + var)) 75 with self.assertRaisesRegex( 76 ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " 77 "datasets that capture stateful objects.+myvar"): 78 dataset_ops.make_one_shot_iterator(dataset) 79 80 @combinations.generate(test_base.graph_only_combinations()) 81 def testOneShotIterator(self): 82 components = (np.arange(7), 83 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 84 np.array(37.0) * np.arange(7)) 85 86 def _map_fn(x, y, z): 87 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 88 89 iterator = dataset_ops.make_one_shot_iterator( 90 dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 91 .repeat(14)) 92 get_next = iterator.get_next() 93 94 self.assertEqual([c.shape[1:] for c in components], 95 [t.shape for t in get_next]) 96 97 with self.cached_session() as sess: 98 for _ in range(14): 99 for i in range(7): 100 result = sess.run(get_next) 101 for component, result_component in zip(components, result): 102 self.assertAllEqual(component[i]**2, result_component) 103 with self.assertRaises(errors.OutOfRangeError): 104 sess.run(get_next) 105 106 @combinations.generate(test_base.graph_only_combinations()) 107 def testOneShotIteratorCaptureByValue(self): 108 components = (np.arange(7), 109 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 110 np.array(37.0) * np.arange(7)) 111 tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) 112 113 def _map_fn(x, y, z): 114 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 115 116 iterator = dataset_ops.make_one_shot_iterator( 117 dataset_ops.Dataset.from_tensor_slices(tensor_components) 118 .map(_map_fn).repeat(14)) 119 get_next = iterator.get_next() 120 121 self.assertEqual([c.shape[1:] for c in components], 122 [t.shape for t in get_next]) 123 124 with self.cached_session() as sess: 125 for _ in range(14): 126 for i in range(7): 127 result = sess.run(get_next) 128 for component, result_component in zip(components, result): 129 self.assertAllEqual(component[i]**2, result_component) 130 with self.assertRaises(errors.OutOfRangeError): 131 sess.run(get_next) 132 133 @combinations.generate(test_base.default_test_combinations()) 134 def testOneShotIteratorInsideContainer(self): 135 components = (np.arange(7), 136 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 137 np.array(37.0) * np.arange(7)) 138 139 def within_container(): 140 141 def _map_fn(x, y, z): 142 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 143 144 iterator = dataset_ops.make_one_shot_iterator( 145 dataset_ops.Dataset.from_tensor_slices(components) 146 .map(_map_fn).repeat(14)) 147 return iterator.get_next() 148 149 server = server_lib.Server.create_local_server() 150 151 # Create two iterators within unique containers, and run them to 152 # make sure that the resources aren't shared. 153 # 154 # The test below would fail if cname were the same across both 155 # sessions. 156 for j in range(2): 157 with session.Session(server.target) as sess: 158 cname = "iteration%d" % j 159 with ops.container(cname): 160 get_next = within_container() 161 162 for _ in range(14): 163 for i in range(7): 164 result = sess.run(get_next) 165 for component, result_component in zip(components, result): 166 self.assertAllEqual(component[i]**2, result_component) 167 with self.assertRaises(errors.OutOfRangeError): 168 sess.run(get_next) 169 170 @combinations.generate(test_base.graph_only_combinations()) 171 def testOneShotIteratorNonBlocking(self): 172 dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) 173 iterator = dataset_ops.make_one_shot_iterator(dataset) 174 next_element = iterator.get_next() 175 176 # Create a session with a single thread to ensure that the 177 # one-shot iterator initializer does not deadlock. 178 config = config_pb2.ConfigProto( 179 inter_op_parallelism_threads=1, use_per_session_threads=True) 180 with session.Session(config=config) as sess: 181 self.assertAllEqual([1, 4, 9], sess.run(next_element)) 182 with self.assertRaises(errors.OutOfRangeError): 183 sess.run(next_element) 184 185 # Test with multiple threads invoking the one-shot iterator concurrently. 186 with session.Session(config=config) as sess: 187 results = [] 188 189 def consumer_thread(): 190 try: 191 results.append(sess.run(next_element)) 192 except errors.OutOfRangeError: 193 results.append(None) 194 195 num_threads = 8 196 threads = [ 197 self.checkedThread(consumer_thread) for _ in range(num_threads) 198 ] 199 for t in threads: 200 t.start() 201 for t in threads: 202 t.join() 203 204 self.assertLen(results, num_threads) 205 self.assertLen([None for r in results if r is None], num_threads - 1) 206 self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) 207 208 @combinations.generate(test_base.graph_only_combinations()) 209 def testOneShotIteratorInitializerFails(self): 210 # Define a dataset whose initialization will always fail. 211 dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4])) 212 iterator = dataset_ops.make_one_shot_iterator(dataset) 213 next_element = iterator.get_next() 214 215 with self.cached_session() as sess: 216 with self.assertRaisesRegex(errors.InvalidArgumentError, ""): 217 sess.run(next_element) 218 219 # Test that subsequent attempts to use the iterator also fail. 220 with self.assertRaisesRegex(errors.InvalidArgumentError, ""): 221 sess.run(next_element) 222 223 with self.cached_session() as sess: 224 225 def consumer_thread(): 226 with self.assertRaisesRegex(errors.InvalidArgumentError, ""): 227 sess.run(next_element) 228 229 num_threads = 8 230 threads = [ 231 self.checkedThread(consumer_thread) for _ in range(num_threads) 232 ] 233 for t in threads: 234 t.start() 235 for t in threads: 236 t.join() 237 238 @combinations.generate(test_base.graph_only_combinations()) 239 def testSimpleSharedResource(self): 240 components = (np.array(1, dtype=np.int64), 241 np.array([1, 2, 3], dtype=np.int64), 242 np.array(37.0, dtype=np.float64)) 243 244 server = server_lib.Server.create_local_server() 245 246 # Create two non-overlapping sessions that share the same iterator 247 # resource on the same server, and verify that an action of the 248 # first session (initializing the iterator) is visible in the 249 # second session. 250 with ops.Graph().as_default(): 251 iterator = dataset_ops.make_initializable_iterator( 252 dataset_ops.Dataset.from_tensors( 253 components).map(lambda x, y, z: (x, y, z)), 254 shared_name="shared_iterator") 255 init_op = iterator.initializer 256 get_next = iterator.get_next() 257 258 with session.Session(server.target) as sess: 259 sess.run(init_op) 260 results = sess.run(get_next) 261 for component, result_component in zip(components, results): 262 self.assertAllEqual(component, result_component) 263 with self.assertRaises(errors.OutOfRangeError): 264 sess.run(get_next) 265 266 # Re-initialize the iterator in the first session. 267 sess.run(init_op) 268 269 with ops.Graph().as_default(): 270 # Re-define the iterator manually, without defining any of the 271 # functions in this graph, to ensure that we are not 272 # accidentally redefining functions with the same names in the 273 # new graph. 274 iterator = iterator_ops.Iterator.from_structure( 275 shared_name="shared_iterator", 276 output_types=(dtypes.int64, dtypes.int64, dtypes.float64), 277 output_shapes=([], [3], [])) 278 get_next = iterator.get_next() 279 280 with session.Session(server.target) as sess: 281 # Use the iterator without re-initializing in the second session. 282 results = sess.run(get_next) 283 for component, result_component in zip(components, results): 284 self.assertAllEqual(component, result_component) 285 with self.assertRaises(errors.OutOfRangeError): 286 sess.run(get_next) 287 288 @combinations.generate(test_base.graph_only_combinations()) 289 def testNotInitializedError(self): 290 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 291 iterator = dataset_ops.make_initializable_iterator( 292 dataset_ops.Dataset.from_tensors(components)) 293 get_next = iterator.get_next() 294 295 with self.cached_session() as sess: 296 with self.assertRaisesRegex(errors.FailedPreconditionError, 297 "iterator has not been initialized"): 298 sess.run(get_next) 299 300 @combinations.generate(test_base.graph_only_combinations()) 301 def testReinitializableIterator(self): 302 dataset_3 = dataset_ops.Dataset.from_tensors( 303 constant_op.constant([1, 2, 3])) 304 dataset_4 = dataset_ops.Dataset.from_tensors( 305 constant_op.constant([4, 5, 6, 7])) 306 iterator = iterator_ops.Iterator.from_structure( 307 dataset_ops.get_legacy_output_types(dataset_3), [None]) 308 309 dataset_3_init_op = iterator.make_initializer(dataset_3) 310 dataset_4_init_op = iterator.make_initializer(dataset_4) 311 get_next = iterator.get_next() 312 313 self.assertEqual( 314 dataset_ops.get_legacy_output_types(dataset_3), 315 dataset_ops.get_legacy_output_types(iterator)) 316 self.assertEqual( 317 dataset_ops.get_legacy_output_types(dataset_4), 318 dataset_ops.get_legacy_output_types(iterator)) 319 self.assertEqual( 320 [None], dataset_ops.get_legacy_output_shapes(iterator).as_list()) 321 322 with self.cached_session() as sess: 323 # The iterator is initially uninitialized. 324 with self.assertRaises(errors.FailedPreconditionError): 325 sess.run(get_next) 326 327 # Initialize with one dataset. 328 sess.run(dataset_3_init_op) 329 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 330 with self.assertRaises(errors.OutOfRangeError): 331 sess.run(get_next) 332 333 # Initialize with a different dataset. 334 sess.run(dataset_4_init_op) 335 self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) 336 with self.assertRaises(errors.OutOfRangeError): 337 sess.run(get_next) 338 339 # Reinitialize with the first dataset. 340 sess.run(dataset_3_init_op) 341 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 342 with self.assertRaises(errors.OutOfRangeError): 343 sess.run(get_next) 344 345 @combinations.generate(test_base.graph_only_combinations()) 346 def testReinitializableIteratorWithFunctions(self): 347 348 def g(): 349 for i in range(10): 350 yield i 351 352 iterator = iterator_ops.Iterator.from_structure(dtypes.int64, []) 353 next_element = iterator.get_next() 354 355 with self.cached_session() as sess: 356 dataset_1 = dataset_ops.Dataset.from_generator( 357 g, output_types=dtypes.int64) 358 sess.run(iterator.make_initializer(dataset_1)) 359 for expected in range(10): 360 self.assertEqual(expected, sess.run(next_element)) 361 with self.assertRaises(errors.OutOfRangeError): 362 sess.run(next_element) 363 364 dataset_2 = dataset_ops.Dataset.from_generator( 365 g, output_types=dtypes.int64) 366 sess.run(iterator.make_initializer(dataset_2)) 367 for expected in range(10): 368 self.assertEqual(expected, sess.run(next_element)) 369 with self.assertRaises(errors.OutOfRangeError): 370 sess.run(next_element) 371 372 @combinations.generate(test_base.default_test_combinations()) 373 def testReinitializableIteratorStaticErrors(self): 374 # Non-matching structure for types and shapes. 375 with self.assertRaises(TypeError): 376 iterator = iterator_ops.Iterator.from_structure( 377 (dtypes.int64, dtypes.float64), [None]) 378 379 # Test validation of dataset argument. 380 iterator = iterator_ops.Iterator.from_structure((dtypes.int64, 381 dtypes.float64)) 382 383 # Incompatible structure. 384 with self.assertRaises(ValueError): 385 iterator.make_initializer( 386 dataset_ops.Dataset.from_tensors(((constant_op.constant( 387 [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( 388 [4., 5., 6., 7.], dtype=dtypes.float64),)))) 389 390 # Incompatible types. 391 with self.assertRaises(TypeError): 392 iterator.make_initializer( 393 dataset_ops.Dataset.from_tensors( 394 (constant_op.constant([1, 2, 3], dtype=dtypes.int32), 395 constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) 396 397 # Incompatible shapes. 398 iterator = iterator_ops.Iterator.from_structure( 399 (dtypes.int64, dtypes.float64), ([None], [])) 400 with self.assertRaises(TypeError): 401 iterator.make_initializer( 402 dataset_ops.Dataset.from_tensors( 403 (constant_op.constant([1, 2, 3], dtype=dtypes.int64), 404 constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) 405 406 @combinations.generate(test_base.graph_only_combinations()) 407 def testIteratorStringHandle(self): 408 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 409 dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) 410 411 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 412 iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) 413 414 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 415 feedable_iterator = iterator_ops.Iterator.from_string_handle( 416 handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), 417 dataset_ops.get_legacy_output_shapes(dataset_3)) 418 next_element = feedable_iterator.get_next() 419 420 self.assertTrue( 421 structure.are_compatible( 422 dataset_ops.get_structure(dataset_3), 423 dataset_ops.get_structure(feedable_iterator))) 424 425 with self.cached_session() as sess: 426 iterator_3_handle = sess.run(iterator_3.string_handle()) 427 iterator_4_handle = sess.run(iterator_4.string_handle()) 428 429 self.assertEqual(10, 430 sess.run( 431 next_element, 432 feed_dict={handle_placeholder: iterator_4_handle})) 433 self.assertEqual(1, 434 sess.run( 435 next_element, 436 feed_dict={handle_placeholder: iterator_3_handle})) 437 self.assertEqual(20, 438 sess.run( 439 next_element, 440 feed_dict={handle_placeholder: iterator_4_handle})) 441 self.assertEqual(2, 442 sess.run( 443 next_element, 444 feed_dict={handle_placeholder: iterator_3_handle})) 445 self.assertEqual(30, 446 sess.run( 447 next_element, 448 feed_dict={handle_placeholder: iterator_4_handle})) 449 self.assertEqual(3, 450 sess.run( 451 next_element, 452 feed_dict={handle_placeholder: iterator_3_handle})) 453 self.assertEqual(40, 454 sess.run( 455 next_element, 456 feed_dict={handle_placeholder: iterator_4_handle})) 457 with self.assertRaises(errors.OutOfRangeError): 458 sess.run( 459 next_element, feed_dict={handle_placeholder: iterator_3_handle}) 460 with self.assertRaises(errors.OutOfRangeError): 461 sess.run( 462 next_element, feed_dict={handle_placeholder: iterator_4_handle}) 463 464 @combinations.generate(test_base.graph_only_combinations()) 465 def testIteratorStringHandleFuture(self): 466 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 467 dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) 468 469 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 470 iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) 471 472 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 473 feedable_iterator = iterator_ops.Iterator.from_string_handle( 474 handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), 475 dataset_ops.get_legacy_output_shapes(dataset_3)) 476 next_element = feedable_iterator.get_next() 477 478 self.assertTrue( 479 structure.are_compatible( 480 dataset_ops.get_structure(dataset_3), 481 dataset_ops.get_structure(feedable_iterator))) 482 483 with self.cached_session() as sess: 484 iterator_3_handle = sess.run(iterator_3.string_handle()) 485 iterator_4_handle = sess.run(iterator_4.string_handle()) 486 487 self.assertEqual( 488 10, 489 sess.run( 490 next_element, 491 feed_dict={handle_placeholder: iterator_4_handle})) 492 self.assertEqual( 493 1, 494 sess.run( 495 next_element, 496 feed_dict={handle_placeholder: iterator_3_handle})) 497 self.assertEqual( 498 20, 499 sess.run( 500 next_element, 501 feed_dict={handle_placeholder: iterator_4_handle})) 502 self.assertEqual( 503 2, 504 sess.run( 505 next_element, 506 feed_dict={handle_placeholder: iterator_3_handle})) 507 self.assertEqual( 508 30, 509 sess.run( 510 next_element, 511 feed_dict={handle_placeholder: iterator_4_handle})) 512 self.assertEqual( 513 3, 514 sess.run( 515 next_element, 516 feed_dict={handle_placeholder: iterator_3_handle})) 517 self.assertEqual( 518 40, 519 sess.run( 520 next_element, 521 feed_dict={handle_placeholder: iterator_4_handle})) 522 with self.assertRaises(errors.OutOfRangeError): 523 sess.run( 524 next_element, feed_dict={handle_placeholder: iterator_3_handle}) 525 with self.assertRaises(errors.OutOfRangeError): 526 sess.run( 527 next_element, feed_dict={handle_placeholder: iterator_4_handle}) 528 529 @combinations.generate(test_base.graph_only_combinations()) 530 def testIteratorStringHandleReuseTensorObject(self): 531 dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 532 one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset) 533 initializable_iterator = dataset_ops.make_initializable_iterator(dataset) 534 structure_iterator = iterator_ops.Iterator.from_structure( 535 dataset_ops.get_legacy_output_types(dataset)) 536 537 created_ops = len(ops.get_default_graph().get_operations()) 538 539 self.assertIs(one_shot_iterator.string_handle(), 540 one_shot_iterator.string_handle()) 541 self.assertIs(initializable_iterator.string_handle(), 542 initializable_iterator.string_handle()) 543 self.assertIs(structure_iterator.string_handle(), 544 structure_iterator.string_handle()) 545 546 # Assert that getting the (default) string handle creates no ops. 547 self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) 548 549 # Specifying an explicit name will create a new op. 550 handle_with_name = one_shot_iterator.string_handle(name="foo") 551 self.assertEqual("foo", handle_with_name.op.name) 552 self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) 553 554 handle_with_same_name = one_shot_iterator.string_handle(name="foo") 555 self.assertEqual("foo_1", handle_with_same_name.op.name) 556 self.assertIsNot(handle_with_name, handle_with_same_name) 557 558 @combinations.generate(test_base.graph_only_combinations()) 559 def testIteratorStringHandleError(self): 560 dataset_int_scalar = ( 561 dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) 562 dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) 563 564 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 565 566 feedable_int_scalar = iterator_ops.Iterator.from_string_handle( 567 handle_placeholder, dtypes.int32, []) 568 feedable_int_vector = iterator_ops.Iterator.from_string_handle( 569 handle_placeholder, dtypes.int32, [None]) 570 feedable_int_any = iterator_ops.Iterator.from_string_handle( 571 handle_placeholder, dtypes.int32) 572 573 with self.cached_session() as sess: 574 handle_int_scalar = sess.run(dataset_ops.make_one_shot_iterator( 575 dataset_int_scalar).string_handle()) 576 handle_float_vector = sess.run(dataset_ops.make_one_shot_iterator( 577 dataset_float_vector).string_handle()) 578 579 self.assertEqual(1, 580 sess.run( 581 feedable_int_scalar.get_next(), 582 feed_dict={handle_placeholder: handle_int_scalar})) 583 584 self.assertEqual(2, 585 sess.run( 586 feedable_int_any.get_next(), 587 feed_dict={handle_placeholder: handle_int_scalar})) 588 589 with self.assertRaises(errors.InvalidArgumentError): 590 print(sess.run( 591 feedable_int_vector.get_next(), 592 feed_dict={handle_placeholder: handle_int_scalar})) 593 594 with self.assertRaises(errors.InvalidArgumentError): 595 print(sess.run( 596 feedable_int_vector.get_next(), 597 feed_dict={handle_placeholder: handle_float_vector})) 598 599 @combinations.generate(test_base.graph_only_combinations()) 600 def testRemoteIteratorUsingRemoteCallOpDirectSession(self): 601 worker_config = config_pb2.ConfigProto() 602 worker_config.device_count["CPU"] = 3 603 604 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): 605 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 606 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 607 iterator_3_handle = iterator_3.string_handle() 608 609 @function.Defun(dtypes.string) 610 def _remote_fn(h): 611 remote_iterator = iterator_ops.Iterator.from_string_handle( 612 h, dataset_ops.get_legacy_output_types(dataset_3), 613 dataset_ops.get_legacy_output_shapes(dataset_3)) 614 return remote_iterator.get_next() 615 616 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 617 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 618 remote_op = functional_ops.remote_call( 619 args=[iterator_3_handle], 620 Tout=[dtypes.int32], 621 f=_remote_fn, 622 target=target_placeholder) 623 624 with self.session(config=worker_config) as sess: 625 elem = sess.run( 626 remote_op, 627 feed_dict={ 628 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 629 }) 630 self.assertEqual(elem, [1]) 631 # Fails when target is cpu:2 where the resource is not located. 632 with self.assertRaises(errors.InvalidArgumentError): 633 sess.run( 634 remote_op, 635 feed_dict={ 636 target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" 637 }) 638 elem = sess.run( 639 remote_op, 640 feed_dict={ 641 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 642 }) 643 self.assertEqual(elem, [2]) 644 elem = sess.run( 645 remote_op, 646 feed_dict={ 647 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 648 }) 649 self.assertEqual(elem, [3]) 650 with self.assertRaises(errors.OutOfRangeError): 651 sess.run( 652 remote_op, 653 feed_dict={ 654 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 655 }) 656 657 @combinations.generate(test_base.graph_only_combinations()) 658 def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): 659 s1 = server_lib.Server.create_local_server() 660 s2 = server_lib.Server.create_local_server() 661 s3 = server_lib.Server.create_local_server() 662 663 cluster_def = cluster_pb2.ClusterDef() 664 workers = cluster_def.job.add() 665 workers.name = "worker" 666 workers.tasks[0] = s1.target[len("grpc://"):] 667 workers.tasks[1] = s2.target[len("grpc://"):] 668 client = cluster_def.job.add() 669 client.name = "client" 670 client.tasks[0] = s3.target[len("grpc://"):] 671 config = config_pb2.ConfigProto(cluster_def=cluster_def) 672 673 worker_devices = [ 674 "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) 675 ] 676 itr_handles = [] 677 for device in worker_devices: 678 with ops.device(device): 679 src = dataset_ops.Dataset.from_tensor_slices([device]) 680 itr = dataset_ops.make_one_shot_iterator(src) 681 itr_handles.append(itr.string_handle()) 682 683 targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) 684 handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) 685 686 @function.Defun(dtypes.string) 687 def loading_func(h): 688 remote_itr = iterator_ops.Iterator.from_string_handle( 689 h, dataset_ops.get_legacy_output_types(itr), 690 dataset_ops.get_legacy_output_shapes(itr)) 691 return remote_itr.get_next() 692 693 def map_fn(target, handle): 694 return functional_ops.remote_call( 695 args=[handle], Tout=[dtypes.string], f=loading_func, target=target) 696 697 with ops.device("/job:client"): 698 client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn) 699 itr = dataset_ops.make_initializable_iterator(client_dataset) 700 n = itr.get_next() 701 702 with session.Session(s3.target, config=config) as sess: 703 sess.run(itr.initializer) 704 expected_values = worker_devices 705 for expected in expected_values: 706 self.assertEqual((compat.as_bytes(expected),), sess.run(n)) 707 708 with self.assertRaises(errors.OutOfRangeError): 709 sess.run(n) 710 711 @combinations.generate(test_base.graph_only_combinations()) 712 def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): 713 if not test_util.is_gpu_available(): 714 self.skipTest("No GPU available") 715 716 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 717 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 718 iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) 719 iterator_3_handle = iterator_3.string_handle() 720 721 def _encode_raw(byte_array): 722 return bytes(bytearray(byte_array)) 723 724 @function.Defun(dtypes.uint8) 725 def _remote_fn(h): 726 handle = script_ops.py_func(_encode_raw, [h], dtypes.string) 727 remote_iterator = iterator_ops.Iterator.from_string_handle( 728 handle, dataset_ops.get_legacy_output_types(dataset_3), 729 dataset_ops.get_legacy_output_shapes(dataset_3)) 730 return remote_iterator.get_next() 731 732 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 733 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 734 iterator_3_handle_uint8 = parsing_ops.decode_raw( 735 input_bytes=iterator_3_handle, out_type=dtypes.uint8) 736 remote_op = functional_ops.remote_call( 737 args=[iterator_3_handle_uint8], 738 Tout=[dtypes.int32], 739 f=_remote_fn, 740 target=target_placeholder) 741 742 with self.cached_session() as sess: 743 elem = sess.run( 744 remote_op, 745 feed_dict={ 746 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 747 }) 748 self.assertEqual(elem, [1]) 749 elem = sess.run( 750 remote_op, 751 feed_dict={ 752 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 753 }) 754 self.assertEqual(elem, [2]) 755 elem = sess.run( 756 remote_op, 757 feed_dict={ 758 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 759 }) 760 self.assertEqual(elem, [3]) 761 with self.assertRaises(errors.OutOfRangeError): 762 sess.run( 763 remote_op, 764 feed_dict={ 765 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 766 }) 767 768 @combinations.generate(test_base.graph_only_combinations()) 769 def testRepeatedGetNextWarning(self): 770 iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10)) 771 warnings.simplefilter("always") 772 with warnings.catch_warnings(record=True) as w: 773 for _ in range(100): 774 iterator.get_next() 775 self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) 776 for warning in w: 777 self.assertIn( 778 iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) 779 780 @combinations.generate( 781 combinations.times( 782 test_base.default_test_combinations(), 783 combinations.combine( 784 expected_element_structure=tensor_spec.TensorSpec([], 785 dtypes.float32), 786 expected_output_classes=ops.Tensor, 787 expected_output_types=dtypes.float32, 788 expected_output_shapes=[[]]))) 789 def testTensorIteratorStructure(self, expected_element_structure, 790 expected_output_classes, 791 expected_output_types, 792 expected_output_shapes): 793 tf_value_fn = lambda: constant_op.constant(37.0) 794 tf_value = tf_value_fn() 795 iterator = dataset_ops.make_one_shot_iterator( 796 dataset_ops.Dataset.from_tensors(tf_value)) 797 798 self.assertTrue( 799 structure.are_compatible( 800 dataset_ops.get_structure(iterator), expected_element_structure)) 801 self.assertEqual(expected_output_classes, 802 dataset_ops.get_legacy_output_classes(iterator)) 803 self.assertEqual(expected_output_types, 804 dataset_ops.get_legacy_output_types(iterator)) 805 self.assertEqual(expected_output_shapes, 806 dataset_ops.get_legacy_output_shapes(iterator)) 807 808 @combinations.generate( 809 combinations.times( 810 test_base.default_test_combinations(), 811 combinations.combine( 812 expected_element_structure=sparse_tensor.SparseTensorSpec( 813 [1], dtypes.int32), 814 expected_output_classes=sparse_tensor.SparseTensor, 815 expected_output_types=dtypes.int32, 816 expected_output_shapes=[[1]]))) 817 def testSparseTensorIteratorStructure(self, expected_element_structure, 818 expected_output_classes, 819 expected_output_types, 820 expected_output_shapes): 821 822 def tf_value_fn(): 823 return sparse_tensor.SparseTensor( 824 indices=[[0]], 825 values=constant_op.constant([0], dtype=dtypes.int32), 826 dense_shape=[1]) 827 828 tf_value = tf_value_fn() 829 iterator = dataset_ops.make_one_shot_iterator( 830 dataset_ops.Dataset.from_tensors(tf_value)) 831 832 self.assertTrue( 833 structure.are_compatible( 834 dataset_ops.get_structure(iterator), expected_element_structure)) 835 self.assertEqual(expected_output_classes, 836 dataset_ops.get_legacy_output_classes(iterator)) 837 self.assertEqual(expected_output_types, 838 dataset_ops.get_legacy_output_types(iterator)) 839 self.assertEqual(expected_output_shapes, 840 dataset_ops.get_legacy_output_shapes(iterator)) 841 842 @combinations.generate( 843 combinations.times( 844 test_base.default_test_combinations(), 845 combinations.combine( 846 expected_element_structure={ 847 "a": 848 tensor_spec.TensorSpec([], dtypes.float32), 849 "b": (tensor_spec.TensorSpec([1], dtypes.string), 850 tensor_spec.TensorSpec([], dtypes.string)) 851 }, 852 expected_output_classes={ 853 "a": ops.Tensor, 854 "b": (ops.Tensor, ops.Tensor) 855 }, 856 expected_output_types={ 857 "a": dtypes.float32, 858 "b": (dtypes.string, dtypes.string) 859 }, 860 expected_output_shapes={ 861 "a": [], 862 "b": ([1], []) 863 }))) 864 def testNestedTensorIteratorStructure(self, expected_element_structure, 865 expected_output_classes, 866 expected_output_types, 867 expected_output_shapes): 868 869 def tf_value_fn(): 870 return { 871 "a": constant_op.constant(37.0), 872 "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) 873 } 874 875 tf_value = tf_value_fn() 876 iterator = dataset_ops.make_one_shot_iterator( 877 dataset_ops.Dataset.from_tensors(tf_value)) 878 879 self.assertTrue( 880 structure.are_compatible( 881 dataset_ops.get_structure(iterator), expected_element_structure)) 882 self.assertEqual(expected_output_classes, 883 dataset_ops.get_legacy_output_classes(iterator)) 884 self.assertEqual(expected_output_types, 885 dataset_ops.get_legacy_output_types(iterator)) 886 self.assertEqual(expected_output_shapes, 887 dataset_ops.get_legacy_output_shapes(iterator)) 888 889 @combinations.generate(test_base.default_test_combinations()) 890 def testIteratorGetNextName(self): 891 with ops.Graph().as_default(): 892 iterator = dataset_ops.make_one_shot_iterator( 893 dataset_ops.Dataset.from_tensors(37.0)) 894 next_element = iterator.get_next(name="overridden_name") 895 self.assertEqual("overridden_name", next_element.op.name) 896 897 @combinations.generate( 898 combinations.combine( 899 tf_api_version=[1, 2], 900 mode="eager", 901 execution_mode=[context.ASYNC, context.SYNC])) 902 def testIteratorEagerIteration(self, execution_mode): 903 with context.eager_mode(), context.execution_mode(execution_mode): 904 val = 0 905 dataset = dataset_ops.Dataset.range(10) 906 iterator = iter(dataset) 907 for foo in iterator: 908 self.assertEqual(val, foo.numpy()) 909 val += 1 910 911 @combinations.generate(test_base.eager_only_combinations()) 912 def testOwnedIteratorFunction(self): 913 914 queue = data_flow_ops.FIFOQueue(10, dtypes.int64) 915 916 @def_function.function 917 def fn(): 918 dataset = dataset_ops.Dataset.range(10) 919 iterator = iter(dataset) 920 for _ in range(10): 921 queue.enqueue(next(iterator)) 922 923 fn() 924 925 for i in range(10): 926 self.assertEqual(queue.dequeue().numpy(), i) 927 928 @combinations.generate(test_base.eager_only_combinations()) 929 def testOwnedIteratorFunctionError(self): 930 # In this test we verify that a function that raises an error ends up 931 # properly deallocating the iterator resource. 932 933 queue = data_flow_ops.FIFOQueue(10, dtypes.int64) 934 queue.enqueue(0) 935 936 def init_fn(n): 937 return n 938 939 def next_fn(_): 940 ds = dataset_ops.Dataset.range(0) 941 return next(iter(ds)) 942 943 def finalize_fn(n): 944 queue.enqueue(0) 945 return n 946 947 @def_function.function 948 def fn(): 949 output_signature = tensor_spec.TensorSpec((), dtypes.int64) 950 dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn, 951 output_signature) 952 iterator = iter(dataset) 953 next(iterator) 954 955 with self.assertRaises(errors.OutOfRangeError): 956 fn() 957 958 self.assertEqual(queue.size().numpy(), 2) 959 960 @combinations.generate(test_base.eager_only_combinations()) 961 def testLimitedRetracing(self): 962 trace_count = [0] 963 964 @def_function.function 965 def f(iterator): 966 trace_count[0] += 1 967 counter = np.int64(0) 968 for elem in iterator: 969 counter += elem 970 return counter 971 972 dataset = dataset_ops.Dataset.range(5) 973 dataset2 = dataset_ops.Dataset.range(10) 974 975 for _ in range(10): 976 self.assertEqual(self.evaluate(f(iter(dataset))), 10) 977 self.assertEqual(self.evaluate(f(iter(dataset2))), 45) 978 self.assertEqual(trace_count[0], 1) 979 980 @combinations.generate(test_base.eager_only_combinations()) 981 def testNestedFunctionsIteratorResource(self): 982 983 @def_function.function 984 def sum_dataset(ds): 985 it = iter(ds) 986 987 @def_function.function 988 def next_element(it): 989 return next(it) 990 991 total = 0 992 for _ in range(10): 993 total += next_element(it) 994 return total 995 996 ds = dataset_ops.Dataset.range(10) 997 self.assertEqual(sum_dataset(ds).numpy(), 45) 998 self.assertEqual(sum_dataset(ds).numpy(), 45) 999 1000 @combinations.generate(test_base.default_test_combinations()) 1001 def testNestedAutomaticControlDependencies(self): 1002 counter_var = variables.Variable(0) 1003 1004 def map_fn(x): 1005 counter_var.assign_add(1) 1006 return x 1007 1008 def dataset_fn(): 1009 return dataset_ops.Dataset.range(10).map(map_fn) 1010 1011 @def_function.function 1012 def fn(): 1013 it = iter(dataset_fn()) 1014 for _ in range(10): 1015 _ = next(it) 1016 return counter_var 1017 1018 self.evaluate(counter_var.initializer) 1019 self.assertEqual(self.evaluate(fn()), 10) 1020 1021 1022if __name__ == "__main__": 1023 test.main() 1024