• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Test Cifar10 and Cifar100 dataset operators
17"""
18import os
19import pytest
20import numpy as np
21import matplotlib.pyplot as plt
22import mindspore.dataset as ds
23from mindspore import log as logger
24
25DATA_DIR_10 = "../data/dataset/testCifar10Data"
26DATA_DIR_100 = "../data/dataset/testCifar100Data"
27NO_BIN_DIR = "../data/dataset/testMnistData"
28
29
30def load_cifar(path, kind="cifar10"):
31    """
32    load Cifar10/100 data
33    """
34    raw = np.empty(0, dtype=np.uint8)
35    for file_name in os.listdir(path):
36        if file_name.endswith(".bin"):
37            with open(os.path.join(path, file_name), mode='rb') as file:
38                raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0)
39    if kind == "cifar10":
40        raw = raw.reshape(-1, 3073)
41        labels = raw[:, 0]
42        images = raw[:, 1:]
43    elif kind == "cifar100":
44        raw = raw.reshape(-1, 3074)
45        labels = raw[:, :2]
46        images = raw[:, 2:]
47    else:
48        raise ValueError("Invalid parameter value")
49    images = images.reshape(-1, 3, 32, 32)
50    images = images.transpose(0, 2, 3, 1)
51    return images, labels
52
53
54def visualize_dataset(images, labels):
55    """
56    Helper function to visualize the dataset samples
57    """
58    num_samples = len(images)
59    for i in range(num_samples):
60        plt.subplot(1, num_samples, i + 1)
61        plt.imshow(images[i])
62        plt.title(labels[i])
63    plt.show()
64
65
66### Testcases for Cifar10Dataset Op ###
67
68
69def test_cifar10_content_check():
70    """
71    Validate Cifar10Dataset image readings
72    """
73    logger.info("Test Cifar10Dataset Op with content check")
74    data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100, shuffle=False)
75    images, labels = load_cifar(DATA_DIR_10)
76    num_iter = 0
77    # in this example, each dictionary has keys "image" and "label"
78    for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
79        np.testing.assert_array_equal(d["image"], images[i])
80        np.testing.assert_array_equal(d["label"], labels[i])
81        num_iter += 1
82    assert num_iter == 100
83
84
85def test_cifar10_basic():
86    """
87    Validate CIFAR10
88    """
89    logger.info("Test Cifar10Dataset Op")
90
91    # case 0: test loading the whole dataset
92    data0 = ds.Cifar10Dataset(DATA_DIR_10)
93    num_iter0 = 0
94    for _ in data0.create_dict_iterator(num_epochs=1):
95        num_iter0 += 1
96    assert num_iter0 == 10000
97
98    # case 1: test num_samples
99    data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
100    num_iter1 = 0
101    for _ in data1.create_dict_iterator(num_epochs=1):
102        num_iter1 += 1
103    assert num_iter1 == 100
104
105    # case 2: test num_parallel_workers
106    data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=50, num_parallel_workers=1)
107    num_iter2 = 0
108    for _ in data2.create_dict_iterator(num_epochs=1):
109        num_iter2 += 1
110    assert num_iter2 == 50
111
112    # case 3: test repeat
113    data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
114    data3 = data3.repeat(3)
115    num_iter3 = 0
116    for _ in data3.create_dict_iterator(num_epochs=1):
117        num_iter3 += 1
118    assert num_iter3 == 300
119
120    # case 4: test batch with drop_remainder=False
121    data4 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
122    assert data4.get_dataset_size() == 100
123    assert data4.get_batch_size() == 1
124    data4 = data4.batch(batch_size=7)  # drop_remainder is default to be False
125    assert data4.get_dataset_size() == 15
126    assert data4.get_batch_size() == 7
127    num_iter4 = 0
128    for _ in data4.create_dict_iterator(num_epochs=1):
129        num_iter4 += 1
130    assert num_iter4 == 15
131
132    # case 5: test batch with drop_remainder=True
133    data5 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100)
134    assert data5.get_dataset_size() == 100
135    assert data5.get_batch_size() == 1
136    data5 = data5.batch(batch_size=7, drop_remainder=True)  # the rest of incomplete batch will be dropped
137    assert data5.get_dataset_size() == 14
138    assert data5.get_batch_size() == 7
139    num_iter5 = 0
140    for _ in data5.create_dict_iterator(num_epochs=1):
141        num_iter5 += 1
142    assert num_iter5 == 14
143
144
145def test_cifar10_pk_sampler():
146    """
147    Test Cifar10Dataset with PKSampler
148    """
149    logger.info("Test Cifar10Dataset Op with PKSampler")
150    golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
151              5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
152    sampler = ds.PKSampler(3)
153    data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
154    num_iter = 0
155    label_list = []
156    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
157        label_list.append(item["label"])
158        num_iter += 1
159    np.testing.assert_array_equal(golden, label_list)
160    assert num_iter == 30
161
162
163def test_cifar10_sequential_sampler():
164    """
165    Test Cifar10Dataset with SequentialSampler
166    """
167    logger.info("Test Cifar10Dataset Op with SequentialSampler")
168    num_samples = 30
169    sampler = ds.SequentialSampler(num_samples=num_samples)
170    data1 = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
171    data2 = ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_samples=num_samples)
172    num_iter = 0
173    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
174                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
175        np.testing.assert_equal(item1["label"], item2["label"])
176        num_iter += 1
177    assert num_iter == num_samples
178
179
180def test_cifar10_exception():
181    """
182    Test error cases for Cifar10Dataset
183    """
184    logger.info("Test error cases for Cifar10Dataset")
185    error_msg_1 = "sampler and shuffle cannot be specified at the same time"
186    with pytest.raises(RuntimeError, match=error_msg_1):
187        ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, sampler=ds.PKSampler(3))
188
189    error_msg_2 = "sampler and sharding cannot be specified at the same time"
190    with pytest.raises(RuntimeError, match=error_msg_2):
191        ds.Cifar10Dataset(DATA_DIR_10, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
192
193    error_msg_3 = "num_shards is specified and currently requires shard_id as well"
194    with pytest.raises(RuntimeError, match=error_msg_3):
195        ds.Cifar10Dataset(DATA_DIR_10, num_shards=10)
196
197    error_msg_4 = "shard_id is specified but num_shards is not"
198    with pytest.raises(RuntimeError, match=error_msg_4):
199        ds.Cifar10Dataset(DATA_DIR_10, shard_id=0)
200
201    error_msg_5 = "Input shard_id is not within the required interval"
202    with pytest.raises(ValueError, match=error_msg_5):
203        ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=-1)
204    with pytest.raises(ValueError, match=error_msg_5):
205        ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=5)
206
207    error_msg_6 = "num_parallel_workers exceeds"
208    with pytest.raises(ValueError, match=error_msg_6):
209        ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=0)
210    with pytest.raises(ValueError, match=error_msg_6):
211        ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=256)
212
213    error_msg_7 = "no .bin files found"
214    with pytest.raises(RuntimeError, match=error_msg_7):
215        ds1 = ds.Cifar10Dataset(NO_BIN_DIR)
216        for _ in ds1.__iter__():
217            pass
218
219
220def test_cifar10_visualize(plot=False):
221    """
222    Visualize Cifar10Dataset results
223    """
224    logger.info("Test Cifar10Dataset visualization")
225
226    data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=10, shuffle=False)
227    num_iter = 0
228    image_list, label_list = [], []
229    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
230        image = item["image"]
231        label = item["label"]
232        image_list.append(image)
233        label_list.append("label {}".format(label))
234        assert isinstance(image, np.ndarray)
235        assert image.shape == (32, 32, 3)
236        assert image.dtype == np.uint8
237        assert label.dtype == np.uint32
238        num_iter += 1
239    assert num_iter == 10
240    if plot:
241        visualize_dataset(image_list, label_list)
242
243
244### Testcases for Cifar100Dataset Op ###
245
246def test_cifar100_content_check():
247    """
248    Validate Cifar100Dataset image readings
249    """
250    logger.info("Test Cifar100Dataset with content check")
251    data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, shuffle=False)
252    images, labels = load_cifar(DATA_DIR_100, kind="cifar100")
253    num_iter = 0
254    # in this example, each dictionary has keys "image", "coarse_label" and "fine_image"
255    for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
256        np.testing.assert_array_equal(d["image"], images[i])
257        np.testing.assert_array_equal(d["coarse_label"], labels[i][0])
258        np.testing.assert_array_equal(d["fine_label"], labels[i][1])
259        num_iter += 1
260    assert num_iter == 100
261
262
263def test_cifar100_basic():
264    """
265    Test Cifar100Dataset
266    """
267    logger.info("Test Cifar100Dataset")
268
269    # case 1: test num_samples
270    data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
271    num_iter1 = 0
272    for _ in data1.create_dict_iterator(num_epochs=1):
273        num_iter1 += 1
274    assert num_iter1 == 100
275
276    # case 2: test repeat
277    data1 = data1.repeat(2)
278    num_iter2 = 0
279    for _ in data1.create_dict_iterator(num_epochs=1):
280        num_iter2 += 1
281    assert num_iter2 == 200
282
283    # case 3: test num_parallel_workers
284    data2 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, num_parallel_workers=1)
285    num_iter3 = 0
286    for _ in data2.create_dict_iterator(num_epochs=1):
287        num_iter3 += 1
288    assert num_iter3 == 100
289
290    # case 4: test batch with drop_remainder=False
291    data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
292    assert data3.get_dataset_size() == 100
293    assert data3.get_batch_size() == 1
294    data3 = data3.batch(batch_size=3)
295    assert data3.get_dataset_size() == 34
296    assert data3.get_batch_size() == 3
297    num_iter4 = 0
298    for _ in data3.create_dict_iterator(num_epochs=1):
299        num_iter4 += 1
300    assert num_iter4 == 34
301
302    # case 4: test batch with drop_remainder=True
303    data4 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100)
304    data4 = data4.batch(batch_size=3, drop_remainder=True)
305    assert data4.get_dataset_size() == 33
306    assert data4.get_batch_size() == 3
307    num_iter5 = 0
308    for _ in data4.create_dict_iterator(num_epochs=1):
309        num_iter5 += 1
310    assert num_iter5 == 33
311
312
313def test_cifar100_pk_sampler():
314    """
315    Test Cifar100Dataset with PKSampler
316    """
317    logger.info("Test Cifar100Dataset with PKSampler")
318    golden = [i for i in range(20)]
319    sampler = ds.PKSampler(1)
320    data = ds.Cifar100Dataset(DATA_DIR_100, sampler=sampler)
321    num_iter = 0
322    label_list = []
323    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
324        label_list.append(item["coarse_label"])
325        num_iter += 1
326    np.testing.assert_array_equal(golden, label_list)
327    assert num_iter == 20
328
329
330def test_cifar100_exception():
331    """
332    Test error cases for Cifar100Dataset
333    """
334    logger.info("Test error cases for Cifar100Dataset")
335    error_msg_1 = "sampler and shuffle cannot be specified at the same time"
336    with pytest.raises(RuntimeError, match=error_msg_1):
337        ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, sampler=ds.PKSampler(3))
338
339    error_msg_2 = "sampler and sharding cannot be specified at the same time"
340    with pytest.raises(RuntimeError, match=error_msg_2):
341        ds.Cifar100Dataset(DATA_DIR_100, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
342
343    error_msg_3 = "num_shards is specified and currently requires shard_id as well"
344    with pytest.raises(RuntimeError, match=error_msg_3):
345        ds.Cifar100Dataset(DATA_DIR_100, num_shards=10)
346
347    error_msg_4 = "shard_id is specified but num_shards is not"
348    with pytest.raises(RuntimeError, match=error_msg_4):
349        ds.Cifar100Dataset(DATA_DIR_100, shard_id=0)
350
351    error_msg_5 = "Input shard_id is not within the required interval"
352    with pytest.raises(ValueError, match=error_msg_5):
353        ds.Cifar100Dataset(DATA_DIR_100, num_shards=2, shard_id=-1)
354    with pytest.raises(ValueError, match=error_msg_5):
355        ds.Cifar10Dataset(DATA_DIR_100, num_shards=2, shard_id=5)
356
357    error_msg_6 = "num_parallel_workers exceeds"
358    with pytest.raises(ValueError, match=error_msg_6):
359        ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=0)
360    with pytest.raises(ValueError, match=error_msg_6):
361        ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=256)
362
363    error_msg_7 = "no .bin files found"
364    with pytest.raises(RuntimeError, match=error_msg_7):
365        ds1 = ds.Cifar100Dataset(NO_BIN_DIR)
366        for _ in ds1.__iter__():
367            pass
368
369
370def test_cifar100_visualize(plot=False):
371    """
372    Visualize Cifar100Dataset results
373    """
374    logger.info("Test Cifar100Dataset visualization")
375
376    data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=10, shuffle=False)
377    num_iter = 0
378    image_list, label_list = [], []
379    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
380        image = item["image"]
381        coarse_label = item["coarse_label"]
382        fine_label = item["fine_label"]
383        image_list.append(image)
384        label_list.append("coarse_label {}\nfine_label {}".format(coarse_label, fine_label))
385        assert isinstance(image, np.ndarray)
386        assert image.shape == (32, 32, 3)
387        assert image.dtype == np.uint8
388        assert coarse_label.dtype == np.uint32
389        assert fine_label.dtype == np.uint32
390        num_iter += 1
391    assert num_iter == 10
392    if plot:
393        visualize_dataset(image_list, label_list)
394
395
396def test_cifar_usage():
397    """
398    test usage of cifar
399    """
400    logger.info("Test Cifar100Dataset usage flag")
401
402    # flag, if True, test cifar10 else test cifar100
403    def test_config(usage, flag=True, cifar_path=None):
404        if cifar_path is None:
405            cifar_path = DATA_DIR_10 if flag else DATA_DIR_100
406        try:
407            data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage)
408            num_rows = 0
409            for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
410                num_rows += 1
411        except (ValueError, TypeError, RuntimeError) as e:
412            return str(e)
413        return num_rows
414
415    # test the usage of CIFAR100
416    assert test_config("train") == 10000
417    assert test_config("all") == 10000
418    assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
419    assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
420    assert "Cifar10Dataset API can't read the data file (interface mismatch or no data found)" in test_config("test")
421
422    # test the usage of CIFAR10
423    assert test_config("test", False) == 10000
424    assert test_config("all", False) == 10000
425    assert "Cifar100Dataset API can't read the data file" in test_config("train", False)
426    assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False)
427
428    # change this directory to the folder that contains all cifar10 files
429    all_cifar10 = None
430    if all_cifar10 is not None:
431        assert test_config("train", True, all_cifar10) == 50000
432        assert test_config("test", True, all_cifar10) == 10000
433        assert test_config("all", True, all_cifar10) == 60000
434        assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000
435        assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000
436        assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000
437
438    # change this directory to the folder that contains all cifar100 files
439    all_cifar100 = None
440    if all_cifar100 is not None:
441        assert test_config("train", False, all_cifar100) == 50000
442        assert test_config("test", False, all_cifar100) == 10000
443        assert test_config("all", False, all_cifar100) == 60000
444        assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000
445        assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000
446        assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000
447
448
449def test_cifar_exception_file_path():
450    def exception_func(item):
451        raise Exception("Error occur!")
452
453    try:
454        data = ds.Cifar10Dataset(DATA_DIR_10)
455        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
456        num_rows = 0
457        for _ in data.create_dict_iterator():
458            num_rows += 1
459        assert False
460    except RuntimeError as e:
461        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
462
463    try:
464        data = ds.Cifar10Dataset(DATA_DIR_10)
465        data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
466        num_rows = 0
467        for _ in data.create_dict_iterator():
468            num_rows += 1
469        assert False
470    except RuntimeError as e:
471        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
472
473    try:
474        data = ds.Cifar100Dataset(DATA_DIR_100)
475        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
476        num_rows = 0
477        for _ in data.create_dict_iterator():
478            num_rows += 1
479        assert False
480    except RuntimeError as e:
481        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
482
483    try:
484        data = ds.Cifar100Dataset(DATA_DIR_100)
485        data = data.map(operations=exception_func, input_columns=["coarse_label"], num_parallel_workers=1)
486        num_rows = 0
487        for _ in data.create_dict_iterator():
488            num_rows += 1
489        assert False
490    except RuntimeError as e:
491        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
492
493    try:
494        data = ds.Cifar100Dataset(DATA_DIR_100)
495        data = data.map(operations=exception_func, input_columns=["fine_label"], num_parallel_workers=1)
496        num_rows = 0
497        for _ in data.create_dict_iterator():
498            num_rows += 1
499        assert False
500    except RuntimeError as e:
501        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
502
503
504def test_cifar10_pk_sampler_get_dataset_size():
505    """
506    Test Cifar10Dataset with PKSampler and get_dataset_size
507    """
508    sampler = ds.PKSampler(3)
509    data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
510    num_iter = 0
511    ds_sz = data.get_dataset_size()
512    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
513        num_iter += 1
514
515    assert ds_sz == num_iter == 30
516
517
518def test_cifar10_with_chained_sampler_get_dataset_size():
519    """
520    Test Cifar10Dataset with PKSampler chained with a SequentialSampler and get_dataset_size
521    """
522    sampler = ds.SequentialSampler(start_index=0, num_samples=5)
523    child_sampler = ds.PKSampler(4)
524    sampler.add_child(child_sampler)
525    data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler)
526    num_iter = 0
527    ds_sz = data.get_dataset_size()
528    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
529        num_iter += 1
530    assert ds_sz == num_iter == 5
531
532
533if __name__ == '__main__':
534    test_cifar10_content_check()
535    test_cifar10_basic()
536    test_cifar10_pk_sampler()
537    test_cifar10_sequential_sampler()
538    test_cifar10_exception()
539    test_cifar10_visualize(plot=False)
540
541    test_cifar100_content_check()
542    test_cifar100_basic()
543    test_cifar100_pk_sampler()
544    test_cifar100_exception()
545    test_cifar100_visualize(plot=False)
546
547    test_cifar_usage()
548    test_cifar_exception_file_path()
549
550    test_cifar10_with_chained_sampler_get_dataset_size()
551    test_cifar10_pk_sampler_get_dataset_size()
552