• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RandomPosterize op in DE
17"""
18import numpy as np
19import mindspore.dataset as ds
20import mindspore.dataset.vision.c_transforms as c_vision
21from mindspore import log as logger
22from util import visualize_list, save_and_check_md5, \
23    config_get_set_seed, config_get_set_num_parallel_workers, diff_mse
24
25GENERATE_GOLDEN = False
26
27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
29
30
31def test_random_posterize_op_c(plot=False, run_golden=False):
32    """
33    Test RandomPosterize in C transformations (uses assertion on mse as using md5 could have jpeg decoding
34    inconsistencies)
35    """
36    logger.info("test_random_posterize_op_c")
37
38    original_seed = config_get_set_seed(55)
39    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
40
41    # define map operations
42    transforms1 = [
43        c_vision.Decode(),
44        c_vision.RandomPosterize((1, 8))
45    ]
46
47    #  First dataset
48    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
49    data1 = data1.map(operations=transforms1, input_columns=["image"])
50    #  Second dataset
51    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
52    data2 = data2.map(operations=[c_vision.Decode()], input_columns=["image"])
53
54    image_posterize = []
55    image_original = []
56    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
57                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
58        image1 = item1["image"]
59        image2 = item2["image"]
60        image_posterize.append(image1)
61        image_original.append(image2)
62
63    # check mse as md5 can be inconsistent.
64    # mse = 2.9668956 is calculated from
65    # a thousand runs of diff_mse(np.array(image_original), np.array(image_posterize)) that all produced the same mse.
66    # allow for an error of 0.0000005
67    assert abs(2.9668956 - diff_mse(np.array(image_original), np.array(image_posterize))) <= 0.0000005
68
69    if run_golden:
70        # check results with md5 comparison
71        filename = "random_posterize_01_result_c.npz"
72        save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
73
74    if plot:
75        visualize_list(image_original, image_posterize)
76
77    # Restore configuration
78    ds.config.set_seed(original_seed)
79    ds.config.set_num_parallel_workers(original_num_parallel_workers)
80
81
82def test_random_posterize_op_fixed_point_c(plot=False, run_golden=True):
83    """
84    Test RandomPosterize in C transformations with fixed point
85    """
86    logger.info("test_random_posterize_op_c")
87
88    # define map operations
89    transforms1 = [
90        c_vision.Decode(),
91        c_vision.RandomPosterize(1)
92    ]
93
94    #  First dataset
95    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
96    data1 = data1.map(operations=transforms1, input_columns=["image"])
97    #  Second dataset
98    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
99    data2 = data2.map(operations=[c_vision.Decode()], input_columns=["image"])
100
101    image_posterize = []
102    image_original = []
103    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
104                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
105        image1 = item1["image"]
106        image2 = item2["image"]
107        image_posterize.append(image1)
108        image_original.append(image2)
109
110    if run_golden:
111        # check results with md5 comparison
112        filename = "random_posterize_fixed_point_01_result_c.npz"
113        save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
114
115    if plot:
116        visualize_list(image_original, image_posterize)
117
118
119def test_random_posterize_default_c_md5(plot=False, run_golden=True):
120    """
121    Test RandomPosterize C Op (default params) with md5 comparison
122    """
123    logger.info("test_random_posterize_default_c_md5")
124    original_seed = config_get_set_seed(5)
125    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
126    # define map operations
127    transforms1 = [
128        c_vision.Decode(),
129        c_vision.RandomPosterize()
130    ]
131
132    #  First dataset
133    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
134    data1 = data1.map(operations=transforms1, input_columns=["image"])
135    #  Second dataset
136    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
137    data2 = data2.map(operations=[c_vision.Decode()], input_columns=["image"])
138
139    image_posterize = []
140    image_original = []
141    for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1),
142                            data2.create_dict_iterator(output_numpy=True, num_epochs=1)):
143        image1 = item1["image"]
144        image2 = item2["image"]
145        image_posterize.append(image1)
146        image_original.append(image2)
147
148    if run_golden:
149        # check results with md5 comparison
150        filename = "random_posterize_01_default_result_c.npz"
151        save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
152
153    if plot:
154        visualize_list(image_original, image_posterize)
155
156    # Restore configuration
157    ds.config.set_seed(original_seed)
158    ds.config.set_num_parallel_workers(original_num_parallel_workers)
159
160
161def test_random_posterize_exception_bit():
162    """
163    Test RandomPosterize: out of range input bits and invalid type
164    """
165    logger.info("test_random_posterize_exception_bit")
166    # Test max > 8
167    try:
168        _ = c_vision.RandomPosterize((1, 9))
169    except ValueError as e:
170        logger.info("Got an exception in DE: {}".format(str(e)))
171        assert str(e) == "Input is not within the required interval of [1, 8]."
172    # Test min < 1
173    try:
174        _ = c_vision.RandomPosterize((0, 7))
175    except ValueError as e:
176        logger.info("Got an exception in DE: {}".format(str(e)))
177        assert str(e) == "Input is not within the required interval of [1, 8]."
178    # Test max < min
179    try:
180        _ = c_vision.RandomPosterize((8, 1))
181    except ValueError as e:
182        logger.info("Got an exception in DE: {}".format(str(e)))
183        assert str(e) == "Input is not within the required interval of [1, 8]."
184    # Test wrong type (not uint8)
185    try:
186        _ = c_vision.RandomPosterize(1.1)
187    except TypeError as e:
188        logger.info("Got an exception in DE: {}".format(str(e)))
189        assert str(e) == ("Argument bits with value 1.1 is not of type [<class 'list'>, <class 'tuple'>, "
190                          "<class 'int'>], but got <class 'float'>.")
191    # Test wrong number of bits
192    try:
193        _ = c_vision.RandomPosterize((1, 1, 1))
194    except TypeError as e:
195        logger.info("Got an exception in DE: {}".format(str(e)))
196        assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2."
197
198
199def test_rescale_with_random_posterize():
200    """
201    Test RandomPosterize: only support CV_8S/CV_8U
202    """
203    logger.info("test_rescale_with_random_posterize")
204
205    DATA_DIR_10 = "../data/dataset/testCifar10Data"
206    dataset = ds.Cifar10Dataset(DATA_DIR_10)
207
208    rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0)
209    dataset = dataset.map(operations=rescale_op, input_columns=["image"])
210
211    random_posterize_op = c_vision.RandomPosterize((4, 8))
212    dataset = dataset.map(operations=random_posterize_op, input_columns=["image"], num_parallel_workers=1)
213
214    try:
215        _ = dataset.output_shapes()
216    except RuntimeError as e:
217        logger.info("Got an exception in DE: {}".format(str(e)))
218        assert "data type of input image should be int" in str(e)
219
220
221if __name__ == "__main__":
222    test_random_posterize_op_c(plot=False, run_golden=False)
223    test_random_posterize_op_fixed_point_c(plot=False)
224    test_random_posterize_default_c_md5(plot=False)
225    test_random_posterize_exception_bit()
226    test_rescale_with_random_posterize()
227