1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Test Flowers102 dataset operators 17""" 18import os 19 20import matplotlib.pyplot as plt 21import numpy as np 22import pytest 23from PIL import Image 24from scipy.io import loadmat 25 26import mindspore.dataset as ds 27import mindspore.dataset.vision.c_transforms as c_vision 28from mindspore import log as logger 29 30DATA_DIR = "../data/dataset/testFlowers102Dataset" 31WRONG_DIR = "../data/dataset/testMnistData" 32 33 34def load_flowers102(path, usage): 35 """ 36 load Flowers102 data 37 """ 38 assert usage in ["train", "valid", "test", "all"] 39 40 imagelabels = (loadmat(os.path.join(path, "imagelabels.mat"))["labels"][0] - 1).astype(np.uint32) 41 split = loadmat(os.path.join(path, "setid.mat")) 42 if usage == 'train': 43 indices = split["trnid"][0].tolist() 44 elif usage == 'test': 45 indices = split["tstid"][0].tolist() 46 elif usage == 'valid': 47 indices = split["valid"][0].tolist() 48 elif usage == 'all': 49 indices = split["trnid"][0].tolist() 50 indices += split["tstid"][0].tolist() 51 indices += split["valid"][0].tolist() 52 53 image_paths = [os.path.join(path, "jpg", "image_" + str(index).zfill(5) + ".jpg") for index in indices] 54 segmentation_paths = [os.path.join(path, "segmim", "segmim_" + str(index).zfill(5) + ".jpg") for index in indices] 55 images = [np.asarray(Image.open(path).convert("RGB")) for path in image_paths] 56 segmentations = [np.asarray(Image.open(path).convert("RGB")) for path in segmentation_paths] 57 labels = [imagelabels[index - 1] for index in indices] 58 59 return images, segmentations, labels 60 61 62def visualize_dataset(images, labels): 63 """ 64 Helper function to visualize the dataset samples 65 """ 66 num_samples = len(images) 67 for i in range(num_samples): 68 plt.subplot(1, num_samples, i + 1) 69 plt.imshow(images[i].squeeze()) 70 plt.title(labels[i]) 71 plt.show() 72 73 74def test_flowers102_content_check(): 75 """ 76 Validate Flowers102Dataset image readings 77 """ 78 logger.info("Test Flowers102Dataset Op with content check") 79 all_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="all", 80 num_samples=6, decode=True, shuffle=False) 81 images, segmentations, labels = load_flowers102(DATA_DIR, "all") 82 num_iter = 0 83 # in this example, each dictionary has keys "image" and "label" 84 for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)): 85 np.testing.assert_array_equal(data["image"], images[i]) 86 np.testing.assert_array_equal(data["segmentation"], segmentations[i]) 87 np.testing.assert_array_equal(data["label"], labels[i]) 88 num_iter += 1 89 assert num_iter == 6 90 91 train_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="train", 92 num_samples=2, decode=True, shuffle=False) 93 images, segmentations, labels = load_flowers102(DATA_DIR, "train") 94 num_iter = 0 95 # in this example, each dictionary has keys "image" and "label" 96 for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)): 97 np.testing.assert_array_equal(data["image"], images[i]) 98 np.testing.assert_array_equal(data["segmentation"], segmentations[i]) 99 np.testing.assert_array_equal(data["label"], labels[i]) 100 num_iter += 1 101 assert num_iter == 2 102 103 test_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="test", 104 num_samples=2, decode=True, shuffle=False) 105 images, segmentations, labels = load_flowers102(DATA_DIR, "test") 106 num_iter = 0 107 # in this example, each dictionary has keys "image" and "label" 108 for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)): 109 np.testing.assert_array_equal(data["image"], images[i]) 110 np.testing.assert_array_equal(data["segmentation"], segmentations[i]) 111 np.testing.assert_array_equal(data["label"], labels[i]) 112 num_iter += 1 113 assert num_iter == 2 114 115 val_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="valid", 116 num_samples=2, decode=True, shuffle=False) 117 images, segmentations, labels = load_flowers102(DATA_DIR, "valid") 118 num_iter = 0 119 # in this example, each dictionary has keys "image" and "label" 120 for i, data in enumerate(val_data.create_dict_iterator(num_epochs=1, output_numpy=True)): 121 np.testing.assert_array_equal(data["image"], images[i]) 122 np.testing.assert_array_equal(data["segmentation"], segmentations[i]) 123 np.testing.assert_array_equal(data["label"], labels[i]) 124 num_iter += 1 125 assert num_iter == 2 126 127 128def test_flowers102_basic(): 129 """ 130 Validate Flowers102Dataset 131 """ 132 logger.info("Test Flowers102Dataset Op") 133 134 # case 1: test decode 135 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, shuffle=False) 136 all_data_1 = all_data.map(operations=[c_vision.Decode()], input_columns=["image"], num_parallel_workers=1) 137 all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shuffle=False) 138 139 num_iter = 0 140 for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True), 141 all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)): 142 np.testing.assert_array_equal(item1["label"], item2["label"]) 143 num_iter += 1 144 assert num_iter == 6 145 146 # case 2: test num_samples 147 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4) 148 num_iter = 0 149 for _ in all_data.create_dict_iterator(num_epochs=1): 150 num_iter += 1 151 assert num_iter == 4 152 153 # case 3: test repeat 154 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4) 155 all_data = all_data.repeat(5) 156 num_iter = 0 157 for _ in all_data.create_dict_iterator(num_epochs=1): 158 num_iter += 1 159 assert num_iter == 20 160 161 # case 3: test get_dataset_size, resize and batch 162 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4) 163 all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"], 164 num_parallel_workers=1) 165 166 assert all_data.get_dataset_size() == 4 167 assert all_data.get_batch_size() == 1 168 all_data = all_data.batch(batch_size=3) # drop_remainder is default to be False 169 assert all_data.get_batch_size() == 3 170 assert all_data.get_dataset_size() == 2 171 172 num_iter = 0 173 for _ in all_data.create_dict_iterator(num_epochs=1): 174 num_iter += 1 175 assert num_iter == 2 176 177 # case 4: test get_class_indexing 178 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4) 179 class_indexing = all_data.get_class_indexing() 180 assert class_indexing["pink primrose"] == 0 181 assert class_indexing["blackberry lily"] == 101 182 183 184def test_flowers102_sequential_sampler(): 185 """ 186 Test Flowers102Dataset with SequentialSampler 187 """ 188 logger.info("Test Flowers102Dataset Op with SequentialSampler") 189 num_samples = 4 190 sampler = ds.SequentialSampler(num_samples=num_samples) 191 all_data_1 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", 192 decode=True, sampler=sampler) 193 all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", 194 decode=True, shuffle=False, num_samples=num_samples) 195 label_list_1, label_list_2 = [], [] 196 num_iter = 0 197 for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1), 198 all_data_2.create_dict_iterator(num_epochs=1)): 199 label_list_1.append(item1["label"].asnumpy()) 200 label_list_2.append(item2["label"].asnumpy()) 201 num_iter += 1 202 np.testing.assert_array_equal(label_list_1, label_list_2) 203 assert num_iter == num_samples 204 205 206def test_flowers102_exception(): 207 """ 208 Test error cases for Flowers102Dataset 209 """ 210 logger.info("Test error cases for Flowers102Dataset") 211 error_msg_1 = "sampler and shuffle cannot be specified at the same time" 212 with pytest.raises(RuntimeError, match=error_msg_1): 213 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", shuffle=False, 214 decode=True, sampler=ds.SequentialSampler(1)) 215 216 error_msg_2 = "sampler and sharding cannot be specified at the same time" 217 with pytest.raises(RuntimeError, match=error_msg_2): 218 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", sampler=ds.SequentialSampler(1), 219 decode=True, num_shards=2, shard_id=0) 220 221 error_msg_3 = "num_shards is specified and currently requires shard_id as well" 222 with pytest.raises(RuntimeError, match=error_msg_3): 223 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=10) 224 225 error_msg_4 = "shard_id is specified but num_shards is not" 226 with pytest.raises(RuntimeError, match=error_msg_4): 227 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shard_id=0) 228 229 error_msg_5 = "Input shard_id is not within the required interval" 230 with pytest.raises(ValueError, match=error_msg_5): 231 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=-1) 232 233 with pytest.raises(ValueError, match=error_msg_5): 234 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=5) 235 236 with pytest.raises(ValueError, match=error_msg_5): 237 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id=5) 238 239 error_msg_6 = "num_parallel_workers exceeds" 240 with pytest.raises(ValueError, match=error_msg_6): 241 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, 242 shuffle=False, num_parallel_workers=0) 243 with pytest.raises(ValueError, match=error_msg_6): 244 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, 245 shuffle=False, num_parallel_workers=256) 246 with pytest.raises(ValueError, match=error_msg_6): 247 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, 248 shuffle=False, num_parallel_workers=-2) 249 250 error_msg_7 = "Argument shard_id" 251 with pytest.raises(TypeError, match=error_msg_7): 252 ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id="0") 253 254 255 error_msg_8 = "does not exist or is not a directory or permission denied!" 256 with pytest.raises(ValueError, match=error_msg_8): 257 all_data = ds.Flowers102Dataset(WRONG_DIR, task="Classification", usage="all", decode=True) 258 for _ in all_data.create_dict_iterator(num_epochs=1): 259 pass 260 261 error_msg_9 = "is not of type" 262 with pytest.raises(TypeError, match=error_msg_9): 263 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=123) 264 for _ in all_data.create_dict_iterator(num_epochs=1): 265 pass 266 267 268def test_flowers102_visualize(plot=False): 269 """ 270 Visualize Flowers102Dataset results 271 """ 272 logger.info("Test Flowers102Dataset visualization") 273 274 all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", num_samples=4, 275 decode=True, shuffle=False) 276 num_iter = 0 277 image_list, label_list = [], [] 278 for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True): 279 image = item["image"] 280 label = item["label"] 281 image_list.append(image) 282 label_list.append("label {}".format(label)) 283 assert isinstance(image, np.ndarray) 284 assert len(image.shape) == 3 285 assert image.shape[-1] == 3 286 assert image.dtype == np.uint8 287 assert label.dtype == np.uint32 288 num_iter += 1 289 assert num_iter == 4 290 if plot: 291 visualize_dataset(image_list, label_list) 292 293 294def test_flowers102_usage(): 295 """ 296 Validate Flowers102Dataset usage 297 """ 298 logger.info("Test Flowers102Dataset usage flag") 299 300 def test_config(usage): 301 try: 302 data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage=usage, decode=True, shuffle=False) 303 num_rows = 0 304 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 305 num_rows += 1 306 except (ValueError, TypeError, RuntimeError) as e: 307 return str(e) 308 return num_rows 309 310 assert test_config("all") == 6 311 assert test_config("train") == 2 312 assert test_config("test") == 2 313 assert test_config("valid") == 2 314 315 assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid") 316 assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"]) 317 318 319def test_flowers102_task(): 320 """ 321 Validate Flowers102Dataset task 322 """ 323 logger.info("Test Flowers102Dataset task flag") 324 325 def test_config(task): 326 try: 327 data = ds.Flowers102Dataset(DATA_DIR, task=task, usage="all", decode=True, shuffle=False) 328 num_rows = 0 329 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 330 num_rows += 1 331 except (ValueError, TypeError, RuntimeError) as e: 332 return str(e) 333 return num_rows 334 335 assert test_config("Classification") == 6 336 assert test_config("Segmentation") == 6 337 338 assert "Input task is not within the valid set of ['Classification', 'Segmentation']" in test_config("invalid") 339 assert "Argument task with value ['list'] is not of type [<class 'str'>]" in test_config(["list"]) 340 341if __name__ == '__main__': 342 test_flowers102_content_check() 343 test_flowers102_basic() 344 test_flowers102_sequential_sampler() 345 test_flowers102_exception() 346 test_flowers102_visualize(plot=True) 347 test_flowers102_usage() 348 test_flowers102_task() 349