• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing the random vertical flip op in DE
17"""
18import numpy as np
19import mindspore.dataset as ds
20import mindspore.dataset.transforms.py_transforms
21import mindspore.dataset.transforms.c_transforms as ops
22import mindspore.dataset.vision.c_transforms as c_vision
23import mindspore.dataset.vision.py_transforms as py_vision
24from mindspore import log as logger
25from util import save_and_check_md5, visualize_list, visualize_image, diff_mse, \
26    config_get_set_seed, config_get_set_num_parallel_workers
27
28GENERATE_GOLDEN = False
29
30DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
31SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
32
33
34def v_flip(image):
35    """
36    Apply the random_vertical
37    """
38
39    # with the seed provided in this test case, it will always flip.
40    # that's why we flip here too
41    image = image[::-1, :, :]
42    return image
43
44
45def test_random_vertical_op(plot=False):
46    """
47    Test random_vertical with default probability
48    """
49    logger.info("Test random_vertical")
50
51    # First dataset
52    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
53    decode_op = c_vision.Decode()
54    random_vertical_op = c_vision.RandomVerticalFlip(1.0)
55    data1 = data1.map(operations=decode_op, input_columns=["image"])
56    data1 = data1.map(operations=random_vertical_op, input_columns=["image"])
57
58    # Second dataset
59    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
60    data2 = data2.map(operations=decode_op, input_columns=["image"])
61
62    num_iter = 0
63    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
64                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
65
66        # with the seed value, we can only guarantee the first number generated
67        if num_iter > 0:
68            break
69
70        image_v_flipped = item1["image"]
71        image = item2["image"]
72        image_v_flipped_2 = v_flip(image)
73
74        mse = diff_mse(image_v_flipped, image_v_flipped_2)
75        assert mse == 0
76        logger.info("image_{}, mse: {}".format(num_iter + 1, mse))
77        num_iter += 1
78        if plot:
79            visualize_image(image, image_v_flipped, mse, image_v_flipped_2)
80
81
82def test_random_vertical_valid_prob_c():
83    """
84    Test RandomVerticalFlip op with c_transforms: valid non-default input, expect to pass
85    """
86    logger.info("test_random_vertical_valid_prob_c")
87    original_seed = config_get_set_seed(0)
88    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
89
90    # Generate dataset
91    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
92    decode_op = c_vision.Decode()
93    random_horizontal_op = c_vision.RandomVerticalFlip(0.8)
94    data = data.map(operations=decode_op, input_columns=["image"])
95    data = data.map(operations=random_horizontal_op, input_columns=["image"])
96
97    filename = "random_vertical_01_c_result.npz"
98    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
99
100    # Restore config setting
101    ds.config.set_seed(original_seed)
102    ds.config.set_num_parallel_workers(original_num_parallel_workers)
103
104
105def test_random_vertical_valid_prob_py():
106    """
107    Test RandomVerticalFlip op with py_transforms: valid non-default input, expect to pass
108    """
109    logger.info("test_random_vertical_valid_prob_py")
110    original_seed = config_get_set_seed(0)
111    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
112
113    # Generate dataset
114    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
115    transforms = [
116        py_vision.Decode(),
117        py_vision.RandomVerticalFlip(0.8),
118        py_vision.ToTensor()
119    ]
120    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
121    data = data.map(operations=transform, input_columns=["image"])
122
123    filename = "random_vertical_01_py_result.npz"
124    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
125
126    # Restore config setting
127    ds.config.set_seed(original_seed)
128    ds.config.set_num_parallel_workers(original_num_parallel_workers)
129
130
131def test_random_vertical_invalid_prob_c():
132    """
133    Test RandomVerticalFlip op in c_transforms: invalid input, expect to raise error
134    """
135    logger.info("test_random_vertical_invalid_prob_c")
136
137    # Generate dataset
138    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
139    decode_op = c_vision.Decode()
140    try:
141        # Note: Valid range of prob should be [0.0, 1.0]
142        random_horizontal_op = c_vision.RandomVerticalFlip(1.5)
143        data = data.map(operations=decode_op, input_columns=["image"])
144        data = data.map(operations=random_horizontal_op, input_columns=["image"])
145    except ValueError as e:
146        logger.info("Got an exception in DE: {}".format(str(e)))
147        assert 'Input prob is not within the required interval of [0.0, 1.0].' in str(e)
148
149
150def test_random_vertical_invalid_prob_py():
151    """
152    Test RandomVerticalFlip op in py_transforms: invalid input, expect to raise error
153    """
154    logger.info("test_random_vertical_invalid_prob_py")
155
156    # Generate dataset
157    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
158    try:
159        transforms = [
160            py_vision.Decode(),
161            # Note: Valid range of prob should be [0.0, 1.0]
162            py_vision.RandomVerticalFlip(1.5),
163            py_vision.ToTensor()
164        ]
165        transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
166        data = data.map(operations=transform, input_columns=["image"])
167    except ValueError as e:
168        logger.info("Got an exception in DE: {}".format(str(e)))
169        assert 'Input prob is not within the required interval of [0.0, 1.0].' in str(e)
170
171
172def test_random_vertical_comp(plot=False):
173    """
174    Test test_random_vertical_flip and compare between python and c image augmentation ops
175    """
176    logger.info("test_random_vertical_comp")
177
178    # First dataset
179    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
180    decode_op = c_vision.Decode()
181    # Note: The image must be flipped if prob is set to be 1
182    random_horizontal_op = c_vision.RandomVerticalFlip(1)
183    data1 = data1.map(operations=decode_op, input_columns=["image"])
184    data1 = data1.map(operations=random_horizontal_op, input_columns=["image"])
185
186    # Second dataset
187    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
188    transforms = [
189        py_vision.Decode(),
190        # Note: The image must be flipped if prob is set to be 1
191        py_vision.RandomVerticalFlip(1),
192        py_vision.ToTensor()
193    ]
194    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
195    data2 = data2.map(operations=transform, input_columns=["image"])
196
197    images_list_c = []
198    images_list_py = []
199    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
200                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
201        image_c = item1["image"]
202        image_py = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
203        images_list_c.append(image_c)
204        images_list_py.append(image_py)
205
206        # Check if the output images are the same
207        mse = diff_mse(image_c, image_py)
208        assert mse < 0.001
209    if plot:
210        visualize_list(images_list_c, images_list_py, visualize_mode=2)
211
212def test_random_vertical_op_1():
213    """
214    Test RandomVerticalFlip with different fields.
215    """
216    logger.info("Test RandomVerticalFlip with different fields.")
217
218    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
219    data = data.map(operations=ops.Duplicate(), input_columns=["image"],
220                    output_columns=["image", "image_copy"], column_order=["image", "image_copy"])
221    random_vertical_op = c_vision.RandomVerticalFlip(1.0)
222    decode_op = c_vision.Decode()
223
224    data = data.map(operations=decode_op, input_columns=["image"])
225    data = data.map(operations=decode_op, input_columns=["image_copy"])
226    data = data.map(operations=random_vertical_op, input_columns=["image", "image_copy"])
227
228    num_iter = 0
229    for data1 in data.create_dict_iterator(num_epochs=1, output_numpy=True):
230        image = data1["image"]
231        image_copy = data1["image_copy"]
232        mse = diff_mse(image, image_copy)
233        logger.info("image_{}, mse: {}".format(num_iter + 1, mse))
234        assert mse == 0
235        num_iter += 1
236
237
238if __name__ == "__main__":
239    test_random_vertical_op(plot=True)
240    test_random_vertical_valid_prob_c()
241    test_random_vertical_valid_prob_py()
242    test_random_vertical_invalid_prob_c()
243    test_random_vertical_invalid_prob_py()
244    test_random_vertical_comp(plot=True)
245    test_random_vertical_op_1()
246