• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing the MixUpBatch op in DE
17"""
18import numpy as np
19import pytest
20import mindspore.dataset as ds
21import mindspore.dataset.vision.c_transforms as vision
22import mindspore.dataset.transforms.c_transforms as data_trans
23from mindspore import log as logger
24from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
25    config_get_set_num_parallel_workers
26
27DATA_DIR = "../data/dataset/testCifar10Data"
28DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
29DATA_DIR3 = "../data/dataset/testCelebAData/"
30
31GENERATE_GOLDEN = False
32
33
34def test_mixup_batch_success1(plot=False):
35    """
36    Test MixUpBatch op with specified alpha parameter
37    """
38    logger.info("test_mixup_batch_success1")
39
40    # Original Images
41    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
42    ds_original = ds_original.batch(5, drop_remainder=True)
43
44    images_original = None
45    for idx, (image, _) in enumerate(ds_original):
46        if idx == 0:
47            images_original = image.asnumpy()
48        else:
49            images_original = np.append(images_original, image.asnumpy(), axis=0)
50
51    # MixUp Images
52    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
53
54    one_hot_op = data_trans.OneHot(num_classes=10)
55    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
56    mixup_batch_op = vision.MixUpBatch(2)
57    data1 = data1.batch(5, drop_remainder=True)
58    data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
59
60    images_mixup = None
61    for idx, (image, _) in enumerate(data1):
62        if idx == 0:
63            images_mixup = image.asnumpy()
64        else:
65            images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
66    if plot:
67        visualize_list(images_original, images_mixup)
68
69    num_samples = images_original.shape[0]
70    mse = np.zeros(num_samples)
71    for i in range(num_samples):
72        mse[i] = diff_mse(images_mixup[i], images_original[i])
73    logger.info("MSE= {}".format(str(np.mean(mse))))
74
75
76def test_mixup_batch_success2(plot=False):
77    """
78    Test MixUpBatch op with specified alpha parameter on ImageFolderDataset
79    """
80    logger.info("test_mixup_batch_success2")
81
82    # Original Images
83    ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
84    decode_op = vision.Decode()
85    ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
86    ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
87
88    images_original = None
89    for idx, (image, _) in enumerate(ds_original):
90        if idx == 0:
91            images_original = image.asnumpy()
92        else:
93            images_original = np.append(images_original, image.asnumpy(), axis=0)
94
95    # MixUp Images
96    data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
97
98    decode_op = vision.Decode()
99    data1 = data1.map(operations=[decode_op], input_columns=["image"])
100
101    one_hot_op = data_trans.OneHot(num_classes=10)
102    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
103
104    mixup_batch_op = vision.MixUpBatch(2.0)
105    data1 = data1.batch(4, pad_info={}, drop_remainder=True)
106    data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
107
108    images_mixup = None
109    for idx, (image, _) in enumerate(data1):
110        if idx == 0:
111            images_mixup = image.asnumpy()
112        else:
113            images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
114    if plot:
115        visualize_list(images_original, images_mixup)
116
117    num_samples = images_original.shape[0]
118    mse = np.zeros(num_samples)
119    for i in range(num_samples):
120        mse[i] = diff_mse(images_mixup[i], images_original[i])
121    logger.info("MSE= {}".format(str(np.mean(mse))))
122
123
124def test_mixup_batch_success3(plot=False):
125    """
126    Test MixUpBatch op without specified alpha parameter.
127    Alpha parameter will be selected by default in this case
128    """
129    logger.info("test_mixup_batch_success3")
130
131    # Original Images
132    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
133    ds_original = ds_original.batch(5, drop_remainder=True)
134
135    images_original = None
136    for idx, (image, _) in enumerate(ds_original):
137        if idx == 0:
138            images_original = image.asnumpy()
139        else:
140            images_original = np.append(images_original, image.asnumpy(), axis=0)
141
142    # MixUp Images
143    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
144
145    one_hot_op = data_trans.OneHot(num_classes=10)
146    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
147    mixup_batch_op = vision.MixUpBatch()
148    data1 = data1.batch(5, drop_remainder=True)
149    data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
150
151    images_mixup = np.array([])
152    for idx, (image, _) in enumerate(data1):
153        if idx == 0:
154            images_mixup = image.asnumpy()
155        else:
156            images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
157    if plot:
158        visualize_list(images_original, images_mixup)
159
160    num_samples = images_original.shape[0]
161    mse = np.zeros(num_samples)
162    for i in range(num_samples):
163        mse[i] = diff_mse(images_mixup[i], images_original[i])
164    logger.info("MSE= {}".format(str(np.mean(mse))))
165
166
167def test_mixup_batch_success4(plot=False):
168    """
169    Test MixUpBatch op on a dataset where OneHot returns a 2D vector.
170    Alpha parameter will be selected by default in this case
171    """
172    logger.info("test_mixup_batch_success4")
173
174    # Original Images
175    ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False)
176    decode_op = vision.Decode()
177    ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
178    ds_original = ds_original.batch(2, drop_remainder=True)
179
180    images_original = None
181    for idx, (image, _) in enumerate(ds_original):
182        if idx == 0:
183            images_original = image.asnumpy()
184        else:
185            images_original = np.append(images_original, image.asnumpy(), axis=0)
186
187    # MixUp Images
188    data1 = ds.CelebADataset(DATA_DIR3, shuffle=False)
189
190    decode_op = vision.Decode()
191    data1 = data1.map(operations=[decode_op], input_columns=["image"])
192
193    one_hot_op = data_trans.OneHot(num_classes=100)
194    data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
195
196    mixup_batch_op = vision.MixUpBatch()
197    data1 = data1.batch(2, drop_remainder=True)
198    data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "attr"])
199
200    images_mixup = np.array([])
201    for idx, (image, _) in enumerate(data1):
202        if idx == 0:
203            images_mixup = image.asnumpy()
204        else:
205            images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
206    if plot:
207        visualize_list(images_original, images_mixup)
208
209    num_samples = images_original.shape[0]
210    mse = np.zeros(num_samples)
211    for i in range(num_samples):
212        mse[i] = diff_mse(images_mixup[i], images_original[i])
213    logger.info("MSE= {}".format(str(np.mean(mse))))
214
215
216def test_mixup_batch_md5():
217    """
218    Test MixUpBatch with MD5:
219    """
220    logger.info("test_mixup_batch_md5")
221    original_seed = config_get_set_seed(0)
222    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
223
224    # MixUp Images
225    data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
226
227    one_hot_op = data_trans.OneHot(num_classes=10)
228    data = data.map(operations=one_hot_op, input_columns=["label"])
229    mixup_batch_op = vision.MixUpBatch()
230    data = data.batch(5, drop_remainder=True)
231    data = data.map(operations=mixup_batch_op, input_columns=["image", "label"])
232
233    filename = "mixup_batch_c_result.npz"
234    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
235
236    # Restore config setting
237    ds.config.set_seed(original_seed)
238    ds.config.set_num_parallel_workers(original_num_parallel_workers)
239
240
241def test_mixup_batch_fail1():
242    """
243    Test MixUpBatch Fail 1
244    We expect this to fail because the images and labels are not batched
245    """
246    logger.info("test_mixup_batch_fail1")
247
248    # Original Images
249    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
250    ds_original = ds_original.batch(5)
251
252    images_original = np.array([])
253    for idx, (image, _) in enumerate(ds_original):
254        if idx == 0:
255            images_original = image.asnumpy()
256        else:
257            images_original = np.append(images_original, image.asnumpy(), axis=0)
258
259    # MixUp Images
260    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
261
262    one_hot_op = data_trans.OneHot(num_classes=10)
263    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
264    mixup_batch_op = vision.MixUpBatch(0.1)
265    with pytest.raises(RuntimeError) as error:
266        data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
267        for idx, (image, _) in enumerate(data1):
268            if idx == 0:
269                images_mixup = image.asnumpy()
270            else:
271                images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
272        error_message = "You must make sure images are HWC or CHW and batched"
273        assert error_message in str(error.value)
274
275
276def test_mixup_batch_fail2():
277    """
278    Test MixUpBatch Fail 2
279    We expect this to fail because alpha is negative
280    """
281    logger.info("test_mixup_batch_fail2")
282
283    # Original Images
284    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
285    ds_original = ds_original.batch(5)
286
287    images_original = np.array([])
288    for idx, (image, _) in enumerate(ds_original):
289        if idx == 0:
290            images_original = image.asnumpy()
291        else:
292            images_original = np.append(images_original, image.asnumpy(), axis=0)
293
294    # MixUp Images
295    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
296
297    one_hot_op = data_trans.OneHot(num_classes=10)
298    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
299    with pytest.raises(ValueError) as error:
300        vision.MixUpBatch(-1)
301        error_message = "Input is not within the required interval"
302        assert error_message in str(error.value)
303
304
305def test_mixup_batch_fail3():
306    """
307    Test MixUpBatch op
308    We expect this to fail because label column is not passed to mixup_batch
309    """
310    logger.info("test_mixup_batch_fail3")
311    # Original Images
312    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
313    ds_original = ds_original.batch(5, drop_remainder=True)
314
315    images_original = None
316    for idx, (image, _) in enumerate(ds_original):
317        if idx == 0:
318            images_original = image.asnumpy()
319        else:
320            images_original = np.append(images_original, image.asnumpy(), axis=0)
321
322    # MixUp Images
323    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
324
325    one_hot_op = data_trans.OneHot(num_classes=10)
326    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
327    mixup_batch_op = vision.MixUpBatch()
328    data1 = data1.batch(5, drop_remainder=True)
329    data1 = data1.map(operations=mixup_batch_op, input_columns=["image"])
330
331    with pytest.raises(RuntimeError) as error:
332        images_mixup = np.array([])
333        for idx, (image, _) in enumerate(data1):
334            if idx == 0:
335                images_mixup = image.asnumpy()
336            else:
337                images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
338    error_message = "size of input data should be 2 (including images or labels)"
339    assert error_message in str(error.value)
340
341
342def test_mixup_batch_fail4():
343    """
344    Test MixUpBatch Fail 2
345    We expect this to fail because alpha is zero
346    """
347    logger.info("test_mixup_batch_fail4")
348
349    # Original Images
350    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
351    ds_original = ds_original.batch(5)
352
353    images_original = np.array([])
354    for idx, (image, _) in enumerate(ds_original):
355        if idx == 0:
356            images_original = image.asnumpy()
357        else:
358            images_original = np.append(images_original, image.asnumpy(), axis=0)
359
360    # MixUp Images
361    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
362
363    one_hot_op = data_trans.OneHot(num_classes=10)
364    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
365    with pytest.raises(ValueError) as error:
366        vision.MixUpBatch(0.0)
367        error_message = "Input is not within the required interval"
368        assert error_message in str(error.value)
369
370
371def test_mixup_batch_fail5():
372    """
373    Test MixUpBatch Fail 5
374    We expect this to fail because labels are not OntHot encoded
375    """
376    logger.info("test_mixup_batch_fail5")
377
378    # Original Images
379    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
380    ds_original = ds_original.batch(5)
381
382    images_original = np.array([])
383    for idx, (image, _) in enumerate(ds_original):
384        if idx == 0:
385            images_original = image.asnumpy()
386        else:
387            images_original = np.append(images_original, image.asnumpy(), axis=0)
388
389    # MixUp Images
390    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
391
392    mixup_batch_op = vision.MixUpBatch()
393    data1 = data1.batch(5, drop_remainder=True)
394    data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
395
396    with pytest.raises(RuntimeError) as error:
397        images_mixup = np.array([])
398        for idx, (image, _) in enumerate(data1):
399            if idx == 0:
400                images_mixup = image.asnumpy()
401            else:
402                images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
403    error_message = "wrong labels shape. The second column (labels) must have a shape of NC or NLC"
404    assert error_message in str(error.value)
405
406
407if __name__ == "__main__":
408    test_mixup_batch_success1(plot=True)
409    test_mixup_batch_success2(plot=True)
410    test_mixup_batch_success3(plot=True)
411    test_mixup_batch_success4(plot=True)
412    test_mixup_batch_md5()
413    test_mixup_batch_fail1()
414    test_mixup_batch_fail2()
415    test_mixup_batch_fail3()
416    test_mixup_batch_fail4()
417    test_mixup_batch_fail5()
418