1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the private `_AutoShardDataset` transformation.""" 16import os 17 18from absl.testing import parameterized 19 20from tensorflow.core.example import example_pb2 21from tensorflow.core.example import feature_pb2 22from tensorflow.python.data.experimental.ops import cardinality 23from tensorflow.python.data.experimental.ops import distribute 24from tensorflow.python.data.experimental.ops import interleave_ops 25from tensorflow.python.data.experimental.ops import readers 26from tensorflow.python.data.experimental.ops import testing 27from tensorflow.python.data.experimental.ops import unique 28from tensorflow.python.data.kernel_tests import checkpoint_test_base 29from tensorflow.python.data.kernel_tests import test_base 30from tensorflow.python.data.kernel_tests import tf_record_test_base 31from tensorflow.python.data.ops import dataset_ops 32from tensorflow.python.data.ops import options as options_lib 33from tensorflow.python.data.ops import readers as core_readers 34from tensorflow.python.framework import combinations 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.lib.io import python_io 38from tensorflow.python.ops import parsing_ops 39from tensorflow.python.ops import string_ops 40from tensorflow.python.platform import test 41 42 43def chunk(l, n): 44 for i in range(0, len(l), n): 45 yield l[i:i + n] 46 47 48class AutoShardDatasetTest(tf_record_test_base.TFRecordTestBase, 49 parameterized.TestCase): 50 51 def setUp(self): 52 super(AutoShardDatasetTest, self).setUp() 53 self._num_files = 10 54 self._num_records = 10 55 self._filenames = self._createFiles() 56 57 def getAllDatasetElements(self, dataset): 58 actual = [] 59 next_fn = self.getNext(dataset) 60 while True: 61 try: 62 actual.append(self.evaluate(next_fn())) 63 except errors.OutOfRangeError: 64 break 65 return actual 66 67 def assertDatasetProducesWithShuffle(self, dataset, expected, batch, 68 num_examples, shuffle): 69 if shuffle: 70 actual = [] 71 next_fn = self.getNext(dataset) 72 for _ in range(num_examples): 73 elem = self.evaluate(next_fn()) 74 if isinstance(elem, tuple): 75 actual.extend(elem) 76 else: 77 actual.extend(elem.tolist()) 78 79 self.assertCountEqual(actual, expected) 80 with self.assertRaises(errors.OutOfRangeError): 81 self.evaluate(next_fn()) 82 else: 83 self.assertDatasetProduces(dataset, list(chunk(expected, batch))) 84 85 @combinations.generate( 86 combinations.times( 87 test_base.default_test_combinations(), 88 combinations.combine(shuffle=[True, False]))) 89 def testFlatMapReaderPipeline(self, shuffle): 90 dataset = dataset_ops.Dataset.list_files( 91 self._filenames, shuffle=shuffle) 92 dataset = dataset.flat_map(core_readers.TFRecordDataset) 93 dataset = dataset.batch(5) 94 dataset = distribute._AutoShardDataset(dataset, 5, 3) 95 96 expected = [ 97 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 98 for f in (3, 8) 99 for r in range(0, 10) 100 ] 101 self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) 102 103 @combinations.generate( 104 combinations.times(test_base.default_test_combinations(), 105 combinations.combine(batch_size=[1, 3, 10]))) 106 def testDatasetOfReaderDatasetsPipeline(self, batch_size): 107 # This tests a scenario where a list_files main return multiple files 108 # due to the glob containing wildcards. 109 def batch(iterator, n): 110 l = len(iterator) 111 for i in range(0, l, n): 112 yield iterator[i:min(i + n, l)] 113 114 datasets = [] 115 for files in batch(self._filenames, batch_size): 116 datasets.append( 117 dataset_ops.Dataset.list_files(files, shuffle=False).map( 118 core_readers.TFRecordDataset)) 119 dataset = dataset_ops.Dataset.from_tensor_slices(datasets) 120 dataset = dataset.flat_map(lambda x: x) 121 122 # Simulate additional ops in between flat_map and interleave. This should be 123 # a no-op since if ShardDataset is placed right after flat_map, we will only 124 # have two datasets left at this point. 125 dataset = dataset.prefetch(1) 126 dataset = dataset.prefetch(1) 127 128 dataset = dataset.interleave( 129 lambda x: x, cycle_length=1, num_parallel_calls=1) 130 131 dataset = distribute._AutoShardDataset(dataset, 5, 0) 132 expected = [ 133 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 134 for f in (0, 5) 135 for r in range(0, 10) 136 ] 137 138 self.assertDatasetProduces(dataset, expected) 139 140 @combinations.generate(test_base.default_test_combinations()) 141 def testZipReaderPipeline(self): 142 dataset1 = dataset_ops.Dataset.list_files( 143 self._filenames, shuffle=False) 144 dataset1 = dataset1.apply( 145 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 146 dataset2 = dataset_ops.Dataset.list_files( 147 self._filenames, shuffle=False) 148 dataset2 = dataset2.apply( 149 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 150 151 dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) 152 dataset = distribute._AutoShardDataset(dataset, 5, 3) 153 154 expected = [ 155 (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f)) # pylint:disable=g-complex-comprehension 156 for r in range(0, 10) 157 for f in (3, 8) 158 ] 159 160 self.assertDatasetProduces(dataset, expected) 161 162 @combinations.generate( 163 combinations.times( 164 test_base.default_test_combinations(), 165 combinations.combine(shuffle=[True, False]))) 166 def testConcatenateReaderPipeline(self, shuffle): 167 dataset1 = dataset_ops.Dataset.list_files( 168 self._filenames, shuffle=shuffle) 169 dataset1 = dataset1.apply( 170 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 171 dataset1 = dataset1.batch(5) 172 dataset2 = dataset_ops.Dataset.list_files( 173 self._filenames, shuffle=shuffle) 174 dataset2 = dataset2.apply( 175 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 176 dataset2 = dataset2.batch(5) 177 178 dataset = dataset1.concatenate(dataset2) 179 dataset = distribute._AutoShardDataset(dataset, 5, 3) 180 181 expected = [ 182 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 183 for r in range(0, 10) 184 for f in (3, 8) 185 ] 186 expected += expected 187 self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle) 188 189 @combinations.generate( 190 combinations.times( 191 test_base.default_test_combinations(), 192 combinations.combine(shuffle=[True, False]))) 193 def testPipelineWithMap(self, shuffle): 194 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 195 dataset = dataset.apply( 196 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 197 dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) 198 dataset = dataset.batch(5) 199 dataset = distribute._AutoShardDataset(dataset, 5, 3) 200 201 expected = [ 202 b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 203 for r in range(0, 10) 204 for f in (3, 8) 205 ] 206 self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) 207 208 @combinations.generate(test_base.default_test_combinations()) 209 def testDirectFilenameTFRecordReaderPipeline(self): 210 dataset = core_readers.TFRecordDataset(self._filenames) 211 dataset = distribute._AutoShardDataset(dataset, 5, 0) 212 213 expected = [ 214 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 215 for f in (0, 5) 216 for r in range(0, 10) 217 ] 218 self.assertDatasetProduces(dataset, expected) 219 220 @combinations.generate( 221 combinations.times( 222 test_base.default_test_combinations(), 223 combinations.combine(shuffle=[True, False]))) 224 def testValidPipelineWithRangeDataset(self, shuffle): 225 dataset = dataset_ops.Dataset.range(self._num_files) 226 dataset = dataset.map(lambda n: string_ops.string_join( # pylint:disable=g-long-lambda 227 [self.get_temp_dir(), 228 string_ops.string_format("/tf_record.{}.txt", [n])])) 229 dataset = dataset.apply( 230 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 231 dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000)) 232 dataset = dataset.batch(5) 233 dataset = distribute._AutoShardDataset(dataset, 5, 3) 234 235 expected = [ 236 b"cord %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 237 for r in range(0, 10) 238 for f in (3, 8) 239 ] 240 self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) 241 242 @combinations.generate( 243 combinations.times( 244 test_base.default_test_combinations(), 245 combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5), 246 (10, 1, 1, 10)]))) 247 def testStandardReaderPipeline(self, params): 248 num_epochs, index, batch_size, parallel_reads = params 249 dataset = readers.make_tf_record_dataset( 250 file_pattern=self._filenames, 251 num_epochs=num_epochs, 252 batch_size=batch_size, 253 parser_fn=None, 254 num_parallel_reads=parallel_reads, 255 drop_final_batch=True, 256 shuffle=False) 257 dataset = distribute._AutoShardDataset(dataset, 2, index) 258 outputs = self.getNext(dataset) 259 self._verify_records( 260 outputs, 261 batch_size=batch_size, 262 file_index=[i for i in range(index, self._num_records, 2)], 263 num_epochs=num_epochs, 264 interleave_cycle_length=parallel_reads, 265 drop_final_batch=True, 266 use_parser_fn=None) 267 with self.assertRaises(errors.OutOfRangeError): 268 self.evaluate(outputs()) 269 270 @combinations.generate(test_base.default_test_combinations()) 271 def testShardInputToInterleave(self): 272 file1 = self._writeFile("f0", [1, 2, 3]) 273 file2 = self._writeFile("f1", [4, 5, 6]) 274 file3 = self._writeFile("f2", [7, 8, 9]) 275 dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3]) 276 dataset = dataset.interleave(core_readers.TFRecordDataset, cycle_length=3) 277 dataset = distribute._AutoShardDataset(dataset, 2, 0) 278 279 # Sharding by file will interleave files 0 and 2 280 expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]] 281 actual = self.getDatasetOutput(dataset) 282 self.assertEqual(actual, expected) 283 284 @combinations.generate(test_base.default_test_combinations()) 285 def testShardInputToInterleaveWithIdentityFunction(self): 286 self.skipTest("Currently fails due to b/238645949") 287 file1 = self._writeFile("f0", [1, 2, 3]) 288 file2 = self._writeFile("f1", [4, 5, 6]) 289 file3 = self._writeFile("f2", [7, 8, 9]) 290 dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3]) 291 dataset = dataset.map(core_readers.TFRecordDataset) 292 dataset = dataset.interleave(lambda x: x, cycle_length=3) 293 dataset = distribute._AutoShardDataset(dataset, 2, 0) 294 295 # Sharding by file will interleave files 0 and 2 296 expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]] 297 actual = self.getDatasetOutput(dataset) 298 self.assertEqual(actual, expected) 299 300 @combinations.generate( 301 combinations.times( 302 test_base.default_test_combinations(), 303 combinations.combine(shuffle=[True, False]))) 304 def testSampleResNetPipeline(self, shuffle): 305 dataset = dataset_ops.Dataset.list_files( 306 self._filenames, shuffle=shuffle) 307 dataset = dataset.apply( 308 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 309 dataset = dataset.batch(5) 310 dataset = distribute._AutoShardDataset(dataset, 5, 3) 311 312 expected = [ 313 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 314 for r in range(0, 10) 315 for f in (3, 8) 316 ] 317 self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle) 318 319 @combinations.generate( 320 combinations.times( 321 test_base.default_test_combinations(), 322 combinations.combine(sharding_policy=[ 323 options_lib.AutoShardPolicy.DATA, 324 options_lib.AutoShardPolicy.AUTO 325 ]))) 326 def testShardByDataBeforePrefetch(self, sharding_policy): 327 dataset = dataset_ops.Dataset.range(4) 328 dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"])) 329 dataset = dataset.prefetch(1) 330 options = options_lib.Options() 331 options.experimental_distribute.auto_shard_policy = sharding_policy 332 dataset = dataset.with_options(options) 333 dataset = distribute._AutoShardDataset(dataset, 2, 0) 334 self.assertDatasetProduces(dataset, [0, 2]) 335 336 @combinations.generate( 337 combinations.times( 338 test_base.default_test_combinations(), 339 combinations.times(combinations.combine( 340 sharding_policy=[options_lib.AutoShardPolicy.DATA, 341 options_lib.AutoShardPolicy.FILE]), 342 combinations.combine(shuffle=[True, False])))) 343 def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy): 344 dataset = dataset_ops.Dataset.list_files(self._filenames, 345 shuffle=shuffle) 346 dataset = dataset.flat_map(core_readers.TFRecordDataset) 347 348 graph_def = dataset._as_serialized_graph( 349 strip_device_assignment=True, 350 external_state_policy=options_lib.ExternalStatePolicy.WARN) 351 352 options = options_lib.Options() 353 options.experimental_distribute.auto_shard_policy = sharding_policy 354 355 ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0", 356 dataset.element_spec) 357 ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0", 358 dataset.element_spec) 359 360 ds1 = ds1.with_options(options) 361 ds2 = ds2.with_options(options) 362 363 ds1 = distribute._AutoShardDataset(ds1, 2, 0) 364 ds2 = distribute._AutoShardDataset(ds2, 2, 1) 365 366 elems1 = set(self.getAllDatasetElements(ds1)) 367 elems2 = set(self.getAllDatasetElements(ds2)) 368 369 self.assertEmpty(elems1.intersection(elems2)) 370 371 @combinations.generate(test_base.default_test_combinations()) 372 def testWorkersGreaterThanNumFilesWithDataSharding(self): 373 options = options_lib.Options() 374 options.experimental_distribute.auto_shard_policy = ( 375 options_lib.AutoShardPolicy.DATA) 376 377 dataset = core_readers._TFRecordDataset(self._filenames) 378 dataset = dataset.with_options(options) 379 dataset = distribute._AutoShardDataset(dataset, 5, 0) 380 381 # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by 382 # individual elements, we should be able to get some data from all files. 383 expected = [ 384 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 385 for f in range(0, 10) 386 for r in (0, 5) 387 ] 388 self.assertDatasetProduces(dataset, expected) 389 390 @combinations.generate(test_base.default_test_combinations()) 391 def testAutoshardPolicyOff(self): 392 options = options_lib.Options() 393 options.experimental_distribute.auto_shard_policy = ( 394 options_lib.AutoShardPolicy.OFF) 395 396 dataset = core_readers._TFRecordDataset(self._filenames) 397 dataset = dataset.with_options(options) 398 dataset = distribute._AutoShardDataset(dataset, 5, 0) 399 400 # Should return every record in every file since autosharding is turned off. 401 expected = [ 402 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 403 for f in range(0, 10) 404 for r in range(0, 10) 405 ] 406 self.assertDatasetProduces(dataset, expected) 407 408 @combinations.generate(test_base.default_test_combinations()) 409 def testFileShardingWithoutReaderDatasetOp(self): 410 options = options_lib.Options() 411 options.experimental_distribute.auto_shard_policy = ( 412 options_lib.AutoShardPolicy.FILE) 413 414 dataset = dataset_ops.Dataset.range(1024) 415 dataset = dataset.with_options(options) 416 417 # We are specifying that we want a file sharding policy, and this pipeline 418 # doesn't start with file reading, so we should error out. 419 with self.assertRaises(errors.NotFoundError): 420 dataset = distribute._AutoShardDataset(dataset, 10, 0) 421 self.evaluate(self.getNext(dataset)()) 422 423 @combinations.generate(test_base.default_test_combinations()) 424 def testWorkersGreaterThanNumFiles(self): 425 dataset = dataset_ops.Dataset.list_files(self._filenames) 426 dataset = dataset.apply( 427 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 428 dataset = dataset.batch(5) 429 dataset = distribute._AutoShardDataset(dataset, 500, 499) 430 self.assertDatasetProduces(dataset, []) 431 432 @combinations.generate(test_base.default_test_combinations()) 433 def testTFRecordReaderWithDirectFileNames(self): 434 # Using `_TFRecordDataset` creates a raw op rather than wrapping it around 435 # a flat_map automatically. 436 dataset = core_readers._TFRecordDataset(self._filenames) 437 dataset = distribute._AutoShardDataset(dataset, 5, 0) 438 439 expected = [ 440 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 441 for f in range(0, 10) 442 for r in (0, 5) 443 ] 444 self.assertDatasetProduces(dataset, expected) 445 446 @combinations.generate(test_base.default_test_combinations()) 447 def testTFRecordReaderWithDirectFileNamesAndShapes(self): 448 # Using `_TFRecordDataset` creates a raw op rather than wrapping it around 449 # a flat_map automatically. 450 dataset = core_readers._TFRecordDataset(self._filenames) 451 452 # BatchDataset contains `output_types` and `output_shapes` 453 dataset = dataset.batch(5) 454 dataset = distribute._AutoShardDataset(dataset, 2, 0) 455 456 expected = [ 457 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 458 for f in range(0, 10) 459 for r in range(0, 5) 460 ] 461 self.assertDatasetProduces(dataset, list(chunk(expected, 5))) 462 463 @combinations.generate(test_base.default_test_combinations()) 464 def testShardOutOfRange(self): 465 dataset = dataset_ops.Dataset.range(5) 466 with self.assertRaises(errors.InvalidArgumentError): 467 dataset = distribute._AutoShardDataset(dataset, 10, 0) 468 self.evaluate(self.getNext(dataset)()) 469 470 @combinations.generate(test_base.default_test_combinations()) 471 def testShardOutOfRangeEmptyDataset(self): 472 dataset = dataset_ops.Dataset.range(0) 473 with self.assertRaises(errors.OutOfRangeError): 474 dataset = distribute._AutoShardDataset(dataset, 10, 0) 475 self.evaluate(self.getNext(dataset)()) 476 477 @combinations.generate(test_base.default_test_combinations()) 478 def testNoReaderPipelines(self): 479 dataset = dataset_ops.Dataset.range(1024) 480 dataset = distribute._AutoShardDataset(dataset, 2, 0) 481 self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0]) 482 483 @combinations.generate(test_base.default_test_combinations()) 484 def testUnknownOpInPipelineStillShardsAtTheEnd(self): 485 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 486 dataset = dataset.flat_map(core_readers.TFRecordDataset) 487 dataset = dataset.apply(unique.unique()) 488 489 dataset = distribute._AutoShardDataset(dataset, 5, 0) 490 491 expected = [ 492 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 493 for f in range(0, 10) 494 for r in (0, 5) 495 ] 496 self.assertDatasetProduces(dataset, expected) 497 498 @combinations.generate(test_base.default_test_combinations()) 499 def testInvalidWorkerIndex(self): 500 dataset = dataset_ops.Dataset.list_files(self._filenames) 501 dataset = dataset.flat_map(core_readers.TFRecordDataset) 502 dataset = dataset.batch(5) 503 504 with self.assertRaises(errors.InvalidArgumentError): 505 dataset = distribute._AutoShardDataset(dataset, 2, 2) 506 self.evaluate(self.getNext(dataset)()) 507 508 @combinations.generate(test_base.default_test_combinations()) 509 def testAssertCardinality(self): 510 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 511 dataset = dataset.flat_map(core_readers.TFRecordDataset) 512 dataset = dataset.batch(5) 513 dataset = dataset.apply(cardinality.assert_cardinality(42)) 514 dataset = distribute._AutoShardDataset(dataset, 5, 0) 515 516 expected = [ 517 b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension 518 for f in (0, 5) 519 for r in range(0, 10) 520 ] 521 self.assertDatasetProduces(dataset, list(chunk(expected, 5))) 522 523 @combinations.generate(test_base.default_test_combinations()) 524 def testMakeBatchedFeaturesDataset(self): 525 files = 2 526 records_per_file = 5 527 528 def make_record(file_index): 529 example = example_pb2.Example( 530 features=feature_pb2.Features( 531 feature={ 532 "file": 533 feature_pb2.Feature( 534 int64_list=feature_pb2.Int64List(value=[file_index])), 535 })) 536 return example.SerializeToString() 537 538 filenames = [] 539 for file_index in range(files): 540 filename = os.path.join(self.get_temp_dir(), 541 "tf_record.%d.txt" % file_index) 542 filenames.append(filename) 543 writer = python_io.TFRecordWriter(filename) 544 for _ in range(records_per_file): 545 writer.write(make_record(file_index)) 546 writer.close() 547 548 dataset = readers.make_batched_features_dataset( 549 file_pattern=filenames, 550 batch_size=records_per_file, 551 features={ 552 "file": parsing_ops.FixedLenFeature([], dtypes.int64), 553 }, 554 reader=core_readers.TFRecordDataset, 555 num_epochs=1) 556 # We should shard at the file level, so that all records come from file 0. 557 dataset = distribute._AutoShardDataset(dataset, 2, 0) 558 dataset = dataset.unbatch() 559 output = self.getDatasetOutput(dataset) 560 files = [elem["file"] for elem in output] 561 self.assertEqual(files, [0] * records_per_file) 562 563 @combinations.generate(test_base.default_test_combinations()) 564 def testHintShardingValidPattern(self): 565 options = options_lib.Options() 566 options.experimental_distribute.auto_shard_policy = ( 567 options_lib.AutoShardPolicy.HINT) 568 569 dataset = dataset_ops.Dataset.range(100).shard(distribute.SHARD_HINT, 0) 570 dataset = dataset.with_options(options) 571 dataset = distribute._AutoShardDataset(dataset, 10, 0) 572 573 self.assertDatasetProduces(dataset, list(range(0, 100, 10))) 574 575 @combinations.generate(test_base.default_test_combinations()) 576 def testHintShardingInvalidPattern(self): 577 options = options_lib.Options() 578 options.experimental_distribute.auto_shard_policy = ( 579 options_lib.AutoShardPolicy.HINT) 580 581 dataset = dataset_ops.Dataset.range(100).shard(1, 0) 582 dataset = dataset.with_options(options) 583 dataset = distribute._AutoShardDataset(dataset, 10, 0) 584 585 self.assertDatasetProduces(dataset, list(range(100))) 586 587 @combinations.generate( 588 combinations.times( 589 test_base.default_test_combinations(), 590 combinations.combine( 591 auto_shard_policy=list(options_lib.AutoShardPolicy)))) 592 def testEnumerateAutoShardPolicies(self, auto_shard_policy): 593 """Verifies tf.data handles every auto-shard policy with no errors.""" 594 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 595 dataset = dataset.flat_map(core_readers.TFRecordDataset) 596 dataset = dataset.batch(5) 597 options = options_lib.Options() 598 options.experimental_distribute.auto_shard_policy = auto_shard_policy 599 dataset = dataset.with_options(options) 600 dataset = distribute._AutoShardDataset(dataset, 5, 3) 601 self.getDatasetOutput(dataset, requires_initialization=True) 602 603 604class AutoShardWithRebatchDatasetTest(tf_record_test_base.TFRecordTestBase, 605 parameterized.TestCase): 606 607 def _setUpFiles(self, num_files, num_records_per_file): 608 self._num_files = num_files 609 self._num_records = num_records_per_file 610 self._filenames = self._createFiles() 611 612 @combinations.generate(test_base.default_test_combinations()) 613 def testFileShardingWithLegacyRebatch(self): 614 # Tests that RebatchDatasetV1 is a passthrough op. 615 self._setUpFiles(num_files=5, num_records_per_file=10) 616 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 617 dataset = dataset.apply( 618 testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) 619 dataset = dataset.flat_map(core_readers.TFRecordDataset) 620 dataset = dataset.batch(5) 621 dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5) 622 dataset = distribute._AutoShardDataset(dataset, 5, 3) 623 expected = [[self._record(3, i)] for i in range(10)] 624 self.assertDatasetProduces(dataset, expected) 625 626 @combinations.generate(test_base.default_test_combinations()) 627 def testFileShardingWithRebatch(self): 628 # Tests that RebatchDatasetV2 is a passthrough op. 629 self._setUpFiles(num_files=3, num_records_per_file=5) 630 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 631 dataset = dataset.apply( 632 testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"])) 633 dataset = dataset.flat_map(core_readers.TFRecordDataset) 634 dataset = dataset.batch(5) 635 dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2]) 636 dataset = distribute._AutoShardDataset(dataset, 3, 1) 637 expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)], 638 [self._record(1, 3), self._record(1, 4)]] 639 self.assertDatasetProduces(dataset, expected) 640 641 @combinations.generate( 642 combinations.times( 643 test_base.default_test_combinations(), 644 combinations.times( 645 combinations.combine(sharding_policy=[ 646 options_lib.AutoShardPolicy.DATA, 647 options_lib.AutoShardPolicy.AUTO 648 ]), combinations.combine(with_prefetch=[True, False])))) 649 def testUseLegacyRebatchWithDataSharding(self, sharding_policy, 650 with_prefetch): 651 # This test simulates a distributed environment with 3 workers, each with 652 # 1 replica. 653 dataset = dataset_ops.Dataset.range(8) 654 dataset = dataset.batch(4) 655 options = options_lib.Options() 656 options.experimental_distribute.auto_shard_policy = sharding_policy 657 dataset = dataset.with_options(options) 658 # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to 659 # RebatchDataset(V1) for correctness reasons. This will modify the output 660 # of the dataset. 661 worker_a_dataset = distribute._RebatchDataset( 662 dataset, batch_sizes=[2, 1, 1]) 663 if with_prefetch: 664 worker_a_dataset = worker_a_dataset.prefetch(1) 665 worker_a_dataset = distribute._AutoShardDataset( 666 worker_a_dataset, 3, 0, num_replicas=3) 667 expected = [[0, 1], [4, 5]] 668 self.assertDatasetProduces(worker_a_dataset, expected) 669 670 worker_b_dataset = distribute._RebatchDataset( 671 dataset, batch_sizes=[1, 1, 2]) 672 if with_prefetch: 673 worker_b_dataset = worker_b_dataset.prefetch(1) 674 worker_b_dataset = distribute._AutoShardDataset( 675 worker_b_dataset, 3, 1, num_replicas=3) 676 expected = [[2, 3], [6, 7]] 677 self.assertDatasetProduces(worker_b_dataset, expected) 678 679 worker_c_dataset = distribute._RebatchDataset( 680 dataset, batch_sizes=[1, 2, 1]) 681 if with_prefetch: 682 worker_c_dataset = worker_c_dataset.prefetch(1) 683 worker_c_dataset = distribute._AutoShardDataset( 684 worker_c_dataset, 3, 2, num_replicas=3) 685 expected = [[], []] 686 self.assertDatasetProduces(worker_c_dataset, expected) 687 688 689class AutoShardDatasetCheckpointTest(tf_record_test_base.TFRecordTestBase, 690 checkpoint_test_base.CheckpointTestBase, 691 parameterized.TestCase): 692 693 def setUp(self): 694 super(AutoShardDatasetCheckpointTest, self).setUp() 695 self._num_files = 10 696 self._num_records = 10 697 self._filenames = self._createFiles() 698 699 @combinations.generate( 700 combinations.times(test_base.default_test_combinations(), 701 checkpoint_test_base.default_test_combinations())) 702 def test(self, verify_fn): 703 704 def build_dataset(): 705 dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False) 706 dataset = dataset.apply( 707 interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10)) 708 dataset = distribute._AutoShardDataset(dataset, 5, 3) 709 return dataset 710 711 verify_fn(self, build_dataset, num_outputs=20) 712 713 714if __name__ == "__main__": 715 test.main() 716