1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing CenterCrop op in DE 17""" 18import numpy as np 19import mindspore.dataset as ds 20import mindspore.dataset.transforms.py_transforms 21import mindspore.dataset.vision.c_transforms as vision 22import mindspore.dataset.vision.py_transforms as py_vision 23from mindspore import log as logger 24from util import diff_mse, visualize_list, save_and_check_md5 25 26GENERATE_GOLDEN = False 27 28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 30 31 32def test_center_crop_op(height=375, width=375, plot=False): 33 """ 34 Test CenterCrop 35 """ 36 logger.info("Test CenterCrop") 37 38 # First dataset 39 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) 40 decode_op = vision.Decode() 41 # 3 images [375, 500] [600, 500] [512, 512] 42 center_crop_op = vision.CenterCrop([height, width]) 43 data1 = data1.map(operations=decode_op, input_columns=["image"]) 44 data1 = data1.map(operations=center_crop_op, input_columns=["image"]) 45 46 # Second dataset 47 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) 48 data2 = data2.map(operations=decode_op, input_columns=["image"]) 49 50 image_cropped = [] 51 image = [] 52 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 53 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 54 image_cropped.append(item1["image"].copy()) 55 image.append(item2["image"].copy()) 56 if plot: 57 visualize_list(image, image_cropped) 58 59 60def test_center_crop_md5(height=375, width=375): 61 """ 62 Test CenterCrop 63 """ 64 logger.info("Test CenterCrop") 65 66 # First dataset 67 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 68 decode_op = vision.Decode() 69 # 3 images [375, 500] [600, 500] [512, 512] 70 center_crop_op = vision.CenterCrop([height, width]) 71 data1 = data1.map(operations=decode_op, input_columns=["image"]) 72 data1 = data1.map(operations=center_crop_op, input_columns=["image"]) 73 # Compare with expected md5 from images 74 filename = "center_crop_01_result.npz" 75 save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) 76 77 78def test_center_crop_comp(height=375, width=375, plot=False): 79 """ 80 Test CenterCrop between python and c image augmentation 81 """ 82 logger.info("Test CenterCrop") 83 84 # First dataset 85 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 86 decode_op = vision.Decode() 87 center_crop_op = vision.CenterCrop([height, width]) 88 data1 = data1.map(operations=decode_op, input_columns=["image"]) 89 data1 = data1.map(operations=center_crop_op, input_columns=["image"]) 90 91 # Second dataset 92 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 93 transforms = [ 94 py_vision.Decode(), 95 py_vision.CenterCrop([height, width]), 96 py_vision.ToTensor() 97 ] 98 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 99 data2 = data2.map(operations=transform, input_columns=["image"]) 100 101 image_c_cropped = [] 102 image_py_cropped = [] 103 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 104 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 105 c_image = item1["image"] 106 py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 107 # Note: The images aren't exactly the same due to rounding error 108 assert diff_mse(py_image, c_image) < 0.001 109 image_c_cropped.append(c_image.copy()) 110 image_py_cropped.append(py_image.copy()) 111 if plot: 112 visualize_list(image_c_cropped, image_py_cropped, visualize_mode=2) 113 114 115def test_crop_grayscale(height=375, width=375): 116 """ 117 Test that centercrop works with pad and grayscale images 118 """ 119 120 # Note: image.transpose performs channel swap to allow py transforms to 121 # work with c transforms 122 transforms = [ 123 py_vision.Decode(), 124 py_vision.Grayscale(1), 125 py_vision.ToTensor(), 126 (lambda image: (image.transpose(1, 2, 0) * 255).astype(np.uint8)) 127 ] 128 129 transform = mindspore.dataset.transforms.py_transforms.Compose(transforms) 130 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 131 data1 = data1.map(operations=transform, input_columns=["image"]) 132 133 # If input is grayscale, the output dimensions should be single channel 134 crop_gray = vision.CenterCrop([height, width]) 135 data1 = data1.map(operations=crop_gray, input_columns=["image"]) 136 137 for item1 in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 138 c_image = item1["image"] 139 140 # Check that the image is grayscale 141 assert (c_image.ndim == 3 and c_image.shape[2] == 1) 142 143 144def test_center_crop_errors(): 145 """ 146 Test that CenterCropOp errors with bad input 147 """ 148 try: 149 test_center_crop_op(16777216, 16777216) 150 except RuntimeError as e: 151 assert "CenterCropOp padding size is more than 3 times the original size" in \ 152 str(e) 153 154 155if __name__ == "__main__": 156 test_center_crop_op(600, 600, plot=True) 157 test_center_crop_op(300, 600) 158 test_center_crop_op(600, 300) 159 test_center_crop_md5() 160 test_center_crop_comp(plot=True) 161 test_crop_grayscale() 162