1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing the MixUpBatch op in DE 17""" 18import numpy as np 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.vision.c_transforms as vision 22import mindspore.dataset.transforms.c_transforms as data_trans 23from mindspore import log as logger 24from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ 25 config_get_set_num_parallel_workers 26 27DATA_DIR = "../data/dataset/testCifar10Data" 28DATA_DIR2 = "../data/dataset/testImageNetData2/train/" 29DATA_DIR3 = "../data/dataset/testCelebAData/" 30 31GENERATE_GOLDEN = False 32 33 34def test_mixup_batch_success1(plot=False): 35 """ 36 Test MixUpBatch op with specified alpha parameter 37 """ 38 logger.info("test_mixup_batch_success1") 39 40 # Original Images 41 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 42 ds_original = ds_original.batch(5, drop_remainder=True) 43 44 images_original = None 45 for idx, (image, _) in enumerate(ds_original): 46 if idx == 0: 47 images_original = image.asnumpy() 48 else: 49 images_original = np.append(images_original, image.asnumpy(), axis=0) 50 51 # MixUp Images 52 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 53 54 one_hot_op = data_trans.OneHot(num_classes=10) 55 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 56 mixup_batch_op = vision.MixUpBatch(2) 57 data1 = data1.batch(5, drop_remainder=True) 58 data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"]) 59 60 images_mixup = None 61 for idx, (image, _) in enumerate(data1): 62 if idx == 0: 63 images_mixup = image.asnumpy() 64 else: 65 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 66 if plot: 67 visualize_list(images_original, images_mixup) 68 69 num_samples = images_original.shape[0] 70 mse = np.zeros(num_samples) 71 for i in range(num_samples): 72 mse[i] = diff_mse(images_mixup[i], images_original[i]) 73 logger.info("MSE= {}".format(str(np.mean(mse)))) 74 75 76def test_mixup_batch_success2(plot=False): 77 """ 78 Test MixUpBatch op with specified alpha parameter on ImageFolderDataset 79 """ 80 logger.info("test_mixup_batch_success2") 81 82 # Original Images 83 ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False) 84 decode_op = vision.Decode() 85 ds_original = ds_original.map(operations=[decode_op], input_columns=["image"]) 86 ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True) 87 88 images_original = None 89 for idx, (image, _) in enumerate(ds_original): 90 if idx == 0: 91 images_original = image.asnumpy() 92 else: 93 images_original = np.append(images_original, image.asnumpy(), axis=0) 94 95 # MixUp Images 96 data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False) 97 98 decode_op = vision.Decode() 99 data1 = data1.map(operations=[decode_op], input_columns=["image"]) 100 101 one_hot_op = data_trans.OneHot(num_classes=10) 102 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 103 104 mixup_batch_op = vision.MixUpBatch(2.0) 105 data1 = data1.batch(4, pad_info={}, drop_remainder=True) 106 data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"]) 107 108 images_mixup = None 109 for idx, (image, _) in enumerate(data1): 110 if idx == 0: 111 images_mixup = image.asnumpy() 112 else: 113 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 114 if plot: 115 visualize_list(images_original, images_mixup) 116 117 num_samples = images_original.shape[0] 118 mse = np.zeros(num_samples) 119 for i in range(num_samples): 120 mse[i] = diff_mse(images_mixup[i], images_original[i]) 121 logger.info("MSE= {}".format(str(np.mean(mse)))) 122 123 124def test_mixup_batch_success3(plot=False): 125 """ 126 Test MixUpBatch op without specified alpha parameter. 127 Alpha parameter will be selected by default in this case 128 """ 129 logger.info("test_mixup_batch_success3") 130 131 # Original Images 132 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 133 ds_original = ds_original.batch(5, drop_remainder=True) 134 135 images_original = None 136 for idx, (image, _) in enumerate(ds_original): 137 if idx == 0: 138 images_original = image.asnumpy() 139 else: 140 images_original = np.append(images_original, image.asnumpy(), axis=0) 141 142 # MixUp Images 143 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 144 145 one_hot_op = data_trans.OneHot(num_classes=10) 146 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 147 mixup_batch_op = vision.MixUpBatch() 148 data1 = data1.batch(5, drop_remainder=True) 149 data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"]) 150 151 images_mixup = np.array([]) 152 for idx, (image, _) in enumerate(data1): 153 if idx == 0: 154 images_mixup = image.asnumpy() 155 else: 156 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 157 if plot: 158 visualize_list(images_original, images_mixup) 159 160 num_samples = images_original.shape[0] 161 mse = np.zeros(num_samples) 162 for i in range(num_samples): 163 mse[i] = diff_mse(images_mixup[i], images_original[i]) 164 logger.info("MSE= {}".format(str(np.mean(mse)))) 165 166 167def test_mixup_batch_success4(plot=False): 168 """ 169 Test MixUpBatch op on a dataset where OneHot returns a 2D vector. 170 Alpha parameter will be selected by default in this case 171 """ 172 logger.info("test_mixup_batch_success4") 173 174 # Original Images 175 ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False) 176 decode_op = vision.Decode() 177 ds_original = ds_original.map(operations=[decode_op], input_columns=["image"]) 178 ds_original = ds_original.batch(2, drop_remainder=True) 179 180 images_original = None 181 for idx, (image, _) in enumerate(ds_original): 182 if idx == 0: 183 images_original = image.asnumpy() 184 else: 185 images_original = np.append(images_original, image.asnumpy(), axis=0) 186 187 # MixUp Images 188 data1 = ds.CelebADataset(DATA_DIR3, shuffle=False) 189 190 decode_op = vision.Decode() 191 data1 = data1.map(operations=[decode_op], input_columns=["image"]) 192 193 one_hot_op = data_trans.OneHot(num_classes=100) 194 data1 = data1.map(operations=one_hot_op, input_columns=["attr"]) 195 196 mixup_batch_op = vision.MixUpBatch() 197 data1 = data1.batch(2, drop_remainder=True) 198 data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "attr"]) 199 200 images_mixup = np.array([]) 201 for idx, (image, _) in enumerate(data1): 202 if idx == 0: 203 images_mixup = image.asnumpy() 204 else: 205 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 206 if plot: 207 visualize_list(images_original, images_mixup) 208 209 num_samples = images_original.shape[0] 210 mse = np.zeros(num_samples) 211 for i in range(num_samples): 212 mse[i] = diff_mse(images_mixup[i], images_original[i]) 213 logger.info("MSE= {}".format(str(np.mean(mse)))) 214 215 216def test_mixup_batch_md5(): 217 """ 218 Test MixUpBatch with MD5: 219 """ 220 logger.info("test_mixup_batch_md5") 221 original_seed = config_get_set_seed(0) 222 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 223 224 # MixUp Images 225 data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 226 227 one_hot_op = data_trans.OneHot(num_classes=10) 228 data = data.map(operations=one_hot_op, input_columns=["label"]) 229 mixup_batch_op = vision.MixUpBatch() 230 data = data.batch(5, drop_remainder=True) 231 data = data.map(operations=mixup_batch_op, input_columns=["image", "label"]) 232 233 filename = "mixup_batch_c_result.npz" 234 save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) 235 236 # Restore config setting 237 ds.config.set_seed(original_seed) 238 ds.config.set_num_parallel_workers(original_num_parallel_workers) 239 240 241def test_mixup_batch_fail1(): 242 """ 243 Test MixUpBatch Fail 1 244 We expect this to fail because the images and labels are not batched 245 """ 246 logger.info("test_mixup_batch_fail1") 247 248 # Original Images 249 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 250 ds_original = ds_original.batch(5) 251 252 images_original = np.array([]) 253 for idx, (image, _) in enumerate(ds_original): 254 if idx == 0: 255 images_original = image.asnumpy() 256 else: 257 images_original = np.append(images_original, image.asnumpy(), axis=0) 258 259 # MixUp Images 260 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 261 262 one_hot_op = data_trans.OneHot(num_classes=10) 263 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 264 mixup_batch_op = vision.MixUpBatch(0.1) 265 with pytest.raises(RuntimeError) as error: 266 data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"]) 267 for idx, (image, _) in enumerate(data1): 268 if idx == 0: 269 images_mixup = image.asnumpy() 270 else: 271 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 272 error_message = "You must make sure images are HWC or CHW and batched" 273 assert error_message in str(error.value) 274 275 276def test_mixup_batch_fail2(): 277 """ 278 Test MixUpBatch Fail 2 279 We expect this to fail because alpha is negative 280 """ 281 logger.info("test_mixup_batch_fail2") 282 283 # Original Images 284 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 285 ds_original = ds_original.batch(5) 286 287 images_original = np.array([]) 288 for idx, (image, _) in enumerate(ds_original): 289 if idx == 0: 290 images_original = image.asnumpy() 291 else: 292 images_original = np.append(images_original, image.asnumpy(), axis=0) 293 294 # MixUp Images 295 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 296 297 one_hot_op = data_trans.OneHot(num_classes=10) 298 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 299 with pytest.raises(ValueError) as error: 300 vision.MixUpBatch(-1) 301 error_message = "Input is not within the required interval" 302 assert error_message in str(error.value) 303 304 305def test_mixup_batch_fail3(): 306 """ 307 Test MixUpBatch op 308 We expect this to fail because label column is not passed to mixup_batch 309 """ 310 logger.info("test_mixup_batch_fail3") 311 # Original Images 312 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 313 ds_original = ds_original.batch(5, drop_remainder=True) 314 315 images_original = None 316 for idx, (image, _) in enumerate(ds_original): 317 if idx == 0: 318 images_original = image.asnumpy() 319 else: 320 images_original = np.append(images_original, image.asnumpy(), axis=0) 321 322 # MixUp Images 323 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 324 325 one_hot_op = data_trans.OneHot(num_classes=10) 326 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 327 mixup_batch_op = vision.MixUpBatch() 328 data1 = data1.batch(5, drop_remainder=True) 329 data1 = data1.map(operations=mixup_batch_op, input_columns=["image"]) 330 331 with pytest.raises(RuntimeError) as error: 332 images_mixup = np.array([]) 333 for idx, (image, _) in enumerate(data1): 334 if idx == 0: 335 images_mixup = image.asnumpy() 336 else: 337 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 338 error_message = "size of input data should be 2 (including images or labels)" 339 assert error_message in str(error.value) 340 341 342def test_mixup_batch_fail4(): 343 """ 344 Test MixUpBatch Fail 2 345 We expect this to fail because alpha is zero 346 """ 347 logger.info("test_mixup_batch_fail4") 348 349 # Original Images 350 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 351 ds_original = ds_original.batch(5) 352 353 images_original = np.array([]) 354 for idx, (image, _) in enumerate(ds_original): 355 if idx == 0: 356 images_original = image.asnumpy() 357 else: 358 images_original = np.append(images_original, image.asnumpy(), axis=0) 359 360 # MixUp Images 361 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 362 363 one_hot_op = data_trans.OneHot(num_classes=10) 364 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 365 with pytest.raises(ValueError) as error: 366 vision.MixUpBatch(0.0) 367 error_message = "Input is not within the required interval" 368 assert error_message in str(error.value) 369 370 371def test_mixup_batch_fail5(): 372 """ 373 Test MixUpBatch Fail 5 374 We expect this to fail because labels are not OntHot encoded 375 """ 376 logger.info("test_mixup_batch_fail5") 377 378 # Original Images 379 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 380 ds_original = ds_original.batch(5) 381 382 images_original = np.array([]) 383 for idx, (image, _) in enumerate(ds_original): 384 if idx == 0: 385 images_original = image.asnumpy() 386 else: 387 images_original = np.append(images_original, image.asnumpy(), axis=0) 388 389 # MixUp Images 390 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 391 392 mixup_batch_op = vision.MixUpBatch() 393 data1 = data1.batch(5, drop_remainder=True) 394 data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"]) 395 396 with pytest.raises(RuntimeError) as error: 397 images_mixup = np.array([]) 398 for idx, (image, _) in enumerate(data1): 399 if idx == 0: 400 images_mixup = image.asnumpy() 401 else: 402 images_mixup = np.append(images_mixup, image.asnumpy(), axis=0) 403 error_message = "wrong labels shape. The second column (labels) must have a shape of NC or NLC" 404 assert error_message in str(error.value) 405 406 407if __name__ == "__main__": 408 test_mixup_batch_success1(plot=True) 409 test_mixup_batch_success2(plot=True) 410 test_mixup_batch_success3(plot=True) 411 test_mixup_batch_success4(plot=True) 412 test_mixup_batch_md5() 413 test_mixup_batch_fail1() 414 test_mixup_batch_fail2() 415 test_mixup_batch_fail3() 416 test_mixup_batch_fail4() 417 test_mixup_batch_fail5() 418