1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing HWC2CHW op in DE 17""" 18import numpy as np 19import pytest 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.py_transforms 22import mindspore.dataset.vision.c_transforms as c_vision 23import mindspore.dataset.vision.py_transforms as py_vision 24from mindspore import log as logger 25from util import diff_mse, visualize_list, save_and_check_md5 26 27GENERATE_GOLDEN = False 28 29DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 30SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 31 32 33def test_HWC2CHW_callable(): 34 """ 35 Test HWC2CHW is callable 36 """ 37 logger.info("Test HWC2CHW callable") 38 img = np.zeros([50, 50, 3]) 39 assert img.shape == (50, 50, 3) 40 41 # test one tensor 42 img1 = c_vision.HWC2CHW()(img) 43 assert img1.shape == (3, 50, 50) 44 45 # test input multiple tensors 46 with pytest.raises(RuntimeError) as info: 47 imgs = [img, img] 48 _ = c_vision.HWC2CHW()(*imgs) 49 assert "The op is OneToOne, can only accept one tensor as input." in str(info.value) 50 51 with pytest.raises(RuntimeError) as info: 52 _ = c_vision.HWC2CHW()(img, img) 53 assert "The op is OneToOne, can only accept one tensor as input." in str(info.value) 54 55 56def test_HWC2CHW(plot=False): 57 """ 58 Test HWC2CHW 59 """ 60 logger.info("Test HWC2CHW") 61 62 # First dataset 63 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 64 decode_op = c_vision.Decode() 65 hwc2chw_op = c_vision.HWC2CHW() 66 data1 = data1.map(operations=decode_op, input_columns=["image"]) 67 data1 = data1.map(operations=hwc2chw_op, input_columns=["image"]) 68 69 # Second dataset 70 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 71 data2 = data2.map(operations=decode_op, input_columns=["image"]) 72 73 image_transposed = [] 74 image = [] 75 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 76 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 77 transposed_item = item1["image"].copy() 78 original_item = item2["image"].copy() 79 image_transposed.append(transposed_item.transpose(1, 2, 0)) 80 image.append(original_item) 81 82 # check if the shape of data is transposed correctly 83 # transpose the original image from shape (H,W,C) to (C,H,W) 84 mse = diff_mse(transposed_item, original_item.transpose(2, 0, 1)) 85 assert mse == 0 86 if plot: 87 visualize_list(image, image_transposed) 88 89 90def test_HWC2CHW_md5(): 91 """ 92 Test HWC2CHW(md5) 93 """ 94 logger.info("Test HWC2CHW with md5 comparison") 95 96 # First dataset 97 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 98 decode_op = c_vision.Decode() 99 hwc2chw_op = c_vision.HWC2CHW() 100 data1 = data1.map(operations=decode_op, input_columns=["image"]) 101 data1 = data1.map(operations=hwc2chw_op, input_columns=["image"]) 102 103 # Compare with expected md5 from images 104 filename = "HWC2CHW_01_result.npz" 105 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 106 107 108def test_HWC2CHW_comp(plot=False): 109 """ 110 Test HWC2CHW between python and c image augmentation 111 """ 112 logger.info("Test HWC2CHW with c_transform and py_transform comparison") 113 114 # First dataset 115 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 116 decode_op = c_vision.Decode() 117 hwc2chw_op = c_vision.HWC2CHW() 118 data1 = data1.map(operations=decode_op, input_columns=["image"]) 119 data1 = data1.map(operations=hwc2chw_op, input_columns=["image"]) 120 121 # Second dataset 122 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 123 transforms = [ 124 py_vision.Decode(), 125 py_vision.ToTensor(), 126 py_vision.HWC2CHW() 127 ] 128 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 129 data2 = data2.map(operations=transform, input_columns=["image"]) 130 131 image_c_transposed = [] 132 image_py_transposed = [] 133 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 134 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 135 c_image = item1["image"] 136 py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 137 138 # Compare images between that applying c_transform and py_transform 139 mse = diff_mse(py_image, c_image) 140 # Note: The images aren't exactly the same due to rounding error 141 assert mse < 0.001 142 image_c_transposed.append(c_image.transpose(1, 2, 0)) 143 image_py_transposed.append(py_image.transpose(1, 2, 0)) 144 if plot: 145 visualize_list(image_c_transposed, image_py_transposed, visualize_mode=2) 146 147 148if __name__ == '__main__': 149 test_HWC2CHW_callable() 150 test_HWC2CHW(True) 151 test_HWC2CHW_md5() 152 test_HWC2CHW_comp(True) 153