1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17import pytest 18import mindspore.common.dtype as mstype 19import mindspore.dataset as ds 20import mindspore.dataset.transforms.c_transforms as c_transforms 21import mindspore.dataset.transforms.py_transforms as py_transforms 22 23import mindspore.dataset.vision.c_transforms as c_vision 24import mindspore.dataset.vision.py_transforms as py_vision 25 26from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers 27 28GENERATE_GOLDEN = False 29 30 31def test_compose(): 32 """ 33 Test C++ and Python Compose Op 34 """ 35 ds.config.set_seed(0) 36 37 def test_config(arr, op_list): 38 try: 39 data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False) 40 data = data.map(input_columns=["col"], operations=op_list) 41 res = [] 42 for i in data.create_dict_iterator(output_numpy=True): 43 res.append(i["col"].tolist()) 44 return res 45 except (TypeError, ValueError) as e: 46 return str(e) 47 48 # Test simple compose with only 1 op, this would generate a warning 49 assert test_config([[1, 0], [3, 4]], c_transforms.Compose([c_transforms.Fill(2)])) == [[2, 2], [2, 2]] 50 51 # Test 1 column -> 2 columns -> 1 -> 2 -> 1 52 assert test_config([[1, 0]], 53 c_transforms.Compose( 54 [c_transforms.Duplicate(), c_transforms.Concatenate(), c_transforms.Duplicate(), 55 c_transforms.Concatenate()])) \ 56 == [[1, 0] * 4] 57 58 # Test one Python transform followed by a C++ transform. Type after OneHot is a float (mixed use-case) 59 assert test_config([1, 0], 60 c_transforms.Compose([py_transforms.OneHotOp(2), c_transforms.TypeCast(mstype.int32)])) \ 61 == [[0, 1], [1, 0]] 62 63 # Test exceptions. 64 with pytest.raises(TypeError) as error_info: 65 c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)]) 66 assert "op_list[0] is neither a c_transform op (TensorOperation) nor a callable pyfunc." in str(error_info.value) 67 68 # Test empty op list 69 with pytest.raises(ValueError) as error_info: 70 test_config([1, 0], c_transforms.Compose([])) 71 assert "op_list can not be empty." in str(error_info.value) 72 73 # Test Python compose op 74 assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[0, 1], [1, 0]] 75 assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[0, 2], 76 [2, 0]] 77 78 # Test nested Python compose op 79 assert test_config([1, 0], 80 py_transforms.Compose([py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)])) \ 81 == [[0, 2], [2, 0]] 82 83 # Test passing a list of Python ops without Compose wrapper 84 assert test_config([1, 0], 85 [py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)]) \ 86 == [[0, 2], [2, 0]] 87 assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[0, 2], [2, 0]] 88 89 # Test a non callable function 90 with pytest.raises(ValueError) as error_info: 91 py_transforms.Compose([1]) 92 assert "transforms[0] is not callable." in str(error_info.value) 93 94 # Test empty Python op list 95 with pytest.raises(ValueError) as error_info: 96 test_config([1, 0], py_transforms.Compose([])) 97 assert "transforms list is empty." in str(error_info.value) 98 99 # Pass in extra brackets 100 with pytest.raises(TypeError) as error_info: 101 py_transforms.Compose([(lambda x: x + x)])() 102 assert "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])())." in str( 103 error_info.value) 104 105 106def test_lambdas(): 107 """ 108 Test Multi Column Python Compose Op 109 """ 110 ds.config.set_seed(0) 111 112 def test_config(arr, input_columns, output_cols, op_list): 113 data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False) 114 data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols, 115 column_order=output_cols) 116 res = [] 117 for i in data.create_dict_iterator(output_numpy=True): 118 for col_name in output_cols: 119 res.append(i[col_name].tolist()) 120 return res 121 122 arr = ([[1]], [[3]]) 123 124 assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([(lambda x, y: x)])) == [[1]] 125 assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([lambda x, y: x, lambda x: x])) == [[1]] 126 assert test_config(arr, ["col0", "col1"], ["a", "b"], 127 py_transforms.Compose([lambda x, y: x, lambda x: (x, x * 2)])) == \ 128 [[1], [2]] 129 assert test_config(arr, ["col0", "col1"], ["a", "b"], 130 [lambda x, y: (x, x + y), lambda x, y: (x, y * 2)]) == [[1], [8]] 131 132 133def test_c_py_compose_transforms_module(): 134 """ 135 Test combining Python and C++ transforms 136 """ 137 ds.config.set_seed(0) 138 139 def test_config(arr, input_columns, output_cols, op_list): 140 data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False) 141 data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols, 142 column_order=output_cols) 143 res = [] 144 for i in data.create_dict_iterator(output_numpy=True): 145 for col_name in output_cols: 146 res.append(i[col_name].tolist()) 147 return res 148 149 arr = [1, 0] 150 assert test_config(arr, ["cols"], ["cols"], 151 [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]) == \ 152 [[False, True], 153 [True, False]] 154 assert test_config(arr, ["cols"], ["cols"], 155 [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1)]) \ 156 == [[1, 1], [1, 1]] 157 assert test_config(arr, ["cols"], ["cols"], 158 [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1), (lambda x: x + x)]) \ 159 == [[2, 2], [2, 2]] 160 assert test_config([[1, 3]], ["cols"], ["cols"], 161 [c_transforms.PadEnd([3], -1), (lambda x: x + x)]) \ 162 == [[2, 6, -2]] 163 164 arr = ([[1]], [[3]]) 165 assert test_config(arr, ["col0", "col1"], ["a"], [(lambda x, y: x + y), c_transforms.PadEnd([2], -1)]) == [[4, -1]] 166 167 168def test_c_py_compose_vision_module(plot=False, run_golden=True): 169 """ 170 Test combining Python and C++ vision transforms 171 """ 172 original_seed = config_get_set_seed(10) 173 original_num_parallel_workers = config_get_set_num_parallel_workers(1) 174 175 def test_config(plot, file_name, op_list): 176 data_dir = "../data/dataset/testImageNetData/train/" 177 data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) 178 data1 = data1.map(operations=op_list, input_columns=["image"]) 179 data2 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) 180 data2 = data2.map(operations=c_vision.Decode(), input_columns=["image"]) 181 original_images = [] 182 transformed_images = [] 183 184 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 185 transformed_images.append(item["image"]) 186 for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True): 187 original_images.append(item["image"]) 188 189 if run_golden: 190 # Compare with expected md5 from images 191 save_and_check_md5(data1, file_name, generate_golden=GENERATE_GOLDEN) 192 193 if plot: 194 visualize_list(original_images, transformed_images) 195 196 test_config(op_list=[c_vision.Decode(), 197 py_vision.ToPIL(), 198 py_vision.Resize((224, 224)), 199 np.array], 200 plot=plot, file_name="compose_c_py_1.npz") 201 202 test_config(op_list=[c_vision.Decode(), 203 c_vision.Resize((224, 244)), 204 py_vision.ToPIL(), 205 np.array, 206 c_vision.Resize((24, 24))], 207 plot=plot, file_name="compose_c_py_2.npz") 208 209 test_config(op_list=[py_vision.Decode(), 210 py_vision.Resize((224, 224)), 211 np.array, 212 c_vision.RandomColor()], 213 plot=plot, file_name="compose_c_py_3.npz") 214 215 # Restore configuration 216 ds.config.set_seed(original_seed) 217 ds.config.set_num_parallel_workers((original_num_parallel_workers)) 218 219 220def test_py_transforms_with_c_vision(): 221 """ 222 These examples will fail, as c_transform should not be used in py_transforms.Random(Apply/Choice/Order) 223 """ 224 225 ds.config.set_seed(0) 226 227 def test_config(op_list): 228 data_dir = "../data/dataset/testImageNetData/train/" 229 data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) 230 data = data.map(operations=op_list) 231 res = [] 232 for i in data.create_dict_iterator(output_numpy=True): 233 for col_name in output_cols: 234 res.append(i[col_name].tolist()) 235 return res 236 237 with pytest.raises(ValueError) as error_info: 238 test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)])) 239 assert "transforms[0] is not a py transforms." in str(error_info.value) 240 241 with pytest.raises(ValueError) as error_info: 242 test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)])) 243 assert "transforms[0] is not a py transforms." in str(error_info.value) 244 245 with pytest.raises(ValueError) as error_info: 246 test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)])) 247 assert "transforms[1] is not a py transforms." in str(error_info.value) 248 249 with pytest.raises(RuntimeError) as error_info: 250 test_config([py_transforms.OneHotOp(20, 0.1)]) 251 assert "is smaller than the category number" in str(error_info.value) 252 253 254def test_py_vision_with_c_transforms(): 255 """ 256 Test combining Python vision operations with C++ transforms operations 257 """ 258 259 ds.config.set_seed(0) 260 261 def test_config(op_list): 262 data_dir = "../data/dataset/testImageNetData/train/" 263 data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False) 264 data1 = data1.map(operations=op_list, input_columns=["image"]) 265 transformed_images = [] 266 267 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 268 transformed_images.append(item["image"]) 269 return transformed_images 270 271 # Test with Mask Op 272 output_arr = test_config([py_vision.Decode(), 273 py_vision.CenterCrop((2)), np.array, 274 c_transforms.Mask(c_transforms.Relational.GE, 100)]) 275 276 exp_arr = [np.array([[[True, False, False], 277 [True, False, False]], 278 [[True, False, False], 279 [True, False, False]]]), 280 np.array([[[True, False, False], 281 [True, False, False]], 282 [[True, False, False], 283 [True, False, False]]])] 284 285 for exp_a, output in zip(exp_arr, output_arr): 286 np.testing.assert_array_equal(exp_a, output) 287 288 # Test with Fill Op 289 output_arr = test_config([py_vision.Decode(), 290 py_vision.CenterCrop((4)), np.array, 291 c_transforms.Fill(10)]) 292 293 exp_arr = [np.ones((4, 4, 3)) * 10] * 2 294 for exp_a, output in zip(exp_arr, output_arr): 295 np.testing.assert_array_equal(exp_a, output) 296 297 # Test with Concatenate Op, which will raise an error since ConcatenateOp only supports rank 1 tensors. 298 with pytest.raises(RuntimeError) as error_info: 299 test_config([py_vision.Decode(), 300 py_vision.CenterCrop((2)), np.array, 301 c_transforms.Concatenate(0)]) 302 assert "only 1D input supported" in str(error_info.value) 303 304 305def test_compose_with_custom_function(): 306 """ 307 Test Python Compose with custom function 308 """ 309 310 def custom_function(x): 311 return (x, x * x) 312 313 # First dataset 314 op_list = [ 315 lambda x: x * 3, 316 custom_function, 317 # convert two column output to one 318 lambda *images: np.stack(images) 319 ] 320 321 data = ds.NumpySlicesDataset([[1, 2]], column_names=["col0"], shuffle=False) 322 data = data.map(input_columns=["col0"], operations=op_list) 323 # 324 325 res = [] 326 for i in data.create_dict_iterator(output_numpy=True): 327 res.append(i["col0"].tolist()) 328 assert res == [[[3, 6], [9, 36]]] 329 330 331if __name__ == "__main__": 332 test_compose() 333 test_lambdas() 334 test_c_py_compose_transforms_module() 335 test_c_py_compose_vision_module(plot=True) 336 test_py_transforms_with_c_vision() 337 test_py_vision_with_c_transforms() 338 test_compose_with_custom_function() 339