• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RandomSolarizeOp op in DE
17"""
18import pytest
19import mindspore.dataset as ds
20import mindspore.dataset.vision.c_transforms as vision
21from mindspore import log as logger
22from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers, \
23    visualize_one_channel_dataset
24
25GENERATE_GOLDEN = False
26
27MNIST_DATA_DIR = "../data/dataset/testMnistData"
28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
30
31
32def test_random_solarize_op(threshold=(10, 150), plot=False, run_golden=True):
33    """
34    Test RandomSolarize
35    """
36    logger.info("Test RandomSolarize")
37
38    # First dataset
39    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
40    decode_op = vision.Decode()
41
42    original_seed = config_get_set_seed(0)
43    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
44
45    if threshold is None:
46        solarize_op = vision.RandomSolarize()
47    else:
48        solarize_op = vision.RandomSolarize(threshold)
49
50    data1 = data1.map(operations=decode_op, input_columns=["image"])
51    data1 = data1.map(operations=solarize_op, input_columns=["image"])
52
53    # Second dataset
54    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
55    data2 = data2.map(operations=decode_op, input_columns=["image"])
56
57    if run_golden:
58        filename = "random_solarize_01_result.npz"
59        save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
60
61    image_solarized = []
62    image = []
63
64    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
65                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
66        image_solarized.append(item1["image"].copy())
67        image.append(item2["image"].copy())
68    if plot:
69        visualize_list(image, image_solarized)
70
71    ds.config.set_seed(original_seed)
72    ds.config.set_num_parallel_workers(original_num_parallel_workers)
73
74
75def test_random_solarize_mnist(plot=False, run_golden=True):
76    """
77    Test RandomSolarize op with MNIST dataset (Grayscale images)
78    """
79
80    mnist_1 = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
81    mnist_2 = ds.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
82    mnist_2 = mnist_2.map(operations=vision.RandomSolarize((0, 255)), input_columns="image")
83
84    images = []
85    images_trans = []
86    labels = []
87
88    for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)):
89        image_orig, label_orig = data_orig
90        image_trans, _ = data_trans
91        images.append(image_orig.asnumpy())
92        labels.append(label_orig.asnumpy())
93        images_trans.append(image_trans.asnumpy())
94
95    if plot:
96        visualize_one_channel_dataset(images, images_trans, labels)
97
98    if run_golden:
99        filename = "random_solarize_02_result.npz"
100        save_and_check_md5(mnist_2, filename, generate_golden=GENERATE_GOLDEN)
101
102
103def test_random_solarize_errors():
104    """
105    Test that RandomSolarize errors with bad input
106    """
107    with pytest.raises(ValueError) as error_info:
108        vision.RandomSolarize((12, 1))
109    assert "threshold must be in min max format numbers" in str(error_info.value)
110
111    with pytest.raises(ValueError) as error_info:
112        vision.RandomSolarize((12, 1000))
113    assert "Input is not within the required interval of [0, 255]." in str(error_info.value)
114
115    with pytest.raises(TypeError) as error_info:
116        vision.RandomSolarize((122.1, 140))
117    assert "Argument threshold[0] with value 122.1 is not of type [<class 'int'>]" in str(error_info.value)
118
119    with pytest.raises(ValueError) as error_info:
120        vision.RandomSolarize((122, 100, 30))
121    assert "threshold must be a sequence of two numbers" in str(error_info.value)
122
123    with pytest.raises(ValueError) as error_info:
124        vision.RandomSolarize((120,))
125    assert "threshold must be a sequence of two numbers" in str(error_info.value)
126
127
128if __name__ == "__main__":
129    test_random_solarize_op((10, 150), plot=True, run_golden=True)
130    test_random_solarize_op((12, 120), plot=True, run_golden=False)
131    test_random_solarize_op(plot=True, run_golden=False)
132    test_random_solarize_mnist(plot=True, run_golden=True)
133    test_random_solarize_errors()
134