• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RandomPerspective op in DE
17"""
18import numpy as np
19import mindspore.dataset as ds
20import mindspore.dataset.transforms.py_transforms
21import mindspore.dataset.vision.py_transforms as py_vision
22from mindspore.dataset.vision.utils import Inter
23from mindspore import log as logger
24from util import visualize_list, save_and_check_md5, \
25    config_get_set_seed, config_get_set_num_parallel_workers
26
27GENERATE_GOLDEN = False
28
29DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
30SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
31
32
33def test_random_perspective_op(plot=False):
34    """
35    Test RandomPerspective in python transformations
36    """
37    logger.info("test_random_perspective_op")
38    # define map operations
39    transforms1 = [
40        py_vision.Decode(),
41        py_vision.RandomPerspective(),
42        py_vision.ToTensor()
43    ]
44    transform1 = mindspore.dataset.transforms.py_transforms.Compose(transforms1)
45
46    transforms2 = [
47        py_vision.Decode(),
48        py_vision.ToTensor()
49    ]
50    transform2 = mindspore.dataset.transforms.py_transforms.Compose(transforms2)
51
52    #  First dataset
53    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
54    data1 = data1.map(operations=transform1, input_columns=["image"])
55    #  Second dataset
56    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
57    data2 = data2.map(operations=transform2, input_columns=["image"])
58
59    image_perspective = []
60    image_original = []
61    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
62                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
63        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
64        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
65        image_perspective.append(image1)
66        image_original.append(image2)
67    if plot:
68        visualize_list(image_original, image_perspective)
69
70
71def skip_test_random_perspective_md5():
72    """
73    Test RandomPerspective with md5 comparison
74    """
75    logger.info("test_random_perspective_md5")
76    original_seed = config_get_set_seed(5)
77    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
78
79    # define map operations
80    transforms = [
81        py_vision.Decode(),
82        py_vision.RandomPerspective(distortion_scale=0.3, prob=0.7,
83                                    interpolation=Inter.BILINEAR),
84        py_vision.Resize(1450),  # resize to a smaller size to prevent round-off error
85        py_vision.ToTensor()
86    ]
87    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
88
89    #  Generate dataset
90    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
91    data = data.map(operations=transform, input_columns=["image"])
92
93    # check results with md5 comparison
94    filename = "random_perspective_01_result.npz"
95    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
96
97    # Restore configuration
98    ds.config.set_seed(original_seed)
99    ds.config.set_num_parallel_workers((original_num_parallel_workers))
100
101
102def test_random_perspective_exception_distortion_scale_range():
103    """
104    Test RandomPerspective: distortion_scale is not in [0, 1], expected to raise ValueError
105    """
106    logger.info("test_random_perspective_exception_distortion_scale_range")
107    try:
108        _ = py_vision.RandomPerspective(distortion_scale=1.5)
109    except ValueError as e:
110        logger.info("Got an exception in DE: {}".format(str(e)))
111        assert str(e) == "Input distortion_scale is not within the required interval of [0.0, 1.0]."
112
113
114def test_random_perspective_exception_prob_range():
115    """
116    Test RandomPerspective: prob is not in [0, 1], expected to raise ValueError
117    """
118    logger.info("test_random_perspective_exception_prob_range")
119    try:
120        _ = py_vision.RandomPerspective(prob=1.2)
121    except ValueError as e:
122        logger.info("Got an exception in DE: {}".format(str(e)))
123        assert str(e) == "Input prob is not within the required interval of [0.0, 1.0]."
124
125
126if __name__ == "__main__":
127    test_random_perspective_op(plot=True)
128    skip_test_random_perspective_md5()
129    test_random_perspective_exception_distortion_scale_range()
130    test_random_perspective_exception_prob_range()
131