• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15Testing TenCrop in DE
16"""
17import pytest
18import numpy as np
19
20import mindspore.dataset as ds
21import mindspore.dataset.transforms.py_transforms
22import mindspore.dataset.vision.py_transforms as vision
23from mindspore import log as logger
24from util import visualize_list, save_and_check_md5
25
26GENERATE_GOLDEN = False
27
28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
30
31
32def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
33    """
34    Utility function for testing TenCrop. Input arguments are given by other tests
35    """
36    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
37    transforms_1 = [
38        vision.Decode(),
39        vision.ToTensor(),
40    ]
41    transform_1 = mindspore.dataset.transforms.py_transforms.Compose(transforms_1)
42    data1 = data1.map(operations=transform_1, input_columns=["image"])
43
44    # Second dataset
45    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
46    transforms_2 = [
47        vision.Decode(),
48        vision.TenCrop(crop_size, use_vertical_flip=vertical_flip),
49        lambda *images: np.stack([vision.ToTensor()(image) for image in images])  # 4D stack of 10 images
50    ]
51    transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
52    data2 = data2.map(operations=transform_2, input_columns=["image"])
53    num_iter = 0
54    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
55                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
56        num_iter += 1
57        image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
58        image_2 = item2["image"]
59
60        logger.info("shape of image_1: {}".format(image_1.shape))
61        logger.info("shape of image_2: {}".format(image_2.shape))
62
63        logger.info("dtype of image_1: {}".format(image_1.dtype))
64        logger.info("dtype of image_2: {}".format(image_2.dtype))
65
66        if plot:
67            visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
68
69        # The output data should be of a 4D tensor shape, a stack of 10 images.
70        assert len(image_2.shape) == 4
71        assert image_2.shape[0] == 10
72
73
74def test_ten_crop_op_square(plot=False):
75    """
76    Tests TenCrop for a square crop
77    """
78
79    logger.info("test_ten_crop_op_square")
80    util_test_ten_crop(200, plot=plot)
81
82
83def test_ten_crop_op_rectangle(plot=False):
84    """
85    Tests TenCrop for a rectangle crop
86    """
87
88    logger.info("test_ten_crop_op_rectangle")
89    util_test_ten_crop((200, 150), plot=plot)
90
91
92def test_ten_crop_op_vertical_flip(plot=False):
93    """
94    Tests TenCrop with vertical flip set to True
95    """
96
97    logger.info("test_ten_crop_op_vertical_flip")
98    util_test_ten_crop(200, vertical_flip=True, plot=plot)
99
100
101def test_ten_crop_md5():
102    """
103    Tests TenCrops for giving the same results in multiple runs.
104    Since TenCrop is a deterministic function, we expect it to return the same result for a specific input every time
105    """
106    logger.info("test_ten_crop_md5")
107
108    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
109    transforms_2 = [
110        vision.Decode(),
111        vision.TenCrop((200, 100), use_vertical_flip=True),
112        lambda *images: np.stack([vision.ToTensor()(image) for image in images])  # 4D stack of 10 images
113    ]
114    transform_2 = mindspore.dataset.transforms.py_transforms.Compose(transforms_2)
115    data2 = data2.map(operations=transform_2, input_columns=["image"])
116    # Compare with expected md5 from images
117    filename = "ten_crop_01_result.npz"
118    save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN)
119
120
121def test_ten_crop_list_size_error_msg():
122    """
123    Tests TenCrop error message when the size arg has more than 2 elements
124    """
125    logger.info("test_ten_crop_list_size_error_msg")
126
127    with pytest.raises(TypeError) as info:
128        _ = [
129            vision.Decode(),
130            vision.TenCrop([200, 200, 200]),
131            lambda images: np.stack([vision.ToTensor()(image) for image in images])  # 4D stack of 10 images
132        ]
133    error_msg = "Size should be a single integer or a list/tuple (h, w) of length 2."
134    assert error_msg == str(info.value)
135
136
137def test_ten_crop_invalid_size_error_msg():
138    """
139    Tests TenCrop error message when the size arg is not positive
140    """
141    logger.info("test_ten_crop_invalid_size_error_msg")
142
143    with pytest.raises(ValueError) as info:
144        _ = [
145            vision.Decode(),
146            vision.TenCrop(0),
147            lambda images: np.stack([vision.ToTensor()(image) for image in images])  # 4D stack of 10 images
148        ]
149    error_msg = "Input is not within the required interval of [1, 16777216]."
150    assert error_msg == str(info.value)
151
152    with pytest.raises(ValueError) as info:
153        _ = [
154            vision.Decode(),
155            vision.TenCrop(-10),
156            lambda images: np.stack([vision.ToTensor()(image) for image in images])  # 4D stack of 10 images
157        ]
158
159    assert error_msg == str(info.value)
160
161
162def test_ten_crop_wrong_img_error_msg():
163    """
164    Tests TenCrop error message when the image is not in the correct format.
165    """
166    logger.info("test_ten_crop_wrong_img_error_msg")
167
168    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
169    transforms = [
170        vision.Decode(),
171        vision.TenCrop(200),
172        vision.ToTensor()
173    ]
174    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
175    data = data.map(operations=transform, input_columns=["image"])
176
177    with pytest.raises(RuntimeError) as info:
178        data.create_tuple_iterator(num_epochs=1).__next__()
179    error_msg = "TypeError: __call__() takes 2 positional arguments but 11 were given"
180
181    # error msg comes from ToTensor()
182    assert error_msg in str(info.value)
183
184
185if __name__ == "__main__":
186    test_ten_crop_op_square(plot=True)
187    test_ten_crop_op_rectangle(plot=True)
188    test_ten_crop_op_vertical_flip(plot=True)
189    test_ten_crop_md5()
190    test_ten_crop_list_size_error_msg()
191    test_ten_crop_invalid_size_error_msg()
192    test_ten_crop_wrong_img_error_msg()
193