1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.shuffle()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.python.data.kernel_tests import test_base 26from tensorflow.python.data.ops import dataset_ops 27 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import random_seed 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.platform import test 35 36 37@test_util.run_all_in_graph_and_eager_modes 38class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): 39 40 def testShuffleDataset(self): 41 components = ( 42 np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), 43 np.array([9.0, 10.0, 11.0, 12.0]) 44 ) 45 46 def dataset_fn(count=5, buffer_size=None, seed=0): 47 repeat_dataset = ( 48 dataset_ops.Dataset.from_tensor_slices(components).repeat(count)) 49 if buffer_size: 50 shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed) 51 52 self.assertEqual( 53 tuple([c.shape[1:] for c in components]), 54 dataset_ops.get_legacy_output_shapes(shuffle_dataset)) 55 return shuffle_dataset 56 else: 57 return repeat_dataset 58 59 # First run without shuffling to collect the "ground truth". 60 get_next = self.getNext(dataset_fn()) 61 unshuffled_elements = [] 62 for _ in range(20): 63 unshuffled_elements.append(self.evaluate(get_next())) 64 with self.assertRaises(errors.OutOfRangeError): 65 self.evaluate(get_next()) 66 67 # Assert that the shuffled dataset has the same elements as the 68 # "ground truth". 69 get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) 70 shuffled_elements = [] 71 for _ in range(20): 72 shuffled_elements.append(self.evaluate(get_next())) 73 with self.assertRaises(errors.OutOfRangeError): 74 self.evaluate(get_next()) 75 with self.assertRaises(errors.OutOfRangeError): 76 self.evaluate(get_next()) 77 self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements)) 78 79 # Assert that shuffling twice with the same seeds gives the same sequence. 80 get_next = self.getNext(dataset_fn(buffer_size=100, seed=37)) 81 reshuffled_elements_same_seed = [] 82 for _ in range(20): 83 reshuffled_elements_same_seed.append(self.evaluate(get_next())) 84 with self.assertRaises(errors.OutOfRangeError): 85 self.evaluate(get_next()) 86 self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) 87 88 # Assert that shuffling twice with a different seed gives a different 89 # permutation of the same elements. 90 get_next = self.getNext(dataset_fn(buffer_size=100, seed=137)) 91 reshuffled_elements_different_seed = [] 92 for _ in range(20): 93 reshuffled_elements_different_seed.append(self.evaluate(get_next())) 94 with self.assertRaises(errors.OutOfRangeError): 95 self.evaluate(get_next()) 96 self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) 97 self.assertAllEqual( 98 sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) 99 100 # Assert that the shuffled dataset has the same elements as the 101 # "ground truth" when the buffer size is smaller than the input 102 # dataset. 103 get_next = self.getNext(dataset_fn(buffer_size=2, seed=37)) 104 reshuffled_elements_small_buffer = [] 105 for _ in range(20): 106 reshuffled_elements_small_buffer.append(self.evaluate(get_next())) 107 with self.assertRaises(errors.OutOfRangeError): 108 self.evaluate(get_next()) 109 self.assertAllEqual( 110 sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) 111 112 # Test the case of shuffling an empty dataset. 113 get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37)) 114 115 with self.assertRaises(errors.OutOfRangeError): 116 self.evaluate(get_next()) 117 118 @test_util.run_deprecated_v1 119 def testSkipEagerSeedZero(self): 120 """Test for same behavior when the seed is a Python or Tensor zero.""" 121 iterator = dataset_ops.make_one_shot_iterator( 122 dataset_ops.Dataset.range(10).shuffle(10, seed=0)) 123 get_next = iterator.get_next() 124 125 elems = [] 126 with self.cached_session() as sess: 127 for _ in range(10): 128 elems.append(sess.run(get_next)) 129 with self.assertRaises(errors.OutOfRangeError): 130 sess.run(get_next) 131 132 seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 133 iterator = dataset_ops.make_initializable_iterator( 134 dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)) 135 get_next = iterator.get_next() 136 137 with self.cached_session() as sess: 138 sess.run(iterator.initializer, feed_dict={seed_placeholder: 0}) 139 for elem in elems: 140 self.assertEqual(elem, sess.run(get_next)) 141 with self.assertRaises(errors.OutOfRangeError): 142 sess.run(get_next) 143 144 def testDefaultArguments(self): 145 components = [0, 1, 2, 3, 4] 146 dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle( 147 5).repeat() 148 get_next = self.getNext(dataset) 149 counts = collections.defaultdict(lambda: 0) 150 for _ in range(10): 151 for _ in range(5): 152 counts[self.evaluate(get_next())] += 1 153 154 for i in range(5): 155 self.assertEqual(10, counts[i]) 156 157 def testShuffleNoReshuffleEachIteration(self): 158 dataset = dataset_ops.Dataset.range(10).shuffle( 159 10, reshuffle_each_iteration=False).batch(10).repeat(3) 160 next_element = self.getNext(dataset) 161 162 initial_permutation = self.evaluate(next_element()) 163 self.assertAllEqual(initial_permutation, self.evaluate(next_element())) 164 self.assertAllEqual(initial_permutation, self.evaluate(next_element())) 165 with self.assertRaises(errors.OutOfRangeError): 166 self.evaluate(next_element()) 167 168 def testShuffleReshuffleEachIteration(self): 169 dataset = dataset_ops.Dataset.range(10).shuffle( 170 10, seed=3, reshuffle_each_iteration=True).batch(10).repeat(3) 171 next_element = self.getNext(dataset) 172 173 initial_permutation = list(self.evaluate(next_element())) 174 for _ in range(2): 175 next_permutation = list(self.evaluate(next_element())) 176 self.assertNotEqual(initial_permutation, next_permutation) 177 self.assertAllEqual(sorted(initial_permutation), sorted(next_permutation)) 178 with self.assertRaises(errors.OutOfRangeError): 179 self.evaluate(next_element()) 180 181 @parameterized.named_parameters( 182 ("ReshuffleGraphLevelSeed", True, 38, None), 183 ("ReshuffleOpLevelSeed", True, None, 42), 184 ("ReshuffleGraphAndOpLevelSeed", True, 38, 42), 185 ("NoReshuffleGraphLevelSeed", False, 38, None), 186 ("NoReshuffleOpLevelSeed", False, None, 42), 187 ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42), 188 ) 189 def testSkipEagerShuffleSeed(self, reshuffle, graph_level_seed, 190 op_level_seed): 191 results = [] 192 for _ in range(2): 193 with ops.Graph().as_default() as g: 194 random_seed.set_random_seed(graph_level_seed) 195 dataset = dataset_ops.Dataset.range(10).shuffle( 196 10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat( 197 3) 198 iterator = dataset_ops.make_one_shot_iterator(dataset) 199 next_element = iterator.get_next() 200 201 run_results = [] 202 with self.session(graph=g) as sess: 203 for _ in range(30): 204 run_results.append(sess.run(next_element)) 205 with self.assertRaises(errors.OutOfRangeError): 206 sess.run(next_element) 207 results.append(run_results) 208 209 self.assertAllEqual(results[0], results[1]) 210 211 # TODO(b/117581999): fails for eager mode with result[0] equal to result[1], 212 # debug. 213 @parameterized.named_parameters( 214 ("ReshuffleOneShot", True, False), 215 ("ReshuffleInitializable", True, True), 216 ("NoReshuffleOneShot", False, False), 217 ("NoReshuffleInitializable", False, True), 218 ) 219 def testSkipEagerMultipleIterators(self, reshuffle, initializable): 220 with ops.Graph().as_default() as g: 221 dataset = dataset_ops.Dataset.range(100).shuffle( 222 10, reshuffle_each_iteration=reshuffle).repeat(3) 223 224 if initializable: 225 iterators = [dataset_ops.make_initializable_iterator(dataset) 226 for _ in range(2)] 227 else: 228 iterators = [dataset_ops.make_one_shot_iterator(dataset) 229 for _ in range(2)] 230 231 results = [] 232 with self.session(graph=g) as sess: 233 for iterator in iterators: 234 if initializable: 235 sess.run(iterator.initializer) 236 next_element = iterator.get_next() 237 run_results = [] 238 for _ in range(300): 239 run_results.append(sess.run(next_element)) 240 with self.assertRaises(errors.OutOfRangeError): 241 sess.run(next_element) 242 243 results.append(run_results) 244 245 self.assertNotEqual(results[0], results[1]) 246 247 248if __name__ == "__main__": 249 test.main() 250