• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing the OneHot Op
17"""
18import numpy as np
19
20import mindspore.dataset as ds
21import mindspore.dataset.transforms.c_transforms as data_trans
22import mindspore.dataset.transforms.py_transforms as py_trans
23import mindspore.dataset.vision.c_transforms as c_vision
24from mindspore import log as logger
25from util import dataset_equal_with_function
26
27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
29
30
31def one_hot(index, depth):
32    """
33    Apply the one_hot
34    """
35    arr = np.zeros([1, depth], dtype=np.int32)
36    arr[0, index] = 1
37    return arr
38
39
40def test_one_hot():
41    """
42    Test OneHot Tensor Operator
43    """
44    logger.info("test_one_hot")
45
46    depth = 10
47
48    # First dataset
49    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
50    one_hot_op = data_trans.OneHot(num_classes=depth)
51    data1 = data1.map(operations=one_hot_op, input_columns=["label"], column_order=["label"])
52
53    # Second dataset
54    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["label"], shuffle=False)
55
56    assert dataset_equal_with_function(data1, data2, 0, one_hot, depth)
57
58
59def test_one_hot_post_aug():
60    """
61    Test One Hot Encoding after Multiple Data Augmentation Operators
62    """
63    logger.info("test_one_hot_post_aug")
64    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
65
66    # Define data augmentation parameters
67    rescale = 1.0 / 255.0
68    shift = 0.0
69    resize_height, resize_width = 224, 224
70
71    # Define map operations
72    decode_op = c_vision.Decode()
73    rescale_op = c_vision.Rescale(rescale, shift)
74    resize_op = c_vision.Resize((resize_height, resize_width))
75
76    # Apply map operations on images
77    data1 = data1.map(operations=decode_op, input_columns=["image"])
78    data1 = data1.map(operations=rescale_op, input_columns=["image"])
79    data1 = data1.map(operations=resize_op, input_columns=["image"])
80
81    # Apply one-hot encoding on labels
82    depth = 4
83    one_hot_encode = data_trans.OneHot(depth)
84    data1 = data1.map(operations=one_hot_encode, input_columns=["label"])
85
86    # Apply datasets ops
87    buffer_size = 100
88    seed = 10
89    batch_size = 2
90    ds.config.set_seed(seed)
91    data1 = data1.shuffle(buffer_size=buffer_size)
92    data1 = data1.batch(batch_size, drop_remainder=True)
93
94    num_iter = 0
95    for item in data1.create_dict_iterator(num_epochs=1):
96        logger.info("image is: {}".format(item["image"]))
97        logger.info("label is: {}".format(item["label"]))
98        num_iter += 1
99
100    assert num_iter == 1
101
102def test_one_hot_success():
103    # success
104    class GetDatasetGenerator:
105        def __init__(self):
106            np.random.seed(58)
107            self.__data = np.random.sample((5, 2))
108            self.__label = []
109            for index in range(5):
110                self.__label.append(np.array(index))
111
112        def __getitem__(self, index):
113            return (self.__data[index], self.__label[index])
114
115        def __len__(self):
116            return len(self.__data)
117
118    dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
119
120    one_hot_encode = py_trans.OneHotOp(10)
121    trans = py_trans.Compose([one_hot_encode])
122    dataset = dataset.map(operations=trans, input_columns=["label"])
123
124    for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
125        assert item["label"][index] == 1.0
126
127def test_one_hot_success2():
128    # success
129    class GetDatasetGenerator:
130        def __init__(self):
131            np.random.seed(58)
132            self.__data = np.random.sample((5, 2))
133            self.__label = []
134            for index in range(5):
135                self.__label.append(np.array([index]))
136
137        def __getitem__(self, index):
138            return (self.__data[index], self.__label[index])
139
140        def __len__(self):
141            return len(self.__data)
142
143    dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
144
145    one_hot_encode = py_trans.OneHotOp(10)
146    trans = py_trans.Compose([one_hot_encode])
147    dataset = dataset.map(operations=trans, input_columns=["label"])
148
149    for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
150        logger.info(item)
151        assert item["label"][0][index] == 1.0
152
153def test_one_hot_success3():
154    # success
155    class GetDatasetGenerator:
156        def __init__(self):
157            np.random.seed(58)
158            self.__data = np.random.sample((5, 2))
159            self.__label = []
160            for _ in range(5):
161                value = np.ones([10, 1], dtype=np.int32)
162                for i in range(10):
163                    value[i][0] = i
164                self.__label.append(value)
165
166        def __getitem__(self, index):
167            return (self.__data[index], self.__label[index])
168
169        def __len__(self):
170            return len(self.__data)
171
172    dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
173
174    one_hot_encode = py_trans.OneHotOp(10)
175    trans = py_trans.Compose([one_hot_encode])
176    dataset = dataset.map(operations=trans, input_columns=["label"])
177
178    for item in dataset.create_dict_iterator(output_numpy=True):
179        logger.info(item)
180        for i in range(10):
181            assert item["label"][i][0][i] == 1.0
182
183def test_one_hot_type_error():
184    # type error
185    class GetDatasetGenerator:
186        def __init__(self):
187            np.random.seed(58)
188            self.__data = np.random.sample((5, 2))
189            self.__label = []
190            for index in range(5):
191                self.__label.append(np.array(float(index)))
192
193        def __getitem__(self, index):
194            return (self.__data[index], self.__label[index])
195
196        def __len__(self):
197            return len(self.__data)
198
199    dataset = ds.GeneratorDataset(GetDatasetGenerator(), ["data", "label"], shuffle=False)
200
201    one_hot_encode = py_trans.OneHotOp(10)
202    trans = py_trans.Compose([one_hot_encode])
203    dataset = dataset.map(operations=trans, input_columns=["label"])
204
205    try:
206        for index, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
207            assert item["label"][index] == 1.0
208    except RuntimeError as e:
209        assert "the input numpy type should be int" in str(e)
210
211if __name__ == "__main__":
212    test_one_hot()
213    test_one_hot_post_aug()
214    test_one_hot_success()
215    test_one_hot_success2()
216    test_one_hot_success3()
217    test_one_hot_type_error()
218