• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing the rescale op in DE
17"""
18import mindspore.dataset as ds
19import mindspore.dataset.vision.c_transforms as vision
20from mindspore import log as logger
21from util import visualize_image, diff_mse, save_and_check_md5
22
23DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
24SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
25
26GENERATE_GOLDEN = False
27
28
29def rescale_np(image):
30    """
31    Apply the rescale
32    """
33    image = image / 255.0
34    image = image - 1.0
35    return image
36
37
38def get_rescaled(image_id):
39    """
40    Reads the image using DE ops and then rescales using Numpy
41    """
42    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
43    decode_op = vision.Decode()
44    data1 = data1.map(operations=decode_op, input_columns=["image"])
45    num_iter = 0
46    for item in data1.create_dict_iterator(num_epochs=1):
47        image = item["image"].asnumpy()
48        if num_iter == image_id:
49            return rescale_np(image)
50        num_iter += 1
51
52    return None
53
54
55def test_rescale_op(plot=False):
56    """
57    Test rescale
58    """
59    logger.info("Test rescale")
60    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
61
62    # define map operations
63    decode_op = vision.Decode()
64    rescale_op = vision.Rescale(1.0 / 255.0, -1.0)
65
66    # apply map operations on images
67    data1 = data1.map(operations=decode_op, input_columns=["image"])
68
69    data2 = data1.map(operations=rescale_op, input_columns=["image"])
70
71    num_iter = 0
72    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
73                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
74        image_original = item1["image"]
75        image_de_rescaled = item2["image"]
76        image_np_rescaled = get_rescaled(num_iter)
77        mse = diff_mse(image_de_rescaled, image_np_rescaled)
78        assert mse < 0.001  # rounding error
79        logger.info("image_{}, mse: {}".format(num_iter + 1, mse))
80        num_iter += 1
81        if plot:
82            visualize_image(image_original, image_de_rescaled, mse, image_np_rescaled)
83
84
85def test_rescale_md5():
86    """
87    Test Rescale with md5 comparison
88    """
89    logger.info("Test Rescale with md5 comparison")
90
91    # generate dataset
92    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
93    decode_op = vision.Decode()
94    rescale_op = vision.Rescale(1.0 / 255.0, -1.0)
95
96    # apply map operations on images
97    data = data.map(operations=decode_op, input_columns=["image"])
98    data = data.map(operations=rescale_op, input_columns=["image"])
99
100    # check results with md5 comparison
101    filename = "rescale_01_result.npz"
102    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
103
104
105if __name__ == "__main__":
106    test_rescale_op(plot=True)
107    test_rescale_md5()
108