1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the input_lib library.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python import tf2 27from tensorflow.python.data.experimental.ops import data_service_ops 28from tensorflow.python.data.experimental.service import server_lib 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.data.ops import options as options_lib 31from tensorflow.python.data.ops.options import AutoShardPolicy 32from tensorflow.python.distribute import combinations 33from tensorflow.python.distribute import device_util 34from tensorflow.python.distribute import distribute_lib 35from tensorflow.python.distribute import distribute_utils 36from tensorflow.python.distribute import input_lib 37from tensorflow.python.distribute import multi_worker_util 38from tensorflow.python.distribute import reduce_util 39from tensorflow.python.distribute import strategy_combinations 40from tensorflow.python.distribute import test_util 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.eager import test 44from tensorflow.python.framework import composite_tensor 45from tensorflow.python.framework import constant_op 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import sparse_tensor 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.ops import sparse_ops 54from tensorflow.python.ops import variables 55from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib 56from tensorflow.python.util import nest 57 58 59class DistributedIteratorTestBase(test.TestCase): 60 61 # The passed input_context is to create a sharded dataset in between-graph 62 # case. 63 # TODO(yuefengz): rewrite the following method to make it less DRY. 64 def _wrap_iterator(self, 65 input_type, 66 dataset_or_input_fn, 67 input_workers, 68 devices, 69 num_replicas_in_sync, 70 strategy, 71 input_context=None): 72 # The `input_context` passed in is to shard dataset for 73 # MultiWorkerMirroredStrategy. It doesn't apply to in-graph case where 74 # multiple InputContexts are needed. 75 if input_type == "input_fn": 76 self.assertIsNone( 77 input_context, 78 msg=("`The input_context` arg is only used to shard dataset in " 79 "`MultiWorkerMirroredStrategy` when the input type is dataset.")) 80 81 input_contexts = [] 82 for i in range(input_workers.num_workers): 83 input_contexts.append( 84 distribute_lib.InputContext( 85 # Note: `input_workers.num_workers` is always 1 in between-graph 86 # case. 87 num_input_pipelines=input_workers.num_workers, 88 input_pipeline_id=i, 89 num_replicas_in_sync=len(devices))) 90 91 iterator = input_lib.InputFunctionIterator(dataset_or_input_fn, 92 input_workers, input_contexts, 93 strategy) 94 else: 95 iterator = input_lib.DatasetIterator( 96 dataset_or_input_fn, 97 input_workers, 98 strategy, 99 num_replicas_in_sync=num_replicas_in_sync, 100 input_context=input_context) 101 return iterator 102 103 def _wrap_dataset(self, 104 input_type, 105 dataset, 106 input_workers, 107 num_replicas_in_sync, 108 strategy, 109 input_context=None): 110 if input_type == "dataset": 111 if tf2.enabled(): 112 return input_lib.DistributedDataset( 113 input_workers, 114 strategy, 115 dataset, 116 num_replicas_in_sync=num_replicas_in_sync, 117 input_context=input_context) 118 else: 119 return input_lib.DistributedDatasetV1( 120 dataset, 121 input_workers, 122 strategy, 123 num_replicas_in_sync=num_replicas_in_sync, 124 input_context=input_context) 125 else: 126 return strategy.distribute_datasets_from_function(dataset) 127 128 def _assert_iterator_values(self, 129 iterator, 130 expected_values, 131 evaluate_fn, 132 devices, 133 enable_get_next_as_optional=False): 134 actual_values = [] 135 for _ in range(len(expected_values)): 136 if enable_get_next_as_optional: 137 next_element = iterator.get_next_as_optional().get_value() 138 else: 139 next_element = iterator.get_next() 140 computed_value = evaluate_fn([ 141 distribute_utils.select_replica(r, next_element) 142 for r in range(len(devices)) 143 ]) 144 actual_values.append(computed_value) 145 for expected_value, actual_value in zip(expected_values, actual_values): 146 for expected, actual in zip(expected_value, actual_value): 147 self.assertAllEqual(expected, actual) 148 149 def _assert_dataset_values_for_loop(self, dataset, expected_values, 150 evaluate_fn, devices): 151 actual_values = [] 152 for x in dataset: 153 computed_value = self.evaluate( 154 [distribute_utils.select_replica(r, x) for r in range(len(devices))]) 155 actual_values.append(computed_value) 156 for expected_value, actual_value in zip(expected_values, actual_values): 157 for expected, actual in zip(expected_value, actual_value): 158 self.assertAllEqual(expected, actual) 159 160 def _test_input_iteration(self, 161 input_type, 162 api_type, 163 iteration_type, 164 dataset_or_input_fn, 165 worker_device_pairs, 166 expected_values, 167 strategy, 168 sess=None, 169 num_replicas_in_sync=None, 170 input_context=None): 171 if iteration_type == "for_loop" and not context.executing_eagerly(): 172 self.skipTest("unsupported test combination.") 173 174 if api_type == "wrap_into_iterator" and iteration_type == "for_loop": 175 self.skipTest("unsupported test combination.") 176 177 if api_type == "wrap_into_iterator" and input_type == "input_fn": 178 self.skipTest("unsupported test combination.") 179 180 devices = nest.flatten([ds for _, ds in worker_device_pairs]) 181 input_workers = input_lib.InputWorkers(worker_device_pairs) 182 183 if api_type == "wrap_into_iterator": 184 iterator = self._wrap_iterator( 185 input_type, 186 dataset_or_input_fn, 187 input_workers, 188 devices, 189 num_replicas_in_sync, 190 strategy, 191 input_context=input_context) 192 else: 193 # wrapping into a dataset: 194 dataset = self._wrap_dataset( 195 input_type, 196 dataset_or_input_fn, 197 input_workers, 198 num_replicas_in_sync, 199 strategy, 200 input_context=input_context) 201 202 if ops.executing_eagerly_outside_functions(): 203 iterator = iter(dataset) 204 else: 205 if isinstance(dataset, input_lib.DistributedDatasetV1): 206 iterator = dataset.make_initializable_iterator() 207 else: 208 self.skipTest("unsupported test combination") 209 210 if isinstance(iterator, composite_tensor.CompositeTensor): 211 nest.assert_same_structure( 212 iterator, iterator._type_spec, expand_composites=True) 213 214 if iteration_type == "get_next": 215 evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) 216 if not ops.executing_eagerly_outside_functions(): 217 evaluate(control_flow_ops.group(iterator.initializer)) 218 219 def test_get_next(iterator): 220 self._assert_iterator_values(iterator, expected_values, evaluate, 221 devices) 222 223 with self.assertRaises(errors.OutOfRangeError): 224 self._assert_iterator_values(iterator, expected_values, evaluate, 225 devices) 226 227 # After re-initializing the iterator, should be able to iterate again. 228 if not ops.executing_eagerly_outside_functions(): 229 evaluate(control_flow_ops.group(iterator.initializer)) 230 else: 231 if api_type == "wrap_into_iterator": 232 self.skipTest("unsupported test combination") 233 else: 234 iterator = iter(dataset) 235 236 self._assert_iterator_values(iterator, expected_values, evaluate, 237 devices) 238 239 def test_get_next_as_optional(iterator): 240 self._assert_iterator_values( 241 iterator, 242 expected_values, 243 evaluate, 244 devices, 245 enable_get_next_as_optional=True) 246 247 next_element = iterator.get_next_as_optional() 248 self.assertFalse(self.evaluate(next_element.has_value())) 249 with self.assertRaises(errors.InvalidArgumentError): 250 self._assert_iterator_values( 251 iterator, [0], 252 evaluate, 253 devices, 254 enable_get_next_as_optional=True) 255 256 test_get_next(iterator) 257 258 # re-initializing the iterator 259 if not tf2.enabled(): 260 # TODO(yuefengz): we should split this function. 261 return 262 else: 263 if api_type == "wrap_into_iterator": 264 return 265 else: 266 iterator = iter(dataset) 267 268 test_get_next_as_optional(iterator) 269 270 if iteration_type == "for_loop" and context.executing_eagerly(): 271 self._assert_dataset_values_for_loop(dataset, expected_values, 272 self.evaluate, devices) 273 274 def _create_dataset_or_input_fn(self, input_type, input_fn): 275 if input_type == "input_fn": 276 return input_fn 277 else: 278 return input_fn(distribute_lib.InputContext()) 279 280 281class DistributedIteratorTest(DistributedIteratorTestBase, 282 parameterized.TestCase): 283 284 @combinations.generate( 285 combinations.combine( 286 mode=["eager"], 287 input_type=["input_fn", "dataset"], 288 distribution=[ 289 strategy_combinations.one_device_strategy, 290 strategy_combinations.mirrored_strategy_with_one_cpu, 291 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 292 strategy_combinations.multi_worker_mirrored_2x1_cpu 293 ])) 294 def testDisablingOwnedIteratorsInTF2(self, distribution, input_type): 295 if not tf2.enabled(): 296 self.skipTest("unsupported test combination") 297 298 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 299 input_workers = input_lib.InputWorkers(worker_device_pairs) 300 dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) 301 dataset_or_input_fn = self._create_dataset_or_input_fn( 302 input_type, dataset_fn) 303 304 input_workers = input_lib.InputWorkers(worker_device_pairs) 305 if input_type == "dataset": 306 dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn, 307 input_workers, 308 distribution) 309 else: 310 dist_dataset = input_lib.get_distributed_datasets_from_function( 311 dataset_or_input_fn, input_workers, [distribute_lib.InputContext()], 312 distribution) 313 314 # Default Iterator types in TF2. 315 iterator = iter(dist_dataset) 316 self.assertIsInstance(iterator, input_lib.DistributedIterator) 317 self.assertIsInstance(iterator._iterators[0], 318 input_lib._SingleWorkerOwnedDatasetIterator) 319 320 # Disable creating owned iterators by setting a property on the strategy. 321 distribution._enable_legacy_iterators = True 322 iterator = iter(dist_dataset) 323 self.assertIsInstance(iterator, input_lib.DistributedIteratorV1) 324 self.assertIsInstance(iterator._iterators[0], 325 input_lib._SingleWorkerDatasetIterator) 326 327 @combinations.generate( 328 combinations.combine( 329 mode=["eager"], 330 distribution=[ 331 strategy_combinations.mirrored_strategy_with_gpu_and_cpu 332 ])) 333 def testMultiDeviceIterInitialize(self, distribution): 334 if tf2.enabled(): 335 self.skipTest("Only V1 is supported.") 336 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 337 "/device:CPU:0"])] 338 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) 339 340 input_workers = input_lib.InputWorkers(worker_device_pairs) 341 342 dist_dataset = input_lib.get_distributed_dataset( 343 dataset_fn(distribute_lib.InputContext()), input_workers, distribution) 344 345 iterator = dataset_ops.make_one_shot_iterator(dist_dataset) 346 347 @def_function.function 348 def init_func_for_iter(): 349 self.evaluate(iterator.initializer) 350 351 init_func_for_iter() 352 353 @combinations.generate( 354 combinations.combine( 355 mode=["graph", "eager"], 356 input_type=["input_fn", "dataset"], 357 api_type=["wrap_into_iterator", "wrap_into_dataset"], 358 iteration_type=["get_next", "for_loop"], 359 distribution=[ 360 strategy_combinations.one_device_strategy, 361 strategy_combinations.mirrored_strategy_with_one_cpu, 362 ], 363 enable_get_next_as_optional=[True, False])) 364 def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, 365 enable_get_next_as_optional): 366 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 367 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 368 dataset_or_input_fn = self._create_dataset_or_input_fn( 369 input_type, dataset_fn) 370 371 expected_values = [[i] for i in range(10)] 372 373 distribution.extended.experimental_enable_get_next_as_optional = ( 374 enable_get_next_as_optional) 375 self._test_input_iteration(input_type, api_type, iteration_type, 376 dataset_or_input_fn, worker_device_pairs, 377 expected_values, distribution) 378 379 @combinations.generate( 380 combinations.combine( 381 mode=["eager"], 382 input_type=["input_fn", "dataset"], 383 api_type=["wrap_into_dataset"], 384 iteration_type=["get_next", "for_loop"], 385 distribution=[strategy_combinations.multi_worker_mirrored_2x1_cpu], 386 enable_get_next_as_optional=[True, False])) 387 def testOneDeviceCPUMultiWorker(self, input_type, api_type, iteration_type, 388 distribution, enable_get_next_as_optional): 389 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 390 dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) 391 dataset_or_input_fn = self._create_dataset_or_input_fn( 392 input_type, dataset_fn) 393 394 expected_values = [[i] for i in range(10)] 395 396 distribution.extended.experimental_enable_get_next_as_optional = ( 397 enable_get_next_as_optional) 398 self._test_input_iteration(input_type, api_type, iteration_type, 399 dataset_or_input_fn, worker_device_pairs, 400 expected_values, distribution) 401 402 @combinations.generate( 403 combinations.combine( 404 mode=["graph", "eager"], 405 input_type=["input_fn", "dataset"], 406 api_type=["wrap_into_iterator", "wrap_into_dataset"], 407 iteration_type=["get_next", "for_loop"], 408 distribution=[ 409 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 410 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 411 ], 412 enable_get_next_as_optional=[True, False])) 413 def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type, 414 distribution, enable_get_next_as_optional): 415 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 416 "/device:CPU:0"])] 417 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 418 dataset_or_input_fn = self._create_dataset_or_input_fn( 419 input_type, dataset_fn) 420 421 expected_values = [[i, i + 1] for i in range(0, 10, 2)] 422 423 distribution.extended.experimental_enable_get_next_as_optional = ( 424 enable_get_next_as_optional) 425 self._test_input_iteration(input_type, api_type, iteration_type, 426 dataset_or_input_fn, worker_device_pairs, 427 expected_values, distribution) 428 429 @combinations.generate( 430 combinations.combine( 431 mode=["graph", "eager"], 432 input_type=["input_fn", "dataset"], 433 api_type=["wrap_into_iterator", "wrap_into_dataset"], 434 iteration_type=["get_next", "for_loop"], 435 distribution=[strategy_combinations.tpu_strategy], 436 enable_get_next_as_optional=[True, False])) 437 def testTPU(self, input_type, api_type, iteration_type, distribution, 438 enable_get_next_as_optional): 439 worker_device_pairs = collections.OrderedDict() 440 for tpu_device in distribution.extended.worker_devices: 441 host_device = device_util.get_host_for_device(tpu_device) 442 worker_device_pairs.setdefault(host_device, []) 443 worker_device_pairs[host_device].append(tpu_device) 444 worker_device_pairs = worker_device_pairs.items() 445 dataset_fn = lambda _: dataset_ops.Dataset.range(10) 446 dataset_or_input_fn = self._create_dataset_or_input_fn( 447 input_type, dataset_fn) 448 449 expected_values = [[i, i + 1] for i in range(0, 10, 2)] 450 451 distribution.extended.experimental_enable_get_next_as_optional = ( 452 enable_get_next_as_optional) 453 self._test_input_iteration(input_type, api_type, iteration_type, 454 dataset_or_input_fn, worker_device_pairs, 455 expected_values, distribution) 456 457 @combinations.generate( 458 combinations.combine( 459 mode=["graph", "eager"], 460 input_type=["input_fn", "dataset"], 461 api_type=["wrap_into_iterator", "wrap_into_dataset"], 462 iteration_type=["get_next", "for_loop"], 463 distribution=[ 464 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 465 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 466 ], 467 enable_get_next_as_optional=[True, False])) 468 def testTupleDataset(self, input_type, api_type, iteration_type, distribution, 469 enable_get_next_as_optional): 470 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 471 "/device:CPU:0"])] 472 473 def dataset_fn(ctx): 474 del ctx 475 dataset1 = dataset_ops.Dataset.range(10) 476 dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) 477 return dataset_ops.Dataset.zip((dataset1, dataset2)) 478 479 dataset_or_input_fn = self._create_dataset_or_input_fn( 480 input_type, dataset_fn) 481 482 expected_values = [ 483 [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) 484 ] 485 486 distribution.extended.experimental_enable_get_next_as_optional = ( 487 enable_get_next_as_optional) 488 self._test_input_iteration(input_type, api_type, iteration_type, 489 dataset_or_input_fn, worker_device_pairs, 490 expected_values, distribution) 491 492 @combinations.generate( 493 combinations.combine( 494 mode=["eager"], 495 input_type=["input_fn", "dataset"], 496 api_type=["wrap_into_dataset"], 497 iteration_type=["get_next", "for_loop"], 498 distribution=[ 499 strategy_combinations.multi_worker_mirrored_2x2_gpu, 500 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call 501 ], 502 enable_get_next_as_optional=[True, False])) 503 def testTupleDatasetMultiworker(self, input_type, api_type, iteration_type, 504 distribution, enable_get_next_as_optional): 505 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 506 "/device:GPU:1"])] 507 508 def dataset_fn(ctx): 509 del ctx 510 dataset1 = dataset_ops.Dataset.range(10) 511 dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) 512 return dataset_ops.Dataset.zip((dataset1, dataset2)) 513 514 dataset_or_input_fn = self._create_dataset_or_input_fn( 515 input_type, dataset_fn) 516 517 expected_values = [ 518 [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 10, 2) 519 ] 520 521 distribution.extended.experimental_enable_get_next_as_optional = ( 522 enable_get_next_as_optional) 523 524 # Input_context is not passed in and thus no sharding. 525 self._test_input_iteration(input_type, api_type, iteration_type, 526 dataset_or_input_fn, worker_device_pairs, 527 expected_values, distribution) 528 529 @combinations.generate( 530 combinations.combine( 531 mode=["eager"], 532 distribution=[ 533 strategy_combinations.one_device_strategy, 534 strategy_combinations.mirrored_strategy_with_one_cpu, 535 strategy_combinations.multi_worker_mirrored_2x1_cpu, 536 ])) 537 def testIterableIterator(self, distribution): 538 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 539 input_workers = input_lib.InputWorkers(worker_device_pairs) 540 541 dataset = dataset_ops.Dataset.range(10) 542 dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, 543 distribution) 544 545 iterator = iter(dist_dataset) 546 for i, element in enumerate(iterator): 547 self.assertAllEqual(distribution.experimental_local_results(element), [i]) 548 549 @combinations.generate( 550 combinations.combine( 551 mode=["eager"], 552 distribution=[ 553 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 554 strategy_combinations.mirrored_strategy_with_one_cpu, 555 ])) 556 def testIterableIteratorError(self, distribution): 557 dataset = dataset_ops.Dataset.range(10).batch(2) 558 dist_dataset = distribution.experimental_distribute_dataset(dataset) 559 560 iterator = iter(dist_dataset) 561 # Raises error when next(iterator) is called without strategy scope 562 with self.assertRaises(ValueError): 563 564 def replica_fn1(iterator): 565 return next(iterator) 566 567 distribution.run(replica_fn1, args=(iterator,)) 568 569 if distribution.num_replicas_in_sync == 1: 570 expected_result = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8, 9]]] 571 elif distribution.num_replicas_in_sync == 2: 572 expected_result = [[[0], [1]], [[2], [3]], [[4], [5]], [[6], [7]], 573 [[8], [9]]] 574 575 with distribution.scope(): 576 577 def replica_fn2(iterator): 578 return iterator 579 580 result = distribution.run(replica_fn2, args=(next(iterator),)) 581 self.assertAllEqual( 582 distribution.experimental_local_results(result), expected_result[0]) 583 584 # Confirm default ReplicaContext also works 585 iterator = iter(dist_dataset) 586 for i, element in enumerate(iterator): 587 self.assertAllEqual( 588 distribution.experimental_local_results(element), expected_result[i]) 589 590 @combinations.generate( 591 combinations.combine( 592 mode=["graph", "eager"], 593 input_type=["input_fn", "dataset"], 594 api_type=["wrap_into_iterator", "wrap_into_dataset"], 595 iteration_type=["get_next", "for_loop"], 596 drop_remainder=[True, False], 597 distribution=[ 598 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 599 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 600 ])) 601 def testUnevenDatasetBatches(self, input_type, api_type, iteration_type, 602 drop_remainder, distribution): 603 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 604 "/device:CPU:0"])] 605 dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch( # pylint: disable=g-long-lambda 606 2, drop_remainder=drop_remainder) 607 dataset_or_input_fn = self._create_dataset_or_input_fn( 608 input_type, dataset_fn) 609 610 # The last global batch only contains data for one replica. 611 if drop_remainder: 612 expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 613 else: 614 expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] 615 distribution.extended.experimental_enable_get_next_as_optional = True 616 self._test_input_iteration(input_type, api_type, iteration_type, 617 dataset_or_input_fn, worker_device_pairs, 618 expected_values, distribution) 619 620 @combinations.generate( 621 combinations.combine( 622 mode=["eager"], 623 input_type=["input_fn", "dataset"], 624 api_type=["wrap_into_dataset"], 625 iteration_type=["get_next", "for_loop"], 626 drop_remainder=[True, False], 627 distribution=[ 628 strategy_combinations.multi_worker_mirrored_2x1_cpu, 629 strategy_combinations.multi_worker_mirrored_2x1_gpu, 630 ])) 631 def testUnevenDatasetBatchesMultiWorker(self, input_type, api_type, 632 iteration_type, drop_remainder, 633 distribution): 634 # Actual devices don't matter in this test as long as the number of global 635 # repices is 2. 636 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 637 cr = distribution.cluster_resolver 638 self.assertIsNotNone(cr) 639 worker_count = multi_worker_util.worker_count(cr.cluster_spec(), 640 cr.task_type) 641 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 642 cr.task_type, cr.task_id) 643 644 def dataset_fn(_): 645 dataset = dataset_ops.Dataset.range(9) 646 647 if input_type == "input_fn": 648 # When input_fn is used, there is no automatic rebatching and sharding, 649 # so we add them here. 650 return dataset.shard(worker_count, id_in_cluster).batch(1) 651 else: 652 return dataset.batch(2, drop_remainder=drop_remainder) 653 654 dataset_or_input_fn = self._create_dataset_or_input_fn( 655 input_type, dataset_fn) 656 657 if drop_remainder and input_type == "dataset": 658 if id_in_cluster == 0: 659 expected_values = [[[0]], [[2]], [[4]], [[6]]] 660 else: 661 expected_values = [[[1]], [[3]], [[5]], [[7]]] 662 else: 663 # The last global batch only contains data for one replica. 664 if id_in_cluster == 0: 665 expected_values = [[[0]], [[2]], [[4]], [[6]], [[8]]] 666 else: 667 expected_values = [[[1]], [[3]], [[5]], [[7]], [[]]] 668 distribution.extended.experimental_enable_get_next_as_optional = True 669 self._test_input_iteration( 670 input_type, 671 api_type, 672 iteration_type, 673 dataset_or_input_fn, 674 worker_device_pairs, 675 expected_values, 676 distribution, 677 num_replicas_in_sync=distribution.num_replicas_in_sync, 678 input_context=distribution.extended._make_input_context()) 679 680 @combinations.generate( 681 combinations.combine( 682 mode=["eager"], 683 input_type=["input_fn", "dataset"], 684 api_type=["wrap_into_dataset"], 685 iteration_type=["get_next", "for_loop"], 686 drop_remainder=[True, False], 687 distribution=[ 688 strategy_combinations.multi_worker_mirrored_2x2_gpu, 689 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call 690 ])) 691 def testUnevenDatasetBatchesMultiWorkerFourReplicas(self, input_type, 692 api_type, iteration_type, 693 drop_remainder, 694 distribution): 695 # Actual devices don't matter in this test as long as the number of global 696 # repices is 2. 697 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 698 "/device:GPU:1"])] 699 cr = distribution.cluster_resolver 700 self.assertIsNotNone(cr) 701 worker_count = multi_worker_util.worker_count(cr.cluster_spec(), 702 cr.task_type) 703 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 704 cr.task_type, cr.task_id) 705 706 def dataset_fn(_): 707 dataset = dataset_ops.Dataset.range(15) 708 709 if input_type == "input_fn": 710 # When input_fn is used, there is no automatic rebatching and sharding, 711 # so we add them here. 712 return dataset.shard(worker_count, id_in_cluster).batch(1) 713 else: 714 return dataset.batch(4, drop_remainder=drop_remainder) 715 716 dataset_or_input_fn = self._create_dataset_or_input_fn( 717 input_type, dataset_fn) 718 719 # The last global batch only contains data for one replica. 720 if drop_remainder and input_type == "dataset": 721 if id_in_cluster == 0: 722 expected_values = [[[0], [2]], [[4], [6]], [[8], [10]]] 723 else: 724 expected_values = [[[1], [3]], [[5], [7]], [[9], [11]]] 725 else: 726 if id_in_cluster == 0: 727 expected_values = [[[0], [2]], [[4], [6]], [[8], [10]], [[12], [14]]] 728 else: 729 expected_values = [[[1], [3]], [[5], [7]], [[9], [11]], [[13], []]] 730 distribution.extended.experimental_enable_get_next_as_optional = True 731 self._test_input_iteration( 732 input_type, 733 api_type, 734 iteration_type, 735 dataset_or_input_fn, 736 worker_device_pairs, 737 expected_values, 738 distribution, 739 num_replicas_in_sync=distribution.num_replicas_in_sync, 740 input_context=distribution.extended._make_input_context()) 741 742 @combinations.generate( 743 combinations.combine( 744 mode=["graph", "eager"], 745 input_type=["dataset"], 746 api_type=["wrap_into_iterator", "wrap_into_dataset"], 747 iteration_type=["get_next", "for_loop"], 748 num_replicas_in_sync=[None, 2], 749 distribution=[ 750 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 751 strategy_combinations.central_storage_strategy_with_gpu_and_cpu 752 ], 753 enable_get_next_as_optional=[True, False])) 754 def testBatchSplitting(self, input_type, api_type, iteration_type, 755 num_replicas_in_sync, distribution, 756 enable_get_next_as_optional): 757 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 758 "/device:CPU:0"])] 759 batch_size = 10 760 dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) 761 dataset_or_input_fn = self._create_dataset_or_input_fn( 762 input_type, dataset_fn) 763 764 updated_batch_size = ( 765 batch_size // 766 num_replicas_in_sync if num_replicas_in_sync else batch_size) 767 expected_values = [[ 768 range(i, i + updated_batch_size), 769 range(i + updated_batch_size, i + 2 * updated_batch_size) 770 ] for i in range(0, 100, updated_batch_size * 2)] 771 772 distribution.extended.experimental_enable_get_next_as_optional = ( 773 enable_get_next_as_optional) 774 self._test_input_iteration( 775 input_type, 776 api_type, 777 iteration_type, 778 dataset_or_input_fn, 779 worker_device_pairs, 780 expected_values, 781 distribution, 782 sess=None, 783 num_replicas_in_sync=num_replicas_in_sync) 784 785 @combinations.generate( 786 combinations.combine( 787 mode=["eager"], 788 input_type=["dataset"], 789 api_type=["wrap_into_dataset"], 790 iteration_type=["get_next", "for_loop"], 791 num_replicas_in_sync=[None, 2], 792 distribution=[ 793 strategy_combinations.multi_worker_mirrored_2x2_gpu, 794 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call 795 ], 796 enable_get_next_as_optional=[True, False])) 797 def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, 798 num_replicas_in_sync, distribution, 799 enable_get_next_as_optional): 800 worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", 801 "/device:GPU:1"])] 802 batch_size = 10 803 cr = distribution.cluster_resolver 804 self.assertIsNotNone(cr) 805 806 def dataset_fn(_): 807 dataset = dataset_ops.Dataset.range(100).batch(batch_size) 808 return dataset 809 810 dataset_or_input_fn = self._create_dataset_or_input_fn( 811 input_type, dataset_fn) 812 813 updated_batch_size = ( 814 batch_size // 815 num_replicas_in_sync if num_replicas_in_sync else batch_size) 816 expected_values = [ 817 [ # pylint: disable=g-complex-comprehension 818 range(i, i + updated_batch_size), 819 range(i + updated_batch_size, i + 2 * updated_batch_size) 820 ] for i in range(0, 100, updated_batch_size * 2) 821 ] 822 823 distribution.extended.experimental_enable_get_next_as_optional = ( 824 enable_get_next_as_optional) 825 self._test_input_iteration( 826 input_type, 827 api_type, 828 iteration_type, 829 dataset_or_input_fn, 830 worker_device_pairs, 831 expected_values, 832 distribution, 833 sess=None, 834 num_replicas_in_sync=num_replicas_in_sync) 835 836 @combinations.generate( 837 combinations.combine( 838 mode=["eager"], 839 distribution=[ 840 strategy_combinations.one_device_strategy, 841 strategy_combinations.mirrored_strategy_with_one_cpu, 842 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 843 strategy_combinations.tpu_strategy, 844 strategy_combinations.central_storage_strategy_with_two_gpus, 845 strategy_combinations.multi_worker_mirrored_2x2_gpu, 846 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 847 strategy_combinations.multi_worker_mirrored_2x1_cpu, 848 ], 849 )) 850 def testCacheAcrossIteration(self, distribution): 851 if not tf2.enabled(): 852 self.skipTest("Only V2 is supported.") 853 854 dataset = dataset_ops.Dataset.range(16).shuffle(16).cache().batch(4) 855 dist_dataset = distribution.experimental_distribute_dataset(dataset) 856 857 first_epoch = list( 858 distribution.experimental_local_results(x) for x in dist_dataset) 859 second_epoch = list( 860 distribution.experimental_local_results(x) for x in dist_dataset) 861 862 self.assertAllEqual(first_epoch, second_epoch) 863 864 @combinations.generate( 865 combinations.combine( 866 mode=["eager"], 867 distribution=[ 868 strategy_combinations.one_device_strategy, 869 strategy_combinations.mirrored_strategy_with_one_cpu, 870 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 871 strategy_combinations.tpu_strategy, 872 strategy_combinations.central_storage_strategy_with_two_gpus, 873 strategy_combinations.multi_worker_mirrored_2x2_gpu, 874 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 875 strategy_combinations.multi_worker_mirrored_2x1_cpu, 876 ], 877 reshuffle=[True, False])) 878 def testShuffleAcrossIterations(self, distribution, reshuffle): 879 if not tf2.enabled(): 880 self.skipTest("Only V2 is supported.") 881 882 dataset = dataset_ops.Dataset.range(12).shuffle( 883 12, reshuffle_each_iteration=reshuffle).batch(4) 884 dist_dataset = distribution.experimental_distribute_dataset(dataset) 885 886 first_epoch = list( 887 distribution.experimental_local_results(x) for x in dist_dataset) 888 second_epoch = list( 889 distribution.experimental_local_results(x) for x in dist_dataset) 890 891 if reshuffle: 892 self.assertNotAllEqual(first_epoch, second_epoch) 893 else: 894 self.assertAllEqual(first_epoch, second_epoch) 895 896 @combinations.generate( 897 combinations.combine( 898 mode=["eager"], 899 distribution=[ 900 strategy_combinations.one_device_strategy, 901 strategy_combinations.mirrored_strategy_with_one_cpu, 902 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 903 strategy_combinations.tpu_strategy, 904 strategy_combinations.central_storage_strategy_with_two_gpus, 905 strategy_combinations.multi_worker_mirrored_2x2_gpu, 906 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 907 strategy_combinations.multi_worker_mirrored_2x1_cpu, 908 ])) 909 def testGetNextOptionalShape(self, distribution): 910 batch_size = 8 911 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 912 "feature": array_ops.ones([batch_size, 10]), 913 "label": array_ops.ones([batch_size]), 914 }) 915 dataset = dataset.batch(batch_size, drop_remainder=True) 916 dist_dataset = distribution.experimental_distribute_dataset(dataset) 917 per_replica_batch_size = batch_size // distribution.num_replicas_in_sync 918 919 @def_function.function 920 def train_fn(): 921 for data in dist_dataset: 922 data = nest.map_structure(distribution.experimental_local_results, data) 923 feature = data["feature"] 924 label = data["label"] 925 926 # Assert the shapes are still static from all replicas. 927 for replica_id in range(len(distribution.extended.worker_devices)): 928 self.assertEqual([per_replica_batch_size, 10], 929 feature[replica_id].shape) 930 self.assertEqual([per_replica_batch_size], label[replica_id].shape) 931 932 train_fn() 933 934 @combinations.generate( 935 combinations.combine( 936 mode=["eager"], 937 distribution=[ 938 strategy_combinations.multi_worker_mirrored_2x1_cpu, 939 ], 940 input_type=["dataset"], 941 api_type=["wrap_into_iterator", "wrap_into_dataset"], 942 iteration_type=["get_next", "for_loop"], 943 auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.OFF])) 944 def testAutoshardingOption(self, distribution, input_type, api_type, 945 iteration_type, auto_shard_policy): 946 cr = distribution.cluster_resolver 947 self.assertIsNotNone(cr) 948 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 949 cr.task_type, cr.task_id) 950 ds_option = options_lib.Options() 951 ds_option.experimental_distribute.auto_shard_policy = auto_shard_policy 952 dataset_fn = ( 953 lambda _: dataset_ops.Dataset.range(4).with_options(ds_option)) 954 dataset_or_input_fn = self._create_dataset_or_input_fn( 955 input_type, dataset_fn) 956 957 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 958 if auto_shard_policy == AutoShardPolicy.AUTO: 959 if id_in_cluster == 0: 960 expected_values = [[0], [2]] 961 else: 962 expected_values = [[1], [3]] 963 else: 964 expected_values = [[0], [1], [2], [3]] 965 self._test_input_iteration( 966 input_type, 967 api_type, 968 iteration_type, 969 dataset_or_input_fn, 970 worker_device_pairs, 971 expected_values, 972 distribution, 973 input_context=distribution.extended._make_input_context()) 974 975 @combinations.generate( 976 combinations.combine( 977 mode=["eager"], 978 distribution=[ 979 strategy_combinations.multi_worker_mirrored_2x1_cpu, 980 ], 981 input_type=["input_fn"], 982 api_type=["wrap_into_dataset"], 983 iteration_type=["get_next", "for_loop"])) 984 def testDifferentDatasetsMultiWorker(self, distribution, input_type, api_type, 985 iteration_type): 986 cr = distribution.cluster_resolver 987 self.assertIsNotNone(cr) 988 id_in_cluster = multi_worker_util.id_in_cluster(cr.cluster_spec(), 989 cr.task_type, cr.task_id) 990 991 def dataset_fn(ctx): 992 if ctx.input_pipeline_id == 0: 993 return dataset_ops.Dataset.range(8).batch(2) 994 else: 995 return dataset_ops.Dataset.range(9).batch(2) 996 997 dataset_or_input_fn = self._create_dataset_or_input_fn( 998 input_type, dataset_fn) 999 1000 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1001 1002 if id_in_cluster == 0: 1003 expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[]]] 1004 else: 1005 expected_values = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]], [[8]]] 1006 distribution.extended.experimental_enable_get_next_as_optional = True 1007 self._test_input_iteration(input_type, api_type, iteration_type, 1008 dataset_or_input_fn, worker_device_pairs, 1009 expected_values, distribution) 1010 1011 @combinations.generate( 1012 combinations.combine( 1013 strategy=[ 1014 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1015 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1016 ], 1017 mode=["eager"])) 1018 def testLoopOverDatasetInTFFunction(self, strategy): 1019 dataset = dataset_ops.Dataset.range(10).map(lambda x: { # pylint: disable=g-long-lambda 1020 "y": math_ops.cast(x, dtypes.float32) ** 2, 1021 }).batch(4) 1022 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1023 1024 with strategy.scope(): 1025 v = variables.Variable(0.0, aggregation=variables.VariableAggregation.SUM) 1026 1027 @def_function.function 1028 def iterator_fn(dist_dataset): 1029 1030 def assign_add_fn(data): 1031 v.assign_add(math_ops.reduce_sum(data["y"])) 1032 1033 for data in dist_dataset: 1034 strategy.run(assign_add_fn, args=(data,)) 1035 1036 iterator_fn(dist_dataset) 1037 self.assertEqual(v.numpy(), 285.0) 1038 1039 1040class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, 1041 parameterized.TestCase): 1042 """Tests for DistributedDataset with non-dense tensors.""" 1043 1044 @combinations.generate( 1045 combinations.combine( 1046 mode=["eager"], 1047 distribution=[ 1048 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1049 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1050 ], 1051 input_type=["dataset", "input_fn"], 1052 drop_remainder=[False, True], 1053 defun_type=["lambda", "tf_function"], 1054 )) 1055 def testRaggedSparse(self, distribution, input_type, drop_remainder, 1056 defun_type): 1057 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1058 if not tf2.enabled(): 1059 self.skipTest("Only V2 is supported.") 1060 1061 defun = { 1062 "lambda": lambda f: f, 1063 "tf_function": def_function.function 1064 }[defun_type] 1065 distribution.extended.experimental_enable_get_next_as_optional = True 1066 global_batch_size = 8 1067 1068 def dataset_fn(ctx=None): 1069 ctx = ctx or distribute_lib.InputContext() 1070 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1071 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1072 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1073 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1074 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1075 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1076 "dense": ragged_tensor.to_tensor(), 1077 "ragged": ragged_tensor, 1078 "sparse": ragged_tensor.to_sparse(), 1079 }) 1080 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1081 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1082 1083 dataset_or_input_fn = self._create_dataset_or_input_fn( 1084 input_type, dataset_fn) 1085 dataset = self._wrap_dataset(input_type, dataset_or_input_fn, 1086 distribution.extended._input_workers, 1087 len(distribution.extended.worker_devices), 1088 distribution) 1089 # Assert that the tensors are rebatched and sparsity is preserved. 1090 per_replica_batch = defun(lambda x: next(iter(x)))(dataset) 1091 self.assertAllEqual( 1092 distribute_utils.select_replica(0, per_replica_batch["dense"]), 1093 [[0., 0., 0.], [1., 0., 0.], [2., 2., 0.], [3., 3., 3.]]) 1094 self.assertAllEqual( 1095 distribute_utils.select_replica(1, per_replica_batch["dense"]), 1096 [[0., 0., 0.], [5., 0., 0.], [6., 6., 0.], [7., 7., 7.]]) 1097 # Transitively check the ragged and sparse tensors by densification. 1098 for i in range(2): 1099 self.assertLen( 1100 distribute_utils.select_replica(i, 1101 per_replica_batch["ragged"]).values, 1102 6) 1103 self.assertAllEqual( 1104 distribute_utils.select_replica( 1105 i, per_replica_batch["ragged"]).to_tensor(), 1106 distribute_utils.select_replica(i, per_replica_batch["dense"])) 1107 self.assertLen( 1108 distribute_utils.select_replica(i, 1109 per_replica_batch["sparse"]).indices, 1110 6) 1111 self.assertAllEqual( 1112 sparse_ops.sparse_tensor_to_dense( 1113 distribute_utils.select_replica(i, per_replica_batch["sparse"])), 1114 distribute_utils.select_replica(i, per_replica_batch["dense"])) 1115 # Iterate through all the batches and sum them up. 1116 def sum_batch(per_replica_features): 1117 """Sums the `PerReplica` values in the `per_replica_features` map.""" 1118 1119 def map_fn(per_replica_values): 1120 per_replica_sums = distribution.run( 1121 (lambda x: math_ops.reduce_sum(x.values)) if all( 1122 map(sparse_tensor.is_sparse, per_replica_values.values)) else 1123 math_ops.reduce_sum, (per_replica_values,)) 1124 return distribution.reduce( 1125 reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) 1126 1127 return nest.map_structure(map_fn, per_replica_features) 1128 1129 def _reduce(state, batch): 1130 sums = sum_batch(batch) 1131 return {name: value + sums[name] for name, value in state.items()} 1132 1133 def sum_for_loop(dataset): 1134 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1135 for batch in dataset: 1136 sums = _reduce(sums, batch) 1137 return sums 1138 1139 def sum_while_loop(iterator, reduce_fn): 1140 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1141 while True: 1142 try: 1143 sums = reduce_fn(sums, iterator) 1144 except (StopIteration, errors.OutOfRangeError): 1145 return sums 1146 1147 while_sums = sum_while_loop( 1148 iter(dataset), 1149 defun(lambda state, iterator: _reduce(state, next(iterator)))) 1150 self.assertAllEqual( 1151 nest.flatten(while_sums), 1152 # When there's no partial batch, the sum is smaller. 1153 [200. if drop_remainder else 310.] * 3) 1154 for_sums = defun(sum_for_loop)(dataset) 1155 # For loops always call get next as optional inside tf functions, so we 1156 # expect 310 here when using an input function (as there are 5 batches of 1157 # size 4 round robined over 2 replicas. 1158 expected_for_sum = 200. 1159 if (not drop_remainder or 1160 (defun_type == "tf_function" and input_type == "input_fn")): 1161 expected_for_sum = 310. 1162 self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) 1163 1164 @combinations.generate( 1165 combinations.combine( 1166 mode=["eager"], 1167 distribution=[ 1168 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1169 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1170 strategy_combinations.one_device_strategy, 1171 strategy_combinations.mirrored_strategy_with_one_cpu 1172 ], 1173 input_type=["dataset", "input_fn"], 1174 drop_remainder=[False, True], 1175 tensor_type=["sparse", "ragged"], 1176 enable_get_next_as_optional=[True, False])) 1177 def testRaggedSparseGetNextAsOptional(self, distribution, input_type, 1178 drop_remainder, tensor_type, 1179 enable_get_next_as_optional): 1180 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1181 if not tf2.enabled(): 1182 self.skipTest("Only V2 is supported.") 1183 1184 distribution.extended.experimental_enable_get_next_as_optional = ( 1185 enable_get_next_as_optional) 1186 global_batch_size = 8 1187 1188 def dataset_fn(ctx=None): 1189 ctx = ctx or distribute_lib.InputContext() 1190 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1191 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1192 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1193 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1194 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1195 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1196 tensor_type: (ragged_tensor if tensor_type == "ragged" else 1197 ragged_tensor.to_sparse()), 1198 }) 1199 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1200 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1201 1202 if input_type == "dataset": 1203 ds = distribution.experimental_distribute_dataset( 1204 dataset_fn(distribute_lib.InputContext())) 1205 else: 1206 ds = distribution.distribute_datasets_from_function(dataset_fn) 1207 iterator = iter(ds) 1208 1209 self.assertEqual(iterator._enable_get_next_as_optional, 1210 (not drop_remainder) and enable_get_next_as_optional) 1211 1212 @combinations.generate( 1213 combinations.combine( 1214 tf_api_version=2, 1215 mode=["eager"], 1216 distribution=[ 1217 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1218 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 1219 strategy_combinations.one_device_strategy, 1220 strategy_combinations.mirrored_strategy_with_one_cpu, 1221 # TODO(mdan): Add these? 1222 # strategy_combinations.multi_worker_mirrored_2x1_cpu, 1223 # strategy_combinations.multi_worker_mirrored_2x1_gpu, 1224 # strategy_combinations.multi_worker_mirrored_2x2_gpu, 1225 ], 1226 input_type=["dataset", "input_fn"], 1227 drop_remainder=[False, True], 1228 )) 1229 def testRaggedSparseGetNextAsOptionalInLoop(self, distribution, input_type, 1230 drop_remainder): 1231 """Test with `RaggedTensor`s and `SparseTensor`s.""" 1232 self.skipTest("b/323359921") 1233 1234 global_batch_size = 8 1235 1236 def dataset_fn(ctx=None): 1237 ctx = ctx or distribute_lib.InputContext() 1238 batch_size = ctx.get_per_replica_batch_size(global_batch_size) 1239 # Use 20 which isn't divisible by 8 to test partial batch behavior. 1240 row_lengths = np.mod(np.arange(20), 4).astype(np.int64) 1241 ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( 1242 np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) 1243 dataset = dataset_ops.DatasetV2.from_tensor_slices({ 1244 "dense": ragged_tensor.to_tensor(), 1245 "ragged": ragged_tensor, 1246 "sparse": ragged_tensor.to_sparse(), 1247 }) 1248 dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) 1249 return dataset.batch(batch_size, drop_remainder=drop_remainder) 1250 1251 if input_type == "dataset": 1252 ds = distribution.experimental_distribute_dataset( 1253 dataset_fn(distribute_lib.InputContext())) 1254 else: 1255 ds = distribution.distribute_datasets_from_function(dataset_fn) 1256 1257 # Iterate through all the batches and sum them up. 1258 def sum_batch(per_replica_features): 1259 """Sums the `PerReplica` values in the `per_replica_features` map.""" 1260 1261 def map_fn(per_replica_values): 1262 per_replica_sums = distribution.run( 1263 (lambda x: math_ops.reduce_sum(x.values)) if all( 1264 map(sparse_tensor.is_sparse, per_replica_values.values)) else 1265 math_ops.reduce_sum, (per_replica_values,)) 1266 return distribution.reduce( 1267 reduce_util.ReduceOp.SUM, per_replica_sums, axis=None) 1268 1269 return nest.map_structure(map_fn, per_replica_features) 1270 1271 def _reduce(state, batch): 1272 sums = sum_batch(batch) 1273 return {name: value + sums[name] for name, value in state.items()} 1274 1275 def sum_while_loop(ds): 1276 iterator = iter(ds) 1277 sums = {"dense": 0., "ragged": 0., "sparse": 0.} 1278 try_next = constant_op.constant(True) 1279 1280 while try_next: 1281 opt_iterate = iterator.get_next_as_optional() 1282 if opt_iterate.has_value(): 1283 sums = _reduce(sums, opt_iterate.get_value()) 1284 else: 1285 try_next = False 1286 return sums 1287 1288 sums = def_function.function(sum_while_loop)(ds) 1289 # For loops always call get next as optional inside tf functions, so we 1290 # expect 310 here when using an input function (as there are 5 batches of 1291 # size 4 round robined over 2 replicas. 1292 expected_for_sum = 200. 1293 if not drop_remainder or input_type == "input_fn": 1294 expected_for_sum = 310. 1295 self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) 1296 1297 @combinations.generate( 1298 combinations.combine( 1299 mode=["eager"], 1300 input_type=["dataset"], 1301 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1302 iteration_type=["get_next", "for_loop"], 1303 distribution=[ 1304 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1305 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1306 ])) 1307 def testMWMSPartialBatch(self, input_type, api_type, iteration_type, 1308 distribution): 1309 # Test case: 2 workers, 1 replica each. 1310 # This test simulates the sharded behavior when we have two files each with 1311 # 12 elements and a global batch size of 8. When we consider the dataset in 1312 # aggregate (non-distributed), there are 24 elements divided into 3 batches 1313 # of size 8. Hence, the correct distributed behavior is for each replica to 1314 # see sub-batches of size 4, over three steps. 1315 def dataset_fn(ctx): 1316 del ctx 1317 dataset = dataset_ops.Dataset.range(12).batch(8) 1318 1319 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1320 # `dataset` defines the per-worker dataset and will not be further 1321 # sharded. Each worker will see a dataset that is 1322 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1323 options = options_lib.Options() 1324 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1325 dataset = dataset.with_options(options) 1326 return dataset 1327 1328 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1329 1330 # Actual devices don't matter in this test as long as there is 1 local 1331 # replica. 1332 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1333 1334 # Each test runs individually on each worker, so we compare the 1335 # values on each worker. Each worker should rebatch its dataset into 1336 # smaller batches of size 4. 1337 expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] 1338 self._test_input_iteration( 1339 input_type, 1340 api_type, 1341 iteration_type, 1342 dataset, 1343 worker_device_pairs, 1344 expected_values, 1345 distribution, 1346 num_replicas_in_sync=distribution.num_replicas_in_sync, 1347 input_context=distribution.extended._make_input_context()) 1348 1349 @combinations.generate( 1350 combinations.combine( 1351 mode=["eager"], 1352 input_type=["dataset"], 1353 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1354 iteration_type=["get_next", "for_loop"], 1355 distribution=[ 1356 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1357 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1358 ])) 1359 def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type, 1360 iteration_type, distribution): 1361 # Test case: 2 workers, 1 replica each. 1362 # This test simulates the sharded behavior when we have two files each with 1363 # 12 elements and a global batch size of 8. When we consider the dataset in 1364 # aggregate (non-distributed), there are 24 elements divided into 3 batches 1365 # of size 8. Hence, the correct distributed behavior is for each replica to 1366 # see sub-batches of size 4, over three steps. However, when we create a 1367 # DistributedDataset and cannot statically infer the intended global batch 1368 # size (e.g. if the user does not use a batching dataset), each worker will 1369 # rebatch based on the dynamic batch size of the data encountered, even when 1370 # it encounters partial batches. The last per-worker partial batch (size 4) 1371 # ends up being split into two replicas, resulting in 4 steps in total, of 1372 # (global) batch sizes 8, 8, 4, 4. 1373 def dataset_fn(ctx): 1374 del ctx 1375 # The following dataset is equivalent to 1376 # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. 1377 # This causes DistributedDataset to use LegacyRebatch instead. 1378 batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) 1379 offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) 1380 dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) 1381 1382 def map_fn(offset, batch_size): 1383 return math_ops.range(offset, offset + batch_size) 1384 1385 dataset = dataset.map(map_fn) 1386 1387 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1388 # `dataset` defines the per-worker dataset and will not be further 1389 # sharded. Each worker will see a dataset that is equivalent to 1390 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1391 options = options_lib.Options() 1392 options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF 1393 dataset = dataset.with_options(options) 1394 return dataset 1395 1396 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1397 1398 # Actual devices don't matter in this test as long as the number of global 1399 # replicas is 2. 1400 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1401 1402 # Each test runs individually on each worker, so we compare the 1403 # values on each worker. Each worker should rebatch its dataset into 1404 # smaller batches of size 4. 1405 expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] 1406 self._test_input_iteration( 1407 input_type, 1408 api_type, 1409 iteration_type, 1410 dataset, 1411 worker_device_pairs, 1412 expected_values, 1413 distribution, 1414 num_replicas_in_sync=distribution.num_replicas_in_sync, 1415 input_context=distribution.extended._make_input_context()) 1416 1417 @combinations.generate( 1418 combinations.combine( 1419 mode=["eager"], 1420 input_type=["dataset"], 1421 api_type=["wrap_into_iterator", "wrap_into_dataset"], 1422 iteration_type=["get_next", "for_loop"], 1423 distribution=[ 1424 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1425 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1426 ], 1427 auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) 1428 def testMWMSWithDataSharding(self, input_type, api_type, iteration_type, 1429 distribution, auto_shard_policy): 1430 # Test case: 2 workers, 1 replica each. 1431 # This test simulates the sharded behavior the dataset is sharded by data 1432 # and the batch size is indivisible by the number of replicas. This checks 1433 # that the elements are as expected and the batch size across all workers 1434 # adds up to 3. This test will only pass if the autoshard rewrite rewrites 1435 # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. 1436 def dataset_fn(ctx): 1437 del ctx 1438 dataset = dataset_ops.Dataset.range(8).batch(3) 1439 1440 # Set the sharding behavior to OFF for simplicity of test setup; namely, 1441 # `dataset` defines the per-worker dataset and will not be further 1442 # sharded. Each worker will see a dataset that is 1443 # tf.data.Dataset.range(12).batch(8).rebatch(...). 1444 options = options_lib.Options() 1445 options.experimental_distribute.auto_shard_policy = auto_shard_policy 1446 dataset = dataset.with_options(options) 1447 return dataset 1448 1449 dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) 1450 1451 # Actual devices don't matter in this test as long as there is 1 local 1452 # replica. 1453 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1454 1455 # Each test runs individually on each worker, so we compare the 1456 # values on each worker. We expect each worker to see different shards of 1457 # data. 1458 cr = distribution.cluster_resolver 1459 worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, 1460 cr.task_id) 1461 1462 if worker_id == 0: 1463 expected_values = [[[0, 1]], [[3, 4]], [[6]]] 1464 elif worker_id == 1: 1465 expected_values = [[[2]], [[5]], [[7]]] 1466 1467 self._test_input_iteration( 1468 input_type, 1469 api_type, 1470 iteration_type, 1471 dataset, 1472 worker_device_pairs, 1473 expected_values, 1474 distribution, 1475 num_replicas_in_sync=distribution.num_replicas_in_sync, 1476 input_context=distribution.extended._make_input_context()) 1477 1478 1479class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase, 1480 parameterized.TestCase): 1481 """Tests for PER_WORKER and PER_REPLICA's InputOptions variants.""" 1482 1483 def setUp(self): 1484 context._reset_context() 1485 strategy_combinations.set_virtual_cpus_to_at_least(3) 1486 super(DistributedIteratorPerDeviceTest, self).setUp() 1487 1488 @combinations.generate( 1489 combinations.combine( 1490 input_options=[ 1491 distribute_lib.InputOptions( 1492 experimental_place_dataset_on_device=False, 1493 experimental_fetch_to_device=True, 1494 experimental_replication_mode=distribute_lib 1495 .InputReplicationMode.PER_WORKER), 1496 distribute_lib.InputOptions( 1497 experimental_place_dataset_on_device=False, 1498 experimental_fetch_to_device=True, 1499 experimental_replication_mode=distribute_lib 1500 .InputReplicationMode.PER_REPLICA), 1501 ], 1502 mode=["eager"], 1503 distribution=[ 1504 strategy_combinations.mirrored_strategy_with_two_gpus, 1505 strategy_combinations 1506 .mirrored_strategy_with_two_gpus_no_merge_call, 1507 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1508 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1509 ])) 1510 def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution, 1511 input_options): 1512 1513 def dataset_fn(input_context): # pylint: disable=[unused-argument] 1514 return dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4]) 1515 1516 ds = distribution.experimental_distribute_datasets_from_function( 1517 dataset_fn, input_options) 1518 1519 for x in ds: 1520 assert x.values[0].device == distribution.extended.worker_devices[0] 1521 assert x.values[0].backing_device == distribution.extended.worker_devices[ 1522 0] 1523 assert x.values[1].device == distribution.extended.worker_devices[1] 1524 assert x.values[1].backing_device == distribution.extended.worker_devices[ 1525 1] 1526 1527 @combinations.generate( 1528 combinations.combine( 1529 distribution=[ 1530 strategy_combinations.mirrored_strategy_with_two_gpus, 1531 strategy_combinations 1532 .mirrored_strategy_with_two_gpus_no_merge_call, 1533 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1534 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1535 ], 1536 input_options=[ 1537 distribute_lib.InputOptions( 1538 experimental_place_dataset_on_device=False, 1539 experimental_fetch_to_device=False, 1540 experimental_replication_mode=distribute_lib 1541 .InputReplicationMode.PER_WORKER) 1542 ], 1543 mode=["eager"], 1544 )) 1545 def testDevicePlacementForPerWorkerValuesWithoutPrefetch( 1546 self, distribution, input_options): 1547 1548 def dataset_fn(input_context): 1549 return dataset_ops.Dataset.from_tensor_slices( 1550 np.full(4, input_context.input_pipeline_id)) 1551 1552 ds = distribution.experimental_distribute_datasets_from_function( 1553 dataset_fn, input_options) 1554 1555 for x in ds: 1556 x = distribution.run(lambda inputs: inputs, args=(x,)) 1557 assert x.values[ 1558 0].device == "/job:localhost/replica:0/task:0/device:CPU:0" 1559 assert x.values[ 1560 0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" 1561 assert x.values[ 1562 1].device == "/job:localhost/replica:0/task:0/device:CPU:0" 1563 assert x.values[ 1564 1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0" 1565 1566 @combinations.generate( 1567 combinations.combine( 1568 input_options=[ 1569 distribute_lib.InputOptions( 1570 experimental_place_dataset_on_device=True, 1571 experimental_fetch_to_device=False, 1572 experimental_replication_mode=distribute_lib 1573 .InputReplicationMode.PER_WORKER), 1574 distribute_lib.InputOptions( 1575 experimental_place_dataset_on_device=True, 1576 experimental_fetch_to_device=True, 1577 experimental_replication_mode=distribute_lib 1578 .InputReplicationMode.PER_REPLICA) 1579 ], 1580 mode=["eager"], 1581 distribution=[ 1582 strategy_combinations.mirrored_strategy_with_two_gpus, 1583 strategy_combinations 1584 .mirrored_strategy_with_two_gpus_no_merge_call, 1585 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1586 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1587 ])) 1588 def testDevicePlacementForInvalidCombinations(self, distribution, 1589 input_options): 1590 1591 def dataset_fn(input_context): 1592 return dataset_ops.Dataset.from_tensor_slices( 1593 np.full(4, input_context.input_pipeline_id)) 1594 1595 with self.assertRaises(ValueError): 1596 distribution.experimental_distribute_datasets_from_function( 1597 dataset_fn, input_options) 1598 1599 @combinations.generate( 1600 combinations.combine( 1601 input_options=[ 1602 distribute_lib.InputOptions( 1603 experimental_place_dataset_on_device=False, 1604 experimental_fetch_to_device=False, 1605 experimental_per_replica_buffer_size=2), 1606 distribute_lib.InputOptions( 1607 experimental_place_dataset_on_device=False, 1608 experimental_fetch_to_device=True, 1609 experimental_per_replica_buffer_size=2), 1610 ], 1611 mode=["eager"], 1612 distribution=[ 1613 strategy_combinations.mirrored_strategy_with_two_gpus, 1614 strategy_combinations 1615 .mirrored_strategy_with_two_gpus_no_merge_call, 1616 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1617 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1618 ])) 1619 def testPrefetchBufferSizeInputOptions(self, distribution, input_options): 1620 1621 def dataset_fn(input_context): 1622 return dataset_ops.Dataset.from_tensor_slices( 1623 np.arange(1, 11).reshape( 1624 (2, 5)) * (input_context.input_pipeline_id + 1)) 1625 1626 ds = distribution.experimental_distribute_datasets_from_function( 1627 dataset_fn, input_options) 1628 1629 # validating the values 1630 x = next(iter(ds)) 1631 assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) 1632 assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) 1633 1634 @combinations.generate( 1635 combinations.combine( 1636 input_options=[ 1637 distribute_lib.InputOptions( 1638 experimental_place_dataset_on_device=False, 1639 experimental_fetch_to_device=False, 1640 experimental_replication_mode=distribute_lib 1641 .InputReplicationMode.PER_WORKER), 1642 distribute_lib.InputOptions( 1643 experimental_place_dataset_on_device=False, 1644 experimental_fetch_to_device=True, 1645 experimental_replication_mode=distribute_lib 1646 .InputReplicationMode.PER_WORKER), 1647 ], 1648 mode=["eager"], 1649 distribution=[ 1650 strategy_combinations.mirrored_strategy_with_two_gpus, 1651 strategy_combinations 1652 .mirrored_strategy_with_two_gpus_no_merge_call, 1653 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1654 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1655 ])) 1656 def testOutputValuesForPerWorkerInputOptions(self, distribution, 1657 input_options): 1658 1659 def dataset_fn(input_context): 1660 return dataset_ops.Dataset.from_tensor_slices( 1661 np.arange(1, 11).reshape( 1662 (2, 5)) * (input_context.input_pipeline_id + 1)) 1663 1664 ds = distribution.experimental_distribute_datasets_from_function( 1665 dataset_fn, input_options) 1666 1667 # validating the values 1668 x = next(iter(ds)) 1669 assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5])) 1670 assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10])) 1671 1672 @combinations.generate( 1673 combinations.combine( 1674 input_options=[ 1675 distribute_lib.InputOptions( 1676 experimental_place_dataset_on_device=True, 1677 experimental_fetch_to_device=False, 1678 experimental_replication_mode=distribute_lib 1679 .InputReplicationMode.PER_REPLICA), 1680 distribute_lib.InputOptions( 1681 experimental_place_dataset_on_device=False, 1682 experimental_fetch_to_device=False, 1683 experimental_replication_mode=distribute_lib 1684 .InputReplicationMode.PER_REPLICA), 1685 distribute_lib.InputOptions( 1686 experimental_place_dataset_on_device=False, 1687 experimental_fetch_to_device=True, 1688 experimental_replication_mode=distribute_lib 1689 .InputReplicationMode.PER_REPLICA), 1690 ], 1691 mode=["eager"], 1692 distribution=[ 1693 strategy_combinations.mirrored_strategy_with_two_gpus, 1694 strategy_combinations 1695 .mirrored_strategy_with_two_gpus_no_merge_call, 1696 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 1697 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1698 ])) 1699 def testOutputValuesForPerReplicaInputOptions(self, distribution, 1700 input_options): 1701 1702 def dataset_fn(input_context): 1703 return dataset_ops.Dataset.from_tensor_slices( 1704 np.arange(1, 10) * (input_context.input_pipeline_id + 1)) 1705 1706 ds = distribution.experimental_distribute_datasets_from_function( 1707 dataset_fn, input_options) 1708 expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) 1709 for i, x in enumerate(ds): 1710 # validating the values 1711 assert x.values[0].numpy() == expected[i] 1712 assert x.values[1].numpy() == expected[i] * 2 1713 loop_num = i 1714 assert loop_num == len(expected) - 1 1715 1716 1717class DistributedIteratorTfDataServiceTest(DistributedIteratorTestBase, 1718 parameterized.TestCase): 1719 """Tests for distributed iterators which read from tf.data service.""" 1720 1721 def setUp(self): 1722 super(DistributedIteratorTfDataServiceTest, self).setUp() 1723 self.num_workers = 3 1724 if combinations.in_main_process(): 1725 self.dispatcher = server_lib.DispatchServer() 1726 self.workers = [] 1727 for _ in range(self.num_workers): 1728 self.workers.append( 1729 server_lib.WorkerServer( 1730 server_lib.WorkerConfig( 1731 dispatcher_address=self.dispatcher.target.split("://")[1], 1732 heartbeat_interval_ms=100, 1733 dispatcher_timeout_ms=1000))) 1734 combinations.env().tf_data_service_dispatcher = self.dispatcher.target 1735 1736 @combinations.generate( 1737 combinations.combine( 1738 mode=["eager"], 1739 distribution=[ 1740 strategy_combinations.one_device_strategy, 1741 strategy_combinations.mirrored_strategy_with_one_cpu, 1742 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1743 strategy_combinations.tpu_strategy, 1744 strategy_combinations.central_storage_strategy_with_two_gpus, 1745 strategy_combinations.multi_worker_mirrored_2x2_gpu, 1746 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 1747 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1748 ])) 1749 def testTfDataService(self, distribution): 1750 worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] 1751 input_workers = input_lib.InputWorkers(worker_device_pairs) 1752 1753 dataset = dataset_ops.Dataset.range(1, 50) 1754 dataset = dataset.apply( 1755 data_service_ops._distribute( 1756 processing_mode="parallel_epochs", 1757 service=combinations.env().tf_data_service_dispatcher, 1758 job_name="foo")) 1759 1760 dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, 1761 distribution) 1762 1763 iterator = iter(dist_dataset) 1764 results = [] 1765 for element in iterator: 1766 local_results = distribution.experimental_local_results(element) 1767 for result in local_results: 1768 # input_lib.distributed_dataset may add extra '0' elements to pad 1769 # per-replica results. 1770 if result.numpy() != 0: 1771 results.append(result.numpy()) 1772 self.assertNotEmpty(results) 1773 gathered = distribution.gather(constant_op.constant(results), axis=0) 1774 self.assertCountEqual(self.num_workers * list(range(1, 50)), gathered) 1775 1776 1777if __name__ == "__main__": 1778 test_util.main() 1779