• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import mindspore.dataset as ds
15import mindspore.dataset.vision.c_transforms as vision
16from mindspore import log as logger
17from mindspore.dataset.vision import Inter
18
19DATA_DIR = "../data/dataset/testCelebAData/"
20
21
22def test_celeba_dataset_label():
23    """
24    Test CelebA dataset with labels
25    """
26    logger.info("Test CelebA labels")
27    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
28    expect_labels = [
29        [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
30         0, 0, 1],
31        [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
32         0, 0, 1],
33        [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
34         0, 0, 1],
35        [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
36         0, 0, 1]]
37    count = 0
38    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
39        logger.info("----------image--------")
40        logger.info(item["image"])
41        logger.info("----------attr--------")
42        logger.info(item["attr"])
43        for index in range(len(expect_labels[count])):
44            assert item["attr"][index] == expect_labels[count][index]
45        count = count + 1
46    assert count == 4
47
48
49def test_celeba_dataset_op():
50    """
51    Test CelebA dataset with decode
52    """
53    logger.info("Test CelebA with decode")
54    data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0)
55    crop_size = (80, 80)
56    resize_size = (24, 24)
57    # define map operations
58    data = data.repeat(2)
59    center_crop = vision.CenterCrop(crop_size)
60    resize_op = vision.Resize(resize_size, Inter.LINEAR)  # Bilinear mode
61    data = data.map(operations=center_crop, input_columns=["image"])
62    data = data.map(operations=resize_op, input_columns=["image"])
63
64    count = 0
65    for item in data.create_dict_iterator(num_epochs=1):
66        logger.info("----------image--------")
67        logger.info(item["image"])
68        count = count + 1
69    assert count == 8
70
71
72def test_celeba_dataset_ext():
73    """
74    Test CelebA dataset with extension
75    """
76    logger.info("Test CelebA extension option")
77    ext = [".JPEG"]
78    data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
79    expect_labels = [
80        [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
81         0, 1, 0, 1, 0, 0, 1],
82        [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
83         0, 1, 0, 1, 0, 0, 1]]
84    count = 0
85    for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
86        logger.info("----------image--------")
87        logger.info(item["image"])
88        logger.info("----------attr--------")
89        logger.info(item["attr"])
90        for index in range(len(expect_labels[count])):
91            assert item["attr"][index] == expect_labels[count][index]
92        count = count + 1
93    assert count == 2
94
95
96def test_celeba_dataset_distribute():
97    """
98    Test CelebA dataset with distributed options
99    """
100    logger.info("Test CelebA with sharding")
101    data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0)
102    count = 0
103    for item in data.create_dict_iterator(num_epochs=1):
104        logger.info("----------image--------")
105        logger.info(item["image"])
106        logger.info("----------attr--------")
107        logger.info(item["attr"])
108        count = count + 1
109    assert count == 2
110
111
112def test_celeba_get_dataset_size():
113    """
114    Test CelebA dataset get dataset size
115    """
116    logger.info("Test CelebA get dataset size")
117    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
118    size = data.get_dataset_size()
119    assert size == 4
120
121    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
122    size = data.get_dataset_size()
123    assert size == 2
124
125    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
126    size = data.get_dataset_size()
127    assert size == 1
128
129    data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
130    size = data.get_dataset_size()
131    assert size == 1
132
133
134def test_celeba_dataset_exception_file_path():
135    """
136    Test CelebA dataset with bad file path
137    """
138    logger.info("Test CelebA with bad file path")
139    def exception_func(item):
140        raise Exception("Error occur!")
141
142    try:
143        data = ds.CelebADataset(DATA_DIR, shuffle=False)
144        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
145        for _ in data.create_dict_iterator():
146            pass
147        assert False
148    except RuntimeError as e:
149        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
150
151    try:
152        data = ds.CelebADataset(DATA_DIR, shuffle=False)
153        data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
154        data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
155        for _ in data.create_dict_iterator():
156            pass
157        assert False
158    except RuntimeError as e:
159        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
160
161    try:
162        data = ds.CelebADataset(DATA_DIR, shuffle=False)
163        data = data.map(operations=exception_func, input_columns=["attr"], num_parallel_workers=1)
164        for _ in data.create_dict_iterator():
165            pass
166        assert False
167    except RuntimeError as e:
168        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
169
170
171def test_celeba_sampler_exception():
172    """
173    Test CelebA with bad sampler input
174    """
175    logger.info("Test CelebA with bad sampler input")
176    try:
177        data = ds.CelebADataset(DATA_DIR, sampler="")
178        for _ in data.create_dict_iterator():
179            pass
180        assert False
181    except TypeError as e:
182        assert "Unsupported sampler object of type (<class 'str'>)" in str(e)
183
184
185if __name__ == '__main__':
186    test_celeba_dataset_label()
187    test_celeba_dataset_op()
188    test_celeba_dataset_ext()
189    test_celeba_dataset_distribute()
190    test_celeba_get_dataset_size()
191    test_celeba_dataset_exception_file_path()
192    test_celeba_sampler_exception()
193