1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16import mindspore.dataset as ds 17from mindspore import log as logger 18from util import save_and_check_dict 19 20# Note: Number of rows in test.data dataset: 12 21DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] 22GENERATE_GOLDEN = False 23 24 25def test_shuffle_01(): 26 """ 27 Test shuffle: buffer_size < number-of-rows-in-dataset 28 """ 29 logger.info("test_shuffle_01") 30 # define parameters 31 buffer_size = 5 32 seed = 1 33 34 # apply dataset operations 35 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 36 ds.config.set_seed(seed) 37 data1 = data1.shuffle(buffer_size=buffer_size) 38 39 filename = "shuffle_01_result.npz" 40 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 41 42 43def test_shuffle_02(): 44 """ 45 Test shuffle: buffer_size = number-of-rows-in-dataset 46 """ 47 logger.info("test_shuffle_02") 48 # define parameters 49 buffer_size = 12 50 seed = 1 51 52 # apply dataset operations 53 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 54 ds.config.set_seed(seed) 55 data1 = data1.shuffle(buffer_size=buffer_size) 56 57 filename = "shuffle_02_result.npz" 58 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 59 60 61def test_shuffle_03(): 62 """ 63 Test shuffle: buffer_size=2 (minimum size), number-of-rows-in-dataset > 2 64 """ 65 logger.info("test_shuffle_03") 66 # define parameters 67 buffer_size = 2 68 seed = 1 69 70 # apply dataset operations 71 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 72 ds.config.set_seed(seed) 73 data1 = data1.shuffle(buffer_size) 74 75 filename = "shuffle_03_result.npz" 76 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 77 78 79def test_shuffle_04(): 80 """ 81 Test shuffle: buffer_size=2 (minimum size), number-of-rows-in-dataset = 2 82 """ 83 logger.info("test_shuffle_04") 84 # define parameters 85 buffer_size = 2 86 seed = 1 87 88 # apply dataset operations 89 data1 = ds.TFRecordDataset(DATA_DIR, num_samples=2) 90 ds.config.set_seed(seed) 91 data1 = data1.shuffle(buffer_size=buffer_size) 92 93 filename = "shuffle_04_result.npz" 94 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 95 96 97def test_shuffle_05(): 98 """ 99 Test shuffle: buffer_size > number-of-rows-in-dataset 100 """ 101 logger.info("test_shuffle_05") 102 # define parameters 103 buffer_size = 13 104 seed = 1 105 106 # apply dataset operations 107 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 108 ds.config.set_seed(seed) 109 data1 = data1.shuffle(buffer_size=buffer_size) 110 111 filename = "shuffle_05_result.npz" 112 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 113 114 115def test_shuffle_06(): 116 """ 117 Test shuffle: with set seed, both datasets 118 """ 119 logger.info("test_shuffle_06") 120 # define parameters 121 buffer_size = 13 122 seed = 1 123 124 # apply dataset operations 125 data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 126 ds.config.set_seed(seed) 127 data1 = data1.shuffle(buffer_size=buffer_size) 128 129 data2 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) 130 data2 = data2.shuffle(buffer_size=buffer_size) 131 132 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 133 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 134 np.testing.assert_equal(item1, item2) 135 136 137def test_shuffle_exception_01(): 138 """ 139 Test shuffle exception: buffer_size<0 140 """ 141 logger.info("test_shuffle_exception_01") 142 143 # apply dataset operations 144 data1 = ds.TFRecordDataset(DATA_DIR) 145 ds.config.set_seed(1) 146 try: 147 data1 = data1.shuffle(buffer_size=-1) 148 sum([1 for _ in data1]) 149 150 except Exception as e: 151 logger.info("Got an exception in DE: {}".format(str(e))) 152 assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e) 153 154 155def test_shuffle_exception_02(): 156 """ 157 Test shuffle exception: buffer_size=0 158 """ 159 logger.info("test_shuffle_exception_02") 160 161 # apply dataset operations 162 data1 = ds.TFRecordDataset(DATA_DIR) 163 ds.config.set_seed(1) 164 try: 165 data1 = data1.shuffle(buffer_size=0) 166 sum([1 for _ in data1]) 167 168 except Exception as e: 169 logger.info("Got an exception in DE: {}".format(str(e))) 170 assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e) 171 172 173def test_shuffle_exception_03(): 174 """ 175 Test shuffle exception: buffer_size=1 176 """ 177 logger.info("test_shuffle_exception_03") 178 179 # apply dataset operations 180 data1 = ds.TFRecordDataset(DATA_DIR) 181 ds.config.set_seed(1) 182 try: 183 data1 = data1.shuffle(buffer_size=1) 184 sum([1 for _ in data1]) 185 186 except Exception as e: 187 logger.info("Got an exception in DE: {}".format(str(e))) 188 assert "Input buffer_size is not within the required interval of [2, 2147483647]" in str(e) 189 190 191def test_shuffle_exception_05(): 192 """ 193 Test shuffle exception: Missing mandatory buffer_size input parameter 194 """ 195 logger.info("test_shuffle_exception_05") 196 197 # apply dataset operations 198 data1 = ds.TFRecordDataset(DATA_DIR) 199 ds.config.set_seed(1) 200 try: 201 data1 = data1.shuffle() 202 sum([1 for _ in data1]) 203 204 except Exception as e: 205 logger.info("Got an exception in DE: {}".format(str(e))) 206 assert "buffer_size" in str(e) 207 208 209def test_shuffle_exception_06(): 210 """ 211 Test shuffle exception: buffer_size wrong type, boolean value False 212 """ 213 logger.info("test_shuffle_exception_06") 214 215 # apply dataset operations 216 data1 = ds.TFRecordDataset(DATA_DIR) 217 ds.config.set_seed(1) 218 try: 219 data1 = data1.shuffle(buffer_size=False) 220 sum([1 for _ in data1]) 221 222 except Exception as e: 223 logger.info("Got an exception in DE: {}".format(str(e))) 224 assert "buffer_size" in str(e) 225 226 227def test_shuffle_exception_07(): 228 """ 229 Test shuffle exception: buffer_size wrong type, boolean value True 230 """ 231 logger.info("test_shuffle_exception_07") 232 233 # apply dataset operations 234 data1 = ds.TFRecordDataset(DATA_DIR) 235 ds.config.set_seed(1) 236 try: 237 data1 = data1.shuffle(buffer_size=True) 238 sum([1 for _ in data1]) 239 240 except Exception as e: 241 logger.info("Got an exception in DE: {}".format(str(e))) 242 assert "buffer_size" in str(e) 243 244 245if __name__ == '__main__': 246 test_shuffle_01() 247 test_shuffle_02() 248 test_shuffle_03() 249 test_shuffle_04() 250 test_shuffle_05() 251 test_shuffle_06() 252 test_shuffle_exception_01() 253 test_shuffle_exception_02() 254 test_shuffle_exception_03() 255 test_shuffle_exception_05() 256 test_shuffle_exception_06() 257 test_shuffle_exception_07() 258 logger.info('\n') 259