• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RandomGrayscale op in DE
17"""
18import numpy as np
19
20import mindspore.dataset.transforms.py_transforms
21import mindspore.dataset.vision.py_transforms as py_vision
22import mindspore.dataset as ds
23from mindspore import log as logger
24from util import save_and_check_md5, visualize_list, \
25    config_get_set_seed, config_get_set_num_parallel_workers
26
27GENERATE_GOLDEN = False
28
29DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
30SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
31
32
33def test_random_grayscale_valid_prob(plot=False):
34    """
35    Test RandomGrayscale Op: valid input, expect to pass
36    """
37    logger.info("test_random_grayscale_valid_prob")
38
39    # First dataset
40    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
41    transforms1 = [
42        py_vision.Decode(),
43        # Note: prob is 1 so the output should always be grayscale images
44        py_vision.RandomGrayscale(1),
45        py_vision.ToTensor()
46    ]
47    transform1 = mindspore.dataset.transforms.py_transforms.Compose(transforms1)
48    data1 = data1.map(operations=transform1, input_columns=["image"])
49
50    # Second dataset
51    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
52    transforms2 = [
53        py_vision.Decode(),
54        py_vision.ToTensor()
55    ]
56    transform2 = mindspore.dataset.transforms.py_transforms.Compose(transforms2)
57    data2 = data2.map(operations=transform2, input_columns=["image"])
58
59    image_gray = []
60    image = []
61    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
62                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
63        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
64        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
65        image_gray.append(image1)
66        image.append(image2)
67    if plot:
68        visualize_list(image, image_gray)
69
70
71def test_random_grayscale_input_grayscale_images():
72    """
73    Test RandomGrayscale Op: valid parameter with grayscale images as input, expect to pass
74    """
75    logger.info("test_random_grayscale_input_grayscale_images")
76    original_seed = config_get_set_seed(0)
77    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
78
79    # First dataset
80    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
81    transforms1 = [
82        py_vision.Decode(),
83        py_vision.Grayscale(1),
84        # Note: If the input images is grayscale image with 1 channel.
85        py_vision.RandomGrayscale(0.5),
86        py_vision.ToTensor()
87    ]
88    transform1 = mindspore.dataset.transforms.py_transforms.Compose(transforms1)
89    data1 = data1.map(operations=transform1, input_columns=["image"])
90
91    # Second dataset
92    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
93    transforms2 = [
94        py_vision.Decode(),
95        py_vision.ToTensor()
96    ]
97    transform2 = mindspore.dataset.transforms.py_transforms.Compose(transforms2)
98    data2 = data2.map(operations=transform2, input_columns=["image"])
99
100    image_gray = []
101    image = []
102    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
103                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
104        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
105        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
106        image_gray.append(image1)
107        image.append(image2)
108
109        assert len(image1.shape) == 3
110        assert image1.shape[2] == 1
111        assert len(image2.shape) == 3
112        assert image2.shape[2] == 3
113
114    # Restore config
115    ds.config.set_seed(original_seed)
116    ds.config.set_num_parallel_workers(original_num_parallel_workers)
117
118
119def test_random_grayscale_md5_valid_input():
120    """
121    Test RandomGrayscale with md5 comparison: valid parameter, expect to pass
122    """
123    logger.info("test_random_grayscale_md5_valid_input")
124    original_seed = config_get_set_seed(0)
125    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
126
127    # Generate dataset
128    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
129    transforms = [
130        py_vision.Decode(),
131        py_vision.RandomGrayscale(0.8),
132        py_vision.ToTensor()
133    ]
134    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
135    data = data.map(operations=transform, input_columns=["image"])
136
137    # Check output images with md5 comparison
138    filename = "random_grayscale_01_result.npz"
139    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
140
141    # Restore config
142    ds.config.set_seed(original_seed)
143    ds.config.set_num_parallel_workers(original_num_parallel_workers)
144
145
146def test_random_grayscale_md5_no_param():
147    """
148    Test RandomGrayscale with md5 comparison: no parameter given, expect to pass
149    """
150    logger.info("test_random_grayscale_md5_no_param")
151    original_seed = config_get_set_seed(0)
152    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
153
154    # Generate dataset
155    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
156    transforms = [
157        py_vision.Decode(),
158        py_vision.RandomGrayscale(),
159        py_vision.ToTensor()
160    ]
161    transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
162    data = data.map(operations=transform, input_columns=["image"])
163
164    # Check output images with md5 comparison
165    filename = "random_grayscale_02_result.npz"
166    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
167
168    # Restore config
169    ds.config.set_seed(original_seed)
170    ds.config.set_num_parallel_workers(original_num_parallel_workers)
171
172
173def test_random_grayscale_invalid_param():
174    """
175    Test RandomGrayscale: invalid parameter given, expect to raise error
176    """
177    logger.info("test_random_grayscale_invalid_param")
178
179    # Generate dataset
180    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
181    try:
182        transforms = [
183            py_vision.Decode(),
184            py_vision.RandomGrayscale(1.5),
185            py_vision.ToTensor()
186        ]
187        transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
188        data = data.map(operations=transform, input_columns=["image"])
189    except ValueError as e:
190        logger.info("Got an exception in DE: {}".format(str(e)))
191        assert "Input prob is not within the required interval of [0.0, 1.0]." in str(e)
192
193
194if __name__ == "__main__":
195    test_random_grayscale_valid_prob(True)
196    test_random_grayscale_input_grayscale_images()
197    test_random_grayscale_md5_valid_input()
198    test_random_grayscale_md5_no_param()
199    test_random_grayscale_invalid_param()
200