1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18import mindspore.dataset.transforms.c_transforms as c 19import mindspore.dataset.transforms.py_transforms as f 20import mindspore.dataset.vision.c_transforms as c_vision 21import mindspore.dataset.vision.py_transforms as py_vision 22from mindspore import log as logger 23 24DATA_DIR = "../data/dataset/testImageNetData/train" 25DATA_DIR_2 = "../data/dataset/testImageNetData2/train" 26 27 28def test_one_hot_op(): 29 """ 30 Test one hot encoding op 31 """ 32 logger.info("Test one hot encoding op") 33 34 # define map operations 35 # ds = de.ImageFolderDataset(DATA_DIR, schema=SCHEMA_DIR) 36 dataset = ds.ImageFolderDataset(DATA_DIR) 37 num_classes = 2 38 epsilon_para = 0.1 39 40 transforms = [f.OneHotOp(num_classes=num_classes, smoothing_rate=epsilon_para)] 41 transform_label = f.Compose(transforms) 42 dataset = dataset.map(operations=transform_label, input_columns=["label"]) 43 44 golden_label = np.ones(num_classes) * epsilon_para / num_classes 45 golden_label[1] = 1 - epsilon_para / num_classes 46 47 for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): 48 label = data["label"] 49 logger.info("label is {}".format(label)) 50 logger.info("golden_label is {}".format(golden_label)) 51 assert label.all() == golden_label.all() 52 logger.info("====test one hot op ok====") 53 54 55def test_mix_up_single(): 56 """ 57 Test single batch mix up op 58 """ 59 logger.info("Test single batch mix up op") 60 61 resize_height = 224 62 resize_width = 224 63 64 # Create dataset and define map operations 65 ds1 = ds.ImageFolderDataset(DATA_DIR_2) 66 67 num_classes = 10 68 decode_op = c_vision.Decode() 69 resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR) 70 one_hot_encode = c.OneHot(num_classes) # num_classes is input argument 71 72 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 73 ds1 = ds1.map(operations=resize_op, input_columns=["image"]) 74 ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"]) 75 76 # apply batch operations 77 batch_size = 3 78 ds1 = ds1.batch(batch_size, drop_remainder=True) 79 80 ds2 = ds1 81 alpha = 0.2 82 transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=True) 83 ] 84 ds1 = ds1.map(operations=transforms, input_columns=["image", "label"]) 85 86 for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True), 87 ds2.create_dict_iterator(num_epochs=1, output_numpy=True)): 88 image1 = data1["image"] 89 label = data1["label"] 90 logger.info("label is {}".format(label)) 91 92 image2 = data2["image"] 93 label2 = data2["label"] 94 logger.info("label2 is {}".format(label2)) 95 96 lam = np.abs(label - label2) 97 for index in range(batch_size - 1): 98 if np.square(lam[index]).mean() != 0: 99 lam_value = 1 - np.sum(lam[index]) / 2 100 img_golden = lam_value * image2[index] + (1 - lam_value) * image2[index + 1] 101 assert image1[index].all() == img_golden.all() 102 logger.info("====test single batch mixup ok====") 103 104 105def test_mix_up_multi(): 106 """ 107 Test multi batch mix up op 108 """ 109 logger.info("Test several batch mix up op") 110 111 resize_height = 224 112 resize_width = 224 113 114 # Create dataset and define map operations 115 ds1 = ds.ImageFolderDataset(DATA_DIR_2) 116 117 num_classes = 3 118 decode_op = c_vision.Decode() 119 resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR) 120 one_hot_encode = c.OneHot(num_classes) # num_classes is input argument 121 122 ds1 = ds1.map(operations=decode_op, input_columns=["image"]) 123 ds1 = ds1.map(operations=resize_op, input_columns=["image"]) 124 ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"]) 125 126 # apply batch operations 127 batch_size = 3 128 ds1 = ds1.batch(batch_size, drop_remainder=True) 129 130 ds2 = ds1 131 alpha = 0.2 132 transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=False) 133 ] 134 ds1 = ds1.map(operations=transforms, input_columns=["image", "label"]) 135 num_iter = 0 136 batch1_image1 = 0 137 for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True), 138 ds2.create_dict_iterator(num_epochs=1, output_numpy=True)): 139 image1 = data1["image"] 140 label1 = data1["label"] 141 logger.info("label: {}".format(label1)) 142 143 image2 = data2["image"] 144 label2 = data2["label"] 145 logger.info("label2: {}".format(label2)) 146 147 if num_iter == 0: 148 batch1_image1 = image1 149 150 if num_iter == 1: 151 lam = np.abs(label2 - label1) 152 logger.info("lam value in multi: {}".format(lam)) 153 for index in range(batch_size): 154 if np.square(lam[index]).mean() != 0: 155 lam_value = 1 - np.sum(lam[index]) / 2 156 img_golden = lam_value * image2[index] + (1 - lam_value) * batch1_image1[index] 157 assert image1[index].all() == img_golden.all() 158 logger.info("====test several batch mixup ok====") 159 break 160 num_iter += 1 161 162 163if __name__ == "__main__": 164 test_one_hot_op() 165 test_mix_up_single() 166 test_mix_up_multi() 167