1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Test Cifar10 and Cifar100 dataset operators 17""" 18import os 19import pytest 20import numpy as np 21import matplotlib.pyplot as plt 22import mindspore.dataset as ds 23from mindspore import log as logger 24 25DATA_DIR_10 = "../data/dataset/testCifar10Data" 26DATA_DIR_100 = "../data/dataset/testCifar100Data" 27NO_BIN_DIR = "../data/dataset/testMnistData" 28 29 30def load_cifar(path, kind="cifar10"): 31 """ 32 load Cifar10/100 data 33 """ 34 raw = np.empty(0, dtype=np.uint8) 35 for file_name in os.listdir(path): 36 if file_name.endswith(".bin"): 37 with open(os.path.join(path, file_name), mode='rb') as file: 38 raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0) 39 if kind == "cifar10": 40 raw = raw.reshape(-1, 3073) 41 labels = raw[:, 0] 42 images = raw[:, 1:] 43 elif kind == "cifar100": 44 raw = raw.reshape(-1, 3074) 45 labels = raw[:, :2] 46 images = raw[:, 2:] 47 else: 48 raise ValueError("Invalid parameter value") 49 images = images.reshape(-1, 3, 32, 32) 50 images = images.transpose(0, 2, 3, 1) 51 return images, labels 52 53 54def visualize_dataset(images, labels): 55 """ 56 Helper function to visualize the dataset samples 57 """ 58 num_samples = len(images) 59 for i in range(num_samples): 60 plt.subplot(1, num_samples, i + 1) 61 plt.imshow(images[i]) 62 plt.title(labels[i]) 63 plt.show() 64 65 66### Testcases for Cifar10Dataset Op ### 67 68 69def test_cifar10_content_check(): 70 """ 71 Validate Cifar10Dataset image readings 72 """ 73 logger.info("Test Cifar10Dataset Op with content check") 74 data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100, shuffle=False) 75 images, labels = load_cifar(DATA_DIR_10) 76 num_iter = 0 77 # in this example, each dictionary has keys "image" and "label" 78 for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): 79 np.testing.assert_array_equal(d["image"], images[i]) 80 np.testing.assert_array_equal(d["label"], labels[i]) 81 num_iter += 1 82 assert num_iter == 100 83 84 85def test_cifar10_basic(): 86 """ 87 Validate CIFAR10 88 """ 89 logger.info("Test Cifar10Dataset Op") 90 91 # case 0: test loading the whole dataset 92 data0 = ds.Cifar10Dataset(DATA_DIR_10) 93 num_iter0 = 0 94 for _ in data0.create_dict_iterator(num_epochs=1): 95 num_iter0 += 1 96 assert num_iter0 == 10000 97 98 # case 1: test num_samples 99 data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) 100 num_iter1 = 0 101 for _ in data1.create_dict_iterator(num_epochs=1): 102 num_iter1 += 1 103 assert num_iter1 == 100 104 105 # case 2: test num_parallel_workers 106 data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=50, num_parallel_workers=1) 107 num_iter2 = 0 108 for _ in data2.create_dict_iterator(num_epochs=1): 109 num_iter2 += 1 110 assert num_iter2 == 50 111 112 # case 3: test repeat 113 data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) 114 data3 = data3.repeat(3) 115 num_iter3 = 0 116 for _ in data3.create_dict_iterator(num_epochs=1): 117 num_iter3 += 1 118 assert num_iter3 == 300 119 120 # case 4: test batch with drop_remainder=False 121 data4 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) 122 assert data4.get_dataset_size() == 100 123 assert data4.get_batch_size() == 1 124 data4 = data4.batch(batch_size=7) # drop_remainder is default to be False 125 assert data4.get_dataset_size() == 15 126 assert data4.get_batch_size() == 7 127 num_iter4 = 0 128 for _ in data4.create_dict_iterator(num_epochs=1): 129 num_iter4 += 1 130 assert num_iter4 == 15 131 132 # case 5: test batch with drop_remainder=True 133 data5 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) 134 assert data5.get_dataset_size() == 100 135 assert data5.get_batch_size() == 1 136 data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped 137 assert data5.get_dataset_size() == 14 138 assert data5.get_batch_size() == 7 139 num_iter5 = 0 140 for _ in data5.create_dict_iterator(num_epochs=1): 141 num_iter5 += 1 142 assert num_iter5 == 14 143 144 145def test_cifar10_pk_sampler(): 146 """ 147 Test Cifar10Dataset with PKSampler 148 """ 149 logger.info("Test Cifar10Dataset Op with PKSampler") 150 golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 151 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9] 152 sampler = ds.PKSampler(3) 153 data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) 154 num_iter = 0 155 label_list = [] 156 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 157 label_list.append(item["label"]) 158 num_iter += 1 159 np.testing.assert_array_equal(golden, label_list) 160 assert num_iter == 30 161 162 163def test_cifar10_sequential_sampler(): 164 """ 165 Test Cifar10Dataset with SequentialSampler 166 """ 167 logger.info("Test Cifar10Dataset Op with SequentialSampler") 168 num_samples = 30 169 sampler = ds.SequentialSampler(num_samples=num_samples) 170 data1 = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) 171 data2 = ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_samples=num_samples) 172 num_iter = 0 173 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 174 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 175 np.testing.assert_equal(item1["label"], item2["label"]) 176 num_iter += 1 177 assert num_iter == num_samples 178 179 180def test_cifar10_exception(): 181 """ 182 Test error cases for Cifar10Dataset 183 """ 184 logger.info("Test error cases for Cifar10Dataset") 185 error_msg_1 = "sampler and shuffle cannot be specified at the same time" 186 with pytest.raises(RuntimeError, match=error_msg_1): 187 ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, sampler=ds.PKSampler(3)) 188 189 error_msg_2 = "sampler and sharding cannot be specified at the same time" 190 with pytest.raises(RuntimeError, match=error_msg_2): 191 ds.Cifar10Dataset(DATA_DIR_10, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) 192 193 error_msg_3 = "num_shards is specified and currently requires shard_id as well" 194 with pytest.raises(RuntimeError, match=error_msg_3): 195 ds.Cifar10Dataset(DATA_DIR_10, num_shards=10) 196 197 error_msg_4 = "shard_id is specified but num_shards is not" 198 with pytest.raises(RuntimeError, match=error_msg_4): 199 ds.Cifar10Dataset(DATA_DIR_10, shard_id=0) 200 201 error_msg_5 = "Input shard_id is not within the required interval" 202 with pytest.raises(ValueError, match=error_msg_5): 203 ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=-1) 204 with pytest.raises(ValueError, match=error_msg_5): 205 ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=5) 206 207 error_msg_6 = "num_parallel_workers exceeds" 208 with pytest.raises(ValueError, match=error_msg_6): 209 ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=0) 210 with pytest.raises(ValueError, match=error_msg_6): 211 ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=256) 212 213 error_msg_7 = "no .bin files found" 214 with pytest.raises(RuntimeError, match=error_msg_7): 215 ds1 = ds.Cifar10Dataset(NO_BIN_DIR) 216 for _ in ds1.__iter__(): 217 pass 218 219 220def test_cifar10_visualize(plot=False): 221 """ 222 Visualize Cifar10Dataset results 223 """ 224 logger.info("Test Cifar10Dataset visualization") 225 226 data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=10, shuffle=False) 227 num_iter = 0 228 image_list, label_list = [], [] 229 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 230 image = item["image"] 231 label = item["label"] 232 image_list.append(image) 233 label_list.append("label {}".format(label)) 234 assert isinstance(image, np.ndarray) 235 assert image.shape == (32, 32, 3) 236 assert image.dtype == np.uint8 237 assert label.dtype == np.uint32 238 num_iter += 1 239 assert num_iter == 10 240 if plot: 241 visualize_dataset(image_list, label_list) 242 243 244### Testcases for Cifar100Dataset Op ### 245 246def test_cifar100_content_check(): 247 """ 248 Validate Cifar100Dataset image readings 249 """ 250 logger.info("Test Cifar100Dataset with content check") 251 data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, shuffle=False) 252 images, labels = load_cifar(DATA_DIR_100, kind="cifar100") 253 num_iter = 0 254 # in this example, each dictionary has keys "image", "coarse_label" and "fine_image" 255 for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)): 256 np.testing.assert_array_equal(d["image"], images[i]) 257 np.testing.assert_array_equal(d["coarse_label"], labels[i][0]) 258 np.testing.assert_array_equal(d["fine_label"], labels[i][1]) 259 num_iter += 1 260 assert num_iter == 100 261 262 263def test_cifar100_basic(): 264 """ 265 Test Cifar100Dataset 266 """ 267 logger.info("Test Cifar100Dataset") 268 269 # case 1: test num_samples 270 data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100) 271 num_iter1 = 0 272 for _ in data1.create_dict_iterator(num_epochs=1): 273 num_iter1 += 1 274 assert num_iter1 == 100 275 276 # case 2: test repeat 277 data1 = data1.repeat(2) 278 num_iter2 = 0 279 for _ in data1.create_dict_iterator(num_epochs=1): 280 num_iter2 += 1 281 assert num_iter2 == 200 282 283 # case 3: test num_parallel_workers 284 data2 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, num_parallel_workers=1) 285 num_iter3 = 0 286 for _ in data2.create_dict_iterator(num_epochs=1): 287 num_iter3 += 1 288 assert num_iter3 == 100 289 290 # case 4: test batch with drop_remainder=False 291 data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100) 292 assert data3.get_dataset_size() == 100 293 assert data3.get_batch_size() == 1 294 data3 = data3.batch(batch_size=3) 295 assert data3.get_dataset_size() == 34 296 assert data3.get_batch_size() == 3 297 num_iter4 = 0 298 for _ in data3.create_dict_iterator(num_epochs=1): 299 num_iter4 += 1 300 assert num_iter4 == 34 301 302 # case 4: test batch with drop_remainder=True 303 data4 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100) 304 data4 = data4.batch(batch_size=3, drop_remainder=True) 305 assert data4.get_dataset_size() == 33 306 assert data4.get_batch_size() == 3 307 num_iter5 = 0 308 for _ in data4.create_dict_iterator(num_epochs=1): 309 num_iter5 += 1 310 assert num_iter5 == 33 311 312 313def test_cifar100_pk_sampler(): 314 """ 315 Test Cifar100Dataset with PKSampler 316 """ 317 logger.info("Test Cifar100Dataset with PKSampler") 318 golden = [i for i in range(20)] 319 sampler = ds.PKSampler(1) 320 data = ds.Cifar100Dataset(DATA_DIR_100, sampler=sampler) 321 num_iter = 0 322 label_list = [] 323 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 324 label_list.append(item["coarse_label"]) 325 num_iter += 1 326 np.testing.assert_array_equal(golden, label_list) 327 assert num_iter == 20 328 329 330def test_cifar100_exception(): 331 """ 332 Test error cases for Cifar100Dataset 333 """ 334 logger.info("Test error cases for Cifar100Dataset") 335 error_msg_1 = "sampler and shuffle cannot be specified at the same time" 336 with pytest.raises(RuntimeError, match=error_msg_1): 337 ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, sampler=ds.PKSampler(3)) 338 339 error_msg_2 = "sampler and sharding cannot be specified at the same time" 340 with pytest.raises(RuntimeError, match=error_msg_2): 341 ds.Cifar100Dataset(DATA_DIR_100, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) 342 343 error_msg_3 = "num_shards is specified and currently requires shard_id as well" 344 with pytest.raises(RuntimeError, match=error_msg_3): 345 ds.Cifar100Dataset(DATA_DIR_100, num_shards=10) 346 347 error_msg_4 = "shard_id is specified but num_shards is not" 348 with pytest.raises(RuntimeError, match=error_msg_4): 349 ds.Cifar100Dataset(DATA_DIR_100, shard_id=0) 350 351 error_msg_5 = "Input shard_id is not within the required interval" 352 with pytest.raises(ValueError, match=error_msg_5): 353 ds.Cifar100Dataset(DATA_DIR_100, num_shards=2, shard_id=-1) 354 with pytest.raises(ValueError, match=error_msg_5): 355 ds.Cifar10Dataset(DATA_DIR_100, num_shards=2, shard_id=5) 356 357 error_msg_6 = "num_parallel_workers exceeds" 358 with pytest.raises(ValueError, match=error_msg_6): 359 ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=0) 360 with pytest.raises(ValueError, match=error_msg_6): 361 ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=256) 362 363 error_msg_7 = "no .bin files found" 364 with pytest.raises(RuntimeError, match=error_msg_7): 365 ds1 = ds.Cifar100Dataset(NO_BIN_DIR) 366 for _ in ds1.__iter__(): 367 pass 368 369 370def test_cifar100_visualize(plot=False): 371 """ 372 Visualize Cifar100Dataset results 373 """ 374 logger.info("Test Cifar100Dataset visualization") 375 376 data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=10, shuffle=False) 377 num_iter = 0 378 image_list, label_list = [], [] 379 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 380 image = item["image"] 381 coarse_label = item["coarse_label"] 382 fine_label = item["fine_label"] 383 image_list.append(image) 384 label_list.append("coarse_label {}\nfine_label {}".format(coarse_label, fine_label)) 385 assert isinstance(image, np.ndarray) 386 assert image.shape == (32, 32, 3) 387 assert image.dtype == np.uint8 388 assert coarse_label.dtype == np.uint32 389 assert fine_label.dtype == np.uint32 390 num_iter += 1 391 assert num_iter == 10 392 if plot: 393 visualize_dataset(image_list, label_list) 394 395 396def test_cifar_usage(): 397 """ 398 test usage of cifar 399 """ 400 logger.info("Test Cifar100Dataset usage flag") 401 402 # flag, if True, test cifar10 else test cifar100 403 def test_config(usage, flag=True, cifar_path=None): 404 if cifar_path is None: 405 cifar_path = DATA_DIR_10 if flag else DATA_DIR_100 406 try: 407 data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage) 408 num_rows = 0 409 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 410 num_rows += 1 411 except (ValueError, TypeError, RuntimeError) as e: 412 return str(e) 413 return num_rows 414 415 # test the usage of CIFAR100 416 assert test_config("train") == 10000 417 assert test_config("all") == 10000 418 assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid") 419 assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"]) 420 assert "Cifar10Dataset API can't read the data file (interface mismatch or no data found)" in test_config("test") 421 422 # test the usage of CIFAR10 423 assert test_config("test", False) == 10000 424 assert test_config("all", False) == 10000 425 assert "Cifar100Dataset API can't read the data file" in test_config("train", False) 426 assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False) 427 428 # change this directory to the folder that contains all cifar10 files 429 all_cifar10 = None 430 if all_cifar10 is not None: 431 assert test_config("train", True, all_cifar10) == 50000 432 assert test_config("test", True, all_cifar10) == 10000 433 assert test_config("all", True, all_cifar10) == 60000 434 assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000 435 assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000 436 assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000 437 438 # change this directory to the folder that contains all cifar100 files 439 all_cifar100 = None 440 if all_cifar100 is not None: 441 assert test_config("train", False, all_cifar100) == 50000 442 assert test_config("test", False, all_cifar100) == 10000 443 assert test_config("all", False, all_cifar100) == 60000 444 assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000 445 assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000 446 assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000 447 448 449def test_cifar_exception_file_path(): 450 def exception_func(item): 451 raise Exception("Error occur!") 452 453 try: 454 data = ds.Cifar10Dataset(DATA_DIR_10) 455 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 456 num_rows = 0 457 for _ in data.create_dict_iterator(): 458 num_rows += 1 459 assert False 460 except RuntimeError as e: 461 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 462 463 try: 464 data = ds.Cifar10Dataset(DATA_DIR_10) 465 data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) 466 num_rows = 0 467 for _ in data.create_dict_iterator(): 468 num_rows += 1 469 assert False 470 except RuntimeError as e: 471 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 472 473 try: 474 data = ds.Cifar100Dataset(DATA_DIR_100) 475 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 476 num_rows = 0 477 for _ in data.create_dict_iterator(): 478 num_rows += 1 479 assert False 480 except RuntimeError as e: 481 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 482 483 try: 484 data = ds.Cifar100Dataset(DATA_DIR_100) 485 data = data.map(operations=exception_func, input_columns=["coarse_label"], num_parallel_workers=1) 486 num_rows = 0 487 for _ in data.create_dict_iterator(): 488 num_rows += 1 489 assert False 490 except RuntimeError as e: 491 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 492 493 try: 494 data = ds.Cifar100Dataset(DATA_DIR_100) 495 data = data.map(operations=exception_func, input_columns=["fine_label"], num_parallel_workers=1) 496 num_rows = 0 497 for _ in data.create_dict_iterator(): 498 num_rows += 1 499 assert False 500 except RuntimeError as e: 501 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 502 503 504def test_cifar10_pk_sampler_get_dataset_size(): 505 """ 506 Test Cifar10Dataset with PKSampler and get_dataset_size 507 """ 508 sampler = ds.PKSampler(3) 509 data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) 510 num_iter = 0 511 ds_sz = data.get_dataset_size() 512 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 513 num_iter += 1 514 515 assert ds_sz == num_iter == 30 516 517 518def test_cifar10_with_chained_sampler_get_dataset_size(): 519 """ 520 Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size 521 """ 522 sampler = ds.SequentialSampler(start_index=0, num_samples=5) 523 child_sampler = ds.PKSampler(4) 524 sampler.add_child(child_sampler) 525 data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) 526 num_iter = 0 527 ds_sz = data.get_dataset_size() 528 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 529 num_iter += 1 530 assert ds_sz == num_iter == 5 531 532 533if __name__ == '__main__': 534 test_cifar10_content_check() 535 test_cifar10_basic() 536 test_cifar10_pk_sampler() 537 test_cifar10_sequential_sampler() 538 test_cifar10_exception() 539 test_cifar10_visualize(plot=False) 540 541 test_cifar100_content_check() 542 test_cifar100_basic() 543 test_cifar100_pk_sampler() 544 test_cifar100_exception() 545 test_cifar100_visualize(plot=False) 546 547 test_cifar_usage() 548 test_cifar_exception_file_path() 549 550 test_cifar10_with_chained_sampler_get_dataset_size() 551 test_cifar10_pk_sampler_get_dataset_size() 552