1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing UniformAugment in DE 17""" 18import numpy as np 19import pytest 20 21import mindspore.dataset as ds 22import mindspore.dataset.transforms.py_transforms 23import mindspore.dataset.vision.c_transforms as C 24import mindspore.dataset.vision.py_transforms as F 25from mindspore import log as logger 26from util import visualize_list, diff_mse 27 28DATA_DIR = "../data/dataset/testImageNetData/train/" 29 30 31def test_uniform_augment_callable(num_ops=2): 32 """ 33 Test UniformAugment is callable 34 """ 35 logger.info("test_uniform_augment_callable") 36 img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8) 37 logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) 38 39 decode_op = C.Decode() 40 img = decode_op(img) 41 assert img.shape == (2268, 4032, 3) 42 43 transforms_ua = [C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32]), 44 C.RandomCrop(size=[400, 400], padding=[32, 32, 32, 32])] 45 uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 46 img = uni_aug(img) 47 assert img.shape == (2268, 4032, 3) or img.shape == (400, 400, 3) 48 49 50def test_uniform_augment(plot=False, num_ops=2): 51 """ 52 Test UniformAugment 53 """ 54 logger.info("Test UniformAugment") 55 56 # Original Images 57 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 58 59 transforms_original = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), 60 F.Resize((224, 224)), 61 F.ToTensor()]) 62 63 ds_original = data_set.map(operations=transforms_original, input_columns="image") 64 65 ds_original = ds_original.batch(512) 66 67 for idx, (image, _) in enumerate(ds_original): 68 if idx == 0: 69 images_original = np.transpose(image.asnumpy(), (0, 2, 3, 1)) 70 else: 71 images_original = np.append(images_original, 72 np.transpose(image.asnumpy(), (0, 2, 3, 1)), 73 axis=0) 74 75 # UniformAugment Images 76 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 77 78 transform_list = [F.RandomRotation(45), 79 F.RandomColor(), 80 F.RandomSharpness(), 81 F.Invert(), 82 F.AutoContrast(), 83 F.Equalize()] 84 85 transforms_ua = \ 86 mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), 87 F.Resize((224, 224)), 88 F.UniformAugment(transforms=transform_list, 89 num_ops=num_ops), 90 F.ToTensor()]) 91 92 ds_ua = data_set.map(operations=transforms_ua, input_columns="image") 93 94 ds_ua = ds_ua.batch(512) 95 96 for idx, (image, _) in enumerate(ds_ua): 97 if idx == 0: 98 images_ua = np.transpose(image.asnumpy(), (0, 2, 3, 1)) 99 else: 100 images_ua = np.append(images_ua, 101 np.transpose(image.asnumpy(), (0, 2, 3, 1)), 102 axis=0) 103 104 num_samples = images_original.shape[0] 105 mse = np.zeros(num_samples) 106 for i in range(num_samples): 107 mse[i] = diff_mse(images_ua[i], images_original[i]) 108 logger.info("MSE= {}".format(str(np.mean(mse)))) 109 110 if plot: 111 visualize_list(images_original, images_ua) 112 113 114def test_cpp_uniform_augment(plot=False, num_ops=2): 115 """ 116 Test UniformAugment 117 """ 118 logger.info("Test CPP UniformAugment") 119 120 # Original Images 121 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 122 123 transforms_original = [C.Decode(), C.Resize(size=[224, 224]), 124 F.ToTensor()] 125 126 ds_original = data_set.map(operations=transforms_original, input_columns="image") 127 128 ds_original = ds_original.batch(512) 129 130 for idx, (image, _) in enumerate(ds_original): 131 if idx == 0: 132 images_original = np.transpose(image.asnumpy(), (0, 2, 3, 1)) 133 else: 134 images_original = np.append(images_original, 135 np.transpose(image.asnumpy(), (0, 2, 3, 1)), 136 axis=0) 137 138 # UniformAugment Images 139 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 140 transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), 141 C.RandomHorizontalFlip(), 142 C.RandomVerticalFlip(), 143 C.RandomColorAdjust(), 144 C.RandomRotation(degrees=45)] 145 146 uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 147 148 transforms_all = [C.Decode(), C.Resize(size=[224, 224]), 149 uni_aug, 150 F.ToTensor()] 151 152 ds_ua = data_set.map(operations=transforms_all, input_columns="image", num_parallel_workers=1) 153 154 ds_ua = ds_ua.batch(512) 155 156 for idx, (image, _) in enumerate(ds_ua): 157 if idx == 0: 158 images_ua = np.transpose(image.asnumpy(), (0, 2, 3, 1)) 159 else: 160 images_ua = np.append(images_ua, 161 np.transpose(image.asnumpy(), (0, 2, 3, 1)), 162 axis=0) 163 if plot: 164 visualize_list(images_original, images_ua) 165 166 num_samples = images_original.shape[0] 167 mse = np.zeros(num_samples) 168 for i in range(num_samples): 169 mse[i] = diff_mse(images_ua[i], images_original[i]) 170 logger.info("MSE= {}".format(str(np.mean(mse)))) 171 172 173def test_cpp_uniform_augment_exception_pyops(num_ops=2): 174 """ 175 Test UniformAugment invalid op in operations 176 """ 177 logger.info("Test CPP UniformAugment invalid OP exception") 178 179 transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), 180 C.RandomHorizontalFlip(), 181 C.RandomVerticalFlip(), 182 C.RandomColorAdjust(), 183 C.RandomRotation(degrees=45), 184 F.Invert()] 185 186 with pytest.raises(TypeError) as e: 187 C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 188 189 logger.info("Got an exception in DE: {}".format(str(e))) 190 assert "Type of Transforms[5] must be c_transform" in str(e.value) 191 192 193def test_cpp_uniform_augment_exception_large_numops(num_ops=6): 194 """ 195 Test UniformAugment invalid large number of ops 196 """ 197 logger.info("Test CPP UniformAugment invalid large num_ops exception") 198 199 transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), 200 C.RandomHorizontalFlip(), 201 C.RandomVerticalFlip(), 202 C.RandomColorAdjust(), 203 C.RandomRotation(degrees=45)] 204 205 try: 206 _ = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 207 208 except Exception as e: 209 logger.info("Got an exception in DE: {}".format(str(e))) 210 assert "num_ops" in str(e) 211 212 213def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): 214 """ 215 Test UniformAugment invalid non-positive number of ops 216 """ 217 logger.info("Test CPP UniformAugment invalid non-positive num_ops exception") 218 219 transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), 220 C.RandomHorizontalFlip(), 221 C.RandomVerticalFlip(), 222 C.RandomColorAdjust(), 223 C.RandomRotation(degrees=45)] 224 225 try: 226 _ = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 227 228 except Exception as e: 229 logger.info("Got an exception in DE: {}".format(str(e))) 230 assert "Input num_ops must be greater than 0" in str(e) 231 232 233def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): 234 """ 235 Test UniformAugment invalid float number of ops 236 """ 237 logger.info("Test CPP UniformAugment invalid float num_ops exception") 238 239 transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]), 240 C.RandomHorizontalFlip(), 241 C.RandomVerticalFlip(), 242 C.RandomColorAdjust(), 243 C.RandomRotation(degrees=45)] 244 245 try: 246 _ = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 247 248 except Exception as e: 249 logger.info("Got an exception in DE: {}".format(str(e))) 250 assert "Argument num_ops with value 2.5 is not of type [<class 'int'>]" in str(e) 251 252 253def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): 254 """ 255 Test UniformAugment with greater crop size 256 """ 257 logger.info("Test CPP UniformAugment with random_crop bad input") 258 batch_size = 2 259 cifar10_dir = "../data/dataset/testCifar10Data" 260 ds1 = ds.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3] 261 262 transforms_ua = [ 263 # Note: crop size [224, 224] > image size [32, 32] 264 C.RandomCrop(size=[224, 224]), 265 C.RandomHorizontalFlip() 266 ] 267 uni_aug = C.UniformAugment(transforms=transforms_ua, num_ops=num_ops) 268 ds1 = ds1.map(operations=uni_aug, input_columns="image") 269 270 # apply DatasetOps 271 ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1) 272 num_batches = 0 273 try: 274 for _ in ds1.create_dict_iterator(num_epochs=1, output_numpy=True): 275 num_batches += 1 276 except Exception as e: 277 assert "crop size" in str(e) 278 279 280if __name__ == "__main__": 281 test_uniform_augment_callable(num_ops=2) 282 test_uniform_augment(num_ops=1, plot=True) 283 test_cpp_uniform_augment(num_ops=1, plot=True) 284 test_cpp_uniform_augment_exception_pyops(num_ops=1) 285 test_cpp_uniform_augment_exception_large_numops(num_ops=6) 286 test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0) 287 test_cpp_uniform_augment_exception_float_numops(num_ops=2.5) 288 test_cpp_uniform_augment_random_crop_badinput(num_ops=1) 289