• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing SlicePatches Python API
17"""
18import functools
19import numpy as np
20
21import mindspore.dataset as ds
22import mindspore.dataset.vision.c_transforms as c_vision
23import mindspore.dataset.vision.utils as mode
24
25from mindspore import log as logger
26from util import diff_mse, visualize_list
27
28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
30
31
32def test_slice_patches_01(plot=False):
33    """
34     slice rgb image(100, 200) to 4 patches
35    """
36    slice_to_patches([100, 200], 2, 2, True, plot=plot)
37
38
39def test_slice_patches_02(plot=False):
40    """
41     no op
42    """
43    slice_to_patches([100, 200], 1, 1, True, plot=plot)
44
45
46def test_slice_patches_03(plot=False):
47    """
48     slice rgb image(99, 199) to 4 patches in pad mode
49    """
50    slice_to_patches([99, 199], 2, 2, True, plot=plot)
51
52
53def test_slice_patches_04(plot=False):
54    """
55     slice rgb image(99, 199) to 4 patches in drop mode
56    """
57    slice_to_patches([99, 199], 2, 2, False, plot=plot)
58
59
60def test_slice_patches_05(plot=False):
61    """
62     slice rgb image(99, 199) to 4 patches in pad mode
63    """
64    slice_to_patches([99, 199], 2, 2, True, 255, plot=plot)
65
66
67def slice_to_patches(ori_size, num_h, num_w, pad_or_drop, fill_value=0, plot=False):
68    """
69    Tool function for slice patches
70    """
71    logger.info("test_slice_patches_pipeline")
72
73    cols = ['img' + str(x) for x in range(num_h*num_w)]
74    # First dataset
75    dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
76    decode_op = c_vision.Decode()
77    resize_op = c_vision.Resize(ori_size)  # H, W
78    slice_patches_op = c_vision.SlicePatches(
79        num_h, num_w, mode.SliceMode.PAD, fill_value)
80    if not pad_or_drop:
81        slice_patches_op = c_vision.SlicePatches(
82            num_h, num_w, mode.SliceMode.DROP)
83    dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
84    dataset1 = dataset1.map(operations=resize_op, input_columns=["image"])
85    dataset1 = dataset1.map(operations=slice_patches_op,
86                            input_columns=["image"], output_columns=cols, column_order=cols)
87    # Second dataset
88    dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
89    dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
90    dataset2 = dataset2.map(operations=resize_op, input_columns=["image"])
91    func_slice_patches = functools.partial(
92        slice_patches, num_h=num_h, num_w=num_w, pad_or_drop=pad_or_drop, fill_value=fill_value)
93    dataset2 = dataset2.map(operations=func_slice_patches,
94                            input_columns=["image"], output_columns=cols, column_order=cols)
95
96    num_iter = 0
97    patches_c = []
98    patches_py = []
99    for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
100                            dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
101
102        for x in range(num_h*num_w):
103            col = "img" + str(x)
104            mse = diff_mse(data1[col], data2[col])
105            logger.info("slice_patches_{}, mse: {}".format(num_iter + 1, mse))
106            assert mse == 0
107            patches_c.append(data1[col])
108            patches_py.append(data2[col])
109        num_iter += 1
110    if plot:
111        visualize_list(patches_py, patches_c)
112
113
114def test_slice_patches_exception_01():
115    """
116    Test SlicePatches with invalid parameters
117    """
118    logger.info("test_Slice_Patches_exception")
119    try:
120        _ = c_vision.SlicePatches(0, 2)
121    except ValueError as e:
122        logger.info("Got an exception in SlicePatches: {}".format(str(e)))
123        assert "Input num_height is not within" in str(e)
124
125    try:
126        _ = c_vision.SlicePatches(2, 0)
127    except ValueError as e:
128        logger.info("Got an exception in SlicePatches: {}".format(str(e)))
129        assert "Input num_width is not within" in str(e)
130
131    try:
132        _ = c_vision.SlicePatches(2, 2, 1)
133    except TypeError as e:
134        logger.info("Got an exception in SlicePatches: {}".format(str(e)))
135        assert "Argument slice_mode with value" in str(e)
136
137    try:
138        _ = c_vision.SlicePatches(2, 2, mode.SliceMode.PAD, -1)
139    except ValueError as e:
140        logger.info("Got an exception in SlicePatches: {}".format(str(e)))
141        assert "Input fill_value is not within" in str(e)
142
143def test_slice_patches_06():
144    image = np.random.randint(0, 255, (158, 126, 1)).astype(np.int32)
145    slice_patches_op = c_vision.SlicePatches(2, 8)
146    patches = slice_patches_op(image)
147    assert len(patches) == 16
148    assert patches[0].shape == (79, 16, 1)
149
150def test_slice_patches_07():
151    image = np.random.randint(0, 255, (158, 126)).astype(np.int32)
152    slice_patches_op = c_vision.SlicePatches(2, 8)
153    patches = slice_patches_op(image)
154    assert len(patches) == 16
155    assert patches[0].shape == (79, 16)
156
157def test_slice_patches_08():
158    np_data = np.random.randint(0, 255, (1, 56, 82, 256)).astype(np.uint8)
159    dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
160    slice_patches_op = c_vision.SlicePatches(2, 2)
161    dataset = dataset.map(input_columns=["image"], output_columns=["img0", "img1", "img2", "img3"],
162                          column_order=["img0", "img1", "img2", "img3"],
163                          operations=slice_patches_op)
164    for item in dataset.create_dict_iterator(output_numpy=True):
165        patch_shape = item['img0'].shape
166        assert patch_shape == (28, 41, 256)
167
168def test_slice_patches_09():
169    image = np.random.randint(0, 255, (56, 82, 256)).astype(np.uint8)
170    slice_patches_op = c_vision.SlicePatches(4, 3, mode.SliceMode.PAD)
171    patches = slice_patches_op(image)
172    assert len(patches) == 12
173    assert patches[0].shape == (14, 28, 256)
174
175def skip_test_slice_patches_10():
176    image = np.random.randint(0, 255, (7000, 7000, 255)).astype(np.uint8)
177    slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP)
178    patches = slice_patches_op(image)
179    assert patches[0].shape == (700, 538, 255)
180
181def skip_test_slice_patches_11():
182    np_data = np.random.randint(0, 255, (1, 7000, 7000, 256)).astype(np.uint8)
183    dataset = ds.NumpySlicesDataset(np_data, column_names=["image"])
184    slice_patches_op = c_vision.SlicePatches(10, 13, mode.SliceMode.DROP)
185    cols = ['img' + str(x) for x in range(10*13)]
186    dataset = dataset.map(input_columns=["image"], output_columns=cols,
187                          column_order=cols, operations=slice_patches_op)
188    for item in dataset.create_dict_iterator(output_numpy=True):
189        patch_shape = item['img0'].shape
190        assert patch_shape == (700, 538, 256)
191
192def slice_patches(image, num_h, num_w, pad_or_drop, fill_value):
193    """ help function which slice patches with numpy """
194    if num_h == 1 and num_w == 1:
195        return image
196    # (H, W, C)
197    H, W, C = image.shape
198    patch_h = H // num_h
199    patch_w = W // num_w
200    if H % num_h != 0:
201        if pad_or_drop:
202            patch_h += 1
203    if W % num_w != 0:
204        if pad_or_drop:
205            patch_w += 1
206    img = image[:, :, :]
207    if pad_or_drop:
208        img = np.full([patch_h*num_h, patch_w*num_w, C], fill_value, dtype=np.uint8)
209        img[:H, :W] = image[:, :, :]
210    patches = []
211    for top in range(num_h):
212        for left in range(num_w):
213            patches.append(img[top*patch_h:(top+1)*patch_h,
214                               left*patch_w:(left+1)*patch_w, :])
215
216    return (*patches,)
217
218
219if __name__ == "__main__":
220    test_slice_patches_01(plot=True)
221    test_slice_patches_02(plot=True)
222    test_slice_patches_03(plot=True)
223    test_slice_patches_04(plot=True)
224    test_slice_patches_05(plot=True)
225    test_slice_patches_06()
226    test_slice_patches_07()
227    test_slice_patches_08()
228    test_slice_patches_09()
229    test_slice_patches_exception_01()
230