1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing SlicePatches Python API 17""" 18import functools 19import numpy as np 20 21import mindspore.dataset as ds 22import mindspore.dataset.vision.c_transforms as c_vision 23import mindspore.dataset.vision.utils as mode 24 25from mindspore import log as logger 26from util import diff_mse, visualize_list 27 28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 30 31 32def test_slice_patches_01(plot=False): 33 """ 34 slice rgb image(100, 200) to 4 patches 35 """ 36 slice_to_patches([100, 200], 2, 2, True, plot=plot) 37 38 39def test_slice_patches_02(plot=False): 40 """ 41 no op 42 """ 43 slice_to_patches([100, 200], 1, 1, True, plot=plot) 44 45 46def test_slice_patches_03(plot=False): 47 """ 48 slice rgb image(99, 199) to 4 patches in pad mode 49 """ 50 slice_to_patches([99, 199], 2, 2, True, plot=plot) 51 52 53def test_slice_patches_04(plot=False): 54 """ 55 slice rgb image(99, 199) to 4 patches in drop mode 56 """ 57 slice_to_patches([99, 199], 2, 2, False, plot=plot) 58 59 60def test_slice_patches_05(plot=False): 61 """ 62 slice rgb image(99, 199) to 4 patches in pad mode 63 """ 64 slice_to_patches([99, 199], 2, 2, True, 255, plot=plot) 65 66 67def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False): 68 """ 69 Tool function for slice patches 70 """ 71 logger.info("test_slice_patches_pipeline") 72 73 cols = ['img' + str(x) for x in range(num_h*num_w)] 74 # First dataset 75 dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 76 decode_op = c_vision.Decode() 77 resize_op = c_vision.Resize(ori_size) # H, W 78 slice_patches_op = c_vision.SlicePatches( 79 num_h, num_w, mode.SliceMode.PAD, fill_value) 80 if not pad_or_drop: 81 slice_patches_op = c_vision.SlicePatches( 82 num_h, num_w, mode.SliceMode.DROP) 83 dataset1 = dataset1.map(operations=decode_op, input_columns=["image"]) 84 dataset1 = dataset1.map(operations=resize_op, input_columns=["image"]) 85 dataset1 = dataset1.map(operations=slice_patches_op, 86 input_columns=["image"], output_columns=cols, column_order=cols) 87 # Second dataset 88 dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 89 dataset2 = dataset2.map(operations=decode_op, input_columns=["image"]) 90 dataset2 = dataset2.map(operations=resize_op, input_columns=["image"]) 91 func_slice_patches = functools.partial( 92 slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value) 93 dataset2 = dataset2.map(operations=func_slice_patches, 94 input_columns=["image"], output_columns=cols, column_order=cols) 95 96 num_iter = 0 97 patches_c = [] 98 patches_py = [] 99 for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True), 100 dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)): 101 102 for x in range(num_h*num_w): 103 col = "img" + str(x) 104 mse = diff_mse(data1[col], data2[col]) 105 logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse)) 106 assert mse == 0 107 patches_c.append(data1[col]) 108 patches_py.append(data2[col]) 109 num_iter += 1 110 if plot: 111 visualize_list(patches_py, patches_c) 112 113 114def test_slice_patches_exception_01(): 115 """ 116 Test SlicePatches with invalid parameters 117 """ 118 logger.info("test_Slice_Patches_exception") 119 try: 120 _ = c_vision.SlicePatches(0, 2) 121 except ValueError as e: 122 logger.info("Got an exception in SlicePatches: {}".format(str(e))) 123 assert "Input num_height is not within" in str(e) 124 125 try: 126 _ = c_vision.SlicePatches(2, 0) 127 except ValueError as e: 128 logger.info("Got an exception in SlicePatches: {}".format(str(e))) 129 assert "Input num_width is not within" in str(e) 130 131 try: 132 _ = c_vision.SlicePatches(2, 2, 1) 133 except TypeError as e: 134 logger.info("Got an exception in SlicePatches: {}".format(str(e))) 135 assert "Argument slice_mode with value" in str(e) 136 137 try: 138 _ = c_vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1) 139 except ValueError as e: 140 logger.info("Got an exception in SlicePatches: {}".format(str(e))) 141 assert "Input fill_value is not within" in str(e) 142 143def test_slice_patches_06(): 144 image = np.random.randint(0, 255, (158, 126, 1)).astype(np.int32) 145 slice_patches_op = c_vision.SlicePatches(2, 8) 146 patches = slice_patches_op(image) 147 assert len(patches) == 16 148 assert patches[0].shape == (79, 16, 1) 149 150def test_slice_patches_07(): 151 image = np.random.randint(0, 255, (158, 126)).astype(np.int32) 152 slice_patches_op = c_vision.SlicePatches(2, 8) 153 patches = slice_patches_op(image) 154 assert len(patches) == 16 155 assert patches[0].shape == (79, 16) 156 157def test_slice_patches_08(): 158 np_data = np.random.randint(0, 255, (1, 56, 82, 256)).astype(np.uint8) 159 dataset = ds.NumpySlicesDataset(np_data, column_names=["image"]) 160 slice_patches_op = c_vision.SlicePatches(2, 2) 161 dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"], 162 column_order=["img0", "img1", "img2", "img3"], 163 operations=slice_patches_op) 164 for item in dataset.create_dict_iterator(output_numpy=True): 165 patch_shape = item['img0'].shape 166 assert patch_shape == (28, 41, 256) 167 168def test_slice_patches_09(): 169 image = np.random.randint(0, 255, (56, 82, 256)).astype(np.uint8) 170 slice_patches_op = c_vision.SlicePatches(4, 3, mode.SliceMode.PAD) 171 patches = slice_patches_op(image) 172 assert len(patches) == 12 173 assert patches[0].shape == (14, 28, 256) 174 175def skip_test_slice_patches_10(): 176 image = np.random.randint(0, 255, (7000, 7000, 255)).astype(np.uint8) 177 slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP) 178 patches = slice_patches_op(image) 179 assert patches[0].shape == (700, 538, 255) 180 181def skip_test_slice_patches_11(): 182 np_data = np.random.randint(0, 255, (1, 7000, 7000, 256)).astype(np.uint8) 183 dataset = ds.NumpySlicesDataset(np_data, column_names=["image"]) 184 slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP) 185 cols = ['img' + str(x) for x in range(10*13)] 186 dataset = dataset.map(input_columns=["image"], output_columns=cols, 187 column_order=cols, operations=slice_patches_op) 188 for item in dataset.create_dict_iterator(output_numpy=True): 189 patch_shape = item['img0'].shape 190 assert patch_shape == (700, 538, 256) 191 192def slice_patches(image, num_h, num_w, pad_or_drop, fill_value): 193 """ help function which slice patches with numpy """ 194 if num_h == 1 and num_w == 1: 195 return image 196 # (H, W, C) 197 H, W, C = image.shape 198 patch_h = H // num_h 199 patch_w = W // num_w 200 if H % num_h != 0: 201 if pad_or_drop: 202 patch_h += 1 203 if W % num_w != 0: 204 if pad_or_drop: 205 patch_w += 1 206 img = image[:, :, :] 207 if pad_or_drop: 208 img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8) 209 img[:H, :W] = image[:, :, :] 210 patches = [] 211 for top in range(num_h): 212 for left in range(num_w): 213 patches.append(img[top*patch_h:(top+1)*patch_h, 214 left*patch_w:(left+1)*patch_w, :]) 215 216 return (*patches,) 217 218 219if __name__ == "__main__": 220 test_slice_patches_01(plot=True) 221 test_slice_patches_02(plot=True) 222 test_slice_patches_03(plot=True) 223 test_slice_patches_04(plot=True) 224 test_slice_patches_05(plot=True) 225 test_slice_patches_06() 226 test_slice_patches_07() 227 test_slice_patches_08() 228 test_slice_patches_09() 229 test_slice_patches_exception_01() 230