• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the private `_AutoShardDataset` transformation."""
16import os
17
18from absl.testing import parameterized
19
20from tensorflow.core.example import example_pb2
21from tensorflow.core.example import feature_pb2
22from tensorflow.python.data.experimental.ops import cardinality
23from tensorflow.python.data.experimental.ops import distribute
24from tensorflow.python.data.experimental.ops import interleave_ops
25from tensorflow.python.data.experimental.ops import readers
26from tensorflow.python.data.experimental.ops import testing
27from tensorflow.python.data.experimental.ops import unique
28from tensorflow.python.data.kernel_tests import checkpoint_test_base
29from tensorflow.python.data.kernel_tests import test_base
30from tensorflow.python.data.kernel_tests import tf_record_test_base
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.data.ops import options as options_lib
33from tensorflow.python.data.ops import readers as core_readers
34from tensorflow.python.framework import combinations
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.lib.io import python_io
38from tensorflow.python.ops import parsing_ops
39from tensorflow.python.ops import string_ops
40from tensorflow.python.platform import test
41
42
43def chunk(l, n):
44  for i in range(0, len(l), n):
45    yield l[i:i + n]
46
47
48class AutoShardDatasetTest(tf_record_test_base.TFRecordTestBase,
49                           parameterized.TestCase):
50
51  def setUp(self):
52    super(AutoShardDatasetTest, self).setUp()
53    self._num_files = 10
54    self._num_records = 10
55    self._filenames = self._createFiles()
56
57  def getAllDatasetElements(self, dataset):
58    actual = []
59    next_fn = self.getNext(dataset)
60    while True:
61      try:
62        actual.append(self.evaluate(next_fn()))
63      except errors.OutOfRangeError:
64        break
65    return actual
66
67  def assertDatasetProducesWithShuffle(self, dataset, expected, batch,
68                                       num_examples, shuffle):
69    if shuffle:
70      actual = []
71      next_fn = self.getNext(dataset)
72      for _ in range(num_examples):
73        elem = self.evaluate(next_fn())
74        if isinstance(elem, tuple):
75          actual.extend(elem)
76        else:
77          actual.extend(elem.tolist())
78
79      self.assertCountEqual(actual, expected)
80      with self.assertRaises(errors.OutOfRangeError):
81        self.evaluate(next_fn())
82    else:
83      self.assertDatasetProduces(dataset, list(chunk(expected, batch)))
84
85  @combinations.generate(
86      combinations.times(
87          test_base.default_test_combinations(),
88          combinations.combine(shuffle=[True, False])))
89  def testFlatMapReaderPipeline(self, shuffle):
90    dataset = dataset_ops.Dataset.list_files(
91        self._filenames, shuffle=shuffle)
92    dataset = dataset.flat_map(core_readers.TFRecordDataset)
93    dataset = dataset.batch(5)
94    dataset = distribute._AutoShardDataset(dataset, 5, 3)
95
96    expected = [
97        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
98        for f in (3, 8)
99        for r in range(0, 10)
100    ]
101    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
102
103  @combinations.generate(
104      combinations.times(test_base.default_test_combinations(),
105                         combinations.combine(batch_size=[1, 3, 10])))
106  def testDatasetOfReaderDatasetsPipeline(self, batch_size):
107    # This tests a scenario where a list_files main return multiple files
108    # due to the glob containing wildcards.
109    def batch(iterator, n):
110      l = len(iterator)
111      for i in range(0, l, n):
112        yield iterator[i:min(i + n, l)]
113
114    datasets = []
115    for files in batch(self._filenames, batch_size):
116      datasets.append(
117          dataset_ops.Dataset.list_files(files, shuffle=False).map(
118              core_readers.TFRecordDataset))
119    dataset = dataset_ops.Dataset.from_tensor_slices(datasets)
120    dataset = dataset.flat_map(lambda x: x)
121
122    # Simulate additional ops in between flat_map and interleave. This should be
123    # a no-op since if ShardDataset is placed right after flat_map, we will only
124    # have two datasets left at this point.
125    dataset = dataset.prefetch(1)
126    dataset = dataset.prefetch(1)
127
128    dataset = dataset.interleave(
129        lambda x: x, cycle_length=1, num_parallel_calls=1)
130
131    dataset = distribute._AutoShardDataset(dataset, 5, 0)
132    expected = [
133        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
134        for f in (0, 5)
135        for r in range(0, 10)
136    ]
137
138    self.assertDatasetProduces(dataset, expected)
139
140  @combinations.generate(test_base.default_test_combinations())
141  def testZipReaderPipeline(self):
142    dataset1 = dataset_ops.Dataset.list_files(
143        self._filenames, shuffle=False)
144    dataset1 = dataset1.apply(
145        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
146    dataset2 = dataset_ops.Dataset.list_files(
147        self._filenames, shuffle=False)
148    dataset2 = dataset2.apply(
149        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
150
151    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
152    dataset = distribute._AutoShardDataset(dataset, 5, 3)
153
154    expected = [
155        (b"Record %d of file %d" % (r, f), b"Record %d of file %d" % (r, f))  # pylint:disable=g-complex-comprehension
156        for r in range(0, 10)
157        for f in (3, 8)
158    ]
159
160    self.assertDatasetProduces(dataset, expected)
161
162  @combinations.generate(
163      combinations.times(
164          test_base.default_test_combinations(),
165          combinations.combine(shuffle=[True, False])))
166  def testConcatenateReaderPipeline(self, shuffle):
167    dataset1 = dataset_ops.Dataset.list_files(
168        self._filenames, shuffle=shuffle)
169    dataset1 = dataset1.apply(
170        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
171    dataset1 = dataset1.batch(5)
172    dataset2 = dataset_ops.Dataset.list_files(
173        self._filenames, shuffle=shuffle)
174    dataset2 = dataset2.apply(
175        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
176    dataset2 = dataset2.batch(5)
177
178    dataset = dataset1.concatenate(dataset2)
179    dataset = distribute._AutoShardDataset(dataset, 5, 3)
180
181    expected = [
182        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
183        for r in range(0, 10)
184        for f in (3, 8)
185    ]
186    expected += expected
187    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 8, shuffle)
188
189  @combinations.generate(
190      combinations.times(
191          test_base.default_test_combinations(),
192          combinations.combine(shuffle=[True, False])))
193  def testPipelineWithMap(self, shuffle):
194    dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
195    dataset = dataset.apply(
196        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
197    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
198    dataset = dataset.batch(5)
199    dataset = distribute._AutoShardDataset(dataset, 5, 3)
200
201    expected = [
202        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
203        for r in range(0, 10)
204        for f in (3, 8)
205    ]
206    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
207
208  @combinations.generate(test_base.default_test_combinations())
209  def testDirectFilenameTFRecordReaderPipeline(self):
210    dataset = core_readers.TFRecordDataset(self._filenames)
211    dataset = distribute._AutoShardDataset(dataset, 5, 0)
212
213    expected = [
214        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
215        for f in (0, 5)
216        for r in range(0, 10)
217    ]
218    self.assertDatasetProduces(dataset, expected)
219
220  @combinations.generate(
221      combinations.times(
222          test_base.default_test_combinations(),
223          combinations.combine(shuffle=[True, False])))
224  def testValidPipelineWithRangeDataset(self, shuffle):
225    dataset = dataset_ops.Dataset.range(self._num_files)
226    dataset = dataset.map(lambda n: string_ops.string_join(  # pylint:disable=g-long-lambda
227        [self.get_temp_dir(),
228         string_ops.string_format("/tf_record.{}.txt", [n])]))
229    dataset = dataset.apply(
230        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
231    dataset = dataset.map(lambda x: string_ops.substr_v2(x, 2, 1000))
232    dataset = dataset.batch(5)
233    dataset = distribute._AutoShardDataset(dataset, 5, 3)
234
235    expected = [
236        b"cord %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
237        for r in range(0, 10)
238        for f in (3, 8)
239    ]
240    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
241
242  @combinations.generate(
243      combinations.times(
244          test_base.default_test_combinations(),
245          combinations.combine(params=[(1, 0, 10, 10), (2, 1, 20, 5),
246                                       (10, 1, 1, 10)])))
247  def testStandardReaderPipeline(self, params):
248    num_epochs, index, batch_size, parallel_reads = params
249    dataset = readers.make_tf_record_dataset(
250        file_pattern=self._filenames,
251        num_epochs=num_epochs,
252        batch_size=batch_size,
253        parser_fn=None,
254        num_parallel_reads=parallel_reads,
255        drop_final_batch=True,
256        shuffle=False)
257    dataset = distribute._AutoShardDataset(dataset, 2, index)
258    outputs = self.getNext(dataset)
259    self._verify_records(
260        outputs,
261        batch_size=batch_size,
262        file_index=[i for i in range(index, self._num_records, 2)],
263        num_epochs=num_epochs,
264        interleave_cycle_length=parallel_reads,
265        drop_final_batch=True,
266        use_parser_fn=None)
267    with self.assertRaises(errors.OutOfRangeError):
268      self.evaluate(outputs())
269
270  @combinations.generate(test_base.default_test_combinations())
271  def testShardInputToInterleave(self):
272    file1 = self._writeFile("f0", [1, 2, 3])
273    file2 = self._writeFile("f1", [4, 5, 6])
274    file3 = self._writeFile("f2", [7, 8, 9])
275    dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3])
276    dataset = dataset.interleave(core_readers.TFRecordDataset, cycle_length=3)
277    dataset = distribute._AutoShardDataset(dataset, 2, 0)
278
279    # Sharding by file will interleave files 0 and 2
280    expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]]
281    actual = self.getDatasetOutput(dataset)
282    self.assertEqual(actual, expected)
283
284  @combinations.generate(test_base.default_test_combinations())
285  def testShardInputToInterleaveWithIdentityFunction(self):
286    self.skipTest("Currently fails due to b/238645949")
287    file1 = self._writeFile("f0", [1, 2, 3])
288    file2 = self._writeFile("f1", [4, 5, 6])
289    file3 = self._writeFile("f2", [7, 8, 9])
290    dataset = dataset_ops.Dataset.from_tensor_slices([file1, file2, file3])
291    dataset = dataset.map(core_readers.TFRecordDataset)
292    dataset = dataset.interleave(lambda x: x, cycle_length=3)
293    dataset = distribute._AutoShardDataset(dataset, 2, 0)
294
295    # Sharding by file will interleave files 0 and 2
296    expected = [str.encode(str(i)) for i in [1, 7, 2, 8, 3, 9]]
297    actual = self.getDatasetOutput(dataset)
298    self.assertEqual(actual, expected)
299
300  @combinations.generate(
301      combinations.times(
302          test_base.default_test_combinations(),
303          combinations.combine(shuffle=[True, False])))
304  def testSampleResNetPipeline(self, shuffle):
305    dataset = dataset_ops.Dataset.list_files(
306        self._filenames, shuffle=shuffle)
307    dataset = dataset.apply(
308        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
309    dataset = dataset.batch(5)
310    dataset = distribute._AutoShardDataset(dataset, 5, 3)
311
312    expected = [
313        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
314        for r in range(0, 10)
315        for f in (3, 8)
316    ]
317    self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
318
319  @combinations.generate(
320      combinations.times(
321          test_base.default_test_combinations(),
322          combinations.combine(sharding_policy=[
323              options_lib.AutoShardPolicy.DATA,
324              options_lib.AutoShardPolicy.AUTO
325          ])))
326  def testShardByDataBeforePrefetch(self, sharding_policy):
327    dataset = dataset_ops.Dataset.range(4)
328    dataset = dataset.apply(testing.assert_next(["Shard", "Prefetch"]))
329    dataset = dataset.prefetch(1)
330    options = options_lib.Options()
331    options.experimental_distribute.auto_shard_policy = sharding_policy
332    dataset = dataset.with_options(options)
333    dataset = distribute._AutoShardDataset(dataset, 2, 0)
334    self.assertDatasetProduces(dataset, [0, 2])
335
336  @combinations.generate(
337      combinations.times(
338          test_base.default_test_combinations(),
339          combinations.times(combinations.combine(
340              sharding_policy=[options_lib.AutoShardPolicy.DATA,
341                               options_lib.AutoShardPolicy.FILE]),
342                             combinations.combine(shuffle=[True, False]))))
343  def testReplicateAndShardProduceDisjointData(self, shuffle, sharding_policy):
344    dataset = dataset_ops.Dataset.list_files(self._filenames,
345                                             shuffle=shuffle)
346    dataset = dataset.flat_map(core_readers.TFRecordDataset)
347
348    graph_def = dataset._as_serialized_graph(
349        strip_device_assignment=True,
350        external_state_policy=options_lib.ExternalStatePolicy.WARN)
351
352    options = options_lib.Options()
353    options.experimental_distribute.auto_shard_policy = sharding_policy
354
355    ds1 = distribute._RemoteDataset(graph_def, "/device:CPU:0",
356                                    dataset.element_spec)
357    ds2 = distribute._RemoteDataset(graph_def, "/device:CPU:0",
358                                    dataset.element_spec)
359
360    ds1 = ds1.with_options(options)
361    ds2 = ds2.with_options(options)
362
363    ds1 = distribute._AutoShardDataset(ds1, 2, 0)
364    ds2 = distribute._AutoShardDataset(ds2, 2, 1)
365
366    elems1 = set(self.getAllDatasetElements(ds1))
367    elems2 = set(self.getAllDatasetElements(ds2))
368
369    self.assertEmpty(elems1.intersection(elems2))
370
371  @combinations.generate(test_base.default_test_combinations())
372  def testWorkersGreaterThanNumFilesWithDataSharding(self):
373    options = options_lib.Options()
374    options.experimental_distribute.auto_shard_policy = (
375        options_lib.AutoShardPolicy.DATA)
376
377    dataset = core_readers._TFRecordDataset(self._filenames)
378    dataset = dataset.with_options(options)
379    dataset = distribute._AutoShardDataset(dataset, 5, 0)
380
381    # Should return "Record (0,5) of file (0 --> 9)" since we are sharding by
382    # individual elements, we should be able to get some data from all files.
383    expected = [
384        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
385        for f in range(0, 10)
386        for r in (0, 5)
387    ]
388    self.assertDatasetProduces(dataset, expected)
389
390  @combinations.generate(test_base.default_test_combinations())
391  def testAutoshardPolicyOff(self):
392    options = options_lib.Options()
393    options.experimental_distribute.auto_shard_policy = (
394        options_lib.AutoShardPolicy.OFF)
395
396    dataset = core_readers._TFRecordDataset(self._filenames)
397    dataset = dataset.with_options(options)
398    dataset = distribute._AutoShardDataset(dataset, 5, 0)
399
400    # Should return every record in every file since autosharding is turned off.
401    expected = [
402        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
403        for f in range(0, 10)
404        for r in range(0, 10)
405    ]
406    self.assertDatasetProduces(dataset, expected)
407
408  @combinations.generate(test_base.default_test_combinations())
409  def testFileShardingWithoutReaderDatasetOp(self):
410    options = options_lib.Options()
411    options.experimental_distribute.auto_shard_policy = (
412        options_lib.AutoShardPolicy.FILE)
413
414    dataset = dataset_ops.Dataset.range(1024)
415    dataset = dataset.with_options(options)
416
417    # We are specifying that we want a file sharding policy, and this pipeline
418    # doesn't start with file reading, so we should error out.
419    with self.assertRaises(errors.NotFoundError):
420      dataset = distribute._AutoShardDataset(dataset, 10, 0)
421      self.evaluate(self.getNext(dataset)())
422
423  @combinations.generate(test_base.default_test_combinations())
424  def testWorkersGreaterThanNumFiles(self):
425    dataset = dataset_ops.Dataset.list_files(self._filenames)
426    dataset = dataset.apply(
427        interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
428    dataset = dataset.batch(5)
429    dataset = distribute._AutoShardDataset(dataset, 500, 499)
430    self.assertDatasetProduces(dataset, [])
431
432  @combinations.generate(test_base.default_test_combinations())
433  def testTFRecordReaderWithDirectFileNames(self):
434    # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
435    # a flat_map automatically.
436    dataset = core_readers._TFRecordDataset(self._filenames)
437    dataset = distribute._AutoShardDataset(dataset, 5, 0)
438
439    expected = [
440        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
441        for f in range(0, 10)
442        for r in (0, 5)
443    ]
444    self.assertDatasetProduces(dataset, expected)
445
446  @combinations.generate(test_base.default_test_combinations())
447  def testTFRecordReaderWithDirectFileNamesAndShapes(self):
448    # Using `_TFRecordDataset` creates a raw op rather than wrapping it around
449    # a flat_map automatically.
450    dataset = core_readers._TFRecordDataset(self._filenames)
451
452    # BatchDataset contains `output_types` and `output_shapes`
453    dataset = dataset.batch(5)
454    dataset = distribute._AutoShardDataset(dataset, 2, 0)
455
456    expected = [
457        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
458        for f in range(0, 10)
459        for r in range(0, 5)
460    ]
461    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
462
463  @combinations.generate(test_base.default_test_combinations())
464  def testShardOutOfRange(self):
465    dataset = dataset_ops.Dataset.range(5)
466    with self.assertRaises(errors.InvalidArgumentError):
467      dataset = distribute._AutoShardDataset(dataset, 10, 0)
468      self.evaluate(self.getNext(dataset)())
469
470  @combinations.generate(test_base.default_test_combinations())
471  def testShardOutOfRangeEmptyDataset(self):
472    dataset = dataset_ops.Dataset.range(0)
473    with self.assertRaises(errors.OutOfRangeError):
474      dataset = distribute._AutoShardDataset(dataset, 10, 0)
475      self.evaluate(self.getNext(dataset)())
476
477  @combinations.generate(test_base.default_test_combinations())
478  def testNoReaderPipelines(self):
479    dataset = dataset_ops.Dataset.range(1024)
480    dataset = distribute._AutoShardDataset(dataset, 2, 0)
481    self.assertDatasetProduces(dataset, [i for i in range(1024) if i % 2 == 0])
482
483  @combinations.generate(test_base.default_test_combinations())
484  def testUnknownOpInPipelineStillShardsAtTheEnd(self):
485    dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
486    dataset = dataset.flat_map(core_readers.TFRecordDataset)
487    dataset = dataset.apply(unique.unique())
488
489    dataset = distribute._AutoShardDataset(dataset, 5, 0)
490
491    expected = [
492        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
493        for f in range(0, 10)
494        for r in (0, 5)
495    ]
496    self.assertDatasetProduces(dataset, expected)
497
498  @combinations.generate(test_base.default_test_combinations())
499  def testInvalidWorkerIndex(self):
500    dataset = dataset_ops.Dataset.list_files(self._filenames)
501    dataset = dataset.flat_map(core_readers.TFRecordDataset)
502    dataset = dataset.batch(5)
503
504    with self.assertRaises(errors.InvalidArgumentError):
505      dataset = distribute._AutoShardDataset(dataset, 2, 2)
506      self.evaluate(self.getNext(dataset)())
507
508  @combinations.generate(test_base.default_test_combinations())
509  def testAssertCardinality(self):
510    dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
511    dataset = dataset.flat_map(core_readers.TFRecordDataset)
512    dataset = dataset.batch(5)
513    dataset = dataset.apply(cardinality.assert_cardinality(42))
514    dataset = distribute._AutoShardDataset(dataset, 5, 0)
515
516    expected = [
517        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
518        for f in (0, 5)
519        for r in range(0, 10)
520    ]
521    self.assertDatasetProduces(dataset, list(chunk(expected, 5)))
522
523  @combinations.generate(test_base.default_test_combinations())
524  def testMakeBatchedFeaturesDataset(self):
525    files = 2
526    records_per_file = 5
527
528    def make_record(file_index):
529      example = example_pb2.Example(
530          features=feature_pb2.Features(
531              feature={
532                  "file":
533                      feature_pb2.Feature(
534                          int64_list=feature_pb2.Int64List(value=[file_index])),
535              }))
536      return example.SerializeToString()
537
538    filenames = []
539    for file_index in range(files):
540      filename = os.path.join(self.get_temp_dir(),
541                              "tf_record.%d.txt" % file_index)
542      filenames.append(filename)
543      writer = python_io.TFRecordWriter(filename)
544      for _ in range(records_per_file):
545        writer.write(make_record(file_index))
546      writer.close()
547
548    dataset = readers.make_batched_features_dataset(
549        file_pattern=filenames,
550        batch_size=records_per_file,
551        features={
552            "file": parsing_ops.FixedLenFeature([], dtypes.int64),
553        },
554        reader=core_readers.TFRecordDataset,
555        num_epochs=1)
556    # We should shard at the file level, so that all records come from file 0.
557    dataset = distribute._AutoShardDataset(dataset, 2, 0)
558    dataset = dataset.unbatch()
559    output = self.getDatasetOutput(dataset)
560    files = [elem["file"] for elem in output]
561    self.assertEqual(files, [0] * records_per_file)
562
563  @combinations.generate(test_base.default_test_combinations())
564  def testHintShardingValidPattern(self):
565    options = options_lib.Options()
566    options.experimental_distribute.auto_shard_policy = (
567        options_lib.AutoShardPolicy.HINT)
568
569    dataset = dataset_ops.Dataset.range(100).shard(distribute.SHARD_HINT, 0)
570    dataset = dataset.with_options(options)
571    dataset = distribute._AutoShardDataset(dataset, 10, 0)
572
573    self.assertDatasetProduces(dataset, list(range(0, 100, 10)))
574
575  @combinations.generate(test_base.default_test_combinations())
576  def testHintShardingInvalidPattern(self):
577    options = options_lib.Options()
578    options.experimental_distribute.auto_shard_policy = (
579        options_lib.AutoShardPolicy.HINT)
580
581    dataset = dataset_ops.Dataset.range(100).shard(1, 0)
582    dataset = dataset.with_options(options)
583    dataset = distribute._AutoShardDataset(dataset, 10, 0)
584
585    self.assertDatasetProduces(dataset, list(range(100)))
586
587  @combinations.generate(
588      combinations.times(
589          test_base.default_test_combinations(),
590          combinations.combine(
591              auto_shard_policy=list(options_lib.AutoShardPolicy))))
592  def testEnumerateAutoShardPolicies(self, auto_shard_policy):
593    """Verifies tf.data handles every auto-shard policy with no errors."""
594    dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
595    dataset = dataset.flat_map(core_readers.TFRecordDataset)
596    dataset = dataset.batch(5)
597    options = options_lib.Options()
598    options.experimental_distribute.auto_shard_policy = auto_shard_policy
599    dataset = dataset.with_options(options)
600    dataset = distribute._AutoShardDataset(dataset, 5, 3)
601    self.getDatasetOutput(dataset, requires_initialization=True)
602
603
604class AutoShardWithRebatchDatasetTest(tf_record_test_base.TFRecordTestBase,
605                                      parameterized.TestCase):
606
607  def _setUpFiles(self, num_files, num_records_per_file):
608    self._num_files = num_files
609    self._num_records = num_records_per_file
610    self._filenames = self._createFiles()
611
612  @combinations.generate(test_base.default_test_combinations())
613  def testFileShardingWithLegacyRebatch(self):
614    # Tests that RebatchDatasetV1 is a passthrough op.
615    self._setUpFiles(num_files=5, num_records_per_file=10)
616    dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
617    dataset = dataset.apply(
618        testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
619    dataset = dataset.flat_map(core_readers.TFRecordDataset)
620    dataset = dataset.batch(5)
621    dataset = distribute._LegacyRebatchDataset(dataset, num_replicas=5)
622    dataset = distribute._AutoShardDataset(dataset, 5, 3)
623    expected = [[self._record(3, i)] for i in range(10)]
624    self.assertDatasetProduces(dataset, expected)
625
626  @combinations.generate(test_base.default_test_combinations())
627  def testFileShardingWithRebatch(self):
628    # Tests that RebatchDatasetV2 is a passthrough op.
629    self._setUpFiles(num_files=3, num_records_per_file=5)
630    dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
631    dataset = dataset.apply(
632        testing.assert_next(["Shard", "FlatMap", "Batch", "Rebatch"]))
633    dataset = dataset.flat_map(core_readers.TFRecordDataset)
634    dataset = dataset.batch(5)
635    dataset = distribute._RebatchDataset(dataset, batch_sizes=[2, 1, 2])
636    dataset = distribute._AutoShardDataset(dataset, 3, 1)
637    expected = [[self._record(1, 0), self._record(1, 1)], [self._record(1, 2)],
638                [self._record(1, 3), self._record(1, 4)]]
639    self.assertDatasetProduces(dataset, expected)
640
641  @combinations.generate(
642      combinations.times(
643          test_base.default_test_combinations(),
644          combinations.times(
645              combinations.combine(sharding_policy=[
646                  options_lib.AutoShardPolicy.DATA,
647                  options_lib.AutoShardPolicy.AUTO
648              ]), combinations.combine(with_prefetch=[True, False]))))
649  def testUseLegacyRebatchWithDataSharding(self, sharding_policy,
650                                           with_prefetch):
651    # This test simulates a distributed environment with 3 workers, each with
652    # 1 replica.
653    dataset = dataset_ops.Dataset.range(8)
654    dataset = dataset.batch(4)
655    options = options_lib.Options()
656    options.experimental_distribute.auto_shard_policy = sharding_policy
657    dataset = dataset.with_options(options)
658    # We expect the auto-shard rewrite to rewrite RebatchDatasetV2 to
659    # RebatchDataset(V1) for correctness reasons. This will modify the output
660    # of the dataset.
661    worker_a_dataset = distribute._RebatchDataset(
662        dataset, batch_sizes=[2, 1, 1])
663    if with_prefetch:
664      worker_a_dataset = worker_a_dataset.prefetch(1)
665    worker_a_dataset = distribute._AutoShardDataset(
666        worker_a_dataset, 3, 0, num_replicas=3)
667    expected = [[0, 1], [4, 5]]
668    self.assertDatasetProduces(worker_a_dataset, expected)
669
670    worker_b_dataset = distribute._RebatchDataset(
671        dataset, batch_sizes=[1, 1, 2])
672    if with_prefetch:
673      worker_b_dataset = worker_b_dataset.prefetch(1)
674    worker_b_dataset = distribute._AutoShardDataset(
675        worker_b_dataset, 3, 1, num_replicas=3)
676    expected = [[2, 3], [6, 7]]
677    self.assertDatasetProduces(worker_b_dataset, expected)
678
679    worker_c_dataset = distribute._RebatchDataset(
680        dataset, batch_sizes=[1, 2, 1])
681    if with_prefetch:
682      worker_c_dataset = worker_c_dataset.prefetch(1)
683    worker_c_dataset = distribute._AutoShardDataset(
684        worker_c_dataset, 3, 2, num_replicas=3)
685    expected = [[], []]
686    self.assertDatasetProduces(worker_c_dataset, expected)
687
688
689class AutoShardDatasetCheckpointTest(tf_record_test_base.TFRecordTestBase,
690                                     checkpoint_test_base.CheckpointTestBase,
691                                     parameterized.TestCase):
692
693  def setUp(self):
694    super(AutoShardDatasetCheckpointTest, self).setUp()
695    self._num_files = 10
696    self._num_records = 10
697    self._filenames = self._createFiles()
698
699  @combinations.generate(
700      combinations.times(test_base.default_test_combinations(),
701                         checkpoint_test_base.default_test_combinations()))
702  def test(self, verify_fn):
703
704    def build_dataset():
705      dataset = dataset_ops.Dataset.list_files(self._filenames, shuffle=False)
706      dataset = dataset.apply(
707          interleave_ops.parallel_interleave(core_readers.TFRecordDataset, 10))
708      dataset = distribute._AutoShardDataset(dataset, 5, 3)
709      return dataset
710
711    verify_fn(self, build_dataset, num_outputs=20)
712
713
714if __name__ == "__main__":
715  test.main()
716