• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing cache operator with mappable datasets
17"""
18import os
19import pytest
20import numpy as np
21import mindspore.dataset as ds
22import mindspore.dataset.vision.c_transforms as c_vision
23import mindspore.dataset.vision.py_transforms as py_vision
24from mindspore import log as logger
25from util import save_and_check_md5
26
27DATA_DIR = "../data/dataset/testImageNetData/train/"
28COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
29COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
30NO_IMAGE_DIR = "../data/dataset/testRandomData/"
31MNIST_DATA_DIR = "../data/dataset/testMnistData/"
32CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
33VOC_DATA_DIR = "../data/dataset/testVOC2012/"
34MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
35CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
36CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
37MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
38GENERATE_GOLDEN = False
39
40
41@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
42def test_cache_map_basic1():
43    """
44    Test mappable leaf with cache op right over the leaf
45
46       Repeat
47         |
48     Map(decode)
49         |
50       Cache
51         |
52     ImageFolder
53    """
54
55    logger.info("Test cache map basic 1")
56    if "SESSION_ID" in os.environ:
57        session_id = int(os.environ['SESSION_ID'])
58    else:
59        raise RuntimeError("Testcase requires SESSION_ID environment variable")
60
61    some_cache = ds.DatasetCache(session_id=session_id, size=0)
62
63    # This DATA_DIR only has 2 images in it
64    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
65    decode_op = c_vision.Decode()
66    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
67    ds1 = ds1.repeat(4)
68
69    filename = "cache_map_01_result.npz"
70    save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
71
72    logger.info("test_cache_map_basic1 Ended.\n")
73
74
75@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
76def test_cache_map_basic2():
77    """
78    Test mappable leaf with the cache op later in the tree above the map(decode)
79
80       Repeat
81         |
82       Cache
83         |
84     Map(decode)
85         |
86     ImageFolder
87    """
88
89    logger.info("Test cache map basic 2")
90    if "SESSION_ID" in os.environ:
91        session_id = int(os.environ['SESSION_ID'])
92    else:
93        raise RuntimeError("Testcase requires SESSION_ID environment variable")
94
95    some_cache = ds.DatasetCache(session_id=session_id, size=0)
96
97    # This DATA_DIR only has 2 images in it
98    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
99    decode_op = c_vision.Decode()
100    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
101    ds1 = ds1.repeat(4)
102
103    filename = "cache_map_02_result.npz"
104    save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
105
106    logger.info("test_cache_map_basic2 Ended.\n")
107
108
109@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
110def test_cache_map_basic3():
111    """
112    Test different rows result in core dump
113    """
114    logger.info("Test cache basic 3")
115    if "SESSION_ID" in os.environ:
116        session_id = int(os.environ['SESSION_ID'])
117    else:
118        raise RuntimeError("Testcase requires SESSION_ID environment variable")
119    some_cache = ds.DatasetCache(session_id=session_id, size=0)
120
121    # This DATA_DIR only has 2 images in it
122    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
123    decode_op = c_vision.Decode()
124    ds1 = ds1.repeat(4)
125    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
126    logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
127    shape = ds1.output_shapes()
128    logger.info(shape)
129    num_iter = 0
130    for _ in ds1.create_dict_iterator(num_epochs=1):
131        logger.info("get data from dataset")
132        num_iter += 1
133
134    logger.info("Number of data in ds1: {} ".format(num_iter))
135    assert num_iter == 8
136    logger.info('test_cache_basic3 Ended.\n')
137
138
139@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
140def test_cache_map_basic4():
141    """
142    Test Map containing random operation above cache
143
144               repeat
145                  |
146             Map(decode, randomCrop)
147                  |
148                Cache
149                  |
150             ImageFolder
151
152    """
153    logger.info("Test cache basic 4")
154    if "SESSION_ID" in os.environ:
155        session_id = int(os.environ['SESSION_ID'])
156    else:
157        raise RuntimeError("Testcase requires SESSION_ID environment variable")
158
159    some_cache = ds.DatasetCache(session_id=session_id, size=0)
160
161    # This DATA_DIR only has 2 images in it
162    data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
163    random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
164    decode_op = c_vision.Decode()
165
166    data = data.map(input_columns=["image"], operations=decode_op)
167    data = data.map(input_columns=["image"], operations=random_crop_op)
168    data = data.repeat(4)
169
170    num_iter = 0
171    for _ in data.create_dict_iterator():
172        num_iter += 1
173
174    logger.info("Number of data in ds1: {} ".format(num_iter))
175    assert num_iter == 8
176    logger.info('test_cache_basic4 Ended.\n')
177
178
179@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
180def test_cache_map_basic5():
181    """
182    Test cache as root node
183
184       cache
185         |
186      ImageFolder
187    """
188    logger.info("Test cache basic 5")
189    if "SESSION_ID" in os.environ:
190        session_id = int(os.environ['SESSION_ID'])
191    else:
192        raise RuntimeError("Testcase requires SESSION_ID environment variable")
193    some_cache = ds.DatasetCache(session_id=session_id, size=0)
194
195    # This DATA_DIR only has 2 images in it
196    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
197    num_iter = 0
198    for _ in ds1.create_dict_iterator(num_epochs=1):
199        logger.info("get data from dataset")
200        num_iter += 1
201
202    logger.info("Number of data in ds1: {} ".format(num_iter))
203    assert num_iter == 2
204    logger.info('test_cache_basic5 Ended.\n')
205
206
207@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
208def test_cache_map_failure1():
209    """
210    Test nested cache (failure)
211
212        Repeat
213          |
214        Cache
215          |
216      Map(decode)
217          |
218        Cache
219          |
220        Coco
221
222    """
223    logger.info("Test cache failure 1")
224    if "SESSION_ID" in os.environ:
225        session_id = int(os.environ['SESSION_ID'])
226    else:
227        raise RuntimeError("Testcase requires SESSION_ID environment variable")
228
229    some_cache = ds.DatasetCache(session_id=session_id, size=0)
230
231    # This DATA_DIR has 6 images in it
232    ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
233                         cache=some_cache)
234    decode_op = c_vision.Decode()
235    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
236    ds1 = ds1.repeat(4)
237
238    with pytest.raises(RuntimeError) as e:
239        ds1.get_batch_size()
240    assert "Nested cache operations" in str(e.value)
241
242    with pytest.raises(RuntimeError) as e:
243        num_iter = 0
244        for _ in ds1.create_dict_iterator(num_epochs=1):
245            num_iter += 1
246    assert "Nested cache operations" in str(e.value)
247
248    assert num_iter == 0
249    logger.info('test_cache_failure1 Ended.\n')
250
251
252@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
253def test_cache_map_failure2():
254    """
255    Test zip under cache (failure)
256
257               repeat
258                  |
259                Cache
260                  |
261             Map(decode)
262                  |
263                 Zip
264                |    |
265      ImageFolder     ImageFolder
266
267    """
268    logger.info("Test cache failure 2")
269    if "SESSION_ID" in os.environ:
270        session_id = int(os.environ['SESSION_ID'])
271    else:
272        raise RuntimeError("Testcase requires SESSION_ID environment variable")
273
274    some_cache = ds.DatasetCache(session_id=session_id, size=0)
275
276    # This DATA_DIR only has 2 images in it
277    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
278    ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
279    dsz = ds.zip((ds1, ds2))
280    decode_op = c_vision.Decode()
281    dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
282    dsz = dsz.repeat(4)
283
284    with pytest.raises(RuntimeError) as e:
285        num_iter = 0
286        for _ in dsz.create_dict_iterator():
287            num_iter += 1
288    assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
289
290    assert num_iter == 0
291    logger.info('test_cache_failure2 Ended.\n')
292
293
294@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
295def test_cache_map_failure3():
296    """
297    Test batch under cache (failure)
298
299               repeat
300                  |
301                Cache
302                  |
303             Map(resize)
304                  |
305                Batch
306                  |
307                Mnist
308    """
309    logger.info("Test cache failure 3")
310    if "SESSION_ID" in os.environ:
311        session_id = int(os.environ['SESSION_ID'])
312    else:
313        raise RuntimeError("Testcase requires SESSION_ID environment variable")
314
315    some_cache = ds.DatasetCache(session_id=session_id, size=0)
316
317    ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
318    ds1 = ds1.batch(2)
319    resize_op = c_vision.Resize((224, 224))
320    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
321    ds1 = ds1.repeat(4)
322
323    with pytest.raises(RuntimeError) as e:
324        num_iter = 0
325        for _ in ds1.create_dict_iterator():
326            num_iter += 1
327    assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
328
329    assert num_iter == 0
330    logger.info('test_cache_failure3 Ended.\n')
331
332
333@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
334def test_cache_map_failure4():
335    """
336    Test filter under cache (failure)
337
338               repeat
339                  |
340                Cache
341                  |
342             Map(decode)
343                  |
344                Filter
345                  |
346               CelebA
347
348    """
349    logger.info("Test cache failure 4")
350    if "SESSION_ID" in os.environ:
351        session_id = int(os.environ['SESSION_ID'])
352    else:
353        raise RuntimeError("Testcase requires SESSION_ID environment variable")
354
355    some_cache = ds.DatasetCache(session_id=session_id, size=0)
356
357    # This dataset has 4 records
358    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
359    ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
360
361    decode_op = c_vision.Decode()
362    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
363    ds1 = ds1.repeat(4)
364
365    with pytest.raises(RuntimeError) as e:
366        num_iter = 0
367        for _ in ds1.create_dict_iterator():
368            num_iter += 1
369    assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
370
371    assert num_iter == 0
372    logger.info('test_cache_failure4 Ended.\n')
373
374
375@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
376def test_cache_map_failure5():
377    """
378    Test Map containing random operation under cache (failure)
379
380               repeat
381                  |
382                Cache
383                  |
384             Map(decode, randomCrop)
385                  |
386              Manifest
387
388    """
389    logger.info("Test cache failure 5")
390    if "SESSION_ID" in os.environ:
391        session_id = int(os.environ['SESSION_ID'])
392    else:
393        raise RuntimeError("Testcase requires SESSION_ID environment variable")
394
395    some_cache = ds.DatasetCache(session_id=session_id, size=0)
396
397    # This dataset has 4 records
398    data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
399    random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
400    decode_op = c_vision.Decode()
401
402    data = data.map(input_columns=["image"], operations=decode_op)
403    data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
404    data = data.repeat(4)
405
406    with pytest.raises(RuntimeError) as e:
407        num_iter = 0
408        for _ in data.create_dict_iterator():
409            num_iter += 1
410    assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
411
412    assert num_iter == 0
413    logger.info('test_cache_failure5 Ended.\n')
414
415
416@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
417def test_cache_map_failure7():
418    """
419    Test no-cache-supporting Generator leaf with Map under cache (failure)
420
421               repeat
422                  |
423                Cache
424                  |
425            Map(lambda x: x)
426                  |
427              Generator
428
429    """
430
431    def generator_1d():
432        for i in range(64):
433            yield (np.array(i),)
434
435    logger.info("Test cache failure 7")
436    if "SESSION_ID" in os.environ:
437        session_id = int(os.environ['SESSION_ID'])
438    else:
439        raise RuntimeError("Testcase requires SESSION_ID environment variable")
440
441    some_cache = ds.DatasetCache(session_id=session_id, size=0)
442
443    data = ds.GeneratorDataset(generator_1d, ["data"])
444    data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache)
445    data = data.repeat(4)
446
447    with pytest.raises(RuntimeError) as e:
448        num_iter = 0
449        for _ in data.create_dict_iterator():
450            num_iter += 1
451    assert "There is currently no support for GeneratorOp under cache" in str(e.value)
452
453    assert num_iter == 0
454    logger.info('test_cache_failure7 Ended.\n')
455
456
457@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
458def test_cache_map_failure8():
459    """
460    Test a repeat under mappable cache (failure)
461
462        Cache
463          |
464      Map(decode)
465          |
466        Repeat
467          |
468       Cifar10
469    """
470
471    logger.info("Test cache failure 8")
472    if "SESSION_ID" in os.environ:
473        session_id = int(os.environ['SESSION_ID'])
474    else:
475        raise RuntimeError("Testcase requires SESSION_ID environment variable")
476
477    some_cache = ds.DatasetCache(session_id=session_id, size=0)
478
479    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10)
480    decode_op = c_vision.Decode()
481    ds1 = ds1.repeat(4)
482    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
483
484    with pytest.raises(RuntimeError) as e:
485        num_iter = 0
486        for _ in ds1.create_dict_iterator(num_epochs=1):
487            num_iter += 1
488    assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value)
489
490    assert num_iter == 0
491    logger.info('test_cache_failure8 Ended.\n')
492
493
494@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
495def test_cache_map_failure9():
496    """
497    Test take under cache (failure)
498
499               repeat
500                  |
501                Cache
502                  |
503             Map(decode)
504                  |
505                Take
506                  |
507             Cifar100
508
509    """
510    logger.info("Test cache failure 9")
511    if "SESSION_ID" in os.environ:
512        session_id = int(os.environ['SESSION_ID'])
513    else:
514        raise RuntimeError("Testcase requires SESSION_ID environment variable")
515
516    some_cache = ds.DatasetCache(session_id=session_id, size=0)
517
518    ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
519    ds1 = ds1.take(2)
520
521    decode_op = c_vision.Decode()
522    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
523    ds1 = ds1.repeat(4)
524
525    with pytest.raises(RuntimeError) as e:
526        num_iter = 0
527        for _ in ds1.create_dict_iterator():
528            num_iter += 1
529    assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
530
531    assert num_iter == 0
532    logger.info('test_cache_failure9 Ended.\n')
533
534
535@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
536def test_cache_map_failure10():
537    """
538    Test skip under cache (failure)
539
540               repeat
541                  |
542                Cache
543                  |
544             Map(decode)
545                  |
546                Skip
547                  |
548                VOC
549
550    """
551    logger.info("Test cache failure 10")
552    if "SESSION_ID" in os.environ:
553        session_id = int(os.environ['SESSION_ID'])
554    else:
555        raise RuntimeError("Testcase requires SESSION_ID environment variable")
556
557    some_cache = ds.DatasetCache(session_id=session_id, size=0)
558
559    # This dataset has 9 records
560    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
561    ds1 = ds1.skip(1)
562
563    decode_op = c_vision.Decode()
564    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
565    ds1 = ds1.repeat(4)
566
567    with pytest.raises(RuntimeError) as e:
568        num_iter = 0
569        for _ in ds1.create_dict_iterator():
570            num_iter += 1
571    assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value)
572
573    assert num_iter == 0
574    logger.info('test_cache_failure10 Ended.\n')
575
576
577@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
578def test_cache_map_failure11():
579    """
580    Test set spilling=true when cache server is started without spilling support (failure)
581
582         Cache(spilling=true)
583                 |
584             ImageFolder
585
586    """
587    logger.info("Test cache failure 11")
588    if "SESSION_ID" in os.environ:
589        session_id = int(os.environ['SESSION_ID'])
590    else:
591        raise RuntimeError("Testcase requires SESSION_ID environment variable")
592
593    some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
594
595    # This DATA_DIR only has 2 images in it
596    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
597
598    with pytest.raises(RuntimeError) as e:
599        num_iter = 0
600        for _ in ds1.create_dict_iterator():
601            num_iter += 1
602    assert "Unexpected error. Server is not set up with spill support" in str(e.value)
603
604    assert num_iter == 0
605    logger.info('test_cache_failure11 Ended.\n')
606
607
608@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
609def test_cache_map_split1():
610    """
611    Test split (after a non-source node) under cache (failure).
612    Split after a non-source node is implemented with TakeOp/SkipOp, hence the failure.
613
614               repeat
615                  |
616                Cache
617                  |
618             Map(resize)
619                  |
620                Split
621                  |
622             Map(decode)
623                  |
624             ImageFolder
625
626    """
627    logger.info("Test cache split 1")
628    if "SESSION_ID" in os.environ:
629        session_id = int(os.environ['SESSION_ID'])
630    else:
631        raise RuntimeError("Testcase requires SESSION_ID environment variable")
632
633    some_cache = ds.DatasetCache(session_id=session_id, size=0)
634
635    # This DATA_DIR only has 2 images in it
636    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
637
638    decode_op = c_vision.Decode()
639    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
640    ds1, ds2 = ds1.split([0.5, 0.5])
641    resize_op = c_vision.Resize((224, 224))
642    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
643    ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache)
644    ds1 = ds1.repeat(4)
645    ds2 = ds2.repeat(4)
646
647    with pytest.raises(RuntimeError) as e:
648        num_iter = 0
649        for _ in ds1.create_dict_iterator():
650            num_iter += 1
651    assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
652
653    with pytest.raises(RuntimeError) as e:
654        num_iter = 0
655        for _ in ds2.create_dict_iterator():
656            num_iter += 1
657    assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
658    logger.info('test_cache_split1 Ended.\n')
659
660
661@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
662def test_cache_map_split2():
663    """
664    Test split (after a source node) under cache (ok).
665    Split after a source node is implemented with subset sampler, hence ok.
666
667               repeat
668                  |
669                Cache
670                  |
671             Map(resize)
672                  |
673                Split
674                  |
675             VOCDataset
676
677    """
678    logger.info("Test cache split 2")
679    if "SESSION_ID" in os.environ:
680        session_id = int(os.environ['SESSION_ID'])
681    else:
682        raise RuntimeError("Testcase requires SESSION_ID environment variable")
683
684    some_cache = ds.DatasetCache(session_id=session_id, size=0)
685
686    # This dataset has 9 records
687    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
688
689    ds1, ds2 = ds1.split([0.3, 0.7])
690    resize_op = c_vision.Resize((224, 224))
691    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
692    ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache)
693    ds1 = ds1.repeat(4)
694    ds2 = ds2.repeat(4)
695
696    num_iter = 0
697    for _ in ds1.create_dict_iterator():
698        num_iter += 1
699    assert num_iter == 12
700
701    num_iter = 0
702    for _ in ds2.create_dict_iterator():
703        num_iter += 1
704    assert num_iter == 24
705    logger.info('test_cache_split2 Ended.\n')
706
707
708@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
709def test_cache_map_parameter_check():
710    """
711    Test illegal parameters for DatasetCache
712    """
713
714    logger.info("Test cache map parameter check")
715
716    with pytest.raises(ValueError) as info:
717        ds.DatasetCache(session_id=-1, size=0)
718    assert "Input is not within the required interval" in str(info.value)
719
720    with pytest.raises(TypeError) as info:
721        ds.DatasetCache(session_id="1", size=0)
722    assert "Argument session_id with value 1 is not of type" in str(info.value)
723
724    with pytest.raises(TypeError) as info:
725        ds.DatasetCache(session_id=None, size=0)
726    assert "Argument session_id with value None is not of type" in str(info.value)
727
728    with pytest.raises(ValueError) as info:
729        ds.DatasetCache(session_id=1, size=-1)
730    assert "Input size must be greater than 0" in str(info.value)
731
732    with pytest.raises(TypeError) as info:
733        ds.DatasetCache(session_id=1, size="1")
734    assert "Argument size with value 1 is not of type" in str(info.value)
735
736    with pytest.raises(TypeError) as info:
737        ds.DatasetCache(session_id=1, size=None)
738    assert "Argument size with value None is not of type" in str(info.value)
739
740    with pytest.raises(TypeError) as info:
741        ds.DatasetCache(session_id=1, size=0, spilling="illegal")
742    assert "Argument spilling with value illegal is not of type" in str(info.value)
743
744    with pytest.raises(TypeError) as err:
745        ds.DatasetCache(session_id=1, size=0, hostname=50052)
746    assert "Argument hostname with value 50052 is not of type" in str(err.value)
747
748    with pytest.raises(RuntimeError) as err:
749        ds.DatasetCache(session_id=1, size=0, hostname="illegal")
750    assert "now cache client has to be on the same host with cache server" in str(err.value)
751
752    with pytest.raises(RuntimeError) as err:
753        ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2")
754    assert "now cache client has to be on the same host with cache server" in str(err.value)
755
756    with pytest.raises(TypeError) as info:
757        ds.DatasetCache(session_id=1, size=0, port="illegal")
758    assert "Argument port with value illegal is not of type" in str(info.value)
759
760    with pytest.raises(TypeError) as info:
761        ds.DatasetCache(session_id=1, size=0, port="50052")
762    assert "Argument port with value 50052 is not of type" in str(info.value)
763
764    with pytest.raises(ValueError) as err:
765        ds.DatasetCache(session_id=1, size=0, port=0)
766    assert "Input port is not within the required interval of [1025, 65535]" in str(err.value)
767
768    with pytest.raises(ValueError) as err:
769        ds.DatasetCache(session_id=1, size=0, port=65536)
770    assert "Input port is not within the required interval of [1025, 65535]" in str(err.value)
771
772    with pytest.raises(TypeError) as err:
773        ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)
774    assert "Argument cache with value True is not of type" in str(err.value)
775
776    logger.info("test_cache_map_parameter_check Ended.\n")
777
778
779@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
780def test_cache_map_running_twice1():
781    """
782    Executing the same pipeline for twice (from python), with cache injected after map
783
784       Repeat
785         |
786       Cache
787         |
788     Map(decode)
789         |
790     ImageFolder
791    """
792
793    logger.info("Test cache map running twice 1")
794    if "SESSION_ID" in os.environ:
795        session_id = int(os.environ['SESSION_ID'])
796    else:
797        raise RuntimeError("Testcase requires SESSION_ID environment variable")
798
799    some_cache = ds.DatasetCache(session_id=session_id, size=0)
800
801    # This DATA_DIR only has 2 images in it
802    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
803    decode_op = c_vision.Decode()
804    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
805    ds1 = ds1.repeat(4)
806
807    num_iter = 0
808    for _ in ds1.create_dict_iterator():
809        num_iter += 1
810    logger.info("Number of data in ds1: {} ".format(num_iter))
811    assert num_iter == 8
812
813    num_iter = 0
814    for _ in ds1.create_dict_iterator():
815        num_iter += 1
816    logger.info("Number of data in ds1: {} ".format(num_iter))
817    assert num_iter == 8
818
819    logger.info("test_cache_map_running_twice1 Ended.\n")
820
821
822@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
823def test_cache_map_running_twice2():
824    """
825    Executing the same pipeline for twice (from shell), with cache injected after leaf
826
827       Repeat
828         |
829     Map(decode)
830         |
831       Cache
832         |
833     ImageFolder
834    """
835
836    logger.info("Test cache map running twice 2")
837    if "SESSION_ID" in os.environ:
838        session_id = int(os.environ['SESSION_ID'])
839    else:
840        raise RuntimeError("Testcase requires SESSION_ID environment variable")
841
842    some_cache = ds.DatasetCache(session_id=session_id, size=0)
843
844    # This DATA_DIR only has 2 images in it
845    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
846    decode_op = c_vision.Decode()
847    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
848    ds1 = ds1.repeat(4)
849
850    num_iter = 0
851    for _ in ds1.create_dict_iterator():
852        num_iter += 1
853
854    logger.info("Number of data in ds1: {} ".format(num_iter))
855    assert num_iter == 8
856    logger.info("test_cache_map_running_twice2 Ended.\n")
857
858
859@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
860def test_cache_map_extra_small_size1():
861    """
862    Test running pipeline with cache of extra small size and spilling true
863
864       Repeat
865         |
866     Map(decode)
867         |
868       Cache
869         |
870     ImageFolder
871    """
872
873    logger.info("Test cache map extra small size 1")
874    if "SESSION_ID" in os.environ:
875        session_id = int(os.environ['SESSION_ID'])
876    else:
877        raise RuntimeError("Testcase requires SESSION_ID environment variable")
878
879    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
880
881    # This DATA_DIR only has 2 images in it
882    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
883    decode_op = c_vision.Decode()
884    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
885    ds1 = ds1.repeat(4)
886
887    num_iter = 0
888    for _ in ds1.create_dict_iterator():
889        num_iter += 1
890
891    logger.info("Number of data in ds1: {} ".format(num_iter))
892    assert num_iter == 8
893    logger.info("test_cache_map_extra_small_size1 Ended.\n")
894
895
896@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
897def test_cache_map_extra_small_size2():
898    """
899    Test running pipeline with cache of extra small size and spilling false
900
901       Repeat
902         |
903       Cache
904         |
905     Map(decode)
906         |
907     ImageFolder
908    """
909
910    logger.info("Test cache map extra small size 2")
911    if "SESSION_ID" in os.environ:
912        session_id = int(os.environ['SESSION_ID'])
913    else:
914        raise RuntimeError("Testcase requires SESSION_ID environment variable")
915
916    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
917
918    # This DATA_DIR only has 2 images in it
919    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
920    decode_op = c_vision.Decode()
921    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
922    ds1 = ds1.repeat(4)
923
924    num_iter = 0
925    for _ in ds1.create_dict_iterator():
926        num_iter += 1
927
928    logger.info("Number of data in ds1: {} ".format(num_iter))
929    assert num_iter == 8
930    logger.info("test_cache_map_extra_small_size2 Ended.\n")
931
932
933@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
934def test_cache_map_no_image():
935    """
936    Test cache with no dataset existing in the path
937
938       Repeat
939         |
940     Map(decode)
941         |
942       Cache
943         |
944     ImageFolder
945    """
946
947    logger.info("Test cache map no image")
948    if "SESSION_ID" in os.environ:
949        session_id = int(os.environ['SESSION_ID'])
950    else:
951        raise RuntimeError("Testcase requires SESSION_ID environment variable")
952
953    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
954
955    # This DATA_DIR only has 2 images in it
956    ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache)
957    decode_op = c_vision.Decode()
958    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
959    ds1 = ds1.repeat(4)
960
961    with pytest.raises(RuntimeError):
962        num_iter = 0
963        for _ in ds1.create_dict_iterator():
964            num_iter += 1
965
966    assert num_iter == 0
967    logger.info("test_cache_map_no_image Ended.\n")
968
969
970@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
971def test_cache_map_parallel_pipeline1(shard):
972    """
973    Test running two parallel pipelines (sharing cache) with cache injected after leaf op
974
975       Repeat
976         |
977     Map(decode)
978         |
979       Cache
980         |
981     ImageFolder
982    """
983
984    logger.info("Test cache map parallel pipeline 1")
985    if "SESSION_ID" in os.environ:
986        session_id = int(os.environ['SESSION_ID'])
987    else:
988        raise RuntimeError("Testcase requires SESSION_ID environment variable")
989
990    some_cache = ds.DatasetCache(session_id=session_id, size=0)
991
992    # This DATA_DIR only has 2 images in it
993    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
994    decode_op = c_vision.Decode()
995    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
996    ds1 = ds1.repeat(4)
997
998    num_iter = 0
999    for _ in ds1.create_dict_iterator():
1000        num_iter += 1
1001
1002    logger.info("Number of data in ds1: {} ".format(num_iter))
1003    assert num_iter == 4
1004    logger.info("test_cache_map_parallel_pipeline1 Ended.\n")
1005
1006
1007@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1008def test_cache_map_parallel_pipeline2(shard):
1009    """
1010    Test running two parallel pipelines (sharing cache) with cache injected after map op
1011
1012       Repeat
1013         |
1014       Cache
1015         |
1016     Map(decode)
1017         |
1018     ImageFolder
1019    """
1020
1021    logger.info("Test cache map parallel pipeline 2")
1022    if "SESSION_ID" in os.environ:
1023        session_id = int(os.environ['SESSION_ID'])
1024    else:
1025        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1026
1027    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1028
1029    # This DATA_DIR only has 2 images in it
1030    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
1031    decode_op = c_vision.Decode()
1032    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1033    ds1 = ds1.repeat(4)
1034
1035    num_iter = 0
1036    for _ in ds1.create_dict_iterator():
1037        num_iter += 1
1038
1039    logger.info("Number of data in ds1: {} ".format(num_iter))
1040    assert num_iter == 4
1041    logger.info("test_cache_map_parallel_pipeline2 Ended.\n")
1042
1043
1044@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1045def test_cache_map_parallel_workers():
1046    """
1047    Test cache with num_parallel_workers > 1 set for map op and leaf op
1048
1049       Repeat
1050         |
1051       cache
1052         |
1053     Map(decode)
1054         |
1055      ImageFolder
1056    """
1057
1058    logger.info("Test cache map parallel workers")
1059    if "SESSION_ID" in os.environ:
1060        session_id = int(os.environ['SESSION_ID'])
1061    else:
1062        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1063
1064    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1065
1066    # This DATA_DIR only has 2 images in it
1067    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
1068    decode_op = c_vision.Decode()
1069    ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
1070    ds1 = ds1.repeat(4)
1071
1072    num_iter = 0
1073    for _ in ds1.create_dict_iterator():
1074        num_iter += 1
1075
1076    logger.info("Number of data in ds1: {} ".format(num_iter))
1077    assert num_iter == 8
1078    logger.info("test_cache_map_parallel_workers Ended.\n")
1079
1080
1081@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1082def test_cache_map_server_workers_1():
1083    """
1084    start cache server with --workers 1 and then test cache function
1085
1086       Repeat
1087         |
1088       cache
1089         |
1090     Map(decode)
1091         |
1092      ImageFolder
1093    """
1094
1095    logger.info("Test cache map server workers 1")
1096    if "SESSION_ID" in os.environ:
1097        session_id = int(os.environ['SESSION_ID'])
1098    else:
1099        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1100
1101    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1102
1103    # This DATA_DIR only has 2 images in it
1104    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1105    decode_op = c_vision.Decode()
1106    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1107    ds1 = ds1.repeat(4)
1108
1109    num_iter = 0
1110    for _ in ds1.create_dict_iterator():
1111        num_iter += 1
1112
1113    logger.info("Number of data in ds1: {} ".format(num_iter))
1114    assert num_iter == 8
1115    logger.info("test_cache_map_server_workers_1 Ended.\n")
1116
1117
1118@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1119def test_cache_map_server_workers_100():
1120    """
1121    start cache server with --workers 100 and then test cache function
1122
1123       Repeat
1124         |
1125     Map(decode)
1126         |
1127       cache
1128         |
1129      ImageFolder
1130    """
1131
1132    logger.info("Test cache map server workers 100")
1133    if "SESSION_ID" in os.environ:
1134        session_id = int(os.environ['SESSION_ID'])
1135    else:
1136        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1137
1138    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1139
1140    # This DATA_DIR only has 2 images in it
1141    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1142    decode_op = c_vision.Decode()
1143    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1144    ds1 = ds1.repeat(4)
1145
1146    num_iter = 0
1147    for _ in ds1.create_dict_iterator():
1148        num_iter += 1
1149
1150    logger.info("Number of data in ds1: {} ".format(num_iter))
1151    assert num_iter == 8
1152    logger.info("test_cache_map_server_workers_100 Ended.\n")
1153
1154
1155@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1156def test_cache_map_num_connections_1():
1157    """
1158    Test setting num_connections=1 in DatasetCache
1159
1160       Repeat
1161         |
1162       cache
1163         |
1164     Map(decode)
1165         |
1166      ImageFolder
1167    """
1168
1169    logger.info("Test cache map num_connections 1")
1170    if "SESSION_ID" in os.environ:
1171        session_id = int(os.environ['SESSION_ID'])
1172    else:
1173        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1174
1175    some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
1176
1177    # This DATA_DIR only has 2 images in it
1178    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1179    decode_op = c_vision.Decode()
1180    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1181    ds1 = ds1.repeat(4)
1182
1183    num_iter = 0
1184    for _ in ds1.create_dict_iterator():
1185        num_iter += 1
1186
1187    logger.info("Number of data in ds1: {} ".format(num_iter))
1188    assert num_iter == 8
1189    logger.info("test_cache_map_num_connections_1 Ended.\n")
1190
1191
1192@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1193def test_cache_map_num_connections_100():
1194    """
1195    Test setting num_connections=100 in DatasetCache
1196
1197       Repeat
1198         |
1199     Map(decode)
1200         |
1201       cache
1202         |
1203      ImageFolder
1204    """
1205
1206    logger.info("Test cache map num_connections 100")
1207    if "SESSION_ID" in os.environ:
1208        session_id = int(os.environ['SESSION_ID'])
1209    else:
1210        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1211
1212    some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
1213
1214    # This DATA_DIR only has 2 images in it
1215    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1216    decode_op = c_vision.Decode()
1217    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1218    ds1 = ds1.repeat(4)
1219
1220    num_iter = 0
1221    for _ in ds1.create_dict_iterator():
1222        num_iter += 1
1223
1224    logger.info("Number of data in ds1: {} ".format(num_iter))
1225    assert num_iter == 8
1226    logger.info("test_cache_map_num_connections_100 Ended.\n")
1227
1228
1229@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1230def test_cache_map_prefetch_size_1():
1231    """
1232    Test setting prefetch_size=1 in DatasetCache
1233
1234       Repeat
1235         |
1236       cache
1237         |
1238     Map(decode)
1239         |
1240      ImageFolder
1241    """
1242
1243    logger.info("Test cache map prefetch_size 1")
1244    if "SESSION_ID" in os.environ:
1245        session_id = int(os.environ['SESSION_ID'])
1246    else:
1247        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1248
1249    some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
1250
1251    # This DATA_DIR only has 2 images in it
1252    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1253    decode_op = c_vision.Decode()
1254    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1255    ds1 = ds1.repeat(4)
1256
1257    num_iter = 0
1258    for _ in ds1.create_dict_iterator():
1259        num_iter += 1
1260
1261    logger.info("Number of data in ds1: {} ".format(num_iter))
1262    assert num_iter == 8
1263    logger.info("test_cache_map_prefetch_size_1 Ended.\n")
1264
1265
1266@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1267def test_cache_map_prefetch_size_100():
1268    """
1269    Test setting prefetch_size=100 in DatasetCache
1270
1271       Repeat
1272         |
1273     Map(decode)
1274         |
1275       cache
1276         |
1277      ImageFolder
1278    """
1279
1280    logger.info("Test cache map prefetch_size 100")
1281    if "SESSION_ID" in os.environ:
1282        session_id = int(os.environ['SESSION_ID'])
1283    else:
1284        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1285
1286    some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
1287
1288    # This DATA_DIR only has 2 images in it
1289    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1290    decode_op = c_vision.Decode()
1291    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1292    ds1 = ds1.repeat(4)
1293
1294    num_iter = 0
1295    for _ in ds1.create_dict_iterator():
1296        num_iter += 1
1297
1298    logger.info("Number of data in ds1: {} ".format(num_iter))
1299    assert num_iter == 8
1300    logger.info("test_cache_map_prefetch_size_100 Ended.\n")
1301
1302
1303@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1304def test_cache_map_to_device():
1305    """
1306    Test cache with to_device
1307
1308     DeviceQueue
1309         |
1310      EpochCtrl
1311         |
1312       Repeat
1313         |
1314     Map(decode)
1315         |
1316       cache
1317         |
1318      ImageFolder
1319    """
1320
1321    logger.info("Test cache map to_device")
1322    if "SESSION_ID" in os.environ:
1323        session_id = int(os.environ['SESSION_ID'])
1324    else:
1325        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1326
1327    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1328
1329    # This DATA_DIR only has 2 images in it
1330    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1331    decode_op = c_vision.Decode()
1332    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1333    ds1 = ds1.repeat(4)
1334    ds1 = ds1.to_device()
1335    ds1.send()
1336
1337    logger.info("test_cache_map_to_device Ended.\n")
1338
1339
1340@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1341def test_cache_map_epoch_ctrl1():
1342    """
1343    Test using two-loops method to run several epochs
1344
1345     Map(decode)
1346         |
1347       cache
1348         |
1349      ImageFolder
1350    """
1351
1352    logger.info("Test cache map epoch ctrl1")
1353    if "SESSION_ID" in os.environ:
1354        session_id = int(os.environ['SESSION_ID'])
1355    else:
1356        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1357
1358    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1359
1360    # This DATA_DIR only has 2 images in it
1361    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1362    decode_op = c_vision.Decode()
1363    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1364
1365    num_epoch = 5
1366    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1367
1368    epoch_count = 0
1369    for _ in range(num_epoch):
1370        row_count = 0
1371        for _ in iter1:
1372            row_count += 1
1373        logger.info("Number of data in ds1: {} ".format(row_count))
1374        assert row_count == 2
1375        epoch_count += 1
1376    assert epoch_count == num_epoch
1377    logger.info("test_cache_map_epoch_ctrl1 Ended.\n")
1378
1379
1380@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1381def test_cache_map_epoch_ctrl2():
1382    """
1383    Test using two-loops method with infinite epochs
1384
1385        cache
1386         |
1387     Map(decode)
1388         |
1389      ImageFolder
1390    """
1391
1392    logger.info("Test cache map epoch ctrl2")
1393    if "SESSION_ID" in os.environ:
1394        session_id = int(os.environ['SESSION_ID'])
1395    else:
1396        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1397
1398    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1399
1400    # This DATA_DIR only has 2 images in it
1401    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
1402    decode_op = c_vision.Decode()
1403    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1404
1405    num_epoch = 5
1406    # iter1 will always assume there is a next epoch and never shutdown
1407    iter1 = ds1.create_dict_iterator()
1408
1409    epoch_count = 0
1410    for _ in range(num_epoch):
1411        row_count = 0
1412        for _ in iter1:
1413            row_count += 1
1414        logger.info("Number of data in ds1: {} ".format(row_count))
1415        assert row_count == 2
1416        epoch_count += 1
1417    assert epoch_count == num_epoch
1418
1419    # manually stop the iterator
1420    iter1.stop()
1421    logger.info("test_cache_map_epoch_ctrl2 Ended.\n")
1422
1423
1424@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1425def test_cache_map_epoch_ctrl3():
1426    """
1427    Test using two-loops method with infinite epochs over repeat
1428
1429       repeat
1430         |
1431     Map(decode)
1432         |
1433       cache
1434         |
1435      ImageFolder
1436    """
1437
1438    logger.info("Test cache map epoch ctrl3")
1439    if "SESSION_ID" in os.environ:
1440        session_id = int(os.environ['SESSION_ID'])
1441    else:
1442        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1443
1444    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1445
1446    # This DATA_DIR only has 2 images in it
1447    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
1448    decode_op = c_vision.Decode()
1449    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1450    ds1 = ds1.repeat(2)
1451
1452    num_epoch = 5
1453    # iter1 will always assume there is a next epoch and never shutdown
1454    iter1 = ds1.create_dict_iterator()
1455
1456    epoch_count = 0
1457    for _ in range(num_epoch):
1458        row_count = 0
1459        for _ in iter1:
1460            row_count += 1
1461        logger.info("Number of data in ds1: {} ".format(row_count))
1462        assert row_count == 4
1463        epoch_count += 1
1464    assert epoch_count == num_epoch
1465
1466    # reply on garbage collector to destroy iter1
1467
1468    logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
1469
1470
1471@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1472def test_cache_map_coco1():
1473    """
1474    Test mappable coco leaf with cache op right over the leaf
1475
1476       cache
1477         |
1478       Coco
1479    """
1480
1481    logger.info("Test cache map coco1")
1482    if "SESSION_ID" in os.environ:
1483        session_id = int(os.environ['SESSION_ID'])
1484    else:
1485        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1486
1487    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1488
1489    # This dataset has 6 records
1490    ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
1491                         cache=some_cache)
1492
1493    num_epoch = 4
1494    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1495
1496    epoch_count = 0
1497    for _ in range(num_epoch):
1498        assert sum([1 for _ in iter1]) == 6
1499        epoch_count += 1
1500    assert epoch_count == num_epoch
1501
1502    logger.info("test_cache_map_coco1 Ended.\n")
1503
1504
1505@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1506def test_cache_map_coco2():
1507    """
1508    Test mappable coco leaf with the cache op later in the tree above the map(resize)
1509
1510       cache
1511         |
1512     Map(resize)
1513         |
1514       Coco
1515    """
1516
1517    logger.info("Test cache map coco2")
1518    if "SESSION_ID" in os.environ:
1519        session_id = int(os.environ['SESSION_ID'])
1520    else:
1521        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1522
1523    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1524
1525    # This dataset has 6 records
1526    ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
1527    resize_op = c_vision.Resize((224, 224))
1528    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
1529
1530    num_epoch = 4
1531    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1532
1533    epoch_count = 0
1534    for _ in range(num_epoch):
1535        assert sum([1 for _ in iter1]) == 6
1536        epoch_count += 1
1537    assert epoch_count == num_epoch
1538
1539    logger.info("test_cache_map_coco2 Ended.\n")
1540
1541
1542@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1543def test_cache_map_mnist1():
1544    """
1545    Test mappable mnist leaf with cache op right over the leaf
1546
1547       cache
1548         |
1549       Mnist
1550    """
1551
1552    logger.info("Test cache map mnist1")
1553    if "SESSION_ID" in os.environ:
1554        session_id = int(os.environ['SESSION_ID'])
1555    else:
1556        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1557
1558    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1559    ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
1560
1561    num_epoch = 4
1562    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1563
1564    epoch_count = 0
1565    for _ in range(num_epoch):
1566        assert sum([1 for _ in iter1]) == 10
1567        epoch_count += 1
1568    assert epoch_count == num_epoch
1569
1570    logger.info("test_cache_map_mnist1 Ended.\n")
1571
1572
1573@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1574def test_cache_map_mnist2():
1575    """
1576    Test mappable mnist leaf with the cache op later in the tree above the map(resize)
1577
1578       cache
1579         |
1580     Map(resize)
1581         |
1582       Mnist
1583    """
1584
1585    logger.info("Test cache map mnist2")
1586    if "SESSION_ID" in os.environ:
1587        session_id = int(os.environ['SESSION_ID'])
1588    else:
1589        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1590
1591    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1592    ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
1593
1594    resize_op = c_vision.Resize((224, 224))
1595    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
1596
1597    num_epoch = 4
1598    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1599
1600    epoch_count = 0
1601    for _ in range(num_epoch):
1602        assert sum([1 for _ in iter1]) == 10
1603        epoch_count += 1
1604    assert epoch_count == num_epoch
1605
1606    logger.info("test_cache_map_mnist2 Ended.\n")
1607
1608
1609@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1610def test_cache_map_celeba1():
1611    """
1612    Test mappable celeba leaf with cache op right over the leaf
1613
1614       cache
1615         |
1616       CelebA
1617    """
1618
1619    logger.info("Test cache map celeba1")
1620    if "SESSION_ID" in os.environ:
1621        session_id = int(os.environ['SESSION_ID'])
1622    else:
1623        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1624
1625    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1626
1627    # This dataset has 4 records
1628    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
1629
1630    num_epoch = 4
1631    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1632
1633    epoch_count = 0
1634    for _ in range(num_epoch):
1635        assert sum([1 for _ in iter1]) == 4
1636        epoch_count += 1
1637    assert epoch_count == num_epoch
1638
1639    logger.info("test_cache_map_celeba1 Ended.\n")
1640
1641
1642@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1643def test_cache_map_celeba2():
1644    """
1645    Test mappable celeba leaf with the cache op later in the tree above the map(resize)
1646
1647       cache
1648         |
1649     Map(resize)
1650         |
1651       CelebA
1652    """
1653
1654    logger.info("Test cache map celeba2")
1655    if "SESSION_ID" in os.environ:
1656        session_id = int(os.environ['SESSION_ID'])
1657    else:
1658        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1659
1660    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1661
1662    # This dataset has 4 records
1663    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
1664    resize_op = c_vision.Resize((224, 224))
1665    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
1666
1667    num_epoch = 4
1668    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1669
1670    epoch_count = 0
1671    for _ in range(num_epoch):
1672        assert sum([1 for _ in iter1]) == 4
1673        epoch_count += 1
1674    assert epoch_count == num_epoch
1675
1676    logger.info("test_cache_map_celeba2 Ended.\n")
1677
1678
1679@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1680def test_cache_map_manifest1():
1681    """
1682    Test mappable manifest leaf with cache op right over the leaf
1683
1684       cache
1685         |
1686      Manifest
1687    """
1688
1689    logger.info("Test cache map manifest1")
1690    if "SESSION_ID" in os.environ:
1691        session_id = int(os.environ['SESSION_ID'])
1692    else:
1693        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1694
1695    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1696
1697    # This dataset has 4 records
1698    ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
1699
1700    num_epoch = 4
1701    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1702
1703    epoch_count = 0
1704    for _ in range(num_epoch):
1705        assert sum([1 for _ in iter1]) == 4
1706        epoch_count += 1
1707    assert epoch_count == num_epoch
1708
1709    logger.info("test_cache_map_manifest1 Ended.\n")
1710
1711
1712@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1713def test_cache_map_manifest2():
1714    """
1715    Test mappable manifest leaf with the cache op later in the tree above the map(resize)
1716
1717       cache
1718         |
1719     Map(resize)
1720         |
1721      Manifest
1722    """
1723
1724    logger.info("Test cache map manifest2")
1725    if "SESSION_ID" in os.environ:
1726        session_id = int(os.environ['SESSION_ID'])
1727    else:
1728        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1729
1730    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1731
1732    # This dataset has 4 records
1733    ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
1734    resize_op = c_vision.Resize((224, 224))
1735    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
1736
1737    num_epoch = 4
1738    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1739
1740    epoch_count = 0
1741    for _ in range(num_epoch):
1742        assert sum([1 for _ in iter1]) == 4
1743        epoch_count += 1
1744    assert epoch_count == num_epoch
1745
1746    logger.info("test_cache_map_manifest2 Ended.\n")
1747
1748
1749@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1750def test_cache_map_cifar1():
1751    """
1752    Test mappable cifar10 leaf with cache op right over the leaf
1753
1754       cache
1755         |
1756      Cifar10
1757    """
1758
1759    logger.info("Test cache map cifar1")
1760    if "SESSION_ID" in os.environ:
1761        session_id = int(os.environ['SESSION_ID'])
1762    else:
1763        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1764
1765    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1766    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
1767
1768    num_epoch = 4
1769    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1770
1771    epoch_count = 0
1772    for _ in range(num_epoch):
1773        assert sum([1 for _ in iter1]) == 10
1774        epoch_count += 1
1775    assert epoch_count == num_epoch
1776
1777    logger.info("test_cache_map_cifar1 Ended.\n")
1778
1779
1780@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1781def test_cache_map_cifar2():
1782    """
1783    Test mappable cifar100 leaf with the cache op later in the tree above the map(resize)
1784
1785       cache
1786         |
1787     Map(resize)
1788         |
1789      Cifar100
1790    """
1791
1792    logger.info("Test cache map cifar2")
1793    if "SESSION_ID" in os.environ:
1794        session_id = int(os.environ['SESSION_ID'])
1795    else:
1796        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1797
1798    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1799
1800    ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
1801    resize_op = c_vision.Resize((224, 224))
1802    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
1803
1804    num_epoch = 4
1805    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1806
1807    epoch_count = 0
1808    for _ in range(num_epoch):
1809        assert sum([1 for _ in iter1]) == 10
1810        epoch_count += 1
1811    assert epoch_count == num_epoch
1812
1813    logger.info("test_cache_map_cifar2 Ended.\n")
1814
1815
1816@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1817def test_cache_map_cifar3():
1818    """
1819    Test mappable cifar10 leaf with the cache op later in the tree above the map(resize)
1820    In this case, we set a extra-small size for cache (size=1) and there are 10000 rows in the dataset.
1821
1822       cache
1823         |
1824      Cifar10
1825    """
1826
1827    logger.info("Test cache map cifar3")
1828    if "SESSION_ID" in os.environ:
1829        session_id = int(os.environ['SESSION_ID'])
1830    else:
1831        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1832
1833    some_cache = ds.DatasetCache(session_id=session_id, size=1)
1834
1835    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
1836
1837    num_epoch = 2
1838    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1839
1840    epoch_count = 0
1841    for _ in range(num_epoch):
1842        assert sum([1 for _ in iter1]) == 10000
1843        epoch_count += 1
1844    assert epoch_count == num_epoch
1845
1846    logger.info("test_cache_map_cifar3 Ended.\n")
1847
1848
1849@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1850def test_cache_map_cifar4():
1851    """
1852    Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op
1853
1854       shuffle
1855         |
1856       cache
1857         |
1858      Cifar10
1859    """
1860
1861    logger.info("Test cache map cifar4")
1862    if "SESSION_ID" in os.environ:
1863        session_id = int(os.environ['SESSION_ID'])
1864    else:
1865        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1866
1867    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1868    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
1869    ds1 = ds1.shuffle(10)
1870
1871    num_epoch = 1
1872    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1873
1874    epoch_count = 0
1875    for _ in range(num_epoch):
1876        assert sum([1 for _ in iter1]) == 10
1877        epoch_count += 1
1878    assert epoch_count == num_epoch
1879
1880    logger.info("test_cache_map_cifar4 Ended.\n")
1881
1882
1883@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1884def test_cache_map_voc1():
1885    """
1886    Test mappable voc leaf with cache op right over the leaf
1887
1888       cache
1889         |
1890       VOC
1891    """
1892
1893    logger.info("Test cache map voc1")
1894    if "SESSION_ID" in os.environ:
1895        session_id = int(os.environ['SESSION_ID'])
1896    else:
1897        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1898
1899    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1900
1901    # This dataset has 9 records
1902    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
1903
1904    num_epoch = 4
1905    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1906
1907    epoch_count = 0
1908    for _ in range(num_epoch):
1909        assert sum([1 for _ in iter1]) == 9
1910        epoch_count += 1
1911    assert epoch_count == num_epoch
1912
1913    logger.info("test_cache_map_voc1 Ended.\n")
1914
1915
1916@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1917def test_cache_map_voc2():
1918    """
1919    Test mappable voc leaf with the cache op later in the tree above the map(resize)
1920
1921       cache
1922         |
1923     Map(resize)
1924         |
1925       VOC
1926    """
1927
1928    logger.info("Test cache map voc2")
1929    if "SESSION_ID" in os.environ:
1930        session_id = int(os.environ['SESSION_ID'])
1931    else:
1932        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1933
1934    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1935
1936    # This dataset has 9 records
1937    ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
1938    resize_op = c_vision.Resize((224, 224))
1939    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
1940
1941    num_epoch = 4
1942    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1943
1944    epoch_count = 0
1945    for _ in range(num_epoch):
1946        assert sum([1 for _ in iter1]) == 9
1947        epoch_count += 1
1948    assert epoch_count == num_epoch
1949
1950    logger.info("test_cache_map_voc2 Ended.\n")
1951
1952
1953class ReverseSampler(ds.Sampler):
1954    def __iter__(self):
1955        for i in range(self.dataset_size - 1, -1, -1):
1956            yield i
1957
1958
1959@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1960def test_cache_map_mindrecord1():
1961    """
1962    Test mappable mindrecord leaf with cache op right over the leaf
1963
1964       cache
1965         |
1966    MindRecord
1967    """
1968
1969    logger.info("Test cache map mindrecord1")
1970    if "SESSION_ID" in os.environ:
1971        session_id = int(os.environ['SESSION_ID'])
1972    else:
1973        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1974
1975    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1976
1977    # This dataset has 5 records
1978    columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
1979    ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, cache=some_cache)
1980
1981    num_epoch = 4
1982    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1983
1984    epoch_count = 0
1985    for _ in range(num_epoch):
1986        assert sum([1 for _ in iter1]) == 5
1987        epoch_count += 1
1988    assert epoch_count == num_epoch
1989
1990    logger.info("test_cache_map_mindrecord1 Ended.\n")
1991
1992
1993@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1994def test_cache_map_mindrecord2():
1995    """
1996    Test mappable mindrecord leaf with the cache op later in the tree above the map(decode)
1997
1998       cache
1999         |
2000     Map(decode)
2001         |
2002     MindRecord
2003    """
2004
2005    logger.info("Test cache map mindrecord2")
2006    if "SESSION_ID" in os.environ:
2007        session_id = int(os.environ['SESSION_ID'])
2008    else:
2009        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2010
2011    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2012
2013    # This dataset has 5 records
2014    columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
2015    ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list)
2016
2017    decode_op = c_vision.Decode()
2018    ds1 = ds1.map(input_columns=["img_data"], operations=decode_op, cache=some_cache)
2019
2020    num_epoch = 4
2021    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
2022
2023    epoch_count = 0
2024    for _ in range(num_epoch):
2025        assert sum([1 for _ in iter1]) == 5
2026        epoch_count += 1
2027    assert epoch_count == num_epoch
2028
2029    logger.info("test_cache_map_mindrecord2 Ended.\n")
2030
2031
2032@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2033def test_cache_map_mindrecord3():
2034    """
2035    Test cache sharing between the following two pipelines with mindrecord leaf:
2036
2037       Cache                                    Cache
2038         |                                      |
2039     Map(decode)                            Map(decode)
2040         |                                      |
2041      MindRecord(num_parallel_workers=5)    MindRecord(num_parallel_workers=6)
2042    """
2043
2044    logger.info("Test cache map mindrecord3")
2045    if "SESSION_ID" in os.environ:
2046        session_id = int(os.environ['SESSION_ID'])
2047    else:
2048        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2049
2050    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2051
2052    # This dataset has 5 records
2053    columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
2054    decode_op = c_vision.Decode()
2055
2056    ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list, num_parallel_workers=5, shuffle=True)
2057    ds1 = ds1.map(input_columns=["img_data"], operations=decode_op, cache=some_cache)
2058
2059    ds2 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list=columns_list, num_parallel_workers=6, shuffle=True)
2060    ds2 = ds2.map(input_columns=["img_data"], operations=decode_op, cache=some_cache)
2061
2062    iter1 = ds1.create_dict_iterator(num_epochs=1, output_numpy=True)
2063    iter2 = ds2.create_dict_iterator(num_epochs=1, output_numpy=True)
2064
2065    assert sum([1 for _ in iter1]) == 5
2066    assert sum([1 for _ in iter2]) == 5
2067
2068    logger.info("test_cache_map_mindrecord3 Ended.\n")
2069
2070
2071@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2072def test_cache_map_python_sampler1():
2073    """
2074    Test using a python sampler, and cache after leaf
2075
2076        Repeat
2077         |
2078     Map(decode)
2079         |
2080       cache
2081         |
2082      ImageFolder
2083    """
2084
2085    logger.info("Test cache map python sampler1")
2086    if "SESSION_ID" in os.environ:
2087        session_id = int(os.environ['SESSION_ID'])
2088    else:
2089        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2090
2091    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2092
2093    # This DATA_DIR only has 2 images in it
2094    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
2095    decode_op = c_vision.Decode()
2096    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
2097    ds1 = ds1.repeat(4)
2098
2099    num_iter = 0
2100    for _ in ds1.create_dict_iterator():
2101        num_iter += 1
2102    logger.info("Number of data in ds1: {} ".format(num_iter))
2103    assert num_iter == 8
2104    logger.info("test_cache_map_python_sampler1 Ended.\n")
2105
2106
2107@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2108def test_cache_map_python_sampler2():
2109    """
2110    Test using a python sampler, and cache after map
2111
2112       Repeat
2113         |
2114       cache
2115         |
2116     Map(decode)
2117         |
2118      ImageFolder
2119    """
2120
2121    logger.info("Test cache map python sampler2")
2122    if "SESSION_ID" in os.environ:
2123        session_id = int(os.environ['SESSION_ID'])
2124    else:
2125        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2126
2127    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2128
2129    # This DATA_DIR only has 2 images in it
2130    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
2131    decode_op = c_vision.Decode()
2132    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
2133    ds1 = ds1.repeat(4)
2134
2135    num_iter = 0
2136    for _ in ds1.create_dict_iterator():
2137        num_iter += 1
2138    logger.info("Number of data in ds1: {} ".format(num_iter))
2139    assert num_iter == 8
2140    logger.info("test_cache_map_python_sampler2 Ended.\n")
2141
2142
2143@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2144def test_cache_map_nested_repeat():
2145    """
2146    Test cache on pipeline with nested repeat ops
2147
2148        Repeat
2149          |
2150      Map(decode)
2151          |
2152        Repeat
2153          |
2154        Cache
2155          |
2156      ImageFolder
2157    """
2158
2159    logger.info("Test cache map nested repeat")
2160    if "SESSION_ID" in os.environ:
2161        session_id = int(os.environ['SESSION_ID'])
2162    else:
2163        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2164
2165    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2166
2167    # This DATA_DIR only has 2 images in it
2168    ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
2169    decode_op = c_vision.Decode()
2170    ds1 = ds1.repeat(4)
2171    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
2172    ds1 = ds1.repeat(2)
2173
2174    num_iter = 0
2175    for _ in ds1.create_dict_iterator(num_epochs=1):
2176        logger.info("get data from dataset")
2177        num_iter += 1
2178
2179    logger.info("Number of data in ds1: {} ".format(num_iter))
2180    assert num_iter == 16
2181    logger.info('test_cache_map_nested_repeat Ended.\n')
2182
2183
2184@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2185def test_cache_map_interrupt_and_rerun():
2186    """
2187    Test interrupt a running pipeline and then re-use the same cache to run another pipeline
2188
2189       cache
2190         |
2191      Cifar10
2192    """
2193
2194    logger.info("Test cache map interrupt and rerun")
2195    if "SESSION_ID" in os.environ:
2196        session_id = int(os.environ['SESSION_ID'])
2197    else:
2198        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2199
2200    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2201
2202    ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
2203    iter1 = ds1.create_dict_iterator()
2204
2205    num_iter = 0
2206    with pytest.raises(AttributeError) as e:
2207        for _ in iter1:
2208            num_iter += 1
2209            if num_iter == 10:
2210                iter1.stop()
2211    assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
2212
2213    num_epoch = 2
2214    iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
2215    epoch_count = 0
2216    for _ in range(num_epoch):
2217        num_iter = 0
2218        for _ in iter2:
2219            num_iter += 1
2220        logger.info("Number of data in ds1: {} ".format(num_iter))
2221        assert num_iter == 10000
2222        epoch_count += 1
2223
2224    cache_stat = some_cache.get_stat()
2225    assert cache_stat.num_mem_cached == 10000
2226
2227    logger.info("test_cache_map_interrupt_and_rerun Ended.\n")
2228
2229
2230@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2231def test_cache_map_dataset_size1():
2232    """
2233    Test get_dataset_size() when cache is injected directly after a mappable leaf
2234
2235       Cache
2236         |
2237      CelebA
2238    """
2239
2240    logger.info("Test cache map dataset size 1")
2241    if "SESSION_ID" in os.environ:
2242        session_id = int(os.environ['SESSION_ID'])
2243    else:
2244        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2245
2246    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2247
2248    # This dataset has 4 records
2249    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, num_shards=3, shard_id=0, cache=some_cache)
2250
2251    dataset_size = ds1.get_dataset_size()
2252    assert dataset_size == 2
2253
2254    num_iter = 0
2255    for _ in ds1.create_dict_iterator():
2256        num_iter += 1
2257
2258    logger.info("Number of data in ds1: {} ".format(num_iter))
2259    assert num_iter == dataset_size
2260    logger.info("test_cache_map_dataset_size1 Ended.\n")
2261
2262
2263@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2264def test_cache_map_dataset_size2():
2265    """
2266    Test get_dataset_size() when cache is injected after map
2267
2268       Cache
2269         |
2270    Map(resize)
2271         |
2272     CelebA
2273    """
2274
2275    logger.info("Test cache map dataset size 2")
2276    if "SESSION_ID" in os.environ:
2277        session_id = int(os.environ['SESSION_ID'])
2278    else:
2279        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2280
2281    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2282
2283    # This dataset has 4 records
2284    ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=3, shard_id=0)
2285    resize_op = c_vision.Resize((224, 224))
2286    ds1 = ds1.map(operations=resize_op, input_columns=["image"], cache=some_cache)
2287
2288    dataset_size = ds1.get_dataset_size()
2289    assert dataset_size == 2
2290
2291    num_iter = 0
2292    for _ in ds1.create_dict_iterator():
2293        num_iter += 1
2294
2295    logger.info("Number of data in ds1: {} ".format(num_iter))
2296    assert num_iter == dataset_size
2297    logger.info("test_cache_map_dataset_size2 Ended.\n")
2298
2299
2300if __name__ == '__main__':
2301    # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py'
2302    # since cache server is required to be brought up first
2303    test_cache_map_basic1()
2304    test_cache_map_basic2()
2305    test_cache_map_basic3()
2306    test_cache_map_basic4()
2307    test_cache_map_basic5()
2308    test_cache_map_failure1()
2309    test_cache_map_failure2()
2310    test_cache_map_failure3()
2311    test_cache_map_failure4()
2312    test_cache_map_failure5()
2313    test_cache_map_failure7()
2314    test_cache_map_failure8()
2315    test_cache_map_failure9()
2316    test_cache_map_failure10()
2317    test_cache_map_failure11()
2318    test_cache_map_split1()
2319    test_cache_map_split2()
2320    test_cache_map_parameter_check()
2321    test_cache_map_running_twice1()
2322    test_cache_map_running_twice2()
2323    test_cache_map_extra_small_size1()
2324    test_cache_map_extra_small_size2()
2325    test_cache_map_no_image()
2326    test_cache_map_parallel_pipeline1(shard=0)
2327    test_cache_map_parallel_pipeline2(shard=1)
2328    test_cache_map_parallel_workers()
2329    test_cache_map_server_workers_1()
2330    test_cache_map_server_workers_100()
2331    test_cache_map_num_connections_1()
2332    test_cache_map_num_connections_100()
2333    test_cache_map_prefetch_size_1()
2334    test_cache_map_prefetch_size_100()
2335    test_cache_map_to_device()
2336    test_cache_map_epoch_ctrl1()
2337    test_cache_map_epoch_ctrl2()
2338    test_cache_map_epoch_ctrl3()
2339    test_cache_map_coco1()
2340    test_cache_map_coco2()
2341    test_cache_map_mnist1()
2342    test_cache_map_mnist2()
2343    test_cache_map_celeba1()
2344    test_cache_map_celeba2()
2345    test_cache_map_manifest1()
2346    test_cache_map_manifest2()
2347    test_cache_map_cifar1()
2348    test_cache_map_cifar2()
2349    test_cache_map_cifar3()
2350    test_cache_map_cifar4()
2351    test_cache_map_voc1()
2352    test_cache_map_voc2()
2353    test_cache_map_mindrecord1()
2354    test_cache_map_mindrecord2()
2355    test_cache_map_python_sampler1()
2356    test_cache_map_python_sampler2()
2357    test_cache_map_nested_repeat()
2358    test_cache_map_dataset_size1()
2359    test_cache_map_dataset_size2()
2360