• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Test Flowers102 dataset operators
17"""
18import os
19
20import matplotlib.pyplot as plt
21import numpy as np
22import pytest
23from PIL import Image
24from scipy.io import loadmat
25
26import mindspore.dataset as ds
27import mindspore.dataset.vision.c_transforms as c_vision
28from mindspore import log as logger
29
30DATA_DIR = "../data/dataset/testFlowers102Dataset"
31WRONG_DIR = "../data/dataset/testMnistData"
32
33
34def load_flowers102(path, usage):
35    """
36    load Flowers102 data
37    """
38    assert usage in ["train", "valid", "test", "all"]
39
40    imagelabels = (loadmat(os.path.join(path, "imagelabels.mat"))["labels"][0] - 1).astype(np.uint32)
41    split = loadmat(os.path.join(path, "setid.mat"))
42    if usage == 'train':
43        indices = split["trnid"][0].tolist()
44    elif usage == 'test':
45        indices = split["tstid"][0].tolist()
46    elif usage == 'valid':
47        indices = split["valid"][0].tolist()
48    elif usage == 'all':
49        indices = split["trnid"][0].tolist()
50        indices += split["tstid"][0].tolist()
51        indices += split["valid"][0].tolist()
52
53    image_paths = [os.path.join(path, "jpg", "image_" + str(index).zfill(5) + ".jpg") for index in indices]
54    segmentation_paths = [os.path.join(path, "segmim", "segmim_" + str(index).zfill(5) + ".jpg") for index in indices]
55    images = [np.asarray(Image.open(path).convert("RGB")) for path in image_paths]
56    segmentations = [np.asarray(Image.open(path).convert("RGB")) for path in segmentation_paths]
57    labels = [imagelabels[index - 1] for index in indices]
58
59    return images, segmentations, labels
60
61
62def visualize_dataset(images, labels):
63    """
64    Helper function to visualize the dataset samples
65    """
66    num_samples = len(images)
67    for i in range(num_samples):
68        plt.subplot(1, num_samples, i + 1)
69        plt.imshow(images[i].squeeze())
70        plt.title(labels[i])
71    plt.show()
72
73
74def test_flowers102_content_check():
75    """
76    Validate Flowers102Dataset image readings
77    """
78    logger.info("Test Flowers102Dataset Op with content check")
79    all_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="all",
80                                    num_samples=6, decode=True, shuffle=False)
81    images, segmentations, labels = load_flowers102(DATA_DIR, "all")
82    num_iter = 0
83    # in this example, each dictionary has keys "image" and "label"
84    for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
85        np.testing.assert_array_equal(data["image"], images[i])
86        np.testing.assert_array_equal(data["segmentation"], segmentations[i])
87        np.testing.assert_array_equal(data["label"], labels[i])
88        num_iter += 1
89    assert num_iter == 6
90
91    train_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="train",
92                                      num_samples=2, decode=True, shuffle=False)
93    images, segmentations, labels = load_flowers102(DATA_DIR, "train")
94    num_iter = 0
95    # in this example, each dictionary has keys "image" and "label"
96    for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
97        np.testing.assert_array_equal(data["image"], images[i])
98        np.testing.assert_array_equal(data["segmentation"], segmentations[i])
99        np.testing.assert_array_equal(data["label"], labels[i])
100        num_iter += 1
101    assert num_iter == 2
102
103    test_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="test",
104                                     num_samples=2, decode=True, shuffle=False)
105    images, segmentations, labels = load_flowers102(DATA_DIR, "test")
106    num_iter = 0
107    # in this example, each dictionary has keys "image" and "label"
108    for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
109        np.testing.assert_array_equal(data["image"], images[i])
110        np.testing.assert_array_equal(data["segmentation"], segmentations[i])
111        np.testing.assert_array_equal(data["label"], labels[i])
112        num_iter += 1
113    assert num_iter == 2
114
115    val_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="valid",
116                                    num_samples=2, decode=True, shuffle=False)
117    images, segmentations, labels = load_flowers102(DATA_DIR, "valid")
118    num_iter = 0
119    # in this example, each dictionary has keys "image" and "label"
120    for i, data in enumerate(val_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
121        np.testing.assert_array_equal(data["image"], images[i])
122        np.testing.assert_array_equal(data["segmentation"], segmentations[i])
123        np.testing.assert_array_equal(data["label"], labels[i])
124        num_iter += 1
125    assert num_iter == 2
126
127
128def test_flowers102_basic():
129    """
130    Validate Flowers102Dataset
131    """
132    logger.info("Test Flowers102Dataset Op")
133
134    # case 1: test decode
135    all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, shuffle=False)
136    all_data_1 = all_data.map(operations=[c_vision.Decode()], input_columns=["image"], num_parallel_workers=1)
137    all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shuffle=False)
138
139    num_iter = 0
140    for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
141                            all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
142        np.testing.assert_array_equal(item1["label"], item2["label"])
143        num_iter += 1
144    assert num_iter == 6
145
146    # case 2: test num_samples
147    all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4)
148    num_iter = 0
149    for _ in all_data.create_dict_iterator(num_epochs=1):
150        num_iter += 1
151    assert num_iter == 4
152
153    # case 3: test repeat
154    all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4)
155    all_data = all_data.repeat(5)
156    num_iter = 0
157    for _ in all_data.create_dict_iterator(num_epochs=1):
158        num_iter += 1
159    assert num_iter == 20
160
161    # case 3: test get_dataset_size, resize and batch
162    all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4)
163    all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"],
164                            num_parallel_workers=1)
165
166    assert all_data.get_dataset_size() == 4
167    assert all_data.get_batch_size() == 1
168    all_data = all_data.batch(batch_size=3)  # drop_remainder is default to be False
169    assert all_data.get_batch_size() == 3
170    assert all_data.get_dataset_size() == 2
171
172    num_iter = 0
173    for _ in all_data.create_dict_iterator(num_epochs=1):
174        num_iter += 1
175    assert num_iter == 2
176
177    # case 4: test get_class_indexing
178    all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4)
179    class_indexing = all_data.get_class_indexing()
180    assert class_indexing["pink primrose"] == 0
181    assert class_indexing["blackberry lily"] == 101
182
183
184def test_flowers102_sequential_sampler():
185    """
186    Test Flowers102Dataset with SequentialSampler
187    """
188    logger.info("Test Flowers102Dataset Op with SequentialSampler")
189    num_samples = 4
190    sampler = ds.SequentialSampler(num_samples=num_samples)
191    all_data_1 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all",
192                                      decode=True, sampler=sampler)
193    all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all",
194                                      decode=True, shuffle=False, num_samples=num_samples)
195    label_list_1, label_list_2 = [], []
196    num_iter = 0
197    for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
198                            all_data_2.create_dict_iterator(num_epochs=1)):
199        label_list_1.append(item1["label"].asnumpy())
200        label_list_2.append(item2["label"].asnumpy())
201        num_iter += 1
202    np.testing.assert_array_equal(label_list_1, label_list_2)
203    assert num_iter == num_samples
204
205
206def test_flowers102_exception():
207    """
208    Test error cases for Flowers102Dataset
209    """
210    logger.info("Test error cases for Flowers102Dataset")
211    error_msg_1 = "sampler and shuffle cannot be specified at the same time"
212    with pytest.raises(RuntimeError, match=error_msg_1):
213        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", shuffle=False,
214                             decode=True, sampler=ds.SequentialSampler(1))
215
216    error_msg_2 = "sampler and sharding cannot be specified at the same time"
217    with pytest.raises(RuntimeError, match=error_msg_2):
218        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", sampler=ds.SequentialSampler(1),
219                             decode=True, num_shards=2, shard_id=0)
220
221    error_msg_3 = "num_shards is specified and currently requires shard_id as well"
222    with pytest.raises(RuntimeError, match=error_msg_3):
223        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=10)
224
225    error_msg_4 = "shard_id is specified but num_shards is not"
226    with pytest.raises(RuntimeError, match=error_msg_4):
227        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shard_id=0)
228
229    error_msg_5 = "Input shard_id is not within the required interval"
230    with pytest.raises(ValueError, match=error_msg_5):
231        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=-1)
232
233    with pytest.raises(ValueError, match=error_msg_5):
234        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=5)
235
236    with pytest.raises(ValueError, match=error_msg_5):
237        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id=5)
238
239    error_msg_6 = "num_parallel_workers exceeds"
240    with pytest.raises(ValueError, match=error_msg_6):
241        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
242                             shuffle=False, num_parallel_workers=0)
243    with pytest.raises(ValueError, match=error_msg_6):
244        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
245                             shuffle=False, num_parallel_workers=256)
246    with pytest.raises(ValueError, match=error_msg_6):
247        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
248                             shuffle=False, num_parallel_workers=-2)
249
250    error_msg_7 = "Argument shard_id"
251    with pytest.raises(TypeError, match=error_msg_7):
252        ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id="0")
253
254
255    error_msg_8 = "does not exist or is not a directory or permission denied!"
256    with pytest.raises(ValueError, match=error_msg_8):
257        all_data = ds.Flowers102Dataset(WRONG_DIR, task="Classification", usage="all", decode=True)
258        for _ in all_data.create_dict_iterator(num_epochs=1):
259            pass
260
261    error_msg_9 = "is not of type"
262    with pytest.raises(TypeError, match=error_msg_9):
263        all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=123)
264        for _ in all_data.create_dict_iterator(num_epochs=1):
265            pass
266
267
268def test_flowers102_visualize(plot=False):
269    """
270    Visualize Flowers102Dataset results
271    """
272    logger.info("Test Flowers102Dataset visualization")
273
274    all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", num_samples=4,
275                                    decode=True, shuffle=False)
276    num_iter = 0
277    image_list, label_list = [], []
278    for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True):
279        image = item["image"]
280        label = item["label"]
281        image_list.append(image)
282        label_list.append("label {}".format(label))
283        assert isinstance(image, np.ndarray)
284        assert len(image.shape) == 3
285        assert image.shape[-1] == 3
286        assert image.dtype == np.uint8
287        assert label.dtype == np.uint32
288        num_iter += 1
289    assert num_iter == 4
290    if plot:
291        visualize_dataset(image_list, label_list)
292
293
294def test_flowers102_usage():
295    """
296    Validate Flowers102Dataset usage
297    """
298    logger.info("Test Flowers102Dataset usage flag")
299
300    def test_config(usage):
301        try:
302            data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage=usage, decode=True, shuffle=False)
303            num_rows = 0
304            for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
305                num_rows += 1
306        except (ValueError, TypeError, RuntimeError) as e:
307            return str(e)
308        return num_rows
309
310    assert test_config("all") == 6
311    assert test_config("train") == 2
312    assert test_config("test") == 2
313    assert test_config("valid") == 2
314
315    assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid")
316    assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
317
318
319def test_flowers102_task():
320    """
321    Validate Flowers102Dataset task
322    """
323    logger.info("Test Flowers102Dataset task flag")
324
325    def test_config(task):
326        try:
327            data = ds.Flowers102Dataset(DATA_DIR, task=task, usage="all", decode=True, shuffle=False)
328            num_rows = 0
329            for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
330                num_rows += 1
331        except (ValueError, TypeError, RuntimeError) as e:
332            return str(e)
333        return num_rows
334
335    assert test_config("Classification") == 6
336    assert test_config("Segmentation") == 6
337
338    assert "Input task is not within the valid set of ['Classification', 'Segmentation']" in test_config("invalid")
339    assert "Argument task with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
340
341if __name__ == '__main__':
342    test_flowers102_content_check()
343    test_flowers102_basic()
344    test_flowers102_sequential_sampler()
345    test_flowers102_exception()
346    test_flowers102_visualize(plot=True)
347    test_flowers102_usage()
348    test_flowers102_task()
349