1# Copyright 2020-2021 Huawei Technologies Co., Ltd. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import mindspore.dataset as ds 15import mindspore.dataset.vision.c_transforms as vision 16from mindspore import log as logger 17from mindspore.dataset.vision import Inter 18 19DATA_DIR = "../data/dataset/testCelebAData/" 20 21 22def test_celeba_dataset_label(): 23 """ 24 Test CelebA dataset with labels 25 """ 26 logger.info("Test CelebA labels") 27 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) 28 expect_labels = [ 29 [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 30 0, 0, 1], 31 [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 32 0, 0, 1], 33 [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 34 0, 0, 1], 35 [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 36 0, 0, 1]] 37 count = 0 38 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 39 logger.info("----------image--------") 40 logger.info(item["image"]) 41 logger.info("----------attr--------") 42 logger.info(item["attr"]) 43 for index in range(len(expect_labels[count])): 44 assert item["attr"][index] == expect_labels[count][index] 45 count = count + 1 46 assert count == 4 47 48 49def test_celeba_dataset_op(): 50 """ 51 Test CelebA dataset with decode 52 """ 53 logger.info("Test CelebA with decode") 54 data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0) 55 crop_size = (80, 80) 56 resize_size = (24, 24) 57 # define map operations 58 data = data.repeat(2) 59 center_crop = vision.CenterCrop(crop_size) 60 resize_op = vision.Resize(resize_size, Inter.LINEAR) # Bilinear mode 61 data = data.map(operations=center_crop, input_columns=["image"]) 62 data = data.map(operations=resize_op, input_columns=["image"]) 63 64 count = 0 65 for item in data.create_dict_iterator(num_epochs=1): 66 logger.info("----------image--------") 67 logger.info(item["image"]) 68 count = count + 1 69 assert count == 8 70 71 72def test_celeba_dataset_ext(): 73 """ 74 Test CelebA dataset with extension 75 """ 76 logger.info("Test CelebA extension option") 77 ext = [".JPEG"] 78 data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext) 79 expect_labels = [ 80 [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 81 0, 1, 0, 1, 0, 0, 1], 82 [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 83 0, 1, 0, 1, 0, 0, 1]] 84 count = 0 85 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 86 logger.info("----------image--------") 87 logger.info(item["image"]) 88 logger.info("----------attr--------") 89 logger.info(item["attr"]) 90 for index in range(len(expect_labels[count])): 91 assert item["attr"][index] == expect_labels[count][index] 92 count = count + 1 93 assert count == 2 94 95 96def test_celeba_dataset_distribute(): 97 """ 98 Test CelebA dataset with distributed options 99 """ 100 logger.info("Test CelebA with sharding") 101 data = ds.CelebADataset(DATA_DIR, decode=True, num_shards=2, shard_id=0) 102 count = 0 103 for item in data.create_dict_iterator(num_epochs=1): 104 logger.info("----------image--------") 105 logger.info(item["image"]) 106 logger.info("----------attr--------") 107 logger.info(item["attr"]) 108 count = count + 1 109 assert count == 2 110 111 112def test_celeba_get_dataset_size(): 113 """ 114 Test CelebA dataset get dataset size 115 """ 116 logger.info("Test CelebA get dataset size") 117 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) 118 size = data.get_dataset_size() 119 assert size == 4 120 121 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train") 122 size = data.get_dataset_size() 123 assert size == 2 124 125 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid") 126 size = data.get_dataset_size() 127 assert size == 1 128 129 data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test") 130 size = data.get_dataset_size() 131 assert size == 1 132 133 134def test_celeba_dataset_exception_file_path(): 135 """ 136 Test CelebA dataset with bad file path 137 """ 138 logger.info("Test CelebA with bad file path") 139 def exception_func(item): 140 raise Exception("Error occur!") 141 142 try: 143 data = ds.CelebADataset(DATA_DIR, shuffle=False) 144 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 145 for _ in data.create_dict_iterator(): 146 pass 147 assert False 148 except RuntimeError as e: 149 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 150 151 try: 152 data = ds.CelebADataset(DATA_DIR, shuffle=False) 153 data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) 154 data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) 155 for _ in data.create_dict_iterator(): 156 pass 157 assert False 158 except RuntimeError as e: 159 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 160 161 try: 162 data = ds.CelebADataset(DATA_DIR, shuffle=False) 163 data = data.map(operations=exception_func, input_columns=["attr"], num_parallel_workers=1) 164 for _ in data.create_dict_iterator(): 165 pass 166 assert False 167 except RuntimeError as e: 168 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 169 170 171def test_celeba_sampler_exception(): 172 """ 173 Test CelebA with bad sampler input 174 """ 175 logger.info("Test CelebA with bad sampler input") 176 try: 177 data = ds.CelebADataset(DATA_DIR, sampler="") 178 for _ in data.create_dict_iterator(): 179 pass 180 assert False 181 except TypeError as e: 182 assert "Unsupported sampler object of type (<class 'str'>)" in str(e) 183 184 185if __name__ == '__main__': 186 test_celeba_dataset_label() 187 test_celeba_dataset_op() 188 test_celeba_dataset_ext() 189 test_celeba_dataset_distribute() 190 test_celeba_get_dataset_size() 191 test_celeba_dataset_exception_file_path() 192 test_celeba_sampler_exception() 193