1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing cache operator with mappable datasets 17""" 18import os 19import pytest 20import numpy as np 21import mindspore.dataset as ds 22import mindspore.dataset.vision.c_transforms as c_vision 23import mindspore.dataset.vision.py_transforms as py_vision 24from mindspore import log as logger 25from util import save_and_check_md5 26 27DATA_DIR = "../data/dataset/testImageNetData/train/" 28COCO_DATA_DIR = "../data/dataset/testCOCO/train/" 29COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" 30NO_IMAGE_DIR = "../data/dataset/testRandomData/" 31MNIST_DATA_DIR = "../data/dataset/testMnistData/" 32CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" 33VOC_DATA_DIR = "../data/dataset/testVOC2012/" 34MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" 35CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/" 36CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/" 37MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord" 38GENERATE_GOLDEN = False 39 40 41@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 42def test_cache_map_basic1(): 43 """ 44 Test mappable leaf with cache op right over the leaf 45 46 Repeat 47 | 48 Map(decode) 49 | 50 Cache 51 | 52 ImageFolder 53 """ 54 55 logger.info("Test cache map basic 1") 56 if "SESSION_ID" in os.environ: 57 session_id = int(os.environ['SESSION_ID']) 58 else: 59 raise RuntimeError("Testcase requires SESSION_ID environment variable") 60 61 some_cache = ds.DatasetCache(session_id=session_id, size=0) 62 63 # This DATA_DIR only has 2 images in it 64 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 65 decode_op = c_vision.Decode() 66 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 67 ds1 = ds1.repeat(4) 68 69 filename = "cache_map_01_result.npz" 70 save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) 71 72 logger.info("test_cache_map_basic1 Ended.\n") 73 74 75@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 76def test_cache_map_basic2(): 77 """ 78 Test mappable leaf with the cache op later in the tree above the map(decode) 79 80 Repeat 81 | 82 Cache 83 | 84 Map(decode) 85 | 86 ImageFolder 87 """ 88 89 logger.info("Test cache map basic 2") 90 if "SESSION_ID" in os.environ: 91 session_id = int(os.environ['SESSION_ID']) 92 else: 93 raise RuntimeError("Testcase requires SESSION_ID environment variable") 94 95 some_cache = ds.DatasetCache(session_id=session_id, size=0) 96 97 # This DATA_DIR only has 2 images in it 98 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 99 decode_op = c_vision.Decode() 100 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 101 ds1 = ds1.repeat(4) 102 103 filename = "cache_map_02_result.npz" 104 save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) 105 106 logger.info("test_cache_map_basic2 Ended.\n") 107 108 109@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 110def test_cache_map_basic3(): 111 """ 112 Test different rows result in core dump 113 """ 114 logger.info("Test cache basic 3") 115 if "SESSION_ID" in os.environ: 116 session_id = int(os.environ['SESSION_ID']) 117 else: 118 raise RuntimeError("Testcase requires SESSION_ID environment variable") 119 some_cache = ds.DatasetCache(session_id=session_id, size=0) 120 121 # This DATA_DIR only has 2 images in it 122 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 123 decode_op = c_vision.Decode() 124 ds1 = ds1.repeat(4) 125 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 126 logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) 127 shape = ds1.output_shapes() 128 logger.info(shape) 129 num_iter = 0 130 for _ in ds1.create_dict_iterator(num_epochs=1): 131 logger.info("get data from dataset") 132 num_iter += 1 133 134 logger.info("Number of data in ds1: {} ".format(num_iter)) 135 assert num_iter == 8 136 logger.info('test_cache_basic3 Ended.\n') 137 138 139@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 140def test_cache_map_basic4(): 141 """ 142 Test Map containing random operation above cache 143 144 repeat 145 | 146 Map(decode, randomCrop) 147 | 148 Cache 149 | 150 ImageFolder 151 152 """ 153 logger.info("Test cache basic 4") 154 if "SESSION_ID" in os.environ: 155 session_id = int(os.environ['SESSION_ID']) 156 else: 157 raise RuntimeError("Testcase requires SESSION_ID environment variable") 158 159 some_cache = ds.DatasetCache(session_id=session_id, size=0) 160 161 # This DATA_DIR only has 2 images in it 162 data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 163 random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) 164 decode_op = c_vision.Decode() 165 166 data = data.map(input_columns=["image"], operations=decode_op) 167 data = data.map(input_columns=["image"], operations=random_crop_op) 168 data = data.repeat(4) 169 170 num_iter = 0 171 for _ in data.create_dict_iterator(): 172 num_iter += 1 173 174 logger.info("Number of data in ds1: {} ".format(num_iter)) 175 assert num_iter == 8 176 logger.info('test_cache_basic4 Ended.\n') 177 178 179@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 180def test_cache_map_basic5(): 181 """ 182 Test cache as root node 183 184 cache 185 | 186 ImageFolder 187 """ 188 logger.info("Test cache basic 5") 189 if "SESSION_ID" in os.environ: 190 session_id = int(os.environ['SESSION_ID']) 191 else: 192 raise RuntimeError("Testcase requires SESSION_ID environment variable") 193 some_cache = ds.DatasetCache(session_id=session_id, size=0) 194 195 # This DATA_DIR only has 2 images in it 196 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 197 num_iter = 0 198 for _ in ds1.create_dict_iterator(num_epochs=1): 199 logger.info("get data from dataset") 200 num_iter += 1 201 202 logger.info("Number of data in ds1: {} ".format(num_iter)) 203 assert num_iter == 2 204 logger.info('test_cache_basic5 Ended.\n') 205 206 207@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 208def test_cache_map_failure1(): 209 """ 210 Test nested cache (failure) 211 212 Repeat 213 | 214 Cache 215 | 216 Map(decode) 217 | 218 Cache 219 | 220 Coco 221 222 """ 223 logger.info("Test cache failure 1") 224 if "SESSION_ID" in os.environ: 225 session_id = int(os.environ['SESSION_ID']) 226 else: 227 raise RuntimeError("Testcase requires SESSION_ID environment variable") 228 229 some_cache = ds.DatasetCache(session_id=session_id, size=0) 230 231 # This DATA_DIR has 6 images in it 232 ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, 233 cache=some_cache) 234 decode_op = c_vision.Decode() 235 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 236 ds1 = ds1.repeat(4) 237 238 with pytest.raises(RuntimeError) as e: 239 ds1.get_batch_size() 240 assert "Nested cache operations" in str(e.value) 241 242 with pytest.raises(RuntimeError) as e: 243 num_iter = 0 244 for _ in ds1.create_dict_iterator(num_epochs=1): 245 num_iter += 1 246 assert "Nested cache operations" in str(e.value) 247 248 assert num_iter == 0 249 logger.info('test_cache_failure1 Ended.\n') 250 251 252@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 253def test_cache_map_failure2(): 254 """ 255 Test zip under cache (failure) 256 257 repeat 258 | 259 Cache 260 | 261 Map(decode) 262 | 263 Zip 264 | | 265 ImageFolder ImageFolder 266 267 """ 268 logger.info("Test cache failure 2") 269 if "SESSION_ID" in os.environ: 270 session_id = int(os.environ['SESSION_ID']) 271 else: 272 raise RuntimeError("Testcase requires SESSION_ID environment variable") 273 274 some_cache = ds.DatasetCache(session_id=session_id, size=0) 275 276 # This DATA_DIR only has 2 images in it 277 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 278 ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 279 dsz = ds.zip((ds1, ds2)) 280 decode_op = c_vision.Decode() 281 dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache) 282 dsz = dsz.repeat(4) 283 284 with pytest.raises(RuntimeError) as e: 285 num_iter = 0 286 for _ in dsz.create_dict_iterator(): 287 num_iter += 1 288 assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value) 289 290 assert num_iter == 0 291 logger.info('test_cache_failure2 Ended.\n') 292 293 294@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 295def test_cache_map_failure3(): 296 """ 297 Test batch under cache (failure) 298 299 repeat 300 | 301 Cache 302 | 303 Map(resize) 304 | 305 Batch 306 | 307 Mnist 308 """ 309 logger.info("Test cache failure 3") 310 if "SESSION_ID" in os.environ: 311 session_id = int(os.environ['SESSION_ID']) 312 else: 313 raise RuntimeError("Testcase requires SESSION_ID environment variable") 314 315 some_cache = ds.DatasetCache(session_id=session_id, size=0) 316 317 ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) 318 ds1 = ds1.batch(2) 319 resize_op = c_vision.Resize((224, 224)) 320 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 321 ds1 = ds1.repeat(4) 322 323 with pytest.raises(RuntimeError) as e: 324 num_iter = 0 325 for _ in ds1.create_dict_iterator(): 326 num_iter += 1 327 assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value) 328 329 assert num_iter == 0 330 logger.info('test_cache_failure3 Ended.\n') 331 332 333@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 334def test_cache_map_failure4(): 335 """ 336 Test filter under cache (failure) 337 338 repeat 339 | 340 Cache 341 | 342 Map(decode) 343 | 344 Filter 345 | 346 CelebA 347 348 """ 349 logger.info("Test cache failure 4") 350 if "SESSION_ID" in os.environ: 351 session_id = int(os.environ['SESSION_ID']) 352 else: 353 raise RuntimeError("Testcase requires SESSION_ID environment variable") 354 355 some_cache = ds.DatasetCache(session_id=session_id, size=0) 356 357 # This dataset has 4 records 358 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) 359 ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) 360 361 decode_op = c_vision.Decode() 362 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 363 ds1 = ds1.repeat(4) 364 365 with pytest.raises(RuntimeError) as e: 366 num_iter = 0 367 for _ in ds1.create_dict_iterator(): 368 num_iter += 1 369 assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value) 370 371 assert num_iter == 0 372 logger.info('test_cache_failure4 Ended.\n') 373 374 375@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 376def test_cache_map_failure5(): 377 """ 378 Test Map containing random operation under cache (failure) 379 380 repeat 381 | 382 Cache 383 | 384 Map(decode, randomCrop) 385 | 386 Manifest 387 388 """ 389 logger.info("Test cache failure 5") 390 if "SESSION_ID" in os.environ: 391 session_id = int(os.environ['SESSION_ID']) 392 else: 393 raise RuntimeError("Testcase requires SESSION_ID environment variable") 394 395 some_cache = ds.DatasetCache(session_id=session_id, size=0) 396 397 # This dataset has 4 records 398 data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) 399 random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) 400 decode_op = c_vision.Decode() 401 402 data = data.map(input_columns=["image"], operations=decode_op) 403 data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache) 404 data = data.repeat(4) 405 406 with pytest.raises(RuntimeError) as e: 407 num_iter = 0 408 for _ in data.create_dict_iterator(): 409 num_iter += 1 410 assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value) 411 412 assert num_iter == 0 413 logger.info('test_cache_failure5 Ended.\n') 414 415 416@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 417def test_cache_map_failure7(): 418 """ 419 Test no-cache-supporting Generator leaf with Map under cache (failure) 420 421 repeat 422 | 423 Cache 424 | 425 Map(lambda x: x) 426 | 427 Generator 428 429 """ 430 431 def generator_1d(): 432 for i in range(64): 433 yield (np.array(i),) 434 435 logger.info("Test cache failure 7") 436 if "SESSION_ID" in os.environ: 437 session_id = int(os.environ['SESSION_ID']) 438 else: 439 raise RuntimeError("Testcase requires SESSION_ID environment variable") 440 441 some_cache = ds.DatasetCache(session_id=session_id, size=0) 442 443 data = ds.GeneratorDataset(generator_1d, ["data"]) 444 data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache) 445 data = data.repeat(4) 446 447 with pytest.raises(RuntimeError) as e: 448 num_iter = 0 449 for _ in data.create_dict_iterator(): 450 num_iter += 1 451 assert "There is currently no support for GeneratorOp under cache" in str(e.value) 452 453 assert num_iter == 0 454 logger.info('test_cache_failure7 Ended.\n') 455 456 457@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 458def test_cache_map_failure8(): 459 """ 460 Test a repeat under mappable cache (failure) 461 462 Cache 463 | 464 Map(decode) 465 | 466 Repeat 467 | 468 Cifar10 469 """ 470 471 logger.info("Test cache failure 8") 472 if "SESSION_ID" in os.environ: 473 session_id = int(os.environ['SESSION_ID']) 474 else: 475 raise RuntimeError("Testcase requires SESSION_ID environment variable") 476 477 some_cache = ds.DatasetCache(session_id=session_id, size=0) 478 479 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10) 480 decode_op = c_vision.Decode() 481 ds1 = ds1.repeat(4) 482 ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) 483 484 with pytest.raises(RuntimeError) as e: 485 num_iter = 0 486 for _ in ds1.create_dict_iterator(num_epochs=1): 487 num_iter += 1 488 assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value) 489 490 assert num_iter == 0 491 logger.info('test_cache_failure8 Ended.\n') 492 493 494@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 495def test_cache_map_failure9(): 496 """ 497 Test take under cache (failure) 498 499 repeat 500 | 501 Cache 502 | 503 Map(decode) 504 | 505 Take 506 | 507 Cifar100 508 509 """ 510 logger.info("Test cache failure 9") 511 if "SESSION_ID" in os.environ: 512 session_id = int(os.environ['SESSION_ID']) 513 else: 514 raise RuntimeError("Testcase requires SESSION_ID environment variable") 515 516 some_cache = ds.DatasetCache(session_id=session_id, size=0) 517 518 ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) 519 ds1 = ds1.take(2) 520 521 decode_op = c_vision.Decode() 522 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 523 ds1 = ds1.repeat(4) 524 525 with pytest.raises(RuntimeError) as e: 526 num_iter = 0 527 for _ in ds1.create_dict_iterator(): 528 num_iter += 1 529 assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) 530 531 assert num_iter == 0 532 logger.info('test_cache_failure9 Ended.\n') 533 534 535@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 536def test_cache_map_failure10(): 537 """ 538 Test skip under cache (failure) 539 540 repeat 541 | 542 Cache 543 | 544 Map(decode) 545 | 546 Skip 547 | 548 VOC 549 550 """ 551 logger.info("Test cache failure 10") 552 if "SESSION_ID" in os.environ: 553 session_id = int(os.environ['SESSION_ID']) 554 else: 555 raise RuntimeError("Testcase requires SESSION_ID environment variable") 556 557 some_cache = ds.DatasetCache(session_id=session_id, size=0) 558 559 # This dataset has 9 records 560 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) 561 ds1 = ds1.skip(1) 562 563 decode_op = c_vision.Decode() 564 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 565 ds1 = ds1.repeat(4) 566 567 with pytest.raises(RuntimeError) as e: 568 num_iter = 0 569 for _ in ds1.create_dict_iterator(): 570 num_iter += 1 571 assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value) 572 573 assert num_iter == 0 574 logger.info('test_cache_failure10 Ended.\n') 575 576 577@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 578def test_cache_map_failure11(): 579 """ 580 Test set spilling=true when cache server is started without spilling support (failure) 581 582 Cache(spilling=true) 583 | 584 ImageFolder 585 586 """ 587 logger.info("Test cache failure 11") 588 if "SESSION_ID" in os.environ: 589 session_id = int(os.environ['SESSION_ID']) 590 else: 591 raise RuntimeError("Testcase requires SESSION_ID environment variable") 592 593 some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) 594 595 # This DATA_DIR only has 2 images in it 596 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 597 598 with pytest.raises(RuntimeError) as e: 599 num_iter = 0 600 for _ in ds1.create_dict_iterator(): 601 num_iter += 1 602 assert "Unexpected error. Server is not set up with spill support" in str(e.value) 603 604 assert num_iter == 0 605 logger.info('test_cache_failure11 Ended.\n') 606 607 608@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 609def test_cache_map_split1(): 610 """ 611 Test split (after a non-source node) under cache (failure). 612 Split after a non-source node is implemented with TakeOp/SkipOp, hence the failure. 613 614 repeat 615 | 616 Cache 617 | 618 Map(resize) 619 | 620 Split 621 | 622 Map(decode) 623 | 624 ImageFolder 625 626 """ 627 logger.info("Test cache split 1") 628 if "SESSION_ID" in os.environ: 629 session_id = int(os.environ['SESSION_ID']) 630 else: 631 raise RuntimeError("Testcase requires SESSION_ID environment variable") 632 633 some_cache = ds.DatasetCache(session_id=session_id, size=0) 634 635 # This DATA_DIR only has 2 images in it 636 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 637 638 decode_op = c_vision.Decode() 639 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 640 ds1, ds2 = ds1.split([0.5, 0.5]) 641 resize_op = c_vision.Resize((224, 224)) 642 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 643 ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache) 644 ds1 = ds1.repeat(4) 645 ds2 = ds2.repeat(4) 646 647 with pytest.raises(RuntimeError) as e: 648 num_iter = 0 649 for _ in ds1.create_dict_iterator(): 650 num_iter += 1 651 assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) 652 653 with pytest.raises(RuntimeError) as e: 654 num_iter = 0 655 for _ in ds2.create_dict_iterator(): 656 num_iter += 1 657 assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value) 658 logger.info('test_cache_split1 Ended.\n') 659 660 661@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 662def test_cache_map_split2(): 663 """ 664 Test split (after a source node) under cache (ok). 665 Split after a source node is implemented with subset sampler, hence ok. 666 667 repeat 668 | 669 Cache 670 | 671 Map(resize) 672 | 673 Split 674 | 675 VOCDataset 676 677 """ 678 logger.info("Test cache split 2") 679 if "SESSION_ID" in os.environ: 680 session_id = int(os.environ['SESSION_ID']) 681 else: 682 raise RuntimeError("Testcase requires SESSION_ID environment variable") 683 684 some_cache = ds.DatasetCache(session_id=session_id, size=0) 685 686 # This dataset has 9 records 687 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) 688 689 ds1, ds2 = ds1.split([0.3, 0.7]) 690 resize_op = c_vision.Resize((224, 224)) 691 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 692 ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache) 693 ds1 = ds1.repeat(4) 694 ds2 = ds2.repeat(4) 695 696 num_iter = 0 697 for _ in ds1.create_dict_iterator(): 698 num_iter += 1 699 assert num_iter == 12 700 701 num_iter = 0 702 for _ in ds2.create_dict_iterator(): 703 num_iter += 1 704 assert num_iter == 24 705 logger.info('test_cache_split2 Ended.\n') 706 707 708@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 709def test_cache_map_parameter_check(): 710 """ 711 Test illegal parameters for DatasetCache 712 """ 713 714 logger.info("Test cache map parameter check") 715 716 with pytest.raises(ValueError) as info: 717 ds.DatasetCache(session_id=-1, size=0) 718 assert "Input is not within the required interval" in str(info.value) 719 720 with pytest.raises(TypeError) as info: 721 ds.DatasetCache(session_id="1", size=0) 722 assert "Argument session_id with value 1 is not of type" in str(info.value) 723 724 with pytest.raises(TypeError) as info: 725 ds.DatasetCache(session_id=None, size=0) 726 assert "Argument session_id with value None is not of type" in str(info.value) 727 728 with pytest.raises(ValueError) as info: 729 ds.DatasetCache(session_id=1, size=-1) 730 assert "Input size must be greater than 0" in str(info.value) 731 732 with pytest.raises(TypeError) as info: 733 ds.DatasetCache(session_id=1, size="1") 734 assert "Argument size with value 1 is not of type" in str(info.value) 735 736 with pytest.raises(TypeError) as info: 737 ds.DatasetCache(session_id=1, size=None) 738 assert "Argument size with value None is not of type" in str(info.value) 739 740 with pytest.raises(TypeError) as info: 741 ds.DatasetCache(session_id=1, size=0, spilling="illegal") 742 assert "Argument spilling with value illegal is not of type" in str(info.value) 743 744 with pytest.raises(TypeError) as err: 745 ds.DatasetCache(session_id=1, size=0, hostname=50052) 746 assert "Argument hostname with value 50052 is not of type" in str(err.value) 747 748 with pytest.raises(RuntimeError) as err: 749 ds.DatasetCache(session_id=1, size=0, hostname="illegal") 750 assert "now cache client has to be on the same host with cache server" in str(err.value) 751 752 with pytest.raises(RuntimeError) as err: 753 ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2") 754 assert "now cache client has to be on the same host with cache server" in str(err.value) 755 756 with pytest.raises(TypeError) as info: 757 ds.DatasetCache(session_id=1, size=0, port="illegal") 758 assert "Argument port with value illegal is not of type" in str(info.value) 759 760 with pytest.raises(TypeError) as info: 761 ds.DatasetCache(session_id=1, size=0, port="50052") 762 assert "Argument port with value 50052 is not of type" in str(info.value) 763 764 with pytest.raises(ValueError) as err: 765 ds.DatasetCache(session_id=1, size=0, port=0) 766 assert "Input port is not within the required interval of [1025, 65535]" in str(err.value) 767 768 with pytest.raises(ValueError) as err: 769 ds.DatasetCache(session_id=1, size=0, port=65536) 770 assert "Input port is not within the required interval of [1025, 65535]" in str(err.value) 771 772 with pytest.raises(TypeError) as err: 773 ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True) 774 assert "Argument cache with value True is not of type" in str(err.value) 775 776 logger.info("test_cache_map_parameter_check Ended.\n") 777 778 779@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 780def test_cache_map_running_twice1(): 781 """ 782 Executing the same pipeline for twice (from python), with cache injected after map 783 784 Repeat 785 | 786 Cache 787 | 788 Map(decode) 789 | 790 ImageFolder 791 """ 792 793 logger.info("Test cache map running twice 1") 794 if "SESSION_ID" in os.environ: 795 session_id = int(os.environ['SESSION_ID']) 796 else: 797 raise RuntimeError("Testcase requires SESSION_ID environment variable") 798 799 some_cache = ds.DatasetCache(session_id=session_id, size=0) 800 801 # This DATA_DIR only has 2 images in it 802 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 803 decode_op = c_vision.Decode() 804 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 805 ds1 = ds1.repeat(4) 806 807 num_iter = 0 808 for _ in ds1.create_dict_iterator(): 809 num_iter += 1 810 logger.info("Number of data in ds1: {} ".format(num_iter)) 811 assert num_iter == 8 812 813 num_iter = 0 814 for _ in ds1.create_dict_iterator(): 815 num_iter += 1 816 logger.info("Number of data in ds1: {} ".format(num_iter)) 817 assert num_iter == 8 818 819 logger.info("test_cache_map_running_twice1 Ended.\n") 820 821 822@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 823def test_cache_map_running_twice2(): 824 """ 825 Executing the same pipeline for twice (from shell), with cache injected after leaf 826 827 Repeat 828 | 829 Map(decode) 830 | 831 Cache 832 | 833 ImageFolder 834 """ 835 836 logger.info("Test cache map running twice 2") 837 if "SESSION_ID" in os.environ: 838 session_id = int(os.environ['SESSION_ID']) 839 else: 840 raise RuntimeError("Testcase requires SESSION_ID environment variable") 841 842 some_cache = ds.DatasetCache(session_id=session_id, size=0) 843 844 # This DATA_DIR only has 2 images in it 845 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 846 decode_op = c_vision.Decode() 847 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 848 ds1 = ds1.repeat(4) 849 850 num_iter = 0 851 for _ in ds1.create_dict_iterator(): 852 num_iter += 1 853 854 logger.info("Number of data in ds1: {} ".format(num_iter)) 855 assert num_iter == 8 856 logger.info("test_cache_map_running_twice2 Ended.\n") 857 858 859@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 860def test_cache_map_extra_small_size1(): 861 """ 862 Test running pipeline with cache of extra small size and spilling true 863 864 Repeat 865 | 866 Map(decode) 867 | 868 Cache 869 | 870 ImageFolder 871 """ 872 873 logger.info("Test cache map extra small size 1") 874 if "SESSION_ID" in os.environ: 875 session_id = int(os.environ['SESSION_ID']) 876 else: 877 raise RuntimeError("Testcase requires SESSION_ID environment variable") 878 879 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True) 880 881 # This DATA_DIR only has 2 images in it 882 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 883 decode_op = c_vision.Decode() 884 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 885 ds1 = ds1.repeat(4) 886 887 num_iter = 0 888 for _ in ds1.create_dict_iterator(): 889 num_iter += 1 890 891 logger.info("Number of data in ds1: {} ".format(num_iter)) 892 assert num_iter == 8 893 logger.info("test_cache_map_extra_small_size1 Ended.\n") 894 895 896@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 897def test_cache_map_extra_small_size2(): 898 """ 899 Test running pipeline with cache of extra small size and spilling false 900 901 Repeat 902 | 903 Cache 904 | 905 Map(decode) 906 | 907 ImageFolder 908 """ 909 910 logger.info("Test cache map extra small size 2") 911 if "SESSION_ID" in os.environ: 912 session_id = int(os.environ['SESSION_ID']) 913 else: 914 raise RuntimeError("Testcase requires SESSION_ID environment variable") 915 916 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) 917 918 # This DATA_DIR only has 2 images in it 919 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 920 decode_op = c_vision.Decode() 921 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 922 ds1 = ds1.repeat(4) 923 924 num_iter = 0 925 for _ in ds1.create_dict_iterator(): 926 num_iter += 1 927 928 logger.info("Number of data in ds1: {} ".format(num_iter)) 929 assert num_iter == 8 930 logger.info("test_cache_map_extra_small_size2 Ended.\n") 931 932 933@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 934def test_cache_map_no_image(): 935 """ 936 Test cache with no dataset existing in the path 937 938 Repeat 939 | 940 Map(decode) 941 | 942 Cache 943 | 944 ImageFolder 945 """ 946 947 logger.info("Test cache map no image") 948 if "SESSION_ID" in os.environ: 949 session_id = int(os.environ['SESSION_ID']) 950 else: 951 raise RuntimeError("Testcase requires SESSION_ID environment variable") 952 953 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) 954 955 # This DATA_DIR only has 2 images in it 956 ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache) 957 decode_op = c_vision.Decode() 958 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 959 ds1 = ds1.repeat(4) 960 961 with pytest.raises(RuntimeError): 962 num_iter = 0 963 for _ in ds1.create_dict_iterator(): 964 num_iter += 1 965 966 assert num_iter == 0 967 logger.info("test_cache_map_no_image Ended.\n") 968 969 970@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 971def test_cache_map_parallel_pipeline1(shard): 972 """ 973 Test running two parallel pipelines (sharing cache) with cache injected after leaf op 974 975 Repeat 976 | 977 Map(decode) 978 | 979 Cache 980 | 981 ImageFolder 982 """ 983 984 logger.info("Test cache map parallel pipeline 1") 985 if "SESSION_ID" in os.environ: 986 session_id = int(os.environ['SESSION_ID']) 987 else: 988 raise RuntimeError("Testcase requires SESSION_ID environment variable") 989 990 some_cache = ds.DatasetCache(session_id=session_id, size=0) 991 992 # This DATA_DIR only has 2 images in it 993 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache) 994 decode_op = c_vision.Decode() 995 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 996 ds1 = ds1.repeat(4) 997 998 num_iter = 0 999 for _ in ds1.create_dict_iterator(): 1000 num_iter += 1 1001 1002 logger.info("Number of data in ds1: {} ".format(num_iter)) 1003 assert num_iter == 4 1004 logger.info("test_cache_map_parallel_pipeline1 Ended.\n") 1005 1006 1007@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1008def test_cache_map_parallel_pipeline2(shard): 1009 """ 1010 Test running two parallel pipelines (sharing cache) with cache injected after map op 1011 1012 Repeat 1013 | 1014 Cache 1015 | 1016 Map(decode) 1017 | 1018 ImageFolder 1019 """ 1020 1021 logger.info("Test cache map parallel pipeline 2") 1022 if "SESSION_ID" in os.environ: 1023 session_id = int(os.environ['SESSION_ID']) 1024 else: 1025 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1026 1027 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1028 1029 # This DATA_DIR only has 2 images in it 1030 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard)) 1031 decode_op = c_vision.Decode() 1032 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1033 ds1 = ds1.repeat(4) 1034 1035 num_iter = 0 1036 for _ in ds1.create_dict_iterator(): 1037 num_iter += 1 1038 1039 logger.info("Number of data in ds1: {} ".format(num_iter)) 1040 assert num_iter == 4 1041 logger.info("test_cache_map_parallel_pipeline2 Ended.\n") 1042 1043 1044@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1045def test_cache_map_parallel_workers(): 1046 """ 1047 Test cache with num_parallel_workers > 1 set for map op and leaf op 1048 1049 Repeat 1050 | 1051 cache 1052 | 1053 Map(decode) 1054 | 1055 ImageFolder 1056 """ 1057 1058 logger.info("Test cache map parallel workers") 1059 if "SESSION_ID" in os.environ: 1060 session_id = int(os.environ['SESSION_ID']) 1061 else: 1062 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1063 1064 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1065 1066 # This DATA_DIR only has 2 images in it 1067 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4) 1068 decode_op = c_vision.Decode() 1069 ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache) 1070 ds1 = ds1.repeat(4) 1071 1072 num_iter = 0 1073 for _ in ds1.create_dict_iterator(): 1074 num_iter += 1 1075 1076 logger.info("Number of data in ds1: {} ".format(num_iter)) 1077 assert num_iter == 8 1078 logger.info("test_cache_map_parallel_workers Ended.\n") 1079 1080 1081@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1082def test_cache_map_server_workers_1(): 1083 """ 1084 start cache server with --workers 1 and then test cache function 1085 1086 Repeat 1087 | 1088 cache 1089 | 1090 Map(decode) 1091 | 1092 ImageFolder 1093 """ 1094 1095 logger.info("Test cache map server workers 1") 1096 if "SESSION_ID" in os.environ: 1097 session_id = int(os.environ['SESSION_ID']) 1098 else: 1099 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1100 1101 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1102 1103 # This DATA_DIR only has 2 images in it 1104 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1105 decode_op = c_vision.Decode() 1106 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1107 ds1 = ds1.repeat(4) 1108 1109 num_iter = 0 1110 for _ in ds1.create_dict_iterator(): 1111 num_iter += 1 1112 1113 logger.info("Number of data in ds1: {} ".format(num_iter)) 1114 assert num_iter == 8 1115 logger.info("test_cache_map_server_workers_1 Ended.\n") 1116 1117 1118@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1119def test_cache_map_server_workers_100(): 1120 """ 1121 start cache server with --workers 100 and then test cache function 1122 1123 Repeat 1124 | 1125 Map(decode) 1126 | 1127 cache 1128 | 1129 ImageFolder 1130 """ 1131 1132 logger.info("Test cache map server workers 100") 1133 if "SESSION_ID" in os.environ: 1134 session_id = int(os.environ['SESSION_ID']) 1135 else: 1136 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1137 1138 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1139 1140 # This DATA_DIR only has 2 images in it 1141 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1142 decode_op = c_vision.Decode() 1143 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1144 ds1 = ds1.repeat(4) 1145 1146 num_iter = 0 1147 for _ in ds1.create_dict_iterator(): 1148 num_iter += 1 1149 1150 logger.info("Number of data in ds1: {} ".format(num_iter)) 1151 assert num_iter == 8 1152 logger.info("test_cache_map_server_workers_100 Ended.\n") 1153 1154 1155@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1156def test_cache_map_num_connections_1(): 1157 """ 1158 Test setting num_connections=1 in DatasetCache 1159 1160 Repeat 1161 | 1162 cache 1163 | 1164 Map(decode) 1165 | 1166 ImageFolder 1167 """ 1168 1169 logger.info("Test cache map num_connections 1") 1170 if "SESSION_ID" in os.environ: 1171 session_id = int(os.environ['SESSION_ID']) 1172 else: 1173 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1174 1175 some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1) 1176 1177 # This DATA_DIR only has 2 images in it 1178 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1179 decode_op = c_vision.Decode() 1180 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1181 ds1 = ds1.repeat(4) 1182 1183 num_iter = 0 1184 for _ in ds1.create_dict_iterator(): 1185 num_iter += 1 1186 1187 logger.info("Number of data in ds1: {} ".format(num_iter)) 1188 assert num_iter == 8 1189 logger.info("test_cache_map_num_connections_1 Ended.\n") 1190 1191 1192@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1193def test_cache_map_num_connections_100(): 1194 """ 1195 Test setting num_connections=100 in DatasetCache 1196 1197 Repeat 1198 | 1199 Map(decode) 1200 | 1201 cache 1202 | 1203 ImageFolder 1204 """ 1205 1206 logger.info("Test cache map num_connections 100") 1207 if "SESSION_ID" in os.environ: 1208 session_id = int(os.environ['SESSION_ID']) 1209 else: 1210 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1211 1212 some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100) 1213 1214 # This DATA_DIR only has 2 images in it 1215 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1216 decode_op = c_vision.Decode() 1217 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1218 ds1 = ds1.repeat(4) 1219 1220 num_iter = 0 1221 for _ in ds1.create_dict_iterator(): 1222 num_iter += 1 1223 1224 logger.info("Number of data in ds1: {} ".format(num_iter)) 1225 assert num_iter == 8 1226 logger.info("test_cache_map_num_connections_100 Ended.\n") 1227 1228 1229@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1230def test_cache_map_prefetch_size_1(): 1231 """ 1232 Test setting prefetch_size=1 in DatasetCache 1233 1234 Repeat 1235 | 1236 cache 1237 | 1238 Map(decode) 1239 | 1240 ImageFolder 1241 """ 1242 1243 logger.info("Test cache map prefetch_size 1") 1244 if "SESSION_ID" in os.environ: 1245 session_id = int(os.environ['SESSION_ID']) 1246 else: 1247 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1248 1249 some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1) 1250 1251 # This DATA_DIR only has 2 images in it 1252 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1253 decode_op = c_vision.Decode() 1254 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1255 ds1 = ds1.repeat(4) 1256 1257 num_iter = 0 1258 for _ in ds1.create_dict_iterator(): 1259 num_iter += 1 1260 1261 logger.info("Number of data in ds1: {} ".format(num_iter)) 1262 assert num_iter == 8 1263 logger.info("test_cache_map_prefetch_size_1 Ended.\n") 1264 1265 1266@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1267def test_cache_map_prefetch_size_100(): 1268 """ 1269 Test setting prefetch_size=100 in DatasetCache 1270 1271 Repeat 1272 | 1273 Map(decode) 1274 | 1275 cache 1276 | 1277 ImageFolder 1278 """ 1279 1280 logger.info("Test cache map prefetch_size 100") 1281 if "SESSION_ID" in os.environ: 1282 session_id = int(os.environ['SESSION_ID']) 1283 else: 1284 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1285 1286 some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100) 1287 1288 # This DATA_DIR only has 2 images in it 1289 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1290 decode_op = c_vision.Decode() 1291 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1292 ds1 = ds1.repeat(4) 1293 1294 num_iter = 0 1295 for _ in ds1.create_dict_iterator(): 1296 num_iter += 1 1297 1298 logger.info("Number of data in ds1: {} ".format(num_iter)) 1299 assert num_iter == 8 1300 logger.info("test_cache_map_prefetch_size_100 Ended.\n") 1301 1302 1303@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1304def test_cache_map_to_device(): 1305 """ 1306 Test cache with to_device 1307 1308 DeviceQueue 1309 | 1310 EpochCtrl 1311 | 1312 Repeat 1313 | 1314 Map(decode) 1315 | 1316 cache 1317 | 1318 ImageFolder 1319 """ 1320 1321 logger.info("Test cache map to_device") 1322 if "SESSION_ID" in os.environ: 1323 session_id = int(os.environ['SESSION_ID']) 1324 else: 1325 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1326 1327 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1328 1329 # This DATA_DIR only has 2 images in it 1330 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1331 decode_op = c_vision.Decode() 1332 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1333 ds1 = ds1.repeat(4) 1334 ds1 = ds1.to_device() 1335 ds1.send() 1336 1337 logger.info("test_cache_map_to_device Ended.\n") 1338 1339 1340@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1341def test_cache_map_epoch_ctrl1(): 1342 """ 1343 Test using two-loops method to run several epochs 1344 1345 Map(decode) 1346 | 1347 cache 1348 | 1349 ImageFolder 1350 """ 1351 1352 logger.info("Test cache map epoch ctrl1") 1353 if "SESSION_ID" in os.environ: 1354 session_id = int(os.environ['SESSION_ID']) 1355 else: 1356 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1357 1358 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1359 1360 # This DATA_DIR only has 2 images in it 1361 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1362 decode_op = c_vision.Decode() 1363 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1364 1365 num_epoch = 5 1366 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1367 1368 epoch_count = 0 1369 for _ in range(num_epoch): 1370 row_count = 0 1371 for _ in iter1: 1372 row_count += 1 1373 logger.info("Number of data in ds1: {} ".format(row_count)) 1374 assert row_count == 2 1375 epoch_count += 1 1376 assert epoch_count == num_epoch 1377 logger.info("test_cache_map_epoch_ctrl1 Ended.\n") 1378 1379 1380@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1381def test_cache_map_epoch_ctrl2(): 1382 """ 1383 Test using two-loops method with infinite epochs 1384 1385 cache 1386 | 1387 Map(decode) 1388 | 1389 ImageFolder 1390 """ 1391 1392 logger.info("Test cache map epoch ctrl2") 1393 if "SESSION_ID" in os.environ: 1394 session_id = int(os.environ['SESSION_ID']) 1395 else: 1396 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1397 1398 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1399 1400 # This DATA_DIR only has 2 images in it 1401 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1402 decode_op = c_vision.Decode() 1403 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 1404 1405 num_epoch = 5 1406 # iter1 will always assume there is a next epoch and never shutdown 1407 iter1 = ds1.create_dict_iterator() 1408 1409 epoch_count = 0 1410 for _ in range(num_epoch): 1411 row_count = 0 1412 for _ in iter1: 1413 row_count += 1 1414 logger.info("Number of data in ds1: {} ".format(row_count)) 1415 assert row_count == 2 1416 epoch_count += 1 1417 assert epoch_count == num_epoch 1418 1419 # manually stop the iterator 1420 iter1.stop() 1421 logger.info("test_cache_map_epoch_ctrl2 Ended.\n") 1422 1423 1424@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1425def test_cache_map_epoch_ctrl3(): 1426 """ 1427 Test using two-loops method with infinite epochs over repeat 1428 1429 repeat 1430 | 1431 Map(decode) 1432 | 1433 cache 1434 | 1435 ImageFolder 1436 """ 1437 1438 logger.info("Test cache map epoch ctrl3") 1439 if "SESSION_ID" in os.environ: 1440 session_id = int(os.environ['SESSION_ID']) 1441 else: 1442 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1443 1444 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1445 1446 # This DATA_DIR only has 2 images in it 1447 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1448 decode_op = c_vision.Decode() 1449 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1450 ds1 = ds1.repeat(2) 1451 1452 num_epoch = 5 1453 # iter1 will always assume there is a next epoch and never shutdown 1454 iter1 = ds1.create_dict_iterator() 1455 1456 epoch_count = 0 1457 for _ in range(num_epoch): 1458 row_count = 0 1459 for _ in iter1: 1460 row_count += 1 1461 logger.info("Number of data in ds1: {} ".format(row_count)) 1462 assert row_count == 4 1463 epoch_count += 1 1464 assert epoch_count == num_epoch 1465 1466 # reply on garbage collector to destroy iter1 1467 1468 logger.info("test_cache_map_epoch_ctrl3 Ended.\n") 1469 1470 1471@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1472def test_cache_map_coco1(): 1473 """ 1474 Test mappable coco leaf with cache op right over the leaf 1475 1476 cache 1477 | 1478 Coco 1479 """ 1480 1481 logger.info("Test cache map coco1") 1482 if "SESSION_ID" in os.environ: 1483 session_id = int(os.environ['SESSION_ID']) 1484 else: 1485 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1486 1487 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1488 1489 # This dataset has 6 records 1490 ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, 1491 cache=some_cache) 1492 1493 num_epoch = 4 1494 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1495 1496 epoch_count = 0 1497 for _ in range(num_epoch): 1498 assert sum([1 for _ in iter1]) == 6 1499 epoch_count += 1 1500 assert epoch_count == num_epoch 1501 1502 logger.info("test_cache_map_coco1 Ended.\n") 1503 1504 1505@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1506def test_cache_map_coco2(): 1507 """ 1508 Test mappable coco leaf with the cache op later in the tree above the map(resize) 1509 1510 cache 1511 | 1512 Map(resize) 1513 | 1514 Coco 1515 """ 1516 1517 logger.info("Test cache map coco2") 1518 if "SESSION_ID" in os.environ: 1519 session_id = int(os.environ['SESSION_ID']) 1520 else: 1521 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1522 1523 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1524 1525 # This dataset has 6 records 1526 ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) 1527 resize_op = c_vision.Resize((224, 224)) 1528 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 1529 1530 num_epoch = 4 1531 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1532 1533 epoch_count = 0 1534 for _ in range(num_epoch): 1535 assert sum([1 for _ in iter1]) == 6 1536 epoch_count += 1 1537 assert epoch_count == num_epoch 1538 1539 logger.info("test_cache_map_coco2 Ended.\n") 1540 1541 1542@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1543def test_cache_map_mnist1(): 1544 """ 1545 Test mappable mnist leaf with cache op right over the leaf 1546 1547 cache 1548 | 1549 Mnist 1550 """ 1551 1552 logger.info("Test cache map mnist1") 1553 if "SESSION_ID" in os.environ: 1554 session_id = int(os.environ['SESSION_ID']) 1555 else: 1556 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1557 1558 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1559 ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) 1560 1561 num_epoch = 4 1562 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1563 1564 epoch_count = 0 1565 for _ in range(num_epoch): 1566 assert sum([1 for _ in iter1]) == 10 1567 epoch_count += 1 1568 assert epoch_count == num_epoch 1569 1570 logger.info("test_cache_map_mnist1 Ended.\n") 1571 1572 1573@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1574def test_cache_map_mnist2(): 1575 """ 1576 Test mappable mnist leaf with the cache op later in the tree above the map(resize) 1577 1578 cache 1579 | 1580 Map(resize) 1581 | 1582 Mnist 1583 """ 1584 1585 logger.info("Test cache map mnist2") 1586 if "SESSION_ID" in os.environ: 1587 session_id = int(os.environ['SESSION_ID']) 1588 else: 1589 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1590 1591 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1592 ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) 1593 1594 resize_op = c_vision.Resize((224, 224)) 1595 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 1596 1597 num_epoch = 4 1598 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1599 1600 epoch_count = 0 1601 for _ in range(num_epoch): 1602 assert sum([1 for _ in iter1]) == 10 1603 epoch_count += 1 1604 assert epoch_count == num_epoch 1605 1606 logger.info("test_cache_map_mnist2 Ended.\n") 1607 1608 1609@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1610def test_cache_map_celeba1(): 1611 """ 1612 Test mappable celeba leaf with cache op right over the leaf 1613 1614 cache 1615 | 1616 CelebA 1617 """ 1618 1619 logger.info("Test cache map celeba1") 1620 if "SESSION_ID" in os.environ: 1621 session_id = int(os.environ['SESSION_ID']) 1622 else: 1623 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1624 1625 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1626 1627 # This dataset has 4 records 1628 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache) 1629 1630 num_epoch = 4 1631 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1632 1633 epoch_count = 0 1634 for _ in range(num_epoch): 1635 assert sum([1 for _ in iter1]) == 4 1636 epoch_count += 1 1637 assert epoch_count == num_epoch 1638 1639 logger.info("test_cache_map_celeba1 Ended.\n") 1640 1641 1642@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1643def test_cache_map_celeba2(): 1644 """ 1645 Test mappable celeba leaf with the cache op later in the tree above the map(resize) 1646 1647 cache 1648 | 1649 Map(resize) 1650 | 1651 CelebA 1652 """ 1653 1654 logger.info("Test cache map celeba2") 1655 if "SESSION_ID" in os.environ: 1656 session_id = int(os.environ['SESSION_ID']) 1657 else: 1658 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1659 1660 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1661 1662 # This dataset has 4 records 1663 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) 1664 resize_op = c_vision.Resize((224, 224)) 1665 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 1666 1667 num_epoch = 4 1668 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1669 1670 epoch_count = 0 1671 for _ in range(num_epoch): 1672 assert sum([1 for _ in iter1]) == 4 1673 epoch_count += 1 1674 assert epoch_count == num_epoch 1675 1676 logger.info("test_cache_map_celeba2 Ended.\n") 1677 1678 1679@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1680def test_cache_map_manifest1(): 1681 """ 1682 Test mappable manifest leaf with cache op right over the leaf 1683 1684 cache 1685 | 1686 Manifest 1687 """ 1688 1689 logger.info("Test cache map manifest1") 1690 if "SESSION_ID" in os.environ: 1691 session_id = int(os.environ['SESSION_ID']) 1692 else: 1693 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1694 1695 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1696 1697 # This dataset has 4 records 1698 ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) 1699 1700 num_epoch = 4 1701 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1702 1703 epoch_count = 0 1704 for _ in range(num_epoch): 1705 assert sum([1 for _ in iter1]) == 4 1706 epoch_count += 1 1707 assert epoch_count == num_epoch 1708 1709 logger.info("test_cache_map_manifest1 Ended.\n") 1710 1711 1712@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1713def test_cache_map_manifest2(): 1714 """ 1715 Test mappable manifest leaf with the cache op later in the tree above the map(resize) 1716 1717 cache 1718 | 1719 Map(resize) 1720 | 1721 Manifest 1722 """ 1723 1724 logger.info("Test cache map manifest2") 1725 if "SESSION_ID" in os.environ: 1726 session_id = int(os.environ['SESSION_ID']) 1727 else: 1728 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1729 1730 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1731 1732 # This dataset has 4 records 1733 ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) 1734 resize_op = c_vision.Resize((224, 224)) 1735 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 1736 1737 num_epoch = 4 1738 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1739 1740 epoch_count = 0 1741 for _ in range(num_epoch): 1742 assert sum([1 for _ in iter1]) == 4 1743 epoch_count += 1 1744 assert epoch_count == num_epoch 1745 1746 logger.info("test_cache_map_manifest2 Ended.\n") 1747 1748 1749@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1750def test_cache_map_cifar1(): 1751 """ 1752 Test mappable cifar10 leaf with cache op right over the leaf 1753 1754 cache 1755 | 1756 Cifar10 1757 """ 1758 1759 logger.info("Test cache map cifar1") 1760 if "SESSION_ID" in os.environ: 1761 session_id = int(os.environ['SESSION_ID']) 1762 else: 1763 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1764 1765 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1766 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) 1767 1768 num_epoch = 4 1769 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1770 1771 epoch_count = 0 1772 for _ in range(num_epoch): 1773 assert sum([1 for _ in iter1]) == 10 1774 epoch_count += 1 1775 assert epoch_count == num_epoch 1776 1777 logger.info("test_cache_map_cifar1 Ended.\n") 1778 1779 1780@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1781def test_cache_map_cifar2(): 1782 """ 1783 Test mappable cifar100 leaf with the cache op later in the tree above the map(resize) 1784 1785 cache 1786 | 1787 Map(resize) 1788 | 1789 Cifar100 1790 """ 1791 1792 logger.info("Test cache map cifar2") 1793 if "SESSION_ID" in os.environ: 1794 session_id = int(os.environ['SESSION_ID']) 1795 else: 1796 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1797 1798 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1799 1800 ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) 1801 resize_op = c_vision.Resize((224, 224)) 1802 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 1803 1804 num_epoch = 4 1805 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1806 1807 epoch_count = 0 1808 for _ in range(num_epoch): 1809 assert sum([1 for _ in iter1]) == 10 1810 epoch_count += 1 1811 assert epoch_count == num_epoch 1812 1813 logger.info("test_cache_map_cifar2 Ended.\n") 1814 1815 1816@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1817def test_cache_map_cifar3(): 1818 """ 1819 Test mappable cifar10 leaf with the cache op later in the tree above the map(resize) 1820 In this case, we set a extra-small size for cache (size=1) and there are 10000 rows in the dataset. 1821 1822 cache 1823 | 1824 Cifar10 1825 """ 1826 1827 logger.info("Test cache map cifar3") 1828 if "SESSION_ID" in os.environ: 1829 session_id = int(os.environ['SESSION_ID']) 1830 else: 1831 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1832 1833 some_cache = ds.DatasetCache(session_id=session_id, size=1) 1834 1835 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) 1836 1837 num_epoch = 2 1838 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1839 1840 epoch_count = 0 1841 for _ in range(num_epoch): 1842 assert sum([1 for _ in iter1]) == 10000 1843 epoch_count += 1 1844 assert epoch_count == num_epoch 1845 1846 logger.info("test_cache_map_cifar3 Ended.\n") 1847 1848 1849@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1850def test_cache_map_cifar4(): 1851 """ 1852 Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op 1853 1854 shuffle 1855 | 1856 cache 1857 | 1858 Cifar10 1859 """ 1860 1861 logger.info("Test cache map cifar4") 1862 if "SESSION_ID" in os.environ: 1863 session_id = int(os.environ['SESSION_ID']) 1864 else: 1865 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1866 1867 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1868 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) 1869 ds1 = ds1.shuffle(10) 1870 1871 num_epoch = 1 1872 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1873 1874 epoch_count = 0 1875 for _ in range(num_epoch): 1876 assert sum([1 for _ in iter1]) == 10 1877 epoch_count += 1 1878 assert epoch_count == num_epoch 1879 1880 logger.info("test_cache_map_cifar4 Ended.\n") 1881 1882 1883@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1884def test_cache_map_voc1(): 1885 """ 1886 Test mappable voc leaf with cache op right over the leaf 1887 1888 cache 1889 | 1890 VOC 1891 """ 1892 1893 logger.info("Test cache map voc1") 1894 if "SESSION_ID" in os.environ: 1895 session_id = int(os.environ['SESSION_ID']) 1896 else: 1897 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1898 1899 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1900 1901 # This dataset has 9 records 1902 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache) 1903 1904 num_epoch = 4 1905 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1906 1907 epoch_count = 0 1908 for _ in range(num_epoch): 1909 assert sum([1 for _ in iter1]) == 9 1910 epoch_count += 1 1911 assert epoch_count == num_epoch 1912 1913 logger.info("test_cache_map_voc1 Ended.\n") 1914 1915 1916@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1917def test_cache_map_voc2(): 1918 """ 1919 Test mappable voc leaf with the cache op later in the tree above the map(resize) 1920 1921 cache 1922 | 1923 Map(resize) 1924 | 1925 VOC 1926 """ 1927 1928 logger.info("Test cache map voc2") 1929 if "SESSION_ID" in os.environ: 1930 session_id = int(os.environ['SESSION_ID']) 1931 else: 1932 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1933 1934 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1935 1936 # This dataset has 9 records 1937 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) 1938 resize_op = c_vision.Resize((224, 224)) 1939 ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache) 1940 1941 num_epoch = 4 1942 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1943 1944 epoch_count = 0 1945 for _ in range(num_epoch): 1946 assert sum([1 for _ in iter1]) == 9 1947 epoch_count += 1 1948 assert epoch_count == num_epoch 1949 1950 logger.info("test_cache_map_voc2 Ended.\n") 1951 1952 1953class ReverseSampler(ds.Sampler): 1954 def __iter__(self): 1955 for i in range(self.dataset_size - 1, -1, -1): 1956 yield i 1957 1958 1959@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1960def test_cache_map_mindrecord1(): 1961 """ 1962 Test mappable mindrecord leaf with cache op right over the leaf 1963 1964 cache 1965 | 1966 MindRecord 1967 """ 1968 1969 logger.info("Test cache map mindrecord1") 1970 if "SESSION_ID" in os.environ: 1971 session_id = int(os.environ['SESSION_ID']) 1972 else: 1973 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1974 1975 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1976 1977 # This dataset has 5 records 1978 columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] 1979 ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, cache=some_cache) 1980 1981 num_epoch = 4 1982 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 1983 1984 epoch_count = 0 1985 for _ in range(num_epoch): 1986 assert sum([1 for _ in iter1]) == 5 1987 epoch_count += 1 1988 assert epoch_count == num_epoch 1989 1990 logger.info("test_cache_map_mindrecord1 Ended.\n") 1991 1992 1993@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1994def test_cache_map_mindrecord2(): 1995 """ 1996 Test mappable mindrecord leaf with the cache op later in the tree above the map(decode) 1997 1998 cache 1999 | 2000 Map(decode) 2001 | 2002 MindRecord 2003 """ 2004 2005 logger.info("Test cache map mindrecord2") 2006 if "SESSION_ID" in os.environ: 2007 session_id = int(os.environ['SESSION_ID']) 2008 else: 2009 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2010 2011 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2012 2013 # This dataset has 5 records 2014 columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] 2015 ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list) 2016 2017 decode_op = c_vision.Decode() 2018 ds1 = ds1.map(input_columns=["img_data"], operations=decode_op, cache=some_cache) 2019 2020 num_epoch = 4 2021 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 2022 2023 epoch_count = 0 2024 for _ in range(num_epoch): 2025 assert sum([1 for _ in iter1]) == 5 2026 epoch_count += 1 2027 assert epoch_count == num_epoch 2028 2029 logger.info("test_cache_map_mindrecord2 Ended.\n") 2030 2031 2032@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2033def test_cache_map_mindrecord3(): 2034 """ 2035 Test cache sharing between the following two pipelines with mindrecord leaf: 2036 2037 Cache Cache 2038 | | 2039 Map(decode) Map(decode) 2040 | | 2041 MindRecord(num_parallel_workers=5) MindRecord(num_parallel_workers=6) 2042 """ 2043 2044 logger.info("Test cache map mindrecord3") 2045 if "SESSION_ID" in os.environ: 2046 session_id = int(os.environ['SESSION_ID']) 2047 else: 2048 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2049 2050 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2051 2052 # This dataset has 5 records 2053 columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] 2054 decode_op = c_vision.Decode() 2055 2056 ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list, num_parallel_workers=5, shuffle=True) 2057 ds1 = ds1.map(input_columns=["img_data"], operations=decode_op, cache=some_cache) 2058 2059 ds2 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list, num_parallel_workers=6, shuffle=True) 2060 ds2 = ds2.map(input_columns=["img_data"], operations=decode_op, cache=some_cache) 2061 2062 iter1 = ds1.create_dict_iterator(num_epochs=1, output_numpy=True) 2063 iter2 = ds2.create_dict_iterator(num_epochs=1, output_numpy=True) 2064 2065 assert sum([1 for _ in iter1]) == 5 2066 assert sum([1 for _ in iter2]) == 5 2067 2068 logger.info("test_cache_map_mindrecord3 Ended.\n") 2069 2070 2071@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2072def test_cache_map_python_sampler1(): 2073 """ 2074 Test using a python sampler, and cache after leaf 2075 2076 Repeat 2077 | 2078 Map(decode) 2079 | 2080 cache 2081 | 2082 ImageFolder 2083 """ 2084 2085 logger.info("Test cache map python sampler1") 2086 if "SESSION_ID" in os.environ: 2087 session_id = int(os.environ['SESSION_ID']) 2088 else: 2089 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2090 2091 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2092 2093 # This DATA_DIR only has 2 images in it 2094 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) 2095 decode_op = c_vision.Decode() 2096 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 2097 ds1 = ds1.repeat(4) 2098 2099 num_iter = 0 2100 for _ in ds1.create_dict_iterator(): 2101 num_iter += 1 2102 logger.info("Number of data in ds1: {} ".format(num_iter)) 2103 assert num_iter == 8 2104 logger.info("test_cache_map_python_sampler1 Ended.\n") 2105 2106 2107@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2108def test_cache_map_python_sampler2(): 2109 """ 2110 Test using a python sampler, and cache after map 2111 2112 Repeat 2113 | 2114 cache 2115 | 2116 Map(decode) 2117 | 2118 ImageFolder 2119 """ 2120 2121 logger.info("Test cache map python sampler2") 2122 if "SESSION_ID" in os.environ: 2123 session_id = int(os.environ['SESSION_ID']) 2124 else: 2125 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2126 2127 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2128 2129 # This DATA_DIR only has 2 images in it 2130 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) 2131 decode_op = c_vision.Decode() 2132 ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) 2133 ds1 = ds1.repeat(4) 2134 2135 num_iter = 0 2136 for _ in ds1.create_dict_iterator(): 2137 num_iter += 1 2138 logger.info("Number of data in ds1: {} ".format(num_iter)) 2139 assert num_iter == 8 2140 logger.info("test_cache_map_python_sampler2 Ended.\n") 2141 2142 2143@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2144def test_cache_map_nested_repeat(): 2145 """ 2146 Test cache on pipeline with nested repeat ops 2147 2148 Repeat 2149 | 2150 Map(decode) 2151 | 2152 Repeat 2153 | 2154 Cache 2155 | 2156 ImageFolder 2157 """ 2158 2159 logger.info("Test cache map nested repeat") 2160 if "SESSION_ID" in os.environ: 2161 session_id = int(os.environ['SESSION_ID']) 2162 else: 2163 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2164 2165 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2166 2167 # This DATA_DIR only has 2 images in it 2168 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 2169 decode_op = c_vision.Decode() 2170 ds1 = ds1.repeat(4) 2171 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 2172 ds1 = ds1.repeat(2) 2173 2174 num_iter = 0 2175 for _ in ds1.create_dict_iterator(num_epochs=1): 2176 logger.info("get data from dataset") 2177 num_iter += 1 2178 2179 logger.info("Number of data in ds1: {} ".format(num_iter)) 2180 assert num_iter == 16 2181 logger.info('test_cache_map_nested_repeat Ended.\n') 2182 2183 2184@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2185def test_cache_map_interrupt_and_rerun(): 2186 """ 2187 Test interrupt a running pipeline and then re-use the same cache to run another pipeline 2188 2189 cache 2190 | 2191 Cifar10 2192 """ 2193 2194 logger.info("Test cache map interrupt and rerun") 2195 if "SESSION_ID" in os.environ: 2196 session_id = int(os.environ['SESSION_ID']) 2197 else: 2198 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2199 2200 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2201 2202 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) 2203 iter1 = ds1.create_dict_iterator() 2204 2205 num_iter = 0 2206 with pytest.raises(AttributeError) as e: 2207 for _ in iter1: 2208 num_iter += 1 2209 if num_iter == 10: 2210 iter1.stop() 2211 assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value) 2212 2213 num_epoch = 2 2214 iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) 2215 epoch_count = 0 2216 for _ in range(num_epoch): 2217 num_iter = 0 2218 for _ in iter2: 2219 num_iter += 1 2220 logger.info("Number of data in ds1: {} ".format(num_iter)) 2221 assert num_iter == 10000 2222 epoch_count += 1 2223 2224 cache_stat = some_cache.get_stat() 2225 assert cache_stat.num_mem_cached == 10000 2226 2227 logger.info("test_cache_map_interrupt_and_rerun Ended.\n") 2228 2229 2230@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2231def test_cache_map_dataset_size1(): 2232 """ 2233 Test get_dataset_size() when cache is injected directly after a mappable leaf 2234 2235 Cache 2236 | 2237 CelebA 2238 """ 2239 2240 logger.info("Test cache map dataset size 1") 2241 if "SESSION_ID" in os.environ: 2242 session_id = int(os.environ['SESSION_ID']) 2243 else: 2244 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2245 2246 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2247 2248 # This dataset has 4 records 2249 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, num_shards=3, shard_id=0, cache=some_cache) 2250 2251 dataset_size = ds1.get_dataset_size() 2252 assert dataset_size == 2 2253 2254 num_iter = 0 2255 for _ in ds1.create_dict_iterator(): 2256 num_iter += 1 2257 2258 logger.info("Number of data in ds1: {} ".format(num_iter)) 2259 assert num_iter == dataset_size 2260 logger.info("test_cache_map_dataset_size1 Ended.\n") 2261 2262 2263@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2264def test_cache_map_dataset_size2(): 2265 """ 2266 Test get_dataset_size() when cache is injected after map 2267 2268 Cache 2269 | 2270 Map(resize) 2271 | 2272 CelebA 2273 """ 2274 2275 logger.info("Test cache map dataset size 2") 2276 if "SESSION_ID" in os.environ: 2277 session_id = int(os.environ['SESSION_ID']) 2278 else: 2279 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2280 2281 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2282 2283 # This dataset has 4 records 2284 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=3, shard_id=0) 2285 resize_op = c_vision.Resize((224, 224)) 2286 ds1 = ds1.map(operations=resize_op, input_columns=["image"], cache=some_cache) 2287 2288 dataset_size = ds1.get_dataset_size() 2289 assert dataset_size == 2 2290 2291 num_iter = 0 2292 for _ in ds1.create_dict_iterator(): 2293 num_iter += 1 2294 2295 logger.info("Number of data in ds1: {} ".format(num_iter)) 2296 assert num_iter == dataset_size 2297 logger.info("test_cache_map_dataset_size2 Ended.\n") 2298 2299 2300if __name__ == '__main__': 2301 # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' 2302 # since cache server is required to be brought up first 2303 test_cache_map_basic1() 2304 test_cache_map_basic2() 2305 test_cache_map_basic3() 2306 test_cache_map_basic4() 2307 test_cache_map_basic5() 2308 test_cache_map_failure1() 2309 test_cache_map_failure2() 2310 test_cache_map_failure3() 2311 test_cache_map_failure4() 2312 test_cache_map_failure5() 2313 test_cache_map_failure7() 2314 test_cache_map_failure8() 2315 test_cache_map_failure9() 2316 test_cache_map_failure10() 2317 test_cache_map_failure11() 2318 test_cache_map_split1() 2319 test_cache_map_split2() 2320 test_cache_map_parameter_check() 2321 test_cache_map_running_twice1() 2322 test_cache_map_running_twice2() 2323 test_cache_map_extra_small_size1() 2324 test_cache_map_extra_small_size2() 2325 test_cache_map_no_image() 2326 test_cache_map_parallel_pipeline1(shard=0) 2327 test_cache_map_parallel_pipeline2(shard=1) 2328 test_cache_map_parallel_workers() 2329 test_cache_map_server_workers_1() 2330 test_cache_map_server_workers_100() 2331 test_cache_map_num_connections_1() 2332 test_cache_map_num_connections_100() 2333 test_cache_map_prefetch_size_1() 2334 test_cache_map_prefetch_size_100() 2335 test_cache_map_to_device() 2336 test_cache_map_epoch_ctrl1() 2337 test_cache_map_epoch_ctrl2() 2338 test_cache_map_epoch_ctrl3() 2339 test_cache_map_coco1() 2340 test_cache_map_coco2() 2341 test_cache_map_mnist1() 2342 test_cache_map_mnist2() 2343 test_cache_map_celeba1() 2344 test_cache_map_celeba2() 2345 test_cache_map_manifest1() 2346 test_cache_map_manifest2() 2347 test_cache_map_cifar1() 2348 test_cache_map_cifar2() 2349 test_cache_map_cifar3() 2350 test_cache_map_cifar4() 2351 test_cache_map_voc1() 2352 test_cache_map_voc2() 2353 test_cache_map_mindrecord1() 2354 test_cache_map_mindrecord2() 2355 test_cache_map_python_sampler1() 2356 test_cache_map_python_sampler2() 2357 test_cache_map_nested_repeat() 2358 test_cache_map_dataset_size1() 2359 test_cache_map_dataset_size2() 2360