1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.shard()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.kernel_tests import test_base 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.framework import errors 23from tensorflow.python.framework import test_util 24from tensorflow.python.platform import test 25 26 27@test_util.run_v1_only("deprecated API, no eager or V2 test coverage") 28class ShardTest(test_base.DatasetTestBase): 29 30 def testSimpleCase(self): 31 dataset = dataset_ops.Dataset.range(10).shard(5, 2) 32 self.assertDatasetProduces(dataset, expected_output=[2, 7]) 33 34 def testNestedData(self): 35 dataset_a = dataset_ops.Dataset.range(10) 36 dataset_b = dataset_ops.Dataset.range(10, 0, -1) 37 dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) 38 self.assertDatasetProduces(dataset, expected_output=[(2, 8), (7, 3)]) 39 40 def testOffsetZero(self): 41 dataset = dataset_ops.Dataset.range(10).shard(5, 0) 42 self.assertDatasetProduces(dataset, expected_output=[0, 5]) 43 44 def testOffsetGreaterNumShards(self): 45 with self.assertRaises(errors.InvalidArgumentError): 46 dataset = dataset_ops.Dataset.range(10).shard(5, 7) 47 self.evaluate(self.getNext(dataset)()) 48 49 def testNegativeOffset(self): 50 with self.assertRaises(errors.InvalidArgumentError): 51 dataset = dataset_ops.Dataset.range(10).shard(5, -3) 52 self.evaluate(self.getNext(dataset)()) 53 54 def testNegativeNumShards(self): 55 with self.assertRaises(errors.InvalidArgumentError): 56 dataset = dataset_ops.Dataset.range(10).shard(-3, 1) 57 self.evaluate(self.getNext(dataset)()) 58 59 def testZeroNumShards(self): 60 with self.assertRaises(errors.InvalidArgumentError): 61 dataset = dataset_ops.Dataset.range(10).shard(0, 1) 62 self.evaluate(self.getNext(dataset)()) 63 64 def testIteratorEndsBeforeFirstElem(self): 65 dataset = dataset_ops.Dataset.range(1).shard(5, 2) 66 self.assertDatasetProduces(dataset, expected_output=[]) 67 68 def testLargerWorkerPool(self): 69 dataset = dataset_ops.Dataset.range(10).shard(7, 5) 70 self.assertDatasetProduces(dataset, expected_output=[5]) 71 72 def testIndexEqualsNumShards(self): 73 dataset = dataset_ops.Dataset.range(10).shard(5, 4) 74 self.assertDatasetProduces(dataset, expected_output=[4, 9]) 75 76 def testIndexEqualsNumShards2(self): 77 dataset = dataset_ops.Dataset.range(10).shard(4, 3) 78 self.assertDatasetProduces(dataset, expected_output=[3, 7]) 79 80 def testNumShardsLargerThanDataset(self): 81 dataset = dataset_ops.Dataset.range(10).shard(20, 5) 82 self.assertDatasetProduces(dataset, expected_output=[5]) 83 84 85if __name__ == "__main__": 86 test.main() 87