1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing LinearTransformation op in DE 17""" 18import numpy as np 19import mindspore.dataset as ds 20import mindspore.dataset.transforms.py_transforms 21import mindspore.dataset.vision.py_transforms as py_vision 22from mindspore import log as logger 23from util import diff_mse, visualize_list, save_and_check_md5 24 25GENERATE_GOLDEN = False 26 27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 29 30 31def test_linear_transformation_op(plot=False): 32 """ 33 Test LinearTransformation op: verify if images transform correctly 34 """ 35 logger.info("test_linear_transformation_01") 36 37 # Initialize parameters 38 height = 50 39 weight = 50 40 dim = 3 * height * weight 41 transformation_matrix = np.eye(dim) 42 mean_vector = np.zeros(dim) 43 44 # Define operations 45 transforms = [ 46 py_vision.Decode(), 47 py_vision.CenterCrop([height, weight]), 48 py_vision.ToTensor() 49 ] 50 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 51 52 # First dataset 53 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 54 data1 = data1.map(operations=transform, input_columns=["image"]) 55 # Note: if transformation matrix is diagonal matrix with all 1 in diagonal, 56 # the output matrix in expected to be the same as the input matrix. 57 data1 = data1.map(operations=py_vision.LinearTransformation(transformation_matrix, mean_vector), 58 input_columns=["image"]) 59 60 # Second dataset 61 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 62 data2 = data2.map(operations=transform, input_columns=["image"]) 63 64 image_transformed = [] 65 image = [] 66 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 67 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 68 image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 69 image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 70 image_transformed.append(image1) 71 image.append(image2) 72 73 mse = diff_mse(image1, image2) 74 assert mse == 0 75 if plot: 76 visualize_list(image, image_transformed) 77 78 79def test_linear_transformation_md5(): 80 """ 81 Test LinearTransformation op: valid params (transformation_matrix, mean_vector) 82 Expected to pass 83 """ 84 logger.info("test_linear_transformation_md5") 85 86 # Initialize parameters 87 height = 50 88 weight = 50 89 dim = 3 * height * weight 90 transformation_matrix = np.ones([dim, dim]) 91 mean_vector = np.zeros(dim) 92 93 # Generate dataset 94 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 95 transforms = [ 96 py_vision.Decode(), 97 py_vision.CenterCrop([height, weight]), 98 py_vision.ToTensor(), 99 py_vision.LinearTransformation(transformation_matrix, mean_vector) 100 ] 101 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 102 data1 = data1.map(operations=transform, input_columns=["image"]) 103 104 # Compare with expected md5 from images 105 filename = "linear_transformation_01_result.npz" 106 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 107 108 109def test_linear_transformation_exception_01(): 110 """ 111 Test LinearTransformation op: transformation_matrix is not provided 112 Expected to raise ValueError 113 """ 114 logger.info("test_linear_transformation_exception_01") 115 116 # Initialize parameters 117 height = 50 118 weight = 50 119 dim = 3 * height * weight 120 mean_vector = np.zeros(dim) 121 122 # Generate dataset 123 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 124 try: 125 transforms = [ 126 py_vision.Decode(), 127 py_vision.CenterCrop([height, weight]), 128 py_vision.ToTensor(), 129 py_vision.LinearTransformation(None, mean_vector) 130 ] 131 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 132 data1 = data1.map(operations=transform, input_columns=["image"]) 133 except TypeError as e: 134 logger.info("Got an exception in DE: {}".format(str(e))) 135 assert "Argument transformation_matrix with value None is not of type [<class 'numpy.ndarray'>]" in str(e) 136 137 138def test_linear_transformation_exception_02(): 139 """ 140 Test LinearTransformation op: mean_vector is not provided 141 Expected to raise ValueError 142 """ 143 logger.info("test_linear_transformation_exception_02") 144 145 # Initialize parameters 146 height = 50 147 weight = 50 148 dim = 3 * height * weight 149 transformation_matrix = np.ones([dim, dim]) 150 151 # Generate dataset 152 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 153 try: 154 transforms = [ 155 py_vision.Decode(), 156 py_vision.CenterCrop([height, weight]), 157 py_vision.ToTensor(), 158 py_vision.LinearTransformation(transformation_matrix, None) 159 ] 160 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 161 data1 = data1.map(operations=transform, input_columns=["image"]) 162 except TypeError as e: 163 logger.info("Got an exception in DE: {}".format(str(e))) 164 assert "Argument mean_vector with value None is not of type [<class 'numpy.ndarray'>]" in str(e) 165 166 167def test_linear_transformation_exception_03(): 168 """ 169 Test LinearTransformation op: transformation_matrix is not a square matrix 170 Expected to raise ValueError 171 """ 172 logger.info("test_linear_transformation_exception_03") 173 174 # Initialize parameters 175 height = 50 176 weight = 50 177 dim = 3 * height * weight 178 transformation_matrix = np.ones([dim, dim - 1]) 179 mean_vector = np.zeros(dim) 180 181 # Generate dataset 182 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 183 try: 184 transforms = [ 185 py_vision.Decode(), 186 py_vision.CenterCrop([height, weight]), 187 py_vision.ToTensor(), 188 py_vision.LinearTransformation(transformation_matrix, mean_vector) 189 ] 190 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 191 data1 = data1.map(operations=transform, input_columns=["image"]) 192 except ValueError as e: 193 logger.info("Got an exception in DE: {}".format(str(e))) 194 assert "square matrix" in str(e) 195 196 197def test_linear_transformation_exception_04(): 198 """ 199 Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix 200 Expected to raise ValueError 201 """ 202 logger.info("test_linear_transformation_exception_04") 203 204 # Initialize parameters 205 height = 50 206 weight = 50 207 dim = 3 * height * weight 208 transformation_matrix = np.ones([dim, dim]) 209 mean_vector = np.zeros(dim - 1) 210 211 # Generate dataset 212 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 213 try: 214 transforms = [ 215 py_vision.Decode(), 216 py_vision.CenterCrop([height, weight]), 217 py_vision.ToTensor(), 218 py_vision.LinearTransformation(transformation_matrix, mean_vector) 219 ] 220 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 221 data1 = data1.map(operations=transform, input_columns=["image"]) 222 except ValueError as e: 223 logger.info("Got an exception in DE: {}".format(str(e))) 224 assert "should match" in str(e) 225 226 227if __name__ == '__main__': 228 test_linear_transformation_op(plot=True) 229 test_linear_transformation_md5() 230 test_linear_transformation_exception_01() 231 test_linear_transformation_exception_02() 232 test_linear_transformation_exception_03() 233 test_linear_transformation_exception_04() 234