1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing RandomPosterize op in DE 17""" 18import numpy as np 19import mindspore.dataset as ds 20import mindspore.dataset.vision.c_transforms as c_vision 21from mindspore import log as logger 22from util import visualize_list, save_and_check_md5, \ 23 config_get_set_seed, config_get_set_num_parallel_workers, diff_mse 24 25GENERATE_GOLDEN = False 26 27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 29 30 31def test_random_posterize_op_c(plot=False, run_golden=False): 32 """ 33 Test RandomPosterize in C transformations (uses assertion on mse as using md5 could have jpeg decoding 34 inconsistencies) 35 """ 36 logger.info("test_random_posterize_op_c") 37 38 original_seed = config_get_set_seed(55) 39 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 40 41 # define map operations 42 transforms1 = [ 43 c_vision.Decode(), 44 c_vision.RandomPosterize((1, 8)) 45 ] 46 47 # First dataset 48 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 49 data1 = data1.map(operations=transforms1, input_columns=["image"]) 50 # Second dataset 51 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 52 data2 = data2.map(operations=[c_vision.Decode()], input_columns=["image"]) 53 54 image_posterize = [] 55 image_original = [] 56 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 57 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 58 image1 = item1["image"] 59 image2 = item2["image"] 60 image_posterize.append(image1) 61 image_original.append(image2) 62 63 # check mse as md5 can be inconsistent. 64 # mse = 2.9668956 is calculated from 65 # a thousand runs of diff_mse(np.array(image_original), np.array(image_posterize)) that all produced the same mse. 66 # allow for an error of 0.0000005 67 assert abs(2.9668956 - diff_mse(np.array(image_original), np.array(image_posterize))) <= 0.0000005 68 69 if run_golden: 70 # check results with md5 comparison 71 filename = "random_posterize_01_result_c.npz" 72 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 73 74 if plot: 75 visualize_list(image_original, image_posterize) 76 77 # Restore configuration 78 ds.config.set_seed(original_seed) 79 ds.config.set_num_parallel_workers(original_num_parallel_workers) 80 81 82def test_random_posterize_op_fixed_point_c(plot=False, run_golden=True): 83 """ 84 Test RandomPosterize in C transformations with fixed point 85 """ 86 logger.info("test_random_posterize_op_c") 87 88 # define map operations 89 transforms1 = [ 90 c_vision.Decode(), 91 c_vision.RandomPosterize(1) 92 ] 93 94 # First dataset 95 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 96 data1 = data1.map(operations=transforms1, input_columns=["image"]) 97 # Second dataset 98 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 99 data2 = data2.map(operations=[c_vision.Decode()], input_columns=["image"]) 100 101 image_posterize = [] 102 image_original = [] 103 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 104 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 105 image1 = item1["image"] 106 image2 = item2["image"] 107 image_posterize.append(image1) 108 image_original.append(image2) 109 110 if run_golden: 111 # check results with md5 comparison 112 filename = "random_posterize_fixed_point_01_result_c.npz" 113 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 114 115 if plot: 116 visualize_list(image_original, image_posterize) 117 118 119def test_random_posterize_default_c_md5(plot=False, run_golden=True): 120 """ 121 Test RandomPosterize C Op (default params) with md5 comparison 122 """ 123 logger.info("test_random_posterize_default_c_md5") 124 original_seed = config_get_set_seed(5) 125 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 126 # define map operations 127 transforms1 = [ 128 c_vision.Decode(), 129 c_vision.RandomPosterize() 130 ] 131 132 # First dataset 133 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 134 data1 = data1.map(operations=transforms1, input_columns=["image"]) 135 # Second dataset 136 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 137 data2 = data2.map(operations=[c_vision.Decode()], input_columns=["image"]) 138 139 image_posterize = [] 140 image_original = [] 141 for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), 142 data2.create_dict_iterator(output_numpy=True, num_epochs=1)): 143 image1 = item1["image"] 144 image2 = item2["image"] 145 image_posterize.append(image1) 146 image_original.append(image2) 147 148 if run_golden: 149 # check results with md5 comparison 150 filename = "random_posterize_01_default_result_c.npz" 151 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 152 153 if plot: 154 visualize_list(image_original, image_posterize) 155 156 # Restore configuration 157 ds.config.set_seed(original_seed) 158 ds.config.set_num_parallel_workers(original_num_parallel_workers) 159 160 161def test_random_posterize_exception_bit(): 162 """ 163 Test RandomPosterize: out of range input bits and invalid type 164 """ 165 logger.info("test_random_posterize_exception_bit") 166 # Test max > 8 167 try: 168 _ = c_vision.RandomPosterize((1, 9)) 169 except ValueError as e: 170 logger.info("Got an exception in DE: {}".format(str(e))) 171 assert str(e) == "Input is not within the required interval of [1, 8]." 172 # Test min < 1 173 try: 174 _ = c_vision.RandomPosterize((0, 7)) 175 except ValueError as e: 176 logger.info("Got an exception in DE: {}".format(str(e))) 177 assert str(e) == "Input is not within the required interval of [1, 8]." 178 # Test max < min 179 try: 180 _ = c_vision.RandomPosterize((8, 1)) 181 except ValueError as e: 182 logger.info("Got an exception in DE: {}".format(str(e))) 183 assert str(e) == "Input is not within the required interval of [1, 8]." 184 # Test wrong type (not uint8) 185 try: 186 _ = c_vision.RandomPosterize(1.1) 187 except TypeError as e: 188 logger.info("Got an exception in DE: {}".format(str(e))) 189 assert str(e) == ("Argument bits with value 1.1 is not of type [<class 'list'>, <class 'tuple'>, " 190 "<class 'int'>], but got <class 'float'>.") 191 # Test wrong number of bits 192 try: 193 _ = c_vision.RandomPosterize((1, 1, 1)) 194 except TypeError as e: 195 logger.info("Got an exception in DE: {}".format(str(e))) 196 assert str(e) == "Size of bits should be a single integer or a list/tuple (min, max) of length 2." 197 198 199def test_rescale_with_random_posterize(): 200 """ 201 Test RandomPosterize: only support CV_8S/CV_8U 202 """ 203 logger.info("test_rescale_with_random_posterize") 204 205 DATA_DIR_10 = "../data/dataset/testCifar10Data" 206 dataset = ds.Cifar10Dataset(DATA_DIR_10) 207 208 rescale_op = c_vision.Rescale((1.0 / 255.0), 0.0) 209 dataset = dataset.map(operations=rescale_op, input_columns=["image"]) 210 211 random_posterize_op = c_vision.RandomPosterize((4, 8)) 212 dataset = dataset.map(operations=random_posterize_op, input_columns=["image"], num_parallel_workers=1) 213 214 try: 215 _ = dataset.output_shapes() 216 except RuntimeError as e: 217 logger.info("Got an exception in DE: {}".format(str(e))) 218 assert "data type of input image should be int" in str(e) 219 220 221if __name__ == "__main__": 222 test_random_posterize_op_c(plot=False, run_golden=False) 223 test_random_posterize_op_fixed_point_c(plot=False) 224 test_random_posterize_default_c_md5(plot=False) 225 test_random_posterize_exception_bit() 226 test_rescale_with_random_posterize() 227