• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing the CutMixBatch op in DE
17"""
18import numpy as np
19import pytest
20import mindspore.dataset as ds
21import mindspore.dataset.vision.c_transforms as vision
22import mindspore.dataset.transforms.c_transforms as data_trans
23import mindspore.dataset.vision.utils as mode
24from mindspore import log as logger
25from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
26    config_get_set_num_parallel_workers
27
28DATA_DIR = "../data/dataset/testCifar10Data"
29DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
30DATA_DIR3 = "../data/dataset/testCelebAData/"
31
32GENERATE_GOLDEN = False
33
34
35def test_cutmix_batch_success1(plot=False):
36    """
37    Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
38    """
39    logger.info("test_cutmix_batch_success1")
40    # Original Images
41    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
42    ds_original = ds_original.batch(5, drop_remainder=True)
43
44    images_original = None
45    for idx, (image, _) in enumerate(ds_original):
46        if idx == 0:
47            images_original = image.asnumpy()
48        else:
49            images_original = np.append(images_original, image.asnumpy(), axis=0)
50
51    # CutMix Images
52    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
53    hwc2chw_op = vision.HWC2CHW()
54    data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
55    one_hot_op = data_trans.OneHot(num_classes=10)
56    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
57    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
58    data1 = data1.batch(5, drop_remainder=True)
59    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
60
61    images_cutmix = None
62    for idx, (image, _) in enumerate(data1):
63        if idx == 0:
64            images_cutmix = image.asnumpy().transpose(0, 2, 3, 1)
65        else:
66            images_cutmix = np.append(images_cutmix, image.asnumpy().transpose(0, 2, 3, 1), axis=0)
67    if plot:
68        visualize_list(images_original, images_cutmix)
69
70    num_samples = images_original.shape[0]
71    mse = np.zeros(num_samples)
72    for i in range(num_samples):
73        mse[i] = diff_mse(images_cutmix[i], images_original[i])
74    logger.info("MSE= {}".format(str(np.mean(mse))))
75
76
77def test_cutmix_batch_success2(plot=False):
78    """
79    Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images
80    """
81    logger.info("test_cutmix_batch_success2")
82
83    # Original Images
84    ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
85    ds_original = ds_original.batch(5, drop_remainder=True)
86
87    images_original = None
88    for idx, (image, _) in enumerate(ds_original):
89        if idx == 0:
90            images_original = image.asnumpy()
91        else:
92            images_original = np.append(images_original, image.asnumpy(), axis=0)
93
94    # CutMix Images
95    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
96    one_hot_op = data_trans.OneHot(num_classes=10)
97    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
98    rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
99    data1 = data1.map(operations=rescale_op, input_columns=["image"])
100    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
101    data1 = data1.batch(5, drop_remainder=True)
102    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
103
104    images_cutmix = None
105    for idx, (image, _) in enumerate(data1):
106        if idx == 0:
107            images_cutmix = image.asnumpy()
108        else:
109            images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
110    if plot:
111        visualize_list(images_original, images_cutmix)
112
113    num_samples = images_original.shape[0]
114    mse = np.zeros(num_samples)
115    for i in range(num_samples):
116        mse[i] = diff_mse(images_cutmix[i], images_original[i])
117    logger.info("MSE= {}".format(str(np.mean(mse))))
118
119
120def test_cutmix_batch_success3(plot=False):
121    """
122    Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDataset
123    """
124    logger.info("test_cutmix_batch_success3")
125
126    ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
127    decode_op = vision.Decode()
128    ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
129    resize_op = vision.Resize([224, 224])
130    ds_original = ds_original.map(operations=[resize_op], input_columns=["image"])
131    ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
132
133    images_original = None
134    for idx, (image, _) in enumerate(ds_original):
135        if idx == 0:
136            images_original = image.asnumpy()
137        else:
138            images_original = np.append(images_original, image.asnumpy(), axis=0)
139
140    # CutMix Images
141    data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
142
143    decode_op = vision.Decode()
144    data1 = data1.map(operations=[decode_op], input_columns=["image"])
145
146    resize_op = vision.Resize([224, 224])
147    data1 = data1.map(operations=[resize_op], input_columns=["image"])
148
149    one_hot_op = data_trans.OneHot(num_classes=10)
150    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
151
152    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
153    data1 = data1.batch(4, pad_info={}, drop_remainder=True)
154    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
155
156    images_cutmix = None
157    for idx, (image, _) in enumerate(data1):
158        if idx == 0:
159            images_cutmix = image.asnumpy()
160        else:
161            images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
162    if plot:
163        visualize_list(images_original, images_cutmix)
164
165    num_samples = images_original.shape[0]
166    mse = np.zeros(num_samples)
167    for i in range(num_samples):
168        mse[i] = diff_mse(images_cutmix[i], images_original[i])
169    logger.info("MSE= {}".format(str(np.mean(mse))))
170
171
172def test_cutmix_batch_success4(plot=False):
173    """
174    Test CutMixBatch on a dataset where OneHot returns a 2D vector
175    """
176    logger.info("test_cutmix_batch_success4")
177
178    ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False)
179    decode_op = vision.Decode()
180    ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
181    resize_op = vision.Resize([224, 224])
182    ds_original = ds_original.map(operations=[resize_op], input_columns=["image"])
183    ds_original = ds_original.batch(2, drop_remainder=True)
184
185    images_original = None
186    for idx, (image, _) in enumerate(ds_original):
187        if idx == 0:
188            images_original = image.asnumpy()
189        else:
190            images_original = np.append(images_original, image.asnumpy(), axis=0)
191
192    # CutMix Images
193    data1 = ds.CelebADataset(dataset_dir=DATA_DIR3, shuffle=False)
194
195    decode_op = vision.Decode()
196    data1 = data1.map(operations=[decode_op], input_columns=["image"])
197
198    resize_op = vision.Resize([224, 224])
199    data1 = data1.map(operations=[resize_op], input_columns=["image"])
200
201    one_hot_op = data_trans.OneHot(num_classes=100)
202    data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
203
204    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
205    data1 = data1.batch(2, drop_remainder=True)
206    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "attr"])
207
208    images_cutmix = None
209    for idx, (image, _) in enumerate(data1):
210        if idx == 0:
211            images_cutmix = image.asnumpy()
212        else:
213            images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
214    if plot:
215        visualize_list(images_original, images_cutmix)
216
217    num_samples = images_original.shape[0]
218    mse = np.zeros(num_samples)
219    for i in range(num_samples):
220        mse[i] = diff_mse(images_cutmix[i], images_original[i])
221    logger.info("MSE= {}".format(str(np.mean(mse))))
222
223
224def test_cutmix_batch_nhwc_md5():
225    """
226    Test CutMixBatch on a batch of HWC images with MD5:
227    """
228    logger.info("test_cutmix_batch_nhwc_md5")
229    original_seed = config_get_set_seed(0)
230    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
231
232    # CutMixBatch Images
233    data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
234
235    one_hot_op = data_trans.OneHot(num_classes=10)
236    data = data.map(operations=one_hot_op, input_columns=["label"])
237    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
238    data = data.batch(5, drop_remainder=True)
239    data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"])
240
241    filename = "cutmix_batch_c_nhwc_result.npz"
242    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
243
244    # Restore config setting
245    ds.config.set_seed(original_seed)
246    ds.config.set_num_parallel_workers(original_num_parallel_workers)
247
248
249def test_cutmix_batch_nchw_md5():
250    """
251    Test CutMixBatch on a batch of CHW images with MD5:
252    """
253    logger.info("test_cutmix_batch_nchw_md5")
254    original_seed = config_get_set_seed(0)
255    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
256
257    # CutMixBatch Images
258    data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
259    hwc2chw_op = vision.HWC2CHW()
260    data = data.map(operations=hwc2chw_op, input_columns=["image"])
261    one_hot_op = data_trans.OneHot(num_classes=10)
262    data = data.map(operations=one_hot_op, input_columns=["label"])
263    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
264    data = data.batch(5, drop_remainder=True)
265    data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"])
266
267    filename = "cutmix_batch_c_nchw_result.npz"
268    save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
269
270    # Restore config setting
271    ds.config.set_seed(original_seed)
272    ds.config.set_num_parallel_workers(original_num_parallel_workers)
273
274
275def test_cutmix_batch_fail1():
276    """
277    Test CutMixBatch Fail 1
278    We expect this to fail because the images and labels are not batched
279    """
280    logger.info("test_cutmix_batch_fail1")
281
282    # CutMixBatch Images
283    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
284
285    one_hot_op = data_trans.OneHot(num_classes=10)
286    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
287    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
288    with pytest.raises(RuntimeError) as error:
289        data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
290        for idx, (image, _) in enumerate(data1):
291            if idx == 0:
292                images_cutmix = image.asnumpy()
293            else:
294                images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
295        error_message = "You must make sure images are HWC or CHW and batch "
296        assert error_message in str(error.value)
297
298
299def test_cutmix_batch_fail2():
300    """
301    Test CutMixBatch Fail 2
302    We expect this to fail because alpha is negative
303    """
304    logger.info("test_cutmix_batch_fail2")
305
306    # CutMixBatch Images
307    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
308
309    one_hot_op = data_trans.OneHot(num_classes=10)
310    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
311    with pytest.raises(ValueError) as error:
312        vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
313        error_message = "Input is not within the required interval"
314        assert error_message in str(error.value)
315
316
317def test_cutmix_batch_fail3():
318    """
319    Test CutMixBatch Fail 2
320    We expect this to fail because prob is larger than 1
321    """
322    logger.info("test_cutmix_batch_fail3")
323
324    # CutMixBatch Images
325    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
326
327    one_hot_op = data_trans.OneHot(num_classes=10)
328    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
329    with pytest.raises(ValueError) as error:
330        vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
331        error_message = "Input is not within the required interval"
332        assert error_message in str(error.value)
333
334
335def test_cutmix_batch_fail4():
336    """
337    Test CutMixBatch Fail 2
338    We expect this to fail because prob is negative
339    """
340    logger.info("test_cutmix_batch_fail4")
341
342    # CutMixBatch Images
343    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
344
345    one_hot_op = data_trans.OneHot(num_classes=10)
346    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
347    with pytest.raises(ValueError) as error:
348        vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
349        error_message = "Input is not within the required interval"
350        assert error_message in str(error.value)
351
352
353def test_cutmix_batch_fail5():
354    """
355    Test CutMixBatch op
356    We expect this to fail because label column is not passed to cutmix_batch
357    """
358    logger.info("test_cutmix_batch_fail5")
359
360    # CutMixBatch Images
361    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
362
363    one_hot_op = data_trans.OneHot(num_classes=10)
364    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
365    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
366    data1 = data1.batch(5, drop_remainder=True)
367    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image"])
368
369    with pytest.raises(RuntimeError) as error:
370        images_cutmix = np.array([])
371        for idx, (image, _) in enumerate(data1):
372            if idx == 0:
373                images_cutmix = image.asnumpy()
374            else:
375                images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
376    error_message = "size of input should be 2 (including image and label)"
377    assert error_message in str(error.value)
378
379
380def test_cutmix_batch_fail6():
381    """
382    Test CutMixBatch op
383    We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
384    """
385    logger.info("test_cutmix_batch_fail6")
386
387    # CutMixBatch Images
388    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
389
390    one_hot_op = data_trans.OneHot(num_classes=10)
391    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
392    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
393    data1 = data1.batch(5, drop_remainder=True)
394    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
395
396    with pytest.raises(RuntimeError) as error:
397        images_cutmix = np.array([])
398        for idx, (image, _) in enumerate(data1):
399            if idx == 0:
400                images_cutmix = image.asnumpy()
401            else:
402                images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
403    error_message = "image doesn't match the <N,C,H,W> format"
404    assert error_message in str(error.value)
405
406
407def test_cutmix_batch_fail7():
408    """
409    Test CutMixBatch op
410    We expect this to fail because labels are not in one-hot format
411    """
412    logger.info("test_cutmix_batch_fail7")
413
414    # CutMixBatch Images
415    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
416
417    cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
418    data1 = data1.batch(5, drop_remainder=True)
419    data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
420
421    with pytest.raises(RuntimeError) as error:
422        images_cutmix = np.array([])
423        for idx, (image, _) in enumerate(data1):
424            if idx == 0:
425                images_cutmix = image.asnumpy()
426            else:
427                images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
428    error_message = "wrong labels shape. The second column (labels) must have a shape of NC or NLC"
429    assert error_message in str(error.value)
430
431
432def test_cutmix_batch_fail8():
433    """
434    Test CutMixBatch Fail 8
435    We expect this to fail because alpha is zero
436    """
437    logger.info("test_cutmix_batch_fail8")
438
439    # CutMixBatch Images
440    data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
441
442    one_hot_op = data_trans.OneHot(num_classes=10)
443    data1 = data1.map(operations=one_hot_op, input_columns=["label"])
444    with pytest.raises(ValueError) as error:
445        vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
446        error_message = "Input is not within the required interval"
447        assert error_message in str(error.value)
448
449
450if __name__ == "__main__":
451    test_cutmix_batch_success1(plot=True)
452    test_cutmix_batch_success2(plot=True)
453    test_cutmix_batch_success3(plot=True)
454    test_cutmix_batch_success4(plot=True)
455    test_cutmix_batch_nchw_md5()
456    test_cutmix_batch_nhwc_md5()
457    test_cutmix_batch_fail1()
458    test_cutmix_batch_fail2()
459    test_cutmix_batch_fail3()
460    test_cutmix_batch_fail4()
461    test_cutmix_batch_fail5()
462    test_cutmix_batch_fail6()
463    test_cutmix_batch_fail7()
464    test_cutmix_batch_fail8()
465