• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the input_lib library."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python import tf2
27from tensorflow.python.data.experimental.ops import data_service_ops
28from tensorflow.python.data.experimental.service import server_lib
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.data.ops import options as options_lib
31from tensorflow.python.data.ops.options import AutoShardPolicy
32from tensorflow.python.distribute import combinations
33from tensorflow.python.distribute import device_util
34from tensorflow.python.distribute import distribute_lib
35from tensorflow.python.distribute import distribute_utils
36from tensorflow.python.distribute import input_lib
37from tensorflow.python.distribute import multi_worker_util
38from tensorflow.python.distribute import reduce_util
39from tensorflow.python.distribute import strategy_combinations
40from tensorflow.python.distribute import test_util
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.eager import test
44from tensorflow.python.framework import composite_tensor
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import sparse_tensor
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import sparse_ops
54from tensorflow.python.ops import variables
55from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
56from tensorflow.python.util import nest
57
58
59class DistributedIteratorTestBase(test.TestCase):
60
61  # The passed input_context is to create a sharded dataset in between-graph
62  # case.
63  # TODO(yuefengz): rewrite the following method to make it less DRY.
64  def _wrap_iterator(self,
65                     input_type,
66                     dataset_or_input_fn,
67                     input_workers,
68                     devices,
69                     num_replicas_in_sync,
70                     strategy,
71                     input_context=None):
72    # The `input_context` passed in is to shard dataset for
73    # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where
74    # multiple InputContexts are needed.
75    if input_type == "input_fn":
76      self.assertIsNone(
77          input_context,
78          msg=("`The input_context` arg is only used to shard dataset in "
79               "`MultiWorkerMirroredStrategy` when the input type is dataset."))
80
81      input_contexts = []
82      for i in range(input_workers.num_workers):
83        input_contexts.append(
84            distribute_lib.InputContext(
85                # Note: `input_workers.num_workers` is always 1 in between-graph
86                # case.
87                num_input_pipelines=input_workers.num_workers,
88                input_pipeline_id=i,
89                num_replicas_in_sync=len(devices)))
90
91      iterator = input_lib.InputFunctionIterator(dataset_or_input_fn,
92                                                 input_workers, input_contexts,
93                                                 strategy)
94    else:
95      iterator = input_lib.DatasetIterator(
96          dataset_or_input_fn,
97          input_workers,
98          strategy,
99          num_replicas_in_sync=num_replicas_in_sync,
100          input_context=input_context)
101    return iterator
102
103  def _wrap_dataset(self,
104                    input_type,
105                    dataset,
106                    input_workers,
107                    num_replicas_in_sync,
108                    strategy,
109                    input_context=None):
110    if input_type == "dataset":
111      if tf2.enabled():
112        return input_lib.DistributedDataset(
113            input_workers,
114            strategy,
115            dataset,
116            num_replicas_in_sync=num_replicas_in_sync,
117            input_context=input_context)
118      else:
119        return input_lib.DistributedDatasetV1(
120            dataset,
121            input_workers,
122            strategy,
123            num_replicas_in_sync=num_replicas_in_sync,
124            input_context=input_context)
125    else:
126      return strategy.distribute_datasets_from_function(dataset)
127
128  def _assert_iterator_values(self,
129                              iterator,
130                              expected_values,
131                              evaluate_fn,
132                              devices,
133                              enable_get_next_as_optional=False):
134    actual_values = []
135    for _ in range(len(expected_values)):
136      if enable_get_next_as_optional:
137        next_element = iterator.get_next_as_optional().get_value()
138      else:
139        next_element = iterator.get_next()
140      computed_value = evaluate_fn([
141          distribute_utils.select_replica(r, next_element)
142          for r in range(len(devices))
143      ])
144      actual_values.append(computed_value)
145    for expected_value, actual_value in zip(expected_values, actual_values):
146      for expected, actual in zip(expected_value, actual_value):
147        self.assertAllEqual(expected, actual)
148
149  def _assert_dataset_values_for_loop(self, dataset, expected_values,
150                                      evaluate_fn, devices):
151    actual_values = []
152    for x in dataset:
153      computed_value = self.evaluate(
154          [distribute_utils.select_replica(r, x) for r in range(len(devices))])
155      actual_values.append(computed_value)
156    for expected_value, actual_value in zip(expected_values, actual_values):
157      for expected, actual in zip(expected_value, actual_value):
158        self.assertAllEqual(expected, actual)
159
160  def _test_input_iteration(self,
161                            input_type,
162                            api_type,
163                            iteration_type,
164                            dataset_or_input_fn,
165                            worker_device_pairs,
166                            expected_values,
167                            strategy,
168                            sess=None,
169                            num_replicas_in_sync=None,
170                            input_context=None):
171    if iteration_type == "for_loop" and not context.executing_eagerly():
172      self.skipTest("unsupported test combination.")
173
174    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
175      self.skipTest("unsupported test combination.")
176
177    if api_type == "wrap_into_iterator" and input_type == "input_fn":
178      self.skipTest("unsupported test combination.")
179
180    devices = nest.flatten([ds for _, ds in worker_device_pairs])
181    input_workers = input_lib.InputWorkers(worker_device_pairs)
182
183    if api_type == "wrap_into_iterator":
184      iterator = self._wrap_iterator(
185          input_type,
186          dataset_or_input_fn,
187          input_workers,
188          devices,
189          num_replicas_in_sync,
190          strategy,
191          input_context=input_context)
192    else:
193      # wrapping into a dataset:
194      dataset = self._wrap_dataset(
195          input_type,
196          dataset_or_input_fn,
197          input_workers,
198          num_replicas_in_sync,
199          strategy,
200          input_context=input_context)
201
202      if ops.executing_eagerly_outside_functions():
203        iterator = iter(dataset)
204      else:
205        if isinstance(dataset, input_lib.DistributedDatasetV1):
206          iterator = dataset.make_initializable_iterator()
207        else:
208          self.skipTest("unsupported test combination")
209
210    if isinstance(iterator, composite_tensor.CompositeTensor):
211      nest.assert_same_structure(
212          iterator, iterator._type_spec, expand_composites=True)
213
214    if iteration_type == "get_next":
215      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
216      if not ops.executing_eagerly_outside_functions():
217        evaluate(control_flow_ops.group(iterator.initializer))
218
219      def test_get_next(iterator):
220        self._assert_iterator_values(iterator, expected_values, evaluate,
221                                     devices)
222
223        with self.assertRaises(errors.OutOfRangeError):
224          self._assert_iterator_values(iterator, expected_values, evaluate,
225                                       devices)
226
227        # After re-initializing the iterator, should be able to iterate again.
228        if not ops.executing_eagerly_outside_functions():
229          evaluate(control_flow_ops.group(iterator.initializer))
230        else:
231          if api_type == "wrap_into_iterator":
232            self.skipTest("unsupported test combination")
233          else:
234            iterator = iter(dataset)
235
236        self._assert_iterator_values(iterator, expected_values, evaluate,
237                                     devices)
238
239      def test_get_next_as_optional(iterator):
240        self._assert_iterator_values(
241            iterator,
242            expected_values,
243            evaluate,
244            devices,
245            enable_get_next_as_optional=True)
246
247        next_element = iterator.get_next_as_optional()
248        self.assertFalse(self.evaluate(next_element.has_value()))
249        with self.assertRaises(errors.InvalidArgumentError):
250          self._assert_iterator_values(
251              iterator, [0],
252              evaluate,
253              devices,
254              enable_get_next_as_optional=True)
255
256      test_get_next(iterator)
257
258      # re-initializing the iterator
259      if not tf2.enabled():
260        # TODO(yuefengz): we should split this function.
261        return
262      else:
263        if api_type == "wrap_into_iterator":
264          return
265        else:
266          iterator = iter(dataset)
267
268      test_get_next_as_optional(iterator)
269
270    if iteration_type == "for_loop" and context.executing_eagerly():
271      self._assert_dataset_values_for_loop(dataset, expected_values,
272                                           self.evaluate, devices)
273
274  def _create_dataset_or_input_fn(self, input_type, input_fn):
275    if input_type == "input_fn":
276      return input_fn
277    else:
278      return input_fn(distribute_lib.InputContext())
279
280
281class DistributedIteratorTest(DistributedIteratorTestBase,
282                              parameterized.TestCase):
283
284  @combinations.generate(
285      combinations.combine(
286          mode=["eager"],
287          input_type=["input_fn", "dataset"],
288          distribution=[
289              strategy_combinations.one_device_strategy,
290              strategy_combinations.mirrored_strategy_with_one_cpu,
291              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
292              strategy_combinations.multi_worker_mirrored_2x1_cpu
293          ]))
294  def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
295    if not tf2.enabled():
296      self.skipTest("unsupported test combination")
297
298    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
299    input_workers = input_lib.InputWorkers(worker_device_pairs)
300    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
301    dataset_or_input_fn = self._create_dataset_or_input_fn(
302        input_type, dataset_fn)
303
304    input_workers = input_lib.InputWorkers(worker_device_pairs)
305    if input_type == "dataset":
306      dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
307                                                       input_workers,
308                                                       distribution)
309    else:
310      dist_dataset = input_lib.get_distributed_datasets_from_function(
311          dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
312          distribution)
313
314    # Default Iterator types in TF2.
315    iterator = iter(dist_dataset)
316    self.assertIsInstance(iterator, input_lib.DistributedIterator)
317    self.assertIsInstance(iterator._iterators[0],
318                          input_lib._SingleWorkerOwnedDatasetIterator)
319
320    # Disable creating owned iterators by setting a property on the strategy.
321    distribution._enable_legacy_iterators = True
322    iterator = iter(dist_dataset)
323    self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
324    self.assertIsInstance(iterator._iterators[0],
325                          input_lib._SingleWorkerDatasetIterator)
326
327  @combinations.generate(
328      combinations.combine(
329          mode=["eager"],
330          distribution=[
331              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
332          ]))
333  def testMultiDeviceIterInitialize(self, distribution):
334    if tf2.enabled():
335      self.skipTest("Only V1 is supported.")
336    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
337                                              "/device:CPU:0"])]
338    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
339
340    input_workers = input_lib.InputWorkers(worker_device_pairs)
341
342    dist_dataset = input_lib.get_distributed_dataset(
343        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
344
345    iterator = dataset_ops.make_one_shot_iterator(dist_dataset)
346
347    @def_function.function
348    def init_func_for_iter():
349      self.evaluate(iterator.initializer)
350
351    init_func_for_iter()
352
353  @combinations.generate(
354      combinations.combine(
355          mode=["graph", "eager"],
356          input_type=["input_fn", "dataset"],
357          api_type=["wrap_into_iterator", "wrap_into_dataset"],
358          iteration_type=["get_next", "for_loop"],
359          distribution=[
360              strategy_combinations.one_device_strategy,
361              strategy_combinations.mirrored_strategy_with_one_cpu,
362          ],
363          enable_get_next_as_optional=[True, False]))
364  def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
365                       enable_get_next_as_optional):
366    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
367    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
368    dataset_or_input_fn = self._create_dataset_or_input_fn(
369        input_type, dataset_fn)
370
371    expected_values = [[i] for i in range(10)]
372
373    distribution.extended.experimental_enable_get_next_as_optional = (
374        enable_get_next_as_optional)
375    self._test_input_iteration(input_type, api_type, iteration_type,
376                               dataset_or_input_fn, worker_device_pairs,
377                               expected_values, distribution)
378
379  @combinations.generate(
380      combinations.combine(
381          mode=["eager"],
382          input_type=["input_fn", "dataset"],
383          api_type=["wrap_into_dataset"],
384          iteration_type=["get_next", "for_loop"],
385          distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu],
386          enable_get_next_as_optional=[True, False]))
387  def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type,
388                                  distribution, enable_get_next_as_optional):
389    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
390    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
391    dataset_or_input_fn = self._create_dataset_or_input_fn(
392        input_type, dataset_fn)
393
394    expected_values = [[i] for i in range(10)]
395
396    distribution.extended.experimental_enable_get_next_as_optional = (
397        enable_get_next_as_optional)
398    self._test_input_iteration(input_type, api_type, iteration_type,
399                               dataset_or_input_fn, worker_device_pairs,
400                               expected_values, distribution)
401
402  @combinations.generate(
403      combinations.combine(
404          mode=["graph", "eager"],
405          input_type=["input_fn", "dataset"],
406          api_type=["wrap_into_iterator", "wrap_into_dataset"],
407          iteration_type=["get_next", "for_loop"],
408          distribution=[
409              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
410              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
411          ],
412          enable_get_next_as_optional=[True, False]))
413  def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
414                                 distribution, enable_get_next_as_optional):
415    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
416                                              "/device:CPU:0"])]
417    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
418    dataset_or_input_fn = self._create_dataset_or_input_fn(
419        input_type, dataset_fn)
420
421    expected_values = [[i, i + 1] for i in range(0, 10, 2)]
422
423    distribution.extended.experimental_enable_get_next_as_optional = (
424        enable_get_next_as_optional)
425    self._test_input_iteration(input_type, api_type, iteration_type,
426                               dataset_or_input_fn, worker_device_pairs,
427                               expected_values, distribution)
428
429  @combinations.generate(
430      combinations.combine(
431          mode=["graph", "eager"],
432          input_type=["input_fn", "dataset"],
433          api_type=["wrap_into_iterator", "wrap_into_dataset"],
434          iteration_type=["get_next", "for_loop"],
435          distribution=[strategy_combinations.tpu_strategy],
436          enable_get_next_as_optional=[True, False]))
437  def testTPU(self, input_type, api_type, iteration_type, distribution,
438              enable_get_next_as_optional):
439    worker_device_pairs = collections.OrderedDict()
440    for tpu_device in distribution.extended.worker_devices:
441      host_device = device_util.get_host_for_device(tpu_device)
442      worker_device_pairs.setdefault(host_device, [])
443      worker_device_pairs[host_device].append(tpu_device)
444    worker_device_pairs = worker_device_pairs.items()
445    dataset_fn = lambda _: dataset_ops.Dataset.range(10)
446    dataset_or_input_fn = self._create_dataset_or_input_fn(
447        input_type, dataset_fn)
448
449    expected_values = [[i, i + 1] for i in range(0, 10, 2)]
450
451    distribution.extended.experimental_enable_get_next_as_optional = (
452        enable_get_next_as_optional)
453    self._test_input_iteration(input_type, api_type, iteration_type,
454                               dataset_or_input_fn, worker_device_pairs,
455                               expected_values, distribution)
456
457  @combinations.generate(
458      combinations.combine(
459          mode=["graph", "eager"],
460          input_type=["input_fn", "dataset"],
461          api_type=["wrap_into_iterator", "wrap_into_dataset"],
462          iteration_type=["get_next", "for_loop"],
463          distribution=[
464              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
465              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
466          ],
467          enable_get_next_as_optional=[True, False]))
468  def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
469                       enable_get_next_as_optional):
470    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
471                                              "/device:CPU:0"])]
472
473    def dataset_fn(ctx):
474      del ctx
475      dataset1 = dataset_ops.Dataset.range(10)
476      dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
477      return dataset_ops.Dataset.zip((dataset1, dataset2))
478
479    dataset_or_input_fn = self._create_dataset_or_input_fn(
480        input_type, dataset_fn)
481
482    expected_values = [
483        [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
484    ]
485
486    distribution.extended.experimental_enable_get_next_as_optional = (
487        enable_get_next_as_optional)
488    self._test_input_iteration(input_type, api_type, iteration_type,
489                               dataset_or_input_fn, worker_device_pairs,
490                               expected_values, distribution)
491
492  @combinations.generate(
493      combinations.combine(
494          mode=["eager"],
495          input_type=["input_fn", "dataset"],
496          api_type=["wrap_into_dataset"],
497          iteration_type=["get_next", "for_loop"],
498          distribution=[
499              strategy_combinations.multi_worker_mirrored_2x2_gpu,
500              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
501          ],
502          enable_get_next_as_optional=[True, False]))
503  def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type,
504                                  distribution, enable_get_next_as_optional):
505    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
506                                              "/device:GPU:1"])]
507
508    def dataset_fn(ctx):
509      del ctx
510      dataset1 = dataset_ops.Dataset.range(10)
511      dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
512      return dataset_ops.Dataset.zip((dataset1, dataset2))
513
514    dataset_or_input_fn = self._create_dataset_or_input_fn(
515        input_type, dataset_fn)
516
517    expected_values = [
518        [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2)
519    ]
520
521    distribution.extended.experimental_enable_get_next_as_optional = (
522        enable_get_next_as_optional)
523
524    # Input_context is not passed in and thus no sharding.
525    self._test_input_iteration(input_type, api_type, iteration_type,
526                               dataset_or_input_fn, worker_device_pairs,
527                               expected_values, distribution)
528
529  @combinations.generate(
530      combinations.combine(
531          mode=["eager"],
532          distribution=[
533              strategy_combinations.one_device_strategy,
534              strategy_combinations.mirrored_strategy_with_one_cpu,
535              strategy_combinations.multi_worker_mirrored_2x1_cpu,
536          ]))
537  def testIterableIterator(self, distribution):
538    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
539    input_workers = input_lib.InputWorkers(worker_device_pairs)
540
541    dataset = dataset_ops.Dataset.range(10)
542    dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
543                                                     distribution)
544
545    iterator = iter(dist_dataset)
546    for i, element in enumerate(iterator):
547      self.assertAllEqual(distribution.experimental_local_results(element), [i])
548
549  @combinations.generate(
550      combinations.combine(
551          mode=["eager"],
552          distribution=[
553              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
554              strategy_combinations.mirrored_strategy_with_one_cpu,
555          ]))
556  def testIterableIteratorError(self, distribution):
557    dataset = dataset_ops.Dataset.range(10).batch(2)
558    dist_dataset = distribution.experimental_distribute_dataset(dataset)
559
560    iterator = iter(dist_dataset)
561    # Raises error when next(iterator) is called without strategy scope
562    with self.assertRaises(ValueError):
563
564      def replica_fn1(iterator):
565        return next(iterator)
566
567      distribution.run(replica_fn1, args=(iterator,))
568
569    if distribution.num_replicas_in_sync == 1:
570      expected_result = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8, 9]]]
571    elif distribution.num_replicas_in_sync == 2:
572      expected_result = [[[0], [1]], [[2], [3]], [[4], [5]], [[6], [7]],
573                         [[8], [9]]]
574
575    with distribution.scope():
576
577      def replica_fn2(iterator):
578        return iterator
579
580      result = distribution.run(replica_fn2, args=(next(iterator),))
581      self.assertAllEqual(
582          distribution.experimental_local_results(result), expected_result[0])
583
584    # Confirm default ReplicaContext also works
585    iterator = iter(dist_dataset)
586    for i, element in enumerate(iterator):
587      self.assertAllEqual(
588          distribution.experimental_local_results(element), expected_result[i])
589
590  @combinations.generate(
591      combinations.combine(
592          mode=["graph", "eager"],
593          input_type=["input_fn", "dataset"],
594          api_type=["wrap_into_iterator", "wrap_into_dataset"],
595          iteration_type=["get_next", "for_loop"],
596          drop_remainder=[True, False],
597          distribution=[
598              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
599              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
600          ]))
601  def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
602                               drop_remainder, distribution):
603    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
604                                              "/device:CPU:0"])]
605    dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(  # pylint: disable=g-long-lambda
606        2, drop_remainder=drop_remainder)
607    dataset_or_input_fn = self._create_dataset_or_input_fn(
608        input_type, dataset_fn)
609
610    # The last global batch only contains data for one replica.
611    if drop_remainder:
612      expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
613    else:
614      expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
615    distribution.extended.experimental_enable_get_next_as_optional = True
616    self._test_input_iteration(input_type, api_type, iteration_type,
617                               dataset_or_input_fn, worker_device_pairs,
618                               expected_values, distribution)
619
620  @combinations.generate(
621      combinations.combine(
622          mode=["eager"],
623          input_type=["input_fn", "dataset"],
624          api_type=["wrap_into_dataset"],
625          iteration_type=["get_next", "for_loop"],
626          drop_remainder=[True, False],
627          distribution=[
628              strategy_combinations.multi_worker_mirrored_2x1_cpu,
629              strategy_combinations.multi_worker_mirrored_2x1_gpu,
630          ]))
631  def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type,
632                                          iteration_type, drop_remainder,
633                                          distribution):
634    # Actual devices don't matter in this test as long as the number of global
635    # repices is 2.
636    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
637    cr = distribution.cluster_resolver
638    self.assertIsNotNone(cr)
639    worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
640                                                  cr.task_type)
641    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
642                                                    cr.task_type, cr.task_id)
643
644    def dataset_fn(_):
645      dataset = dataset_ops.Dataset.range(9)
646
647      if input_type == "input_fn":
648        # When input_fn is used, there is no automatic rebatching and sharding,
649        # so we add them here.
650        return dataset.shard(worker_count, id_in_cluster).batch(1)
651      else:
652        return dataset.batch(2, drop_remainder=drop_remainder)
653
654    dataset_or_input_fn = self._create_dataset_or_input_fn(
655        input_type, dataset_fn)
656
657    if drop_remainder and input_type == "dataset":
658      if id_in_cluster == 0:
659        expected_values = [[[0]], [[2]], [[4]], [[6]]]
660      else:
661        expected_values = [[[1]], [[3]], [[5]], [[7]]]
662    else:
663      # The last global batch only contains data for one replica.
664      if id_in_cluster == 0:
665        expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]]
666      else:
667        expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]]
668    distribution.extended.experimental_enable_get_next_as_optional = True
669    self._test_input_iteration(
670        input_type,
671        api_type,
672        iteration_type,
673        dataset_or_input_fn,
674        worker_device_pairs,
675        expected_values,
676        distribution,
677        num_replicas_in_sync=distribution.num_replicas_in_sync,
678        input_context=distribution.extended._make_input_context())
679
680  @combinations.generate(
681      combinations.combine(
682          mode=["eager"],
683          input_type=["input_fn", "dataset"],
684          api_type=["wrap_into_dataset"],
685          iteration_type=["get_next", "for_loop"],
686          drop_remainder=[True, False],
687          distribution=[
688              strategy_combinations.multi_worker_mirrored_2x2_gpu,
689              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
690          ]))
691  def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type,
692                                                      api_type, iteration_type,
693                                                      drop_remainder,
694                                                      distribution):
695    # Actual devices don't matter in this test as long as the number of global
696    # repices is 2.
697    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
698                                              "/device:GPU:1"])]
699    cr = distribution.cluster_resolver
700    self.assertIsNotNone(cr)
701    worker_count = multi_worker_util.worker_count(cr.cluster_spec(),
702                                                  cr.task_type)
703    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
704                                                    cr.task_type, cr.task_id)
705
706    def dataset_fn(_):
707      dataset = dataset_ops.Dataset.range(15)
708
709      if input_type == "input_fn":
710        # When input_fn is used, there is no automatic rebatching and sharding,
711        # so we add them here.
712        return dataset.shard(worker_count, id_in_cluster).batch(1)
713      else:
714        return dataset.batch(4, drop_remainder=drop_remainder)
715
716    dataset_or_input_fn = self._create_dataset_or_input_fn(
717        input_type, dataset_fn)
718
719    # The last global batch only contains data for one replica.
720    if drop_remainder and input_type == "dataset":
721      if id_in_cluster == 0:
722        expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]]
723      else:
724        expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]]
725    else:
726      if id_in_cluster == 0:
727        expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]]
728      else:
729        expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]]
730    distribution.extended.experimental_enable_get_next_as_optional = True
731    self._test_input_iteration(
732        input_type,
733        api_type,
734        iteration_type,
735        dataset_or_input_fn,
736        worker_device_pairs,
737        expected_values,
738        distribution,
739        num_replicas_in_sync=distribution.num_replicas_in_sync,
740        input_context=distribution.extended._make_input_context())
741
742  @combinations.generate(
743      combinations.combine(
744          mode=["graph", "eager"],
745          input_type=["dataset"],
746          api_type=["wrap_into_iterator", "wrap_into_dataset"],
747          iteration_type=["get_next", "for_loop"],
748          num_replicas_in_sync=[None, 2],
749          distribution=[
750              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
751              strategy_combinations.central_storage_strategy_with_gpu_and_cpu
752          ],
753          enable_get_next_as_optional=[True, False]))
754  def testBatchSplitting(self, input_type, api_type, iteration_type,
755                         num_replicas_in_sync, distribution,
756                         enable_get_next_as_optional):
757    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
758                                              "/device:CPU:0"])]
759    batch_size = 10
760    dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
761    dataset_or_input_fn = self._create_dataset_or_input_fn(
762        input_type, dataset_fn)
763
764    updated_batch_size = (
765        batch_size //
766        num_replicas_in_sync if num_replicas_in_sync else batch_size)
767    expected_values = [[
768        range(i, i + updated_batch_size),
769        range(i + updated_batch_size, i + 2 * updated_batch_size)
770    ] for i in range(0, 100, updated_batch_size * 2)]
771
772    distribution.extended.experimental_enable_get_next_as_optional = (
773        enable_get_next_as_optional)
774    self._test_input_iteration(
775        input_type,
776        api_type,
777        iteration_type,
778        dataset_or_input_fn,
779        worker_device_pairs,
780        expected_values,
781        distribution,
782        sess=None,
783        num_replicas_in_sync=num_replicas_in_sync)
784
785  @combinations.generate(
786      combinations.combine(
787          mode=["eager"],
788          input_type=["dataset"],
789          api_type=["wrap_into_dataset"],
790          iteration_type=["get_next", "for_loop"],
791          num_replicas_in_sync=[None, 2],
792          distribution=[
793              strategy_combinations.multi_worker_mirrored_2x2_gpu,
794              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call
795          ],
796          enable_get_next_as_optional=[True, False]))
797  def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
798                                    num_replicas_in_sync, distribution,
799                                    enable_get_next_as_optional):
800    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
801                                              "/device:GPU:1"])]
802    batch_size = 10
803    cr = distribution.cluster_resolver
804    self.assertIsNotNone(cr)
805
806    def dataset_fn(_):
807      dataset = dataset_ops.Dataset.range(100).batch(batch_size)
808      return dataset
809
810    dataset_or_input_fn = self._create_dataset_or_input_fn(
811        input_type, dataset_fn)
812
813    updated_batch_size = (
814        batch_size //
815        num_replicas_in_sync if num_replicas_in_sync else batch_size)
816    expected_values = [
817        [  # pylint: disable=g-complex-comprehension
818            range(i, i + updated_batch_size),
819            range(i + updated_batch_size, i + 2 * updated_batch_size)
820        ] for i in range(0, 100, updated_batch_size * 2)
821    ]
822
823    distribution.extended.experimental_enable_get_next_as_optional = (
824        enable_get_next_as_optional)
825    self._test_input_iteration(
826        input_type,
827        api_type,
828        iteration_type,
829        dataset_or_input_fn,
830        worker_device_pairs,
831        expected_values,
832        distribution,
833        sess=None,
834        num_replicas_in_sync=num_replicas_in_sync)
835
836  @combinations.generate(
837      combinations.combine(
838          mode=["eager"],
839          distribution=[
840              strategy_combinations.one_device_strategy,
841              strategy_combinations.mirrored_strategy_with_one_cpu,
842              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
843              strategy_combinations.tpu_strategy,
844              strategy_combinations.central_storage_strategy_with_two_gpus,
845              strategy_combinations.multi_worker_mirrored_2x2_gpu,
846              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
847              strategy_combinations.multi_worker_mirrored_2x1_cpu,
848          ],
849      ))
850  def testCacheAcrossIteration(self, distribution):
851    if not tf2.enabled():
852      self.skipTest("Only V2 is supported.")
853
854    dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4)
855    dist_dataset = distribution.experimental_distribute_dataset(dataset)
856
857    first_epoch = list(
858        distribution.experimental_local_results(x) for x in dist_dataset)
859    second_epoch = list(
860        distribution.experimental_local_results(x) for x in dist_dataset)
861
862    self.assertAllEqual(first_epoch, second_epoch)
863
864  @combinations.generate(
865      combinations.combine(
866          mode=["eager"],
867          distribution=[
868              strategy_combinations.one_device_strategy,
869              strategy_combinations.mirrored_strategy_with_one_cpu,
870              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
871              strategy_combinations.tpu_strategy,
872              strategy_combinations.central_storage_strategy_with_two_gpus,
873              strategy_combinations.multi_worker_mirrored_2x2_gpu,
874              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
875              strategy_combinations.multi_worker_mirrored_2x1_cpu,
876          ],
877          reshuffle=[True, False]))
878  def testShuffleAcrossIterations(self, distribution, reshuffle):
879    if not tf2.enabled():
880      self.skipTest("Only V2 is supported.")
881
882    dataset = dataset_ops.Dataset.range(12).shuffle(
883        12, reshuffle_each_iteration=reshuffle).batch(4)
884    dist_dataset = distribution.experimental_distribute_dataset(dataset)
885
886    first_epoch = list(
887        distribution.experimental_local_results(x) for x in dist_dataset)
888    second_epoch = list(
889        distribution.experimental_local_results(x) for x in dist_dataset)
890
891    if reshuffle:
892      self.assertNotAllEqual(first_epoch, second_epoch)
893    else:
894      self.assertAllEqual(first_epoch, second_epoch)
895
896  @combinations.generate(
897      combinations.combine(
898          mode=["eager"],
899          distribution=[
900              strategy_combinations.one_device_strategy,
901              strategy_combinations.mirrored_strategy_with_one_cpu,
902              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
903              strategy_combinations.tpu_strategy,
904              strategy_combinations.central_storage_strategy_with_two_gpus,
905              strategy_combinations.multi_worker_mirrored_2x2_gpu,
906              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
907              strategy_combinations.multi_worker_mirrored_2x1_cpu,
908          ]))
909  def testGetNextOptionalShape(self, distribution):
910    batch_size = 8
911    dataset = dataset_ops.DatasetV2.from_tensor_slices({
912        "feature": array_ops.ones([batch_size, 10]),
913        "label": array_ops.ones([batch_size]),
914    })
915    dataset = dataset.batch(batch_size, drop_remainder=True)
916    dist_dataset = distribution.experimental_distribute_dataset(dataset)
917    per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
918
919    @def_function.function
920    def train_fn():
921      for data in dist_dataset:
922        data = nest.map_structure(distribution.experimental_local_results, data)
923        feature = data["feature"]
924        label = data["label"]
925
926        # Assert the shapes are still static from all replicas.
927        for replica_id in range(len(distribution.extended.worker_devices)):
928          self.assertEqual([per_replica_batch_size, 10],
929                           feature[replica_id].shape)
930          self.assertEqual([per_replica_batch_size], label[replica_id].shape)
931
932    train_fn()
933
934  @combinations.generate(
935      combinations.combine(
936          mode=["eager"],
937          distribution=[
938              strategy_combinations.multi_worker_mirrored_2x1_cpu,
939          ],
940          input_type=["dataset"],
941          api_type=["wrap_into_iterator", "wrap_into_dataset"],
942          iteration_type=["get_next", "for_loop"],
943          auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF]))
944  def testAutoshardingOption(self, distribution, input_type, api_type,
945                             iteration_type, auto_shard_policy):
946    cr = distribution.cluster_resolver
947    self.assertIsNotNone(cr)
948    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
949                                                    cr.task_type, cr.task_id)
950    ds_option = options_lib.Options()
951    ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy
952    dataset_fn = (
953        lambda _: dataset_ops.Dataset.range(4).with_options(ds_option))
954    dataset_or_input_fn = self._create_dataset_or_input_fn(
955        input_type, dataset_fn)
956
957    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
958    if auto_shard_policy == AutoShardPolicy.AUTO:
959      if id_in_cluster == 0:
960        expected_values = [[0], [2]]
961      else:
962        expected_values = [[1], [3]]
963    else:
964      expected_values = [[0], [1], [2], [3]]
965    self._test_input_iteration(
966        input_type,
967        api_type,
968        iteration_type,
969        dataset_or_input_fn,
970        worker_device_pairs,
971        expected_values,
972        distribution,
973        input_context=distribution.extended._make_input_context())
974
975  @combinations.generate(
976      combinations.combine(
977          mode=["eager"],
978          distribution=[
979              strategy_combinations.multi_worker_mirrored_2x1_cpu,
980          ],
981          input_type=["input_fn"],
982          api_type=["wrap_into_dataset"],
983          iteration_type=["get_next", "for_loop"]))
984  def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type,
985                                       iteration_type):
986    cr = distribution.cluster_resolver
987    self.assertIsNotNone(cr)
988    id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(),
989                                                    cr.task_type, cr.task_id)
990
991    def dataset_fn(ctx):
992      if ctx.input_pipeline_id == 0:
993        return dataset_ops.Dataset.range(8).batch(2)
994      else:
995        return dataset_ops.Dataset.range(9).batch(2)
996
997    dataset_or_input_fn = self._create_dataset_or_input_fn(
998        input_type, dataset_fn)
999
1000    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1001
1002    if id_in_cluster == 0:
1003      expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]]
1004    else:
1005      expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]]
1006    distribution.extended.experimental_enable_get_next_as_optional = True
1007    self._test_input_iteration(input_type, api_type, iteration_type,
1008                               dataset_or_input_fn, worker_device_pairs,
1009                               expected_values, distribution)
1010
1011  @combinations.generate(
1012      combinations.combine(
1013          strategy=[
1014              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1015              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1016          ],
1017          mode=["eager"]))
1018  def testLoopOverDatasetInTFFunction(self, strategy):
1019    dataset = dataset_ops.Dataset.range(10).map(lambda x: {  # pylint: disable=g-long-lambda
1020        "y": math_ops.cast(x, dtypes.float32) ** 2,
1021    }).batch(4)
1022    dist_dataset = strategy.experimental_distribute_dataset(dataset)
1023
1024    with strategy.scope():
1025      v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM)
1026
1027    @def_function.function
1028    def iterator_fn(dist_dataset):
1029
1030      def assign_add_fn(data):
1031        v.assign_add(math_ops.reduce_sum(data["y"]))
1032
1033      for data in dist_dataset:
1034        strategy.run(assign_add_fn, args=(data,))
1035
1036    iterator_fn(dist_dataset)
1037    self.assertEqual(v.numpy(), 285.0)
1038
1039
1040class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
1041                                        parameterized.TestCase):
1042  """Tests for DistributedDataset with non-dense tensors."""
1043
1044  @combinations.generate(
1045      combinations.combine(
1046          mode=["eager"],
1047          distribution=[
1048              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1049              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1050          ],
1051          input_type=["dataset", "input_fn"],
1052          drop_remainder=[False, True],
1053          defun_type=["lambda", "tf_function"],
1054      ))
1055  def testRaggedSparse(self, distribution, input_type, drop_remainder,
1056                       defun_type):
1057    """Test with `RaggedTensor`s and `SparseTensor`s."""
1058    if not tf2.enabled():
1059      self.skipTest("Only V2 is supported.")
1060
1061    defun = {
1062        "lambda": lambda f: f,
1063        "tf_function": def_function.function
1064    }[defun_type]
1065    distribution.extended.experimental_enable_get_next_as_optional = True
1066    global_batch_size = 8
1067
1068    def dataset_fn(ctx=None):
1069      ctx = ctx or distribute_lib.InputContext()
1070      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1071      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1072      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1073      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1074          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1075      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1076          "dense": ragged_tensor.to_tensor(),
1077          "ragged": ragged_tensor,
1078          "sparse": ragged_tensor.to_sparse(),
1079      })
1080      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1081      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1082
1083    dataset_or_input_fn = self._create_dataset_or_input_fn(
1084        input_type, dataset_fn)
1085    dataset = self._wrap_dataset(input_type, dataset_or_input_fn,
1086                                 distribution.extended._input_workers,
1087                                 len(distribution.extended.worker_devices),
1088                                 distribution)
1089    # Assert that the tensors are rebatched and sparsity is preserved.
1090    per_replica_batch = defun(lambda x: next(iter(x)))(dataset)
1091    self.assertAllEqual(
1092        distribute_utils.select_replica(0, per_replica_batch["dense"]),
1093        [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]])
1094    self.assertAllEqual(
1095        distribute_utils.select_replica(1, per_replica_batch["dense"]),
1096        [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]])
1097    # Transitively check the ragged and sparse tensors by densification.
1098    for i in range(2):
1099      self.assertLen(
1100          distribute_utils.select_replica(i,
1101                                          per_replica_batch["ragged"]).values,
1102          6)
1103      self.assertAllEqual(
1104          distribute_utils.select_replica(
1105              i, per_replica_batch["ragged"]).to_tensor(),
1106          distribute_utils.select_replica(i, per_replica_batch["dense"]))
1107      self.assertLen(
1108          distribute_utils.select_replica(i,
1109                                          per_replica_batch["sparse"]).indices,
1110          6)
1111      self.assertAllEqual(
1112          sparse_ops.sparse_tensor_to_dense(
1113              distribute_utils.select_replica(i, per_replica_batch["sparse"])),
1114          distribute_utils.select_replica(i, per_replica_batch["dense"]))
1115    # Iterate through all the batches and sum them up.
1116    def sum_batch(per_replica_features):
1117      """Sums the `PerReplica` values in the `per_replica_features` map."""
1118
1119      def map_fn(per_replica_values):
1120        per_replica_sums = distribution.run(
1121            (lambda x: math_ops.reduce_sum(x.values)) if all(
1122                map(sparse_tensor.is_sparse, per_replica_values.values)) else
1123            math_ops.reduce_sum, (per_replica_values,))
1124        return distribution.reduce(
1125            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
1126
1127      return nest.map_structure(map_fn, per_replica_features)
1128
1129    def _reduce(state, batch):
1130      sums = sum_batch(batch)
1131      return {name: value + sums[name] for name, value in state.items()}
1132
1133    def sum_for_loop(dataset):
1134      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1135      for batch in dataset:
1136        sums = _reduce(sums, batch)
1137      return sums
1138
1139    def sum_while_loop(iterator, reduce_fn):
1140      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1141      while True:
1142        try:
1143          sums = reduce_fn(sums, iterator)
1144        except (StopIteration, errors.OutOfRangeError):
1145          return sums
1146
1147    while_sums = sum_while_loop(
1148        iter(dataset),
1149        defun(lambda state, iterator: _reduce(state, next(iterator))))
1150    self.assertAllEqual(
1151        nest.flatten(while_sums),
1152        # When there's no partial batch, the sum is smaller.
1153        [200. if drop_remainder else 310.] * 3)
1154    for_sums = defun(sum_for_loop)(dataset)
1155    # For loops always call get next as optional inside tf functions, so we
1156    # expect 310 here when using an input function (as there are 5 batches of
1157    # size 4 round robined over 2 replicas.
1158    expected_for_sum = 200.
1159    if (not drop_remainder or
1160        (defun_type == "tf_function" and input_type == "input_fn")):
1161      expected_for_sum = 310.
1162    self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
1163
1164  @combinations.generate(
1165      combinations.combine(
1166          mode=["eager"],
1167          distribution=[
1168              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1169              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1170              strategy_combinations.one_device_strategy,
1171              strategy_combinations.mirrored_strategy_with_one_cpu
1172          ],
1173          input_type=["dataset", "input_fn"],
1174          drop_remainder=[False, True],
1175          tensor_type=["sparse", "ragged"],
1176          enable_get_next_as_optional=[True, False]))
1177  def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
1178                                        drop_remainder, tensor_type,
1179                                        enable_get_next_as_optional):
1180    """Test with `RaggedTensor`s and `SparseTensor`s."""
1181    if not tf2.enabled():
1182      self.skipTest("Only V2 is supported.")
1183
1184    distribution.extended.experimental_enable_get_next_as_optional = (
1185        enable_get_next_as_optional)
1186    global_batch_size = 8
1187
1188    def dataset_fn(ctx=None):
1189      ctx = ctx or distribute_lib.InputContext()
1190      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1191      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1192      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1193      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1194          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1195      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1196          tensor_type: (ragged_tensor if tensor_type == "ragged" else
1197                        ragged_tensor.to_sparse()),
1198      })
1199      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1200      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1201
1202    if input_type == "dataset":
1203      ds = distribution.experimental_distribute_dataset(
1204          dataset_fn(distribute_lib.InputContext()))
1205    else:
1206      ds = distribution.distribute_datasets_from_function(dataset_fn)
1207    iterator = iter(ds)
1208
1209    self.assertEqual(iterator._enable_get_next_as_optional,
1210                     (not drop_remainder) and enable_get_next_as_optional)
1211
1212  @combinations.generate(
1213      combinations.combine(
1214          tf_api_version=2,
1215          mode=["eager"],
1216          distribution=[
1217              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1218              strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
1219              strategy_combinations.one_device_strategy,
1220              strategy_combinations.mirrored_strategy_with_one_cpu,
1221              # TODO(mdan): Add these?
1222              # strategy_combinations.multi_worker_mirrored_2x1_cpu,
1223              # strategy_combinations.multi_worker_mirrored_2x1_gpu,
1224              # strategy_combinations.multi_worker_mirrored_2x2_gpu,
1225          ],
1226          input_type=["dataset", "input_fn"],
1227          drop_remainder=[False, True],
1228      ))
1229  def testRaggedSparseGetNextAsOptionalInLoop(self, distribution, input_type,
1230                                              drop_remainder):
1231    """Test with `RaggedTensor`s and `SparseTensor`s."""
1232    self.skipTest("b/323359921")
1233
1234    global_batch_size = 8
1235
1236    def dataset_fn(ctx=None):
1237      ctx = ctx or distribute_lib.InputContext()
1238      batch_size = ctx.get_per_replica_batch_size(global_batch_size)
1239      # Use 20 which isn't divisible by 8 to test partial batch behavior.
1240      row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
1241      ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
1242          np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
1243      dataset = dataset_ops.DatasetV2.from_tensor_slices({
1244          "dense": ragged_tensor.to_tensor(),
1245          "ragged": ragged_tensor,
1246          "sparse": ragged_tensor.to_sparse(),
1247      })
1248      dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
1249      return dataset.batch(batch_size, drop_remainder=drop_remainder)
1250
1251    if input_type == "dataset":
1252      ds = distribution.experimental_distribute_dataset(
1253          dataset_fn(distribute_lib.InputContext()))
1254    else:
1255      ds = distribution.distribute_datasets_from_function(dataset_fn)
1256
1257    # Iterate through all the batches and sum them up.
1258    def sum_batch(per_replica_features):
1259      """Sums the `PerReplica` values in the `per_replica_features` map."""
1260
1261      def map_fn(per_replica_values):
1262        per_replica_sums = distribution.run(
1263            (lambda x: math_ops.reduce_sum(x.values)) if all(
1264                map(sparse_tensor.is_sparse, per_replica_values.values)) else
1265            math_ops.reduce_sum, (per_replica_values,))
1266        return distribution.reduce(
1267            reduce_util.ReduceOp.SUM, per_replica_sums, axis=None)
1268
1269      return nest.map_structure(map_fn, per_replica_features)
1270
1271    def _reduce(state, batch):
1272      sums = sum_batch(batch)
1273      return {name: value + sums[name] for name, value in state.items()}
1274
1275    def sum_while_loop(ds):
1276      iterator = iter(ds)
1277      sums = {"dense": 0., "ragged": 0., "sparse": 0.}
1278      try_next = constant_op.constant(True)
1279
1280      while try_next:
1281        opt_iterate = iterator.get_next_as_optional()
1282        if opt_iterate.has_value():
1283          sums = _reduce(sums, opt_iterate.get_value())
1284        else:
1285          try_next = False
1286      return sums
1287
1288    sums = def_function.function(sum_while_loop)(ds)
1289    # For loops always call get next as optional inside tf functions, so we
1290    # expect 310 here when using an input function (as there are 5 batches of
1291    # size 4 round robined over 2 replicas.
1292    expected_for_sum = 200.
1293    if not drop_remainder or input_type == "input_fn":
1294      expected_for_sum = 310.
1295    self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
1296
1297  @combinations.generate(
1298      combinations.combine(
1299          mode=["eager"],
1300          input_type=["dataset"],
1301          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1302          iteration_type=["get_next", "for_loop"],
1303          distribution=[
1304              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1305              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1306          ]))
1307  def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
1308                           distribution):
1309    # Test case: 2 workers, 1 replica each.
1310    # This test simulates the sharded behavior when we have two files each with
1311    # 12 elements and a global batch size of 8. When we consider the dataset in
1312    # aggregate (non-distributed), there are 24 elements divided into 3 batches
1313    # of size 8. Hence, the correct distributed behavior is for each replica to
1314    # see sub-batches of size 4, over three steps.
1315    def dataset_fn(ctx):
1316      del ctx
1317      dataset = dataset_ops.Dataset.range(12).batch(8)
1318
1319      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1320      # `dataset` defines the per-worker dataset and will not be further
1321      # sharded. Each worker will see a dataset that is
1322      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1323      options = options_lib.Options()
1324      options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1325      dataset = dataset.with_options(options)
1326      return dataset
1327
1328    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1329
1330    # Actual devices don't matter in this test as long as there is 1 local
1331    # replica.
1332    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1333
1334    # Each test runs individually on each worker, so we compare the
1335    # values on each worker. Each worker should rebatch its dataset into
1336    # smaller batches of size 4.
1337    expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
1338    self._test_input_iteration(
1339        input_type,
1340        api_type,
1341        iteration_type,
1342        dataset,
1343        worker_device_pairs,
1344        expected_values,
1345        distribution,
1346        num_replicas_in_sync=distribution.num_replicas_in_sync,
1347        input_context=distribution.extended._make_input_context())
1348
1349  @combinations.generate(
1350      combinations.combine(
1351          mode=["eager"],
1352          input_type=["dataset"],
1353          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1354          iteration_type=["get_next", "for_loop"],
1355          distribution=[
1356              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1357              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1358          ]))
1359  def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
1360                                            iteration_type, distribution):
1361    # Test case: 2 workers, 1 replica each.
1362    # This test simulates the sharded behavior when we have two files each with
1363    # 12 elements and a global batch size of 8. When we consider the dataset in
1364    # aggregate (non-distributed), there are 24 elements divided into 3 batches
1365    # of size 8. Hence, the correct distributed behavior is for each replica to
1366    # see sub-batches of size 4, over three steps. However, when we create a
1367    # DistributedDataset and cannot statically infer the intended global batch
1368    # size (e.g. if the user does not use a batching dataset), each worker will
1369    # rebatch based on the dynamic batch size of the data encountered, even when
1370    # it encounters partial batches. The last per-worker partial batch (size 4)
1371    # ends up being split into two replicas, resulting in 4 steps in total, of
1372    # (global) batch sizes 8, 8, 4, 4.
1373    def dataset_fn(ctx):
1374      del ctx
1375      # The following dataset is equivalent to
1376      # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
1377      # This causes DistributedDataset to use LegacyRebatch instead.
1378      batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
1379      offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
1380      dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
1381
1382      def map_fn(offset, batch_size):
1383        return math_ops.range(offset, offset + batch_size)
1384
1385      dataset = dataset.map(map_fn)
1386
1387      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1388      # `dataset` defines the per-worker dataset and will not be further
1389      # sharded. Each worker will see a dataset that is equivalent to
1390      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1391      options = options_lib.Options()
1392      options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
1393      dataset = dataset.with_options(options)
1394      return dataset
1395
1396    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1397
1398    # Actual devices don't matter in this test as long as the number of global
1399    # replicas is 2.
1400    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1401
1402    # Each test runs individually on each worker, so we compare the
1403    # values on each worker. Each worker should rebatch its dataset into
1404    # smaller batches of size 4.
1405    expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
1406    self._test_input_iteration(
1407        input_type,
1408        api_type,
1409        iteration_type,
1410        dataset,
1411        worker_device_pairs,
1412        expected_values,
1413        distribution,
1414        num_replicas_in_sync=distribution.num_replicas_in_sync,
1415        input_context=distribution.extended._make_input_context())
1416
1417  @combinations.generate(
1418      combinations.combine(
1419          mode=["eager"],
1420          input_type=["dataset"],
1421          api_type=["wrap_into_iterator", "wrap_into_dataset"],
1422          iteration_type=["get_next", "for_loop"],
1423          distribution=[
1424              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1425              strategy_combinations.multi_worker_mirrored_2x1_gpu,
1426          ],
1427          auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
1428  def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
1429                               distribution, auto_shard_policy):
1430    # Test case: 2 workers, 1 replica each.
1431    # This test simulates the sharded behavior the dataset is sharded by data
1432    # and the batch size is indivisible by the number of replicas. This checks
1433    # that the elements are as expected and the batch size across all workers
1434    # adds up to 3. This test will only pass if the autoshard rewrite rewrites
1435    # RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
1436    def dataset_fn(ctx):
1437      del ctx
1438      dataset = dataset_ops.Dataset.range(8).batch(3)
1439
1440      # Set the sharding behavior to OFF for simplicity of test setup; namely,
1441      # `dataset` defines the per-worker dataset and will not be further
1442      # sharded. Each worker will see a dataset that is
1443      # tf.data.Dataset.range(12).batch(8).rebatch(...).
1444      options = options_lib.Options()
1445      options.experimental_distribute.auto_shard_policy = auto_shard_policy
1446      dataset = dataset.with_options(options)
1447      return dataset
1448
1449    dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
1450
1451    # Actual devices don't matter in this test as long as there is 1 local
1452    # replica.
1453    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1454
1455    # Each test runs individually on each worker, so we compare the
1456    # values on each worker. We expect each worker to see different shards of
1457    # data.
1458    cr = distribution.cluster_resolver
1459    worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
1460                                                cr.task_id)
1461
1462    if worker_id == 0:
1463      expected_values = [[[0, 1]], [[3, 4]], [[6]]]
1464    elif worker_id == 1:
1465      expected_values = [[[2]], [[5]], [[7]]]
1466
1467    self._test_input_iteration(
1468        input_type,
1469        api_type,
1470        iteration_type,
1471        dataset,
1472        worker_device_pairs,
1473        expected_values,
1474        distribution,
1475        num_replicas_in_sync=distribution.num_replicas_in_sync,
1476        input_context=distribution.extended._make_input_context())
1477
1478
1479class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
1480                                       parameterized.TestCase):
1481  """Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
1482
1483  def setUp(self):
1484    context._reset_context()
1485    strategy_combinations.set_virtual_cpus_to_at_least(3)
1486    super(DistributedIteratorPerDeviceTest, self).setUp()
1487
1488  @combinations.generate(
1489      combinations.combine(
1490          input_options=[
1491              distribute_lib.InputOptions(
1492                  experimental_place_dataset_on_device=False,
1493                  experimental_fetch_to_device=True,
1494                  experimental_replication_mode=distribute_lib
1495                  .InputReplicationMode.PER_WORKER),
1496              distribute_lib.InputOptions(
1497                  experimental_place_dataset_on_device=False,
1498                  experimental_fetch_to_device=True,
1499                  experimental_replication_mode=distribute_lib
1500                  .InputReplicationMode.PER_REPLICA),
1501          ],
1502          mode=["eager"],
1503          distribution=[
1504              strategy_combinations.mirrored_strategy_with_two_gpus,
1505              strategy_combinations
1506              .mirrored_strategy_with_two_gpus_no_merge_call,
1507              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1508              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1509          ]))
1510  def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
1511                                                        input_options):
1512
1513    def dataset_fn(input_context):  # pylint: disable=[unused-argument]
1514      return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4])
1515
1516    ds = distribution.experimental_distribute_datasets_from_function(
1517        dataset_fn, input_options)
1518
1519    for x in ds:
1520      assert x.values[0].device == distribution.extended.worker_devices[0]
1521      assert x.values[0].backing_device == distribution.extended.worker_devices[
1522          0]
1523      assert x.values[1].device == distribution.extended.worker_devices[1]
1524      assert x.values[1].backing_device == distribution.extended.worker_devices[
1525          1]
1526
1527  @combinations.generate(
1528      combinations.combine(
1529          distribution=[
1530              strategy_combinations.mirrored_strategy_with_two_gpus,
1531              strategy_combinations
1532              .mirrored_strategy_with_two_gpus_no_merge_call,
1533              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1534              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1535          ],
1536          input_options=[
1537              distribute_lib.InputOptions(
1538                  experimental_place_dataset_on_device=False,
1539                  experimental_fetch_to_device=False,
1540                  experimental_replication_mode=distribute_lib
1541                  .InputReplicationMode.PER_WORKER)
1542          ],
1543          mode=["eager"],
1544      ))
1545  def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
1546      self, distribution, input_options):
1547
1548    def dataset_fn(input_context):
1549      return dataset_ops.Dataset.from_tensor_slices(
1550          np.full(4, input_context.input_pipeline_id))
1551
1552    ds = distribution.experimental_distribute_datasets_from_function(
1553        dataset_fn, input_options)
1554
1555    for x in ds:
1556      x = distribution.run(lambda inputs: inputs, args=(x,))
1557      assert x.values[
1558          0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
1559      assert x.values[
1560          0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
1561      assert x.values[
1562          1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
1563      assert x.values[
1564          1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
1565
1566  @combinations.generate(
1567      combinations.combine(
1568          input_options=[
1569              distribute_lib.InputOptions(
1570                  experimental_place_dataset_on_device=True,
1571                  experimental_fetch_to_device=False,
1572                  experimental_replication_mode=distribute_lib
1573                  .InputReplicationMode.PER_WORKER),
1574              distribute_lib.InputOptions(
1575                  experimental_place_dataset_on_device=True,
1576                  experimental_fetch_to_device=True,
1577                  experimental_replication_mode=distribute_lib
1578                  .InputReplicationMode.PER_REPLICA)
1579          ],
1580          mode=["eager"],
1581          distribution=[
1582              strategy_combinations.mirrored_strategy_with_two_gpus,
1583              strategy_combinations
1584              .mirrored_strategy_with_two_gpus_no_merge_call,
1585              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1586              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1587          ]))
1588  def testDevicePlacementForInvalidCombinations(self, distribution,
1589                                                input_options):
1590
1591    def dataset_fn(input_context):
1592      return dataset_ops.Dataset.from_tensor_slices(
1593          np.full(4, input_context.input_pipeline_id))
1594
1595    with self.assertRaises(ValueError):
1596      distribution.experimental_distribute_datasets_from_function(
1597          dataset_fn, input_options)
1598
1599  @combinations.generate(
1600      combinations.combine(
1601          input_options=[
1602              distribute_lib.InputOptions(
1603                  experimental_place_dataset_on_device=False,
1604                  experimental_fetch_to_device=False,
1605                  experimental_per_replica_buffer_size=2),
1606              distribute_lib.InputOptions(
1607                  experimental_place_dataset_on_device=False,
1608                  experimental_fetch_to_device=True,
1609                  experimental_per_replica_buffer_size=2),
1610          ],
1611          mode=["eager"],
1612          distribution=[
1613              strategy_combinations.mirrored_strategy_with_two_gpus,
1614              strategy_combinations
1615              .mirrored_strategy_with_two_gpus_no_merge_call,
1616              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1617              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1618          ]))
1619  def testPrefetchBufferSizeInputOptions(self, distribution, input_options):
1620
1621    def dataset_fn(input_context):
1622      return dataset_ops.Dataset.from_tensor_slices(
1623          np.arange(1, 11).reshape(
1624              (2, 5)) * (input_context.input_pipeline_id + 1))
1625
1626    ds = distribution.experimental_distribute_datasets_from_function(
1627        dataset_fn, input_options)
1628
1629    # validating the values
1630    x = next(iter(ds))
1631    assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
1632    assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
1633
1634  @combinations.generate(
1635      combinations.combine(
1636          input_options=[
1637              distribute_lib.InputOptions(
1638                  experimental_place_dataset_on_device=False,
1639                  experimental_fetch_to_device=False,
1640                  experimental_replication_mode=distribute_lib
1641                  .InputReplicationMode.PER_WORKER),
1642              distribute_lib.InputOptions(
1643                  experimental_place_dataset_on_device=False,
1644                  experimental_fetch_to_device=True,
1645                  experimental_replication_mode=distribute_lib
1646                  .InputReplicationMode.PER_WORKER),
1647          ],
1648          mode=["eager"],
1649          distribution=[
1650              strategy_combinations.mirrored_strategy_with_two_gpus,
1651              strategy_combinations
1652              .mirrored_strategy_with_two_gpus_no_merge_call,
1653              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1654              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1655          ]))
1656  def testOutputValuesForPerWorkerInputOptions(self, distribution,
1657                                               input_options):
1658
1659    def dataset_fn(input_context):
1660      return dataset_ops.Dataset.from_tensor_slices(
1661          np.arange(1, 11).reshape(
1662              (2, 5)) * (input_context.input_pipeline_id + 1))
1663
1664    ds = distribution.experimental_distribute_datasets_from_function(
1665        dataset_fn, input_options)
1666
1667    # validating the values
1668    x = next(iter(ds))
1669    assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
1670    assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
1671
1672  @combinations.generate(
1673      combinations.combine(
1674          input_options=[
1675              distribute_lib.InputOptions(
1676                  experimental_place_dataset_on_device=True,
1677                  experimental_fetch_to_device=False,
1678                  experimental_replication_mode=distribute_lib
1679                  .InputReplicationMode.PER_REPLICA),
1680              distribute_lib.InputOptions(
1681                  experimental_place_dataset_on_device=False,
1682                  experimental_fetch_to_device=False,
1683                  experimental_replication_mode=distribute_lib
1684                  .InputReplicationMode.PER_REPLICA),
1685              distribute_lib.InputOptions(
1686                  experimental_place_dataset_on_device=False,
1687                  experimental_fetch_to_device=True,
1688                  experimental_replication_mode=distribute_lib
1689                  .InputReplicationMode.PER_REPLICA),
1690          ],
1691          mode=["eager"],
1692          distribution=[
1693              strategy_combinations.mirrored_strategy_with_two_gpus,
1694              strategy_combinations
1695              .mirrored_strategy_with_two_gpus_no_merge_call,
1696              strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
1697              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1698          ]))
1699  def testOutputValuesForPerReplicaInputOptions(self, distribution,
1700                                                input_options):
1701
1702    def dataset_fn(input_context):
1703      return dataset_ops.Dataset.from_tensor_slices(
1704          np.arange(1, 10) * (input_context.input_pipeline_id + 1))
1705
1706    ds = distribution.experimental_distribute_datasets_from_function(
1707        dataset_fn, input_options)
1708    expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
1709    for i, x in enumerate(ds):
1710      # validating the values
1711      assert x.values[0].numpy() == expected[i]
1712      assert x.values[1].numpy() == expected[i] * 2
1713      loop_num = i
1714    assert loop_num == len(expected) - 1
1715
1716
1717class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase,
1718                                           parameterized.TestCase):
1719  """Tests for distributed iterators which read from tf.data service."""
1720
1721  def setUp(self):
1722    super(DistributedIteratorTfDataServiceTest, self).setUp()
1723    self.num_workers = 3
1724    if combinations.in_main_process():
1725      self.dispatcher = server_lib.DispatchServer()
1726      self.workers = []
1727      for _ in range(self.num_workers):
1728        self.workers.append(
1729            server_lib.WorkerServer(
1730                server_lib.WorkerConfig(
1731                    dispatcher_address=self.dispatcher.target.split("://")[1],
1732                    heartbeat_interval_ms=100,
1733                    dispatcher_timeout_ms=1000)))
1734      combinations.env().tf_data_service_dispatcher = self.dispatcher.target
1735
1736  @combinations.generate(
1737      combinations.combine(
1738          mode=["eager"],
1739          distribution=[
1740              strategy_combinations.one_device_strategy,
1741              strategy_combinations.mirrored_strategy_with_one_cpu,
1742              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1743              strategy_combinations.tpu_strategy,
1744              strategy_combinations.central_storage_strategy_with_two_gpus,
1745              strategy_combinations.multi_worker_mirrored_2x2_gpu,
1746              strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
1747              strategy_combinations.multi_worker_mirrored_2x1_cpu,
1748          ]))
1749  def testTfDataService(self, distribution):
1750    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
1751    input_workers = input_lib.InputWorkers(worker_device_pairs)
1752
1753    dataset = dataset_ops.Dataset.range(1, 50)
1754    dataset = dataset.apply(
1755        data_service_ops._distribute(
1756            processing_mode="parallel_epochs",
1757            service=combinations.env().tf_data_service_dispatcher,
1758            job_name="foo"))
1759
1760    dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
1761                                                     distribution)
1762
1763    iterator = iter(dist_dataset)
1764    results = []
1765    for element in iterator:
1766      local_results = distribution.experimental_local_results(element)
1767      for result in local_results:
1768        # input_lib.distributed_dataset may add extra '0' elements to pad
1769        # per-replica results.
1770        if result.numpy() != 0:
1771          results.append(result.numpy())
1772    self.assertNotEmpty(results)
1773    gathered = distribution.gather(constant_op.constant(results), axis=0)
1774    self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered)
1775
1776
1777if __name__ == "__main__":
1778  test_util.main()
1779