1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing the rescale op in DE 17""" 18import mindspore.dataset as ds 19import mindspore.dataset.vision.c_transforms as vision 20from mindspore import log as logger 21from util import visualize_image, diff_mse, save_and_check_md5 22 23DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 24SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 25 26GENERATE_GOLDEN = False 27 28 29def rescale_np(image): 30 """ 31 Apply the rescale 32 """ 33 image = image / 255.0 34 image = image - 1.0 35 return image 36 37 38def get_rescaled(image_id): 39 """ 40 Reads the image using DE ops and then rescales using Numpy 41 """ 42 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 43 decode_op = vision.Decode() 44 data1 = data1.map(operations=decode_op, input_columns=["image"]) 45 num_iter = 0 46 for item in data1.create_dict_iterator(num_epochs=1): 47 image = item["image"].asnumpy() 48 if num_iter == image_id: 49 return rescale_np(image) 50 num_iter += 1 51 52 return None 53 54 55def test_rescale_op(plot=False): 56 """ 57 Test rescale 58 """ 59 logger.info("Test rescale") 60 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 61 62 # define map operations 63 decode_op = vision.Decode() 64 rescale_op = vision.Rescale(1.0 / 255.0, -1.0) 65 66 # apply map operations on images 67 data1 = data1.map(operations=decode_op, input_columns=["image"]) 68 69 data2 = data1.map(operations=rescale_op, input_columns=["image"]) 70 71 num_iter = 0 72 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 73 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 74 image_original = item1["image"] 75 image_de_rescaled = item2["image"] 76 image_np_rescaled = get_rescaled(num_iter) 77 mse = diff_mse(image_de_rescaled, image_np_rescaled) 78 assert mse < 0.001 # rounding error 79 logger.info("image_{}, mse: {}".format(num_iter + 1, mse)) 80 num_iter += 1 81 if plot: 82 visualize_image(image_original, image_de_rescaled, mse, image_np_rescaled) 83 84 85def test_rescale_md5(): 86 """ 87 Test Rescale with md5 comparison 88 """ 89 logger.info("Test Rescale with md5 comparison") 90 91 # generate dataset 92 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 93 decode_op = vision.Decode() 94 rescale_op = vision.Rescale(1.0 / 255.0, -1.0) 95 96 # apply map operations on images 97 data = data.map(operations=decode_op, input_columns=["image"]) 98 data = data.map(operations=rescale_op, input_columns=["image"]) 99 100 # check results with md5 comparison 101 filename = "rescale_01_result.npz" 102 save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) 103 104 105if __name__ == "__main__": 106 test_rescale_op(plot=True) 107 test_rescale_md5() 108