1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing RandomChoice op in DE 17""" 18import numpy as np 19import mindspore.dataset as ds 20import mindspore.dataset.transforms.py_transforms as py_transforms 21import mindspore.dataset.vision.py_transforms as py_vision 22from mindspore import log as logger 23from util import visualize_list, diff_mse 24 25DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 26SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 27 28 29def test_random_choice_op(plot=False): 30 """ 31 Test RandomChoice in python transformations 32 """ 33 logger.info("test_random_choice_op") 34 # define map operations 35 transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)] 36 transforms1 = [ 37 py_vision.Decode(), 38 py_transforms.RandomChoice(transforms_list), 39 py_vision.ToTensor() 40 ] 41 transform1 = py_transforms.Compose(transforms1) 42 43 transforms2 = [ 44 py_vision.Decode(), 45 py_vision.ToTensor() 46 ] 47 transform2 = py_transforms.Compose(transforms2) 48 49 # First dataset 50 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 51 data1 = data1.map(operations=transform1, input_columns=["image"]) 52 # Second dataset 53 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 54 data2 = data2.map(operations=transform2, input_columns=["image"]) 55 56 image_choice = [] 57 image_original = [] 58 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 59 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 60 image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 61 image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 62 image_choice.append(image1) 63 image_original.append(image2) 64 if plot: 65 visualize_list(image_original, image_choice) 66 67 68def test_random_choice_comp(plot=False): 69 """ 70 Test RandomChoice and compare with single CenterCrop results 71 """ 72 logger.info("test_random_choice_comp") 73 # define map operations 74 transforms_list = [py_vision.CenterCrop(64)] 75 transforms1 = [ 76 py_vision.Decode(), 77 py_transforms.RandomChoice(transforms_list), 78 py_vision.ToTensor() 79 ] 80 transform1 = py_transforms.Compose(transforms1) 81 82 transforms2 = [ 83 py_vision.Decode(), 84 py_vision.CenterCrop(64), 85 py_vision.ToTensor() 86 ] 87 transform2 = py_transforms.Compose(transforms2) 88 89 # First dataset 90 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 91 data1 = data1.map(operations=transform1, input_columns=["image"]) 92 # Second dataset 93 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 94 data2 = data2.map(operations=transform2, input_columns=["image"]) 95 96 image_choice = [] 97 image_original = [] 98 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 99 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 100 image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 101 image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) 102 image_choice.append(image1) 103 image_original.append(image2) 104 105 mse = diff_mse(image1, image2) 106 assert mse == 0 107 if plot: 108 visualize_list(image_original, image_choice) 109 110 111def test_random_choice_exception_random_crop_badinput(): 112 """ 113 Test RandomChoice: hit error in RandomCrop with greater crop size, 114 expected to raise error 115 """ 116 logger.info("test_random_choice_exception_random_crop_badinput") 117 # define map operations 118 # note: crop size[5000, 5000] > image size[4032, 2268] 119 transforms_list = [py_vision.RandomCrop(5000)] 120 transforms = [ 121 py_vision.Decode(), 122 py_transforms.RandomChoice(transforms_list), 123 py_vision.ToTensor() 124 ] 125 transform = py_transforms.Compose(transforms) 126 # Generate dataset 127 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 128 data = data.map(operations=transform, input_columns=["image"]) 129 try: 130 _ = data.create_dict_iterator(num_epochs=1).__next__() 131 except RuntimeError as e: 132 logger.info("Got an exception in DE: {}".format(str(e))) 133 assert "Crop size" in str(e) 134 135 136if __name__ == '__main__': 137 test_random_choice_op(plot=True) 138 test_random_choice_comp(plot=True) 139 test_random_choice_exception_random_crop_badinput() 140