• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Iterator`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import warnings
21
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.core.protobuf import cluster_pb2
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.python.client import session
28from tensorflow.python.data.kernel_tests import test_base
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.data.ops import iterator_ops
31from tensorflow.python.data.util import structure
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import combinations
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import function
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import sparse_tensor
41from tensorflow.python.framework import tensor_spec
42from tensorflow.python.framework import test_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import data_flow_ops
45from tensorflow.python.ops import functional_ops
46from tensorflow.python.ops import gradients_impl
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import parsing_ops
49from tensorflow.python.ops import script_ops
50from tensorflow.python.ops import variables
51from tensorflow.python.platform import test
52from tensorflow.python.training import server_lib
53from tensorflow.python.util import compat
54
55
56class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
57
58  @combinations.generate(test_base.graph_only_combinations())
59  def testNoGradients(self):
60    component = constant_op.constant([1.])
61    side = constant_op.constant(0.)
62    add = lambda x: x + side
63    dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
64    value = dataset_ops.make_one_shot_iterator(dataset).get_next()
65    self.assertIsNone(gradients_impl.gradients(value, component)[0])
66    self.assertIsNone(gradients_impl.gradients(value, side)[0])
67    self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])
68
69  @combinations.generate(test_base.graph_only_combinations())
70  def testCapturingStateInOneShotRaisesException(self):
71    var = variables.Variable(37.0, name="myvar")
72    dataset = (
73        dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
74        .map(lambda x: x + var))
75    with self.assertRaisesRegex(
76        ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
77        "datasets that capture stateful objects.+myvar"):
78      dataset_ops.make_one_shot_iterator(dataset)
79
80  @combinations.generate(test_base.graph_only_combinations())
81  def testOneShotIterator(self):
82    components = (np.arange(7),
83                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
84                  np.array(37.0) * np.arange(7))
85
86    def _map_fn(x, y, z):
87      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
88
89    iterator = dataset_ops.make_one_shot_iterator(
90        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
91        .repeat(14))
92    get_next = iterator.get_next()
93
94    self.assertEqual([c.shape[1:] for c in components],
95                     [t.shape for t in get_next])
96
97    with self.cached_session() as sess:
98      for _ in range(14):
99        for i in range(7):
100          result = sess.run(get_next)
101          for component, result_component in zip(components, result):
102            self.assertAllEqual(component[i]**2, result_component)
103      with self.assertRaises(errors.OutOfRangeError):
104        sess.run(get_next)
105
106  @combinations.generate(test_base.graph_only_combinations())
107  def testOneShotIteratorCaptureByValue(self):
108    components = (np.arange(7),
109                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
110                  np.array(37.0) * np.arange(7))
111    tensor_components = tuple([ops.convert_to_tensor(c) for c in components])
112
113    def _map_fn(x, y, z):
114      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
115
116    iterator = dataset_ops.make_one_shot_iterator(
117        dataset_ops.Dataset.from_tensor_slices(tensor_components)
118        .map(_map_fn).repeat(14))
119    get_next = iterator.get_next()
120
121    self.assertEqual([c.shape[1:] for c in components],
122                     [t.shape for t in get_next])
123
124    with self.cached_session() as sess:
125      for _ in range(14):
126        for i in range(7):
127          result = sess.run(get_next)
128          for component, result_component in zip(components, result):
129            self.assertAllEqual(component[i]**2, result_component)
130      with self.assertRaises(errors.OutOfRangeError):
131        sess.run(get_next)
132
133  @combinations.generate(test_base.default_test_combinations())
134  def testOneShotIteratorInsideContainer(self):
135    components = (np.arange(7),
136                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
137                  np.array(37.0) * np.arange(7))
138
139    def within_container():
140
141      def _map_fn(x, y, z):
142        return math_ops.square(x), math_ops.square(y), math_ops.square(z)
143
144      iterator = dataset_ops.make_one_shot_iterator(
145          dataset_ops.Dataset.from_tensor_slices(components)
146          .map(_map_fn).repeat(14))
147      return iterator.get_next()
148
149    server = server_lib.Server.create_local_server()
150
151    # Create two iterators within unique containers, and run them to
152    # make sure that the resources aren't shared.
153    #
154    # The test below would fail if cname were the same across both
155    # sessions.
156    for j in range(2):
157      with session.Session(server.target) as sess:
158        cname = "iteration%d" % j
159        with ops.container(cname):
160          get_next = within_container()
161
162        for _ in range(14):
163          for i in range(7):
164            result = sess.run(get_next)
165            for component, result_component in zip(components, result):
166              self.assertAllEqual(component[i]**2, result_component)
167        with self.assertRaises(errors.OutOfRangeError):
168          sess.run(get_next)
169
170  @combinations.generate(test_base.graph_only_combinations())
171  def testOneShotIteratorNonBlocking(self):
172    dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
173    iterator = dataset_ops.make_one_shot_iterator(dataset)
174    next_element = iterator.get_next()
175
176    # Create a session with a single thread to ensure that the
177    # one-shot iterator initializer does not deadlock.
178    config = config_pb2.ConfigProto(
179        inter_op_parallelism_threads=1, use_per_session_threads=True)
180    with session.Session(config=config) as sess:
181      self.assertAllEqual([1, 4, 9], sess.run(next_element))
182      with self.assertRaises(errors.OutOfRangeError):
183        sess.run(next_element)
184
185    # Test with multiple threads invoking the one-shot iterator concurrently.
186    with session.Session(config=config) as sess:
187      results = []
188
189      def consumer_thread():
190        try:
191          results.append(sess.run(next_element))
192        except errors.OutOfRangeError:
193          results.append(None)
194
195      num_threads = 8
196      threads = [
197          self.checkedThread(consumer_thread) for _ in range(num_threads)
198      ]
199      for t in threads:
200        t.start()
201      for t in threads:
202        t.join()
203
204      self.assertLen(results, num_threads)
205      self.assertLen([None for r in results if r is None], num_threads - 1)
206      self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
207
208  @combinations.generate(test_base.graph_only_combinations())
209  def testOneShotIteratorInitializerFails(self):
210    # Define a dataset whose initialization will always fail.
211    dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
212    iterator = dataset_ops.make_one_shot_iterator(dataset)
213    next_element = iterator.get_next()
214
215    with self.cached_session() as sess:
216      with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
217        sess.run(next_element)
218
219      # Test that subsequent attempts to use the iterator also fail.
220      with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
221        sess.run(next_element)
222
223    with self.cached_session() as sess:
224
225      def consumer_thread():
226        with self.assertRaisesRegex(errors.InvalidArgumentError, ""):
227          sess.run(next_element)
228
229      num_threads = 8
230      threads = [
231          self.checkedThread(consumer_thread) for _ in range(num_threads)
232      ]
233      for t in threads:
234        t.start()
235      for t in threads:
236        t.join()
237
238  @combinations.generate(test_base.graph_only_combinations())
239  def testSimpleSharedResource(self):
240    components = (np.array(1, dtype=np.int64),
241                  np.array([1, 2, 3], dtype=np.int64),
242                  np.array(37.0, dtype=np.float64))
243
244    server = server_lib.Server.create_local_server()
245
246    # Create two non-overlapping sessions that share the same iterator
247    # resource on the same server, and verify that an action of the
248    # first session (initializing the iterator) is visible in the
249    # second session.
250    with ops.Graph().as_default():
251      iterator = dataset_ops.make_initializable_iterator(
252          dataset_ops.Dataset.from_tensors(
253              components).map(lambda x, y, z: (x, y, z)),
254          shared_name="shared_iterator")
255      init_op = iterator.initializer
256      get_next = iterator.get_next()
257
258      with session.Session(server.target) as sess:
259        sess.run(init_op)
260        results = sess.run(get_next)
261        for component, result_component in zip(components, results):
262          self.assertAllEqual(component, result_component)
263        with self.assertRaises(errors.OutOfRangeError):
264          sess.run(get_next)
265
266        # Re-initialize the iterator in the first session.
267        sess.run(init_op)
268
269    with ops.Graph().as_default():
270      # Re-define the iterator manually, without defining any of the
271      # functions in this graph, to ensure that we are not
272      # accidentally redefining functions with the same names in the
273      # new graph.
274      iterator = iterator_ops.Iterator.from_structure(
275          shared_name="shared_iterator",
276          output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
277          output_shapes=([], [3], []))
278      get_next = iterator.get_next()
279
280      with session.Session(server.target) as sess:
281        # Use the iterator without re-initializing in the second session.
282        results = sess.run(get_next)
283        for component, result_component in zip(components, results):
284          self.assertAllEqual(component, result_component)
285        with self.assertRaises(errors.OutOfRangeError):
286          sess.run(get_next)
287
288  @combinations.generate(test_base.graph_only_combinations())
289  def testNotInitializedError(self):
290    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
291    iterator = dataset_ops.make_initializable_iterator(
292        dataset_ops.Dataset.from_tensors(components))
293    get_next = iterator.get_next()
294
295    with self.cached_session() as sess:
296      with self.assertRaisesRegex(errors.FailedPreconditionError,
297                                  "iterator has not been initialized"):
298        sess.run(get_next)
299
300  @combinations.generate(test_base.graph_only_combinations())
301  def testReinitializableIterator(self):
302    dataset_3 = dataset_ops.Dataset.from_tensors(
303        constant_op.constant([1, 2, 3]))
304    dataset_4 = dataset_ops.Dataset.from_tensors(
305        constant_op.constant([4, 5, 6, 7]))
306    iterator = iterator_ops.Iterator.from_structure(
307        dataset_ops.get_legacy_output_types(dataset_3), [None])
308
309    dataset_3_init_op = iterator.make_initializer(dataset_3)
310    dataset_4_init_op = iterator.make_initializer(dataset_4)
311    get_next = iterator.get_next()
312
313    self.assertEqual(
314        dataset_ops.get_legacy_output_types(dataset_3),
315        dataset_ops.get_legacy_output_types(iterator))
316    self.assertEqual(
317        dataset_ops.get_legacy_output_types(dataset_4),
318        dataset_ops.get_legacy_output_types(iterator))
319    self.assertEqual(
320        [None], dataset_ops.get_legacy_output_shapes(iterator).as_list())
321
322    with self.cached_session() as sess:
323      # The iterator is initially uninitialized.
324      with self.assertRaises(errors.FailedPreconditionError):
325        sess.run(get_next)
326
327      # Initialize with one dataset.
328      sess.run(dataset_3_init_op)
329      self.assertAllEqual([1, 2, 3], sess.run(get_next))
330      with self.assertRaises(errors.OutOfRangeError):
331        sess.run(get_next)
332
333      # Initialize with a different dataset.
334      sess.run(dataset_4_init_op)
335      self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
336      with self.assertRaises(errors.OutOfRangeError):
337        sess.run(get_next)
338
339      # Reinitialize with the first dataset.
340      sess.run(dataset_3_init_op)
341      self.assertAllEqual([1, 2, 3], sess.run(get_next))
342      with self.assertRaises(errors.OutOfRangeError):
343        sess.run(get_next)
344
345  @combinations.generate(test_base.graph_only_combinations())
346  def testReinitializableIteratorWithFunctions(self):
347
348    def g():
349      for i in range(10):
350        yield i
351
352    iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
353    next_element = iterator.get_next()
354
355    with self.cached_session() as sess:
356      dataset_1 = dataset_ops.Dataset.from_generator(
357          g, output_types=dtypes.int64)
358      sess.run(iterator.make_initializer(dataset_1))
359      for expected in range(10):
360        self.assertEqual(expected, sess.run(next_element))
361      with self.assertRaises(errors.OutOfRangeError):
362        sess.run(next_element)
363
364      dataset_2 = dataset_ops.Dataset.from_generator(
365          g, output_types=dtypes.int64)
366      sess.run(iterator.make_initializer(dataset_2))
367      for expected in range(10):
368        self.assertEqual(expected, sess.run(next_element))
369      with self.assertRaises(errors.OutOfRangeError):
370        sess.run(next_element)
371
372  @combinations.generate(test_base.default_test_combinations())
373  def testReinitializableIteratorStaticErrors(self):
374    # Non-matching structure for types and shapes.
375    with self.assertRaises(TypeError):
376      iterator = iterator_ops.Iterator.from_structure(
377          (dtypes.int64, dtypes.float64), [None])
378
379    # Test validation of dataset argument.
380    iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
381                                                     dtypes.float64))
382
383    # Incompatible structure.
384    with self.assertRaises(ValueError):
385      iterator.make_initializer(
386          dataset_ops.Dataset.from_tensors(((constant_op.constant(
387              [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
388                  [4., 5., 6., 7.], dtype=dtypes.float64),))))
389
390    # Incompatible types.
391    with self.assertRaises(TypeError):
392      iterator.make_initializer(
393          dataset_ops.Dataset.from_tensors(
394              (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
395               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32))))
396
397    # Incompatible shapes.
398    iterator = iterator_ops.Iterator.from_structure(
399        (dtypes.int64, dtypes.float64), ([None], []))
400    with self.assertRaises(TypeError):
401      iterator.make_initializer(
402          dataset_ops.Dataset.from_tensors(
403              (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
404               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))
405
406  @combinations.generate(test_base.graph_only_combinations())
407  def testIteratorStringHandle(self):
408    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
409    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
410
411    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
412    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
413
414    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
415    feedable_iterator = iterator_ops.Iterator.from_string_handle(
416        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
417        dataset_ops.get_legacy_output_shapes(dataset_3))
418    next_element = feedable_iterator.get_next()
419
420    self.assertTrue(
421        structure.are_compatible(
422            dataset_ops.get_structure(dataset_3),
423            dataset_ops.get_structure(feedable_iterator)))
424
425    with self.cached_session() as sess:
426      iterator_3_handle = sess.run(iterator_3.string_handle())
427      iterator_4_handle = sess.run(iterator_4.string_handle())
428
429      self.assertEqual(10,
430                       sess.run(
431                           next_element,
432                           feed_dict={handle_placeholder: iterator_4_handle}))
433      self.assertEqual(1,
434                       sess.run(
435                           next_element,
436                           feed_dict={handle_placeholder: iterator_3_handle}))
437      self.assertEqual(20,
438                       sess.run(
439                           next_element,
440                           feed_dict={handle_placeholder: iterator_4_handle}))
441      self.assertEqual(2,
442                       sess.run(
443                           next_element,
444                           feed_dict={handle_placeholder: iterator_3_handle}))
445      self.assertEqual(30,
446                       sess.run(
447                           next_element,
448                           feed_dict={handle_placeholder: iterator_4_handle}))
449      self.assertEqual(3,
450                       sess.run(
451                           next_element,
452                           feed_dict={handle_placeholder: iterator_3_handle}))
453      self.assertEqual(40,
454                       sess.run(
455                           next_element,
456                           feed_dict={handle_placeholder: iterator_4_handle}))
457      with self.assertRaises(errors.OutOfRangeError):
458        sess.run(
459            next_element, feed_dict={handle_placeholder: iterator_3_handle})
460      with self.assertRaises(errors.OutOfRangeError):
461        sess.run(
462            next_element, feed_dict={handle_placeholder: iterator_4_handle})
463
464  @combinations.generate(test_base.graph_only_combinations())
465  def testIteratorStringHandleFuture(self):
466    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
467    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
468
469    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
470    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)
471
472    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
473    feedable_iterator = iterator_ops.Iterator.from_string_handle(
474        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
475        dataset_ops.get_legacy_output_shapes(dataset_3))
476    next_element = feedable_iterator.get_next()
477
478    self.assertTrue(
479        structure.are_compatible(
480            dataset_ops.get_structure(dataset_3),
481            dataset_ops.get_structure(feedable_iterator)))
482
483    with self.cached_session() as sess:
484      iterator_3_handle = sess.run(iterator_3.string_handle())
485      iterator_4_handle = sess.run(iterator_4.string_handle())
486
487      self.assertEqual(
488          10,
489          sess.run(
490              next_element,
491              feed_dict={handle_placeholder: iterator_4_handle}))
492      self.assertEqual(
493          1,
494          sess.run(
495              next_element,
496              feed_dict={handle_placeholder: iterator_3_handle}))
497      self.assertEqual(
498          20,
499          sess.run(
500              next_element,
501              feed_dict={handle_placeholder: iterator_4_handle}))
502      self.assertEqual(
503          2,
504          sess.run(
505              next_element,
506              feed_dict={handle_placeholder: iterator_3_handle}))
507      self.assertEqual(
508          30,
509          sess.run(
510              next_element,
511              feed_dict={handle_placeholder: iterator_4_handle}))
512      self.assertEqual(
513          3,
514          sess.run(
515              next_element,
516              feed_dict={handle_placeholder: iterator_3_handle}))
517      self.assertEqual(
518          40,
519          sess.run(
520              next_element,
521              feed_dict={handle_placeholder: iterator_4_handle}))
522      with self.assertRaises(errors.OutOfRangeError):
523        sess.run(
524            next_element, feed_dict={handle_placeholder: iterator_3_handle})
525      with self.assertRaises(errors.OutOfRangeError):
526        sess.run(
527            next_element, feed_dict={handle_placeholder: iterator_4_handle})
528
529  @combinations.generate(test_base.graph_only_combinations())
530  def testIteratorStringHandleReuseTensorObject(self):
531    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
532    one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
533    initializable_iterator = dataset_ops.make_initializable_iterator(dataset)
534    structure_iterator = iterator_ops.Iterator.from_structure(
535        dataset_ops.get_legacy_output_types(dataset))
536
537    created_ops = len(ops.get_default_graph().get_operations())
538
539    self.assertIs(one_shot_iterator.string_handle(),
540                  one_shot_iterator.string_handle())
541    self.assertIs(initializable_iterator.string_handle(),
542                  initializable_iterator.string_handle())
543    self.assertIs(structure_iterator.string_handle(),
544                  structure_iterator.string_handle())
545
546    # Assert that getting the (default) string handle creates no ops.
547    self.assertEqual(created_ops, len(ops.get_default_graph().get_operations()))
548
549    # Specifying an explicit name will create a new op.
550    handle_with_name = one_shot_iterator.string_handle(name="foo")
551    self.assertEqual("foo", handle_with_name.op.name)
552    self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)
553
554    handle_with_same_name = one_shot_iterator.string_handle(name="foo")
555    self.assertEqual("foo_1", handle_with_same_name.op.name)
556    self.assertIsNot(handle_with_name, handle_with_same_name)
557
558  @combinations.generate(test_base.graph_only_combinations())
559  def testIteratorStringHandleError(self):
560    dataset_int_scalar = (
561        dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
562    dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]))
563
564    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
565
566    feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
567        handle_placeholder, dtypes.int32, [])
568    feedable_int_vector = iterator_ops.Iterator.from_string_handle(
569        handle_placeholder, dtypes.int32, [None])
570    feedable_int_any = iterator_ops.Iterator.from_string_handle(
571        handle_placeholder, dtypes.int32)
572
573    with self.cached_session() as sess:
574      handle_int_scalar = sess.run(dataset_ops.make_one_shot_iterator(
575          dataset_int_scalar).string_handle())
576      handle_float_vector = sess.run(dataset_ops.make_one_shot_iterator(
577          dataset_float_vector).string_handle())
578
579      self.assertEqual(1,
580                       sess.run(
581                           feedable_int_scalar.get_next(),
582                           feed_dict={handle_placeholder: handle_int_scalar}))
583
584      self.assertEqual(2,
585                       sess.run(
586                           feedable_int_any.get_next(),
587                           feed_dict={handle_placeholder: handle_int_scalar}))
588
589      with self.assertRaises(errors.InvalidArgumentError):
590        print(sess.run(
591            feedable_int_vector.get_next(),
592            feed_dict={handle_placeholder: handle_int_scalar}))
593
594      with self.assertRaises(errors.InvalidArgumentError):
595        print(sess.run(
596            feedable_int_vector.get_next(),
597            feed_dict={handle_placeholder: handle_float_vector}))
598
599  @combinations.generate(test_base.graph_only_combinations())
600  def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
601    worker_config = config_pb2.ConfigProto()
602    worker_config.device_count["CPU"] = 3
603
604    with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
605      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
606      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
607      iterator_3_handle = iterator_3.string_handle()
608
609    @function.Defun(dtypes.string)
610    def _remote_fn(h):
611      remote_iterator = iterator_ops.Iterator.from_string_handle(
612          h, dataset_ops.get_legacy_output_types(dataset_3),
613          dataset_ops.get_legacy_output_shapes(dataset_3))
614      return remote_iterator.get_next()
615
616    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
617      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
618      remote_op = functional_ops.remote_call(
619          args=[iterator_3_handle],
620          Tout=[dtypes.int32],
621          f=_remote_fn,
622          target=target_placeholder)
623
624    with self.session(config=worker_config) as sess:
625      elem = sess.run(
626          remote_op,
627          feed_dict={
628              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
629          })
630      self.assertEqual(elem, [1])
631      # Fails when target is cpu:2 where the resource is not located.
632      with self.assertRaises(errors.InvalidArgumentError):
633        sess.run(
634            remote_op,
635            feed_dict={
636                target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
637            })
638      elem = sess.run(
639          remote_op,
640          feed_dict={
641              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
642          })
643      self.assertEqual(elem, [2])
644      elem = sess.run(
645          remote_op,
646          feed_dict={
647              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
648          })
649      self.assertEqual(elem, [3])
650      with self.assertRaises(errors.OutOfRangeError):
651        sess.run(
652            remote_op,
653            feed_dict={
654                target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
655            })
656
657  @combinations.generate(test_base.graph_only_combinations())
658  def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
659    s1 = server_lib.Server.create_local_server()
660    s2 = server_lib.Server.create_local_server()
661    s3 = server_lib.Server.create_local_server()
662
663    cluster_def = cluster_pb2.ClusterDef()
664    workers = cluster_def.job.add()
665    workers.name = "worker"
666    workers.tasks[0] = s1.target[len("grpc://"):]
667    workers.tasks[1] = s2.target[len("grpc://"):]
668    client = cluster_def.job.add()
669    client.name = "client"
670    client.tasks[0] = s3.target[len("grpc://"):]
671    config = config_pb2.ConfigProto(cluster_def=cluster_def)
672
673    worker_devices = [
674        "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
675    ]
676    itr_handles = []
677    for device in worker_devices:
678      with ops.device(device):
679        src = dataset_ops.Dataset.from_tensor_slices([device])
680        itr = dataset_ops.make_one_shot_iterator(src)
681        itr_handles.append(itr.string_handle())
682
683    targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
684    handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)
685
686    @function.Defun(dtypes.string)
687    def loading_func(h):
688      remote_itr = iterator_ops.Iterator.from_string_handle(
689          h, dataset_ops.get_legacy_output_types(itr),
690          dataset_ops.get_legacy_output_shapes(itr))
691      return remote_itr.get_next()
692
693    def map_fn(target, handle):
694      return functional_ops.remote_call(
695          args=[handle], Tout=[dtypes.string], f=loading_func, target=target)
696
697    with ops.device("/job:client"):
698      client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn)
699      itr = dataset_ops.make_initializable_iterator(client_dataset)
700      n = itr.get_next()
701
702    with session.Session(s3.target, config=config) as sess:
703      sess.run(itr.initializer)
704      expected_values = worker_devices
705      for expected in expected_values:
706        self.assertEqual((compat.as_bytes(expected),), sess.run(n))
707
708      with self.assertRaises(errors.OutOfRangeError):
709        sess.run(n)
710
711  @combinations.generate(test_base.graph_only_combinations())
712  def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
713    if not test_util.is_gpu_available():
714      self.skipTest("No GPU available")
715
716    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
717      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
718      iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
719      iterator_3_handle = iterator_3.string_handle()
720
721    def _encode_raw(byte_array):
722      return bytes(bytearray(byte_array))
723
724    @function.Defun(dtypes.uint8)
725    def _remote_fn(h):
726      handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
727      remote_iterator = iterator_ops.Iterator.from_string_handle(
728          handle, dataset_ops.get_legacy_output_types(dataset_3),
729          dataset_ops.get_legacy_output_shapes(dataset_3))
730      return remote_iterator.get_next()
731
732    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
733      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
734      iterator_3_handle_uint8 = parsing_ops.decode_raw(
735          input_bytes=iterator_3_handle, out_type=dtypes.uint8)
736      remote_op = functional_ops.remote_call(
737          args=[iterator_3_handle_uint8],
738          Tout=[dtypes.int32],
739          f=_remote_fn,
740          target=target_placeholder)
741
742    with self.cached_session() as sess:
743      elem = sess.run(
744          remote_op,
745          feed_dict={
746              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
747          })
748      self.assertEqual(elem, [1])
749      elem = sess.run(
750          remote_op,
751          feed_dict={
752              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
753          })
754      self.assertEqual(elem, [2])
755      elem = sess.run(
756          remote_op,
757          feed_dict={
758              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
759          })
760      self.assertEqual(elem, [3])
761      with self.assertRaises(errors.OutOfRangeError):
762        sess.run(
763            remote_op,
764            feed_dict={
765                target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
766            })
767
768  @combinations.generate(test_base.graph_only_combinations())
769  def testRepeatedGetNextWarning(self):
770    iterator = dataset_ops.make_one_shot_iterator(dataset_ops.Dataset.range(10))
771    warnings.simplefilter("always")
772    with warnings.catch_warnings(record=True) as w:
773      for _ in range(100):
774        iterator.get_next()
775    self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w))
776    for warning in w:
777      self.assertIn(
778          iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message))
779
780  @combinations.generate(
781      combinations.times(
782          test_base.default_test_combinations(),
783          combinations.combine(
784              expected_element_structure=tensor_spec.TensorSpec([],
785                                                                dtypes.float32),
786              expected_output_classes=ops.Tensor,
787              expected_output_types=dtypes.float32,
788              expected_output_shapes=[[]])))
789  def testTensorIteratorStructure(self, expected_element_structure,
790                                  expected_output_classes,
791                                  expected_output_types,
792                                  expected_output_shapes):
793    tf_value_fn = lambda: constant_op.constant(37.0)
794    tf_value = tf_value_fn()
795    iterator = dataset_ops.make_one_shot_iterator(
796        dataset_ops.Dataset.from_tensors(tf_value))
797
798    self.assertTrue(
799        structure.are_compatible(
800            dataset_ops.get_structure(iterator), expected_element_structure))
801    self.assertEqual(expected_output_classes,
802                     dataset_ops.get_legacy_output_classes(iterator))
803    self.assertEqual(expected_output_types,
804                     dataset_ops.get_legacy_output_types(iterator))
805    self.assertEqual(expected_output_shapes,
806                     dataset_ops.get_legacy_output_shapes(iterator))
807
808  @combinations.generate(
809      combinations.times(
810          test_base.default_test_combinations(),
811          combinations.combine(
812              expected_element_structure=sparse_tensor.SparseTensorSpec(
813                  [1], dtypes.int32),
814              expected_output_classes=sparse_tensor.SparseTensor,
815              expected_output_types=dtypes.int32,
816              expected_output_shapes=[[1]])))
817  def testSparseTensorIteratorStructure(self, expected_element_structure,
818                                        expected_output_classes,
819                                        expected_output_types,
820                                        expected_output_shapes):
821
822    def tf_value_fn():
823      return sparse_tensor.SparseTensor(
824          indices=[[0]],
825          values=constant_op.constant([0], dtype=dtypes.int32),
826          dense_shape=[1])
827
828    tf_value = tf_value_fn()
829    iterator = dataset_ops.make_one_shot_iterator(
830        dataset_ops.Dataset.from_tensors(tf_value))
831
832    self.assertTrue(
833        structure.are_compatible(
834            dataset_ops.get_structure(iterator), expected_element_structure))
835    self.assertEqual(expected_output_classes,
836                     dataset_ops.get_legacy_output_classes(iterator))
837    self.assertEqual(expected_output_types,
838                     dataset_ops.get_legacy_output_types(iterator))
839    self.assertEqual(expected_output_shapes,
840                     dataset_ops.get_legacy_output_shapes(iterator))
841
842  @combinations.generate(
843      combinations.times(
844          test_base.default_test_combinations(),
845          combinations.combine(
846              expected_element_structure={
847                  "a":
848                      tensor_spec.TensorSpec([], dtypes.float32),
849                  "b": (tensor_spec.TensorSpec([1], dtypes.string),
850                        tensor_spec.TensorSpec([], dtypes.string))
851              },
852              expected_output_classes={
853                  "a": ops.Tensor,
854                  "b": (ops.Tensor, ops.Tensor)
855              },
856              expected_output_types={
857                  "a": dtypes.float32,
858                  "b": (dtypes.string, dtypes.string)
859              },
860              expected_output_shapes={
861                  "a": [],
862                  "b": ([1], [])
863              })))
864  def testNestedTensorIteratorStructure(self, expected_element_structure,
865                                        expected_output_classes,
866                                        expected_output_types,
867                                        expected_output_shapes):
868
869    def tf_value_fn():
870      return {
871          "a": constant_op.constant(37.0),
872          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
873      }
874
875    tf_value = tf_value_fn()
876    iterator = dataset_ops.make_one_shot_iterator(
877        dataset_ops.Dataset.from_tensors(tf_value))
878
879    self.assertTrue(
880        structure.are_compatible(
881            dataset_ops.get_structure(iterator), expected_element_structure))
882    self.assertEqual(expected_output_classes,
883                     dataset_ops.get_legacy_output_classes(iterator))
884    self.assertEqual(expected_output_types,
885                     dataset_ops.get_legacy_output_types(iterator))
886    self.assertEqual(expected_output_shapes,
887                     dataset_ops.get_legacy_output_shapes(iterator))
888
889  @combinations.generate(test_base.default_test_combinations())
890  def testIteratorGetNextName(self):
891    with ops.Graph().as_default():
892      iterator = dataset_ops.make_one_shot_iterator(
893          dataset_ops.Dataset.from_tensors(37.0))
894      next_element = iterator.get_next(name="overridden_name")
895      self.assertEqual("overridden_name", next_element.op.name)
896
897  @combinations.generate(
898      combinations.combine(
899          tf_api_version=[1, 2],
900          mode="eager",
901          execution_mode=[context.ASYNC, context.SYNC]))
902  def testIteratorEagerIteration(self, execution_mode):
903    with context.eager_mode(), context.execution_mode(execution_mode):
904      val = 0
905      dataset = dataset_ops.Dataset.range(10)
906      iterator = iter(dataset)
907      for foo in iterator:
908        self.assertEqual(val, foo.numpy())
909        val += 1
910
911  @combinations.generate(test_base.eager_only_combinations())
912  def testOwnedIteratorFunction(self):
913
914    queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
915
916    @def_function.function
917    def fn():
918      dataset = dataset_ops.Dataset.range(10)
919      iterator = iter(dataset)
920      for _ in range(10):
921        queue.enqueue(next(iterator))
922
923    fn()
924
925    for i in range(10):
926      self.assertEqual(queue.dequeue().numpy(), i)
927
928  @combinations.generate(test_base.eager_only_combinations())
929  def testOwnedIteratorFunctionError(self):
930    # In this test we verify that a function that raises an error ends up
931    # properly deallocating the iterator resource.
932
933    queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
934    queue.enqueue(0)
935
936    def init_fn(n):
937      return n
938
939    def next_fn(_):
940      ds = dataset_ops.Dataset.range(0)
941      return next(iter(ds))
942
943    def finalize_fn(n):
944      queue.enqueue(0)
945      return n
946
947    @def_function.function
948    def fn():
949      output_signature = tensor_spec.TensorSpec((), dtypes.int64)
950      dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn, finalize_fn,
951                                              output_signature)
952      iterator = iter(dataset)
953      next(iterator)
954
955    with self.assertRaises(errors.OutOfRangeError):
956      fn()
957
958    self.assertEqual(queue.size().numpy(), 2)
959
960  @combinations.generate(test_base.eager_only_combinations())
961  def testLimitedRetracing(self):
962    trace_count = [0]
963
964    @def_function.function
965    def f(iterator):
966      trace_count[0] += 1
967      counter = np.int64(0)
968      for elem in iterator:
969        counter += elem
970      return counter
971
972    dataset = dataset_ops.Dataset.range(5)
973    dataset2 = dataset_ops.Dataset.range(10)
974
975    for _ in range(10):
976      self.assertEqual(self.evaluate(f(iter(dataset))), 10)
977      self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
978      self.assertEqual(trace_count[0], 1)
979
980  @combinations.generate(test_base.eager_only_combinations())
981  def testNestedFunctionsIteratorResource(self):
982
983    @def_function.function
984    def sum_dataset(ds):
985      it = iter(ds)
986
987      @def_function.function
988      def next_element(it):
989        return next(it)
990
991      total = 0
992      for _ in range(10):
993        total += next_element(it)
994      return total
995
996    ds = dataset_ops.Dataset.range(10)
997    self.assertEqual(sum_dataset(ds).numpy(), 45)
998    self.assertEqual(sum_dataset(ds).numpy(), 45)
999
1000  @combinations.generate(test_base.default_test_combinations())
1001  def testNestedAutomaticControlDependencies(self):
1002    counter_var = variables.Variable(0)
1003
1004    def map_fn(x):
1005      counter_var.assign_add(1)
1006      return x
1007
1008    def dataset_fn():
1009      return dataset_ops.Dataset.range(10).map(map_fn)
1010
1011    @def_function.function
1012    def fn():
1013      it = iter(dataset_fn())
1014      for _ in range(10):
1015        _ = next(it)
1016      return counter_var
1017
1018    self.evaluate(counter_var.initializer)
1019    self.assertEqual(self.evaluate(fn()), 10)
1020
1021
1022if __name__ == "__main__":
1023  test.main()
1024