1# Copyright 2019-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import pytest 16import mindspore.dataset as ds 17import mindspore.dataset.vision.c_transforms as vision 18from mindspore import log as logger 19 20DATA_DIR = "../data/dataset/testPK/data" 21 22 23def test_imagefolder_basic(): 24 logger.info("Test Case basic") 25 # define parameters 26 repeat_count = 1 27 28 # apply dataset operations 29 data1 = ds.ImageFolderDataset(DATA_DIR) 30 data1 = data1.repeat(repeat_count) 31 32 num_iter = 0 33 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 34 # in this example, each dictionary has keys "image" and "label" 35 logger.info("image is {}".format(item["image"])) 36 logger.info("label is {}".format(item["label"])) 37 num_iter += 1 38 39 logger.info("Number of data in data1: {}".format(num_iter)) 40 assert num_iter == 44 41 42 43def test_imagefolder_numsamples(): 44 logger.info("Test Case numSamples") 45 # define parameters 46 repeat_count = 1 47 48 # apply dataset operations 49 data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10, num_parallel_workers=2) 50 data1 = data1.repeat(repeat_count) 51 52 num_iter = 0 53 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 54 # in this example, each dictionary has keys "image" and "label" 55 logger.info("image is {}".format(item["image"])) 56 logger.info("label is {}".format(item["label"])) 57 num_iter += 1 58 59 logger.info("Number of data in data1: {}".format(num_iter)) 60 assert num_iter == 10 61 62 random_sampler = ds.RandomSampler(num_samples=3, replacement=True) 63 data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) 64 65 num_iter = 0 66 for item in data1.create_dict_iterator(num_epochs=1): 67 num_iter += 1 68 69 assert num_iter == 3 70 71 random_sampler = ds.RandomSampler(num_samples=3, replacement=False) 72 data1 = ds.ImageFolderDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) 73 74 num_iter = 0 75 for item in data1.create_dict_iterator(num_epochs=1): 76 num_iter += 1 77 78 assert num_iter == 3 79 80 81def test_imagefolder_numshards(): 82 logger.info("Test Case numShards") 83 # define parameters 84 repeat_count = 1 85 86 # apply dataset operations 87 data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=3) 88 data1 = data1.repeat(repeat_count) 89 90 num_iter = 0 91 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 92 # in this example, each dictionary has keys "image" and "label" 93 logger.info("image is {}".format(item["image"])) 94 logger.info("label is {}".format(item["label"])) 95 num_iter += 1 96 97 logger.info("Number of data in data1: {}".format(num_iter)) 98 assert num_iter == 11 99 100 101def test_imagefolder_shardid(): 102 logger.info("Test Case withShardID") 103 # define parameters 104 repeat_count = 1 105 106 # apply dataset operations 107 data1 = ds.ImageFolderDataset(DATA_DIR, num_shards=4, shard_id=1) 108 data1 = data1.repeat(repeat_count) 109 110 num_iter = 0 111 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 112 # in this example, each dictionary has keys "image" and "label" 113 logger.info("image is {}".format(item["image"])) 114 logger.info("label is {}".format(item["label"])) 115 num_iter += 1 116 117 logger.info("Number of data in data1: {}".format(num_iter)) 118 assert num_iter == 11 119 120 121def test_imagefolder_noshuffle(): 122 logger.info("Test Case noShuffle") 123 # define parameters 124 repeat_count = 1 125 126 # apply dataset operations 127 data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=False) 128 data1 = data1.repeat(repeat_count) 129 130 num_iter = 0 131 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 132 # in this example, each dictionary has keys "image" and "label" 133 logger.info("image is {}".format(item["image"])) 134 logger.info("label is {}".format(item["label"])) 135 num_iter += 1 136 137 logger.info("Number of data in data1: {}".format(num_iter)) 138 assert num_iter == 44 139 140 141def test_imagefolder_extrashuffle(): 142 logger.info("Test Case extraShuffle") 143 # define parameters 144 repeat_count = 2 145 146 # apply dataset operations 147 data1 = ds.ImageFolderDataset(DATA_DIR, shuffle=True) 148 data1 = data1.shuffle(buffer_size=5) 149 data1 = data1.repeat(repeat_count) 150 151 num_iter = 0 152 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 153 # in this example, each dictionary has keys "image" and "label" 154 logger.info("image is {}".format(item["image"])) 155 logger.info("label is {}".format(item["label"])) 156 num_iter += 1 157 158 logger.info("Number of data in data1: {}".format(num_iter)) 159 assert num_iter == 88 160 161 162def test_imagefolder_classindex(): 163 logger.info("Test Case classIndex") 164 # define parameters 165 repeat_count = 1 166 167 # apply dataset operations 168 class_index = {"class3": 333, "class1": 111} 169 data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False) 170 data1 = data1.repeat(repeat_count) 171 172 golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 173 333, 333, 333, 333, 333, 333, 333, 333, 333, 333, 333] 174 175 num_iter = 0 176 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 177 # in this example, each dictionary has keys "image" and "label" 178 logger.info("image is {}".format(item["image"])) 179 logger.info("label is {}".format(item["label"])) 180 assert item["label"] == golden[num_iter] 181 num_iter += 1 182 183 logger.info("Number of data in data1: {}".format(num_iter)) 184 assert num_iter == 22 185 186 187def test_imagefolder_negative_classindex(): 188 logger.info("Test Case negative classIndex") 189 # define parameters 190 repeat_count = 1 191 192 # apply dataset operations 193 class_index = {"class3": -333, "class1": 111} 194 data1 = ds.ImageFolderDataset(DATA_DIR, class_indexing=class_index, shuffle=False) 195 data1 = data1.repeat(repeat_count) 196 197 golden = [111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 111, 198 -333, -333, -333, -333, -333, -333, -333, -333, -333, -333, -333] 199 200 num_iter = 0 201 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 202 # in this example, each dictionary has keys "image" and "label" 203 logger.info("image is {}".format(item["image"])) 204 logger.info("label is {}".format(item["label"])) 205 assert item["label"] == golden[num_iter] 206 num_iter += 1 207 208 logger.info("Number of data in data1: {}".format(num_iter)) 209 assert num_iter == 22 210 211 212def test_imagefolder_extensions(): 213 logger.info("Test Case extensions") 214 # define parameters 215 repeat_count = 1 216 217 # apply dataset operations 218 ext = [".jpg", ".JPEG"] 219 data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext) 220 data1 = data1.repeat(repeat_count) 221 222 num_iter = 0 223 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 224 # in this example, each dictionary has keys "image" and "label" 225 logger.info("image is {}".format(item["image"])) 226 logger.info("label is {}".format(item["label"])) 227 num_iter += 1 228 229 logger.info("Number of data in data1: {}".format(num_iter)) 230 assert num_iter == 44 231 232 233def test_imagefolder_decode(): 234 logger.info("Test Case decode") 235 # define parameters 236 repeat_count = 1 237 238 # apply dataset operations 239 ext = [".jpg", ".JPEG"] 240 data1 = ds.ImageFolderDataset(DATA_DIR, extensions=ext, decode=True) 241 data1 = data1.repeat(repeat_count) 242 243 num_iter = 0 244 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 245 # in this example, each dictionary has keys "image" and "label" 246 logger.info("image is {}".format(item["image"])) 247 logger.info("label is {}".format(item["label"])) 248 num_iter += 1 249 250 logger.info("Number of data in data1: {}".format(num_iter)) 251 assert num_iter == 44 252 253 254def test_sequential_sampler(): 255 logger.info("Test Case SequentialSampler") 256 257 golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 258 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 259 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 260 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] 261 262 # define parameters 263 repeat_count = 1 264 265 # apply dataset operations 266 sampler = ds.SequentialSampler() 267 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 268 data1 = data1.repeat(repeat_count) 269 270 result = [] 271 num_iter = 0 272 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary 273 # in this example, each dictionary has keys "image" and "label" 274 result.append(item["label"]) 275 num_iter += 1 276 277 assert num_iter == 44 278 logger.info("Result: {}".format(result)) 279 assert result == golden 280 281 282def test_random_sampler(): 283 logger.info("Test Case RandomSampler") 284 # define parameters 285 repeat_count = 1 286 287 # apply dataset operations 288 sampler = ds.RandomSampler() 289 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 290 data1 = data1.repeat(repeat_count) 291 292 num_iter = 0 293 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 294 # in this example, each dictionary has keys "image" and "label" 295 logger.info("image is {}".format(item["image"])) 296 logger.info("label is {}".format(item["label"])) 297 num_iter += 1 298 299 logger.info("Number of data in data1: {}".format(num_iter)) 300 assert num_iter == 44 301 302 303def test_distributed_sampler(): 304 logger.info("Test Case DistributedSampler") 305 # define parameters 306 repeat_count = 1 307 308 # apply dataset operations 309 sampler = ds.DistributedSampler(10, 1) 310 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 311 data1 = data1.repeat(repeat_count) 312 313 num_iter = 0 314 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 315 # in this example, each dictionary has keys "image" and "label" 316 logger.info("image is {}".format(item["image"])) 317 logger.info("label is {}".format(item["label"])) 318 num_iter += 1 319 320 logger.info("Number of data in data1: {}".format(num_iter)) 321 assert num_iter == 5 322 323 324def test_pk_sampler(): 325 logger.info("Test Case PKSampler") 326 # define parameters 327 repeat_count = 1 328 329 # apply dataset operations 330 sampler = ds.PKSampler(3) 331 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 332 data1 = data1.repeat(repeat_count) 333 334 num_iter = 0 335 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 336 # in this example, each dictionary has keys "image" and "label" 337 logger.info("image is {}".format(item["image"])) 338 logger.info("label is {}".format(item["label"])) 339 num_iter += 1 340 341 logger.info("Number of data in data1: {}".format(num_iter)) 342 assert num_iter == 12 343 344 345def test_subset_random_sampler(): 346 logger.info("Test Case SubsetRandomSampler") 347 # define parameters 348 repeat_count = 1 349 350 # apply dataset operations 351 indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11] 352 sampler = ds.SubsetRandomSampler(indices) 353 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 354 data1 = data1.repeat(repeat_count) 355 356 num_iter = 0 357 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 358 # in this example, each dictionary has keys "image" and "label" 359 logger.info("image is {}".format(item["image"])) 360 logger.info("label is {}".format(item["label"])) 361 num_iter += 1 362 363 logger.info("Number of data in data1: {}".format(num_iter)) 364 assert num_iter == 12 365 366 367def test_weighted_random_sampler(): 368 logger.info("Test Case WeightedRandomSampler") 369 # define parameters 370 repeat_count = 1 371 372 # apply dataset operations 373 weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1] 374 sampler = ds.WeightedRandomSampler(weights, 11) 375 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 376 data1 = data1.repeat(repeat_count) 377 378 num_iter = 0 379 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 380 # in this example, each dictionary has keys "image" and "label" 381 logger.info("image is {}".format(item["image"])) 382 logger.info("label is {}".format(item["label"])) 383 num_iter += 1 384 385 logger.info("Number of data in data1: {}".format(num_iter)) 386 assert num_iter == 11 387 388 389def test_weighted_random_sampler_exception(): 390 """ 391 Test error cases for WeightedRandomSampler 392 """ 393 logger.info("Test error cases for WeightedRandomSampler") 394 error_msg_1 = "type of weights element must be number" 395 with pytest.raises(TypeError, match=error_msg_1): 396 weights = "" 397 ds.WeightedRandomSampler(weights) 398 399 error_msg_2 = "type of weights element must be number" 400 with pytest.raises(TypeError, match=error_msg_2): 401 weights = (0.9, 0.8, 1.1) 402 ds.WeightedRandomSampler(weights) 403 404 error_msg_3 = "WeightedRandomSampler: weights vector must not be empty" 405 with pytest.raises(RuntimeError, match=error_msg_3): 406 weights = [] 407 sampler = ds.WeightedRandomSampler(weights) 408 sampler.parse() 409 410 error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: " 411 with pytest.raises(RuntimeError, match=error_msg_4): 412 weights = [1.0, 0.1, 0.02, 0.3, -0.4] 413 sampler = ds.WeightedRandomSampler(weights) 414 sampler.parse() 415 416 error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero" 417 with pytest.raises(RuntimeError, match=error_msg_5): 418 weights = [0, 0, 0, 0, 0] 419 sampler = ds.WeightedRandomSampler(weights) 420 sampler.parse() 421 422 423def test_chained_sampler_01(): 424 logger.info("Test Case Chained Sampler - Random and Sequential, with repeat") 425 426 # Create chained sampler, random and sequential 427 sampler = ds.RandomSampler() 428 child_sampler = ds.SequentialSampler() 429 sampler.add_child(child_sampler) 430 # Create ImageFolderDataset with sampler 431 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 432 433 data1 = data1.repeat(count=3) 434 435 # Verify dataset size 436 data1_size = data1.get_dataset_size() 437 logger.info("dataset size is: {}".format(data1_size)) 438 assert data1_size == 132 439 440 # Verify number of iterations 441 num_iter = 0 442 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 443 # in this example, each dictionary has keys "image" and "label" 444 logger.info("image is {}".format(item["image"])) 445 logger.info("label is {}".format(item["label"])) 446 num_iter += 1 447 448 logger.info("Number of data in data1: {}".format(num_iter)) 449 assert num_iter == 132 450 451 452def test_chained_sampler_02(): 453 logger.info("Test Case Chained Sampler - Random and Sequential, with batch then repeat") 454 455 # Create chained sampler, random and sequential 456 sampler = ds.RandomSampler() 457 child_sampler = ds.SequentialSampler() 458 sampler.add_child(child_sampler) 459 # Create ImageFolderDataset with sampler 460 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 461 462 data1 = data1.batch(batch_size=5, drop_remainder=True) 463 data1 = data1.repeat(count=2) 464 465 # Verify dataset size 466 data1_size = data1.get_dataset_size() 467 logger.info("dataset size is: {}".format(data1_size)) 468 assert data1_size == 16 469 470 # Verify number of iterations 471 num_iter = 0 472 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 473 # in this example, each dictionary has keys "image" and "label" 474 logger.info("image is {}".format(item["image"])) 475 logger.info("label is {}".format(item["label"])) 476 num_iter += 1 477 478 logger.info("Number of data in data1: {}".format(num_iter)) 479 assert num_iter == 16 480 481 482def test_chained_sampler_03(): 483 logger.info("Test Case Chained Sampler - Random and Sequential, with repeat then batch") 484 485 # Create chained sampler, random and sequential 486 sampler = ds.RandomSampler() 487 child_sampler = ds.SequentialSampler() 488 sampler.add_child(child_sampler) 489 # Create ImageFolderDataset with sampler 490 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 491 492 data1 = data1.repeat(count=2) 493 data1 = data1.batch(batch_size=5, drop_remainder=False) 494 495 # Verify dataset size 496 data1_size = data1.get_dataset_size() 497 logger.info("dataset size is: {}".format(data1_size)) 498 assert data1_size == 18 499 500 # Verify number of iterations 501 num_iter = 0 502 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 503 # in this example, each dictionary has keys "image" and "label" 504 logger.info("image is {}".format(item["image"])) 505 logger.info("label is {}".format(item["label"])) 506 num_iter += 1 507 508 logger.info("Number of data in data1: {}".format(num_iter)) 509 assert num_iter == 18 510 511 512def test_chained_sampler_04(): 513 logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat") 514 515 # Create chained sampler, distributed and random 516 sampler = ds.DistributedSampler(num_shards=4, shard_id=3) 517 child_sampler = ds.RandomSampler() 518 sampler.add_child(child_sampler) 519 # Create ImageFolderDataset with sampler 520 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 521 522 data1 = data1.batch(batch_size=5, drop_remainder=True) 523 data1 = data1.repeat(count=3) 524 525 # Verify dataset size 526 data1_size = data1.get_dataset_size() 527 logger.info("dataset size is: {}".format(data1_size)) 528 assert data1_size == 6 529 530 # Verify number of iterations 531 num_iter = 0 532 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 533 # in this example, each dictionary has keys "image" and "label" 534 logger.info("image is {}".format(item["image"])) 535 logger.info("label is {}".format(item["label"])) 536 num_iter += 1 537 538 logger.info("Number of data in data1: {}".format(num_iter)) 539 # Note: Each of the 4 shards has 44/4=11 samples 540 # Note: Number of iterations is (11/5 = 2) * 3 = 6 541 assert num_iter == 6 542 543 544def skip_test_chained_sampler_05(): 545 logger.info("Test Case Chained Sampler - PKSampler and WeightedRandom") 546 547 # Create chained sampler, PKSampler and WeightedRandom 548 sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes) 549 weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5] 550 child_sampler = ds.WeightedRandomSampler(weights, num_samples=12) 551 sampler.add_child(child_sampler) 552 # Create ImageFolderDataset with sampler 553 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 554 555 # Verify dataset size 556 data1_size = data1.get_dataset_size() 557 logger.info("dataset size is: {}".format(data1_size)) 558 assert data1_size == 12 559 560 # Verify number of iterations 561 num_iter = 0 562 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 563 # in this example, each dictionary has keys "image" and "label" 564 logger.info("image is {}".format(item["image"])) 565 logger.info("label is {}".format(item["label"])) 566 num_iter += 1 567 568 logger.info("Number of data in data1: {}".format(num_iter)) 569 # Note: PKSampler produces 4x3=12 samples 570 # Note: Child WeightedRandomSampler produces 12 samples 571 assert num_iter == 12 572 573 574def test_chained_sampler_06(): 575 logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler") 576 577 # Create chained sampler, WeightedRandom and PKSampler 578 weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 0.5] 579 sampler = ds.WeightedRandomSampler(weights=weights, num_samples=12) 580 child_sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes) 581 sampler.add_child(child_sampler) 582 # Create ImageFolderDataset with sampler 583 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 584 585 # Verify dataset size 586 data1_size = data1.get_dataset_size() 587 logger.info("dataset size is: {}".format(data1_size)) 588 assert data1_size == 12 589 590 # Verify number of iterations 591 num_iter = 0 592 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 593 # in this example, each dictionary has keys "image" and "label" 594 logger.info("image is {}".format(item["image"])) 595 logger.info("label is {}".format(item["label"])) 596 num_iter += 1 597 598 logger.info("Number of data in data1: {}".format(num_iter)) 599 # Note: WeightedRandomSampler produces 12 samples 600 # Note: Child PKSampler produces 12 samples 601 assert num_iter == 12 602 603 604def test_chained_sampler_07(): 605 logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 2 shards") 606 607 # Create chained sampler, subset random and distributed 608 indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11] 609 sampler = ds.SubsetRandomSampler(indices, num_samples=12) 610 child_sampler = ds.DistributedSampler(num_shards=2, shard_id=1) 611 sampler.add_child(child_sampler) 612 # Create ImageFolderDataset with sampler 613 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 614 615 # Verify dataset size 616 data1_size = data1.get_dataset_size() 617 logger.info("dataset size is: {}".format(data1_size)) 618 assert data1_size == 12 619 620 # Verify number of iterations 621 622 num_iter = 0 623 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 624 # in this example, each dictionary has keys "image" and "label" 625 logger.info("image is {}".format(item["image"])) 626 logger.info("label is {}".format(item["label"])) 627 num_iter += 1 628 629 logger.info("Number of data in data1: {}".format(num_iter)) 630 # Note: SubsetRandomSampler produces 12 samples 631 # Note: Each of 2 shards has 6 samples 632 # FIXME: Uncomment the following assert when code issue is resolved; at runtime, number of samples is 12 not 6 633 # assert num_iter == 6 634 635 636def skip_test_chained_sampler_08(): 637 logger.info("Test Case Chained Sampler - SubsetRandom and Distributed, 4 shards") 638 639 # Create chained sampler, subset random and distributed 640 indices = [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11] 641 sampler = ds.SubsetRandomSampler(indices, num_samples=12) 642 child_sampler = ds.DistributedSampler(num_shards=4, shard_id=1) 643 sampler.add_child(child_sampler) 644 # Create ImageFolderDataset with sampler 645 data1 = ds.ImageFolderDataset(DATA_DIR, sampler=sampler) 646 647 # Verify dataset size 648 data1_size = data1.get_dataset_size() 649 logger.info("dataset size is: {}".format(data1_size)) 650 assert data1_size == 3 651 652 # Verify number of iterations 653 num_iter = 0 654 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 655 # in this example, each dictionary has keys "image" and "label" 656 logger.info("image is {}".format(item["image"])) 657 logger.info("label is {}".format(item["label"])) 658 num_iter += 1 659 660 logger.info("Number of data in data1: {}".format(num_iter)) 661 # Note: SubsetRandomSampler returns 12 samples 662 # Note: Each of 4 shards has 3 samples 663 assert num_iter == 3 664 665 666def test_imagefolder_rename(): 667 logger.info("Test Case rename") 668 # define parameters 669 repeat_count = 1 670 671 # apply dataset operations 672 data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10) 673 data1 = data1.repeat(repeat_count) 674 675 num_iter = 0 676 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 677 # in this example, each dictionary has keys "image" and "label" 678 logger.info("image is {}".format(item["image"])) 679 logger.info("label is {}".format(item["label"])) 680 num_iter += 1 681 682 logger.info("Number of data in data1: {}".format(num_iter)) 683 assert num_iter == 10 684 685 data1 = data1.rename(input_columns=["image"], output_columns="image2") 686 687 num_iter = 0 688 for item in data1.create_dict_iterator(num_epochs=1): # each data is a dictionary 689 # in this example, each dictionary has keys "image" and "label" 690 logger.info("image is {}".format(item["image2"])) 691 logger.info("label is {}".format(item["label"])) 692 num_iter += 1 693 694 logger.info("Number of data in data1: {}".format(num_iter)) 695 assert num_iter == 10 696 697 698def test_imagefolder_zip(): 699 logger.info("Test Case zip") 700 # define parameters 701 repeat_count = 2 702 703 # apply dataset operations 704 data1 = ds.ImageFolderDataset(DATA_DIR, num_samples=10) 705 data2 = ds.ImageFolderDataset(DATA_DIR, num_samples=10) 706 707 data1 = data1.repeat(repeat_count) 708 # rename dataset2 for no conflict 709 data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"]) 710 data3 = ds.zip((data1, data2)) 711 712 num_iter = 0 713 for item in data3.create_dict_iterator(num_epochs=1): # each data is a dictionary 714 # in this example, each dictionary has keys "image" and "label" 715 logger.info("image is {}".format(item["image"])) 716 logger.info("label is {}".format(item["label"])) 717 num_iter += 1 718 719 logger.info("Number of data in data1: {}".format(num_iter)) 720 assert num_iter == 10 721 722 723def test_imagefolder_exception(): 724 logger.info("Test imagefolder exception") 725 726 def exception_func(item): 727 raise Exception("Error occur!") 728 729 def exception_func2(image, label): 730 raise Exception("Error occur!") 731 732 try: 733 data = ds.ImageFolderDataset(DATA_DIR) 734 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 735 for _ in data.__iter__(): 736 pass 737 assert False 738 except RuntimeError as e: 739 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 740 741 try: 742 data = ds.ImageFolderDataset(DATA_DIR) 743 data = data.map(operations=exception_func2, input_columns=["image", "label"], 744 output_columns=["image", "label", "label1"], 745 column_order=["image", "label", "label1"], num_parallel_workers=1) 746 for _ in data.__iter__(): 747 pass 748 assert False 749 except RuntimeError as e: 750 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 751 752 try: 753 data = ds.ImageFolderDataset(DATA_DIR) 754 data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) 755 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 756 for _ in data.__iter__(): 757 pass 758 assert False 759 except RuntimeError as e: 760 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 761 762 data_dir_invalid = "../data/dataset/testPK" 763 try: 764 data = ds.ImageFolderDataset(data_dir_invalid) 765 for _ in data.__iter__(): 766 pass 767 assert False 768 except RuntimeError as e: 769 assert "should be file, but got directory" in str(e) 770 771 772if __name__ == '__main__': 773 test_imagefolder_basic() 774 logger.info('test_imagefolder_basic Ended.\n') 775 776 test_imagefolder_numsamples() 777 logger.info('test_imagefolder_numsamples Ended.\n') 778 779 test_sequential_sampler() 780 logger.info('test_sequential_sampler Ended.\n') 781 782 test_random_sampler() 783 logger.info('test_random_sampler Ended.\n') 784 785 test_distributed_sampler() 786 logger.info('test_distributed_sampler Ended.\n') 787 788 test_pk_sampler() 789 logger.info('test_pk_sampler Ended.\n') 790 791 test_subset_random_sampler() 792 logger.info('test_subset_random_sampler Ended.\n') 793 794 test_weighted_random_sampler() 795 logger.info('test_weighted_random_sampler Ended.\n') 796 797 test_weighted_random_sampler_exception() 798 logger.info('test_weighted_random_sampler_exception Ended.\n') 799 800 test_chained_sampler_01() 801 logger.info('test_chained_sampler_01 Ended.\n') 802 803 test_chained_sampler_02() 804 logger.info('test_chained_sampler_02 Ended.\n') 805 806 test_chained_sampler_03() 807 logger.info('test_chained_sampler_03 Ended.\n') 808 809 test_chained_sampler_04() 810 logger.info('test_chained_sampler_04 Ended.\n') 811 812 # test_chained_sampler_05() 813 # logger.info('test_chained_sampler_05 Ended.\n') 814 815 test_chained_sampler_06() 816 logger.info('test_chained_sampler_06 Ended.\n') 817 818 test_chained_sampler_07() 819 logger.info('test_chained_sampler_07 Ended.\n') 820 821 # test_chained_sampler_08() 822 # logger.info('test_chained_sampler_07 Ended.\n') 823 824 test_imagefolder_numshards() 825 logger.info('test_imagefolder_numshards Ended.\n') 826 827 test_imagefolder_shardid() 828 logger.info('test_imagefolder_shardid Ended.\n') 829 830 test_imagefolder_noshuffle() 831 logger.info('test_imagefolder_noshuffle Ended.\n') 832 833 test_imagefolder_extrashuffle() 834 logger.info('test_imagefolder_extrashuffle Ended.\n') 835 836 test_imagefolder_classindex() 837 logger.info('test_imagefolder_classindex Ended.\n') 838 839 test_imagefolder_negative_classindex() 840 logger.info('test_imagefolder_negative_classindex Ended.\n') 841 842 test_imagefolder_extensions() 843 logger.info('test_imagefolder_extensions Ended.\n') 844 845 test_imagefolder_decode() 846 logger.info('test_imagefolder_decode Ended.\n') 847 848 test_imagefolder_rename() 849 logger.info('test_imagefolder_rename Ended.\n') 850 851 test_imagefolder_zip() 852 logger.info('test_imagefolder_zip Ended.\n') 853 854 test_imagefolder_exception() 855 logger.info('test_imagefolder_exception Ended.\n') 856