• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Test USPS dataset operators
17"""
18import os
19
20import matplotlib.pyplot as plt
21import numpy as np
22import pytest
23from PIL import Image
24
25import mindspore.dataset as ds
26import mindspore.dataset.vision.c_transforms as vision
27from mindspore import log as logger
28
29DATA_DIR = "../data/dataset/testSBUDataset"
30WRONG_DIR = "../data/dataset/testMnistData"
31
32
33def load_sbu(path):
34    """
35    load SBU data
36    """
37    images = []
38    captions = []
39
40    file1 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_urls.txt'))
41    file2 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_captions.txt'))
42
43    for line1, line2 in zip(open(file1), open(file2)):
44        url = line1.rstrip()
45        image = url[23:].replace("/", "_")
46        filename = os.path.join(path, 'sbu_images', image)
47        if os.path.exists(filename):
48            caption = line2.rstrip()
49            images.append(np.asarray(Image.open(filename).convert('RGB')).astype(np.uint8))
50            captions.append(caption)
51    return images, captions
52
53
54def visualize_dataset(images, captions):
55    """
56    Helper function to visualize the dataset samples
57    """
58    num_samples = len(images)
59    for i in range(num_samples):
60        plt.subplot(1, num_samples, i + 1)
61        plt.imshow(images[i].squeeze())
62        plt.title(captions[i])
63    plt.show()
64
65
66def test_sbu_content_check():
67    """
68    Validate SBUDataset image readings
69    """
70    logger.info("Test SBUDataset Op with content check")
71    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=50, shuffle=False)
72    images, captions = load_sbu(DATA_DIR)
73    num_iter = 0
74    # in this example, each dictionary has keys "image" and "caption"
75    for i, data in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
76        assert data["image"].shape == images[i].shape
77        assert data["caption"].item().decode("utf8") == captions[i]
78        num_iter += 1
79    assert num_iter == 5
80
81
82def test_sbu_case():
83    """
84    Validate SBUDataset cases
85    """
86    dataset = ds.SBUDataset(DATA_DIR, decode=True)
87
88    dataset = dataset.map(operations=[vision.Resize((224, 224))], input_columns=["image"])
89    repeat_num = 4
90    dataset = dataset.repeat(repeat_num)
91    batch_size = 2
92    dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
93
94    num = 0
95    for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
96        num += 1
97    # 4 x 5 / 2
98    assert num == 10
99
100    dataset = ds.SBUDataset(DATA_DIR, decode=False)
101
102    dataset = dataset.map(operations=[vision.Decode(rgb=True), vision.Resize((224, 224))], input_columns=["image"])
103    repeat_num = 4
104    dataset = dataset.repeat(repeat_num)
105    batch_size = 2
106    dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
107
108    num = 0
109    for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
110        num += 1
111    # 4 x 5 / 2
112    assert num == 10
113
114
115def test_sbu_basic():
116    """
117    Validate SBUDataset
118    """
119    logger.info("Test SBUDataset Op")
120
121    # case 1: test loading whole dataset
122    dataset = ds.SBUDataset(DATA_DIR, decode=True)
123    num_iter = 0
124    for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
125        num_iter += 1
126    assert num_iter == 5
127
128
129    # case 2: test num_samples
130    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
131    num_iter = 0
132    for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
133        num_iter += 1
134    assert num_iter == 5
135
136    # case 3: test repeat
137    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
138    dataset = dataset.repeat(5)
139    num_iter = 0
140    for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
141        num_iter += 1
142    assert num_iter == 25
143
144    # case 4: test batch
145    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
146    assert dataset.get_dataset_size() == 5
147    assert dataset.get_batch_size() == 1
148
149    num_iter = 0
150    for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
151        num_iter += 1
152    assert num_iter == 5
153
154    # case 5: test get_class_indexing
155    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
156    assert dataset.get_class_indexing() == {}
157
158    # case 6: test get_col_names
159    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
160    assert dataset.get_col_names() == ["image", "caption"]
161
162
163def test_sbu_sequential_sampler():
164    """
165    Test SBUDataset with SequentialSampler
166    """
167    logger.info("Test SBUDataset Op with SequentialSampler")
168    num_samples = 5
169    sampler = ds.SequentialSampler(num_samples=num_samples)
170    dataset_1 = ds.SBUDataset(DATA_DIR, decode=True, sampler=sampler)
171    dataset_2 = ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_samples=num_samples)
172
173    num_iter = 0
174    for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
175                            dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
176        np.testing.assert_array_equal(item1["caption"], item2["caption"])
177        num_iter += 1
178    assert num_iter == num_samples
179
180
181def test_sbu_exception():
182    """
183    Test error cases for SBUDataset
184    """
185    logger.info("Test error cases for SBUDataset")
186    error_msg_1 = "sampler and shuffle cannot be specified at the same time"
187    with pytest.raises(RuntimeError, match=error_msg_1):
188        ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, sampler=ds.SequentialSampler())
189
190    error_msg_2 = "sampler and sharding cannot be specified at the same time"
191    with pytest.raises(RuntimeError, match=error_msg_2):
192        ds.SBUDataset(DATA_DIR, decode=True, sampler=ds.SequentialSampler(), num_shards=2, shard_id=0)
193
194    error_msg_3 = "num_shards is specified and currently requires shard_id as well"
195    with pytest.raises(RuntimeError, match=error_msg_3):
196        ds.SBUDataset(DATA_DIR, decode=True, num_shards=10)
197
198    error_msg_4 = "shard_id is specified but num_shards is not"
199    with pytest.raises(RuntimeError, match=error_msg_4):
200        ds.SBUDataset(DATA_DIR, decode=True, shard_id=0)
201
202    error_msg_5 = "Input shard_id is not within the required interval"
203    with pytest.raises(ValueError, match=error_msg_5):
204        ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
205    with pytest.raises(ValueError, match=error_msg_5):
206        ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
207    with pytest.raises(ValueError, match=error_msg_5):
208        ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id=5)
209
210    error_msg_6 = "num_parallel_workers exceeds"
211    with pytest.raises(ValueError, match=error_msg_6):
212        ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
213    with pytest.raises(ValueError, match=error_msg_6):
214        ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
215    with pytest.raises(ValueError, match=error_msg_6):
216        ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=-2)
217
218    error_msg_7 = "Argument shard_id"
219    with pytest.raises(TypeError, match=error_msg_7):
220        ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
221
222    def exception_func(item):
223        raise Exception("Error occur!")
224
225    error_msg_8 = "The corresponding data files"
226    with pytest.raises(RuntimeError, match=error_msg_8):
227        dataset = ds.SBUDataset(DATA_DIR, decode=True)
228        dataset = dataset.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
229        for _ in dataset.__iter__():
230            pass
231
232    with pytest.raises(RuntimeError, match=error_msg_8):
233        dataset = ds.SBUDataset(DATA_DIR, decode=True)
234        dataset = dataset.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
235        for _ in dataset.__iter__():
236            pass
237
238    error_msg_9 = "does not exist or permission denied"
239    with pytest.raises(ValueError, match=error_msg_9):
240        dataset = ds.SBUDataset(WRONG_DIR, decode=True)
241        for _ in dataset.__iter__():
242            pass
243
244    error_msg_10 = "Argument decode with value"
245    with pytest.raises(TypeError, match=error_msg_10):
246        dataset = ds.SBUDataset(DATA_DIR, decode="not_bool")
247        for _ in dataset.__iter__():
248            pass
249
250
251def test_sbu_visualize(plot=False):
252    """
253    Visualize SBUDataset results
254    """
255    logger.info("Test SBUDataset visualization")
256
257    dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=10, shuffle=False)
258    num_iter = 0
259    image_list, caption_list = [], []
260    for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
261        image = item["image"]
262        caption = item["caption"].item().decode("utf8")
263        image_list.append(image)
264        caption_list.append("caption {}".format(caption))
265        assert isinstance(image, np.ndarray)
266
267        assert image.dtype == np.uint8
268        assert isinstance(caption, str)
269        num_iter += 1
270    assert num_iter == 5
271    if plot:
272        visualize_dataset(image_list, caption_list)
273
274
275def test_sbu_decode():
276    """
277    Validate SBUDataset image readings
278    """
279    logger.info("Test SBUDataset decode flag")
280
281    sampler = ds.SequentialSampler(num_samples=50)
282    dataset = ds.SBUDataset(dataset_dir=DATA_DIR, decode=False, sampler=sampler)
283    dataset_1 = dataset.map(operations=[vision.Decode(rgb=True)], input_columns=["image"])
284
285    dataset_2 = ds.SBUDataset(dataset_dir=DATA_DIR, decode=True, sampler=sampler)
286
287    num_iter = 0
288    for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
289                            dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
290        np.testing.assert_array_equal(item1["caption"], item2["caption"])
291        num_iter += 1
292
293    assert num_iter == 5
294
295
296if __name__ == '__main__':
297    test_sbu_content_check()
298    test_sbu_basic()
299    test_sbu_case()
300    test_sbu_sequential_sampler()
301    test_sbu_exception()
302    test_sbu_visualize(plot=True)
303    test_sbu_decode()
304