• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing cache operation with mappable datasets
17"""
18import os
19import pytest
20import numpy as np
21import mindspore.dataset as ds
22import mindspore.dataset.vision as c_vision
23from mindspore import log as logger
24from util import save_and_check_md5
25
26DATA_DIR = "../data/dataset/testImageNetData/train/"
27COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
28COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
29NO_IMAGE_DIR = "../data/dataset/testRandomData/"
30MNIST_DATA_DIR = "../data/dataset/testMnistData/"
31CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
32VOC_DATA_DIR = "../data/dataset/testVOC2012/"
33MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
34CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
35CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
36MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
37GENERATE_GOLDEN = False
38
39
40@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
41def test_cache_map_basic1():
42    """
43    Feature: DatasetCache op
44    Description: Test mappable leaf with Cache op right over the leaf
45
46       Repeat
47         |
48     Map(Decode)
49         |
50       Cache
51         |
52     ImageFolder
53
54    Expectation: Passes the md5 check test
55    """
56    logger.info("Test cache map basic 1")
57    if "SESSION_ID" in os.environ:
58        session_id = int(os.environ['SESSION_ID'])
59    else:
60        raise RuntimeError("Testcase requires SESSION_ID environment variable")
61
62    some_cache = ds.DatasetCache(session_id=session_id, size=0)
63
64    # This DATA_DIR only has 2 images in it
65    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
66    decode_op = c_vision.Decode()
67    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
68    ds1 = ds1.repeat(4)
69
70    filename = "cache_map_01_result.npz"
71    save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
72
73    logger.info("test_cache_map_basic1 Ended.\n")
74
75
76@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
77def test_cache_map_basic2():
78    """
79    Feature: DatasetCache op
80    Description: Test mappable leaf with the Cache op later in the tree above the Map (Decode)
81
82       Repeat
83         |
84       Cache
85         |
86     Map(Decode)
87         |
88     ImageFolder
89
90    Expectation: Passes the md5 check test
91    """
92    logger.info("Test cache map basic 2")
93    if "SESSION_ID" in os.environ:
94        session_id = int(os.environ['SESSION_ID'])
95    else:
96        raise RuntimeError("Testcase requires SESSION_ID environment variable")
97
98    some_cache = ds.DatasetCache(session_id=session_id, size=0)
99
100    # This DATA_DIR only has 2 images in it
101    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
102    decode_op = c_vision.Decode()
103    ds1 = ds1.map(operations=decode_op, input_columns=[
104        "image"], cache=some_cache)
105    ds1 = ds1.repeat(4)
106
107    filename = "cache_map_02_result.npz"
108    save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
109
110    logger.info("test_cache_map_basic2 Ended.\n")
111
112
113@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
114def test_cache_map_basic3():
115    """
116    Feature: DatasetCache op
117    Description: Test mappable leaf with the Cache op later in the tree below the Map (Decode)
118    Expectation: Runs successfully
119    """
120    logger.info("Test cache basic 3")
121    if "SESSION_ID" in os.environ:
122        session_id = int(os.environ['SESSION_ID'])
123    else:
124        raise RuntimeError("Testcase requires SESSION_ID environment variable")
125    some_cache = ds.DatasetCache(session_id=session_id, size=0)
126
127    # This DATA_DIR only has 2 images in it
128    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
129    decode_op = c_vision.Decode()
130    ds1 = ds1.repeat(4)
131    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
132    logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
133    shape = ds1.output_shapes()
134    logger.info(shape)
135    num_iter = 0
136    for _ in ds1.create_dict_iterator(num_epochs=1):
137        logger.info("get data from dataset")
138        num_iter += 1
139
140    logger.info("Number of data in ds1: {} ".format(num_iter))
141    assert num_iter == 8
142    logger.info('test_cache_basic3 Ended.\n')
143
144
145@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
146def test_cache_map_basic4():
147    """
148    Feature: DatasetCache op
149    Description: Test Map containing random op above Cache
150
151                Repeat
152                  |
153             Map(Decode, RandomCrop)
154                  |
155                Cache
156                  |
157             ImageFolder
158
159    Expectation: Runs successfully
160    """
161    logger.info("Test cache basic 4")
162    if "SESSION_ID" in os.environ:
163        session_id = int(os.environ['SESSION_ID'])
164    else:
165        raise RuntimeError("Testcase requires SESSION_ID environment variable")
166
167    some_cache = ds.DatasetCache(session_id=session_id, size=0)
168
169    # This DATA_DIR only has 2 images in it
170    data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
171    random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
172    decode_op = c_vision.Decode()
173
174    data = data.map(input_columns=["image"], operations=decode_op)
175    data = data.map(input_columns=["image"], operations=random_crop_op)
176    data = data.repeat(4)
177
178    num_iter = 0
179    for _ in data.create_dict_iterator(num_epochs=1):
180        num_iter += 1
181
182    logger.info("Number of data in ds1: {} ".format(num_iter))
183    assert num_iter == 8
184    logger.info('test_cache_basic4 Ended.\n')
185
186
187@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
188def test_cache_map_basic5():
189    """
190    Feature: DatasetCache op
191    Description: Test Cache as root node
192
193       Cache
194         |
195      ImageFolder
196
197    Expectation: Runs successfully
198    """
199    logger.info("Test cache basic 5")
200    if "SESSION_ID" in os.environ:
201        session_id = int(os.environ['SESSION_ID'])
202    else:
203        raise RuntimeError("Testcase requires SESSION_ID environment variable")
204    some_cache = ds.DatasetCache(session_id=session_id, size=0)
205
206    # This DATA_DIR only has 2 images in it
207    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
208    num_iter = 0
209    for _ in ds1.create_dict_iterator(num_epochs=1):
210        logger.info("get data from dataset")
211        num_iter += 1
212
213    logger.info("Number of data in ds1: {} ".format(num_iter))
214    assert num_iter == 2
215    logger.info('test_cache_basic5 Ended.\n')
216
217
218@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
219def test_cache_map_failure1():
220    """
221    Feature: DatasetCache op
222    Description: Test nested Cache
223
224        Repeat
225          |
226        Cache
227          |
228      Map(Decode)
229          |
230        Cache
231          |
232        Coco
233
234    Expectation: Correct error is raised as expected
235    """
236    logger.info("Test cache failure 1")
237    if "SESSION_ID" in os.environ:
238        session_id = int(os.environ['SESSION_ID'])
239    else:
240        raise RuntimeError("Testcase requires SESSION_ID environment variable")
241
242    some_cache = ds.DatasetCache(session_id=session_id, size=0)
243
244    # This DATA_DIR has 6 images in it
245    ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
246                         cache=some_cache)
247    decode_op = c_vision.Decode()
248    ds1 = ds1.map(operations=decode_op, input_columns=[
249        "image"], cache=some_cache)
250    ds1 = ds1.repeat(4)
251
252    with pytest.raises(RuntimeError) as e:
253        ds1.get_batch_size()
254    assert "Nested cache operations" in str(e.value)
255
256    with pytest.raises(RuntimeError) as e:
257        num_iter = 0
258        for _ in ds1.create_dict_iterator(num_epochs=1):
259            num_iter += 1
260    assert "Nested cache operations" in str(e.value)
261
262    assert num_iter == 0
263    logger.info('test_cache_failure1 Ended.\n')
264
265
266@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
267def test_cache_map_failure2():
268    """
269    Feature: DatasetCache op
270    Description: Test Zip under Cache
271
272               Repeat
273                  |
274                Cache
275                  |
276             Map(Decode)
277                  |
278                 Zip
279                |    |
280      ImageFolder     ImageFolder
281
282    Expectation: Correct error is raised as expected
283    """
284    logger.info("Test cache failure 2")
285    if "SESSION_ID" in os.environ:
286        session_id = int(os.environ['SESSION_ID'])
287    else:
288        raise RuntimeError("Testcase requires SESSION_ID environment variable")
289
290    some_cache = ds.DatasetCache(session_id=session_id, size=0)
291
292    # This DATA_DIR only has 2 images in it
293    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
294    ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
295    dsz = ds.zip((ds1, ds2))
296    decode_op = c_vision.Decode()
297    dsz = dsz.map(input_columns=["image"],
298                  operations=decode_op, cache=some_cache)
299    dsz = dsz.repeat(4)
300
301    with pytest.raises(RuntimeError) as e:
302        num_iter = 0
303        for _ in dsz.create_dict_iterator(num_epochs=1):
304            num_iter += 1
305    assert "ZipNode is not supported as a descendant operator under a cache" in str(
306        e.value)
307
308    assert num_iter == 0
309    logger.info('test_cache_failure2 Ended.\n')
310
311
312@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
313def test_cache_map_failure3():
314    """
315    Feature: DatasetCache op
316    Description: Test Batch under Cache
317
318               Repeat
319                  |
320                Cache
321                  |
322             Map(Resize)
323                  |
324                Batch
325                  |
326                Mnist
327
328    Expectation: Correct error is raised as expected
329    """
330    logger.info("Test cache failure 3")
331    if "SESSION_ID" in os.environ:
332        session_id = int(os.environ['SESSION_ID'])
333    else:
334        raise RuntimeError("Testcase requires SESSION_ID environment variable")
335
336    some_cache = ds.DatasetCache(session_id=session_id, size=0)
337
338    ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
339    ds1 = ds1.batch(2)
340    resize_op = c_vision.Resize((224, 224))
341    ds1 = ds1.map(input_columns=["image"],
342                  operations=resize_op, cache=some_cache)
343    ds1 = ds1.repeat(4)
344
345    with pytest.raises(RuntimeError) as e:
346        num_iter = 0
347        for _ in ds1.create_dict_iterator(num_epochs=1):
348            num_iter += 1
349    assert "BatchNode is not supported as a descendant operator under a cache" in str(
350        e.value)
351
352    assert num_iter == 0
353    logger.info('test_cache_failure3 Ended.\n')
354
355
356@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
357def test_cache_map_failure4():
358    """
359    Feature: DatasetCache op
360    Description: Test Filter under Cache
361
362               Repeat
363                  |
364                Cache
365                  |
366             Map(Decode)
367                  |
368                Filter
369                  |
370               CelebA
371
372    Expectation: Correct error is raised as expected
373    """
374    logger.info("Test cache failure 4")
375    if "SESSION_ID" in os.environ:
376        session_id = int(os.environ['SESSION_ID'])
377    else:
378        raise RuntimeError("Testcase requires SESSION_ID environment variable")
379
380    some_cache = ds.DatasetCache(session_id=session_id, size=0)
381
382    # This dataset has 4 records
383    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
384    ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
385
386    decode_op = c_vision.Decode()
387    ds1 = ds1.map(input_columns=["image"],
388                  operations=decode_op, cache=some_cache)
389    ds1 = ds1.repeat(4)
390
391    with pytest.raises(RuntimeError) as e:
392        num_iter = 0
393        for _ in ds1.create_dict_iterator(num_epochs=1):
394            num_iter += 1
395    assert "FilterNode is not supported as a descendant operator under a cache" in str(
396        e.value)
397
398    assert num_iter == 0
399    logger.info('test_cache_failure4 Ended.\n')
400
401
402@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
403def test_cache_map_failure5():
404    """
405    Feature: DatasetCache op
406    Description: Test Map containing Random operation under Cache
407
408               Repeat
409                  |
410                Cache
411                  |
412             Map(Decode, RandomCrop)
413                  |
414              Manifest
415
416    Expectation: Correct error is raised as expected
417    """
418    logger.info("Test cache failure 5")
419    if "SESSION_ID" in os.environ:
420        session_id = int(os.environ['SESSION_ID'])
421    else:
422        raise RuntimeError("Testcase requires SESSION_ID environment variable")
423
424    some_cache = ds.DatasetCache(session_id=session_id, size=0)
425
426    # This dataset has 4 records
427    data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
428    random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
429    decode_op = c_vision.Decode()
430
431    data = data.map(input_columns=["image"], operations=decode_op)
432    data = data.map(input_columns=["image"],
433                    operations=random_crop_op, cache=some_cache)
434    data = data.repeat(4)
435
436    with pytest.raises(RuntimeError) as e:
437        num_iter = 0
438        for _ in data.create_dict_iterator(num_epochs=1):
439            num_iter += 1
440    assert "MapNode containing random operation is not supported as a descendant of cache" in str(
441        e.value)
442
443    assert num_iter == 0
444    logger.info('test_cache_failure5 Ended.\n')
445
446
447@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
448def test_cache_map_failure7():
449    """
450    Feature: DatasetCache op
451    Description: Test no-cache-supporting Generator leaf with Map under Cache
452
453               Repeat
454                  |
455                Cache
456                  |
457            Map(lambda x: x)
458                  |
459              Generator
460
461    Expectation: Correct error is raised as expected
462    """
463    def generator_1d():
464        for i in range(64):
465            yield (np.array(i),)
466
467    logger.info("Test cache failure 7")
468    if "SESSION_ID" in os.environ:
469        session_id = int(os.environ['SESSION_ID'])
470    else:
471        raise RuntimeError("Testcase requires SESSION_ID environment variable")
472
473    some_cache = ds.DatasetCache(session_id=session_id, size=0)
474
475    data = ds.GeneratorDataset(generator_1d, ["data"])
476    data = data.map(vision.not_random(lambda x: x), ["data"], cache=some_cache)
477    data = data.repeat(4)
478
479    with pytest.raises(RuntimeError) as e:
480        num_iter = 0
481        for _ in data.create_dict_iterator(num_epochs=1):
482            num_iter += 1
483    assert "There is currently no support for GeneratorOp under cache" in str(
484        e.value)
485
486    assert num_iter == 0
487    logger.info('test_cache_failure7 Ended.\n')
488
489
490@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
491def test_cache_map_failure8():
492    """
493    Feature: DatasetCache op
494    Description: Test a Repeat under mappable Cache
495
496        Cache
497          |
498      Map(decode)
499          |
500        Repeat
501          |
502       Cifar10
503
504    Expectation: Correct error is raised as expected
505    """
506
507    logger.info("Test cache failure 8")
508    if "SESSION_ID" in os.environ:
509        session_id = int(os.environ['SESSION_ID'])
510    else:
511        raise RuntimeError("Testcase requires SESSION_ID environment variable")
512
513    some_cache = ds.DatasetCache(session_id=session_id, size=0)
514
515    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10)
516    decode_op = c_vision.Decode()
517    ds1 = ds1.repeat(4)
518    ds1 = ds1.map(operations=decode_op, input_columns=[
519        "image"], cache=some_cache)
520
521    with pytest.raises(RuntimeError) as e:
522        num_iter = 0
523        for _ in ds1.create_dict_iterator(num_epochs=1):
524            num_iter += 1
525    assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(
526        e.value)
527
528    assert num_iter == 0
529    logger.info('test_cache_failure8 Ended.\n')
530
531
532@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
533def test_cache_map_failure9():
534    """
535    Feature: DatasetCache op
536    Description: Test Take under Cache
537
538               Repeat
539                  |
540                Cache
541                  |
542             Map(Decode)
543                  |
544                Take
545                  |
546             Cifar100
547
548    Expectation: Correct error is raised as expected
549    """
550    logger.info("Test cache failure 9")
551    if "SESSION_ID" in os.environ:
552        session_id = int(os.environ['SESSION_ID'])
553    else:
554        raise RuntimeError("Testcase requires SESSION_ID environment variable")
555
556    some_cache = ds.DatasetCache(session_id=session_id, size=0)
557
558    ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
559    ds1 = ds1.take(2)
560
561    decode_op = c_vision.Decode()
562    ds1 = ds1.map(input_columns=["image"],
563                  operations=decode_op, cache=some_cache)
564    ds1 = ds1.repeat(4)
565
566    with pytest.raises(RuntimeError) as e:
567        num_iter = 0
568        for _ in ds1.create_dict_iterator(num_epochs=1):
569            num_iter += 1
570    assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(
571        e.value)
572
573    assert num_iter == 0
574    logger.info('test_cache_failure9 Ended.\n')
575
576
577@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
578def test_cache_map_failure10():
579    """
580    Feature: DatasetCache op
581    Description: Test Skip under Cache
582
583               Repeat
584                  |
585                Cache
586                  |
587             Map(Decode)
588                  |
589                Skip
590                  |
591                VOC
592
593    Expectation: Correct error is raised as expected
594    """
595    logger.info("Test cache failure 10")
596    if "SESSION_ID" in os.environ:
597        session_id = int(os.environ['SESSION_ID'])
598    else:
599        raise RuntimeError("Testcase requires SESSION_ID environment variable")
600
601    some_cache = ds.DatasetCache(session_id=session_id, size=0)
602
603    # This dataset has 9 records
604    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection",
605                        usage="train", shuffle=False, decode=True)
606    ds1 = ds1.skip(1)
607
608    decode_op = c_vision.Decode()
609    ds1 = ds1.map(input_columns=["image"],
610                  operations=decode_op, cache=some_cache)
611    ds1 = ds1.repeat(4)
612
613    with pytest.raises(RuntimeError) as e:
614        num_iter = 0
615        for _ in ds1.create_dict_iterator(num_epochs=1):
616            num_iter += 1
617    assert "SkipNode is not supported as a descendant operator under a cache" in str(
618        e.value)
619
620    assert num_iter == 0
621    logger.info('test_cache_failure10 Ended.\n')
622
623
624@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
625def test_cache_map_failure11():
626    """
627    Feature: DatasetCache op
628    Description: Test set spilling=true when Cache server is started without spilling support
629
630         Cache(spilling=true)
631                 |
632             ImageFolder
633
634    Expectation: Correct error is raised as expected
635    """
636    logger.info("Test cache failure 11")
637    if "SESSION_ID" in os.environ:
638        session_id = int(os.environ['SESSION_ID'])
639    else:
640        raise RuntimeError("Testcase requires SESSION_ID environment variable")
641
642    some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
643
644    # This DATA_DIR only has 2 images in it
645    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
646
647    with pytest.raises(RuntimeError) as e:
648        num_iter = 0
649        for _ in ds1.create_dict_iterator(num_epochs=1):
650            num_iter += 1
651    assert "Server is not set up with spill support" in str(
652        e.value)
653
654    assert num_iter == 0
655    logger.info('test_cache_failure11 Ended.\n')
656
657
658@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
659def test_cache_map_split1():
660    """
661    Feature: DatasetCache op
662    Description: Test split (after a non-source node, which is implemented with TakeOp/SkipOp) under Cache
663
664               Repeat
665                  |
666                Cache
667                  |
668             Map(Resize)
669                  |
670                Split
671                  |
672             Map(Decode)
673                  |
674             ImageFolder
675
676    Expectation: Correct error is raised as expected
677    """
678    logger.info("Test cache split 1")
679    if "SESSION_ID" in os.environ:
680        session_id = int(os.environ['SESSION_ID'])
681    else:
682        raise RuntimeError("Testcase requires SESSION_ID environment variable")
683
684    some_cache = ds.DatasetCache(session_id=session_id, size=0)
685
686    # This DATA_DIR only has 2 images in it
687    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
688
689    decode_op = c_vision.Decode()
690    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
691    ds1, ds2 = ds1.split([0.5, 0.5])
692    resize_op = c_vision.Resize((224, 224))
693    ds1 = ds1.map(input_columns=["image"],
694                  operations=resize_op, cache=some_cache)
695    ds2 = ds2.map(input_columns=["image"],
696                  operations=resize_op, cache=some_cache)
697    ds1 = ds1.repeat(4)
698    ds2 = ds2.repeat(4)
699
700    with pytest.raises(RuntimeError) as e:
701        num_iter = 0
702        for _ in ds1.create_dict_iterator(num_epochs=1):
703            num_iter += 1
704    assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(
705        e.value)
706
707    with pytest.raises(RuntimeError) as e:
708        num_iter = 0
709        for _ in ds2.create_dict_iterator(num_epochs=1):
710            num_iter += 1
711    assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(
712        e.value)
713    logger.info('test_cache_split1 Ended.\n')
714
715
716@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
717def test_cache_map_split2():
718    """
719    Feature: DatasetCache op
720    Description: Test split (after a source node, which is implemented with subset sampler) under Cache
721
722               Repeat
723                  |
724                Cache
725                  |
726             Map(Resize)
727                  |
728                Split
729                  |
730             VOCDataset
731
732    Expectation: Output is equal to the expected output
733    """
734    logger.info("Test cache split 2")
735    if "SESSION_ID" in os.environ:
736        session_id = int(os.environ['SESSION_ID'])
737    else:
738        raise RuntimeError("Testcase requires SESSION_ID environment variable")
739
740    some_cache = ds.DatasetCache(session_id=session_id, size=0)
741
742    # This dataset has 9 records
743    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection",
744                        usage="train", shuffle=False, decode=True)
745
746    ds1, ds2 = ds1.split([0.3, 0.7])
747    resize_op = c_vision.Resize((224, 224))
748    ds1 = ds1.map(input_columns=["image"],
749                  operations=resize_op, cache=some_cache)
750    ds2 = ds2.map(input_columns=["image"],
751                  operations=resize_op, cache=some_cache)
752    ds1 = ds1.repeat(4)
753    ds2 = ds2.repeat(4)
754
755    num_iter = 0
756    for _ in ds1.create_dict_iterator(num_epochs=1):
757        num_iter += 1
758    assert num_iter == 12
759
760    num_iter = 0
761    for _ in ds2.create_dict_iterator(num_epochs=1):
762        num_iter += 1
763    assert num_iter == 24
764    logger.info('test_cache_split2 Ended.\n')
765
766
767@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
768def test_cache_map_parameter_check():
769    """
770    Feature: DatasetCache op
771    Description: Test illegal parameters for DatasetCache op
772    Expectation: Correct error is raised as expected
773    """
774    logger.info("Test cache map parameter check")
775
776    with pytest.raises(ValueError) as info:
777        ds.DatasetCache(session_id=-1, size=0)
778    assert "Input is not within the required interval" in str(info.value)
779
780    with pytest.raises(TypeError) as info:
781        ds.DatasetCache(session_id="1", size=0)
782    assert "Argument session_id with value 1 is not of type" in str(info.value)
783
784    with pytest.raises(TypeError) as info:
785        ds.DatasetCache(session_id=None, size=0)
786    assert "Argument session_id with value None is not of type" in str(
787        info.value)
788
789    with pytest.raises(ValueError) as info:
790        ds.DatasetCache(session_id=1, size=-1)
791    assert "Input size must be greater than 0" in str(info.value)
792
793    with pytest.raises(TypeError) as info:
794        ds.DatasetCache(session_id=1, size="1")
795    assert "Argument size with value 1 is not of type" in str(info.value)
796
797    with pytest.raises(TypeError) as info:
798        ds.DatasetCache(session_id=1, size=None)
799    assert "Argument size with value None is not of type" in str(info.value)
800
801    with pytest.raises(TypeError) as info:
802        ds.DatasetCache(session_id=1, size=0, spilling="illegal")
803    assert "Argument spilling with value illegal is not of type" in str(
804        info.value)
805
806    with pytest.raises(TypeError) as err:
807        ds.DatasetCache(session_id=1, size=0, hostname=50052)
808    assert "Argument hostname with value 50052 is not of type" in str(
809        err.value)
810
811    with pytest.raises(RuntimeError) as err:
812        ds.DatasetCache(session_id=1, size=0, hostname="illegal")
813    assert "now cache client has to be on the same host with cache server" in str(
814        err.value)
815
816    with pytest.raises(RuntimeError) as err:
817        ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2")
818    assert "now cache client has to be on the same host with cache server" in str(
819        err.value)
820
821    with pytest.raises(TypeError) as info:
822        ds.DatasetCache(session_id=1, size=0, port="illegal")
823    assert "Argument port with value illegal is not of type" in str(info.value)
824
825    with pytest.raises(TypeError) as info:
826        ds.DatasetCache(session_id=1, size=0, port="50052")
827    assert "Argument port with value 50052 is not of type" in str(info.value)
828
829    with pytest.raises(ValueError) as err:
830        ds.DatasetCache(session_id=1, size=0, port=0)
831    assert "Input port is not within the required interval of [1025, 65535]" in str(
832        err.value)
833
834    with pytest.raises(ValueError) as err:
835        ds.DatasetCache(session_id=1, size=0, port=65536)
836    assert "Input port is not within the required interval of [1025, 65535]" in str(
837        err.value)
838
839    with pytest.raises(TypeError) as err:
840        ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)
841    assert "Argument cache with value True is not of type" in str(err.value)
842
843    logger.info("test_cache_map_parameter_check Ended.\n")
844
845
846@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
847def test_cache_map_running_twice1():
848    """
849    Feature: DatasetCache op
850    Description: Test executing the same pipeline for twice (from Python), with Cache injected after Map
851
852       Repeat
853         |
854       Cache
855         |
856     Map(decode)
857         |
858     ImageFolder
859
860    Expectation: Output is the same as expected output
861    """
862    logger.info("Test cache map running twice 1")
863    if "SESSION_ID" in os.environ:
864        session_id = int(os.environ['SESSION_ID'])
865    else:
866        raise RuntimeError("Testcase requires SESSION_ID environment variable")
867
868    some_cache = ds.DatasetCache(session_id=session_id, size=0)
869
870    # This DATA_DIR only has 2 images in it
871    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
872    decode_op = c_vision.Decode()
873    ds1 = ds1.map(input_columns=["image"],
874                  operations=decode_op, cache=some_cache)
875    ds1 = ds1.repeat(4)
876
877    num_iter = 0
878    for _ in ds1.create_dict_iterator(num_epochs=1):
879        num_iter += 1
880    logger.info("Number of data in ds1: {} ".format(num_iter))
881    assert num_iter == 8
882
883    num_iter = 0
884    for _ in ds1.create_dict_iterator(num_epochs=1):
885        num_iter += 1
886    logger.info("Number of data in ds1: {} ".format(num_iter))
887    assert num_iter == 8
888
889    logger.info("test_cache_map_running_twice1 Ended.\n")
890
891
892@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
893def test_cache_map_running_twice2():
894    """
895    Feature: DatasetCache op
896    Description: Test executing the same pipeline for twice (from shell), with Cache injected after leaf
897
898       Repeat
899         |
900     Map(Decode)
901         |
902       Cache
903         |
904     ImageFolder
905
906    Expectation: Output is the same as expected output
907    """
908    logger.info("Test cache map running twice 2")
909    if "SESSION_ID" in os.environ:
910        session_id = int(os.environ['SESSION_ID'])
911    else:
912        raise RuntimeError("Testcase requires SESSION_ID environment variable")
913
914    some_cache = ds.DatasetCache(session_id=session_id, size=0)
915
916    # This DATA_DIR only has 2 images in it
917    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
918    decode_op = c_vision.Decode()
919    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
920    ds1 = ds1.repeat(4)
921
922    num_iter = 0
923    for _ in ds1.create_dict_iterator(num_epochs=1):
924        num_iter += 1
925
926    logger.info("Number of data in ds1: {} ".format(num_iter))
927    assert num_iter == 8
928    logger.info("test_cache_map_running_twice2 Ended.\n")
929
930
931@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
932def test_cache_map_extra_small_size1():
933    """
934    Feature: DatasetCache op
935    Description: Test running pipeline with Cache of extra small size and spilling=True
936
937       Repeat
938         |
939     Map(Decode)
940         |
941       Cache
942         |
943     ImageFolder
944
945    Expectation: Output is the same as expected output
946    """
947    logger.info("Test cache map extra small size 1")
948    if "SESSION_ID" in os.environ:
949        session_id = int(os.environ['SESSION_ID'])
950    else:
951        raise RuntimeError("Testcase requires SESSION_ID environment variable")
952
953    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
954
955    # This DATA_DIR only has 2 images in it
956    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
957    decode_op = c_vision.Decode()
958    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
959    ds1 = ds1.repeat(4)
960
961    num_iter = 0
962    for _ in ds1.create_dict_iterator(num_epochs=1):
963        num_iter += 1
964
965    logger.info("Number of data in ds1: {} ".format(num_iter))
966    assert num_iter == 8
967    logger.info("test_cache_map_extra_small_size1 Ended.\n")
968
969
970@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
971def test_cache_map_extra_small_size2():
972    """
973    Feature: DatasetCache op
974    Description: Test running pipeline with Cache of extra small size and spilling=False
975
976       Repeat
977         |
978       Cache
979         |
980     Map(Decode)
981         |
982     ImageFolder
983
984    Expectation: Output is the same as expected output
985    """
986    logger.info("Test cache map extra small size 2")
987    if "SESSION_ID" in os.environ:
988        session_id = int(os.environ['SESSION_ID'])
989    else:
990        raise RuntimeError("Testcase requires SESSION_ID environment variable")
991
992    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
993
994    # This DATA_DIR only has 2 images in it
995    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
996    decode_op = c_vision.Decode()
997    ds1 = ds1.map(input_columns=["image"],
998                  operations=decode_op, cache=some_cache)
999    ds1 = ds1.repeat(4)
1000
1001    num_iter = 0
1002    for _ in ds1.create_dict_iterator(num_epochs=1):
1003        num_iter += 1
1004
1005    logger.info("Number of data in ds1: {} ".format(num_iter))
1006    assert num_iter == 8
1007    logger.info("test_cache_map_extra_small_size2 Ended.\n")
1008
1009
1010@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1011def test_cache_map_no_image():
1012    """
1013    Feature: DatasetCache op
1014    Description: Test Cache with no dataset existing in the path
1015
1016       Repeat
1017         |
1018     Map(Decode)
1019         |
1020       Cache
1021         |
1022     ImageFolder
1023
1024    Expectation: Error is raised as expected
1025    """
1026    logger.info("Test cache map no image")
1027    if "SESSION_ID" in os.environ:
1028        session_id = int(os.environ['SESSION_ID'])
1029    else:
1030        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1031
1032    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
1033
1034    # This DATA_DIR only has 2 images in it
1035    ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache)
1036    decode_op = c_vision.Decode()
1037    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1038    ds1 = ds1.repeat(4)
1039
1040    with pytest.raises(RuntimeError):
1041        num_iter = 0
1042        for _ in ds1.create_dict_iterator(num_epochs=1):
1043            num_iter += 1
1044
1045    assert num_iter == 0
1046    logger.info("test_cache_map_no_image Ended.\n")
1047
1048
1049@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1050def test_cache_map_parallel_pipeline1(shard):
1051    """
1052    Feature: DatasetCache op
1053    Description: Test running two parallel pipelines (sharing Cache) with Cache injected after leaf op
1054
1055       Repeat
1056         |
1057     Map(Decode)
1058         |
1059       Cache
1060         |
1061     ImageFolder
1062
1063    Expectation: Output is the same as expected output
1064    """
1065    logger.info("Test cache map parallel pipeline 1")
1066    if "SESSION_ID" in os.environ:
1067        session_id = int(os.environ['SESSION_ID'])
1068    else:
1069        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1070
1071    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1072
1073    # This DATA_DIR only has 2 images in it
1074    ds1 = ds.ImageFolderDataset(
1075        dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
1076    decode_op = c_vision.Decode()
1077    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1078    ds1 = ds1.repeat(4)
1079
1080    num_iter = 0
1081    for _ in ds1.create_dict_iterator(num_epochs=1):
1082        num_iter += 1
1083
1084    logger.info("Number of data in ds1: {} ".format(num_iter))
1085    assert num_iter == 4
1086    logger.info("test_cache_map_parallel_pipeline1 Ended.\n")
1087
1088
1089@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1090def test_cache_map_parallel_pipeline2(shard):
1091    """
1092    Feature: DatasetCache op
1093    Description: Test running two parallel pipelines (sharing Cache) with Cache injected after Map op
1094
1095       Repeat
1096         |
1097       Cache
1098         |
1099     Map(Decode)
1100         |
1101     ImageFolder
1102
1103    Expectation: Output is the same as expected output
1104    """
1105    logger.info("Test cache map parallel pipeline 2")
1106    if "SESSION_ID" in os.environ:
1107        session_id = int(os.environ['SESSION_ID'])
1108    else:
1109        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1110
1111    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1112
1113    # This DATA_DIR only has 2 images in it
1114    ds1 = ds.ImageFolderDataset(
1115        dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
1116    decode_op = c_vision.Decode()
1117    ds1 = ds1.map(input_columns=["image"],
1118                  operations=decode_op, cache=some_cache)
1119    ds1 = ds1.repeat(4)
1120
1121    num_iter = 0
1122    for _ in ds1.create_dict_iterator(num_epochs=1):
1123        num_iter += 1
1124
1125    logger.info("Number of data in ds1: {} ".format(num_iter))
1126    assert num_iter == 4
1127    logger.info("test_cache_map_parallel_pipeline2 Ended.\n")
1128
1129
1130@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1131def test_cache_map_parallel_workers():
1132    """
1133    Feature: DatasetCache op
1134    Description: Test Cache with num_parallel_workers > 1 set for Map op and leaf op
1135
1136       Repeat
1137         |
1138       Cache
1139         |
1140     Map(Decode)
1141         |
1142      ImageFolder
1143
1144    Expectation: Output is the same as expected output
1145    """
1146    logger.info("Test cache map parallel workers")
1147    if "SESSION_ID" in os.environ:
1148        session_id = int(os.environ['SESSION_ID'])
1149    else:
1150        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1151
1152    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1153
1154    # This DATA_DIR only has 2 images in it
1155    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
1156    decode_op = c_vision.Decode()
1157    ds1 = ds1.map(input_columns=[
1158        "image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
1159    ds1 = ds1.repeat(4)
1160
1161    num_iter = 0
1162    for _ in ds1.create_dict_iterator(num_epochs=1):
1163        num_iter += 1
1164
1165    logger.info("Number of data in ds1: {} ".format(num_iter))
1166    assert num_iter == 8
1167    logger.info("test_cache_map_parallel_workers Ended.\n")
1168
1169
1170@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1171def test_cache_map_server_workers_1():
1172    """
1173    Feature: DatasetCache op
1174    Description: Start Cache server with --workers 1 and then test Cache function
1175
1176       Repeat
1177         |
1178       Cache
1179         |
1180     Map(Decode)
1181         |
1182      ImageFolder
1183
1184    Expectation: Output is the same as expected output
1185    """
1186    logger.info("Test cache map server workers 1")
1187    if "SESSION_ID" in os.environ:
1188        session_id = int(os.environ['SESSION_ID'])
1189    else:
1190        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1191
1192    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1193
1194    # This DATA_DIR only has 2 images in it
1195    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1196    decode_op = c_vision.Decode()
1197    ds1 = ds1.map(input_columns=["image"],
1198                  operations=decode_op, cache=some_cache)
1199    ds1 = ds1.repeat(4)
1200
1201    num_iter = 0
1202    for _ in ds1.create_dict_iterator(num_epochs=1):
1203        num_iter += 1
1204
1205    logger.info("Number of data in ds1: {} ".format(num_iter))
1206    assert num_iter == 8
1207    logger.info("test_cache_map_server_workers_1 Ended.\n")
1208
1209
1210@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1211def test_cache_map_server_workers_100():
1212    """
1213    Feature: DatasetCache op
1214    Description: Start Cache server with --workers 100 and then test Cache function
1215
1216       Repeat
1217         |
1218     Map(decode)
1219         |
1220       Cache
1221         |
1222      ImageFolder
1223
1224    Expectation: Output is the same as expected output
1225    """
1226    logger.info("Test cache map server workers 100")
1227    if "SESSION_ID" in os.environ:
1228        session_id = int(os.environ['SESSION_ID'])
1229    else:
1230        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1231
1232    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1233
1234    # This DATA_DIR only has 2 images in it
1235    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1236    decode_op = c_vision.Decode()
1237    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1238    ds1 = ds1.repeat(4)
1239
1240    num_iter = 0
1241    for _ in ds1.create_dict_iterator(num_epochs=1):
1242        num_iter += 1
1243
1244    logger.info("Number of data in ds1: {} ".format(num_iter))
1245    assert num_iter == 8
1246    logger.info("test_cache_map_server_workers_100 Ended.\n")
1247
1248
1249@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1250def test_cache_map_num_connections_1():
1251    """
1252    Feature: DatasetCache op
1253    Description: Test setting num_connections=1 in DatasetCache
1254
1255       Repeat
1256         |
1257       Cache
1258         |
1259     Map(Decode)
1260         |
1261      ImageFolder
1262
1263    Expectation: Output is the same as expected output
1264    """
1265    logger.info("Test cache map num_connections 1")
1266    if "SESSION_ID" in os.environ:
1267        session_id = int(os.environ['SESSION_ID'])
1268    else:
1269        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1270
1271    some_cache = ds.DatasetCache(
1272        session_id=session_id, size=0, num_connections=1)
1273
1274    # This DATA_DIR only has 2 images in it
1275    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1276    decode_op = c_vision.Decode()
1277    ds1 = ds1.map(input_columns=["image"],
1278                  operations=decode_op, cache=some_cache)
1279    ds1 = ds1.repeat(4)
1280
1281    num_iter = 0
1282    for _ in ds1.create_dict_iterator(num_epochs=1):
1283        num_iter += 1
1284
1285    logger.info("Number of data in ds1: {} ".format(num_iter))
1286    assert num_iter == 8
1287    logger.info("test_cache_map_num_connections_1 Ended.\n")
1288
1289
1290@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1291def test_cache_map_num_connections_100():
1292    """
1293    Feature: DatasetCache op
1294    Description: Test setting num_connections=100 in DatasetCache
1295
1296       Repeat
1297         |
1298     Map(Decode)
1299         |
1300       Cache
1301         |
1302      ImageFolder
1303
1304    Expectation: Output is the same as expected output
1305    """
1306    logger.info("Test cache map num_connections 100")
1307    if "SESSION_ID" in os.environ:
1308        session_id = int(os.environ['SESSION_ID'])
1309    else:
1310        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1311
1312    some_cache = ds.DatasetCache(
1313        session_id=session_id, size=0, num_connections=100)
1314
1315    # This DATA_DIR only has 2 images in it
1316    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1317    decode_op = c_vision.Decode()
1318    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1319    ds1 = ds1.repeat(4)
1320
1321    num_iter = 0
1322    for _ in ds1.create_dict_iterator(num_epochs=1):
1323        num_iter += 1
1324
1325    logger.info("Number of data in ds1: {} ".format(num_iter))
1326    assert num_iter == 8
1327    logger.info("test_cache_map_num_connections_100 Ended.\n")
1328
1329
1330@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1331def test_cache_map_prefetch_size_1():
1332    """
1333    Feature: DatasetCache op
1334    Description: Test setting prefetch_size=1 in DatasetCache
1335
1336       Repeat
1337         |
1338       Cache
1339         |
1340     Map(Decode)
1341         |
1342      ImageFolder
1343
1344    Expectation: Output is the same as expected output
1345    """
1346    logger.info("Test cache map prefetch_size 1")
1347    if "SESSION_ID" in os.environ:
1348        session_id = int(os.environ['SESSION_ID'])
1349    else:
1350        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1351
1352    some_cache = ds.DatasetCache(
1353        session_id=session_id, size=0, prefetch_size=1)
1354
1355    # This DATA_DIR only has 2 images in it
1356    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1357    decode_op = c_vision.Decode()
1358    ds1 = ds1.map(input_columns=["image"],
1359                  operations=decode_op, cache=some_cache)
1360    ds1 = ds1.repeat(4)
1361
1362    num_iter = 0
1363    for _ in ds1.create_dict_iterator(num_epochs=1):
1364        num_iter += 1
1365
1366    logger.info("Number of data in ds1: {} ".format(num_iter))
1367    assert num_iter == 8
1368    logger.info("test_cache_map_prefetch_size_1 Ended.\n")
1369
1370
1371@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1372def test_cache_map_prefetch_size_100():
1373    """
1374    Feature: DatasetCache op
1375    Description: Test setting prefetch_size=100 in DatasetCache
1376
1377       Repeat
1378         |
1379     Map(Decode)
1380         |
1381       Cache
1382         |
1383      ImageFolder
1384
1385    Expectation: Output is the same as expected output
1386    """
1387    logger.info("Test cache map prefetch_size 100")
1388    if "SESSION_ID" in os.environ:
1389        session_id = int(os.environ['SESSION_ID'])
1390    else:
1391        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1392
1393    some_cache = ds.DatasetCache(
1394        session_id=session_id, size=0, prefetch_size=100)
1395
1396    # This DATA_DIR only has 2 images in it
1397    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1398    decode_op = c_vision.Decode()
1399    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1400    ds1 = ds1.repeat(4)
1401
1402    num_iter = 0
1403    for _ in ds1.create_dict_iterator(num_epochs=1):
1404        num_iter += 1
1405
1406    logger.info("Number of data in ds1: {} ".format(num_iter))
1407    assert num_iter == 8
1408    logger.info("test_cache_map_prefetch_size_100 Ended.\n")
1409
1410
1411@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1412def test_cache_map_device_que():
1413    """
1414    Feature: DatasetCache op
1415    Description: Test Cache with device_que
1416
1417     DeviceQueue
1418         |
1419      EpochCtrl
1420         |
1421       Repeat
1422         |
1423     Map(Decode)
1424         |
1425       Cache
1426         |
1427      ImageFolder
1428
1429    Expectation: Output is the same as expected output
1430    """
1431    logger.info("Test cache map device_que")
1432    if "SESSION_ID" in os.environ:
1433        session_id = int(os.environ['SESSION_ID'])
1434    else:
1435        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1436
1437    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1438
1439    # This DATA_DIR only has 2 images in it
1440    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1441    decode_op = c_vision.Decode()
1442    ds1 = ds1.map(input_columns=["image"],
1443                  operations=decode_op, cache=some_cache)
1444    ds1 = ds1.repeat(4)
1445    ds1 = ds1.device_que()
1446    ds1.send()
1447
1448    logger.info("test_cache_map_device_que Ended.\n")
1449
1450
1451@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1452def test_cache_map_epoch_ctrl1():
1453    """
1454    Feature: DatasetCache op
1455    Description: Test using two-loops method to run several epochs
1456
1457     Map(Decode)
1458         |
1459       Cache
1460         |
1461      ImageFolder
1462
1463    Expectation: Output is the same as expected output
1464    """
1465    logger.info("Test cache map epoch ctrl1")
1466    if "SESSION_ID" in os.environ:
1467        session_id = int(os.environ['SESSION_ID'])
1468    else:
1469        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1470
1471    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1472
1473    # This DATA_DIR only has 2 images in it
1474    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1475    decode_op = c_vision.Decode()
1476    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1477
1478    num_epoch = 5
1479    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1480
1481    epoch_count = 0
1482    for _ in range(num_epoch):
1483        row_count = 0
1484        for _ in iter1:
1485            row_count += 1
1486        logger.info("Number of data in ds1: {} ".format(row_count))
1487        assert row_count == 2
1488        epoch_count += 1
1489    assert epoch_count == num_epoch
1490    logger.info("test_cache_map_epoch_ctrl1 Ended.\n")
1491
1492
1493@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1494def test_cache_map_epoch_ctrl2():
1495    """
1496    Feature: DatasetCache op
1497    Description: Test using two-loops method with infinite epochs
1498
1499        Cache
1500         |
1501     Map(Decode)
1502         |
1503      ImageFolder
1504
1505    Expectation: Output is the same as expected output
1506    """
1507    logger.info("Test cache map epoch ctrl2")
1508    if "SESSION_ID" in os.environ:
1509        session_id = int(os.environ['SESSION_ID'])
1510    else:
1511        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1512
1513    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1514
1515    # This DATA_DIR only has 2 images in it
1516    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1517    decode_op = c_vision.Decode()
1518    ds1 = ds1.map(input_columns=["image"],
1519                  operations=decode_op, cache=some_cache)
1520
1521    num_epoch = 5
1522    # iter1 will always assume there is a next epoch and never shutdown
1523    iter1 = ds1.create_dict_iterator(num_epochs=-1)
1524
1525    epoch_count = 0
1526    for _ in range(num_epoch):
1527        row_count = 0
1528        for _ in iter1:
1529            row_count += 1
1530        logger.info("Number of data in ds1: {} ".format(row_count))
1531        assert row_count == 2
1532        epoch_count += 1
1533    assert epoch_count == num_epoch
1534
1535    # manually stop the iterator
1536    iter1.stop()
1537    logger.info("test_cache_map_epoch_ctrl2 Ended.\n")
1538
1539
1540@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1541def test_cache_map_epoch_ctrl3():
1542    """
1543    Feature: DatasetCache op
1544    Description: Test using two-loops method with infinite epochs over repeat
1545
1546       Repeat
1547         |
1548     Map(Decode)
1549         |
1550       Cache
1551         |
1552      ImageFolder
1553
1554    Expectation: Output is the same as expected output
1555    """
1556    logger.info("Test cache map epoch ctrl3")
1557    if "SESSION_ID" in os.environ:
1558        session_id = int(os.environ['SESSION_ID'])
1559    else:
1560        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1561
1562    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1563
1564    # This DATA_DIR only has 2 images in it
1565    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1566    decode_op = c_vision.Decode()
1567    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1568    ds1 = ds1.repeat(2)
1569
1570    num_epoch = 5
1571    # iter1 will always assume there is a next epoch and never shutdown
1572    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1573
1574    epoch_count = 0
1575    for _ in range(num_epoch):
1576        row_count = 0
1577        for _ in iter1:
1578            row_count += 1
1579        logger.info("Number of data in ds1: {} ".format(row_count))
1580        assert row_count == 4
1581        epoch_count += 1
1582    assert epoch_count == num_epoch
1583
1584    # reply on garbage collector to destroy iter1
1585
1586    logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
1587
1588
1589@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1590def test_cache_map_coco1():
1591    """
1592    Feature: DatasetCache op
1593    Description: Test mappable Coco leaf with Cache op right over the leaf
1594
1595       Cache
1596         |
1597       Coco
1598
1599    Expectation: Output is the same as expected output
1600    """
1601    logger.info("Test cache map coco1")
1602    if "SESSION_ID" in os.environ:
1603        session_id = int(os.environ['SESSION_ID'])
1604    else:
1605        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1606
1607    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1608
1609    # This dataset has 6 records
1610    ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
1611                         cache=some_cache)
1612
1613    num_epoch = 4
1614    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1615
1616    epoch_count = 0
1617    for _ in range(num_epoch):
1618        assert sum([1 for _ in iter1]) == 6
1619        epoch_count += 1
1620    assert epoch_count == num_epoch
1621
1622    logger.info("test_cache_map_coco1 Ended.\n")
1623
1624
1625@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1626def test_cache_map_coco2():
1627    """
1628    Feature: DatasetCache op
1629    Description: Test mappable Coco leaf with the Cache op later in the tree above the Map(Resize)
1630
1631       Cache
1632         |
1633     Map(Resize)
1634         |
1635       Coco
1636
1637    Expectation: Output is the same as expected output
1638    """
1639    logger.info("Test cache map coco2")
1640    if "SESSION_ID" in os.environ:
1641        session_id = int(os.environ['SESSION_ID'])
1642    else:
1643        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1644
1645    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1646
1647    # This dataset has 6 records
1648    ds1 = ds.CocoDataset(
1649        COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
1650    resize_op = c_vision.Resize((224, 224))
1651    ds1 = ds1.map(input_columns=["image"],
1652                  operations=resize_op, cache=some_cache)
1653
1654    num_epoch = 4
1655    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1656
1657    epoch_count = 0
1658    for _ in range(num_epoch):
1659        assert sum([1 for _ in iter1]) == 6
1660        epoch_count += 1
1661    assert epoch_count == num_epoch
1662
1663    logger.info("test_cache_map_coco2 Ended.\n")
1664
1665
1666@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1667def test_cache_map_mnist1():
1668    """
1669    Feature: DatasetCache op
1670    Description: Test mappable Mnist leaf with Cache op right over the leaf
1671
1672       Cache
1673         |
1674       Mnist
1675
1676    Expectation: Output is the same as expected output
1677    """
1678    logger.info("Test cache map mnist1")
1679    if "SESSION_ID" in os.environ:
1680        session_id = int(os.environ['SESSION_ID'])
1681    else:
1682        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1683
1684    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1685    ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
1686
1687    num_epoch = 4
1688    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1689
1690    epoch_count = 0
1691    for _ in range(num_epoch):
1692        assert sum([1 for _ in iter1]) == 10
1693        epoch_count += 1
1694    assert epoch_count == num_epoch
1695
1696    logger.info("test_cache_map_mnist1 Ended.\n")
1697
1698
1699@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1700def test_cache_map_mnist2():
1701    """
1702    Feature: DatasetCache op
1703    Description: Test mappable Mnist leaf with the Cache op later in the tree above the Map(Resize)
1704
1705       Cache
1706         |
1707     Map(Resize)
1708         |
1709       Mnist
1710
1711    Expectation: Output is the same as expected output
1712    """
1713    logger.info("Test cache map mnist2")
1714    if "SESSION_ID" in os.environ:
1715        session_id = int(os.environ['SESSION_ID'])
1716    else:
1717        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1718
1719    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1720    ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
1721
1722    resize_op = c_vision.Resize((224, 224))
1723    ds1 = ds1.map(input_columns=["image"],
1724                  operations=resize_op, cache=some_cache)
1725
1726    num_epoch = 4
1727    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1728
1729    epoch_count = 0
1730    for _ in range(num_epoch):
1731        assert sum([1 for _ in iter1]) == 10
1732        epoch_count += 1
1733    assert epoch_count == num_epoch
1734
1735    logger.info("test_cache_map_mnist2 Ended.\n")
1736
1737
1738@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1739def test_cache_map_celeba1():
1740    """
1741    Feature: DatasetCache op
1742    Description: Test mappable CelebA leaf with Cache op right over the leaf
1743
1744       Cache
1745         |
1746       CelebA
1747
1748    Expectation: Output is the same as expected output
1749    """
1750    logger.info("Test cache map celeba1")
1751    if "SESSION_ID" in os.environ:
1752        session_id = int(os.environ['SESSION_ID'])
1753    else:
1754        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1755
1756    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1757
1758    # This dataset has 4 records
1759    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False,
1760                           decode=True, cache=some_cache)
1761
1762    num_epoch = 4
1763    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1764
1765    epoch_count = 0
1766    for _ in range(num_epoch):
1767        assert sum([1 for _ in iter1]) == 4
1768        epoch_count += 1
1769    assert epoch_count == num_epoch
1770
1771    logger.info("test_cache_map_celeba1 Ended.\n")
1772
1773
1774@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1775def test_cache_map_celeba2():
1776    """
1777    Feature: DatasetCache op
1778    Description: Test mappable CelebA leaf with the Cache op later in the tree above the Map(Resize)
1779
1780       Cache
1781         |
1782     Map(Resize)
1783         |
1784       CelebA
1785
1786    Expectation: Output is the same as expected output
1787    """
1788    logger.info("Test cache map celeba2")
1789    if "SESSION_ID" in os.environ:
1790        session_id = int(os.environ['SESSION_ID'])
1791    else:
1792        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1793
1794    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1795
1796    # This dataset has 4 records
1797    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
1798    resize_op = c_vision.Resize((224, 224))
1799    ds1 = ds1.map(input_columns=["image"],
1800                  operations=resize_op, cache=some_cache)
1801
1802    num_epoch = 4
1803    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1804
1805    epoch_count = 0
1806    for _ in range(num_epoch):
1807        assert sum([1 for _ in iter1]) == 4
1808        epoch_count += 1
1809    assert epoch_count == num_epoch
1810
1811    logger.info("test_cache_map_celeba2 Ended.\n")
1812
1813
1814@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1815def test_cache_map_manifest1():
1816    """
1817    Feature: DatasetCache op
1818    Description: Test mappable Manifest leaf with Cache op right over the leaf
1819
1820       Cache
1821         |
1822      Manifest
1823
1824    Expectation: Output is the same as expected output
1825    """
1826    logger.info("Test cache map manifest1")
1827    if "SESSION_ID" in os.environ:
1828        session_id = int(os.environ['SESSION_ID'])
1829    else:
1830        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1831
1832    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1833
1834    # This dataset has 4 records
1835    ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
1836
1837    num_epoch = 4
1838    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1839
1840    epoch_count = 0
1841    for _ in range(num_epoch):
1842        assert sum([1 for _ in iter1]) == 4
1843        epoch_count += 1
1844    assert epoch_count == num_epoch
1845
1846    logger.info("test_cache_map_manifest1 Ended.\n")
1847
1848
1849@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1850def test_cache_map_manifest2():
1851    """
1852    Feature: DatasetCache op
1853    Description: Test mappable Manifest leaf with the Cache op later in the tree above the Map(Resize)
1854
1855       Cache
1856         |
1857     Map(Resize)
1858         |
1859      Manifest
1860
1861    Expectation: Output is the same as expected output
1862    """
1863    logger.info("Test cache map manifest2")
1864    if "SESSION_ID" in os.environ:
1865        session_id = int(os.environ['SESSION_ID'])
1866    else:
1867        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1868
1869    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1870
1871    # This dataset has 4 records
1872    ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
1873    resize_op = c_vision.Resize((224, 224))
1874    ds1 = ds1.map(input_columns=["image"],
1875                  operations=resize_op, cache=some_cache)
1876
1877    num_epoch = 4
1878    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1879
1880    epoch_count = 0
1881    for _ in range(num_epoch):
1882        assert sum([1 for _ in iter1]) == 4
1883        epoch_count += 1
1884    assert epoch_count == num_epoch
1885
1886    logger.info("test_cache_map_manifest2 Ended.\n")
1887
1888
1889@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1890def test_cache_map_cifar1():
1891    """
1892    Feature: DatasetCache op
1893    Description: Test mappable Cifar10 leaf with Cache op right over the leaf
1894
1895       Cache
1896         |
1897      Cifar10
1898
1899    Expectation: Output is the same as expected output
1900    """
1901    logger.info("Test cache map cifar1")
1902    if "SESSION_ID" in os.environ:
1903        session_id = int(os.environ['SESSION_ID'])
1904    else:
1905        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1906
1907    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1908    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
1909
1910    num_epoch = 4
1911    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1912
1913    epoch_count = 0
1914    for _ in range(num_epoch):
1915        assert sum([1 for _ in iter1]) == 10
1916        epoch_count += 1
1917    assert epoch_count == num_epoch
1918
1919    logger.info("test_cache_map_cifar1 Ended.\n")
1920
1921
1922@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1923def test_cache_map_cifar2():
1924    """
1925    Feature: DatasetCache op
1926    Description: Test mappable Cifar100 leaf with the Cache op later in the tree above the Map(Resize)
1927
1928       Cache
1929         |
1930     Map(Resize)
1931         |
1932      Cifar100
1933
1934    Expectation: Output is the same as expected output
1935    """
1936    logger.info("Test cache map cifar2")
1937    if "SESSION_ID" in os.environ:
1938        session_id = int(os.environ['SESSION_ID'])
1939    else:
1940        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1941
1942    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1943
1944    ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
1945    resize_op = c_vision.Resize((224, 224))
1946    ds1 = ds1.map(input_columns=["image"],
1947                  operations=resize_op, cache=some_cache)
1948
1949    num_epoch = 4
1950    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1951
1952    epoch_count = 0
1953    for _ in range(num_epoch):
1954        assert sum([1 for _ in iter1]) == 10
1955        epoch_count += 1
1956    assert epoch_count == num_epoch
1957
1958    logger.info("test_cache_map_cifar2 Ended.\n")
1959
1960
1961@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1962def test_cache_map_cifar3():
1963    """
1964    Feature: DatasetCache op
1965    Description: Test mappable Cifar10 leaf with the Cache op, extra-small size (size=1), and 10000 rows in the dataset
1966
1967       Cache
1968         |
1969      Cifar10
1970
1971    Expectation: Output is the same as expected output
1972    """
1973    logger.info("Test cache map cifar3")
1974    if "SESSION_ID" in os.environ:
1975        session_id = int(os.environ['SESSION_ID'])
1976    else:
1977        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1978
1979    some_cache = ds.DatasetCache(session_id=session_id, size=1)
1980
1981    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
1982
1983    num_epoch = 2
1984    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1985
1986    epoch_count = 0
1987    for _ in range(num_epoch):
1988        assert sum([1 for _ in iter1]) == 10000
1989        epoch_count += 1
1990    assert epoch_count == num_epoch
1991
1992    logger.info("test_cache_map_cifar3 Ended.\n")
1993
1994
1995@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1996def test_cache_map_cifar4():
1997    """
1998    Feature: DatasetCache op
1999    Description: Test mappable Cifar10 leaf with Cache op right over the leaf, and Shuffle op over the Cache op
2000
2001       Shuffle
2002         |
2003       Cache
2004         |
2005      Cifar10
2006
2007    Expectation: Output is the same as expected output
2008    """
2009    logger.info("Test cache map cifar4")
2010    if "SESSION_ID" in os.environ:
2011        session_id = int(os.environ['SESSION_ID'])
2012    else:
2013        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2014
2015    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2016    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
2017    ds1 = ds1.shuffle(10)
2018
2019    num_epoch = 1
2020    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
2021
2022    epoch_count = 0
2023    for _ in range(num_epoch):
2024        assert sum([1 for _ in iter1]) == 10
2025        epoch_count += 1
2026    assert epoch_count == num_epoch
2027
2028    logger.info("test_cache_map_cifar4 Ended.\n")
2029
2030
2031@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2032def test_cache_map_voc1():
2033    """
2034    Feature: DatasetCache op
2035    Description: Test mappable VOC leaf with Cache op right over the leaf
2036
2037       Cache
2038         |
2039       VOC
2040
2041    Expectation: Output is the same as expected output
2042    """
2043    logger.info("Test cache map voc1")
2044    if "SESSION_ID" in os.environ:
2045        session_id = int(os.environ['SESSION_ID'])
2046    else:
2047        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2048
2049    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2050
2051    # This dataset has 9 records
2052    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train",
2053                        shuffle=False, decode=True, cache=some_cache)
2054
2055    num_epoch = 4
2056    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
2057
2058    epoch_count = 0
2059    for _ in range(num_epoch):
2060        assert sum([1 for _ in iter1]) == 9
2061        epoch_count += 1
2062    assert epoch_count == num_epoch
2063
2064    logger.info("test_cache_map_voc1 Ended.\n")
2065
2066
2067@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2068def test_cache_map_voc2():
2069    """
2070    Feature: DatasetCache op
2071    Description: Test mappable VOC leaf with the Cache op later in the tree above the Map(Resize)
2072
2073       Cache
2074         |
2075     Map(Resize)
2076         |
2077       VOC
2078
2079    Expectation: Output is the same as expected output
2080    """
2081    logger.info("Test cache map voc2")
2082    if "SESSION_ID" in os.environ:
2083        session_id = int(os.environ['SESSION_ID'])
2084    else:
2085        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2086
2087    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2088
2089    # This dataset has 9 records
2090    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection",
2091                        usage="train", shuffle=False, decode=True)
2092    resize_op = c_vision.Resize((224, 224))
2093    ds1 = ds1.map(input_columns=["image"],
2094                  operations=resize_op, cache=some_cache)
2095
2096    num_epoch = 4
2097    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
2098
2099    epoch_count = 0
2100    for _ in range(num_epoch):
2101        assert sum([1 for _ in iter1]) == 9
2102        epoch_count += 1
2103    assert epoch_count == num_epoch
2104
2105    logger.info("test_cache_map_voc2 Ended.\n")
2106
2107
2108class ReverseSampler(ds.Sampler):
2109    def __iter__(self):
2110        for i in range(self.dataset_size - 1, -1, -1):
2111            yield i
2112
2113
2114@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2115def test_cache_map_mindrecord1():
2116    """
2117    Feature: DatasetCache op
2118    Description: Test mappable MindRecord leaf with Cache op right over the leaf
2119
2120       Cache
2121         |
2122    MindRecord
2123
2124    Expectation: Output is the same as expected output
2125    """
2126    logger.info("Test cache map mindrecord1")
2127    if "SESSION_ID" in os.environ:
2128        session_id = int(os.environ['SESSION_ID'])
2129    else:
2130        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2131
2132    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2133
2134    # This dataset has 5 records
2135    columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
2136    ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, cache=some_cache)
2137
2138    num_epoch = 4
2139    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
2140
2141    epoch_count = 0
2142    for _ in range(num_epoch):
2143        assert sum([1 for _ in iter1]) == 5
2144        epoch_count += 1
2145    assert epoch_count == num_epoch
2146
2147    logger.info("test_cache_map_mindrecord1 Ended.\n")
2148
2149
2150@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2151def test_cache_map_mindrecord2():
2152    """
2153    Feature: DatasetCache op
2154    Description: Test mappable MindRecord leaf with the Cache op later in the tree above the Map(Decode)
2155
2156       Cache
2157         |
2158     Map(Decode)
2159         |
2160     MindRecord
2161
2162    Expectation: Output is the same as expected output
2163    """
2164    logger.info("Test cache map mindrecord2")
2165    if "SESSION_ID" in os.environ:
2166        session_id = int(os.environ['SESSION_ID'])
2167    else:
2168        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2169
2170    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2171
2172    # This dataset has 5 records
2173    columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
2174    ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list)
2175
2176    decode_op = c_vision.Decode()
2177    ds1 = ds1.map(input_columns=["img_data"],
2178                  operations=decode_op, cache=some_cache)
2179
2180    num_epoch = 4
2181    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
2182
2183    epoch_count = 0
2184    for _ in range(num_epoch):
2185        assert sum([1 for _ in iter1]) == 5
2186        epoch_count += 1
2187    assert epoch_count == num_epoch
2188
2189    logger.info("test_cache_map_mindrecord2 Ended.\n")
2190
2191
2192@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2193def test_cache_map_mindrecord3():
2194    """
2195    Feature: DatasetCache op
2196    Description: Test Cache sharing between the following two pipelines with MindRecord leaf:
2197
2198       Cache                                    Cache
2199         |                                      |
2200     Map(Decode)                            Map(Decode)
2201         |                                      |
2202      MindRecord(num_parallel_workers=5)    MindRecord(num_parallel_workers=6)
2203
2204    Expectation: Output is the same as expected output
2205    """
2206    logger.info("Test cache map mindrecord3")
2207    if "SESSION_ID" in os.environ:
2208        session_id = int(os.environ['SESSION_ID'])
2209    else:
2210        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2211
2212    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2213
2214    # This dataset has 5 records
2215    columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
2216    decode_op = c_vision.Decode()
2217
2218    ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list,
2219                         num_parallel_workers=5, shuffle=True)
2220    ds1 = ds1.map(input_columns=["img_data"],
2221                  operations=decode_op, cache=some_cache)
2222
2223    ds2 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list,
2224                         num_parallel_workers=6, shuffle=True)
2225    ds2 = ds2.map(input_columns=["img_data"],
2226                  operations=decode_op, cache=some_cache)
2227
2228    iter1 = ds1.create_dict_iterator(num_epochs=1, output_numpy=True)
2229    iter2 = ds2.create_dict_iterator(num_epochs=1, output_numpy=True)
2230
2231    assert sum([1 for _ in iter1]) == 5
2232    assert sum([1 for _ in iter2]) == 5
2233
2234    logger.info("test_cache_map_mindrecord3 Ended.\n")
2235
2236
2237@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2238def test_cache_map_python_sampler1():
2239    """
2240    Feature: DatasetCache op
2241    Description: Test using a Python sampler, and Cache after leaf
2242
2243        Repeat
2244         |
2245     Map(Decode)
2246         |
2247       Cache
2248         |
2249      ImageFolder
2250
2251    Expectation: Output is the same as expected output
2252    """
2253    logger.info("Test cache map python sampler1")
2254    if "SESSION_ID" in os.environ:
2255        session_id = int(os.environ['SESSION_ID'])
2256    else:
2257        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2258
2259    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2260
2261    # This DATA_DIR only has 2 images in it
2262    ds1 = ds.ImageFolderDataset(
2263        dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
2264    decode_op = c_vision.Decode()
2265    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
2266    ds1 = ds1.repeat(4)
2267
2268    num_iter = 0
2269    for _ in ds1.create_dict_iterator(num_epochs=1):
2270        num_iter += 1
2271    logger.info("Number of data in ds1: {} ".format(num_iter))
2272    assert num_iter == 8
2273    logger.info("test_cache_map_python_sampler1 Ended.\n")
2274
2275
2276@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2277def test_cache_map_python_sampler2():
2278    """
2279    Feature: DatasetCache op
2280    Description: Test using a Python sampler, and Cache after Map
2281
2282       Repeat
2283         |
2284       Cache
2285         |
2286     Map(Decode)
2287         |
2288      ImageFolder
2289
2290    Expectation: Output is the same as expected output
2291    """
2292    logger.info("Test cache map python sampler2")
2293    if "SESSION_ID" in os.environ:
2294        session_id = int(os.environ['SESSION_ID'])
2295    else:
2296        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2297
2298    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2299
2300    # This DATA_DIR only has 2 images in it
2301    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
2302    decode_op = c_vision.Decode()
2303    ds1 = ds1.map(input_columns=["image"],
2304                  operations=decode_op, cache=some_cache)
2305    ds1 = ds1.repeat(4)
2306
2307    num_iter = 0
2308    for _ in ds1.create_dict_iterator(num_epochs=1):
2309        num_iter += 1
2310    logger.info("Number of data in ds1: {} ".format(num_iter))
2311    assert num_iter == 8
2312    logger.info("test_cache_map_python_sampler2 Ended.\n")
2313
2314
2315@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2316def test_cache_map_nested_repeat():
2317    """
2318    Feature: DatasetCache op
2319    Description: Test Cache on pipeline with nested Repeat ops
2320
2321        Repeat
2322          |
2323      Map(Decode)
2324          |
2325        Repeat
2326          |
2327        Cache
2328          |
2329      ImageFolder
2330
2331    Expectation: Output is the same as expected output
2332    """
2333    logger.info("Test cache map nested repeat")
2334    if "SESSION_ID" in os.environ:
2335        session_id = int(os.environ['SESSION_ID'])
2336    else:
2337        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2338
2339    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2340
2341    # This DATA_DIR only has 2 images in it
2342    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
2343    decode_op = c_vision.Decode()
2344    ds1 = ds1.repeat(4)
2345    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
2346    ds1 = ds1.repeat(2)
2347
2348    num_iter = 0
2349    for _ in ds1.create_dict_iterator(num_epochs=1):
2350        logger.info("get data from dataset")
2351        num_iter += 1
2352
2353    logger.info("Number of data in ds1: {} ".format(num_iter))
2354    assert num_iter == 16
2355    logger.info('test_cache_map_nested_repeat Ended.\n')
2356
2357
2358@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2359def test_cache_map_interrupt_and_rerun():
2360    """
2361    Feature: DatasetCache op
2362    Description: Test interrupt a running pipeline and then re-use the same Cache to run another pipeline
2363
2364       Cache
2365         |
2366      Cifar10
2367
2368    Expectation: Error is raised when interrupted and output is the same as expected output after the rerun
2369    """
2370    logger.info("Test cache map interrupt and rerun")
2371    if "SESSION_ID" in os.environ:
2372        session_id = int(os.environ['SESSION_ID'])
2373    else:
2374        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2375
2376    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2377
2378    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
2379    iter1 = ds1.create_dict_iterator(num_epochs=-1)
2380
2381    num_iter = 0
2382    with pytest.raises(AttributeError) as e:
2383        for _ in iter1:
2384            num_iter += 1
2385            if num_iter == 10:
2386                iter1.stop()
2387    assert "'DictIterator' object has no attribute '_runtime_context'" in str(
2388        e.value)
2389
2390    num_epoch = 2
2391    iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
2392    epoch_count = 0
2393    for _ in range(num_epoch):
2394        num_iter = 0
2395        for _ in iter2:
2396            num_iter += 1
2397        logger.info("Number of data in ds1: {} ".format(num_iter))
2398        assert num_iter == 10000
2399        epoch_count += 1
2400
2401    cache_stat = some_cache.get_stat()
2402    assert cache_stat.num_mem_cached == 10000
2403
2404    logger.info("test_cache_map_interrupt_and_rerun Ended.\n")
2405
2406
2407@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2408def test_cache_map_dataset_size1():
2409    """
2410    Feature: DatasetCache op
2411    Description: Test get_dataset_size op when Cache is injected directly after a mappable leaf
2412
2413       Cache
2414         |
2415      CelebA
2416
2417    Expectation: Output is the same as expected output
2418    """
2419    logger.info("Test cache map dataset size 1")
2420    if "SESSION_ID" in os.environ:
2421        session_id = int(os.environ['SESSION_ID'])
2422    else:
2423        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2424
2425    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2426
2427    # This dataset has 4 records
2428    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, num_shards=3,
2429                           shard_id=0, cache=some_cache)
2430
2431    dataset_size = ds1.get_dataset_size()
2432    assert dataset_size == 2
2433
2434    num_iter = 0
2435    for _ in ds1.create_dict_iterator(num_epochs=1):
2436        num_iter += 1
2437
2438    logger.info("Number of data in ds1: {} ".format(num_iter))
2439    assert num_iter == dataset_size
2440    logger.info("test_cache_map_dataset_size1 Ended.\n")
2441
2442
2443@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2444def test_cache_map_dataset_size2():
2445    """
2446    Feature: DatasetCache op
2447    Description: Test get_dataset_size op when Cache is injected after Map
2448
2449       Cache
2450         |
2451    Map(Resize)
2452         |
2453     CelebA
2454
2455    Expectation: Output is the same as expected output
2456    """
2457    logger.info("Test cache map dataset size 2")
2458    if "SESSION_ID" in os.environ:
2459        session_id = int(os.environ['SESSION_ID'])
2460    else:
2461        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2462
2463    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2464
2465    # This dataset has 4 records
2466    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False,
2467                           decode=True, num_shards=3, shard_id=0)
2468    resize_op = c_vision.Resize((224, 224))
2469    ds1 = ds1.map(operations=resize_op, input_columns=[
2470        "image"], cache=some_cache)
2471
2472    dataset_size = ds1.get_dataset_size()
2473    assert dataset_size == 2
2474
2475    num_iter = 0
2476    for _ in ds1.create_dict_iterator(num_epochs=1):
2477        num_iter += 1
2478
2479    logger.info("Number of data in ds1: {} ".format(num_iter))
2480    assert num_iter == dataset_size
2481    logger.info("test_cache_map_dataset_size2 Ended.\n")
2482
2483
2484if __name__ == '__main__':
2485    # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py'
2486    # since cache server is required to be brought up first
2487    test_cache_map_basic1()
2488    test_cache_map_basic2()
2489    test_cache_map_basic3()
2490    test_cache_map_basic4()
2491    test_cache_map_basic5()
2492    test_cache_map_failure1()
2493    test_cache_map_failure2()
2494    test_cache_map_failure3()
2495    test_cache_map_failure4()
2496    test_cache_map_failure5()
2497    test_cache_map_failure7()
2498    test_cache_map_failure8()
2499    test_cache_map_failure9()
2500    test_cache_map_failure10()
2501    test_cache_map_failure11()
2502    test_cache_map_split1()
2503    test_cache_map_split2()
2504    test_cache_map_parameter_check()
2505    test_cache_map_running_twice1()
2506    test_cache_map_running_twice2()
2507    test_cache_map_extra_small_size1()
2508    test_cache_map_extra_small_size2()
2509    test_cache_map_no_image()
2510    test_cache_map_parallel_pipeline1(shard=0)
2511    test_cache_map_parallel_pipeline2(shard=1)
2512    test_cache_map_parallel_workers()
2513    test_cache_map_server_workers_1()
2514    test_cache_map_server_workers_100()
2515    test_cache_map_num_connections_1()
2516    test_cache_map_num_connections_100()
2517    test_cache_map_prefetch_size_1()
2518    test_cache_map_prefetch_size_100()
2519    test_cache_map_device_que()
2520    test_cache_map_epoch_ctrl1()
2521    test_cache_map_epoch_ctrl2()
2522    test_cache_map_epoch_ctrl3()
2523    test_cache_map_coco1()
2524    test_cache_map_coco2()
2525    test_cache_map_mnist1()
2526    test_cache_map_mnist2()
2527    test_cache_map_celeba1()
2528    test_cache_map_celeba2()
2529    test_cache_map_manifest1()
2530    test_cache_map_manifest2()
2531    test_cache_map_cifar1()
2532    test_cache_map_cifar2()
2533    test_cache_map_cifar3()
2534    test_cache_map_cifar4()
2535    test_cache_map_voc1()
2536    test_cache_map_voc2()
2537    test_cache_map_mindrecord1()
2538    test_cache_map_mindrecord2()
2539    test_cache_map_python_sampler1()
2540    test_cache_map_python_sampler2()
2541    test_cache_map_nested_repeat()
2542    test_cache_map_dataset_size1()
2543    test_cache_map_dataset_size2()
2544