1# Copyright 2020-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing cache operation with mappable datasets 17""" 18import os 19import pytest 20import numpy as np 21import mindspore.dataset as ds 22import mindspore.dataset.vision as c_vision 23from mindspore import log as logger 24from util import save_and_check_md5 25 26DATA_DIR = "../data/dataset/testImageNetData/train/" 27COCO_DATA_DIR = "../data/dataset/testCOCO/train/" 28COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json" 29NO_IMAGE_DIR = "../data/dataset/testRandomData/" 30MNIST_DATA_DIR = "../data/dataset/testMnistData/" 31CELEBA_DATA_DIR = "../data/dataset/testCelebAData/" 32VOC_DATA_DIR = "../data/dataset/testVOC2012/" 33MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest" 34CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/" 35CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/" 36MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord" 37GENERATE_GOLDEN = False 38 39 40@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 41def test_cache_map_basic1(): 42 """ 43 Feature: DatasetCache op 44 Description: Test mappable leaf with Cache op right over the leaf 45 46 Repeat 47 | 48 Map(Decode) 49 | 50 Cache 51 | 52 ImageFolder 53 54 Expectation: Passes the md5 check test 55 """ 56 logger.info("Test cache map basic 1") 57 if "SESSION_ID" in os.environ: 58 session_id = int(os.environ['SESSION_ID']) 59 else: 60 raise RuntimeError("Testcase requires SESSION_ID environment variable") 61 62 some_cache = ds.DatasetCache(session_id=session_id, size=0) 63 64 # This DATA_DIR only has 2 images in it 65 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 66 decode_op = c_vision.Decode() 67 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 68 ds1 = ds1.repeat(4) 69 70 filename = "cache_map_01_result.npz" 71 save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) 72 73 logger.info("test_cache_map_basic1 Ended.\n") 74 75 76@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 77def test_cache_map_basic2(): 78 """ 79 Feature: DatasetCache op 80 Description: Test mappable leaf with the Cache op later in the tree above the Map (Decode) 81 82 Repeat 83 | 84 Cache 85 | 86 Map(Decode) 87 | 88 ImageFolder 89 90 Expectation: Passes the md5 check test 91 """ 92 logger.info("Test cache map basic 2") 93 if "SESSION_ID" in os.environ: 94 session_id = int(os.environ['SESSION_ID']) 95 else: 96 raise RuntimeError("Testcase requires SESSION_ID environment variable") 97 98 some_cache = ds.DatasetCache(session_id=session_id, size=0) 99 100 # This DATA_DIR only has 2 images in it 101 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 102 decode_op = c_vision.Decode() 103 ds1 = ds1.map(operations=decode_op, input_columns=[ 104 "image"], cache=some_cache) 105 ds1 = ds1.repeat(4) 106 107 filename = "cache_map_02_result.npz" 108 save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) 109 110 logger.info("test_cache_map_basic2 Ended.\n") 111 112 113@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 114def test_cache_map_basic3(): 115 """ 116 Feature: DatasetCache op 117 Description: Test mappable leaf with the Cache op later in the tree below the Map (Decode) 118 Expectation: Runs successfully 119 """ 120 logger.info("Test cache basic 3") 121 if "SESSION_ID" in os.environ: 122 session_id = int(os.environ['SESSION_ID']) 123 else: 124 raise RuntimeError("Testcase requires SESSION_ID environment variable") 125 some_cache = ds.DatasetCache(session_id=session_id, size=0) 126 127 # This DATA_DIR only has 2 images in it 128 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 129 decode_op = c_vision.Decode() 130 ds1 = ds1.repeat(4) 131 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 132 logger.info("ds1.dataset_size is ", ds1.get_dataset_size()) 133 shape = ds1.output_shapes() 134 logger.info(shape) 135 num_iter = 0 136 for _ in ds1.create_dict_iterator(num_epochs=1): 137 logger.info("get data from dataset") 138 num_iter += 1 139 140 logger.info("Number of data in ds1: {} ".format(num_iter)) 141 assert num_iter == 8 142 logger.info('test_cache_basic3 Ended.\n') 143 144 145@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 146def test_cache_map_basic4(): 147 """ 148 Feature: DatasetCache op 149 Description: Test Map containing random op above Cache 150 151 Repeat 152 | 153 Map(Decode, RandomCrop) 154 | 155 Cache 156 | 157 ImageFolder 158 159 Expectation: Runs successfully 160 """ 161 logger.info("Test cache basic 4") 162 if "SESSION_ID" in os.environ: 163 session_id = int(os.environ['SESSION_ID']) 164 else: 165 raise RuntimeError("Testcase requires SESSION_ID environment variable") 166 167 some_cache = ds.DatasetCache(session_id=session_id, size=0) 168 169 # This DATA_DIR only has 2 images in it 170 data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 171 random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) 172 decode_op = c_vision.Decode() 173 174 data = data.map(input_columns=["image"], operations=decode_op) 175 data = data.map(input_columns=["image"], operations=random_crop_op) 176 data = data.repeat(4) 177 178 num_iter = 0 179 for _ in data.create_dict_iterator(num_epochs=1): 180 num_iter += 1 181 182 logger.info("Number of data in ds1: {} ".format(num_iter)) 183 assert num_iter == 8 184 logger.info('test_cache_basic4 Ended.\n') 185 186 187@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 188def test_cache_map_basic5(): 189 """ 190 Feature: DatasetCache op 191 Description: Test Cache as root node 192 193 Cache 194 | 195 ImageFolder 196 197 Expectation: Runs successfully 198 """ 199 logger.info("Test cache basic 5") 200 if "SESSION_ID" in os.environ: 201 session_id = int(os.environ['SESSION_ID']) 202 else: 203 raise RuntimeError("Testcase requires SESSION_ID environment variable") 204 some_cache = ds.DatasetCache(session_id=session_id, size=0) 205 206 # This DATA_DIR only has 2 images in it 207 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 208 num_iter = 0 209 for _ in ds1.create_dict_iterator(num_epochs=1): 210 logger.info("get data from dataset") 211 num_iter += 1 212 213 logger.info("Number of data in ds1: {} ".format(num_iter)) 214 assert num_iter == 2 215 logger.info('test_cache_basic5 Ended.\n') 216 217 218@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 219def test_cache_map_failure1(): 220 """ 221 Feature: DatasetCache op 222 Description: Test nested Cache 223 224 Repeat 225 | 226 Cache 227 | 228 Map(Decode) 229 | 230 Cache 231 | 232 Coco 233 234 Expectation: Correct error is raised as expected 235 """ 236 logger.info("Test cache failure 1") 237 if "SESSION_ID" in os.environ: 238 session_id = int(os.environ['SESSION_ID']) 239 else: 240 raise RuntimeError("Testcase requires SESSION_ID environment variable") 241 242 some_cache = ds.DatasetCache(session_id=session_id, size=0) 243 244 # This DATA_DIR has 6 images in it 245 ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, 246 cache=some_cache) 247 decode_op = c_vision.Decode() 248 ds1 = ds1.map(operations=decode_op, input_columns=[ 249 "image"], cache=some_cache) 250 ds1 = ds1.repeat(4) 251 252 with pytest.raises(RuntimeError) as e: 253 ds1.get_batch_size() 254 assert "Nested cache operations" in str(e.value) 255 256 with pytest.raises(RuntimeError) as e: 257 num_iter = 0 258 for _ in ds1.create_dict_iterator(num_epochs=1): 259 num_iter += 1 260 assert "Nested cache operations" in str(e.value) 261 262 assert num_iter == 0 263 logger.info('test_cache_failure1 Ended.\n') 264 265 266@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 267def test_cache_map_failure2(): 268 """ 269 Feature: DatasetCache op 270 Description: Test Zip under Cache 271 272 Repeat 273 | 274 Cache 275 | 276 Map(Decode) 277 | 278 Zip 279 | | 280 ImageFolder ImageFolder 281 282 Expectation: Correct error is raised as expected 283 """ 284 logger.info("Test cache failure 2") 285 if "SESSION_ID" in os.environ: 286 session_id = int(os.environ['SESSION_ID']) 287 else: 288 raise RuntimeError("Testcase requires SESSION_ID environment variable") 289 290 some_cache = ds.DatasetCache(session_id=session_id, size=0) 291 292 # This DATA_DIR only has 2 images in it 293 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 294 ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 295 dsz = ds.zip((ds1, ds2)) 296 decode_op = c_vision.Decode() 297 dsz = dsz.map(input_columns=["image"], 298 operations=decode_op, cache=some_cache) 299 dsz = dsz.repeat(4) 300 301 with pytest.raises(RuntimeError) as e: 302 num_iter = 0 303 for _ in dsz.create_dict_iterator(num_epochs=1): 304 num_iter += 1 305 assert "ZipNode is not supported as a descendant operator under a cache" in str( 306 e.value) 307 308 assert num_iter == 0 309 logger.info('test_cache_failure2 Ended.\n') 310 311 312@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 313def test_cache_map_failure3(): 314 """ 315 Feature: DatasetCache op 316 Description: Test Batch under Cache 317 318 Repeat 319 | 320 Cache 321 | 322 Map(Resize) 323 | 324 Batch 325 | 326 Mnist 327 328 Expectation: Correct error is raised as expected 329 """ 330 logger.info("Test cache failure 3") 331 if "SESSION_ID" in os.environ: 332 session_id = int(os.environ['SESSION_ID']) 333 else: 334 raise RuntimeError("Testcase requires SESSION_ID environment variable") 335 336 some_cache = ds.DatasetCache(session_id=session_id, size=0) 337 338 ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) 339 ds1 = ds1.batch(2) 340 resize_op = c_vision.Resize((224, 224)) 341 ds1 = ds1.map(input_columns=["image"], 342 operations=resize_op, cache=some_cache) 343 ds1 = ds1.repeat(4) 344 345 with pytest.raises(RuntimeError) as e: 346 num_iter = 0 347 for _ in ds1.create_dict_iterator(num_epochs=1): 348 num_iter += 1 349 assert "BatchNode is not supported as a descendant operator under a cache" in str( 350 e.value) 351 352 assert num_iter == 0 353 logger.info('test_cache_failure3 Ended.\n') 354 355 356@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 357def test_cache_map_failure4(): 358 """ 359 Feature: DatasetCache op 360 Description: Test Filter under Cache 361 362 Repeat 363 | 364 Cache 365 | 366 Map(Decode) 367 | 368 Filter 369 | 370 CelebA 371 372 Expectation: Correct error is raised as expected 373 """ 374 logger.info("Test cache failure 4") 375 if "SESSION_ID" in os.environ: 376 session_id = int(os.environ['SESSION_ID']) 377 else: 378 raise RuntimeError("Testcase requires SESSION_ID environment variable") 379 380 some_cache = ds.DatasetCache(session_id=session_id, size=0) 381 382 # This dataset has 4 records 383 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) 384 ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"]) 385 386 decode_op = c_vision.Decode() 387 ds1 = ds1.map(input_columns=["image"], 388 operations=decode_op, cache=some_cache) 389 ds1 = ds1.repeat(4) 390 391 with pytest.raises(RuntimeError) as e: 392 num_iter = 0 393 for _ in ds1.create_dict_iterator(num_epochs=1): 394 num_iter += 1 395 assert "FilterNode is not supported as a descendant operator under a cache" in str( 396 e.value) 397 398 assert num_iter == 0 399 logger.info('test_cache_failure4 Ended.\n') 400 401 402@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 403def test_cache_map_failure5(): 404 """ 405 Feature: DatasetCache op 406 Description: Test Map containing Random operation under Cache 407 408 Repeat 409 | 410 Cache 411 | 412 Map(Decode, RandomCrop) 413 | 414 Manifest 415 416 Expectation: Correct error is raised as expected 417 """ 418 logger.info("Test cache failure 5") 419 if "SESSION_ID" in os.environ: 420 session_id = int(os.environ['SESSION_ID']) 421 else: 422 raise RuntimeError("Testcase requires SESSION_ID environment variable") 423 424 some_cache = ds.DatasetCache(session_id=session_id, size=0) 425 426 # This dataset has 4 records 427 data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) 428 random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200]) 429 decode_op = c_vision.Decode() 430 431 data = data.map(input_columns=["image"], operations=decode_op) 432 data = data.map(input_columns=["image"], 433 operations=random_crop_op, cache=some_cache) 434 data = data.repeat(4) 435 436 with pytest.raises(RuntimeError) as e: 437 num_iter = 0 438 for _ in data.create_dict_iterator(num_epochs=1): 439 num_iter += 1 440 assert "MapNode containing random operation is not supported as a descendant of cache" in str( 441 e.value) 442 443 assert num_iter == 0 444 logger.info('test_cache_failure5 Ended.\n') 445 446 447@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 448def test_cache_map_failure7(): 449 """ 450 Feature: DatasetCache op 451 Description: Test no-cache-supporting Generator leaf with Map under Cache 452 453 Repeat 454 | 455 Cache 456 | 457 Map(lambda x: x) 458 | 459 Generator 460 461 Expectation: Correct error is raised as expected 462 """ 463 def generator_1d(): 464 for i in range(64): 465 yield (np.array(i),) 466 467 logger.info("Test cache failure 7") 468 if "SESSION_ID" in os.environ: 469 session_id = int(os.environ['SESSION_ID']) 470 else: 471 raise RuntimeError("Testcase requires SESSION_ID environment variable") 472 473 some_cache = ds.DatasetCache(session_id=session_id, size=0) 474 475 data = ds.GeneratorDataset(generator_1d, ["data"]) 476 data = data.map(vision.not_random(lambda x: x), ["data"], cache=some_cache) 477 data = data.repeat(4) 478 479 with pytest.raises(RuntimeError) as e: 480 num_iter = 0 481 for _ in data.create_dict_iterator(num_epochs=1): 482 num_iter += 1 483 assert "There is currently no support for GeneratorOp under cache" in str( 484 e.value) 485 486 assert num_iter == 0 487 logger.info('test_cache_failure7 Ended.\n') 488 489 490@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 491def test_cache_map_failure8(): 492 """ 493 Feature: DatasetCache op 494 Description: Test a Repeat under mappable Cache 495 496 Cache 497 | 498 Map(decode) 499 | 500 Repeat 501 | 502 Cifar10 503 504 Expectation: Correct error is raised as expected 505 """ 506 507 logger.info("Test cache failure 8") 508 if "SESSION_ID" in os.environ: 509 session_id = int(os.environ['SESSION_ID']) 510 else: 511 raise RuntimeError("Testcase requires SESSION_ID environment variable") 512 513 some_cache = ds.DatasetCache(session_id=session_id, size=0) 514 515 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10) 516 decode_op = c_vision.Decode() 517 ds1 = ds1.repeat(4) 518 ds1 = ds1.map(operations=decode_op, input_columns=[ 519 "image"], cache=some_cache) 520 521 with pytest.raises(RuntimeError) as e: 522 num_iter = 0 523 for _ in ds1.create_dict_iterator(num_epochs=1): 524 num_iter += 1 525 assert "A cache over a RepeatNode of a mappable dataset is not supported" in str( 526 e.value) 527 528 assert num_iter == 0 529 logger.info('test_cache_failure8 Ended.\n') 530 531 532@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 533def test_cache_map_failure9(): 534 """ 535 Feature: DatasetCache op 536 Description: Test Take under Cache 537 538 Repeat 539 | 540 Cache 541 | 542 Map(Decode) 543 | 544 Take 545 | 546 Cifar100 547 548 Expectation: Correct error is raised as expected 549 """ 550 logger.info("Test cache failure 9") 551 if "SESSION_ID" in os.environ: 552 session_id = int(os.environ['SESSION_ID']) 553 else: 554 raise RuntimeError("Testcase requires SESSION_ID environment variable") 555 556 some_cache = ds.DatasetCache(session_id=session_id, size=0) 557 558 ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) 559 ds1 = ds1.take(2) 560 561 decode_op = c_vision.Decode() 562 ds1 = ds1.map(input_columns=["image"], 563 operations=decode_op, cache=some_cache) 564 ds1 = ds1.repeat(4) 565 566 with pytest.raises(RuntimeError) as e: 567 num_iter = 0 568 for _ in ds1.create_dict_iterator(num_epochs=1): 569 num_iter += 1 570 assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str( 571 e.value) 572 573 assert num_iter == 0 574 logger.info('test_cache_failure9 Ended.\n') 575 576 577@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 578def test_cache_map_failure10(): 579 """ 580 Feature: DatasetCache op 581 Description: Test Skip under Cache 582 583 Repeat 584 | 585 Cache 586 | 587 Map(Decode) 588 | 589 Skip 590 | 591 VOC 592 593 Expectation: Correct error is raised as expected 594 """ 595 logger.info("Test cache failure 10") 596 if "SESSION_ID" in os.environ: 597 session_id = int(os.environ['SESSION_ID']) 598 else: 599 raise RuntimeError("Testcase requires SESSION_ID environment variable") 600 601 some_cache = ds.DatasetCache(session_id=session_id, size=0) 602 603 # This dataset has 9 records 604 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", 605 usage="train", shuffle=False, decode=True) 606 ds1 = ds1.skip(1) 607 608 decode_op = c_vision.Decode() 609 ds1 = ds1.map(input_columns=["image"], 610 operations=decode_op, cache=some_cache) 611 ds1 = ds1.repeat(4) 612 613 with pytest.raises(RuntimeError) as e: 614 num_iter = 0 615 for _ in ds1.create_dict_iterator(num_epochs=1): 616 num_iter += 1 617 assert "SkipNode is not supported as a descendant operator under a cache" in str( 618 e.value) 619 620 assert num_iter == 0 621 logger.info('test_cache_failure10 Ended.\n') 622 623 624@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 625def test_cache_map_failure11(): 626 """ 627 Feature: DatasetCache op 628 Description: Test set spilling=true when Cache server is started without spilling support 629 630 Cache(spilling=true) 631 | 632 ImageFolder 633 634 Expectation: Correct error is raised as expected 635 """ 636 logger.info("Test cache failure 11") 637 if "SESSION_ID" in os.environ: 638 session_id = int(os.environ['SESSION_ID']) 639 else: 640 raise RuntimeError("Testcase requires SESSION_ID environment variable") 641 642 some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) 643 644 # This DATA_DIR only has 2 images in it 645 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 646 647 with pytest.raises(RuntimeError) as e: 648 num_iter = 0 649 for _ in ds1.create_dict_iterator(num_epochs=1): 650 num_iter += 1 651 assert "Server is not set up with spill support" in str( 652 e.value) 653 654 assert num_iter == 0 655 logger.info('test_cache_failure11 Ended.\n') 656 657 658@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 659def test_cache_map_split1(): 660 """ 661 Feature: DatasetCache op 662 Description: Test split (after a non-source node, which is implemented with TakeOp/SkipOp) under Cache 663 664 Repeat 665 | 666 Cache 667 | 668 Map(Resize) 669 | 670 Split 671 | 672 Map(Decode) 673 | 674 ImageFolder 675 676 Expectation: Correct error is raised as expected 677 """ 678 logger.info("Test cache split 1") 679 if "SESSION_ID" in os.environ: 680 session_id = int(os.environ['SESSION_ID']) 681 else: 682 raise RuntimeError("Testcase requires SESSION_ID environment variable") 683 684 some_cache = ds.DatasetCache(session_id=session_id, size=0) 685 686 # This DATA_DIR only has 2 images in it 687 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 688 689 decode_op = c_vision.Decode() 690 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 691 ds1, ds2 = ds1.split([0.5, 0.5]) 692 resize_op = c_vision.Resize((224, 224)) 693 ds1 = ds1.map(input_columns=["image"], 694 operations=resize_op, cache=some_cache) 695 ds2 = ds2.map(input_columns=["image"], 696 operations=resize_op, cache=some_cache) 697 ds1 = ds1.repeat(4) 698 ds2 = ds2.repeat(4) 699 700 with pytest.raises(RuntimeError) as e: 701 num_iter = 0 702 for _ in ds1.create_dict_iterator(num_epochs=1): 703 num_iter += 1 704 assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str( 705 e.value) 706 707 with pytest.raises(RuntimeError) as e: 708 num_iter = 0 709 for _ in ds2.create_dict_iterator(num_epochs=1): 710 num_iter += 1 711 assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str( 712 e.value) 713 logger.info('test_cache_split1 Ended.\n') 714 715 716@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 717def test_cache_map_split2(): 718 """ 719 Feature: DatasetCache op 720 Description: Test split (after a source node, which is implemented with subset sampler) under Cache 721 722 Repeat 723 | 724 Cache 725 | 726 Map(Resize) 727 | 728 Split 729 | 730 VOCDataset 731 732 Expectation: Output is equal to the expected output 733 """ 734 logger.info("Test cache split 2") 735 if "SESSION_ID" in os.environ: 736 session_id = int(os.environ['SESSION_ID']) 737 else: 738 raise RuntimeError("Testcase requires SESSION_ID environment variable") 739 740 some_cache = ds.DatasetCache(session_id=session_id, size=0) 741 742 # This dataset has 9 records 743 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", 744 usage="train", shuffle=False, decode=True) 745 746 ds1, ds2 = ds1.split([0.3, 0.7]) 747 resize_op = c_vision.Resize((224, 224)) 748 ds1 = ds1.map(input_columns=["image"], 749 operations=resize_op, cache=some_cache) 750 ds2 = ds2.map(input_columns=["image"], 751 operations=resize_op, cache=some_cache) 752 ds1 = ds1.repeat(4) 753 ds2 = ds2.repeat(4) 754 755 num_iter = 0 756 for _ in ds1.create_dict_iterator(num_epochs=1): 757 num_iter += 1 758 assert num_iter == 12 759 760 num_iter = 0 761 for _ in ds2.create_dict_iterator(num_epochs=1): 762 num_iter += 1 763 assert num_iter == 24 764 logger.info('test_cache_split2 Ended.\n') 765 766 767@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 768def test_cache_map_parameter_check(): 769 """ 770 Feature: DatasetCache op 771 Description: Test illegal parameters for DatasetCache op 772 Expectation: Correct error is raised as expected 773 """ 774 logger.info("Test cache map parameter check") 775 776 with pytest.raises(ValueError) as info: 777 ds.DatasetCache(session_id=-1, size=0) 778 assert "Input is not within the required interval" in str(info.value) 779 780 with pytest.raises(TypeError) as info: 781 ds.DatasetCache(session_id="1", size=0) 782 assert "Argument session_id with value 1 is not of type" in str(info.value) 783 784 with pytest.raises(TypeError) as info: 785 ds.DatasetCache(session_id=None, size=0) 786 assert "Argument session_id with value None is not of type" in str( 787 info.value) 788 789 with pytest.raises(ValueError) as info: 790 ds.DatasetCache(session_id=1, size=-1) 791 assert "Input size must be greater than 0" in str(info.value) 792 793 with pytest.raises(TypeError) as info: 794 ds.DatasetCache(session_id=1, size="1") 795 assert "Argument size with value 1 is not of type" in str(info.value) 796 797 with pytest.raises(TypeError) as info: 798 ds.DatasetCache(session_id=1, size=None) 799 assert "Argument size with value None is not of type" in str(info.value) 800 801 with pytest.raises(TypeError) as info: 802 ds.DatasetCache(session_id=1, size=0, spilling="illegal") 803 assert "Argument spilling with value illegal is not of type" in str( 804 info.value) 805 806 with pytest.raises(TypeError) as err: 807 ds.DatasetCache(session_id=1, size=0, hostname=50052) 808 assert "Argument hostname with value 50052 is not of type" in str( 809 err.value) 810 811 with pytest.raises(RuntimeError) as err: 812 ds.DatasetCache(session_id=1, size=0, hostname="illegal") 813 assert "now cache client has to be on the same host with cache server" in str( 814 err.value) 815 816 with pytest.raises(RuntimeError) as err: 817 ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2") 818 assert "now cache client has to be on the same host with cache server" in str( 819 err.value) 820 821 with pytest.raises(TypeError) as info: 822 ds.DatasetCache(session_id=1, size=0, port="illegal") 823 assert "Argument port with value illegal is not of type" in str(info.value) 824 825 with pytest.raises(TypeError) as info: 826 ds.DatasetCache(session_id=1, size=0, port="50052") 827 assert "Argument port with value 50052 is not of type" in str(info.value) 828 829 with pytest.raises(ValueError) as err: 830 ds.DatasetCache(session_id=1, size=0, port=0) 831 assert "Input port is not within the required interval of [1025, 65535]" in str( 832 err.value) 833 834 with pytest.raises(ValueError) as err: 835 ds.DatasetCache(session_id=1, size=0, port=65536) 836 assert "Input port is not within the required interval of [1025, 65535]" in str( 837 err.value) 838 839 with pytest.raises(TypeError) as err: 840 ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True) 841 assert "Argument cache with value True is not of type" in str(err.value) 842 843 logger.info("test_cache_map_parameter_check Ended.\n") 844 845 846@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 847def test_cache_map_running_twice1(): 848 """ 849 Feature: DatasetCache op 850 Description: Test executing the same pipeline for twice (from Python), with Cache injected after Map 851 852 Repeat 853 | 854 Cache 855 | 856 Map(decode) 857 | 858 ImageFolder 859 860 Expectation: Output is the same as expected output 861 """ 862 logger.info("Test cache map running twice 1") 863 if "SESSION_ID" in os.environ: 864 session_id = int(os.environ['SESSION_ID']) 865 else: 866 raise RuntimeError("Testcase requires SESSION_ID environment variable") 867 868 some_cache = ds.DatasetCache(session_id=session_id, size=0) 869 870 # This DATA_DIR only has 2 images in it 871 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 872 decode_op = c_vision.Decode() 873 ds1 = ds1.map(input_columns=["image"], 874 operations=decode_op, cache=some_cache) 875 ds1 = ds1.repeat(4) 876 877 num_iter = 0 878 for _ in ds1.create_dict_iterator(num_epochs=1): 879 num_iter += 1 880 logger.info("Number of data in ds1: {} ".format(num_iter)) 881 assert num_iter == 8 882 883 num_iter = 0 884 for _ in ds1.create_dict_iterator(num_epochs=1): 885 num_iter += 1 886 logger.info("Number of data in ds1: {} ".format(num_iter)) 887 assert num_iter == 8 888 889 logger.info("test_cache_map_running_twice1 Ended.\n") 890 891 892@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 893def test_cache_map_running_twice2(): 894 """ 895 Feature: DatasetCache op 896 Description: Test executing the same pipeline for twice (from shell), with Cache injected after leaf 897 898 Repeat 899 | 900 Map(Decode) 901 | 902 Cache 903 | 904 ImageFolder 905 906 Expectation: Output is the same as expected output 907 """ 908 logger.info("Test cache map running twice 2") 909 if "SESSION_ID" in os.environ: 910 session_id = int(os.environ['SESSION_ID']) 911 else: 912 raise RuntimeError("Testcase requires SESSION_ID environment variable") 913 914 some_cache = ds.DatasetCache(session_id=session_id, size=0) 915 916 # This DATA_DIR only has 2 images in it 917 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 918 decode_op = c_vision.Decode() 919 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 920 ds1 = ds1.repeat(4) 921 922 num_iter = 0 923 for _ in ds1.create_dict_iterator(num_epochs=1): 924 num_iter += 1 925 926 logger.info("Number of data in ds1: {} ".format(num_iter)) 927 assert num_iter == 8 928 logger.info("test_cache_map_running_twice2 Ended.\n") 929 930 931@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 932def test_cache_map_extra_small_size1(): 933 """ 934 Feature: DatasetCache op 935 Description: Test running pipeline with Cache of extra small size and spilling=True 936 937 Repeat 938 | 939 Map(Decode) 940 | 941 Cache 942 | 943 ImageFolder 944 945 Expectation: Output is the same as expected output 946 """ 947 logger.info("Test cache map extra small size 1") 948 if "SESSION_ID" in os.environ: 949 session_id = int(os.environ['SESSION_ID']) 950 else: 951 raise RuntimeError("Testcase requires SESSION_ID environment variable") 952 953 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True) 954 955 # This DATA_DIR only has 2 images in it 956 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 957 decode_op = c_vision.Decode() 958 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 959 ds1 = ds1.repeat(4) 960 961 num_iter = 0 962 for _ in ds1.create_dict_iterator(num_epochs=1): 963 num_iter += 1 964 965 logger.info("Number of data in ds1: {} ".format(num_iter)) 966 assert num_iter == 8 967 logger.info("test_cache_map_extra_small_size1 Ended.\n") 968 969 970@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 971def test_cache_map_extra_small_size2(): 972 """ 973 Feature: DatasetCache op 974 Description: Test running pipeline with Cache of extra small size and spilling=False 975 976 Repeat 977 | 978 Cache 979 | 980 Map(Decode) 981 | 982 ImageFolder 983 984 Expectation: Output is the same as expected output 985 """ 986 logger.info("Test cache map extra small size 2") 987 if "SESSION_ID" in os.environ: 988 session_id = int(os.environ['SESSION_ID']) 989 else: 990 raise RuntimeError("Testcase requires SESSION_ID environment variable") 991 992 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) 993 994 # This DATA_DIR only has 2 images in it 995 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 996 decode_op = c_vision.Decode() 997 ds1 = ds1.map(input_columns=["image"], 998 operations=decode_op, cache=some_cache) 999 ds1 = ds1.repeat(4) 1000 1001 num_iter = 0 1002 for _ in ds1.create_dict_iterator(num_epochs=1): 1003 num_iter += 1 1004 1005 logger.info("Number of data in ds1: {} ".format(num_iter)) 1006 assert num_iter == 8 1007 logger.info("test_cache_map_extra_small_size2 Ended.\n") 1008 1009 1010@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1011def test_cache_map_no_image(): 1012 """ 1013 Feature: DatasetCache op 1014 Description: Test Cache with no dataset existing in the path 1015 1016 Repeat 1017 | 1018 Map(Decode) 1019 | 1020 Cache 1021 | 1022 ImageFolder 1023 1024 Expectation: Error is raised as expected 1025 """ 1026 logger.info("Test cache map no image") 1027 if "SESSION_ID" in os.environ: 1028 session_id = int(os.environ['SESSION_ID']) 1029 else: 1030 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1031 1032 some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False) 1033 1034 # This DATA_DIR only has 2 images in it 1035 ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache) 1036 decode_op = c_vision.Decode() 1037 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1038 ds1 = ds1.repeat(4) 1039 1040 with pytest.raises(RuntimeError): 1041 num_iter = 0 1042 for _ in ds1.create_dict_iterator(num_epochs=1): 1043 num_iter += 1 1044 1045 assert num_iter == 0 1046 logger.info("test_cache_map_no_image Ended.\n") 1047 1048 1049@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1050def test_cache_map_parallel_pipeline1(shard): 1051 """ 1052 Feature: DatasetCache op 1053 Description: Test running two parallel pipelines (sharing Cache) with Cache injected after leaf op 1054 1055 Repeat 1056 | 1057 Map(Decode) 1058 | 1059 Cache 1060 | 1061 ImageFolder 1062 1063 Expectation: Output is the same as expected output 1064 """ 1065 logger.info("Test cache map parallel pipeline 1") 1066 if "SESSION_ID" in os.environ: 1067 session_id = int(os.environ['SESSION_ID']) 1068 else: 1069 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1070 1071 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1072 1073 # This DATA_DIR only has 2 images in it 1074 ds1 = ds.ImageFolderDataset( 1075 dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache) 1076 decode_op = c_vision.Decode() 1077 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1078 ds1 = ds1.repeat(4) 1079 1080 num_iter = 0 1081 for _ in ds1.create_dict_iterator(num_epochs=1): 1082 num_iter += 1 1083 1084 logger.info("Number of data in ds1: {} ".format(num_iter)) 1085 assert num_iter == 4 1086 logger.info("test_cache_map_parallel_pipeline1 Ended.\n") 1087 1088 1089@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1090def test_cache_map_parallel_pipeline2(shard): 1091 """ 1092 Feature: DatasetCache op 1093 Description: Test running two parallel pipelines (sharing Cache) with Cache injected after Map op 1094 1095 Repeat 1096 | 1097 Cache 1098 | 1099 Map(Decode) 1100 | 1101 ImageFolder 1102 1103 Expectation: Output is the same as expected output 1104 """ 1105 logger.info("Test cache map parallel pipeline 2") 1106 if "SESSION_ID" in os.environ: 1107 session_id = int(os.environ['SESSION_ID']) 1108 else: 1109 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1110 1111 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1112 1113 # This DATA_DIR only has 2 images in it 1114 ds1 = ds.ImageFolderDataset( 1115 dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard)) 1116 decode_op = c_vision.Decode() 1117 ds1 = ds1.map(input_columns=["image"], 1118 operations=decode_op, cache=some_cache) 1119 ds1 = ds1.repeat(4) 1120 1121 num_iter = 0 1122 for _ in ds1.create_dict_iterator(num_epochs=1): 1123 num_iter += 1 1124 1125 logger.info("Number of data in ds1: {} ".format(num_iter)) 1126 assert num_iter == 4 1127 logger.info("test_cache_map_parallel_pipeline2 Ended.\n") 1128 1129 1130@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1131def test_cache_map_parallel_workers(): 1132 """ 1133 Feature: DatasetCache op 1134 Description: Test Cache with num_parallel_workers > 1 set for Map op and leaf op 1135 1136 Repeat 1137 | 1138 Cache 1139 | 1140 Map(Decode) 1141 | 1142 ImageFolder 1143 1144 Expectation: Output is the same as expected output 1145 """ 1146 logger.info("Test cache map parallel workers") 1147 if "SESSION_ID" in os.environ: 1148 session_id = int(os.environ['SESSION_ID']) 1149 else: 1150 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1151 1152 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1153 1154 # This DATA_DIR only has 2 images in it 1155 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4) 1156 decode_op = c_vision.Decode() 1157 ds1 = ds1.map(input_columns=[ 1158 "image"], operations=decode_op, num_parallel_workers=4, cache=some_cache) 1159 ds1 = ds1.repeat(4) 1160 1161 num_iter = 0 1162 for _ in ds1.create_dict_iterator(num_epochs=1): 1163 num_iter += 1 1164 1165 logger.info("Number of data in ds1: {} ".format(num_iter)) 1166 assert num_iter == 8 1167 logger.info("test_cache_map_parallel_workers Ended.\n") 1168 1169 1170@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1171def test_cache_map_server_workers_1(): 1172 """ 1173 Feature: DatasetCache op 1174 Description: Start Cache server with --workers 1 and then test Cache function 1175 1176 Repeat 1177 | 1178 Cache 1179 | 1180 Map(Decode) 1181 | 1182 ImageFolder 1183 1184 Expectation: Output is the same as expected output 1185 """ 1186 logger.info("Test cache map server workers 1") 1187 if "SESSION_ID" in os.environ: 1188 session_id = int(os.environ['SESSION_ID']) 1189 else: 1190 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1191 1192 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1193 1194 # This DATA_DIR only has 2 images in it 1195 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1196 decode_op = c_vision.Decode() 1197 ds1 = ds1.map(input_columns=["image"], 1198 operations=decode_op, cache=some_cache) 1199 ds1 = ds1.repeat(4) 1200 1201 num_iter = 0 1202 for _ in ds1.create_dict_iterator(num_epochs=1): 1203 num_iter += 1 1204 1205 logger.info("Number of data in ds1: {} ".format(num_iter)) 1206 assert num_iter == 8 1207 logger.info("test_cache_map_server_workers_1 Ended.\n") 1208 1209 1210@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1211def test_cache_map_server_workers_100(): 1212 """ 1213 Feature: DatasetCache op 1214 Description: Start Cache server with --workers 100 and then test Cache function 1215 1216 Repeat 1217 | 1218 Map(decode) 1219 | 1220 Cache 1221 | 1222 ImageFolder 1223 1224 Expectation: Output is the same as expected output 1225 """ 1226 logger.info("Test cache map server workers 100") 1227 if "SESSION_ID" in os.environ: 1228 session_id = int(os.environ['SESSION_ID']) 1229 else: 1230 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1231 1232 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1233 1234 # This DATA_DIR only has 2 images in it 1235 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1236 decode_op = c_vision.Decode() 1237 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1238 ds1 = ds1.repeat(4) 1239 1240 num_iter = 0 1241 for _ in ds1.create_dict_iterator(num_epochs=1): 1242 num_iter += 1 1243 1244 logger.info("Number of data in ds1: {} ".format(num_iter)) 1245 assert num_iter == 8 1246 logger.info("test_cache_map_server_workers_100 Ended.\n") 1247 1248 1249@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1250def test_cache_map_num_connections_1(): 1251 """ 1252 Feature: DatasetCache op 1253 Description: Test setting num_connections=1 in DatasetCache 1254 1255 Repeat 1256 | 1257 Cache 1258 | 1259 Map(Decode) 1260 | 1261 ImageFolder 1262 1263 Expectation: Output is the same as expected output 1264 """ 1265 logger.info("Test cache map num_connections 1") 1266 if "SESSION_ID" in os.environ: 1267 session_id = int(os.environ['SESSION_ID']) 1268 else: 1269 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1270 1271 some_cache = ds.DatasetCache( 1272 session_id=session_id, size=0, num_connections=1) 1273 1274 # This DATA_DIR only has 2 images in it 1275 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1276 decode_op = c_vision.Decode() 1277 ds1 = ds1.map(input_columns=["image"], 1278 operations=decode_op, cache=some_cache) 1279 ds1 = ds1.repeat(4) 1280 1281 num_iter = 0 1282 for _ in ds1.create_dict_iterator(num_epochs=1): 1283 num_iter += 1 1284 1285 logger.info("Number of data in ds1: {} ".format(num_iter)) 1286 assert num_iter == 8 1287 logger.info("test_cache_map_num_connections_1 Ended.\n") 1288 1289 1290@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1291def test_cache_map_num_connections_100(): 1292 """ 1293 Feature: DatasetCache op 1294 Description: Test setting num_connections=100 in DatasetCache 1295 1296 Repeat 1297 | 1298 Map(Decode) 1299 | 1300 Cache 1301 | 1302 ImageFolder 1303 1304 Expectation: Output is the same as expected output 1305 """ 1306 logger.info("Test cache map num_connections 100") 1307 if "SESSION_ID" in os.environ: 1308 session_id = int(os.environ['SESSION_ID']) 1309 else: 1310 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1311 1312 some_cache = ds.DatasetCache( 1313 session_id=session_id, size=0, num_connections=100) 1314 1315 # This DATA_DIR only has 2 images in it 1316 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1317 decode_op = c_vision.Decode() 1318 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1319 ds1 = ds1.repeat(4) 1320 1321 num_iter = 0 1322 for _ in ds1.create_dict_iterator(num_epochs=1): 1323 num_iter += 1 1324 1325 logger.info("Number of data in ds1: {} ".format(num_iter)) 1326 assert num_iter == 8 1327 logger.info("test_cache_map_num_connections_100 Ended.\n") 1328 1329 1330@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1331def test_cache_map_prefetch_size_1(): 1332 """ 1333 Feature: DatasetCache op 1334 Description: Test setting prefetch_size=1 in DatasetCache 1335 1336 Repeat 1337 | 1338 Cache 1339 | 1340 Map(Decode) 1341 | 1342 ImageFolder 1343 1344 Expectation: Output is the same as expected output 1345 """ 1346 logger.info("Test cache map prefetch_size 1") 1347 if "SESSION_ID" in os.environ: 1348 session_id = int(os.environ['SESSION_ID']) 1349 else: 1350 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1351 1352 some_cache = ds.DatasetCache( 1353 session_id=session_id, size=0, prefetch_size=1) 1354 1355 # This DATA_DIR only has 2 images in it 1356 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1357 decode_op = c_vision.Decode() 1358 ds1 = ds1.map(input_columns=["image"], 1359 operations=decode_op, cache=some_cache) 1360 ds1 = ds1.repeat(4) 1361 1362 num_iter = 0 1363 for _ in ds1.create_dict_iterator(num_epochs=1): 1364 num_iter += 1 1365 1366 logger.info("Number of data in ds1: {} ".format(num_iter)) 1367 assert num_iter == 8 1368 logger.info("test_cache_map_prefetch_size_1 Ended.\n") 1369 1370 1371@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1372def test_cache_map_prefetch_size_100(): 1373 """ 1374 Feature: DatasetCache op 1375 Description: Test setting prefetch_size=100 in DatasetCache 1376 1377 Repeat 1378 | 1379 Map(Decode) 1380 | 1381 Cache 1382 | 1383 ImageFolder 1384 1385 Expectation: Output is the same as expected output 1386 """ 1387 logger.info("Test cache map prefetch_size 100") 1388 if "SESSION_ID" in os.environ: 1389 session_id = int(os.environ['SESSION_ID']) 1390 else: 1391 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1392 1393 some_cache = ds.DatasetCache( 1394 session_id=session_id, size=0, prefetch_size=100) 1395 1396 # This DATA_DIR only has 2 images in it 1397 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1398 decode_op = c_vision.Decode() 1399 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1400 ds1 = ds1.repeat(4) 1401 1402 num_iter = 0 1403 for _ in ds1.create_dict_iterator(num_epochs=1): 1404 num_iter += 1 1405 1406 logger.info("Number of data in ds1: {} ".format(num_iter)) 1407 assert num_iter == 8 1408 logger.info("test_cache_map_prefetch_size_100 Ended.\n") 1409 1410 1411@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1412def test_cache_map_device_que(): 1413 """ 1414 Feature: DatasetCache op 1415 Description: Test Cache with device_que 1416 1417 DeviceQueue 1418 | 1419 EpochCtrl 1420 | 1421 Repeat 1422 | 1423 Map(Decode) 1424 | 1425 Cache 1426 | 1427 ImageFolder 1428 1429 Expectation: Output is the same as expected output 1430 """ 1431 logger.info("Test cache map device_que") 1432 if "SESSION_ID" in os.environ: 1433 session_id = int(os.environ['SESSION_ID']) 1434 else: 1435 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1436 1437 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1438 1439 # This DATA_DIR only has 2 images in it 1440 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1441 decode_op = c_vision.Decode() 1442 ds1 = ds1.map(input_columns=["image"], 1443 operations=decode_op, cache=some_cache) 1444 ds1 = ds1.repeat(4) 1445 ds1 = ds1.device_que() 1446 ds1.send() 1447 1448 logger.info("test_cache_map_device_que Ended.\n") 1449 1450 1451@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1452def test_cache_map_epoch_ctrl1(): 1453 """ 1454 Feature: DatasetCache op 1455 Description: Test using two-loops method to run several epochs 1456 1457 Map(Decode) 1458 | 1459 Cache 1460 | 1461 ImageFolder 1462 1463 Expectation: Output is the same as expected output 1464 """ 1465 logger.info("Test cache map epoch ctrl1") 1466 if "SESSION_ID" in os.environ: 1467 session_id = int(os.environ['SESSION_ID']) 1468 else: 1469 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1470 1471 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1472 1473 # This DATA_DIR only has 2 images in it 1474 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1475 decode_op = c_vision.Decode() 1476 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1477 1478 num_epoch = 5 1479 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1480 1481 epoch_count = 0 1482 for _ in range(num_epoch): 1483 row_count = 0 1484 for _ in iter1: 1485 row_count += 1 1486 logger.info("Number of data in ds1: {} ".format(row_count)) 1487 assert row_count == 2 1488 epoch_count += 1 1489 assert epoch_count == num_epoch 1490 logger.info("test_cache_map_epoch_ctrl1 Ended.\n") 1491 1492 1493@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1494def test_cache_map_epoch_ctrl2(): 1495 """ 1496 Feature: DatasetCache op 1497 Description: Test using two-loops method with infinite epochs 1498 1499 Cache 1500 | 1501 Map(Decode) 1502 | 1503 ImageFolder 1504 1505 Expectation: Output is the same as expected output 1506 """ 1507 logger.info("Test cache map epoch ctrl2") 1508 if "SESSION_ID" in os.environ: 1509 session_id = int(os.environ['SESSION_ID']) 1510 else: 1511 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1512 1513 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1514 1515 # This DATA_DIR only has 2 images in it 1516 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR) 1517 decode_op = c_vision.Decode() 1518 ds1 = ds1.map(input_columns=["image"], 1519 operations=decode_op, cache=some_cache) 1520 1521 num_epoch = 5 1522 # iter1 will always assume there is a next epoch and never shutdown 1523 iter1 = ds1.create_dict_iterator(num_epochs=-1) 1524 1525 epoch_count = 0 1526 for _ in range(num_epoch): 1527 row_count = 0 1528 for _ in iter1: 1529 row_count += 1 1530 logger.info("Number of data in ds1: {} ".format(row_count)) 1531 assert row_count == 2 1532 epoch_count += 1 1533 assert epoch_count == num_epoch 1534 1535 # manually stop the iterator 1536 iter1.stop() 1537 logger.info("test_cache_map_epoch_ctrl2 Ended.\n") 1538 1539 1540@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1541def test_cache_map_epoch_ctrl3(): 1542 """ 1543 Feature: DatasetCache op 1544 Description: Test using two-loops method with infinite epochs over repeat 1545 1546 Repeat 1547 | 1548 Map(Decode) 1549 | 1550 Cache 1551 | 1552 ImageFolder 1553 1554 Expectation: Output is the same as expected output 1555 """ 1556 logger.info("Test cache map epoch ctrl3") 1557 if "SESSION_ID" in os.environ: 1558 session_id = int(os.environ['SESSION_ID']) 1559 else: 1560 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1561 1562 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1563 1564 # This DATA_DIR only has 2 images in it 1565 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 1566 decode_op = c_vision.Decode() 1567 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 1568 ds1 = ds1.repeat(2) 1569 1570 num_epoch = 5 1571 # iter1 will always assume there is a next epoch and never shutdown 1572 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1573 1574 epoch_count = 0 1575 for _ in range(num_epoch): 1576 row_count = 0 1577 for _ in iter1: 1578 row_count += 1 1579 logger.info("Number of data in ds1: {} ".format(row_count)) 1580 assert row_count == 4 1581 epoch_count += 1 1582 assert epoch_count == num_epoch 1583 1584 # reply on garbage collector to destroy iter1 1585 1586 logger.info("test_cache_map_epoch_ctrl3 Ended.\n") 1587 1588 1589@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1590def test_cache_map_coco1(): 1591 """ 1592 Feature: DatasetCache op 1593 Description: Test mappable Coco leaf with Cache op right over the leaf 1594 1595 Cache 1596 | 1597 Coco 1598 1599 Expectation: Output is the same as expected output 1600 """ 1601 logger.info("Test cache map coco1") 1602 if "SESSION_ID" in os.environ: 1603 session_id = int(os.environ['SESSION_ID']) 1604 else: 1605 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1606 1607 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1608 1609 # This dataset has 6 records 1610 ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True, 1611 cache=some_cache) 1612 1613 num_epoch = 4 1614 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1615 1616 epoch_count = 0 1617 for _ in range(num_epoch): 1618 assert sum([1 for _ in iter1]) == 6 1619 epoch_count += 1 1620 assert epoch_count == num_epoch 1621 1622 logger.info("test_cache_map_coco1 Ended.\n") 1623 1624 1625@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1626def test_cache_map_coco2(): 1627 """ 1628 Feature: DatasetCache op 1629 Description: Test mappable Coco leaf with the Cache op later in the tree above the Map(Resize) 1630 1631 Cache 1632 | 1633 Map(Resize) 1634 | 1635 Coco 1636 1637 Expectation: Output is the same as expected output 1638 """ 1639 logger.info("Test cache map coco2") 1640 if "SESSION_ID" in os.environ: 1641 session_id = int(os.environ['SESSION_ID']) 1642 else: 1643 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1644 1645 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1646 1647 # This dataset has 6 records 1648 ds1 = ds.CocoDataset( 1649 COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True) 1650 resize_op = c_vision.Resize((224, 224)) 1651 ds1 = ds1.map(input_columns=["image"], 1652 operations=resize_op, cache=some_cache) 1653 1654 num_epoch = 4 1655 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1656 1657 epoch_count = 0 1658 for _ in range(num_epoch): 1659 assert sum([1 for _ in iter1]) == 6 1660 epoch_count += 1 1661 assert epoch_count == num_epoch 1662 1663 logger.info("test_cache_map_coco2 Ended.\n") 1664 1665 1666@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1667def test_cache_map_mnist1(): 1668 """ 1669 Feature: DatasetCache op 1670 Description: Test mappable Mnist leaf with Cache op right over the leaf 1671 1672 Cache 1673 | 1674 Mnist 1675 1676 Expectation: Output is the same as expected output 1677 """ 1678 logger.info("Test cache map mnist1") 1679 if "SESSION_ID" in os.environ: 1680 session_id = int(os.environ['SESSION_ID']) 1681 else: 1682 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1683 1684 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1685 ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache) 1686 1687 num_epoch = 4 1688 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1689 1690 epoch_count = 0 1691 for _ in range(num_epoch): 1692 assert sum([1 for _ in iter1]) == 10 1693 epoch_count += 1 1694 assert epoch_count == num_epoch 1695 1696 logger.info("test_cache_map_mnist1 Ended.\n") 1697 1698 1699@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1700def test_cache_map_mnist2(): 1701 """ 1702 Feature: DatasetCache op 1703 Description: Test mappable Mnist leaf with the Cache op later in the tree above the Map(Resize) 1704 1705 Cache 1706 | 1707 Map(Resize) 1708 | 1709 Mnist 1710 1711 Expectation: Output is the same as expected output 1712 """ 1713 logger.info("Test cache map mnist2") 1714 if "SESSION_ID" in os.environ: 1715 session_id = int(os.environ['SESSION_ID']) 1716 else: 1717 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1718 1719 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1720 ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10) 1721 1722 resize_op = c_vision.Resize((224, 224)) 1723 ds1 = ds1.map(input_columns=["image"], 1724 operations=resize_op, cache=some_cache) 1725 1726 num_epoch = 4 1727 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1728 1729 epoch_count = 0 1730 for _ in range(num_epoch): 1731 assert sum([1 for _ in iter1]) == 10 1732 epoch_count += 1 1733 assert epoch_count == num_epoch 1734 1735 logger.info("test_cache_map_mnist2 Ended.\n") 1736 1737 1738@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1739def test_cache_map_celeba1(): 1740 """ 1741 Feature: DatasetCache op 1742 Description: Test mappable CelebA leaf with Cache op right over the leaf 1743 1744 Cache 1745 | 1746 CelebA 1747 1748 Expectation: Output is the same as expected output 1749 """ 1750 logger.info("Test cache map celeba1") 1751 if "SESSION_ID" in os.environ: 1752 session_id = int(os.environ['SESSION_ID']) 1753 else: 1754 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1755 1756 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1757 1758 # This dataset has 4 records 1759 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, 1760 decode=True, cache=some_cache) 1761 1762 num_epoch = 4 1763 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1764 1765 epoch_count = 0 1766 for _ in range(num_epoch): 1767 assert sum([1 for _ in iter1]) == 4 1768 epoch_count += 1 1769 assert epoch_count == num_epoch 1770 1771 logger.info("test_cache_map_celeba1 Ended.\n") 1772 1773 1774@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1775def test_cache_map_celeba2(): 1776 """ 1777 Feature: DatasetCache op 1778 Description: Test mappable CelebA leaf with the Cache op later in the tree above the Map(Resize) 1779 1780 Cache 1781 | 1782 Map(Resize) 1783 | 1784 CelebA 1785 1786 Expectation: Output is the same as expected output 1787 """ 1788 logger.info("Test cache map celeba2") 1789 if "SESSION_ID" in os.environ: 1790 session_id = int(os.environ['SESSION_ID']) 1791 else: 1792 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1793 1794 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1795 1796 # This dataset has 4 records 1797 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True) 1798 resize_op = c_vision.Resize((224, 224)) 1799 ds1 = ds1.map(input_columns=["image"], 1800 operations=resize_op, cache=some_cache) 1801 1802 num_epoch = 4 1803 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1804 1805 epoch_count = 0 1806 for _ in range(num_epoch): 1807 assert sum([1 for _ in iter1]) == 4 1808 epoch_count += 1 1809 assert epoch_count == num_epoch 1810 1811 logger.info("test_cache_map_celeba2 Ended.\n") 1812 1813 1814@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1815def test_cache_map_manifest1(): 1816 """ 1817 Feature: DatasetCache op 1818 Description: Test mappable Manifest leaf with Cache op right over the leaf 1819 1820 Cache 1821 | 1822 Manifest 1823 1824 Expectation: Output is the same as expected output 1825 """ 1826 logger.info("Test cache map manifest1") 1827 if "SESSION_ID" in os.environ: 1828 session_id = int(os.environ['SESSION_ID']) 1829 else: 1830 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1831 1832 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1833 1834 # This dataset has 4 records 1835 ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache) 1836 1837 num_epoch = 4 1838 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1839 1840 epoch_count = 0 1841 for _ in range(num_epoch): 1842 assert sum([1 for _ in iter1]) == 4 1843 epoch_count += 1 1844 assert epoch_count == num_epoch 1845 1846 logger.info("test_cache_map_manifest1 Ended.\n") 1847 1848 1849@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1850def test_cache_map_manifest2(): 1851 """ 1852 Feature: DatasetCache op 1853 Description: Test mappable Manifest leaf with the Cache op later in the tree above the Map(Resize) 1854 1855 Cache 1856 | 1857 Map(Resize) 1858 | 1859 Manifest 1860 1861 Expectation: Output is the same as expected output 1862 """ 1863 logger.info("Test cache map manifest2") 1864 if "SESSION_ID" in os.environ: 1865 session_id = int(os.environ['SESSION_ID']) 1866 else: 1867 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1868 1869 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1870 1871 # This dataset has 4 records 1872 ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True) 1873 resize_op = c_vision.Resize((224, 224)) 1874 ds1 = ds1.map(input_columns=["image"], 1875 operations=resize_op, cache=some_cache) 1876 1877 num_epoch = 4 1878 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1879 1880 epoch_count = 0 1881 for _ in range(num_epoch): 1882 assert sum([1 for _ in iter1]) == 4 1883 epoch_count += 1 1884 assert epoch_count == num_epoch 1885 1886 logger.info("test_cache_map_manifest2 Ended.\n") 1887 1888 1889@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1890def test_cache_map_cifar1(): 1891 """ 1892 Feature: DatasetCache op 1893 Description: Test mappable Cifar10 leaf with Cache op right over the leaf 1894 1895 Cache 1896 | 1897 Cifar10 1898 1899 Expectation: Output is the same as expected output 1900 """ 1901 logger.info("Test cache map cifar1") 1902 if "SESSION_ID" in os.environ: 1903 session_id = int(os.environ['SESSION_ID']) 1904 else: 1905 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1906 1907 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1908 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) 1909 1910 num_epoch = 4 1911 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1912 1913 epoch_count = 0 1914 for _ in range(num_epoch): 1915 assert sum([1 for _ in iter1]) == 10 1916 epoch_count += 1 1917 assert epoch_count == num_epoch 1918 1919 logger.info("test_cache_map_cifar1 Ended.\n") 1920 1921 1922@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1923def test_cache_map_cifar2(): 1924 """ 1925 Feature: DatasetCache op 1926 Description: Test mappable Cifar100 leaf with the Cache op later in the tree above the Map(Resize) 1927 1928 Cache 1929 | 1930 Map(Resize) 1931 | 1932 Cifar100 1933 1934 Expectation: Output is the same as expected output 1935 """ 1936 logger.info("Test cache map cifar2") 1937 if "SESSION_ID" in os.environ: 1938 session_id = int(os.environ['SESSION_ID']) 1939 else: 1940 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1941 1942 some_cache = ds.DatasetCache(session_id=session_id, size=0) 1943 1944 ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10) 1945 resize_op = c_vision.Resize((224, 224)) 1946 ds1 = ds1.map(input_columns=["image"], 1947 operations=resize_op, cache=some_cache) 1948 1949 num_epoch = 4 1950 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1951 1952 epoch_count = 0 1953 for _ in range(num_epoch): 1954 assert sum([1 for _ in iter1]) == 10 1955 epoch_count += 1 1956 assert epoch_count == num_epoch 1957 1958 logger.info("test_cache_map_cifar2 Ended.\n") 1959 1960 1961@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1962def test_cache_map_cifar3(): 1963 """ 1964 Feature: DatasetCache op 1965 Description: Test mappable Cifar10 leaf with the Cache op, extra-small size (size=1), and 10000 rows in the dataset 1966 1967 Cache 1968 | 1969 Cifar10 1970 1971 Expectation: Output is the same as expected output 1972 """ 1973 logger.info("Test cache map cifar3") 1974 if "SESSION_ID" in os.environ: 1975 session_id = int(os.environ['SESSION_ID']) 1976 else: 1977 raise RuntimeError("Testcase requires SESSION_ID environment variable") 1978 1979 some_cache = ds.DatasetCache(session_id=session_id, size=1) 1980 1981 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) 1982 1983 num_epoch = 2 1984 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 1985 1986 epoch_count = 0 1987 for _ in range(num_epoch): 1988 assert sum([1 for _ in iter1]) == 10000 1989 epoch_count += 1 1990 assert epoch_count == num_epoch 1991 1992 logger.info("test_cache_map_cifar3 Ended.\n") 1993 1994 1995@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 1996def test_cache_map_cifar4(): 1997 """ 1998 Feature: DatasetCache op 1999 Description: Test mappable Cifar10 leaf with Cache op right over the leaf, and Shuffle op over the Cache op 2000 2001 Shuffle 2002 | 2003 Cache 2004 | 2005 Cifar10 2006 2007 Expectation: Output is the same as expected output 2008 """ 2009 logger.info("Test cache map cifar4") 2010 if "SESSION_ID" in os.environ: 2011 session_id = int(os.environ['SESSION_ID']) 2012 else: 2013 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2014 2015 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2016 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) 2017 ds1 = ds1.shuffle(10) 2018 2019 num_epoch = 1 2020 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 2021 2022 epoch_count = 0 2023 for _ in range(num_epoch): 2024 assert sum([1 for _ in iter1]) == 10 2025 epoch_count += 1 2026 assert epoch_count == num_epoch 2027 2028 logger.info("test_cache_map_cifar4 Ended.\n") 2029 2030 2031@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2032def test_cache_map_voc1(): 2033 """ 2034 Feature: DatasetCache op 2035 Description: Test mappable VOC leaf with Cache op right over the leaf 2036 2037 Cache 2038 | 2039 VOC 2040 2041 Expectation: Output is the same as expected output 2042 """ 2043 logger.info("Test cache map voc1") 2044 if "SESSION_ID" in os.environ: 2045 session_id = int(os.environ['SESSION_ID']) 2046 else: 2047 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2048 2049 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2050 2051 # This dataset has 9 records 2052 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", 2053 shuffle=False, decode=True, cache=some_cache) 2054 2055 num_epoch = 4 2056 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 2057 2058 epoch_count = 0 2059 for _ in range(num_epoch): 2060 assert sum([1 for _ in iter1]) == 9 2061 epoch_count += 1 2062 assert epoch_count == num_epoch 2063 2064 logger.info("test_cache_map_voc1 Ended.\n") 2065 2066 2067@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2068def test_cache_map_voc2(): 2069 """ 2070 Feature: DatasetCache op 2071 Description: Test mappable VOC leaf with the Cache op later in the tree above the Map(Resize) 2072 2073 Cache 2074 | 2075 Map(Resize) 2076 | 2077 VOC 2078 2079 Expectation: Output is the same as expected output 2080 """ 2081 logger.info("Test cache map voc2") 2082 if "SESSION_ID" in os.environ: 2083 session_id = int(os.environ['SESSION_ID']) 2084 else: 2085 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2086 2087 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2088 2089 # This dataset has 9 records 2090 ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", 2091 usage="train", shuffle=False, decode=True) 2092 resize_op = c_vision.Resize((224, 224)) 2093 ds1 = ds1.map(input_columns=["image"], 2094 operations=resize_op, cache=some_cache) 2095 2096 num_epoch = 4 2097 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) 2098 2099 epoch_count = 0 2100 for _ in range(num_epoch): 2101 assert sum([1 for _ in iter1]) == 9 2102 epoch_count += 1 2103 assert epoch_count == num_epoch 2104 2105 logger.info("test_cache_map_voc2 Ended.\n") 2106 2107 2108class ReverseSampler(ds.Sampler): 2109 def __iter__(self): 2110 for i in range(self.dataset_size - 1, -1, -1): 2111 yield i 2112 2113 2114@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2115def test_cache_map_mindrecord1(): 2116 """ 2117 Feature: DatasetCache op 2118 Description: Test mappable MindRecord leaf with Cache op right over the leaf 2119 2120 Cache 2121 | 2122 MindRecord 2123 2124 Expectation: Output is the same as expected output 2125 """ 2126 logger.info("Test cache map mindrecord1") 2127 if "SESSION_ID" in os.environ: 2128 session_id = int(os.environ['SESSION_ID']) 2129 else: 2130 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2131 2132 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2133 2134 # This dataset has 5 records 2135 columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] 2136 ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, cache=some_cache) 2137 2138 num_epoch = 4 2139 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 2140 2141 epoch_count = 0 2142 for _ in range(num_epoch): 2143 assert sum([1 for _ in iter1]) == 5 2144 epoch_count += 1 2145 assert epoch_count == num_epoch 2146 2147 logger.info("test_cache_map_mindrecord1 Ended.\n") 2148 2149 2150@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2151def test_cache_map_mindrecord2(): 2152 """ 2153 Feature: DatasetCache op 2154 Description: Test mappable MindRecord leaf with the Cache op later in the tree above the Map(Decode) 2155 2156 Cache 2157 | 2158 Map(Decode) 2159 | 2160 MindRecord 2161 2162 Expectation: Output is the same as expected output 2163 """ 2164 logger.info("Test cache map mindrecord2") 2165 if "SESSION_ID" in os.environ: 2166 session_id = int(os.environ['SESSION_ID']) 2167 else: 2168 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2169 2170 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2171 2172 # This dataset has 5 records 2173 columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] 2174 ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list) 2175 2176 decode_op = c_vision.Decode() 2177 ds1 = ds1.map(input_columns=["img_data"], 2178 operations=decode_op, cache=some_cache) 2179 2180 num_epoch = 4 2181 iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) 2182 2183 epoch_count = 0 2184 for _ in range(num_epoch): 2185 assert sum([1 for _ in iter1]) == 5 2186 epoch_count += 1 2187 assert epoch_count == num_epoch 2188 2189 logger.info("test_cache_map_mindrecord2 Ended.\n") 2190 2191 2192@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2193def test_cache_map_mindrecord3(): 2194 """ 2195 Feature: DatasetCache op 2196 Description: Test Cache sharing between the following two pipelines with MindRecord leaf: 2197 2198 Cache Cache 2199 | | 2200 Map(Decode) Map(Decode) 2201 | | 2202 MindRecord(num_parallel_workers=5) MindRecord(num_parallel_workers=6) 2203 2204 Expectation: Output is the same as expected output 2205 """ 2206 logger.info("Test cache map mindrecord3") 2207 if "SESSION_ID" in os.environ: 2208 session_id = int(os.environ['SESSION_ID']) 2209 else: 2210 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2211 2212 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2213 2214 # This dataset has 5 records 2215 columns_list = ["id", "file_name", "label_name", "img_data", "label_data"] 2216 decode_op = c_vision.Decode() 2217 2218 ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list, 2219 num_parallel_workers=5, shuffle=True) 2220 ds1 = ds1.map(input_columns=["img_data"], 2221 operations=decode_op, cache=some_cache) 2222 2223 ds2 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list, 2224 num_parallel_workers=6, shuffle=True) 2225 ds2 = ds2.map(input_columns=["img_data"], 2226 operations=decode_op, cache=some_cache) 2227 2228 iter1 = ds1.create_dict_iterator(num_epochs=1, output_numpy=True) 2229 iter2 = ds2.create_dict_iterator(num_epochs=1, output_numpy=True) 2230 2231 assert sum([1 for _ in iter1]) == 5 2232 assert sum([1 for _ in iter2]) == 5 2233 2234 logger.info("test_cache_map_mindrecord3 Ended.\n") 2235 2236 2237@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2238def test_cache_map_python_sampler1(): 2239 """ 2240 Feature: DatasetCache op 2241 Description: Test using a Python sampler, and Cache after leaf 2242 2243 Repeat 2244 | 2245 Map(Decode) 2246 | 2247 Cache 2248 | 2249 ImageFolder 2250 2251 Expectation: Output is the same as expected output 2252 """ 2253 logger.info("Test cache map python sampler1") 2254 if "SESSION_ID" in os.environ: 2255 session_id = int(os.environ['SESSION_ID']) 2256 else: 2257 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2258 2259 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2260 2261 # This DATA_DIR only has 2 images in it 2262 ds1 = ds.ImageFolderDataset( 2263 dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache) 2264 decode_op = c_vision.Decode() 2265 ds1 = ds1.map(input_columns=["image"], operations=decode_op) 2266 ds1 = ds1.repeat(4) 2267 2268 num_iter = 0 2269 for _ in ds1.create_dict_iterator(num_epochs=1): 2270 num_iter += 1 2271 logger.info("Number of data in ds1: {} ".format(num_iter)) 2272 assert num_iter == 8 2273 logger.info("test_cache_map_python_sampler1 Ended.\n") 2274 2275 2276@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2277def test_cache_map_python_sampler2(): 2278 """ 2279 Feature: DatasetCache op 2280 Description: Test using a Python sampler, and Cache after Map 2281 2282 Repeat 2283 | 2284 Cache 2285 | 2286 Map(Decode) 2287 | 2288 ImageFolder 2289 2290 Expectation: Output is the same as expected output 2291 """ 2292 logger.info("Test cache map python sampler2") 2293 if "SESSION_ID" in os.environ: 2294 session_id = int(os.environ['SESSION_ID']) 2295 else: 2296 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2297 2298 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2299 2300 # This DATA_DIR only has 2 images in it 2301 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler()) 2302 decode_op = c_vision.Decode() 2303 ds1 = ds1.map(input_columns=["image"], 2304 operations=decode_op, cache=some_cache) 2305 ds1 = ds1.repeat(4) 2306 2307 num_iter = 0 2308 for _ in ds1.create_dict_iterator(num_epochs=1): 2309 num_iter += 1 2310 logger.info("Number of data in ds1: {} ".format(num_iter)) 2311 assert num_iter == 8 2312 logger.info("test_cache_map_python_sampler2 Ended.\n") 2313 2314 2315@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2316def test_cache_map_nested_repeat(): 2317 """ 2318 Feature: DatasetCache op 2319 Description: Test Cache on pipeline with nested Repeat ops 2320 2321 Repeat 2322 | 2323 Map(Decode) 2324 | 2325 Repeat 2326 | 2327 Cache 2328 | 2329 ImageFolder 2330 2331 Expectation: Output is the same as expected output 2332 """ 2333 logger.info("Test cache map nested repeat") 2334 if "SESSION_ID" in os.environ: 2335 session_id = int(os.environ['SESSION_ID']) 2336 else: 2337 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2338 2339 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2340 2341 # This DATA_DIR only has 2 images in it 2342 ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache) 2343 decode_op = c_vision.Decode() 2344 ds1 = ds1.repeat(4) 2345 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 2346 ds1 = ds1.repeat(2) 2347 2348 num_iter = 0 2349 for _ in ds1.create_dict_iterator(num_epochs=1): 2350 logger.info("get data from dataset") 2351 num_iter += 1 2352 2353 logger.info("Number of data in ds1: {} ".format(num_iter)) 2354 assert num_iter == 16 2355 logger.info('test_cache_map_nested_repeat Ended.\n') 2356 2357 2358@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2359def test_cache_map_interrupt_and_rerun(): 2360 """ 2361 Feature: DatasetCache op 2362 Description: Test interrupt a running pipeline and then re-use the same Cache to run another pipeline 2363 2364 Cache 2365 | 2366 Cifar10 2367 2368 Expectation: Error is raised when interrupted and output is the same as expected output after the rerun 2369 """ 2370 logger.info("Test cache map interrupt and rerun") 2371 if "SESSION_ID" in os.environ: 2372 session_id = int(os.environ['SESSION_ID']) 2373 else: 2374 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2375 2376 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2377 2378 ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache) 2379 iter1 = ds1.create_dict_iterator(num_epochs=-1) 2380 2381 num_iter = 0 2382 with pytest.raises(AttributeError) as e: 2383 for _ in iter1: 2384 num_iter += 1 2385 if num_iter == 10: 2386 iter1.stop() 2387 assert "'DictIterator' object has no attribute '_runtime_context'" in str( 2388 e.value) 2389 2390 num_epoch = 2 2391 iter2 = ds1.create_dict_iterator(num_epochs=num_epoch) 2392 epoch_count = 0 2393 for _ in range(num_epoch): 2394 num_iter = 0 2395 for _ in iter2: 2396 num_iter += 1 2397 logger.info("Number of data in ds1: {} ".format(num_iter)) 2398 assert num_iter == 10000 2399 epoch_count += 1 2400 2401 cache_stat = some_cache.get_stat() 2402 assert cache_stat.num_mem_cached == 10000 2403 2404 logger.info("test_cache_map_interrupt_and_rerun Ended.\n") 2405 2406 2407@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2408def test_cache_map_dataset_size1(): 2409 """ 2410 Feature: DatasetCache op 2411 Description: Test get_dataset_size op when Cache is injected directly after a mappable leaf 2412 2413 Cache 2414 | 2415 CelebA 2416 2417 Expectation: Output is the same as expected output 2418 """ 2419 logger.info("Test cache map dataset size 1") 2420 if "SESSION_ID" in os.environ: 2421 session_id = int(os.environ['SESSION_ID']) 2422 else: 2423 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2424 2425 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2426 2427 # This dataset has 4 records 2428 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, num_shards=3, 2429 shard_id=0, cache=some_cache) 2430 2431 dataset_size = ds1.get_dataset_size() 2432 assert dataset_size == 2 2433 2434 num_iter = 0 2435 for _ in ds1.create_dict_iterator(num_epochs=1): 2436 num_iter += 1 2437 2438 logger.info("Number of data in ds1: {} ".format(num_iter)) 2439 assert num_iter == dataset_size 2440 logger.info("test_cache_map_dataset_size1 Ended.\n") 2441 2442 2443@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") 2444def test_cache_map_dataset_size2(): 2445 """ 2446 Feature: DatasetCache op 2447 Description: Test get_dataset_size op when Cache is injected after Map 2448 2449 Cache 2450 | 2451 Map(Resize) 2452 | 2453 CelebA 2454 2455 Expectation: Output is the same as expected output 2456 """ 2457 logger.info("Test cache map dataset size 2") 2458 if "SESSION_ID" in os.environ: 2459 session_id = int(os.environ['SESSION_ID']) 2460 else: 2461 raise RuntimeError("Testcase requires SESSION_ID environment variable") 2462 2463 some_cache = ds.DatasetCache(session_id=session_id, size=0) 2464 2465 # This dataset has 4 records 2466 ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, 2467 decode=True, num_shards=3, shard_id=0) 2468 resize_op = c_vision.Resize((224, 224)) 2469 ds1 = ds1.map(operations=resize_op, input_columns=[ 2470 "image"], cache=some_cache) 2471 2472 dataset_size = ds1.get_dataset_size() 2473 assert dataset_size == 2 2474 2475 num_iter = 0 2476 for _ in ds1.create_dict_iterator(num_epochs=1): 2477 num_iter += 1 2478 2479 logger.info("Number of data in ds1: {} ".format(num_iter)) 2480 assert num_iter == dataset_size 2481 logger.info("test_cache_map_dataset_size2 Ended.\n") 2482 2483 2484if __name__ == '__main__': 2485 # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' 2486 # since cache server is required to be brought up first 2487 test_cache_map_basic1() 2488 test_cache_map_basic2() 2489 test_cache_map_basic3() 2490 test_cache_map_basic4() 2491 test_cache_map_basic5() 2492 test_cache_map_failure1() 2493 test_cache_map_failure2() 2494 test_cache_map_failure3() 2495 test_cache_map_failure4() 2496 test_cache_map_failure5() 2497 test_cache_map_failure7() 2498 test_cache_map_failure8() 2499 test_cache_map_failure9() 2500 test_cache_map_failure10() 2501 test_cache_map_failure11() 2502 test_cache_map_split1() 2503 test_cache_map_split2() 2504 test_cache_map_parameter_check() 2505 test_cache_map_running_twice1() 2506 test_cache_map_running_twice2() 2507 test_cache_map_extra_small_size1() 2508 test_cache_map_extra_small_size2() 2509 test_cache_map_no_image() 2510 test_cache_map_parallel_pipeline1(shard=0) 2511 test_cache_map_parallel_pipeline2(shard=1) 2512 test_cache_map_parallel_workers() 2513 test_cache_map_server_workers_1() 2514 test_cache_map_server_workers_100() 2515 test_cache_map_num_connections_1() 2516 test_cache_map_num_connections_100() 2517 test_cache_map_prefetch_size_1() 2518 test_cache_map_prefetch_size_100() 2519 test_cache_map_device_que() 2520 test_cache_map_epoch_ctrl1() 2521 test_cache_map_epoch_ctrl2() 2522 test_cache_map_epoch_ctrl3() 2523 test_cache_map_coco1() 2524 test_cache_map_coco2() 2525 test_cache_map_mnist1() 2526 test_cache_map_mnist2() 2527 test_cache_map_celeba1() 2528 test_cache_map_celeba2() 2529 test_cache_map_manifest1() 2530 test_cache_map_manifest2() 2531 test_cache_map_cifar1() 2532 test_cache_map_cifar2() 2533 test_cache_map_cifar3() 2534 test_cache_map_cifar4() 2535 test_cache_map_voc1() 2536 test_cache_map_voc2() 2537 test_cache_map_mindrecord1() 2538 test_cache_map_mindrecord2() 2539 test_cache_map_python_sampler1() 2540 test_cache_map_python_sampler2() 2541 test_cache_map_nested_repeat() 2542 test_cache_map_dataset_size1() 2543 test_cache_map_dataset_size2() 2544