1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing AdjustGamma op in DE 17""" 18import numpy as np 19from numpy.testing import assert_allclose 20import PIL 21 22import mindspore.dataset as ds 23import mindspore.dataset.transforms.py_transforms 24import mindspore.dataset.vision.py_transforms as F 25import mindspore.dataset.vision.c_transforms as C 26from mindspore import log as logger 27 28DATA_DIR = "../data/dataset/testImageNetData/train/" 29MNIST_DATA_DIR = "../data/dataset/testMnistData" 30 31DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 32SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 33 34 35def generate_numpy_random_rgb(shape): 36 """ 37 Only generate floating points that are fractions like n / 256, since they 38 are RGB pixels. Some low-precision floating point types in this test can't 39 handle arbitrary precision floating points well. 40 """ 41 return np.random.randint(0, 256, shape) / 255. 42 43 44def test_adjust_gamma_c_eager(): 45 # Eager 3-channel 46 rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32) 47 img_in = rgb_flat.reshape((8, 8, 3)) 48 49 adjustgamma_op = C.AdjustGamma(10, 1) 50 img_out = adjustgamma_op(img_in) 51 assert img_out is not None 52 53 54def test_adjust_gamma_py_eager(): 55 # Eager 3-channel 56 rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.uint8) 57 img_in = PIL.Image.fromarray(rgb_flat.reshape((8, 8, 3))) 58 59 adjustgamma_op = F.AdjustGamma(10, 1) 60 img_out = adjustgamma_op(img_in) 61 assert img_out is not None 62 63 64def test_adjust_gamma_c_eager_gray(): 65 # Eager 3-channel 66 rgb_flat = generate_numpy_random_rgb((64, 1)).astype(np.float32) 67 img_in = rgb_flat.reshape((8, 8)) 68 69 adjustgamma_op = C.AdjustGamma(10, 1) 70 img_out = adjustgamma_op(img_in) 71 assert img_out is not None 72 73 74def test_adjust_gamma_py_eager_gray(): 75 # Eager 3-channel 76 rgb_flat = generate_numpy_random_rgb((64, 1)).astype(np.uint8) 77 img_in = PIL.Image.fromarray(rgb_flat.reshape((8, 8))) 78 79 adjustgamma_op = F.AdjustGamma(10, 1) 80 img_out = adjustgamma_op(img_in) 81 assert img_out is not None 82 83 84def test_adjust_gamma_invalid_gamma_param_c(): 85 """ 86 Test AdjustGamma C Op with invalid ignore parameter 87 """ 88 logger.info("Test AdjustGamma C Op with invalid ignore parameter") 89 try: 90 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 91 data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), lambda img: np.array(img[:, :, 0])], 92 input_columns=["image"]) 93 # invalid gamma 94 data_set = data_set.map(operations=C.AdjustGamma(gamma=-10.0, gain=1.0), 95 input_columns="image") 96 except ValueError as error: 97 logger.info("Got an exception in AdjustGamma: {}".format(str(error))) 98 assert "Input is not within the required interval of " in str(error) 99 try: 100 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 101 data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), lambda img: np.array(img[:, :, 0])], 102 input_columns=["image"]) 103 # invalid gamma 104 data_set = data_set.map(operations=C.AdjustGamma(gamma=[1, 2], gain=1.0), 105 input_columns="image") 106 except TypeError as error: 107 logger.info("Got an exception in AdjustGamma: {}".format(str(error))) 108 assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error) 109 110 111def test_adjust_gamma_invalid_gamma_param_py(): 112 """ 113 Test AdjustGamma python Op with invalid ignore parameter 114 """ 115 logger.info("Test AdjustGamma python Op with invalid ignore parameter") 116 try: 117 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 118 trans = mindspore.dataset.transforms.py_transforms.Compose([ 119 F.Decode(), 120 F.Resize((224, 224)), 121 F.AdjustGamma(gamma=-10.0), 122 F.ToTensor() 123 ]) 124 data_set = data_set.map(operations=[trans], input_columns=["image"]) 125 except ValueError as error: 126 logger.info("Got an exception in AdjustGamma: {}".format(str(error))) 127 assert "Input is not within the required interval of " in str(error) 128 try: 129 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 130 trans = mindspore.dataset.transforms.py_transforms.Compose([ 131 F.Decode(), 132 F.Resize((224, 224)), 133 F.AdjustGamma(gamma=[1, 2]), 134 F.ToTensor() 135 ]) 136 data_set = data_set.map(operations=[trans], input_columns=["image"]) 137 except TypeError as error: 138 logger.info("Got an exception in AdjustGamma: {}".format(str(error))) 139 assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error) 140 141 142def test_adjust_gamma_invalid_gain_param_c(): 143 """ 144 Test AdjustGamma C Op with invalid gain parameter 145 """ 146 logger.info("Test AdjustGamma C Op with invalid gain parameter") 147 try: 148 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 149 data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), lambda img: np.array(img[:, :, 0])], 150 input_columns=["image"]) 151 # invalid gain 152 data_set = data_set.map(operations=C.AdjustGamma(gamma=10.0, gain=[1, 10]), 153 input_columns="image") 154 except TypeError as error: 155 logger.info("Got an exception in AdjustGamma: {}".format(str(error))) 156 assert "is not of type [<class 'float'>, <class 'int'>], but got " in str(error) 157 158 159def test_adjust_gamma_invalid_gain_param_py(): 160 """ 161 Test AdjustGamma python Op with invalid gain parameter 162 """ 163 logger.info("Test AdjustGamma python Op with invalid gain parameter") 164 try: 165 data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False) 166 trans = mindspore.dataset.transforms.py_transforms.Compose([ 167 F.Decode(), 168 F.Resize((224, 224)), 169 F.AdjustGamma(gamma=10.0, gain=[1, 10]), 170 F.ToTensor() 171 ]) 172 data_set = data_set.map(operations=[trans], input_columns=["image"]) 173 except TypeError as error: 174 logger.info("Got an exception in AdjustGamma: {}".format(str(error))) 175 assert "is not of type [<class 'float'>, <class 'int'>], but got " in str(error) 176 177 178def test_adjust_gamma_pipeline_c(): 179 """ 180 Test AdjustGamma C Op Pipeline 181 """ 182 # First dataset 183 transforms1 = [C.Decode(), C.Resize([64, 64])] 184 transforms1 = mindspore.dataset.transforms.py_transforms.Compose( 185 transforms1) 186 ds1 = ds.TFRecordDataset(DATA_DIR_2, 187 SCHEMA_DIR, 188 columns_list=["image"], 189 shuffle=False) 190 ds1 = ds1.map(operations=transforms1, input_columns=["image"]) 191 192 # Second dataset 193 transforms2 = [ 194 C.Decode(), 195 C.Resize([64, 64]), 196 C.AdjustGamma(1.0, 1.0) 197 ] 198 transform2 = mindspore.dataset.transforms.py_transforms.Compose( 199 transforms2) 200 ds2 = ds.TFRecordDataset(DATA_DIR_2, 201 SCHEMA_DIR, 202 columns_list=["image"], 203 shuffle=False) 204 ds2 = ds2.map(operations=transform2, input_columns=["image"]) 205 206 num_iter = 0 207 for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), 208 ds2.create_dict_iterator(num_epochs=1)): 209 num_iter += 1 210 ori_img = data1["image"].asnumpy() 211 cvt_img = data2["image"].asnumpy() 212 assert_allclose(ori_img.flatten(), 213 cvt_img.flatten(), 214 rtol=1e-5, 215 atol=0) 216 assert ori_img.shape == cvt_img.shape 217 218 219def test_adjust_gamma_pipeline_py(): 220 """ 221 Test AdjustGamma python Op Pipeline 222 """ 223 # First dataset 224 transforms1 = [F.Decode(), F.Resize([64, 64]), F.ToTensor()] 225 transforms1 = mindspore.dataset.transforms.py_transforms.Compose( 226 transforms1) 227 ds1 = ds.TFRecordDataset(DATA_DIR_2, 228 SCHEMA_DIR, 229 columns_list=["image"], 230 shuffle=False) 231 ds1 = ds1.map(operations=transforms1, input_columns=["image"]) 232 233 # Second dataset 234 transforms2 = [ 235 F.Decode(), 236 F.Resize([64, 64]), 237 F.AdjustGamma(1.0, 1.0), 238 F.ToTensor() 239 ] 240 transform2 = mindspore.dataset.transforms.py_transforms.Compose( 241 transforms2) 242 ds2 = ds.TFRecordDataset(DATA_DIR_2, 243 SCHEMA_DIR, 244 columns_list=["image"], 245 shuffle=False) 246 ds2 = ds2.map(operations=transform2, input_columns=["image"]) 247 248 num_iter = 0 249 for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), 250 ds2.create_dict_iterator(num_epochs=1)): 251 num_iter += 1 252 ori_img = data1["image"].asnumpy() 253 cvt_img = data2["image"].asnumpy() 254 assert_allclose(ori_img.flatten(), 255 cvt_img.flatten(), 256 rtol=1e-5, 257 atol=0) 258 assert ori_img.shape == cvt_img.shape 259 260 261def test_adjust_gamma_pipeline_py_gray(): 262 """ 263 Test AdjustGamma python Op Pipeline 1-channel 264 """ 265 # First dataset 266 transforms1 = [F.Decode(), F.Resize([64, 64]), F.Grayscale(), F.ToTensor()] 267 transforms1 = mindspore.dataset.transforms.py_transforms.Compose( 268 transforms1) 269 ds1 = ds.TFRecordDataset(DATA_DIR_2, 270 SCHEMA_DIR, 271 columns_list=["image"], 272 shuffle=False) 273 ds1 = ds1.map(operations=transforms1, input_columns=["image"]) 274 275 # Second dataset 276 transforms2 = [ 277 F.Decode(), 278 F.Resize([64, 64]), 279 F.Grayscale(), 280 F.AdjustGamma(1.0, 1.0), 281 F.ToTensor() 282 ] 283 transform2 = mindspore.dataset.transforms.py_transforms.Compose( 284 transforms2) 285 ds2 = ds.TFRecordDataset(DATA_DIR_2, 286 SCHEMA_DIR, 287 columns_list=["image"], 288 shuffle=False) 289 ds2 = ds2.map(operations=transform2, input_columns=["image"]) 290 291 num_iter = 0 292 for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), 293 ds2.create_dict_iterator(num_epochs=1)): 294 num_iter += 1 295 ori_img = data1["image"].asnumpy() 296 cvt_img = data2["image"].asnumpy() 297 assert_allclose(ori_img.flatten(), 298 cvt_img.flatten(), 299 rtol=1e-5, 300 atol=0) 301 302 303if __name__ == "__main__": 304 test_adjust_gamma_c_eager() 305 test_adjust_gamma_py_eager() 306 test_adjust_gamma_c_eager_gray() 307 test_adjust_gamma_py_eager_gray() 308 309 test_adjust_gamma_invalid_gamma_param_c() 310 test_adjust_gamma_invalid_gamma_param_py() 311 test_adjust_gamma_invalid_gain_param_c() 312 test_adjust_gamma_invalid_gain_param_py() 313 test_adjust_gamma_pipeline_c() 314 test_adjust_gamma_pipeline_py() 315 test_adjust_gamma_pipeline_py_gray() 316