• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing Epoch Control op in DE
17"""
18import itertools
19import numpy as np
20import pytest
21import cv2
22
23import mindspore.dataset as ds
24import mindspore.dataset.vision.c_transforms as vision
25from mindspore import log as logger
26
27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
29
30
31def diff_mse(in1, in2):
32    """
33    diff_mse
34    """
35    mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
36    return mse * 100
37
38
39def test_cifar10():
40    """
41    dataset parameter
42    """
43    logger.info("Test dataset parameter")
44    data_dir_10 = "../data/dataset/testCifar10Data"
45    num_repeat = 2
46    batch_size = 32
47    limit_dataset = 100
48    # apply dataset operations
49    data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset)
50    data1 = data1.repeat(num_repeat)
51    data1 = data1.batch(batch_size, True)
52    num_epoch = 5
53    # iter1 will always assume there is a next epoch and never shutdown.
54    iter1 = data1.create_tuple_iterator()
55    epoch_count = 0
56    sample_count = 0
57    for _ in range(num_epoch):
58        row_count = 0
59        for _ in iter1:
60            # in this example, each dictionary has keys "image" and "label"
61            row_count += 1
62        assert row_count == int(limit_dataset * num_repeat / batch_size)
63        logger.debug("row_count: ", row_count)
64        epoch_count += 1
65        sample_count += row_count
66    assert epoch_count == num_epoch
67    logger.debug("total epochs: ", epoch_count)
68    assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
69    logger.debug("total sample: ", sample_count)
70
71
72def test_decode_op():
73    """
74    Test Decode op
75    """
76    logger.info("test_decode_op")
77
78    # Decode with rgb format set to True
79    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
80
81    # Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
82    data1 = data1.map(operations=[vision.Decode(True)], input_columns=["image"])
83
84    # Second dataset
85    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
86
87    num_epoch = 5
88    # iter1 will always assume there is a next epoch and never shutdown.
89    iter1 = data1.create_dict_iterator(output_numpy=True)
90    # iter 2 will stop and shutdown pipeline after num_epoch
91    iter2 = data2.create_dict_iterator(num_epoch, output_numpy=True)
92    for _ in range(num_epoch):
93        i = 0
94        for item1, item2 in itertools.zip_longest(iter1, iter2):
95            actual = item1["image"]
96            expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR)
97            expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB)
98            assert actual.shape == expected.shape
99            diff = actual - expected
100            mse = np.sum(np.power(diff, 2))
101            assert mse == 0
102            i = i + 1
103        assert i == 3
104
105    # Users have the option to manually stop the iterator, or rely on garbage collector.
106    iter1.stop()
107    # Expect a AttributeError since iter1 has been stopped.
108    with pytest.raises(AttributeError) as info:
109        iter1.__next__()
110    assert "object has no attribute '_runtime_context'" in str(info.value)
111
112    with pytest.raises(RuntimeError) as info:
113        iter2.__next__()
114    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
115    assert err_msg in str(info.value)
116
117
118# Generate 1d int numpy array from 0 - 63
119def generator_1d():
120    """
121    generator
122    """
123    for i in range(64):
124        yield (np.array([i]),)
125
126
127def test_generator_dict_0():
128    """
129    test generator dict 0
130    """
131    logger.info("Test 1D Generator : 0 - 63")
132
133    # apply dataset operations
134    data1 = ds.GeneratorDataset(generator_1d, ["data"])
135
136    i = 0
137    # create the iterator inside the loop declaration
138    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
139        golden = np.array([i])
140        np.testing.assert_array_equal(item["data"], golden)
141        i = i + 1
142
143
144def test_generator_dict_1():
145    """
146    test generator dict 1
147    """
148    logger.info("Test 1D Generator : 0 - 63")
149
150    # apply dataset operations
151    data1 = ds.GeneratorDataset(generator_1d, ["data"])
152
153    for _ in range(10):
154        i = 0
155        # BAD. Do not create iterator every time inside.
156        # Create iterator outside the epoch for loop.
157        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
158            golden = np.array([i])
159            np.testing.assert_array_equal(item["data"], golden)
160            i = i + 1
161        assert i == 64
162
163
164def test_generator_dict_2():
165    """
166    test generator dict 2
167    """
168    logger.info("Test 1D Generator : 0 - 63")
169
170    # apply dataset operations
171    data1 = ds.GeneratorDataset(generator_1d, ["data"])
172    iter1 = data1.create_dict_iterator()
173    for _ in range(10):
174        i = 0
175        for item in iter1:  # each data is a dictionary
176            golden = np.array([i])
177            np.testing.assert_array_equal(item["data"].asnumpy(), golden)
178            i = i + 1
179        assert i == 64
180
181    # iter1 is still alive and running.
182    item1 = iter1.__next__()
183    assert item1
184    # rely on garbage collector to destroy iter1
185
186
187def test_generator_dict_3():
188    """
189    test generator dict 3
190    """
191    logger.info("Test 1D Generator : 0 - 63")
192
193    # apply dataset operations
194    data1 = ds.GeneratorDataset(generator_1d, ["data"])
195    iter1 = data1.create_dict_iterator()
196    for _ in range(10):
197        i = 0
198        for item in iter1:  # each data is a dictionary
199            golden = np.array([i])
200            np.testing.assert_array_equal(item["data"].asnumpy(), golden)
201            i = i + 1
202        assert i == 64
203    # optional
204    iter1.stop()
205    # Expect a AttributeError since iter1 has been stopped.
206    with pytest.raises(AttributeError) as info:
207        iter1.__next__()
208    assert "object has no attribute '_runtime_context'" in str(info.value)
209
210
211def test_generator_dict_4():
212    """
213    test generator dict 4
214    """
215    logger.info("Test 1D Generator : 0 - 63")
216
217    # apply dataset operations
218    data1 = ds.GeneratorDataset(generator_1d, ["data"])
219    iter1 = data1.create_dict_iterator(num_epochs=10)
220    for _ in range(10):
221        i = 0
222        for item in iter1:  # each data is a dictionary
223            golden = np.array([i])
224            np.testing.assert_array_equal(item["data"].asnumpy(), golden)
225            i = i + 1
226        assert i == 64
227
228    with pytest.raises(RuntimeError) as info:
229        iter1.__next__()
230    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
231    assert err_msg in str(info.value)
232
233
234def test_generator_dict_4_1():
235    """
236    test generator dict 4_1
237    """
238    logger.info("Test 1D Generator : 0 - 63")
239
240    # apply dataset operations
241    data1 = ds.GeneratorDataset(generator_1d, ["data"])
242    # epoch ctrl op will not be injected if num_epochs is 1.
243    iter1 = data1.create_dict_iterator(num_epochs=1, output_numpy=True)
244    for _ in range(1):
245        i = 0
246        for item in iter1:  # each data is a dictionary
247            golden = np.array([i])
248            np.testing.assert_array_equal(item["data"], golden)
249            i = i + 1
250        assert i == 64
251
252    with pytest.raises(RuntimeError) as info:
253        iter1.__next__()
254    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
255    assert err_msg in str(info.value)
256
257
258def test_generator_dict_4_2():
259    """
260    test generator dict 4_2
261    """
262    logger.info("Test 1D Generator : 0 - 63")
263
264    # apply dataset operations
265    data1 = ds.GeneratorDataset(generator_1d, ["data"])
266    # repeat will not be injected when num repeat is 1.
267    data1 = data1.repeat(1)
268    # epoch ctrl op will not be injected if num_epochs is 1.
269    iter1 = data1.create_dict_iterator(num_epochs=1, output_numpy=True)
270    for _ in range(1):
271        i = 0
272        for item in iter1:  # each data is a dictionary
273            golden = np.array([i])
274            np.testing.assert_array_equal(item["data"], golden)
275            i = i + 1
276        assert i == 64
277
278    with pytest.raises(RuntimeError) as info:
279        iter1.__next__()
280    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
281    assert err_msg in str(info.value)
282
283
284def test_generator_dict_5():
285    """
286    test generator dict 5
287    """
288    logger.info("Test 1D Generator : 0 - 63")
289
290    # apply dataset operations
291    data1 = ds.GeneratorDataset(generator_1d, ["data"])
292    iter1 = data1.create_dict_iterator(num_epochs=11, output_numpy=True)
293    for _ in range(10):
294        i = 0
295        for item in iter1:  # each data is a dictionary
296            golden = np.array([i])
297            np.testing.assert_array_equal(item["data"], golden)
298            i = i + 1
299        assert i == 64
300
301    # still one more epoch left in the iter1.
302    i = 0
303    for item in iter1:  # each data is a dictionary
304        golden = np.array([i])
305        np.testing.assert_array_equal(item["data"], golden)
306        i = i + 1
307    assert i == 64
308
309    # now iter1 has been exhausted, c++ pipeline has been shut down.
310    with pytest.raises(RuntimeError) as info:
311        iter1.__next__()
312    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
313    assert err_msg in str(info.value)
314
315
316# Test tuple iterator
317
318def test_generator_tuple_0():
319    """
320    test generator tuple 0
321    """
322    logger.info("Test 1D Generator : 0 - 63")
323
324    # apply dataset operations
325    data1 = ds.GeneratorDataset(generator_1d, ["data"])
326
327    i = 0
328    # create the iterator inside the loop declaration
329    for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
330        golden = np.array([i])
331        np.testing.assert_array_equal(item[0], golden)
332        i = i + 1
333
334
335def test_generator_tuple_1():
336    """
337    test generator tuple 1
338    """
339    logger.info("Test 1D Generator : 0 - 63")
340
341    # apply dataset operations
342    data1 = ds.GeneratorDataset(generator_1d, ["data"])
343
344    for _ in range(10):
345        i = 0
346        # BAD. Do not create iterator every time inside.
347        # Create iterator outside the epoch for loop.
348        for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True):  # each data is a dictionary
349            golden = np.array([i])
350            np.testing.assert_array_equal(item[0], golden)
351            i = i + 1
352        assert i == 64
353
354
355def test_generator_tuple_2():
356    """
357    test generator tuple 2
358    """
359    logger.info("Test 1D Generator : 0 - 63")
360
361    # apply dataset operations
362    data1 = ds.GeneratorDataset(generator_1d, ["data"])
363    iter1 = data1.create_tuple_iterator(output_numpy=True)
364    for _ in range(10):
365        i = 0
366        for item in iter1:  # each data is a dictionary
367            golden = np.array([i])
368            np.testing.assert_array_equal(item[0], golden)
369            i = i + 1
370        assert i == 64
371
372    # iter1 is still alive and running.
373    item1 = iter1.__next__()
374    assert item1
375    # rely on garbage collector to destroy iter1
376
377
378def test_generator_tuple_3():
379    """
380    test generator tuple 3
381    """
382    logger.info("Test 1D Generator : 0 - 63")
383
384    # apply dataset operations
385    data1 = ds.GeneratorDataset(generator_1d, ["data"])
386    iter1 = data1.create_tuple_iterator(output_numpy=True)
387    for _ in range(10):
388        i = 0
389        for item in iter1:  # each data is a dictionary
390            golden = np.array([i])
391            np.testing.assert_array_equal(item[0], golden)
392            i = i + 1
393        assert i == 64
394    # optional
395    iter1.stop()
396    # Expect a AttributeError since iter1 has been stopped.
397    with pytest.raises(AttributeError) as info:
398        iter1.__next__()
399    assert "object has no attribute '_runtime_context'" in str(info.value)
400
401
402def test_generator_tuple_4():
403    """
404    test generator tuple 4
405    """
406    logger.info("Test 1D Generator : 0 - 63")
407
408    # apply dataset operations
409    data1 = ds.GeneratorDataset(generator_1d, ["data"])
410    iter1 = data1.create_tuple_iterator(num_epochs=10, output_numpy=True)
411    for _ in range(10):
412        i = 0
413        for item in iter1:  # each data is a dictionary
414            golden = np.array([i])
415            np.testing.assert_array_equal(item[0], golden)
416            i = i + 1
417        assert i == 64
418
419    with pytest.raises(RuntimeError) as info:
420        iter1.__next__()
421    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
422    assert err_msg in str(info.value)
423
424
425def test_generator_tuple_5():
426    """
427    test generator tuple 5
428    """
429    logger.info("Test 1D Generator : 0 - 63")
430
431    # apply dataset operations
432    data1 = ds.GeneratorDataset(generator_1d, ["data"])
433    iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
434    for _ in range(10):
435        i = 0
436        for item in iter1:  # each data is a dictionary
437            golden = np.array([i])
438            np.testing.assert_array_equal(item[0], golden)
439            i = i + 1
440        assert i == 64
441
442    # still one more epoch left in the iter1.
443    i = 0
444    for item in iter1:  # each data is a dictionary
445        golden = np.array([i])
446        np.testing.assert_array_equal(item[0], golden)
447        i = i + 1
448    assert i == 64
449
450    # now iter1 has been exhausted, c++ pipeline has been shut down.
451    with pytest.raises(RuntimeError) as info:
452        iter1.__next__()
453    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
454    assert err_msg in str(info.value)
455
456
457# Test with repeat
458def test_generator_tuple_repeat_1():
459    """
460    test generator tuple repeat 1
461    """
462    logger.info("Test 1D Generator : 0 - 63")
463
464    # apply dataset operations
465    data1 = ds.GeneratorDataset(generator_1d, ["data"])
466    data1 = data1.repeat(2)
467    iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
468    for _ in range(10):
469        i = 0
470        for item in iter1:  # each data is a dictionary
471            golden = np.array([i % 64])
472            np.testing.assert_array_equal(item[0], golden)
473            i = i + 1
474        assert i == 64 * 2
475
476    # still one more epoch left in the iter1.
477    i = 0
478    for item in iter1:  # each data is a dictionary
479        golden = np.array([i % 64])
480        np.testing.assert_array_equal(item[0], golden)
481        i = i + 1
482    assert i == 64 * 2
483
484    # now iter1 has been exhausted, c++ pipeline has been shut down.
485    with pytest.raises(RuntimeError) as info:
486        iter1.__next__()
487    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
488    assert err_msg in str(info.value)
489
490
491# Test with repeat
492def test_generator_tuple_repeat_repeat_1():
493    """
494    test generator tuple repeat repeat 1
495    """
496    logger.info("Test 1D Generator : 0 - 63")
497
498    # apply dataset operations
499    data1 = ds.GeneratorDataset(generator_1d, ["data"])
500    data1 = data1.repeat(2)
501    data1 = data1.repeat(3)
502    iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
503    for _ in range(10):
504        i = 0
505        for item in iter1:  # each data is a dictionary
506            golden = np.array([i % 64])
507            np.testing.assert_array_equal(item[0], golden)
508            i = i + 1
509        assert i == 64 * 2 * 3
510
511    # still one more epoch left in the iter1.
512    i = 0
513    for item in iter1:  # each data is a dictionary
514        golden = np.array([i % 64])
515        np.testing.assert_array_equal(item[0], golden)
516        i = i + 1
517    assert i == 64 * 2 * 3
518
519    # now iter1 has been exhausted, c++ pipeline has been shut down.
520    with pytest.raises(RuntimeError) as info:
521        iter1.__next__()
522    err_msg = "EOF buffer encountered. User tries to fetch data beyond the specified number of epochs."
523    assert err_msg in str(info.value)
524
525
526def test_generator_tuple_repeat_repeat_2():
527    """
528    test generator tuple repeat repeat 2
529    """
530    logger.info("Test 1D Generator : 0 - 63")
531
532    # apply dataset operations
533    data1 = ds.GeneratorDataset(generator_1d, ["data"])
534    data1 = data1.repeat(2)
535    data1 = data1.repeat(3)
536    iter1 = data1.create_tuple_iterator(output_numpy=True)
537    for _ in range(10):
538        i = 0
539        for item in iter1:  # each data is a dictionary
540            golden = np.array([i % 64])
541            np.testing.assert_array_equal(item[0], golden)
542            i = i + 1
543        assert i == 64 * 2 * 3
544    # optional
545    iter1.stop()
546    # Expect a AttributeError since iter1 has been stopped.
547    with pytest.raises(AttributeError) as info:
548        iter1.__next__()
549    assert "object has no attribute '_runtime_context'" in str(info.value)
550
551
552def test_generator_tuple_repeat_repeat_3():
553    """
554    test generator tuple repeat repeat 3
555    """
556    logger.info("Test 1D Generator : 0 - 63")
557
558    # apply dataset operations
559    data1 = ds.GeneratorDataset(generator_1d, ["data"])
560    data1 = data1.repeat(2)
561    data1 = data1.repeat(3)
562    iter1 = data1.create_tuple_iterator(output_numpy=True)
563    for _ in range(10):
564        i = 0
565        for item in iter1:  # each data is a dictionary
566            golden = np.array([i % 64])
567            np.testing.assert_array_equal(item[0], golden)
568            i = i + 1
569        assert i == 64 * 2 * 3
570
571    for _ in range(5):
572        i = 0
573        for item in iter1:  # each data is a dictionary
574            golden = np.array([i % 64])
575            np.testing.assert_array_equal(item[0], golden)
576            i = i + 1
577        assert i == 64 * 2 * 3
578
579    # rely on garbage collector to destroy iter1
580
581
582def test_generator_tuple_infinite_repeat_repeat_1():
583    """
584    test generator tuple infinite repeat repeat 1
585    """
586    logger.info("Test 1D Generator : 0 - 63")
587
588    # apply dataset operations
589    data1 = ds.GeneratorDataset(generator_1d, ["data"])
590    data1 = data1.repeat()
591    data1 = data1.repeat(3)
592    iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
593
594    i = 0
595    for item in iter1:  # each data is a dictionary
596        golden = np.array([i % 64])
597        np.testing.assert_array_equal(item[0], golden)
598        i = i + 1
599        if i == 100:
600            break
601
602    # rely on garbage collector to destroy iter1
603
604
605def test_generator_tuple_infinite_repeat_repeat_2():
606    """
607    test generator tuple infinite repeat repeat 2
608    """
609    logger.info("Test 1D Generator : 0 - 63")
610
611    # apply dataset operations
612    data1 = ds.GeneratorDataset(generator_1d, ["data"])
613    data1 = data1.repeat(3)
614    data1 = data1.repeat()
615    iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
616
617    i = 0
618    for item in iter1:  # each data is a dictionary
619        golden = np.array([i % 64])
620        np.testing.assert_array_equal(item[0], golden)
621        i = i + 1
622        if i == 100:
623            break
624
625    # rely on garbage collector to destroy iter1
626
627
628def test_generator_tuple_infinite_repeat_repeat_3():
629    """
630    test generator tuple infinite repeat repeat 3
631    """
632    logger.info("Test 1D Generator : 0 - 63")
633
634    # apply dataset operations
635    data1 = ds.GeneratorDataset(generator_1d, ["data"])
636    data1 = data1.repeat()
637    data1 = data1.repeat()
638    iter1 = data1.create_tuple_iterator(num_epochs=11, output_numpy=True)
639
640    i = 0
641    for item in iter1:  # each data is a dictionary
642        golden = np.array([i % 64])
643        np.testing.assert_array_equal(item[0], golden)
644        i = i + 1
645        if i == 100:
646            break
647
648    # rely on garbage collector to destroy iter1
649
650
651def test_generator_tuple_infinite_repeat_repeat_4():
652    """
653    test generator tuple infinite repeat repeat 4
654    """
655    logger.info("Test 1D Generator : 0 - 63")
656
657    # apply dataset operations
658    data1 = ds.GeneratorDataset(generator_1d, ["data"])
659    data1 = data1.repeat()
660    data1 = data1.repeat()
661    iter1 = data1.create_tuple_iterator(output_numpy=True)
662
663    i = 0
664    for item in iter1:  # each data is a dictionary
665        golden = np.array([i % 64])
666        np.testing.assert_array_equal(item[0], golden)
667        i = i + 1
668        if i == 100:
669            break
670
671    # rely on garbage collector to destroy iter1
672
673
674def test_generator_reusedataset():
675    """
676    test generator reusedataset
677    """
678    logger.info("Test 1D Generator : 0 - 63")
679
680    # apply dataset operations
681    data1 = ds.GeneratorDataset(generator_1d, ["data"])
682    data1 = data1.repeat(2)
683    iter1 = data1.create_tuple_iterator(output_numpy=True)
684    for _ in range(10):
685        i = 0
686        for item in iter1:  # each data is a dictionary
687            golden = np.array([i % 64])
688            np.testing.assert_array_equal(item[0], golden)
689            i = i + 1
690        assert i == 64 * 2
691
692    data1 = data1.repeat(3)
693    iter1 = data1.create_tuple_iterator(output_numpy=True)
694    for _ in range(5):
695        i = 0
696        for item in iter1:  # each data is a dictionary
697            golden = np.array([i % 64])
698            np.testing.assert_array_equal(item[0], golden)
699            i = i + 1
700        assert i == 64 * 2 * 3
701
702    data1 = data1.batch(2)
703    iter1 = data1.create_dict_iterator(output_numpy=True)
704    for _ in range(5):
705        i = 0
706        sample = 0
707        for item in iter1:  # each data is a dictionary
708            golden = np.array([[i % 64], [(i + 1) % 64]])
709            np.testing.assert_array_equal(item["data"], golden)
710            i = i + 2
711            sample = sample + 1
712        assert sample == 64 * 3
713
714    # rely on garbage collector to destroy iter1
715