1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing the OneHot Op 17""" 18import numpy as np 19 20import mindspore.dataset as ds 21import mindspore.dataset.transforms.c_transforms as data_trans 22import mindspore.dataset.transforms.py_transforms as py_trans 23import mindspore.dataset.vision.c_transforms as c_vision 24from mindspore import log as logger 25from util import dataset_equal_with_function 26 27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 29 30 31def one_hot(index, depth): 32 """ 33 Apply the one_hot 34 """ 35 arr = np.zeros([1, depth], dtype=np.int32) 36 arr[0, index] = 1 37 return arr 38 39 40def test_one_hot(): 41 """ 42 Test OneHot Tensor Operator 43 """ 44 logger.info("test_one_hot") 45 46 depth = 10 47 48 # First dataset 49 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 50 one_hot_op = data_trans.OneHot(num_classes=depth) 51 data1 = data1.map(operations=one_hot_op, input_columns=["label"], column_order=["label"]) 52 53 # Second dataset 54 data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False) 55 56 assert dataset_equal_with_function(data1, data2, 0, one_hot, depth) 57 58 59def test_one_hot_post_aug(): 60 """ 61 Test One Hot Encoding after Multiple Data Augmentation Operators 62 """ 63 logger.info("test_one_hot_post_aug") 64 data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) 65 66 # Define data augmentation parameters 67 rescale = 1.0 / 255.0 68 shift = 0.0 69 resize_height, resize_width = 224, 224 70 71 # Define map operations 72 decode_op = c_vision.Decode() 73 rescale_op = c_vision.Rescale(rescale, shift) 74 resize_op = c_vision.Resize((resize_height, resize_width)) 75 76 # Apply map operations on images 77 data1 = data1.map(operations=decode_op, input_columns=["image"]) 78 data1 = data1.map(operations=rescale_op, input_columns=["image"]) 79 data1 = data1.map(operations=resize_op, input_columns=["image"]) 80 81 # Apply one-hot encoding on labels 82 depth = 4 83 one_hot_encode = data_trans.OneHot(depth) 84 data1 = data1.map(operations=one_hot_encode, input_columns=["label"]) 85 86 # Apply datasets ops 87 buffer_size = 100 88 seed = 10 89 batch_size = 2 90 ds.config.set_seed(seed) 91 data1 = data1.shuffle(buffer_size=buffer_size) 92 data1 = data1.batch(batch_size, drop_remainder=True) 93 94 num_iter = 0 95 for item in data1.create_dict_iterator(num_epochs=1): 96 logger.info("image is: {}".format(item["image"])) 97 logger.info("label is: {}".format(item["label"])) 98 num_iter += 1 99 100 assert num_iter == 1 101 102def test_one_hot_success(): 103 # success 104 class GetDatasetGenerator: 105 def __init__(self): 106 np.random.seed(58) 107 self.__data = np.random.sample((5, 2)) 108 self.__label = [] 109 for index in range(5): 110 self.__label.append(np.array(index)) 111 112 def __getitem__(self, index): 113 return (self.__data[index], self.__label[index]) 114 115 def __len__(self): 116 return len(self.__data) 117 118 dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False) 119 120 one_hot_encode = py_trans.OneHotOp(10) 121 trans = py_trans.Compose([one_hot_encode]) 122 dataset = dataset.map(operations=trans, input_columns=["label"]) 123 124 for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)): 125 assert item["label"][index] == 1.0 126 127def test_one_hot_success2(): 128 # success 129 class GetDatasetGenerator: 130 def __init__(self): 131 np.random.seed(58) 132 self.__data = np.random.sample((5, 2)) 133 self.__label = [] 134 for index in range(5): 135 self.__label.append(np.array([index])) 136 137 def __getitem__(self, index): 138 return (self.__data[index], self.__label[index]) 139 140 def __len__(self): 141 return len(self.__data) 142 143 dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False) 144 145 one_hot_encode = py_trans.OneHotOp(10) 146 trans = py_trans.Compose([one_hot_encode]) 147 dataset = dataset.map(operations=trans, input_columns=["label"]) 148 149 for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)): 150 logger.info(item) 151 assert item["label"][0][index] == 1.0 152 153def test_one_hot_success3(): 154 # success 155 class GetDatasetGenerator: 156 def __init__(self): 157 np.random.seed(58) 158 self.__data = np.random.sample((5, 2)) 159 self.__label = [] 160 for _ in range(5): 161 value = np.ones([10, 1], dtype=np.int32) 162 for i in range(10): 163 value[i][0] = i 164 self.__label.append(value) 165 166 def __getitem__(self, index): 167 return (self.__data[index], self.__label[index]) 168 169 def __len__(self): 170 return len(self.__data) 171 172 dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False) 173 174 one_hot_encode = py_trans.OneHotOp(10) 175 trans = py_trans.Compose([one_hot_encode]) 176 dataset = dataset.map(operations=trans, input_columns=["label"]) 177 178 for item in dataset.create_dict_iterator(output_numpy=True): 179 logger.info(item) 180 for i in range(10): 181 assert item["label"][i][0][i] == 1.0 182 183def test_one_hot_type_error(): 184 # type error 185 class GetDatasetGenerator: 186 def __init__(self): 187 np.random.seed(58) 188 self.__data = np.random.sample((5, 2)) 189 self.__label = [] 190 for index in range(5): 191 self.__label.append(np.array(float(index))) 192 193 def __getitem__(self, index): 194 return (self.__data[index], self.__label[index]) 195 196 def __len__(self): 197 return len(self.__data) 198 199 dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False) 200 201 one_hot_encode = py_trans.OneHotOp(10) 202 trans = py_trans.Compose([one_hot_encode]) 203 dataset = dataset.map(operations=trans, input_columns=["label"]) 204 205 try: 206 for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)): 207 assert item["label"][index] == 1.0 208 except RuntimeError as e: 209 assert "the input numpy type should be int" in str(e) 210 211if __name__ == "__main__": 212 test_one_hot() 213 test_one_hot_post_aug() 214 test_one_hot_success() 215 test_one_hot_success2() 216 test_one_hot_success3() 217 test_one_hot_type_error() 218