1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing Invert op in DE 17""" 18import numpy as np 19 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.py_transforms 22import mindspore.dataset.vision.py_transforms as F 23import mindspore.dataset.vision.c_transforms as C 24from mindspore import log as logger 25from util import visualize_list, save_and_check_md5, diff_mse 26 27DATA_DIR = "../data/dataset/testImageNetData/train/" 28 29GENERATE_GOLDEN = False 30 31 32def test_invert_callable(): 33 """ 34 Test Invert is callable 35 """ 36 logger.info("Test Invert callable") 37 img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8) 38 logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) 39 40 img = C.Decode()(img) 41 img = C.Invert()(img) 42 logger.info("Image.type: {}, Image.shape: {}".format(type(img), img.shape)) 43 44 assert img.shape == (2268, 4032, 3) 45 46 47def test_invert_py(plot=False): 48 """ 49 Test Invert python op 50 """ 51 logger.info("Test Invert Python op") 52 53 # Original Images 54 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 55 56 transforms_original = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), 57 F.Resize((224, 224)), 58 F.ToTensor()]) 59 60 ds_original = data_set.map(operations=transforms_original, input_columns="image") 61 62 ds_original = ds_original.batch(512) 63 64 for idx, (image, _) in enumerate(ds_original): 65 if idx == 0: 66 images_original = np.transpose(image.asnumpy(), (0, 2, 3, 1)) 67 else: 68 images_original = np.append(images_original, 69 np.transpose(image.asnumpy(), (0, 2, 3, 1)), 70 axis=0) 71 72 # Color Inverted Images 73 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 74 75 transforms_invert = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), 76 F.Resize((224, 224)), 77 F.Invert(), 78 F.ToTensor()]) 79 80 ds_invert = data_set.map(operations=transforms_invert, input_columns="image") 81 82 ds_invert = ds_invert.batch(512) 83 84 for idx, (image, _) in enumerate(ds_invert): 85 if idx == 0: 86 images_invert = np.transpose(image.asnumpy(), (0, 2, 3, 1)) 87 else: 88 images_invert = np.append(images_invert, 89 np.transpose(image.asnumpy(), (0, 2, 3, 1)), 90 axis=0) 91 92 num_samples = images_original.shape[0] 93 mse = np.zeros(num_samples) 94 for i in range(num_samples): 95 mse[i] = np.mean((images_invert[i] - images_original[i]) ** 2) 96 logger.info("MSE= {}".format(str(np.mean(mse)))) 97 98 if plot: 99 visualize_list(images_original, images_invert) 100 101 102def test_invert_c(plot=False): 103 """ 104 Test Invert Cpp op 105 """ 106 logger.info("Test Invert cpp op") 107 108 # Original Images 109 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 110 111 transforms_original = [C.Decode(), C.Resize(size=[224, 224])] 112 113 ds_original = data_set.map(operations=transforms_original, input_columns="image") 114 115 ds_original = ds_original.batch(512) 116 117 for idx, (image, _) in enumerate(ds_original): 118 if idx == 0: 119 images_original = image.asnumpy() 120 else: 121 images_original = np.append(images_original, 122 image.asnumpy(), 123 axis=0) 124 125 # Invert Images 126 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 127 128 transform_invert = [C.Decode(), C.Resize(size=[224, 224]), 129 C.Invert()] 130 131 ds_invert = data_set.map(operations=transform_invert, input_columns="image") 132 133 ds_invert = ds_invert.batch(512) 134 135 for idx, (image, _) in enumerate(ds_invert): 136 if idx == 0: 137 images_invert = image.asnumpy() 138 else: 139 images_invert = np.append(images_invert, 140 image.asnumpy(), 141 axis=0) 142 if plot: 143 visualize_list(images_original, images_invert) 144 145 num_samples = images_original.shape[0] 146 mse = np.zeros(num_samples) 147 for i in range(num_samples): 148 mse[i] = diff_mse(images_invert[i], images_original[i]) 149 logger.info("MSE= {}".format(str(np.mean(mse)))) 150 151 152def test_invert_py_c(plot=False): 153 """ 154 Test Invert Cpp op and python op 155 """ 156 logger.info("Test Invert cpp and python op") 157 158 # Invert Images in cpp 159 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 160 data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224))], input_columns=["image"]) 161 162 ds_c_invert = data_set.map(operations=C.Invert(), input_columns="image") 163 164 ds_c_invert = ds_c_invert.batch(512) 165 166 for idx, (image, _) in enumerate(ds_c_invert): 167 if idx == 0: 168 images_c_invert = image.asnumpy() 169 else: 170 images_c_invert = np.append(images_c_invert, 171 image.asnumpy(), 172 axis=0) 173 174 # invert images in python 175 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 176 data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224))], input_columns=["image"]) 177 178 transforms_p_invert = mindspore.dataset.transforms.py_transforms.Compose([lambda img: img.astype(np.uint8), 179 F.ToPIL(), 180 F.Invert(), 181 np.array]) 182 183 ds_p_invert = data_set.map(operations=transforms_p_invert, input_columns="image") 184 185 ds_p_invert = ds_p_invert.batch(512) 186 187 for idx, (image, _) in enumerate(ds_p_invert): 188 if idx == 0: 189 images_p_invert = image.asnumpy() 190 else: 191 images_p_invert = np.append(images_p_invert, 192 image.asnumpy(), 193 axis=0) 194 195 num_samples = images_c_invert.shape[0] 196 mse = np.zeros(num_samples) 197 for i in range(num_samples): 198 mse[i] = diff_mse(images_p_invert[i], images_c_invert[i]) 199 logger.info("MSE= {}".format(str(np.mean(mse)))) 200 201 if plot: 202 visualize_list(images_c_invert, images_p_invert, visualize_mode=2) 203 204 205def test_invert_one_channel(): 206 """ 207 Test Invert cpp op with one channel image 208 """ 209 logger.info("Test Invert C Op With One Channel Images") 210 211 c_op = C.Invert() 212 213 try: 214 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 215 data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), 216 lambda img: np.array(img[:, :, 0])], input_columns=["image"]) 217 218 data_set.map(operations=c_op, input_columns="image") 219 220 except RuntimeError as e: 221 logger.info("Got an exception in DE: {}".format(str(e))) 222 assert "The shape" in str(e) 223 224 225def test_invert_md5_py(): 226 """ 227 Test Invert python op with md5 check 228 """ 229 logger.info("Test Invert python op with md5 check") 230 231 # Generate dataset 232 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 233 234 transforms_invert = mindspore.dataset.transforms.py_transforms.Compose([F.Decode(), 235 F.Invert(), 236 F.ToTensor()]) 237 238 data = data_set.map(operations=transforms_invert, input_columns="image") 239 # Compare with expected md5 from images 240 filename = "invert_01_result_py.npz" 241 save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) 242 243 244def test_invert_md5_c(): 245 """ 246 Test Invert cpp op with md5 check 247 """ 248 logger.info("Test Invert cpp op with md5 check") 249 250 # Generate dataset 251 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 252 253 transforms_invert = [C.Decode(), 254 C.Resize(size=[224, 224]), 255 C.Invert(), 256 F.ToTensor()] 257 258 data = data_set.map(operations=transforms_invert, input_columns="image") 259 # Compare with expected md5 from images 260 filename = "invert_01_result_c.npz" 261 save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) 262 263 264if __name__ == "__main__": 265 test_invert_callable() 266 test_invert_py(plot=False) 267 test_invert_c(plot=False) 268 test_invert_py_c(plot=False) 269 test_invert_one_channel() 270 test_invert_md5_py() 271 test_invert_md5_c() 272