• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17import pytest
18import mindspore.common.dtype as mstype
19import mindspore.dataset as ds
20import mindspore.dataset.transforms.c_transforms as c_transforms
21import mindspore.dataset.transforms.py_transforms as py_transforms
22
23import mindspore.dataset.vision.c_transforms as c_vision
24import mindspore.dataset.vision.py_transforms as py_vision
25
26from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers
27
28GENERATE_GOLDEN = False
29
30
31def test_compose():
32    """
33    Test C++ and Python Compose Op
34    """
35    ds.config.set_seed(0)
36
37    def test_config(arr, op_list):
38        try:
39            data = ds.NumpySlicesDataset(arr, column_names="col", shuffle=False)
40            data = data.map(input_columns=["col"], operations=op_list)
41            res = []
42            for i in data.create_dict_iterator(output_numpy=True):
43                res.append(i["col"].tolist())
44            return res
45        except (TypeError, ValueError) as e:
46            return str(e)
47
48    # Test simple compose with only 1 op, this would generate a warning
49    assert test_config([[1, 0], [3, 4]], c_transforms.Compose([c_transforms.Fill(2)])) == [[2, 2], [2, 2]]
50
51    # Test 1 column -> 2 columns -> 1 -> 2 -> 1
52    assert test_config([[1, 0]],
53                       c_transforms.Compose(
54                           [c_transforms.Duplicate(), c_transforms.Concatenate(), c_transforms.Duplicate(),
55                            c_transforms.Concatenate()])) \
56           == [[1, 0] * 4]
57
58    # Test one Python transform followed by a C++ transform. Type after OneHot is a float (mixed use-case)
59    assert test_config([1, 0],
60                       c_transforms.Compose([py_transforms.OneHotOp(2), c_transforms.TypeCast(mstype.int32)])) \
61           == [[0, 1], [1, 0]]
62
63    # Test exceptions.
64    with pytest.raises(TypeError) as error_info:
65        c_transforms.Compose([1, c_transforms.TypeCast(mstype.int32)])
66    assert "op_list[0] is neither a c_transform op (TensorOperation) nor a callable pyfunc." in str(error_info.value)
67
68    # Test empty op list
69    with pytest.raises(ValueError) as error_info:
70        test_config([1, 0], c_transforms.Compose([]))
71    assert "op_list can not be empty." in str(error_info.value)
72
73    # Test Python compose op
74    assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2)])) == [[0, 1], [1, 0]]
75    assert test_config([1, 0], py_transforms.Compose([py_transforms.OneHotOp(2), (lambda x: x + x)])) == [[0, 2],
76                                                                                                          [2, 0]]
77
78    # Test nested Python compose op
79    assert test_config([1, 0],
80                       py_transforms.Compose([py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)])) \
81           == [[0, 2], [2, 0]]
82
83    # Test passing a list of Python ops without Compose wrapper
84    assert test_config([1, 0],
85                       [py_transforms.Compose([py_transforms.OneHotOp(2)]), (lambda x: x + x)]) \
86           == [[0, 2], [2, 0]]
87    assert test_config([1, 0], [py_transforms.OneHotOp(2), (lambda x: x + x)]) == [[0, 2], [2, 0]]
88
89    # Test a non callable function
90    with pytest.raises(ValueError) as error_info:
91        py_transforms.Compose([1])
92    assert "transforms[0] is not callable." in str(error_info.value)
93
94    # Test empty Python op list
95    with pytest.raises(ValueError) as error_info:
96        test_config([1, 0], py_transforms.Compose([]))
97    assert "transforms list is empty." in str(error_info.value)
98
99    # Pass in extra brackets
100    with pytest.raises(TypeError) as error_info:
101        py_transforms.Compose([(lambda x: x + x)])()
102    assert "Compose was called without an image. Fix invocation (avoid it being invoked as Compose([...])())." in str(
103        error_info.value)
104
105
106def test_lambdas():
107    """
108    Test Multi Column Python Compose Op
109    """
110    ds.config.set_seed(0)
111
112    def test_config(arr, input_columns, output_cols, op_list):
113        data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False)
114        data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
115                        column_order=output_cols)
116        res = []
117        for i in data.create_dict_iterator(output_numpy=True):
118            for col_name in output_cols:
119                res.append(i[col_name].tolist())
120        return res
121
122    arr = ([[1]], [[3]])
123
124    assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([(lambda x, y: x)])) == [[1]]
125    assert test_config(arr, ["col0", "col1"], ["a"], py_transforms.Compose([lambda x, y: x, lambda x: x])) == [[1]]
126    assert test_config(arr, ["col0", "col1"], ["a", "b"],
127                       py_transforms.Compose([lambda x, y: x, lambda x: (x, x * 2)])) == \
128           [[1], [2]]
129    assert test_config(arr, ["col0", "col1"], ["a", "b"],
130                       [lambda x, y: (x, x + y), lambda x, y: (x, y * 2)]) == [[1], [8]]
131
132
133def test_c_py_compose_transforms_module():
134    """
135    Test combining Python and C++ transforms
136    """
137    ds.config.set_seed(0)
138
139    def test_config(arr, input_columns, output_cols, op_list):
140        data = ds.NumpySlicesDataset(arr, column_names=input_columns, shuffle=False)
141        data = data.map(operations=op_list, input_columns=input_columns, output_columns=output_cols,
142                        column_order=output_cols)
143        res = []
144        for i in data.create_dict_iterator(output_numpy=True):
145            for col_name in output_cols:
146                res.append(i[col_name].tolist())
147        return res
148
149    arr = [1, 0]
150    assert test_config(arr, ["cols"], ["cols"],
151                       [py_transforms.OneHotOp(2), c_transforms.Mask(c_transforms.Relational.EQ, 1)]) == \
152           [[False, True],
153            [True, False]]
154    assert test_config(arr, ["cols"], ["cols"],
155                       [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1)]) \
156           == [[1, 1], [1, 1]]
157    assert test_config(arr, ["cols"], ["cols"],
158                       [py_transforms.OneHotOp(2), (lambda x: x + x), c_transforms.Fill(1), (lambda x: x + x)]) \
159           == [[2, 2], [2, 2]]
160    assert test_config([[1, 3]], ["cols"], ["cols"],
161                       [c_transforms.PadEnd([3], -1), (lambda x: x + x)]) \
162           == [[2, 6, -2]]
163
164    arr = ([[1]], [[3]])
165    assert test_config(arr, ["col0", "col1"], ["a"], [(lambda x, y: x + y), c_transforms.PadEnd([2], -1)]) == [[4, -1]]
166
167
168def test_c_py_compose_vision_module(plot=False, run_golden=True):
169    """
170    Test combining Python and C++ vision transforms
171    """
172    original_seed = config_get_set_seed(10)
173    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
174
175    def test_config(plot, file_name, op_list):
176        data_dir = "../data/dataset/testImageNetData/train/"
177        data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
178        data1 = data1.map(operations=op_list, input_columns=["image"])
179        data2 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
180        data2 = data2.map(operations=c_vision.Decode(), input_columns=["image"])
181        original_images = []
182        transformed_images = []
183
184        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
185            transformed_images.append(item["image"])
186        for item in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
187            original_images.append(item["image"])
188
189        if run_golden:
190            # Compare with expected md5 from images
191            save_and_check_md5(data1, file_name, generate_golden=GENERATE_GOLDEN)
192
193        if plot:
194            visualize_list(original_images, transformed_images)
195
196    test_config(op_list=[c_vision.Decode(),
197                         py_vision.ToPIL(),
198                         py_vision.Resize((224, 224)),
199                         np.array],
200                plot=plot, file_name="compose_c_py_1.npz")
201
202    test_config(op_list=[c_vision.Decode(),
203                         c_vision.Resize((224, 244)),
204                         py_vision.ToPIL(),
205                         np.array,
206                         c_vision.Resize((24, 24))],
207                plot=plot, file_name="compose_c_py_2.npz")
208
209    test_config(op_list=[py_vision.Decode(),
210                         py_vision.Resize((224, 224)),
211                         np.array,
212                         c_vision.RandomColor()],
213                plot=plot, file_name="compose_c_py_3.npz")
214
215    # Restore configuration
216    ds.config.set_seed(original_seed)
217    ds.config.set_num_parallel_workers((original_num_parallel_workers))
218
219
220def test_py_transforms_with_c_vision():
221    """
222    These examples will fail, as c_transform should not be used in py_transforms.Random(Apply/Choice/Order)
223    """
224
225    ds.config.set_seed(0)
226
227    def test_config(op_list):
228        data_dir = "../data/dataset/testImageNetData/train/"
229        data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
230        data = data.map(operations=op_list)
231        res = []
232        for i in data.create_dict_iterator(output_numpy=True):
233            for col_name in output_cols:
234                res.append(i[col_name].tolist())
235        return res
236
237    with pytest.raises(ValueError) as error_info:
238        test_config(py_transforms.RandomApply([c_vision.RandomResizedCrop(200)]))
239    assert "transforms[0] is not a py transforms." in str(error_info.value)
240
241    with pytest.raises(ValueError) as error_info:
242        test_config(py_transforms.RandomChoice([c_vision.RandomResizedCrop(200)]))
243    assert "transforms[0] is not a py transforms." in str(error_info.value)
244
245    with pytest.raises(ValueError) as error_info:
246        test_config(py_transforms.RandomOrder([np.array, c_vision.RandomResizedCrop(200)]))
247    assert "transforms[1] is not a py transforms." in str(error_info.value)
248
249    with pytest.raises(RuntimeError) as error_info:
250        test_config([py_transforms.OneHotOp(20, 0.1)])
251    assert "is smaller than the category number" in str(error_info.value)
252
253
254def test_py_vision_with_c_transforms():
255    """
256    Test combining Python vision operations with C++ transforms operations
257    """
258
259    ds.config.set_seed(0)
260
261    def test_config(op_list):
262        data_dir = "../data/dataset/testImageNetData/train/"
263        data1 = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
264        data1 = data1.map(operations=op_list, input_columns=["image"])
265        transformed_images = []
266
267        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
268            transformed_images.append(item["image"])
269        return transformed_images
270
271    # Test with Mask Op
272    output_arr = test_config([py_vision.Decode(),
273                              py_vision.CenterCrop((2)), np.array,
274                              c_transforms.Mask(c_transforms.Relational.GE, 100)])
275
276    exp_arr = [np.array([[[True, False, False],
277                          [True, False, False]],
278                         [[True, False, False],
279                          [True, False, False]]]),
280               np.array([[[True, False, False],
281                          [True, False, False]],
282                         [[True, False, False],
283                          [True, False, False]]])]
284
285    for exp_a, output in zip(exp_arr, output_arr):
286        np.testing.assert_array_equal(exp_a, output)
287
288    # Test with Fill Op
289    output_arr = test_config([py_vision.Decode(),
290                              py_vision.CenterCrop((4)), np.array,
291                              c_transforms.Fill(10)])
292
293    exp_arr = [np.ones((4, 4, 3)) * 10] * 2
294    for exp_a, output in zip(exp_arr, output_arr):
295        np.testing.assert_array_equal(exp_a, output)
296
297    # Test with Concatenate Op, which will raise an error since ConcatenateOp only supports rank 1 tensors.
298    with pytest.raises(RuntimeError) as error_info:
299        test_config([py_vision.Decode(),
300                     py_vision.CenterCrop((2)), np.array,
301                     c_transforms.Concatenate(0)])
302    assert "only 1D input supported" in str(error_info.value)
303
304
305def test_compose_with_custom_function():
306    """
307    Test Python Compose with custom function
308    """
309
310    def custom_function(x):
311        return (x, x * x)
312
313    # First dataset
314    op_list = [
315        lambda x: x * 3,
316        custom_function,
317        # convert two column output to one
318        lambda *images: np.stack(images)
319    ]
320
321    data = ds.NumpySlicesDataset([[1, 2]], column_names=["col0"], shuffle=False)
322    data = data.map(input_columns=["col0"], operations=op_list)
323    #
324
325    res = []
326    for i in data.create_dict_iterator(output_numpy=True):
327        res.append(i["col0"].tolist())
328    assert res == [[[3, 6], [9, 36]]]
329
330
331if __name__ == "__main__":
332    test_compose()
333    test_lambdas()
334    test_c_py_compose_transforms_module()
335    test_c_py_compose_vision_module(plot=True)
336    test_py_transforms_with_c_vision()
337    test_py_vision_with_c_transforms()
338    test_compose_with_custom_function()
339