• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import matplotlib.pyplot as plt
17
18import mindspore.dataset as ds
19import mindspore.dataset.vision.c_transforms as c_vision
20from mindspore import log as logger
21
22FLICKR30K_DATASET_DIR = "../data/dataset/testFlickrData/flickr30k/flickr30k-images"
23FLICKR30K_ANNOTATION_FILE_1 = "../data/dataset/testFlickrData/flickr30k/test1.token"
24FLICKR30K_ANNOTATION_FILE_2 = "../data/dataset/testFlickrData/flickr30k/test2.token"
25
26
27def visualize_dataset(images, labels):
28    """
29    Helper function to visualize the dataset samples
30    """
31    plt.figure(figsize=(10, 10))
32    for i, item in enumerate(zip(images, labels), start=1):
33        plt.imshow(item[0])
34        plt.title('\n'.join([s.decode('utf-8') for s in item[1]]))
35        plt.savefig('./flickr_' + str(i) + '.jpg')
36
37
38def test_flickr30k_dataset_train(plot=False):
39    data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
40    count = 0
41    images_list = []
42    annotation_list = []
43    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
44        logger.info("item[image] is {}".format(item["image"]))
45        images_list.append(item['image'])
46        annotation_list.append(item['annotation'])
47        count = count + 1
48    assert count == 2
49    if plot:
50        visualize_dataset(images_list, annotation_list)
51
52
53def test_flickr30k_dataset_annotation_check():
54    data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True, shuffle=False)
55    count = 0
56    expect_annotation_arr = [
57        np.array([
58            r'This is \*a banana.',
59            'This is a yellow banana.',
60            'This is a banana on the table.',
61            'The banana is yellow.',
62            'The banana is very big.',
63        ]),
64        np.array([
65            'This is a pen.',
66            'This is a red and black pen.',
67            'This is a pen on the table.',
68            'The color of the pen is red and black.',
69            'The pen has two colors.',
70        ])
71    ]
72    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
73        annotation = [s.decode("utf8") for s in item["annotation"]]
74        np.testing.assert_array_equal(annotation, expect_annotation_arr[count])
75        logger.info("item[image] is {}".format(item["image"]))
76        count = count + 1
77    assert count == 2
78
79
80def test_flickr30k_dataset_basic():
81    # case 1: test num_samples
82    data1 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, num_samples=2, decode=True)
83    num_iter1 = 0
84    for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
85        num_iter1 += 1
86    assert num_iter1 == 2
87
88    # case 2: test repeat
89    data2 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
90    data2 = data2.repeat(5)
91    num_iter2 = 0
92    for _ in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
93        num_iter2 += 1
94    assert num_iter2 == 10
95
96    # case 3: test batch with drop_remainder=False
97    data3 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False)
98    resize_op = c_vision.Resize((100, 100))
99    data3 = data3.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
100    assert data3.get_dataset_size() == 3
101    assert data3.get_batch_size() == 1
102    data3 = data3.batch(batch_size=2)  # drop_remainder is default to be False
103    assert data3.get_dataset_size() == 2
104    assert data3.get_batch_size() == 2
105    num_iter3 = 0
106    for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
107        num_iter3 += 1
108    assert num_iter3 == 2
109
110    # case 4: test batch with drop_remainder=True
111    data4 = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_2, decode=True, shuffle=False)
112    resize_op = c_vision.Resize((100, 100))
113    data4 = data4.map(operations=resize_op, input_columns=["image"], num_parallel_workers=1)
114    assert data4.get_dataset_size() == 3
115    assert data4.get_batch_size() == 1
116    data4 = data4.batch(batch_size=2, drop_remainder=True)  # the rest of incomplete batch will be dropped
117    assert data4.get_dataset_size() == 1
118    assert data4.get_batch_size() == 2
119    num_iter4 = 0
120    for _ in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
121        num_iter4 += 1
122    assert num_iter4 == 1
123
124
125def test_flickr30k_dataset_exception():
126    def exception_func(item):
127        raise Exception("Error occur!")
128
129    try:
130        data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
131        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
132        num_rows = 0
133        for _ in data.create_dict_iterator():
134            num_rows += 1
135        assert False
136    except RuntimeError as e:
137        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
138
139    try:
140        data = ds.FlickrDataset(FLICKR30K_DATASET_DIR, FLICKR30K_ANNOTATION_FILE_1, decode=True)
141        data = data.map(operations=exception_func, input_columns=["annotation"], num_parallel_workers=1)
142        num_rows = 0
143        for _ in data.create_dict_iterator():
144            num_rows += 1
145        assert False
146    except RuntimeError as e:
147        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
148
149
150if __name__ == "__main__":
151    test_flickr30k_dataset_train(False)
152    test_flickr30k_dataset_annotation_check()
153    test_flickr30k_dataset_basic()
154    test_flickr30k_dataset_exception()
155