• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing AdjustGamma op in DE
17"""
18import numpy as np
19from numpy.testing import assert_allclose
20import PIL
21
22import mindspore.dataset as ds
23import mindspore.dataset.transforms.py_transforms
24import mindspore.dataset.vision.py_transforms as F
25import mindspore.dataset.vision.c_transforms as C
26from mindspore import log as logger
27
28DATA_DIR = "../data/dataset/testImageNetData/train/"
29MNIST_DATA_DIR = "../data/dataset/testMnistData"
30
31DATA_DIR_2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
32SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
33
34
35def generate_numpy_random_rgb(shape):
36    """
37    Only generate floating points that are fractions like n / 256, since they
38    are RGB pixels. Some low-precision floating point types in this test can't
39    handle arbitrary precision floating points well.
40    """
41    return np.random.randint(0, 256, shape) / 255.
42
43
44def test_adjust_gamma_c_eager():
45    # Eager 3-channel
46    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
47    img_in = rgb_flat.reshape((8, 8, 3))
48
49    adjustgamma_op = C.AdjustGamma(10, 1)
50    img_out = adjustgamma_op(img_in)
51    assert img_out is not None
52
53
54def test_adjust_gamma_py_eager():
55    # Eager 3-channel
56    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.uint8)
57    img_in = PIL.Image.fromarray(rgb_flat.reshape((8, 8, 3)))
58
59    adjustgamma_op = F.AdjustGamma(10, 1)
60    img_out = adjustgamma_op(img_in)
61    assert img_out is not None
62
63
64def test_adjust_gamma_c_eager_gray():
65    # Eager 3-channel
66    rgb_flat = generate_numpy_random_rgb((64, 1)).astype(np.float32)
67    img_in = rgb_flat.reshape((8, 8))
68
69    adjustgamma_op = C.AdjustGamma(10, 1)
70    img_out = adjustgamma_op(img_in)
71    assert img_out is not None
72
73
74def test_adjust_gamma_py_eager_gray():
75    # Eager 3-channel
76    rgb_flat = generate_numpy_random_rgb((64, 1)).astype(np.uint8)
77    img_in = PIL.Image.fromarray(rgb_flat.reshape((8, 8)))
78
79    adjustgamma_op = F.AdjustGamma(10, 1)
80    img_out = adjustgamma_op(img_in)
81    assert img_out is not None
82
83
84def test_adjust_gamma_invalid_gamma_param_c():
85    """
86    Test AdjustGamma C Op with invalid ignore parameter
87    """
88    logger.info("Test AdjustGamma C Op with invalid ignore parameter")
89    try:
90        data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
91        data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), lambda img: np.array(img[:, :, 0])],
92                                input_columns=["image"])
93        # invalid gamma
94        data_set = data_set.map(operations=C.AdjustGamma(gamma=-10.0, gain=1.0),
95                                input_columns="image")
96    except ValueError as error:
97        logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
98        assert "Input is not within the required interval of " in str(error)
99    try:
100        data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
101        data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), lambda img: np.array(img[:, :, 0])],
102                                input_columns=["image"])
103        # invalid gamma
104        data_set = data_set.map(operations=C.AdjustGamma(gamma=[1, 2], gain=1.0),
105                                input_columns="image")
106    except TypeError as error:
107        logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
108        assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error)
109
110
111def test_adjust_gamma_invalid_gamma_param_py():
112    """
113    Test AdjustGamma python Op with invalid ignore parameter
114    """
115    logger.info("Test AdjustGamma python Op with invalid ignore parameter")
116    try:
117        data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
118        trans = mindspore.dataset.transforms.py_transforms.Compose([
119            F.Decode(),
120            F.Resize((224, 224)),
121            F.AdjustGamma(gamma=-10.0),
122            F.ToTensor()
123        ])
124        data_set = data_set.map(operations=[trans], input_columns=["image"])
125    except ValueError as error:
126        logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
127        assert "Input is not within the required interval of " in str(error)
128    try:
129        data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
130        trans = mindspore.dataset.transforms.py_transforms.Compose([
131            F.Decode(),
132            F.Resize((224, 224)),
133            F.AdjustGamma(gamma=[1, 2]),
134            F.ToTensor()
135        ])
136        data_set = data_set.map(operations=[trans], input_columns=["image"])
137    except TypeError as error:
138        logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
139        assert "is not of type [<class 'float'>, <class 'int'>], but got" in str(error)
140
141
142def test_adjust_gamma_invalid_gain_param_c():
143    """
144    Test AdjustGamma C Op with invalid gain parameter
145    """
146    logger.info("Test AdjustGamma C Op with invalid gain parameter")
147    try:
148        data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
149        data_set = data_set.map(operations=[C.Decode(), C.Resize((224, 224)), lambda img: np.array(img[:, :, 0])],
150                                input_columns=["image"])
151        # invalid gain
152        data_set = data_set.map(operations=C.AdjustGamma(gamma=10.0, gain=[1, 10]),
153                                input_columns="image")
154    except TypeError as error:
155        logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
156        assert "is not of type [<class 'float'>, <class 'int'>], but got " in str(error)
157
158
159def test_adjust_gamma_invalid_gain_param_py():
160    """
161    Test AdjustGamma python Op with invalid gain parameter
162    """
163    logger.info("Test AdjustGamma python Op with invalid gain parameter")
164    try:
165        data_set = ds.ImageFolderDataset(dataset_dir=DATA_DIR, shuffle=False)
166        trans = mindspore.dataset.transforms.py_transforms.Compose([
167            F.Decode(),
168            F.Resize((224, 224)),
169            F.AdjustGamma(gamma=10.0, gain=[1, 10]),
170            F.ToTensor()
171        ])
172        data_set = data_set.map(operations=[trans], input_columns=["image"])
173    except TypeError as error:
174        logger.info("Got an exception in AdjustGamma: {}".format(str(error)))
175        assert "is not of type [<class 'float'>, <class 'int'>], but got " in str(error)
176
177
178def test_adjust_gamma_pipeline_c():
179    """
180    Test AdjustGamma C Op Pipeline
181    """
182    # First dataset
183    transforms1 = [C.Decode(), C.Resize([64, 64])]
184    transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
185        transforms1)
186    ds1 = ds.TFRecordDataset(DATA_DIR_2,
187                             SCHEMA_DIR,
188                             columns_list=["image"],
189                             shuffle=False)
190    ds1 = ds1.map(operations=transforms1, input_columns=["image"])
191
192    # Second dataset
193    transforms2 = [
194        C.Decode(),
195        C.Resize([64, 64]),
196        C.AdjustGamma(1.0, 1.0)
197    ]
198    transform2 = mindspore.dataset.transforms.py_transforms.Compose(
199        transforms2)
200    ds2 = ds.TFRecordDataset(DATA_DIR_2,
201                             SCHEMA_DIR,
202                             columns_list=["image"],
203                             shuffle=False)
204    ds2 = ds2.map(operations=transform2, input_columns=["image"])
205
206    num_iter = 0
207    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
208                            ds2.create_dict_iterator(num_epochs=1)):
209        num_iter += 1
210        ori_img = data1["image"].asnumpy()
211        cvt_img = data2["image"].asnumpy()
212        assert_allclose(ori_img.flatten(),
213                        cvt_img.flatten(),
214                        rtol=1e-5,
215                        atol=0)
216        assert ori_img.shape == cvt_img.shape
217
218
219def test_adjust_gamma_pipeline_py():
220    """
221    Test AdjustGamma python Op Pipeline
222    """
223    # First dataset
224    transforms1 = [F.Decode(), F.Resize([64, 64]), F.ToTensor()]
225    transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
226        transforms1)
227    ds1 = ds.TFRecordDataset(DATA_DIR_2,
228                             SCHEMA_DIR,
229                             columns_list=["image"],
230                             shuffle=False)
231    ds1 = ds1.map(operations=transforms1, input_columns=["image"])
232
233    # Second dataset
234    transforms2 = [
235        F.Decode(),
236        F.Resize([64, 64]),
237        F.AdjustGamma(1.0, 1.0),
238        F.ToTensor()
239    ]
240    transform2 = mindspore.dataset.transforms.py_transforms.Compose(
241        transforms2)
242    ds2 = ds.TFRecordDataset(DATA_DIR_2,
243                             SCHEMA_DIR,
244                             columns_list=["image"],
245                             shuffle=False)
246    ds2 = ds2.map(operations=transform2, input_columns=["image"])
247
248    num_iter = 0
249    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
250                            ds2.create_dict_iterator(num_epochs=1)):
251        num_iter += 1
252        ori_img = data1["image"].asnumpy()
253        cvt_img = data2["image"].asnumpy()
254        assert_allclose(ori_img.flatten(),
255                        cvt_img.flatten(),
256                        rtol=1e-5,
257                        atol=0)
258        assert ori_img.shape == cvt_img.shape
259
260
261def test_adjust_gamma_pipeline_py_gray():
262    """
263    Test AdjustGamma python Op Pipeline 1-channel
264    """
265    # First dataset
266    transforms1 = [F.Decode(), F.Resize([64, 64]), F.Grayscale(), F.ToTensor()]
267    transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
268        transforms1)
269    ds1 = ds.TFRecordDataset(DATA_DIR_2,
270                             SCHEMA_DIR,
271                             columns_list=["image"],
272                             shuffle=False)
273    ds1 = ds1.map(operations=transforms1, input_columns=["image"])
274
275    # Second dataset
276    transforms2 = [
277        F.Decode(),
278        F.Resize([64, 64]),
279        F.Grayscale(),
280        F.AdjustGamma(1.0, 1.0),
281        F.ToTensor()
282    ]
283    transform2 = mindspore.dataset.transforms.py_transforms.Compose(
284        transforms2)
285    ds2 = ds.TFRecordDataset(DATA_DIR_2,
286                             SCHEMA_DIR,
287                             columns_list=["image"],
288                             shuffle=False)
289    ds2 = ds2.map(operations=transform2, input_columns=["image"])
290
291    num_iter = 0
292    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
293                            ds2.create_dict_iterator(num_epochs=1)):
294        num_iter += 1
295        ori_img = data1["image"].asnumpy()
296        cvt_img = data2["image"].asnumpy()
297        assert_allclose(ori_img.flatten(),
298                        cvt_img.flatten(),
299                        rtol=1e-5,
300                        atol=0)
301
302
303if __name__ == "__main__":
304    test_adjust_gamma_c_eager()
305    test_adjust_gamma_py_eager()
306    test_adjust_gamma_c_eager_gray()
307    test_adjust_gamma_py_eager_gray()
308
309    test_adjust_gamma_invalid_gamma_param_c()
310    test_adjust_gamma_invalid_gamma_param_py()
311    test_adjust_gamma_invalid_gain_param_c()
312    test_adjust_gamma_invalid_gain_param_py()
313    test_adjust_gamma_pipeline_c()
314    test_adjust_gamma_pipeline_py()
315    test_adjust_gamma_pipeline_py_gray()
316