• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18import mindspore.dataset.transforms.c_transforms as c
19import mindspore.dataset.transforms.py_transforms as f
20import mindspore.dataset.vision.c_transforms as c_vision
21import mindspore.dataset.vision.py_transforms as py_vision
22from mindspore import log as logger
23
24DATA_DIR = "../data/dataset/testImageNetData/train"
25DATA_DIR_2 = "../data/dataset/testImageNetData2/train"
26
27
28def test_one_hot_op():
29    """
30    Test one hot encoding op
31    """
32    logger.info("Test one hot encoding op")
33
34    # define map operations
35    # ds = de.ImageFolderDataset(DATA_DIR, schema=SCHEMA_DIR)
36    dataset = ds.ImageFolderDataset(DATA_DIR)
37    num_classes = 2
38    epsilon_para = 0.1
39
40    transforms = [f.OneHotOp(num_classes=num_classes, smoothing_rate=epsilon_para)]
41    transform_label = f.Compose(transforms)
42    dataset = dataset.map(operations=transform_label, input_columns=["label"])
43
44    golden_label = np.ones(num_classes) * epsilon_para / num_classes
45    golden_label[1] = 1 - epsilon_para / num_classes
46
47    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
48        label = data["label"]
49        logger.info("label is {}".format(label))
50        logger.info("golden_label is {}".format(golden_label))
51        assert label.all() == golden_label.all()
52        logger.info("====test one hot op ok====")
53
54
55def test_mix_up_single():
56    """
57    Test single batch mix up op
58    """
59    logger.info("Test single batch mix up op")
60
61    resize_height = 224
62    resize_width = 224
63
64    # Create dataset and define map operations
65    ds1 = ds.ImageFolderDataset(DATA_DIR_2)
66
67    num_classes = 10
68    decode_op = c_vision.Decode()
69    resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
70    one_hot_encode = c.OneHot(num_classes)  # num_classes is input argument
71
72    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
73    ds1 = ds1.map(operations=resize_op, input_columns=["image"])
74    ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"])
75
76    # apply batch operations
77    batch_size = 3
78    ds1 = ds1.batch(batch_size, drop_remainder=True)
79
80    ds2 = ds1
81    alpha = 0.2
82    transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=True)
83                  ]
84    ds1 = ds1.map(operations=transforms, input_columns=["image", "label"])
85
86    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True),
87                            ds2.create_dict_iterator(num_epochs=1, output_numpy=True)):
88        image1 = data1["image"]
89        label = data1["label"]
90        logger.info("label is {}".format(label))
91
92        image2 = data2["image"]
93        label2 = data2["label"]
94        logger.info("label2 is {}".format(label2))
95
96        lam = np.abs(label - label2)
97        for index in range(batch_size - 1):
98            if np.square(lam[index]).mean() != 0:
99                lam_value = 1 - np.sum(lam[index]) / 2
100                img_golden = lam_value * image2[index] + (1 - lam_value) * image2[index + 1]
101                assert image1[index].all() == img_golden.all()
102                logger.info("====test single batch mixup ok====")
103
104
105def test_mix_up_multi():
106    """
107    Test multi batch mix up op
108    """
109    logger.info("Test several batch mix up op")
110
111    resize_height = 224
112    resize_width = 224
113
114    # Create dataset and define map operations
115    ds1 = ds.ImageFolderDataset(DATA_DIR_2)
116
117    num_classes = 3
118    decode_op = c_vision.Decode()
119    resize_op = c_vision.Resize((resize_height, resize_width), c_vision.Inter.LINEAR)
120    one_hot_encode = c.OneHot(num_classes)  # num_classes is input argument
121
122    ds1 = ds1.map(operations=decode_op, input_columns=["image"])
123    ds1 = ds1.map(operations=resize_op, input_columns=["image"])
124    ds1 = ds1.map(operations=one_hot_encode, input_columns=["label"])
125
126    # apply batch operations
127    batch_size = 3
128    ds1 = ds1.batch(batch_size, drop_remainder=True)
129
130    ds2 = ds1
131    alpha = 0.2
132    transforms = [py_vision.MixUp(batch_size=batch_size, alpha=alpha, is_single=False)
133                  ]
134    ds1 = ds1.map(operations=transforms, input_columns=["image", "label"])
135    num_iter = 0
136    batch1_image1 = 0
137    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1, output_numpy=True),
138                            ds2.create_dict_iterator(num_epochs=1, output_numpy=True)):
139        image1 = data1["image"]
140        label1 = data1["label"]
141        logger.info("label: {}".format(label1))
142
143        image2 = data2["image"]
144        label2 = data2["label"]
145        logger.info("label2: {}".format(label2))
146
147        if num_iter == 0:
148            batch1_image1 = image1
149
150        if num_iter == 1:
151            lam = np.abs(label2 - label1)
152            logger.info("lam value in multi: {}".format(lam))
153            for index in range(batch_size):
154                if np.square(lam[index]).mean() != 0:
155                    lam_value = 1 - np.sum(lam[index]) / 2
156                    img_golden = lam_value * image2[index] + (1 - lam_value) * batch1_image1[index]
157                    assert image1[index].all() == img_golden.all()
158                    logger.info("====test several batch mixup ok====")
159            break
160        num_iter += 1
161
162
163if __name__ == "__main__":
164    test_one_hot_op()
165    test_mix_up_single()
166    test_mix_up_multi()
167