• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RandomChoice op in DE
17"""
18import numpy as np
19import mindspore.dataset as ds
20import mindspore.dataset.transforms.py_transforms as py_transforms
21import mindspore.dataset.vision.py_transforms as py_vision
22from mindspore import log as logger
23from util import visualize_list, diff_mse
24
25DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
26SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
27
28
29def test_random_choice_op(plot=False):
30    """
31    Test RandomChoice in python transformations
32    """
33    logger.info("test_random_choice_op")
34    # define map operations
35    transforms_list = [py_vision.CenterCrop(64), py_vision.RandomRotation(30)]
36    transforms1 = [
37        py_vision.Decode(),
38        py_transforms.RandomChoice(transforms_list),
39        py_vision.ToTensor()
40    ]
41    transform1 = py_transforms.Compose(transforms1)
42
43    transforms2 = [
44        py_vision.Decode(),
45        py_vision.ToTensor()
46    ]
47    transform2 = py_transforms.Compose(transforms2)
48
49    #  First dataset
50    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
51    data1 = data1.map(operations=transform1, input_columns=["image"])
52    #  Second dataset
53    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
54    data2 = data2.map(operations=transform2, input_columns=["image"])
55
56    image_choice = []
57    image_original = []
58    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
59                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
60        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
61        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
62        image_choice.append(image1)
63        image_original.append(image2)
64    if plot:
65        visualize_list(image_original, image_choice)
66
67
68def test_random_choice_comp(plot=False):
69    """
70    Test RandomChoice and compare with single CenterCrop results
71    """
72    logger.info("test_random_choice_comp")
73    # define map operations
74    transforms_list = [py_vision.CenterCrop(64)]
75    transforms1 = [
76        py_vision.Decode(),
77        py_transforms.RandomChoice(transforms_list),
78        py_vision.ToTensor()
79    ]
80    transform1 = py_transforms.Compose(transforms1)
81
82    transforms2 = [
83        py_vision.Decode(),
84        py_vision.CenterCrop(64),
85        py_vision.ToTensor()
86    ]
87    transform2 = py_transforms.Compose(transforms2)
88
89    #  First dataset
90    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
91    data1 = data1.map(operations=transform1, input_columns=["image"])
92    #  Second dataset
93    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
94    data2 = data2.map(operations=transform2, input_columns=["image"])
95
96    image_choice = []
97    image_original = []
98    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
99                            data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
100        image1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
101        image2 = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
102        image_choice.append(image1)
103        image_original.append(image2)
104
105        mse = diff_mse(image1, image2)
106        assert mse == 0
107    if plot:
108        visualize_list(image_original, image_choice)
109
110
111def test_random_choice_exception_random_crop_badinput():
112    """
113    Test RandomChoice: hit error in RandomCrop with greater crop size,
114    expected to raise error
115    """
116    logger.info("test_random_choice_exception_random_crop_badinput")
117    # define map operations
118    # note: crop size[5000, 5000] > image size[4032, 2268]
119    transforms_list = [py_vision.RandomCrop(5000)]
120    transforms = [
121        py_vision.Decode(),
122        py_transforms.RandomChoice(transforms_list),
123        py_vision.ToTensor()
124    ]
125    transform = py_transforms.Compose(transforms)
126    #  Generate dataset
127    data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
128    data = data.map(operations=transform, input_columns=["image"])
129    try:
130        _ = data.create_dict_iterator(num_epochs=1).__next__()
131    except RuntimeError as e:
132        logger.info("Got an exception in DE: {}".format(str(e)))
133        assert "Crop size" in str(e)
134
135
136if __name__ == '__main__':
137    test_random_choice_op(plot=True)
138    test_random_choice_comp(plot=True)
139    test_random_choice_exception_random_crop_badinput()
140