• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing cache operation with non-mappable datasets
17"""
18import os
19import itertools
20import numpy as np
21import pytest
22import mindspore.common.dtype as mstype
23import mindspore.dataset as ds
24import mindspore.dataset.text as text
25import mindspore.dataset.vision as c_vision
26from mindspore import log as logger
27
28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
30
31TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"]
32SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
33
34TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
35                  "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
36                  "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
37                  "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
38TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
39
40IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/"
41CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
42CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
43TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
44
45PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
46PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
47
48GENERATE_GOLDEN = False
49
50
51@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
52def test_cache_nomap_basic1():
53    """
54    Feature: DatasetCache op
55    Description: Test a RandomDataset (a non mappable dataset) with a Cache over it just after the leaf
56    Expectation: Output is equal to the expected output
57    """
58    logger.info("Test cache nomap basic 1")
59    if "SESSION_ID" in os.environ:
60        session_id = int(os.environ['SESSION_ID'])
61    else:
62        raise RuntimeError("Testcase requires SESSION_ID environment variable")
63
64    schema = ds.Schema()
65    schema.add_column('image', de_type=mstype.uint8,
66                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
67    schema.add_column('label', de_type=mstype.uint8, shape=[1])
68
69    # create a cache.  arbitrary session_id for now
70    some_cache = ds.DatasetCache(session_id=session_id, size=0)
71
72    # User-created sampler here
73    ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
74    ds1 = ds1.repeat(4)
75
76    num_iter = 0
77    for data in ds1.create_dict_iterator(num_epochs=1):
78        logger.info("printing the label: {}".format(data["label"]))
79        num_iter += 1
80
81    logger.info("Number of data in ds1: {} ".format(num_iter))
82    assert num_iter == 40
83    logger.info("test_cache_nomap_basic1 Ended.\n")
84
85
86@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
87def test_cache_nomap_basic2():
88    """
89    Feature: DatasetCache op
90    Description: Test RandomDataset (a non mappable dataset with num_samples) with a Cache over it just after the leaf
91    Expectation: Output is equal to the expected output
92    """
93    logger.info("Test cache nomap basic 2")
94    if "SESSION_ID" in os.environ:
95        session_id = int(os.environ['SESSION_ID'])
96    else:
97        raise RuntimeError("Testcase requires SESSION_ID environment variable")
98
99    schema = ds.Schema()
100    schema.add_column('image', de_type=mstype.uint8,
101                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
102    schema.add_column('label', de_type=mstype.uint8, shape=[1])
103
104    # create a cache.  arbitrary session_id for now
105    some_cache = ds.DatasetCache(session_id=session_id, size=0)
106
107    # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
108    # num_samples, shuffle, num_shards, shard_id
109    # In this case, the presence of num_samples chooses a sampler.
110    ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
111    ds1 = ds1.repeat(2)
112
113    num_iter = 0
114    for data in ds1.create_dict_iterator(num_epochs=1):
115        logger.info("printing the label: {}".format(data["label"]))
116        num_iter += 1
117
118    logger.info("Number of data in ds1: {} ".format(num_iter))
119    assert num_iter == 40
120    logger.info("test_cache_nomap_basic2 Ended.\n")
121
122
123@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
124def test_cache_nomap_basic3():
125    """
126    Feature: DatasetCache op
127    Description: Test a TFReaderDataset (a non mappable dataset) with a Cache over it just after the leaf
128
129       Repeat
130         |
131     Map(Decode)
132         |
133       Cache
134         |
135      TFReader
136
137    Expectation: Output is equal to the expected output
138    """
139    logger.info("Test cache nomap basic 3")
140    if "SESSION_ID" in os.environ:
141        session_id = int(os.environ['SESSION_ID'])
142    else:
143        raise RuntimeError("Testcase requires SESSION_ID environment variable")
144
145    some_cache = ds.DatasetCache(session_id=session_id, size=0)
146    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
147    decode_op = c_vision.Decode()
148    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
149    ds1 = ds1.repeat(4)
150
151    num_iter = 0
152    for _ in ds1.create_dict_iterator(num_epochs=1):
153        num_iter += 1
154
155    logger.info("Number of data in ds1: {} ".format(num_iter))
156    assert num_iter == 12
157
158    # Contact the server to get the statistics
159    stat = some_cache.get_stat()
160    cache_sz = stat.avg_cache_sz
161    num_mem_cached = stat.num_mem_cached
162    num_disk_cached = stat.num_disk_cached
163
164    logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
165    logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
166    logger.info("Average row cache size: {}".format(cache_sz))
167
168    logger.info("test_cache_nomap_basic3 Ended.\n")
169
170
171@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
172def test_cache_nomap_basic4():
173    """
174    Feature: DatasetCache op
175    Description: Test a TFReaderDataset (a non mappable dataset) with a map Decode and Cache after it
176        Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
177        But, if there's a cache later, that shuffle becomes invalid and should be removed.
178
179       Repeat
180         |
181       Cache
182         |
183     Map(Decode)
184         |
185      TFReader
186
187    Expectation: Output is equal to the expected output
188    """
189    logger.info("Test cache nomap basic 4")
190    if "SESSION_ID" in os.environ:
191        session_id = int(os.environ['SESSION_ID'])
192    else:
193        raise RuntimeError("Testcase requires SESSION_ID environment variable")
194
195    # This dataset has 3 records in it only
196    some_cache = ds.DatasetCache(session_id=session_id, size=0)
197    # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
198    # in the picture.  This causes a shuffle-injection over the TF.  For clarify, this test will
199    # explicitly give the global option, even though it's the default in python.
200    # But, when caching is added in the ascendent tree above TF, we do global shuffling
201    # through the sampler over the cache, not by the shuffle op.  In that case, tree prepare
202    # will remove the shuffle op that got injected by the initial tree creation.
203    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
204    decode_op = c_vision.Decode()
205
206    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
207    ds1 = ds1.repeat(4)
208
209    num_iter = 0
210    for _ in ds1.create_dict_iterator(num_epochs=1):
211        num_iter += 1
212
213    logger.info("Number of data in ds1: {} ".format(num_iter))
214    assert num_iter == 12
215    logger.info("test_cache_nomap_basic4 Ended.\n")
216
217
218@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
219def test_cache_nomap_basic5():
220    """
221    Feature: DatasetCache op
222    Description: Test a TFReaderDataset (a non mappable dataset) with a Cache over it just after the leaf.
223        Same as test 3, but this one does not have Shuffle arg, causing TF to default to global
224        shuffle which attempts to inject a Shuffle operation. However, since there is a Cache
225        we do not need global shuffle, so the shuffle will not be built. It ends up being
226        identical to test basic 3, however we arrive at the same tree in different codepaths
227        (if there was no Cache, then the Shuffle is built)
228
229       Repeat
230         |
231     Map(Decode)
232         |
233       Cache
234         |
235      TFReader
236
237    Expectation: Output is equal to the expected output
238    """
239    logger.info("Test cache nomap basic 5")
240    if "SESSION_ID" in os.environ:
241        session_id = int(os.environ['SESSION_ID'])
242    else:
243        raise RuntimeError("Testcase requires SESSION_ID environment variable")
244
245    # This dataset has 3 records in it only
246    some_cache = ds.DatasetCache(session_id=session_id, size=0)
247    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
248    decode_op = c_vision.Decode()
249    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
250    ds1 = ds1.repeat(4)
251
252    num_iter = 0
253    for _ in ds1.create_dict_iterator(num_epochs=1):
254        num_iter += 1
255
256    logger.info("Number of data in ds1: {} ".format(num_iter))
257    assert num_iter == 12
258    logger.info("test_cache_nomap_basic5 Ended.\n")
259
260
261@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
262def test_cache_nomap_basic6():
263    """
264    Feature: DatasetCache op
265    Description: Test a TFReaderDataset (a non mappable dataset) with a Cache over it just after the leaf
266        In this one, the TFReaderDataset will be given sharding configuration, however since a Cache is
267        used, the tree prepare should undo the sharding configuration and instead, a distributed
268        sampler will be chosen with the same shard config.
269
270       Repeat
271         |
272     Map(Decode)
273         |
274       Cache
275         |
276      TFReader
277
278    Expectation: Output is equal to the expected output
279    """
280    logger.info("Test cache nomap basic 6")
281    if "SESSION_ID" in os.environ:
282        session_id = int(os.environ['SESSION_ID'])
283    else:
284        raise RuntimeError("Testcase requires SESSION_ID environment variable")
285
286    # This dataset has 3 records in it only
287    some_cache = ds.DatasetCache(session_id=session_id, size=0)
288
289    # With only 3 records shard into 3, we expect only 1 record returned for this shard
290    # However, the sharding will be done by the sampler, not by the tf record leaf node
291    # In this case, it is a row-based sharding, not the file-based sharding that would happen if
292    # there was not any cache.
293    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
294    decode_op = c_vision.Decode()
295    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
296    ds1 = ds1.repeat(4)
297
298    num_iter = 0
299    for _ in ds1.create_dict_iterator(num_epochs=1):
300        num_iter += 1
301
302    logger.info("Number of data in ds1: {} ".format(num_iter))
303    assert num_iter == 4
304    logger.info("test_cache_nomap_basic6 Ended.\n")
305
306
307@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
308def test_cache_nomap_basic7():
309    """
310    Feature: DatasetCache op
311    Description: Test a TFReaderDataset (a non mappable dataset) that uses global shuffle, and is Cached followed by
312        Map. In this one, the TFReaderDataset with global shuffle might want to inject a Shuffle op over top of the
313        TFReaderDataset, but since a Cache is given, it will choose not to.
314
315       Repeat
316         |
317     Map(Decode)
318         |
319       cache
320         |
321      TFReader
322
323    Expectation: Output is equal to the expected output
324    """
325    logger.info("Test cache nomap basic 7")
326    if "SESSION_ID" in os.environ:
327        session_id = int(os.environ['SESSION_ID'])
328    else:
329        raise RuntimeError("Testcase requires SESSION_ID environment variable")
330
331    some_cache = ds.DatasetCache(session_id=session_id, size=0)
332
333    # This dataset has 3 records in it only
334    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
335    decode_op = c_vision.Decode()
336    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
337    ds1 = ds1.repeat(4)
338
339    num_iter = 0
340    for _ in ds1.create_dict_iterator(num_epochs=1):
341        num_iter += 1
342
343    logger.info("Number of data in ds1: {} ".format(num_iter))
344    assert num_iter == 12
345    logger.info("test_cache_nomap_basic7 Ended.\n")
346
347
348@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
349def test_cache_nomap_basic8():
350    """
351    Feature: DatasetCache op
352    Description: Test Cache as root node
353
354       Cache
355         |
356      TFReader
357
358    Expectation: Output is equal to the expected output
359    """
360    logger.info("Test cache basic 8")
361    if "SESSION_ID" in os.environ:
362        session_id = int(os.environ['SESSION_ID'])
363    else:
364        raise RuntimeError("Testcase requires SESSION_ID environment variable")
365    some_cache = ds.DatasetCache(session_id=session_id, size=0)
366
367    # This dataset has 3 records in it only
368    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
369    num_iter = 0
370    for _ in ds1.create_dict_iterator(num_epochs=1):
371        logger.info("get data from dataset")
372        num_iter += 1
373
374    logger.info("Number of data in ds1: {} ".format(num_iter))
375    assert num_iter == 3
376    logger.info('test_cache_basic8 Ended.\n')
377
378
379@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
380def test_cache_nomap_basic9():
381    """
382    Feature: DatasetCache op
383    Description: Testing the get_stat interface for getting some info from server but Cache is not created in pipeline
384    Expectation: Error is raised as expected
385    """
386    logger.info("Test cache nomap basic 9")
387    if "SESSION_ID" in os.environ:
388        session_id = int(os.environ['SESSION_ID'])
389    else:
390        raise RuntimeError("Testcase requires SESSION_ID environment variable")
391
392    some_cache = ds.DatasetCache(session_id=session_id, size=0)
393
394    # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline
395    # so there will not be any cache to get stats on.
396    with pytest.raises(RuntimeError) as e:
397        stat = some_cache.get_stat()
398        cache_sz = stat.avg_cache_sz
399        logger.info("Average row cache size: {}".format(cache_sz))
400    assert "Unexpected error" in str(e.value)
401
402    logger.info("test_cache_nomap_basic9 Ended.\n")
403
404
405@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
406def test_cache_nomap_allowed_share1():
407    """
408    Feature: DatasetCache op
409    Description: Test sharing the Cache between the following two trees:
410
411       Repeat     Shuffle
412         |           |
413       Cache       Cache
414         |           |
415      TFReader    TFReader
416
417    Expectation: Output is equal to the expected output
418    """
419    logger.info("Test cache nomap allowed share 1")
420    if "SESSION_ID" in os.environ:
421        session_id = int(os.environ['SESSION_ID'])
422    else:
423        raise RuntimeError("Testcase requires SESSION_ID environment variable")
424
425    ds.config.set_seed(1)
426    # This dataset has 3 records in it only
427    some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32)
428    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
429    ds1 = ds1.repeat(4)
430
431    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
432    ds2 = ds2.shuffle(buffer_size=2)
433
434    num_iter = 0
435    for _ in ds1.create_dict_iterator(num_epochs=1):
436        num_iter += 1
437    assert num_iter == 12
438    logger.info("Number of data in ds1: {} ".format(num_iter))
439
440    num_iter = 0
441    for _ in ds2.create_dict_iterator(num_epochs=1):
442        num_iter += 1
443    assert num_iter == 3
444    logger.info("test_cache_nomap_allowed_share1 Ended.\n")
445
446
447@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
448def test_cache_nomap_allowed_share2():
449    """
450    Feature: DatasetCache op
451    Description: Test sharing the Cache between the following two trees (with Map Decode):
452
453       Repeat     Shuffle
454         |           |
455       Cache       Cache
456         |           |
457     Map(Decode) Map(Decode)
458         |           |
459      TFReader    TFReader
460
461    Expectation: Output is equal to the expected output
462    """
463    logger.info("Test cache nomap allowed share 2")
464    if "SESSION_ID" in os.environ:
465        session_id = int(os.environ['SESSION_ID'])
466    else:
467        raise RuntimeError("Testcase requires SESSION_ID environment variable")
468
469    ds.config.set_seed(1)
470    # This dataset has 3 records in it only
471    some_cache = ds.DatasetCache(session_id=session_id, size=0)
472    decode_op = c_vision.Decode()
473
474    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
475    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
476    ds1 = ds1.repeat(4)
477
478    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
479    ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache)
480    ds2 = ds2.shuffle(buffer_size=2)
481
482    num_iter = 0
483    for _ in ds1.create_dict_iterator(num_epochs=1):
484        num_iter += 1
485    logger.info("Number of data in ds1: {} ".format(num_iter))
486    assert num_iter == 12
487
488    num_iter = 0
489    for _ in ds2.create_dict_iterator(num_epochs=1):
490        num_iter += 1
491    assert num_iter == 3
492    logger.info("test_cache_nomap_allowed_share2 Ended.\n")
493
494
495@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
496def test_cache_nomap_allowed_share3():
497    """
498    Feature: DatasetCache op
499    Description: Test sharing the Cache between the following two trees (different shard ids):
500
501       Repeat                     Repeat
502         |                          |
503       Cache                      Cache
504         |                          |
505      TFReader(shard_id = 0)     TFReader(shard_id = 1)
506
507    Expectation: Output is equal to the expected output
508    """
509    logger.info("Test cache nomap allowed share 3")
510    if "SESSION_ID" in os.environ:
511        session_id = int(os.environ['SESSION_ID'])
512    else:
513        raise RuntimeError("Testcase requires SESSION_ID environment variable")
514
515    some_cache = ds.DatasetCache(session_id=session_id, size=0)
516
517    tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
518    ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
519    ds1 = ds1.repeat(4)
520
521    ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
522    ds2 = ds2.repeat(4)
523
524    num_iter = 0
525    for _ in ds1.create_dict_iterator(num_epochs=1):
526        num_iter += 1
527    logger.info("Number of data in ds1: {} ".format(num_iter))
528    assert num_iter == 12
529
530    num_iter = 0
531    for _ in ds2.create_dict_iterator(num_epochs=1):
532        num_iter += 1
533    assert num_iter == 12
534    logger.info("test_cache_nomap_allowed_share3 Ended.\n")
535
536
537@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
538def test_cache_nomap_allowed_share4():
539    """
540    Feature: DatasetCache op
541    Description: Test sharing the Cache between the following two trees:
542
543       Cache                                  Cache
544         |                                      |
545     Map(Decode, num_parallel_workers=1)    Map(Decode, num_parallel_workers=2)
546         |                                      |
547      TFReader                              TFReader
548
549    Expectation: Output is equal to the expected output
550    """
551    logger.info("Test cache nomap allowed share 4")
552    if "SESSION_ID" in os.environ:
553        session_id = int(os.environ['SESSION_ID'])
554    else:
555        raise RuntimeError("Testcase requires SESSION_ID environment variable")
556
557    # This dataset has 3 records in it only
558    some_cache = ds.DatasetCache(session_id=session_id, size=0)
559    decode_op = c_vision.Decode()
560
561    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
562    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=1)
563
564    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
565    ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=2)
566
567    num_iter = 0
568    for _ in ds1.create_dict_iterator(num_epochs=1):
569        num_iter += 1
570    logger.info("Number of data in ds1: {} ".format(num_iter))
571    assert num_iter == 3
572
573    num_iter = 0
574    for _ in ds2.create_dict_iterator(num_epochs=1):
575        num_iter += 1
576    logger.info("Number of data in ds2: {} ".format(num_iter))
577    assert num_iter == 3
578
579    logger.info("test_cache_nomap_allowed_share4 Ended.\n")
580
581
582@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
583def test_cache_nomap_disallowed_share1():
584    """
585    Feature: DatasetCache op
586    Description: Test sharing the Cache between the following two trees:
587
588       Cache       Cache
589         |           |
590     Map(Decode) Map(Rescale)
591         |           |
592      TFReader    TFReader
593
594    Expectation: Error is raised as expected
595    """
596    logger.info("Test cache nomap disallowed share1")
597    if "SESSION_ID" in os.environ:
598        session_id = int(os.environ['SESSION_ID'])
599    else:
600        raise RuntimeError("Testcase requires SESSION_ID environment variable")
601
602    # This dataset has 3 records in it only
603    some_cache = ds.DatasetCache(session_id=session_id, size=0)
604    decode_op = c_vision.Decode()
605    rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
606
607    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
608    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
609
610    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
611    ds2 = ds2.map(operations=rescale_op, input_columns=["image"], cache=some_cache)
612
613    num_iter = 0
614    for _ in ds1.create_dict_iterator(num_epochs=1):
615        num_iter += 1
616    logger.info("Number of data in ds1: {} ".format(num_iter))
617    assert num_iter == 3
618
619    with pytest.raises(RuntimeError) as e:
620        sum([1 for _ in ds2])
621    assert "Cannot re-use a cache for a different tree!" in str(e.value)
622
623    logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
624
625
626@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
627def test_cache_nomap_running_twice1():
628    """
629    Feature: DatasetCache op
630    Description: Test executing the same pipeline for twice (from Python), with Cache injected after Map
631
632       Repeat
633         |
634       Cache
635         |
636     Map(Decode)
637         |
638     TFRecord
639
640    Expectation: Output is equal to the expected output
641    """
642    logger.info("Test cache nomap running twice 1")
643    if "SESSION_ID" in os.environ:
644        session_id = int(os.environ['SESSION_ID'])
645    else:
646        raise RuntimeError("Testcase requires SESSION_ID environment variable")
647
648    some_cache = ds.DatasetCache(session_id=session_id, size=0)
649
650    # This dataset has 3 records in it only
651    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
652    decode_op = c_vision.Decode()
653    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
654    ds1 = ds1.repeat(4)
655
656    num_iter = 0
657    for _ in ds1.create_dict_iterator(num_epochs=1):
658        num_iter += 1
659    logger.info("Number of data in ds1: {} ".format(num_iter))
660    assert num_iter == 12
661
662    num_iter = 0
663    for _ in ds1.create_dict_iterator(num_epochs=1):
664        num_iter += 1
665    logger.info("Number of data in ds1: {} ".format(num_iter))
666    assert num_iter == 12
667
668    logger.info("test_cache_nomap_running_twice1 Ended.\n")
669
670
671@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
672def test_cache_nomap_running_twice2():
673    """
674    Feature: DatasetCache op
675    Description: Test executing the same pipeline for twice (from shell), with Cache injected after leaf
676
677       Repeat
678         |
679     Map(Decode)
680         |
681       Cache
682         |
683     TFRecord
684
685    Expectation: Output is equal to the expected output
686    """
687    logger.info("Test cache nomap running twice 2")
688    if "SESSION_ID" in os.environ:
689        session_id = int(os.environ['SESSION_ID'])
690    else:
691        raise RuntimeError("Testcase requires SESSION_ID environment variable")
692
693    some_cache = ds.DatasetCache(session_id=session_id, size=0)
694
695    # This dataset has 3 records in it only
696    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
697    decode_op = c_vision.Decode()
698    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
699    ds1 = ds1.repeat(4)
700
701    num_iter = 0
702    for _ in ds1.create_dict_iterator(num_epochs=1):
703        num_iter += 1
704
705    logger.info("Number of data in ds1: {} ".format(num_iter))
706    assert num_iter == 12
707    logger.info("test_cache_nomap_running_twice2 Ended.\n")
708
709
710@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
711def test_cache_nomap_extra_small_size1():
712    """
713    Feature: DatasetCache op
714    Description: Test running pipeline with Cache of extra small size and spilling=True
715
716       Repeat
717         |
718     Map(Decode)
719         |
720       Cache
721         |
722     TFRecord
723
724    Expectation: Output is equal to the expected output
725    """
726    logger.info("Test cache nomap extra small size 1")
727    if "SESSION_ID" in os.environ:
728        session_id = int(os.environ['SESSION_ID'])
729    else:
730        raise RuntimeError("Testcase requires SESSION_ID environment variable")
731    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
732
733    # This dataset has 3 records in it only
734    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
735    decode_op = c_vision.Decode()
736    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
737    ds1 = ds1.repeat(4)
738
739    num_iter = 0
740    for _ in ds1.create_dict_iterator(num_epochs=1):
741        num_iter += 1
742
743    logger.info("Number of data in ds1: {} ".format(num_iter))
744    assert num_iter == 12
745    logger.info("test_cache_nomap_extra_small_size1 Ended.\n")
746
747
748@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
749def test_cache_nomap_extra_small_size2():
750    """
751    Feature: DatasetCache op
752    Description: Test running pipeline with Cache of extra small size and spilling=False
753
754       Repeat
755         |
756       Cache
757         |
758     Map(Decode)
759         |
760     TFRecord
761
762    Expectation: Error is raised as expected
763    """
764    logger.info("Test cache nomap extra small size 2")
765    if "SESSION_ID" in os.environ:
766        session_id = int(os.environ['SESSION_ID'])
767    else:
768        raise RuntimeError("Testcase requires SESSION_ID environment variable")
769    some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
770
771    # This dataset has 3 records in it only
772    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
773    decode_op = c_vision.Decode()
774    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
775    ds1 = ds1.repeat(4)
776
777    with pytest.raises(RuntimeError) as e:
778        sum([1 for _ in ds1])
779    assert "Out of memory" in str(e.value)
780    logger.info("test_cache_nomap_extra_small_size2 Ended.\n")
781
782
783@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
784def test_cache_nomap_parallel_pipeline1(shard):
785    """
786    Feature: DatasetCache op
787    Description: Test running two parallel pipelines (sharing Cache) with Cache injected after leaf op
788
789       Repeat
790         |
791     Map(Decode)
792         |
793       Cache
794         |
795      TFReader
796
797    Expectation: Output is equal to the expected output
798    """
799    logger.info("Test cache nomap parallel pipeline 1")
800    if "SESSION_ID" in os.environ:
801        session_id = int(os.environ['SESSION_ID'])
802    else:
803        raise RuntimeError("Testcase requires SESSION_ID environment variable")
804    some_cache = ds.DatasetCache(session_id=session_id, size=0)
805
806    # This dataset has 3 records in it only
807    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
808    decode_op = c_vision.Decode()
809    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
810    ds1 = ds1.repeat(4)
811
812    num_iter = 0
813    for _ in ds1.create_dict_iterator(num_epochs=1):
814        num_iter += 1
815
816    logger.info("Number of data in ds1: {} ".format(num_iter))
817    assert num_iter == 4
818    logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n")
819
820
821@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
822def test_cache_nomap_parallel_pipeline2(shard):
823    """
824    Feature: DatasetCache op
825    Description: Test running two parallel pipelines (sharing Cache) with Cache injected after Map op
826
827       Repeat
828         |
829       Cache
830         |
831     Map(Decode)
832         |
833      TFReader
834
835    Expectation: Output is equal to the expected output
836    """
837    logger.info("Test cache nomap parallel pipeline 2")
838    if "SESSION_ID" in os.environ:
839        session_id = int(os.environ['SESSION_ID'])
840    else:
841        raise RuntimeError("Testcase requires SESSION_ID environment variable")
842    some_cache = ds.DatasetCache(session_id=session_id, size=0)
843
844    # This dataset has 3 records in it only
845    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
846    decode_op = c_vision.Decode()
847    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
848    ds1 = ds1.repeat(4)
849
850    num_iter = 0
851    for _ in ds1.create_dict_iterator(num_epochs=1):
852        num_iter += 1
853
854    logger.info("Number of data in ds1: {} ".format(num_iter))
855    assert num_iter == 4
856    logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n")
857
858
859@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
860def test_cache_nomap_parallel_workers():
861    """
862    Feature: DatasetCache op
863    Description: Test Cache with num_parallel_workers > 1 set for Map op and leaf op
864
865       Repeat
866         |
867     Map(Decode)
868         |
869       Cache
870         |
871      TFReader
872
873    Expectation: Output is equal to the expected output
874    """
875    logger.info("Test cache nomap parallel workers")
876    if "SESSION_ID" in os.environ:
877        session_id = int(os.environ['SESSION_ID'])
878    else:
879        raise RuntimeError("Testcase requires SESSION_ID environment variable")
880    some_cache = ds.DatasetCache(session_id=session_id, size=0)
881
882    # This dataset has 3 records in it only
883    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
884    decode_op = c_vision.Decode()
885    ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
886    ds1 = ds1.repeat(4)
887
888    num_iter = 0
889    for _ in ds1.create_dict_iterator(num_epochs=1):
890        num_iter += 1
891
892    logger.info("Number of data in ds1: {} ".format(num_iter))
893    assert num_iter == 12
894    logger.info("test_cache_nomap_parallel_workers Ended.\n")
895
896
897@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
898def test_cache_nomap_server_workers_1():
899    """
900    Feature: DatasetCache op
901    Description: Start Cache server with --workers 1 and then test Cache function
902
903       Repeat
904         |
905       Cache
906         |
907     Map(Decode)
908         |
909      TFRecord
910
911    Expectation: Output is equal to the expected output
912    """
913    logger.info("Test cache nomap server workers 1")
914    if "SESSION_ID" in os.environ:
915        session_id = int(os.environ['SESSION_ID'])
916    else:
917        raise RuntimeError("Testcase requires SESSION_ID environment variable")
918
919    some_cache = ds.DatasetCache(session_id=session_id, size=0)
920
921    # This dataset has 3 records in it only
922    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
923    decode_op = c_vision.Decode()
924    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
925    ds1 = ds1.repeat(4)
926
927    num_iter = 0
928    for _ in ds1.create_dict_iterator(num_epochs=1):
929        num_iter += 1
930
931    logger.info("Number of data in ds1: {} ".format(num_iter))
932    assert num_iter == 12
933    logger.info("test_cache_nomap_server_workers_1 Ended.\n")
934
935
936@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
937def test_cache_nomap_server_workers_100():
938    """
939    Feature: DatasetCache op
940    Description: Start Cache server with --workers 100 and then test Cache function
941
942       Repeat
943         |
944     Map(Decode)
945         |
946       Cache
947         |
948      TFRecord
949
950    Expectation: Output is equal to the expected output
951    """
952    logger.info("Test cache nomap server workers 100")
953    if "SESSION_ID" in os.environ:
954        session_id = int(os.environ['SESSION_ID'])
955    else:
956        raise RuntimeError("Testcase requires SESSION_ID environment variable")
957
958    some_cache = ds.DatasetCache(session_id=session_id, size=0)
959
960    # This dataset has 3 records in it only
961    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
962    decode_op = c_vision.Decode()
963    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
964    ds1 = ds1.repeat(4)
965
966    num_iter = 0
967    for _ in ds1.create_dict_iterator(num_epochs=1):
968        num_iter += 1
969
970    logger.info("Number of data in ds1: {} ".format(num_iter))
971    assert num_iter == 12
972    logger.info("test_cache_nomap_server_workers_100 Ended.\n")
973
974
975@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
976def test_cache_nomap_num_connections_1():
977    """
978    Feature: DatasetCache op
979    Description: Test setting num_connections=1 in DatasetCache
980
981       Repeat
982         |
983       Cache
984         |
985     Map(Decode)
986         |
987      TFRecord
988
989    Expectation: Output is equal to the expected output
990    """
991    logger.info("Test cache nomap num_connections 1")
992    if "SESSION_ID" in os.environ:
993        session_id = int(os.environ['SESSION_ID'])
994    else:
995        raise RuntimeError("Testcase requires SESSION_ID environment variable")
996
997    some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
998
999    # This dataset has 3 records in it only
1000    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1001    decode_op = c_vision.Decode()
1002    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1003    ds1 = ds1.repeat(4)
1004
1005    num_iter = 0
1006    for _ in ds1.create_dict_iterator(num_epochs=1):
1007        num_iter += 1
1008
1009    logger.info("Number of data in ds1: {} ".format(num_iter))
1010    assert num_iter == 12
1011    logger.info("test_cache_nomap_num_connections_1 Ended.\n")
1012
1013
1014@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1015def test_cache_nomap_num_connections_100():
1016    """
1017    Feature: DatasetCache op
1018    Description: Test setting num_connections=100 in DatasetCache
1019
1020       Repeat
1021         |
1022     Map(Decode)
1023         |
1024       Cache
1025         |
1026      TFRecord
1027
1028    Expectation: Output is equal to the expected output
1029    """
1030    logger.info("Test cache nomap num_connections 100")
1031    if "SESSION_ID" in os.environ:
1032        session_id = int(os.environ['SESSION_ID'])
1033    else:
1034        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1035
1036    some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
1037
1038    # This dataset has 3 records in it only
1039    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
1040    decode_op = c_vision.Decode()
1041    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1042    ds1 = ds1.repeat(4)
1043
1044    num_iter = 0
1045    for _ in ds1.create_dict_iterator(num_epochs=1):
1046        num_iter += 1
1047
1048    logger.info("Number of data in ds1: {} ".format(num_iter))
1049    assert num_iter == 12
1050    logger.info("test_cache_nomap_num_connections_100 Ended.\n")
1051
1052
1053@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1054def test_cache_nomap_prefetch_size_1():
1055    """
1056    Feature: DatasetCache op
1057    Description: Test setting prefetch_size=1 in DatasetCache
1058
1059       Repeat
1060         |
1061       Cache
1062         |
1063     Map(Decode)
1064         |
1065      TFRecord
1066
1067    Expectation: Output is equal to the expected output
1068    """
1069    logger.info("Test cache nomap prefetch_size 1")
1070    if "SESSION_ID" in os.environ:
1071        session_id = int(os.environ['SESSION_ID'])
1072    else:
1073        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1074
1075    some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
1076
1077    # This dataset has 3 records in it only
1078    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1079    decode_op = c_vision.Decode()
1080    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1081    ds1 = ds1.repeat(4)
1082
1083    num_iter = 0
1084    for _ in ds1.create_dict_iterator(num_epochs=1):
1085        num_iter += 1
1086
1087    logger.info("Number of data in ds1: {} ".format(num_iter))
1088    assert num_iter == 12
1089    logger.info("test_cache_nomap_prefetch_size_1 Ended.\n")
1090
1091
1092@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1093def test_cache_nomap_prefetch_size_100():
1094    """
1095    Feature: DatasetCache op
1096    Description: Test setting prefetch_size=100 in DatasetCache
1097
1098       Repeat
1099         |
1100     Map(Decode)
1101         |
1102       Cache
1103         |
1104      TFRecord
1105
1106    Expectation: Output is equal to the expected output
1107    """
1108    logger.info("Test cache nomap prefetch_size 100")
1109    if "SESSION_ID" in os.environ:
1110        session_id = int(os.environ['SESSION_ID'])
1111    else:
1112        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1113
1114    some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
1115
1116    # This dataset has 3 records in it only
1117    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
1118    decode_op = c_vision.Decode()
1119    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1120    ds1 = ds1.repeat(4)
1121
1122    num_iter = 0
1123    for _ in ds1.create_dict_iterator(num_epochs=1):
1124        num_iter += 1
1125
1126    logger.info("Number of data in ds1: {} ".format(num_iter))
1127    assert num_iter == 12
1128    logger.info("test_cache_nomap_prefetch_size_100 Ended.\n")
1129
1130
1131@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1132def test_cache_nomap_device_que():
1133    """
1134    Feature: DatasetCache op
1135    Description: Test Cache with device_que
1136
1137     DeviceQueue
1138         |
1139      EpochCtrl
1140         |
1141       Repeat
1142         |
1143     Map(Decode)
1144         |
1145       Cache
1146         |
1147      TFReader
1148
1149    Expectation: Output is equal to the expected output
1150    """
1151    logger.info("Test cache nomap device_que")
1152    if "SESSION_ID" in os.environ:
1153        session_id = int(os.environ['SESSION_ID'])
1154    else:
1155        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1156
1157    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1158
1159    # This dataset has 3 records in it only
1160    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1161    decode_op = c_vision.Decode()
1162    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1163    ds1 = ds1.repeat(4)
1164    ds1 = ds1.device_que()
1165    ds1.send()
1166
1167    logger.info("test_cache_nomap_device_que Ended.\n")
1168
1169
1170@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1171def test_cache_nomap_session_destroy():
1172    """
1173    Feature: DatasetCache op
1174    Description: Test executing cache_admin -d while the pipeline is running
1175
1176       Repeat
1177         |
1178       Cache
1179         |
1180     RandomDataset
1181
1182    Expectation: Error is raised as expected
1183    """
1184    logger.info("Test cache nomap session destroy")
1185    if "SESSION_ID" in os.environ:
1186        session_id = int(os.environ['SESSION_ID'])
1187    else:
1188        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1189
1190    schema = ds.Schema()
1191    schema.add_column('image', de_type=mstype.uint8,
1192                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
1193    schema.add_column('label', de_type=mstype.uint8, shape=[1])
1194
1195    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1196
1197    # User-created sampler here
1198    ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
1199    ds1 = ds1.repeat()
1200
1201    with pytest.raises(RuntimeError) as e:
1202        num_iter = 0
1203        for _ in ds1.create_dict_iterator(num_epochs=1):
1204            num_iter += 1
1205    assert "Unexpected error" in str(e.value)
1206
1207    logger.info("test_cache_nomap_session_destroy Ended.\n")
1208
1209
1210@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1211def test_cache_nomap_server_stop():
1212    """
1213    Feature: DatasetCache op
1214    Description: Test executing cache_admin --stop while the pipeline is running
1215
1216       Repeat
1217         |
1218       Cache
1219         |
1220     RandomDataset
1221
1222    Expectation: Error is raised as expected
1223    """
1224    logger.info("Test cache nomap server stop")
1225    if "SESSION_ID" in os.environ:
1226        session_id = int(os.environ['SESSION_ID'])
1227    else:
1228        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1229
1230    schema = ds.Schema()
1231    schema.add_column('image', de_type=mstype.uint8,
1232                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
1233    schema.add_column('label', de_type=mstype.uint8, shape=[1])
1234
1235    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1236
1237    # User-created sampler here
1238    ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
1239    ds1 = ds1.repeat()
1240
1241    with pytest.raises(RuntimeError) as e:
1242        num_iter = 0
1243        for _ in ds1.create_dict_iterator(num_epochs=1):
1244            num_iter += 1
1245    assert "Network error. Cache server with port 50052 is unreachable. Make sure the server is running." in \
1246           str(e.value)
1247
1248    logger.info("test_cache_nomap_server_stop Ended.\n")
1249
1250
1251@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1252def test_cache_nomap_interrupt_and_rerun():
1253    """
1254    Feature: DatasetCache op
1255    Description: Test interrupt a running pipeline and then re-use the same Cache to run another pipeline
1256
1257       Cache
1258         |
1259     RandomDataset
1260
1261    Expectation: Error is raised after the interrupt then putput is equal to the expected output after the rerun
1262    """
1263    logger.info("Test cache nomap interrupt and rerun")
1264    if "SESSION_ID" in os.environ:
1265        session_id = int(os.environ['SESSION_ID'])
1266    else:
1267        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1268
1269    schema = ds.Schema()
1270    schema.add_column('image', de_type=mstype.uint8,
1271                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
1272    schema.add_column('label', de_type=mstype.uint8, shape=[1])
1273
1274    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1275
1276    # User-created sampler here
1277    ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache)
1278    iter1 = ds1.create_dict_iterator(num_epochs=-1)
1279
1280    num_iter = 0
1281    with pytest.raises(AttributeError) as e:
1282        for _ in iter1:
1283            num_iter += 1
1284            if num_iter == 10:
1285                iter1.stop()
1286    assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
1287
1288    num_epoch = 2
1289    iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
1290    epoch_count = 0
1291    for _ in range(num_epoch):
1292        num_iter = 0
1293        for _ in iter2:
1294            num_iter += 1
1295        logger.info("Number of data in ds1: {} ".format(num_iter))
1296        assert num_iter == 10000
1297        epoch_count += 1
1298
1299    cache_stat = some_cache.get_stat()
1300    assert cache_stat.num_mem_cached == 10000
1301
1302    logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n")
1303
1304
1305@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1306def test_cache_nomap_epoch_ctrl1():
1307    """
1308    Feature: DatasetCache op
1309    Description: Test using two-loops method to run several epochs
1310
1311     Map(Decode)
1312         |
1313       Cache
1314         |
1315      TFRecord
1316
1317    Expectation: Output is equal to the expected output
1318    """
1319    logger.info("Test cache nomap epoch ctrl1")
1320    if "SESSION_ID" in os.environ:
1321        session_id = int(os.environ['SESSION_ID'])
1322    else:
1323        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1324
1325    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1326
1327    # This dataset has 3 records in it only
1328    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
1329    decode_op = c_vision.Decode()
1330    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1331
1332    num_epoch = 5
1333    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1334
1335    epoch_count = 0
1336    for _ in range(num_epoch):
1337        row_count = 0
1338        for _ in iter1:
1339            row_count += 1
1340        logger.info("Number of data in ds1: {} ".format(row_count))
1341        assert row_count == 3
1342        epoch_count += 1
1343    assert epoch_count == num_epoch
1344    logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n")
1345
1346
1347@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1348def test_cache_nomap_epoch_ctrl2():
1349    """
1350    Feature: DatasetCache op
1351    Description: Test using two-loops method with infinite epochs
1352
1353        Cache
1354         |
1355     Map(Decode)
1356         |
1357      TFRecord
1358
1359    Expectation: Output is equal to the expected output
1360    """
1361    logger.info("Test cache nomap epoch ctrl2")
1362    if "SESSION_ID" in os.environ:
1363        session_id = int(os.environ['SESSION_ID'])
1364    else:
1365        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1366
1367    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1368
1369    # This dataset has 3 records in it only
1370    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1371    decode_op = c_vision.Decode()
1372    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1373
1374    num_epoch = 5
1375    # iter1 will always assume there is a next epoch and never shutdown
1376    iter1 = ds1.create_dict_iterator(num_epochs=-1)
1377
1378    epoch_count = 0
1379    for _ in range(num_epoch):
1380        row_count = 0
1381        for _ in iter1:
1382            row_count += 1
1383        logger.info("Number of data in ds1: {} ".format(row_count))
1384        assert row_count == 3
1385        epoch_count += 1
1386    assert epoch_count == num_epoch
1387
1388    # manually stop the iterator
1389    iter1.stop()
1390    logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n")
1391
1392
1393@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1394def test_cache_nomap_epoch_ctrl3():
1395    """
1396    Feature: DatasetCache op
1397    Description: Test using two-loops method with infinite epochs over Repeat op
1398
1399       Repeat
1400         |
1401     Map(Decode)
1402         |
1403       Cache
1404         |
1405      TFRecord
1406
1407    Expectation: Output is equal to the expected output
1408    """
1409    logger.info("Test cache nomap epoch ctrl3")
1410    if "SESSION_ID" in os.environ:
1411        session_id = int(os.environ['SESSION_ID'])
1412    else:
1413        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1414
1415    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1416
1417    # This dataset has 3 records in it only
1418    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
1419    decode_op = c_vision.Decode()
1420    ds1 = ds1.map(input_columns=["image"], operations=decode_op)
1421    ds1 = ds1.repeat(2)
1422
1423    num_epoch = 5
1424    # iter1 will always assume there is a next epoch and never shutdown
1425    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1426
1427    epoch_count = 0
1428    for _ in range(num_epoch):
1429        row_count = 0
1430        for _ in iter1:
1431            row_count += 1
1432        logger.info("Number of data in ds1: {} ".format(row_count))
1433        assert row_count == 6
1434        epoch_count += 1
1435    assert epoch_count == num_epoch
1436
1437    # reply on garbage collector to destroy iter1
1438
1439    logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n")
1440
1441
1442@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1443def test_cache_nomap_epoch_ctrl4():
1444    """
1445    Feature: DatasetCache op
1446    Description: Test using two-loops method with Repeat under Cache
1447
1448        Cache
1449         |
1450     Map(Decode)
1451         |
1452       Repeat
1453         |
1454      TFRecord
1455
1456    Expectation: Output is equal to the expected output
1457    """
1458    logger.info("Test cache nomap epoch ctrl4")
1459    if "SESSION_ID" in os.environ:
1460        session_id = int(os.environ['SESSION_ID'])
1461    else:
1462        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1463
1464    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1465
1466    # This dataset has 3 records in it only
1467    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1468    ds1 = ds1.repeat(2)
1469    decode_op = c_vision.Decode()
1470    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
1471
1472    num_epoch = 5
1473    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
1474
1475    epoch_count = 0
1476    for _ in range(num_epoch):
1477        row_count = 0
1478        for _ in iter1:
1479            row_count += 1
1480        logger.info("Number of data in ds1: {} ".format(row_count))
1481        assert row_count == 6
1482        epoch_count += 1
1483    assert epoch_count == num_epoch
1484
1485    logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n")
1486
1487
1488@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1489def test_cache_nomap_multiple_cache1():
1490    """
1491    Feature: DatasetCache op
1492    Description: Test multiple Cache in the same python script
1493
1494       Cache                  Cache
1495         |                      |
1496    Map(Decode)             Map(Decode)
1497         |                      |
1498    TFRecord(train)        TFRecord(eval)
1499
1500    Expectation: Output is equal to the expected output
1501    """
1502    logger.info("Test cache nomap multiple cache 1")
1503    if "SESSION_ID" in os.environ:
1504        session_id = int(os.environ['SESSION_ID'])
1505    else:
1506        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1507
1508    train_cache = ds.DatasetCache(session_id=session_id, size=0)
1509    eval_cache = ds.DatasetCache(session_id=session_id, size=0)
1510
1511    # This dataset has 12 records in it
1512    train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
1513    decode_op = c_vision.Decode()
1514    train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
1515
1516    # This dataset has 3 records in it only
1517    eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1518    eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
1519
1520    num_epoch = 5
1521    train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
1522    eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
1523
1524    epoch_count = 0
1525    for _ in range(num_epoch):
1526        assert sum([1 for _ in train_iter]) == 12
1527        assert sum([1 for _ in eval_iter]) == 3
1528        epoch_count += 1
1529    assert epoch_count == num_epoch
1530
1531    logger.info("test_cache_nomap_multiple_cache1 Ended.\n")
1532
1533
1534@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1535def test_cache_nomap_multiple_cache2():
1536    """
1537    Feature: DatasetCache op
1538    Description: Test multiple Cache in the same Python script
1539
1540       Cache
1541         |
1542    Map(Decode)               Cache
1543         |                      |
1544    TFRecord(image)        TFRecord(text)
1545
1546    Expectation: Output is equal to the expected output
1547    """
1548    logger.info("Test cache nomap multiple cache 2")
1549    if "SESSION_ID" in os.environ:
1550        session_id = int(os.environ['SESSION_ID'])
1551    else:
1552        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1553
1554    image_cache = ds.DatasetCache(session_id=session_id, size=0)
1555    text_cache = ds.DatasetCache(session_id=session_id, size=0)
1556
1557    # This dataset has 3 records in it only
1558    image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1559    decode_op = c_vision.Decode()
1560    image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
1561
1562    # This dataset has 3 records in it only
1563    text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache)
1564
1565    num_epoch = 5
1566    image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
1567    text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1568
1569    epoch_count = 0
1570    for _ in range(num_epoch):
1571        row_count = 0
1572        for _, _ in itertools.zip_longest(image_iter, text_iter):
1573            row_count += 1
1574        assert row_count == 3
1575        epoch_count += 1
1576    assert epoch_count == num_epoch
1577
1578    logger.info("test_cache_nomap_multiple_cache2 Ended.\n")
1579
1580
1581@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1582def test_cache_nomap_multiple_cache3():
1583    """
1584    Feature: DatasetCache op
1585    Description: Test multiple Cache in the same Python script
1586
1587       Cache                   Cache
1588         |                      |
1589    Map(Decode)             Map(Decode)
1590         |                      |
1591    TFRecord                ImageFolder
1592
1593    Expectation: Output is equal to the expected output
1594    """
1595
1596    logger.info("Test cache nomap multiple cache 3")
1597    if "SESSION_ID" in os.environ:
1598        session_id = int(os.environ['SESSION_ID'])
1599    else:
1600        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1601
1602    tf_cache = ds.DatasetCache(session_id=session_id, size=0)
1603    image_cache = ds.DatasetCache(session_id=session_id, size=0)
1604
1605    # This dataset has 3 records in it only
1606    tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1607    decode_op = c_vision.Decode()
1608    tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
1609
1610    # This DATA_DIR only has 2 images in it
1611    image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR)
1612    image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
1613
1614    num_epoch = 5
1615    tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch)
1616    image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
1617
1618    epoch_count = 0
1619    for _ in range(num_epoch):
1620        assert sum([1 for _ in tf_iter]) == 3
1621        assert sum([1 for _ in image_iter]) == 2
1622        epoch_count += 1
1623    assert epoch_count == num_epoch
1624
1625    logger.info("test_cache_nomap_multiple_cache3 Ended.\n")
1626
1627
1628@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1629def test_cache_nomap_multiple_cache_train():
1630    """
1631    Feature: DatasetCache op
1632    Description: Test multi Cache in different Python scripts.
1633        Runs concurrently with test_cache_nomap_multiple_cache_eval
1634
1635       Cache
1636         |
1637    Map(Decode)
1638         |
1639    TFRecord(train)
1640
1641    Expectation: Output is equal to the expected output
1642    """
1643    logger.info("Test cache nomap multiple cache train")
1644    if "SESSION_ID" in os.environ:
1645        session_id = int(os.environ['SESSION_ID'])
1646    else:
1647        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1648
1649    train_cache = ds.DatasetCache(session_id=session_id, size=0)
1650
1651    # This dataset has 12 records in it
1652    train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
1653    decode_op = c_vision.Decode()
1654    train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
1655
1656    num_epoch = 5
1657    train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
1658
1659    epoch_count = 0
1660    for _ in range(num_epoch):
1661        assert sum([1 for _ in train_iter]) == 12
1662        epoch_count += 1
1663    assert epoch_count == num_epoch
1664
1665    logger.info("test_cache_nomap_multiple_cache_train Ended.\n")
1666
1667
1668@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1669def test_cache_nomap_multiple_cache_eval():
1670    """
1671    Feature: DatasetCache op
1672    Description: Test multi Cache in different Python scripts.
1673        Runs concurrently with test_cache_nomap_multiple_cache_eval
1674
1675       Cache
1676         |
1677    Map(Decode)
1678         |
1679    TFRecord(eval)
1680
1681    Expectation: Output is equal to the expected output
1682    """
1683    logger.info("Test cache nomap multiple cache eval")
1684    if "SESSION_ID" in os.environ:
1685        session_id = int(os.environ['SESSION_ID'])
1686    else:
1687        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1688
1689    eval_cache = ds.DatasetCache(session_id=session_id, size=0)
1690
1691    # This dataset only has 3 records in it
1692    eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1693    decode_op = c_vision.Decode()
1694    eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
1695
1696    num_epoch = 5
1697    eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
1698
1699    epoch_count = 0
1700    for _ in range(num_epoch):
1701        assert sum([1 for _ in eval_iter]) == 3
1702        epoch_count += 1
1703    assert epoch_count == num_epoch
1704
1705    logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
1706
1707
1708@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1709def test_cache_nomap_clue1():
1710    """
1711    Feature: DatasetCache op
1712    Description: Test CLUEDataset (a non mappable dataset) with a Cache over it just after the leaf
1713        In this one, the CLUEDataset will be given sharding configuration, however since a Cache is
1714        used, the tree prepare should undo the sharding configuration and instead, a distributed
1715        sampler will be chosen with the same shard config.
1716
1717       Cache
1718         |
1719       CLUE
1720
1721    Expectation: Output is equal to the expected output
1722    """
1723    logger.info("Test cache nomap clue 1")
1724    if "SESSION_ID" in os.environ:
1725        session_id = int(os.environ['SESSION_ID'])
1726    else:
1727        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1728
1729    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1730
1731    # With only 3 records shard into 3, we expect only 1 record returned for this shard
1732    # However, the sharding will be done by the sampler, not by the clue leaf node
1733    # In this case, it is a row-based sharding, not the file-based sharding that would happen if
1734    # there was not any cache.
1735    ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache)
1736
1737    num_epoch = 4
1738    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1739
1740    epoch_count = 0
1741    for _ in range(num_epoch):
1742        assert sum([1 for _ in iter1]) == 1
1743        epoch_count += 1
1744    assert epoch_count == num_epoch
1745
1746    logger.info("test_cache_nomap_clue1 Ended.\n")
1747
1748
1749@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1750def test_cache_nomap_clue2():
1751    """
1752    Feature: DatasetCache op
1753    Description: Test CLUEDataset (a non mappable dataset) with a Cache over it after Map, num_samples arg is given
1754
1755       Cache
1756         |
1757    Map(lambda x: x)
1758         |
1759       CLUE
1760
1761    Expectation: Output is equal to the expected output
1762    """
1763    logger.info("Test cache nomap clue 2")
1764    if "SESSION_ID" in os.environ:
1765        session_id = int(os.environ['SESSION_ID'])
1766    else:
1767        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1768
1769    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1770
1771    ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
1772    ds1 = ds1.map(vision.not_random(lambda x: x), ["label"], cache=some_cache)
1773
1774    num_epoch = 4
1775    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1776
1777    epoch_count = 0
1778    for _ in range(num_epoch):
1779        assert sum([1 for _ in iter1]) == 2
1780        epoch_count += 1
1781    assert epoch_count == num_epoch
1782
1783    logger.info("test_cache_nomap_clue2 Ended.\n")
1784
1785
1786@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1787def test_cache_nomap_csv1():
1788    """
1789    Feature: DatasetCache op
1790    Description: Test CSVDataset (a non mappable dataset) with a Cache over it just after the leaf
1791        In this one, the CSVDataset will be given sharding configuration, however since a Cache is
1792        used, the tree prepare should undo the sharding configuration and instead, a distributed
1793        sampler will be chosen with the same shard config.
1794
1795       Cache
1796         |
1797       CSV
1798
1799    Expectation: Output is equal to the expected output
1800    """
1801    logger.info("Test cache nomap csv 1")
1802    if "SESSION_ID" in os.environ:
1803        session_id = int(os.environ['SESSION_ID'])
1804    else:
1805        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1806
1807    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1808
1809    # With only 3 records shard into 3, we expect only 1 record returned for this shard
1810    # However, the sharding will be done by the sampler, not by the clue leaf node
1811    # In this case, it is a row-based sharding, not the file-based sharding that would happen if
1812    # there was not any cache.
1813    ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
1814                        column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache)
1815
1816    num_epoch = 4
1817    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1818
1819    epoch_count = 0
1820    for _ in range(num_epoch):
1821        assert sum([1 for _ in iter1]) == 1
1822        epoch_count += 1
1823    assert epoch_count == num_epoch
1824
1825    logger.info("test_cache_nomap_csv1 Ended.\n")
1826
1827
1828@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1829def test_cache_nomap_csv2():
1830    """
1831    Feature: DatasetCache op
1832    Description: Test CSVDataset (a non mappable dataset) with a Cache over it after Map, num_samples arg is given
1833
1834       Cache
1835         |
1836    Map(lambda x: x)
1837         |
1838       CSV
1839
1840    Expectation: Output is equal to the expected output
1841    """
1842    logger.info("Test cache nomap csv 2")
1843    if "SESSION_ID" in os.environ:
1844        session_id = int(os.environ['SESSION_ID'])
1845    else:
1846        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1847
1848    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1849
1850    ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
1851                        column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
1852    ds1 = ds1.map(vision.not_random(lambda x: x), ["col1"], cache=some_cache)
1853
1854    num_epoch = 4
1855    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1856
1857    epoch_count = 0
1858    for _ in range(num_epoch):
1859        assert sum([1 for _ in iter1]) == 2
1860        epoch_count += 1
1861    assert epoch_count == num_epoch
1862
1863    logger.info("test_cache_nomap_csv2 Ended.\n")
1864
1865
1866@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1867def test_cache_nomap_textfile1():
1868    """
1869    Feature: DatasetCache op
1870    Description: Test TextFileDataset (a non mappable dataset) with a Cache over it just after the leaf
1871        In this one, the text file dataset will be given sharding configuration, however since a Cache is
1872        used, the tree prepare should undo the sharding configuration and instead, a distributed
1873        sampler will be chosen with the same shard config.
1874
1875       Cache
1876         |
1877     TextFile
1878
1879    Expectation: Output is equal to the expected output
1880    """
1881    logger.info("Test cache nomap textfile 1")
1882    if "SESSION_ID" in os.environ:
1883        session_id = int(os.environ['SESSION_ID'])
1884    else:
1885        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1886
1887    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1888
1889    # With only 3 records shard into 3, we expect only 1 record returned for this shard
1890    # However, the sharding will be done by the sampler, not by the clue leaf node
1891    # In this case, it is a row-based sharding, not the file-based sharding that would happen if
1892    # there was not any cache.
1893    ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache)
1894
1895    num_epoch = 4
1896    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1897
1898    epoch_count = 0
1899    for _ in range(num_epoch):
1900        assert sum([1 for _ in iter1]) == 1
1901        epoch_count += 1
1902    assert epoch_count == num_epoch
1903
1904    logger.info("test_cache_nomap_textfile1 Ended.\n")
1905
1906
1907@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1908def test_cache_nomap_textfile2():
1909    """
1910    Feature: DatasetCache op
1911    Description: Test TextFileDataset (a non mappable dataset) with a Cache over it after Map, num_samples arg is given
1912
1913       Cache
1914         |
1915    Map(Tokenizer)
1916         |
1917     TextFile
1918
1919    Expectation: Output is equal to the expected output
1920    """
1921    def my_tokenizer(line):
1922        words = line.split()
1923        if not words:
1924            return [""]
1925        return words
1926
1927    logger.info("Test cache nomap textfile 2")
1928    if "SESSION_ID" in os.environ:
1929        session_id = int(os.environ['SESSION_ID'])
1930    else:
1931        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1932
1933    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1934
1935    ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
1936    tokenizer = text.PythonTokenizer(my_tokenizer)
1937    ds1 = ds1.map(operations=tokenizer, cache=some_cache)
1938
1939    num_epoch = 4
1940    iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
1941
1942    epoch_count = 0
1943    for _ in range(num_epoch):
1944        assert sum([1 for _ in iter1]) == 2
1945        epoch_count += 1
1946    assert epoch_count == num_epoch
1947
1948    logger.info("test_cache_nomap_textfile2 Ended.\n")
1949
1950
1951@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1952def test_cache_nomap_nested_repeat():
1953    """
1954    Feature: DatasetCache op
1955    Description: Test Cache on pipeline with nested Repeat ops
1956
1957        Repeat
1958          |
1959        Cache
1960          |
1961      Map(Decode)
1962          |
1963        Repeat
1964          |
1965      TFRecord
1966
1967    Expectation: Output is equal to the expected output
1968    """
1969    logger.info("Test cache nomap nested repeat")
1970    if "SESSION_ID" in os.environ:
1971        session_id = int(os.environ['SESSION_ID'])
1972    else:
1973        raise RuntimeError("Testcase requires SESSION_ID environment variable")
1974
1975    some_cache = ds.DatasetCache(session_id=session_id, size=0)
1976
1977    # This dataset has 3 records in it only
1978    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
1979    decode_op = c_vision.Decode()
1980    ds1 = ds1.repeat(4)
1981    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
1982    ds1 = ds1.repeat(2)
1983
1984    num_iter = 0
1985    for _ in ds1.create_dict_iterator(num_epochs=1):
1986        logger.info("get data from dataset")
1987        num_iter += 1
1988
1989    logger.info("Number of data in ds1: {} ".format(num_iter))
1990    assert num_iter == 24
1991    logger.info('test_cache_nomap_nested_repeat Ended.\n')
1992
1993
1994@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
1995def test_cache_nomap_get_repeat_count():
1996    """
1997    Feature: DatasetCache op
1998    Description: Test get_repeat_count for a pipeline with Cache and nested repeat ops
1999
2000        Cache
2001          |
2002      Map(Decode)
2003          |
2004        Repeat
2005          |
2006      TFRecord
2007
2008    Expectation: Output is equal to the expected output
2009    """
2010    logger.info("Test cache nomap get_repeat_count")
2011    if "SESSION_ID" in os.environ:
2012        session_id = int(os.environ['SESSION_ID'])
2013    else:
2014        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2015
2016    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2017
2018    # This dataset has 3 records in it only
2019    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
2020    ds1 = ds1.repeat(4)
2021    decode_op = c_vision.Decode()
2022    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
2023
2024    repeat_count = ds1.get_repeat_count()
2025    logger.info("repeat_count: {}".format(repeat_count))
2026    assert repeat_count == 4
2027
2028    num_iter = 0
2029    for _ in ds1.create_dict_iterator(num_epochs=1):
2030        logger.info("get data from dataset")
2031        num_iter += 1
2032    assert num_iter == 12
2033
2034
2035@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2036def test_cache_nomap_long_file_list():
2037    """
2038    Feature: DatasetCache op
2039    Description: Test Cache after TFRecord with a long list of files as arguments
2040
2041        Cache
2042          |
2043      TFRecord
2044
2045    Expectation: Error is raised as expected
2046    """
2047    logger.info("Test cache nomap long file list")
2048    if "SESSION_ID" in os.environ:
2049        session_id = int(os.environ['SESSION_ID'])
2050    else:
2051        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2052
2053    some_cache = ds.DatasetCache(session_id=session_id, size=1)
2054
2055    ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"],
2056                             cache=some_cache)
2057
2058    with pytest.raises(RuntimeError) as e:
2059        sum([1 for _ in ds1])
2060    assert "Out of memory" in str(e.value)
2061    logger.info("test_cache_nomap_long_file_list Ended.\n")
2062
2063
2064@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2065def test_cache_nomap_failure1():
2066    """
2067    Feature: DatasetCache op
2068    Description: Test nested Cache
2069
2070        Repeat
2071          |
2072        Cache
2073          |
2074      Map(Decode)
2075          |
2076        Cache
2077          |
2078      TFRecord
2079
2080    Expectation: Error is raised as expected
2081    """
2082    logger.info("Test cache nomap failure 1")
2083    if "SESSION_ID" in os.environ:
2084        session_id = int(os.environ['SESSION_ID'])
2085    else:
2086        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2087
2088    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2089
2090    # This dataset has 3 records in it only
2091    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
2092    decode_op = c_vision.Decode()
2093    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
2094    ds1 = ds1.repeat(4)
2095
2096    with pytest.raises(RuntimeError) as e:
2097        ds1.get_batch_size()
2098    assert "Nested cache operations" in str(e.value)
2099
2100    with pytest.raises(RuntimeError) as e:
2101        num_iter = 0
2102        for _ in ds1.create_dict_iterator(num_epochs=1):
2103            num_iter += 1
2104    assert "Nested cache operations" in str(e.value)
2105
2106    assert num_iter == 0
2107    logger.info('test_cache_nomap_failure1 Ended.\n')
2108
2109
2110@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2111def test_cache_nomap_failure2():
2112    """
2113    Feature: DatasetCache op
2114    Description: Test Zip under Cache
2115
2116               Repeat
2117                  |
2118                Cache
2119                  |
2120             Map(Decode)
2121                  |
2122                 Zip
2123                |    |
2124           Random    Random
2125
2126    Expectation: Error is raised as expected
2127    """
2128    logger.info("Test cache nomap failure 2")
2129    if "SESSION_ID" in os.environ:
2130        session_id = int(os.environ['SESSION_ID'])
2131    else:
2132        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2133
2134    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2135
2136    schema = ds.Schema()
2137    schema.add_column('image', de_type=mstype.uint8,
2138                      shape=[640, 480, 3])  # 921600 bytes (a bit less than 1 MB per image)
2139    schema.add_column('label', de_type=mstype.uint8, shape=[1])
2140
2141    ds1 = ds.RandomDataset(schema=schema)
2142    ds2 = ds.RandomDataset(schema=schema)
2143    dsz = ds.zip((ds1, ds2))
2144    decode_op = c_vision.Decode()
2145    dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
2146    dsz = dsz.repeat(4)
2147
2148    with pytest.raises(RuntimeError) as e:
2149        num_iter = 0
2150        for _ in dsz.create_dict_iterator(num_epochs=1):
2151            num_iter += 1
2152    assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
2153
2154    assert num_iter == 0
2155    logger.info('test_cache_nomap_failure2 Ended.\n')
2156
2157
2158@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2159def test_cache_nomap_failure3():
2160    """
2161    Feature: DatasetCache op
2162    Description: Test Batch under Cache
2163
2164               Repeat
2165                  |
2166                Cache
2167                  |
2168             Map(Resize)
2169                  |
2170                Batch
2171                  |
2172                Clue
2173
2174    Expectation: Error is raised as expected
2175    """
2176    logger.info("Test cache nomap failure 3")
2177    if "SESSION_ID" in os.environ:
2178        session_id = int(os.environ['SESSION_ID'])
2179    else:
2180        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2181
2182    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2183
2184    ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train')
2185    ds1 = ds1.batch(2)
2186    resize_op = c_vision.Resize((224, 224))
2187    ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
2188    ds1 = ds1.repeat(4)
2189
2190    with pytest.raises(RuntimeError) as e:
2191        num_iter = 0
2192        for _ in ds1.create_dict_iterator(num_epochs=1):
2193            num_iter += 1
2194    assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
2195
2196    assert num_iter == 0
2197    logger.info('test_cache_nomap_failure3 Ended.\n')
2198
2199
2200@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2201def test_cache_nomap_failure4():
2202    """
2203    Feature: DatasetCache op
2204    Description: Test Filter under Cache
2205
2206               Repeat
2207                  |
2208                Cache
2209                  |
2210             Map(Decode)
2211                  |
2212                Filter
2213                  |
2214                 CSV
2215
2216    Expectation: Error is raised as expected
2217    """
2218    logger.info("Test cache nomap failure 4")
2219    if "SESSION_ID" in os.environ:
2220        session_id = int(os.environ['SESSION_ID'])
2221    else:
2222        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2223
2224    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2225
2226    ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
2227                        column_names=['col1', 'col2', 'col3', 'col4'])
2228    ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
2229
2230    decode_op = c_vision.Decode()
2231    ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
2232    ds1 = ds1.repeat(4)
2233
2234    with pytest.raises(RuntimeError) as e:
2235        num_iter = 0
2236        for _ in ds1.create_dict_iterator(num_epochs=1):
2237            num_iter += 1
2238    assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
2239
2240    assert num_iter == 0
2241    logger.info('test_cache_nomap_failure4 Ended.\n')
2242
2243
2244@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2245def test_cache_nomap_failure5():
2246    """
2247    Feature: DatasetCache op
2248    Description: Test Map containing Random operation under Cache
2249
2250               Repeat
2251                  |
2252                Cache
2253                  |
2254             Map(Decode, RandomCrop)
2255                  |
2256              TextFile
2257
2258    Expectation: Error is raised as expected
2259    """
2260    logger.info("Test cache nomap failure 5")
2261    if "SESSION_ID" in os.environ:
2262        session_id = int(os.environ['SESSION_ID'])
2263    else:
2264        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2265
2266    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2267
2268    data = ds.TextFileDataset(TEXT_FILE_DATA_DIR)
2269    random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
2270    decode_op = c_vision.Decode()
2271
2272    data = data.map(input_columns=["image"], operations=decode_op)
2273    data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
2274    data = data.repeat(4)
2275
2276    with pytest.raises(RuntimeError) as e:
2277        num_iter = 0
2278        for _ in data.create_dict_iterator(num_epochs=1):
2279            num_iter += 1
2280    assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
2281
2282    assert num_iter == 0
2283    logger.info('test_cache_nomap_failure5 Ended.\n')
2284
2285
2286@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2287def test_cache_nomap_pyfunc_lambda():
2288    """
2289    Feature: DatasetCache op
2290    Description: Test cache after Map op with a Python lambda function
2291
2292        Cache
2293          |
2294        Map(lambda function1, lambda function2)
2295          |
2296      TFRecord
2297
2298    Expectation: Only success if the lambda function is wrapped by 'pyvision.not_random', otherwise error is raised
2299    """
2300    logger.info("Test cache nomap pyfunc lambda")
2301    if "SESSION_ID" in os.environ:
2302        session_id = int(os.environ['SESSION_ID'])
2303    else:
2304        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2305
2306    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2307
2308    # This dataset has 12 records in it
2309    data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
2310    transforms = [vision.not_random(lambda x: x + x), vision.not_random(lambda x: x - 1)]
2311    data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache)
2312
2313    num_iter = 0
2314    for _ in data1.create_dict_iterator(num_epochs=1):
2315        num_iter += 1
2316    assert num_iter == 12
2317
2318    other_cache = ds.DatasetCache(session_id=session_id, size=0)
2319    ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
2320    ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache)
2321
2322    with pytest.raises(RuntimeError) as e:
2323        num_iter = 0
2324        for _ in ds2.create_dict_iterator(num_epochs=1):
2325            num_iter += 1
2326    assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
2327    logger.info("test_cache_nomap_pyfunc_lambda Ended.\n")
2328
2329
2330@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2331def test_cache_nomap_pyfunc_builtin():
2332    """
2333    Feature: DatasetCache op
2334    Description: Test Cache after Map op with a Python builtin PyFunc
2335
2336        Cache
2337          |
2338     Map([builtin pyfunc1, builtin pyfunc2])
2339          |
2340      TFRecord
2341
2342    Expectation: Error will be raised if the builtin PyFunc containing Random op, otherwise runs successfully
2343    """
2344    logger.info("Test cache nomap pyfunc builtin")
2345    if "SESSION_ID" in os.environ:
2346        session_id = int(os.environ['SESSION_ID'])
2347    else:
2348        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2349
2350    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2351    # This dataset has 3 records in it only
2352    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
2353    ds1 = ds1.map(operations=[vision.Decode(), vision.ToTensor()], input_columns=["image"], cache=some_cache)
2354
2355    num_iter = 0
2356    for _ in ds1.create_dict_iterator(num_epochs=1):
2357        num_iter += 1
2358    assert num_iter == 3
2359
2360    other_cache = ds.DatasetCache(session_id=session_id, size=0)
2361    # This dataset has 3 records in it only
2362    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
2363    ds2 = ds2.map(operations=[vision.Decode(), vision.RandomCrop(224), vision.ToTensor()],
2364                  input_columns=["image"], cache=other_cache)
2365
2366    with pytest.raises(RuntimeError) as e:
2367        num_iter = 0
2368        for _ in ds2.create_dict_iterator(num_epochs=1):
2369            num_iter += 1
2370    assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
2371    logger.info("test_cache_nomap_pyfunc_builtin Ended.\n")
2372
2373
2374@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2375def test_cache_nomap_pyfunc_function():
2376    """
2377    Feature: DatasetCache op
2378    Description: Test Cache after Map op with a Python customized function
2379
2380        Cache
2381          |
2382     Map([function1, function2])
2383          |
2384      TFRecord
2385
2386    Expectation: Only success if the function is decorated with 'vision.not_random', otherwise an error will be raised
2387    """
2388    @vision.not_random
2389    def not_random_func(x):
2390        return np.ones(x.shape, dtype=x.dtype)
2391
2392    def normal_func(x):
2393        return np.ones(x.shape, dtype=x.dtype)
2394
2395    logger.info("Test cache nomap pyfunc function")
2396    if "SESSION_ID" in os.environ:
2397        session_id = int(os.environ['SESSION_ID'])
2398    else:
2399        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2400
2401    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2402    # This dataset has 3 records in it only
2403    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
2404    ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache)
2405
2406    num_iter = 0
2407    for _ in ds1.create_dict_iterator(num_epochs=1):
2408        num_iter += 1
2409    assert num_iter == 3
2410
2411    other_cache = ds.DatasetCache(session_id=session_id, size=0)
2412    # This dataset has 3 records in it only
2413    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
2414    ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache)
2415
2416    with pytest.raises(RuntimeError) as e:
2417        num_iter = 0
2418        for _ in ds2.create_dict_iterator(num_epochs=1):
2419            num_iter += 1
2420    assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
2421    logger.info("test_cache_nomap_pyfunc_function Ended.\n")
2422
2423
2424@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2425def test_cache_nomap_all_rows_cached():
2426    """
2427    Feature: DatasetCache op
2428    Description: Test making sure all rows are cached before we switch to the fetching phase
2429
2430       Cache
2431         |
2432     RandomDataset
2433
2434    Expectation: Output is equal to the expected output
2435    """
2436    logger.info("Test cache nomap all rows cached")
2437    if "SESSION_ID" in os.environ:
2438        session_id = int(os.environ['SESSION_ID'])
2439    else:
2440        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2441
2442    schema = ds.Schema()
2443    schema.add_column('image', de_type=mstype.uint8,
2444                      shape=[450, 450, 3])
2445    schema.add_column('label', de_type=mstype.uint8, shape=[1])
2446
2447    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2448
2449    # easier to reproduce the problem with 271 total rows
2450    num_total_rows = 271
2451    # User-created sampler here
2452    ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache)
2453    iter1 = ds1.create_dict_iterator(num_epochs=1)
2454
2455    num_iter = 0
2456    for _ in iter1:
2457        num_iter += 1
2458    logger.info("Number of data in ds1: {} ".format(num_iter))
2459    assert num_iter == num_total_rows
2460
2461    cache_stat = some_cache.get_stat()
2462    assert cache_stat.num_mem_cached == num_total_rows
2463
2464    logger.info("test_cache_nomap_all_rows_cached Ended.\n")
2465
2466
2467@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2468def test_cache_nomap_dataset_size1():
2469    """
2470    Feature: DatasetCache op
2471    Description: Test get_dataset_size when Cache is injected directly after a non-mappable leaf
2472
2473       Cache
2474         |
2475      TFRecord
2476
2477    Expectation: Output is equal to the expected output
2478    """
2479    logger.info("Test cache nomap dataset size 1")
2480    if "SESSION_ID" in os.environ:
2481        session_id = int(os.environ['SESSION_ID'])
2482    else:
2483        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2484
2485    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2486
2487    # This dataset has 3 records in it only
2488    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=2, shard_id=0, cache=some_cache)
2489
2490    dataset_size = ds1.get_dataset_size()
2491    assert dataset_size == 2
2492
2493    num_iter = 0
2494    for _ in ds1.create_dict_iterator(num_epochs=1):
2495        num_iter += 1
2496
2497    logger.info("Number of data in ds1: {} ".format(num_iter))
2498    assert num_iter == dataset_size
2499    logger.info("test_cache_nomap_dataset_size1 Ended.\n")
2500
2501
2502@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
2503def test_cache_nomap_dataset_size2():
2504    """
2505    Feature: DatasetCache op
2506    Description: Test get_dataset_size when Cache is injected after Map
2507
2508       Cache
2509         |
2510    Map(Decode)
2511         |
2512     TFRecord
2513
2514    Expectation: Output is equal to the expected output
2515    """
2516    logger.info("Test cache nomap dataset size 2")
2517    if "SESSION_ID" in os.environ:
2518        session_id = int(os.environ['SESSION_ID'])
2519    else:
2520        raise RuntimeError("Testcase requires SESSION_ID environment variable")
2521
2522    some_cache = ds.DatasetCache(session_id=session_id, size=0)
2523
2524    # This dataset has 3 records in it only
2525    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=2, shard_id=0)
2526    decode_op = c_vision.Decode()
2527    ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
2528
2529    dataset_size = ds1.get_dataset_size()
2530    assert dataset_size == 2
2531
2532    num_iter = 0
2533    for _ in ds1.create_dict_iterator(num_epochs=1):
2534        num_iter += 1
2535
2536    logger.info("Number of data in ds1: {} ".format(num_iter))
2537    assert num_iter == dataset_size
2538    logger.info("test_cache_nomap_dataset_size2 Ended.\n")
2539
2540
2541if __name__ == '__main__':
2542    # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
2543    # since cache server is required to be brought up first
2544    test_cache_nomap_basic1()
2545    test_cache_nomap_basic2()
2546    test_cache_nomap_basic3()
2547    test_cache_nomap_basic4()
2548    test_cache_nomap_basic5()
2549    test_cache_nomap_basic6()
2550    test_cache_nomap_basic7()
2551    test_cache_nomap_basic8()
2552    test_cache_nomap_basic9()
2553    test_cache_nomap_allowed_share1()
2554    test_cache_nomap_allowed_share2()
2555    test_cache_nomap_allowed_share3()
2556    test_cache_nomap_allowed_share4()
2557    test_cache_nomap_disallowed_share1()
2558    test_cache_nomap_running_twice1()
2559    test_cache_nomap_running_twice2()
2560    test_cache_nomap_extra_small_size1()
2561    test_cache_nomap_extra_small_size2()
2562    test_cache_nomap_parallel_pipeline1(shard=0)
2563    test_cache_nomap_parallel_pipeline2(shard=1)
2564    test_cache_nomap_parallel_workers()
2565    test_cache_nomap_server_workers_1()
2566    test_cache_nomap_server_workers_100()
2567    test_cache_nomap_num_connections_1()
2568    test_cache_nomap_num_connections_100()
2569    test_cache_nomap_prefetch_size_1()
2570    test_cache_nomap_prefetch_size_100()
2571    test_cache_nomap_device_que()
2572    test_cache_nomap_session_destroy()
2573    test_cache_nomap_server_stop()
2574    test_cache_nomap_epoch_ctrl1()
2575    test_cache_nomap_epoch_ctrl2()
2576    test_cache_nomap_epoch_ctrl3()
2577    test_cache_nomap_epoch_ctrl4()
2578    test_cache_nomap_multiple_cache1()
2579    test_cache_nomap_multiple_cache2()
2580    test_cache_nomap_multiple_cache3()
2581    test_cache_nomap_multiple_cache_train()
2582    test_cache_nomap_multiple_cache_eval()
2583    test_cache_nomap_clue1()
2584    test_cache_nomap_clue2()
2585    test_cache_nomap_csv1()
2586    test_cache_nomap_csv2()
2587    test_cache_nomap_textfile1()
2588    test_cache_nomap_textfile2()
2589    test_cache_nomap_nested_repeat()
2590    test_cache_nomap_get_repeat_count()
2591    test_cache_nomap_long_file_list()
2592    test_cache_nomap_failure1()
2593    test_cache_nomap_failure2()
2594    test_cache_nomap_failure3()
2595    test_cache_nomap_failure4()
2596    test_cache_nomap_failure5()
2597    test_cache_nomap_pyfunc_lambda()
2598    test_cache_nomap_pyfunc_builtin()
2599    test_cache_nomap_pyfunc_function()
2600    test_cache_nomap_dataset_size1()
2601    test_cache_nomap_dataset_size2()
2602