1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Test Repeat Op 17""" 18import numpy as np 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.vision.c_transforms as vision 22from mindspore import log as logger 23from util import save_and_check_dict 24 25DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] 26SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" 27 28DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 29SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 30 31GENERATE_GOLDEN = False 32 33 34def test_tf_repeat_01(): 35 """ 36 Test a simple repeat operation. 37 """ 38 logger.info("Test Simple Repeat") 39 # define parameters 40 repeat_count = 2 41 42 # apply dataset operations 43 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 44 data1 = data1.repeat(repeat_count) 45 46 filename = "repeat_result.npz" 47 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 48 49 50def test_tf_repeat_02(): 51 """ 52 Test Infinite Repeat. 53 """ 54 logger.info("Test Infinite Repeat") 55 # define parameters 56 repeat_count = -1 57 58 # apply dataset operations 59 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) 60 data1 = data1.repeat(repeat_count) 61 62 itr = 0 63 for _ in data1: 64 itr = itr + 1 65 if itr == 100: 66 break 67 assert itr == 100 68 69 70def test_tf_repeat_03(): 71 """ 72 Test Repeat then Batch. 73 """ 74 logger.info("Test Repeat then Batch") 75 data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) 76 77 batch_size = 32 78 resize_height, resize_width = 32, 32 79 decode_op = vision.Decode() 80 resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) 81 data1 = data1.map(operations=decode_op, input_columns=["image"]) 82 data1 = data1.map(operations=resize_op, input_columns=["image"]) 83 data1 = data1.repeat(22) 84 data1 = data1.batch(batch_size, drop_remainder=True) 85 86 num_iter = 0 87 for _ in data1.create_dict_iterator(num_epochs=1): 88 num_iter += 1 89 logger.info("Number of tf data in data1: {}".format(num_iter)) 90 assert num_iter == 2 91 92 93def test_tf_repeat_04(): 94 """ 95 Test a simple repeat operation with column list. 96 """ 97 logger.info("Test Simple Repeat Column List") 98 # define parameters 99 repeat_count = 2 100 columns_list = ["col_sint64", "col_sint32"] 101 # apply dataset operations 102 data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False) 103 data1 = data1.repeat(repeat_count) 104 105 filename = "repeat_list_result.npz" 106 save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) 107 108 109def generator(): 110 for i in range(3): 111 (yield np.array([i]),) 112 113 114def test_nested_repeat1(): 115 logger.info("test_nested_repeat1") 116 data = ds.GeneratorDataset(generator, ["data"]) 117 data = data.repeat(2) 118 data = data.repeat(3) 119 120 for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)): 121 assert i % 3 == d[0][0] 122 123 assert sum([1 for _ in data]) == 2 * 3 * 3 124 125 126def test_nested_repeat2(): 127 logger.info("test_nested_repeat2") 128 data = ds.GeneratorDataset(generator, ["data"]) 129 data = data.repeat(1) 130 data = data.repeat(1) 131 132 for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)): 133 assert i % 3 == d[0][0] 134 135 assert sum([1 for _ in data]) == 3 136 137 138def test_nested_repeat3(): 139 logger.info("test_nested_repeat3") 140 data = ds.GeneratorDataset(generator, ["data"]) 141 data = data.repeat(1) 142 data = data.repeat(2) 143 144 for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)): 145 assert i % 3 == d[0][0] 146 147 assert sum([1 for _ in data]) == 2 * 3 148 149 150def test_nested_repeat4(): 151 logger.info("test_nested_repeat4") 152 data = ds.GeneratorDataset(generator, ["data"]) 153 data = data.repeat(2) 154 data = data.repeat(1) 155 156 for i, d in enumerate(data.create_tuple_iterator(output_numpy=True)): 157 assert i % 3 == d[0][0] 158 159 assert sum([1 for _ in data]) == 2 * 3 160 161 162def test_nested_repeat5(): 163 logger.info("test_nested_repeat5") 164 data = ds.GeneratorDataset(generator, ["data"]) 165 data = data.batch(3) 166 data = data.repeat(2) 167 data = data.repeat(3) 168 169 for _, d in enumerate(data): 170 np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1], [2]])) 171 172 assert sum([1 for _ in data]) == 6 173 174 175def test_nested_repeat6(): 176 logger.info("test_nested_repeat6") 177 data = ds.GeneratorDataset(generator, ["data"]) 178 data = data.repeat(2) 179 data = data.batch(3) 180 data = data.repeat(3) 181 182 for _, d in enumerate(data): 183 np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1], [2]])) 184 185 assert sum([1 for _ in data]) == 6 186 187 188def test_nested_repeat7(): 189 logger.info("test_nested_repeat7") 190 data = ds.GeneratorDataset(generator, ["data"]) 191 data = data.repeat(2) 192 data = data.repeat(3) 193 data = data.batch(3) 194 195 for _, d in enumerate(data): 196 np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1], [2]])) 197 198 assert sum([1 for _ in data]) == 6 199 200 201def test_nested_repeat8(): 202 logger.info("test_nested_repeat8") 203 data = ds.GeneratorDataset(generator, ["data"]) 204 data = data.batch(2, drop_remainder=False) 205 data = data.repeat(2) 206 data = data.repeat(3) 207 208 for i, d in enumerate(data): 209 if i % 2 == 0: 210 np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[0], [1]])) 211 else: 212 np.testing.assert_array_equal(d[0].asnumpy(), np.asarray([[2]])) 213 214 assert sum([1 for _ in data]) == 6 * 2 215 216 217def test_nested_repeat9(): 218 logger.info("test_nested_repeat9") 219 data = ds.GeneratorDataset(generator, ["data"]) 220 data = data.repeat() 221 data = data.repeat(3) 222 223 for i, d in enumerate(data): 224 assert i % 3 == d[0].asnumpy()[0] 225 if i == 10: 226 break 227 228 229def test_nested_repeat10(): 230 logger.info("test_nested_repeat10") 231 data = ds.GeneratorDataset(generator, ["data"]) 232 data = data.repeat(3) 233 data = data.repeat() 234 235 for i, d in enumerate(data): 236 assert i % 3 == d[0].asnumpy()[0] 237 if i == 10: 238 break 239 240 241def test_nested_repeat11(): 242 logger.info("test_nested_repeat11") 243 data = ds.GeneratorDataset(generator, ["data"]) 244 data = data.repeat(2) 245 data = data.repeat(3) 246 data = data.repeat(4) 247 data = data.repeat(5) 248 249 for i, d in enumerate(data): 250 assert i % 3 == d[0].asnumpy()[0] 251 252 assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3 253 254 255def test_repeat_count1(): 256 data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) 257 data1_size = data1.get_dataset_size() 258 logger.info("dataset size is {}".format(data1_size)) 259 batch_size = 2 260 repeat_count = 4 261 resize_height, resize_width = 32, 32 262 decode_op = vision.Decode() 263 resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) 264 data1 = data1.map(operations=decode_op, input_columns=["image"]) 265 data1 = data1.map(operations=resize_op, input_columns=["image"]) 266 data1 = data1.repeat(repeat_count) 267 data1 = data1.batch(batch_size, drop_remainder=False) 268 dataset_size = data1.get_dataset_size() 269 logger.info("dataset repeat then batch's size is {}".format(dataset_size)) 270 num1_iter = 0 271 for _ in data1.create_dict_iterator(num_epochs=1): 272 num1_iter += 1 273 274 assert data1_size == 3 275 assert dataset_size == num1_iter == 6 276 277 278def test_repeat_count2(): 279 data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) 280 data1_size = data1.get_dataset_size() 281 logger.info("dataset size is {}".format(data1_size)) 282 batch_size = 2 283 repeat_count = 4 284 resize_height, resize_width = 32, 32 285 decode_op = vision.Decode() 286 resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR) 287 data1 = data1.map(operations=decode_op, input_columns=["image"]) 288 data1 = data1.map(operations=resize_op, input_columns=["image"]) 289 data1 = data1.batch(batch_size, drop_remainder=False) 290 data1 = data1.repeat(repeat_count) 291 dataset_size = data1.get_dataset_size() 292 logger.info("dataset batch then repeat's size is {}".format(dataset_size)) 293 num1_iter = 0 294 for _ in data1.create_dict_iterator(num_epochs=1): 295 num1_iter += 1 296 297 assert data1_size == 3 298 assert dataset_size == num1_iter == 8 299 300 301def test_repeat_count0(): 302 """ 303 Test Repeat with invalid count 0. 304 """ 305 logger.info("Test Repeat with invalid count 0") 306 with pytest.raises(ValueError) as info: 307 data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) 308 data1.repeat(0) 309 assert "count" in str(info.value) 310 311 312def test_repeat_countneg2(): 313 """ 314 Test Repeat with invalid count -2. 315 """ 316 logger.info("Test Repeat with invalid count -2") 317 with pytest.raises(ValueError) as info: 318 data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) 319 data1.repeat(-2) 320 assert "count" in str(info.value) 321 322 323if __name__ == "__main__": 324 test_tf_repeat_01() 325 test_tf_repeat_02() 326 test_tf_repeat_03() 327 test_tf_repeat_04() 328 test_nested_repeat1() 329 test_nested_repeat2() 330 test_nested_repeat3() 331 test_nested_repeat4() 332 test_nested_repeat5() 333 test_nested_repeat6() 334 test_nested_repeat7() 335 test_nested_repeat8() 336 test_nested_repeat9() 337 test_nested_repeat10() 338 test_nested_repeat11() 339 test_repeat_count1() 340 test_repeat_count2() 341 test_repeat_count0() 342 test_repeat_countneg2() 343