• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18import mindspore.dataset.vision.c_transforms as vision
19import mindspore.dataset.transforms.c_transforms as data_trans
20from mindspore import log as logger
21
22DATA_FILE = "../data/dataset/testManifestData/test.manifest"
23
24
25def test_manifest_dataset_train():
26    data = ds.ManifestDataset(DATA_FILE, decode=True)
27    count = 0
28    cat_count = 0
29    dog_count = 0
30    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
31        logger.info("item[image] is {}".format(item["image"]))
32        count = count + 1
33        if item["label"].size == 1 and item["label"] == 0:
34            cat_count = cat_count + 1
35        elif item["label"].size == 1 and item["label"] == 1:
36            dog_count = dog_count + 1
37    assert cat_count == 2
38    assert dog_count == 1
39    assert count == 4
40
41
42def test_manifest_dataset_eval():
43    data = ds.ManifestDataset(DATA_FILE, "eval", decode=True)
44    count = 0
45    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
46        logger.info("item[image] is {}".format(item["image"]))
47        count = count + 1
48        if item["label"] != 0 and item["label"] != 1:
49            assert 0
50    assert count == 2
51
52
53def test_manifest_dataset_class_index():
54    class_indexing = {"dog": 11}
55    data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing)
56    out_class_indexing = data.get_class_indexing()
57    assert out_class_indexing == {"dog": 11}
58    count = 0
59    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
60        logger.info("item[image] is {}".format(item["image"]))
61        count = count + 1
62        if item["label"] != 11:
63            assert 0
64    assert count == 1
65
66
67def test_manifest_dataset_get_class_index():
68    data = ds.ManifestDataset(DATA_FILE, decode=True)
69    class_indexing = data.get_class_indexing()
70    assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
71    data = data.shuffle(4)
72    class_indexing = data.get_class_indexing()
73    assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
74    count = 0
75    for item in data.create_dict_iterator(num_epochs=1):
76        logger.info("item[image] is {}".format(item["image"]))
77        count = count + 1
78    assert count == 4
79
80
81def test_manifest_dataset_multi_label():
82    data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
83    count = 0
84    expect_label = [1, 0, 0, [0, 2]]
85    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
86        assert item["label"].tolist() == expect_label[count]
87        logger.info("item[image] is {}".format(item["image"]))
88        count = count + 1
89    assert count == 4
90
91
92def multi_label_hot(x):
93    result = np.zeros(x.size // x.ndim, dtype=int)
94    if x.ndim > 1:
95        for i in range(x.ndim):
96            result = np.add(result, x[i])
97    else:
98        result = np.add(result, x)
99
100    return result
101
102
103def test_manifest_dataset_multi_label_onehot():
104    data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
105    expect_label = [[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [1, 0, 1]]]
106    one_hot_encode = data_trans.OneHot(3)
107    data = data.map(operations=one_hot_encode, input_columns=["label"])
108    data = data.map(operations=multi_label_hot, input_columns=["label"])
109    data = data.batch(2)
110    count = 0
111    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
112        assert item["label"].tolist() == expect_label[count]
113        logger.info("item[image] is {}".format(item["image"]))
114        count = count + 1
115
116
117def test_manifest_dataset_get_num_class():
118    data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
119    assert data.num_classes() == 3
120
121    padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}]
122    padded_ds = ds.PaddedDataset(padded_samples)
123
124    data = data.repeat(2)
125    padded_ds = padded_ds.repeat(2)
126
127    data1 = data + padded_ds
128    assert data1.num_classes() == 3
129
130
131def test_manifest_dataset_exception():
132    def exception_func(item):
133        raise Exception("Error occur!")
134
135    try:
136        data = ds.ManifestDataset(DATA_FILE)
137        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
138        for _ in data.__iter__():
139            pass
140        assert False
141    except RuntimeError as e:
142        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
143
144    try:
145        data = ds.ManifestDataset(DATA_FILE)
146        data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
147        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
148        for _ in data.__iter__():
149            pass
150        assert False
151    except RuntimeError as e:
152        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
153
154    try:
155        data = ds.ManifestDataset(DATA_FILE)
156        data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
157        for _ in data.__iter__():
158            pass
159        assert False
160    except RuntimeError as e:
161        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
162
163    NO_SOURCE_DATA_FILE = "../data/dataset/testManifestData/invalidNoSource.manifest"
164    try:
165        data = ds.ManifestDataset(NO_SOURCE_DATA_FILE)
166        for _ in data.__iter__():
167            pass
168        assert False
169    except RuntimeError as e:
170        assert "Invalid data, 'source' is not found in Manifest file" in str(e)
171
172    NO_USAGE_DATA_FILE = "../data/dataset/testManifestData/invalidNoUsage.manifest"
173    try:
174        data = ds.ManifestDataset(NO_USAGE_DATA_FILE)
175        for _ in data.__iter__():
176            pass
177        assert False
178    except RuntimeError as e:
179        assert "Invalid data, 'usage' is not found in Manifest file" in str(e)
180
181
182if __name__ == '__main__':
183    test_manifest_dataset_train()
184    test_manifest_dataset_eval()
185    test_manifest_dataset_class_index()
186    test_manifest_dataset_get_class_index()
187    test_manifest_dataset_multi_label()
188    test_manifest_dataset_multi_label_onehot()
189    test_manifest_dataset_get_num_class()
190    test_manifest_dataset_exception()
191