• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RgbToBgr op in DE
17"""
18
19import numpy as np
20from numpy.testing import assert_allclose
21import mindspore.dataset as ds
22import mindspore.dataset.transforms.py_transforms
23import mindspore.dataset.vision.c_transforms as vision
24import mindspore.dataset.vision.py_transforms as py_vision
25import mindspore.dataset.vision.py_transforms_util as util
26
27DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
28SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
29
30
31def generate_numpy_random_rgb(shape):
32    # Only generate floating points that are fractions like n / 256, since they
33    # are RGB pixels. Some low-precision floating point types in this test can't
34    # handle arbitrary precision floating points well.
35    return np.random.randint(0, 256, shape) / 255.
36
37
38def test_rgb_bgr_hwc_py():
39    # Eager
40    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
41    rgb_np = rgb_flat.reshape((8, 8, 3))
42
43    bgr_np_pred = util.rgb_to_bgrs(rgb_np, True)
44    r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
45    bgr_np_gt = np.stack((b, g, r), axis=2)
46    assert bgr_np_pred.shape == rgb_np.shape
47    assert_allclose(bgr_np_pred.flatten(),
48                    bgr_np_gt.flatten(),
49                    rtol=1e-5,
50                    atol=0)
51
52
53def test_rgb_bgr_hwc_c():
54    # Eager
55    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
56    rgb_np = rgb_flat.reshape((8, 8, 3))
57
58    rgb2bgr_op = vision.RgbToBgr()
59    bgr_np_pred = rgb2bgr_op(rgb_np)
60    r, g, b = rgb_np[:, :, 0], rgb_np[:, :, 1], rgb_np[:, :, 2]
61    bgr_np_gt = np.stack((b, g, r), axis=2)
62    assert bgr_np_pred.shape == rgb_np.shape
63    assert_allclose(bgr_np_pred.flatten(),
64                    bgr_np_gt.flatten(),
65                    rtol=1e-5,
66                    atol=0)
67
68
69def test_rgb_bgr_chw_py():
70    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
71    rgb_np = rgb_flat.reshape((3, 8, 8))
72
73    rgb_np_pred = util.rgb_to_bgrs(rgb_np, False)
74    rgb_np_gt = rgb_np[::-1, :, :]
75    assert rgb_np_pred.shape == rgb_np.shape
76    assert_allclose(rgb_np_pred.flatten(),
77                    rgb_np_gt.flatten(),
78                    rtol=1e-5,
79                    atol=0)
80
81
82def test_rgb_bgr_pipeline_py():
83    # First dataset
84    transforms1 = [py_vision.Decode(), py_vision.Resize([64, 64]), py_vision.ToTensor()]
85    transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
86        transforms1)
87    ds1 = ds.TFRecordDataset(DATA_DIR,
88                             SCHEMA_DIR,
89                             columns_list=["image"],
90                             shuffle=False)
91    ds1 = ds1.map(operations=transforms1, input_columns=["image"])
92
93    # Second dataset
94    transforms2 = [
95        py_vision.Decode(),
96        py_vision.Resize([64, 64]),
97        py_vision.ToTensor(),
98        py_vision.RgbToBgr()
99    ]
100    transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
101        transforms2)
102    ds2 = ds.TFRecordDataset(DATA_DIR,
103                             SCHEMA_DIR,
104                             columns_list=["image"],
105                             shuffle=False)
106    ds2 = ds2.map(operations=transforms2, input_columns=["image"])
107
108    num_iter = 0
109    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
110                            ds2.create_dict_iterator(num_epochs=1)):
111        num_iter += 1
112        ori_img = data1["image"].asnumpy()
113        cvt_img = data2["image"].asnumpy()
114        cvt_img_gt = ori_img[::-1, :, :]
115        assert_allclose(cvt_img_gt.flatten(),
116                        cvt_img.flatten(),
117                        rtol=1e-5,
118                        atol=0)
119        assert ori_img.shape == cvt_img.shape
120
121
122def test_rgb_bgr_pipeline_c():
123    # First dataset
124    transforms1 = [
125        vision.Decode(),
126        vision.Resize([64, 64])
127    ]
128    transforms1 = mindspore.dataset.transforms.py_transforms.Compose(
129        transforms1)
130    ds1 = ds.TFRecordDataset(DATA_DIR,
131                             SCHEMA_DIR,
132                             columns_list=["image"],
133                             shuffle=False)
134    ds1 = ds1.map(operations=transforms1, input_columns=["image"])
135
136    # Second dataset
137    transforms2 = [
138        vision.Decode(),
139        vision.Resize([64, 64]),
140        vision.RgbToBgr()
141    ]
142    transforms2 = mindspore.dataset.transforms.py_transforms.Compose(
143        transforms2)
144    ds2 = ds.TFRecordDataset(DATA_DIR,
145                             SCHEMA_DIR,
146                             columns_list=["image"],
147                             shuffle=False)
148    ds2 = ds2.map(operations=transforms2, input_columns=["image"])
149
150    num_iter = 0
151    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1),
152                            ds2.create_dict_iterator(num_epochs=1)):
153        num_iter += 1
154        ori_img = data1["image"].asnumpy()
155        cvt_img = data2["image"].asnumpy()
156        cvt_img_gt = ori_img[:, :, ::-1]
157        assert_allclose(cvt_img_gt.flatten(),
158                        cvt_img.flatten(),
159                        rtol=1e-5,
160                        atol=0)
161        assert ori_img.shape == cvt_img.shape
162
163
164if __name__ == "__main__":
165    test_rgb_bgr_hwc_py()
166    test_rgb_bgr_hwc_c()
167    test_rgb_bgr_chw_py()
168    test_rgb_bgr_pipeline_py()
169    test_rgb_bgr_pipeline_c()
170