1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing the CutMixBatch op in DE 17""" 18import numpy as np 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.vision.c_transforms as vision 22import mindspore.dataset.transforms.c_transforms as data_trans 23import mindspore.dataset.vision.utils as mode 24from mindspore import log as logger 25from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ 26 config_get_set_num_parallel_workers 27 28DATA_DIR = "../data/dataset/testCifar10Data" 29DATA_DIR2 = "../data/dataset/testImageNetData2/train/" 30DATA_DIR3 = "../data/dataset/testCelebAData/" 31 32GENERATE_GOLDEN = False 33 34 35def test_cutmix_batch_success1(plot=False): 36 """ 37 Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images 38 """ 39 logger.info("test_cutmix_batch_success1") 40 # Original Images 41 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 42 ds_original = ds_original.batch(5, drop_remainder=True) 43 44 images_original = None 45 for idx, (image, _) in enumerate(ds_original): 46 if idx == 0: 47 images_original = image.asnumpy() 48 else: 49 images_original = np.append(images_original, image.asnumpy(), axis=0) 50 51 # CutMix Images 52 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 53 hwc2chw_op = vision.HWC2CHW() 54 data1 = data1.map(operations=hwc2chw_op, input_columns=["image"]) 55 one_hot_op = data_trans.OneHot(num_classes=10) 56 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 57 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5) 58 data1 = data1.batch(5, drop_remainder=True) 59 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 60 61 images_cutmix = None 62 for idx, (image, _) in enumerate(data1): 63 if idx == 0: 64 images_cutmix = image.asnumpy().transpose(0, 2, 3, 1) 65 else: 66 images_cutmix = np.append(images_cutmix, image.asnumpy().transpose(0, 2, 3, 1), axis=0) 67 if plot: 68 visualize_list(images_original, images_cutmix) 69 70 num_samples = images_original.shape[0] 71 mse = np.zeros(num_samples) 72 for i in range(num_samples): 73 mse[i] = diff_mse(images_cutmix[i], images_original[i]) 74 logger.info("MSE= {}".format(str(np.mean(mse)))) 75 76 77def test_cutmix_batch_success2(plot=False): 78 """ 79 Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images 80 """ 81 logger.info("test_cutmix_batch_success2") 82 83 # Original Images 84 ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 85 ds_original = ds_original.batch(5, drop_remainder=True) 86 87 images_original = None 88 for idx, (image, _) in enumerate(ds_original): 89 if idx == 0: 90 images_original = image.asnumpy() 91 else: 92 images_original = np.append(images_original, image.asnumpy(), axis=0) 93 94 # CutMix Images 95 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 96 one_hot_op = data_trans.OneHot(num_classes=10) 97 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 98 rescale_op = vision.Rescale((1.0 / 255.0), 0.0) 99 data1 = data1.map(operations=rescale_op, input_columns=["image"]) 100 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) 101 data1 = data1.batch(5, drop_remainder=True) 102 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 103 104 images_cutmix = None 105 for idx, (image, _) in enumerate(data1): 106 if idx == 0: 107 images_cutmix = image.asnumpy() 108 else: 109 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 110 if plot: 111 visualize_list(images_original, images_cutmix) 112 113 num_samples = images_original.shape[0] 114 mse = np.zeros(num_samples) 115 for i in range(num_samples): 116 mse[i] = diff_mse(images_cutmix[i], images_original[i]) 117 logger.info("MSE= {}".format(str(np.mean(mse)))) 118 119 120def test_cutmix_batch_success3(plot=False): 121 """ 122 Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDataset 123 """ 124 logger.info("test_cutmix_batch_success3") 125 126 ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False) 127 decode_op = vision.Decode() 128 ds_original = ds_original.map(operations=[decode_op], input_columns=["image"]) 129 resize_op = vision.Resize([224, 224]) 130 ds_original = ds_original.map(operations=[resize_op], input_columns=["image"]) 131 ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True) 132 133 images_original = None 134 for idx, (image, _) in enumerate(ds_original): 135 if idx == 0: 136 images_original = image.asnumpy() 137 else: 138 images_original = np.append(images_original, image.asnumpy(), axis=0) 139 140 # CutMix Images 141 data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False) 142 143 decode_op = vision.Decode() 144 data1 = data1.map(operations=[decode_op], input_columns=["image"]) 145 146 resize_op = vision.Resize([224, 224]) 147 data1 = data1.map(operations=[resize_op], input_columns=["image"]) 148 149 one_hot_op = data_trans.OneHot(num_classes=10) 150 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 151 152 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) 153 data1 = data1.batch(4, pad_info={}, drop_remainder=True) 154 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 155 156 images_cutmix = None 157 for idx, (image, _) in enumerate(data1): 158 if idx == 0: 159 images_cutmix = image.asnumpy() 160 else: 161 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 162 if plot: 163 visualize_list(images_original, images_cutmix) 164 165 num_samples = images_original.shape[0] 166 mse = np.zeros(num_samples) 167 for i in range(num_samples): 168 mse[i] = diff_mse(images_cutmix[i], images_original[i]) 169 logger.info("MSE= {}".format(str(np.mean(mse)))) 170 171 172def test_cutmix_batch_success4(plot=False): 173 """ 174 Test CutMixBatch on a dataset where OneHot returns a 2D vector 175 """ 176 logger.info("test_cutmix_batch_success4") 177 178 ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False) 179 decode_op = vision.Decode() 180 ds_original = ds_original.map(operations=[decode_op], input_columns=["image"]) 181 resize_op = vision.Resize([224, 224]) 182 ds_original = ds_original.map(operations=[resize_op], input_columns=["image"]) 183 ds_original = ds_original.batch(2, drop_remainder=True) 184 185 images_original = None 186 for idx, (image, _) in enumerate(ds_original): 187 if idx == 0: 188 images_original = image.asnumpy() 189 else: 190 images_original = np.append(images_original, image.asnumpy(), axis=0) 191 192 # CutMix Images 193 data1 = ds.CelebADataset(dataset_dir=DATA_DIR3, shuffle=False) 194 195 decode_op = vision.Decode() 196 data1 = data1.map(operations=[decode_op], input_columns=["image"]) 197 198 resize_op = vision.Resize([224, 224]) 199 data1 = data1.map(operations=[resize_op], input_columns=["image"]) 200 201 one_hot_op = data_trans.OneHot(num_classes=100) 202 data1 = data1.map(operations=one_hot_op, input_columns=["attr"]) 203 204 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9) 205 data1 = data1.batch(2, drop_remainder=True) 206 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "attr"]) 207 208 images_cutmix = None 209 for idx, (image, _) in enumerate(data1): 210 if idx == 0: 211 images_cutmix = image.asnumpy() 212 else: 213 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 214 if plot: 215 visualize_list(images_original, images_cutmix) 216 217 num_samples = images_original.shape[0] 218 mse = np.zeros(num_samples) 219 for i in range(num_samples): 220 mse[i] = diff_mse(images_cutmix[i], images_original[i]) 221 logger.info("MSE= {}".format(str(np.mean(mse)))) 222 223 224def test_cutmix_batch_nhwc_md5(): 225 """ 226 Test CutMixBatch on a batch of HWC images with MD5: 227 """ 228 logger.info("test_cutmix_batch_nhwc_md5") 229 original_seed = config_get_set_seed(0) 230 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 231 232 # CutMixBatch Images 233 data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 234 235 one_hot_op = data_trans.OneHot(num_classes=10) 236 data = data.map(operations=one_hot_op, input_columns=["label"]) 237 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) 238 data = data.batch(5, drop_remainder=True) 239 data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 240 241 filename = "cutmix_batch_c_nhwc_result.npz" 242 save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) 243 244 # Restore config setting 245 ds.config.set_seed(original_seed) 246 ds.config.set_num_parallel_workers(original_num_parallel_workers) 247 248 249def test_cutmix_batch_nchw_md5(): 250 """ 251 Test CutMixBatch on a batch of CHW images with MD5: 252 """ 253 logger.info("test_cutmix_batch_nchw_md5") 254 original_seed = config_get_set_seed(0) 255 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 256 257 # CutMixBatch Images 258 data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 259 hwc2chw_op = vision.HWC2CHW() 260 data = data.map(operations=hwc2chw_op, input_columns=["image"]) 261 one_hot_op = data_trans.OneHot(num_classes=10) 262 data = data.map(operations=one_hot_op, input_columns=["label"]) 263 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) 264 data = data.batch(5, drop_remainder=True) 265 data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 266 267 filename = "cutmix_batch_c_nchw_result.npz" 268 save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) 269 270 # Restore config setting 271 ds.config.set_seed(original_seed) 272 ds.config.set_num_parallel_workers(original_num_parallel_workers) 273 274 275def test_cutmix_batch_fail1(): 276 """ 277 Test CutMixBatch Fail 1 278 We expect this to fail because the images and labels are not batched 279 """ 280 logger.info("test_cutmix_batch_fail1") 281 282 # CutMixBatch Images 283 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 284 285 one_hot_op = data_trans.OneHot(num_classes=10) 286 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 287 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) 288 with pytest.raises(RuntimeError) as error: 289 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 290 for idx, (image, _) in enumerate(data1): 291 if idx == 0: 292 images_cutmix = image.asnumpy() 293 else: 294 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 295 error_message = "You must make sure images are HWC or CHW and batch " 296 assert error_message in str(error.value) 297 298 299def test_cutmix_batch_fail2(): 300 """ 301 Test CutMixBatch Fail 2 302 We expect this to fail because alpha is negative 303 """ 304 logger.info("test_cutmix_batch_fail2") 305 306 # CutMixBatch Images 307 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 308 309 one_hot_op = data_trans.OneHot(num_classes=10) 310 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 311 with pytest.raises(ValueError) as error: 312 vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1) 313 error_message = "Input is not within the required interval" 314 assert error_message in str(error.value) 315 316 317def test_cutmix_batch_fail3(): 318 """ 319 Test CutMixBatch Fail 2 320 We expect this to fail because prob is larger than 1 321 """ 322 logger.info("test_cutmix_batch_fail3") 323 324 # CutMixBatch Images 325 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 326 327 one_hot_op = data_trans.OneHot(num_classes=10) 328 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 329 with pytest.raises(ValueError) as error: 330 vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2) 331 error_message = "Input is not within the required interval" 332 assert error_message in str(error.value) 333 334 335def test_cutmix_batch_fail4(): 336 """ 337 Test CutMixBatch Fail 2 338 We expect this to fail because prob is negative 339 """ 340 logger.info("test_cutmix_batch_fail4") 341 342 # CutMixBatch Images 343 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 344 345 one_hot_op = data_trans.OneHot(num_classes=10) 346 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 347 with pytest.raises(ValueError) as error: 348 vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1) 349 error_message = "Input is not within the required interval" 350 assert error_message in str(error.value) 351 352 353def test_cutmix_batch_fail5(): 354 """ 355 Test CutMixBatch op 356 We expect this to fail because label column is not passed to cutmix_batch 357 """ 358 logger.info("test_cutmix_batch_fail5") 359 360 # CutMixBatch Images 361 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 362 363 one_hot_op = data_trans.OneHot(num_classes=10) 364 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 365 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) 366 data1 = data1.batch(5, drop_remainder=True) 367 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image"]) 368 369 with pytest.raises(RuntimeError) as error: 370 images_cutmix = np.array([]) 371 for idx, (image, _) in enumerate(data1): 372 if idx == 0: 373 images_cutmix = image.asnumpy() 374 else: 375 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 376 error_message = "size of input should be 2 (including image and label)" 377 assert error_message in str(error.value) 378 379 380def test_cutmix_batch_fail6(): 381 """ 382 Test CutMixBatch op 383 We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images 384 """ 385 logger.info("test_cutmix_batch_fail6") 386 387 # CutMixBatch Images 388 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 389 390 one_hot_op = data_trans.OneHot(num_classes=10) 391 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 392 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW) 393 data1 = data1.batch(5, drop_remainder=True) 394 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 395 396 with pytest.raises(RuntimeError) as error: 397 images_cutmix = np.array([]) 398 for idx, (image, _) in enumerate(data1): 399 if idx == 0: 400 images_cutmix = image.asnumpy() 401 else: 402 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 403 error_message = "image doesn't match the <N,C,H,W> format" 404 assert error_message in str(error.value) 405 406 407def test_cutmix_batch_fail7(): 408 """ 409 Test CutMixBatch op 410 We expect this to fail because labels are not in one-hot format 411 """ 412 logger.info("test_cutmix_batch_fail7") 413 414 # CutMixBatch Images 415 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 416 417 cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC) 418 data1 = data1.batch(5, drop_remainder=True) 419 data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"]) 420 421 with pytest.raises(RuntimeError) as error: 422 images_cutmix = np.array([]) 423 for idx, (image, _) in enumerate(data1): 424 if idx == 0: 425 images_cutmix = image.asnumpy() 426 else: 427 images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0) 428 error_message = "wrong labels shape. The second column (labels) must have a shape of NC or NLC" 429 assert error_message in str(error.value) 430 431 432def test_cutmix_batch_fail8(): 433 """ 434 Test CutMixBatch Fail 8 435 We expect this to fail because alpha is zero 436 """ 437 logger.info("test_cutmix_batch_fail8") 438 439 # CutMixBatch Images 440 data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) 441 442 one_hot_op = data_trans.OneHot(num_classes=10) 443 data1 = data1.map(operations=one_hot_op, input_columns=["label"]) 444 with pytest.raises(ValueError) as error: 445 vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0) 446 error_message = "Input is not within the required interval" 447 assert error_message in str(error.value) 448 449 450if __name__ == "__main__": 451 test_cutmix_batch_success1(plot=True) 452 test_cutmix_batch_success2(plot=True) 453 test_cutmix_batch_success3(plot=True) 454 test_cutmix_batch_success4(plot=True) 455 test_cutmix_batch_nchw_md5() 456 test_cutmix_batch_nhwc_md5() 457 test_cutmix_batch_fail1() 458 test_cutmix_batch_fail2() 459 test_cutmix_batch_fail3() 460 test_cutmix_batch_fail4() 461 test_cutmix_batch_fail5() 462 test_cutmix_batch_fail6() 463 test_cutmix_batch_fail7() 464 test_cutmix_batch_fail8() 465