• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16Testing RgbToHsv and HsvToRgb op in DE
17"""
18
19import colorsys
20import numpy as np
21from numpy.testing import assert_allclose
22
23import mindspore.dataset as ds
24import mindspore.dataset.transforms.py_transforms
25import mindspore.dataset.vision.py_transforms as vision
26import mindspore.dataset.vision.py_transforms_util as util
27
28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
30
31
32def generate_numpy_random_rgb(shape):
33    # Only generate floating points that are fractions like n / 256, since they
34    # are RGB pixels. Some low-precision floating point types in this test can't
35    # handle arbitrary precision floating points well.
36    return np.random.randint(0, 256, shape) / 255.
37
38
39def test_rgb_hsv_hwc():
40    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
41    rgb_np = rgb_flat.reshape((8, 8, 3))
42    hsv_base = np.array([
43        colorsys.rgb_to_hsv(
44            r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
45        for r, g, b in rgb_flat
46    ])
47    hsv_base = hsv_base.reshape((8, 8, 3))
48    hsv_de = util.rgb_to_hsvs(rgb_np, True)
49    assert hsv_base.shape == hsv_de.shape
50    assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
51
52    hsv_flat = hsv_base.reshape(64, 3)
53    rgb_base = np.array([
54        colorsys.hsv_to_rgb(
55            h.astype(np.float64), s.astype(np.float64), v.astype(np.float64))
56        for h, s, v in hsv_flat
57    ])
58    rgb_base = rgb_base.reshape((8, 8, 3))
59    rgb_de = util.hsv_to_rgbs(hsv_base, True)
60    assert rgb_base.shape == rgb_de.shape
61    assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
62
63
64def test_rgb_hsv_batch_hwc():
65    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
66    rgb_np = rgb_flat.reshape((4, 2, 8, 3))
67    hsv_base = np.array([
68        colorsys.rgb_to_hsv(
69            r.astype(np.float64), g.astype(np.float64), b.astype(np.float64))
70        for r, g, b in rgb_flat
71    ])
72    hsv_base = hsv_base.reshape((4, 2, 8, 3))
73    hsv_de = util.rgb_to_hsvs(rgb_np, True)
74    assert hsv_base.shape == hsv_de.shape
75    assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
76
77    hsv_flat = hsv_base.reshape((64, 3))
78    rgb_base = np.array([
79        colorsys.hsv_to_rgb(
80            h.astype(np.float64), s.astype(np.float64), v.astype(np.float64))
81        for h, s, v in hsv_flat
82    ])
83    rgb_base = rgb_base.reshape((4, 2, 8, 3))
84    rgb_de = util.hsv_to_rgbs(hsv_base, True)
85    assert rgb_de.shape == rgb_base.shape
86    assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
87
88
89def test_rgb_hsv_chw():
90    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
91    rgb_np = rgb_flat.reshape((3, 8, 8))
92    hsv_base = np.array([
93        np.vectorize(colorsys.rgb_to_hsv)(
94            rgb_np[0, :, :].astype(np.float64), rgb_np[1, :, :].astype(np.float64), rgb_np[2, :, :].astype(np.float64))
95    ])
96    hsv_base = hsv_base.reshape((3, 8, 8))
97    hsv_de = util.rgb_to_hsvs(rgb_np, False)
98    assert hsv_base.shape == hsv_de.shape
99    assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
100
101    rgb_base = np.array([
102        np.vectorize(colorsys.hsv_to_rgb)(
103            hsv_base[0, :, :].astype(np.float64), hsv_base[1, :, :].astype(np.float64),
104            hsv_base[2, :, :].astype(np.float64))
105    ])
106    rgb_base = rgb_base.reshape((3, 8, 8))
107    rgb_de = util.hsv_to_rgbs(hsv_base, False)
108    assert rgb_de.shape == rgb_base.shape
109    assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
110
111
112def test_rgb_hsv_batch_chw():
113    rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32)
114    rgb_imgs = rgb_flat.reshape((4, 3, 2, 8))
115    hsv_base_imgs = np.array([
116        np.vectorize(colorsys.rgb_to_hsv)(
117            img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64))
118        for img in rgb_imgs
119    ])
120    hsv_de = util.rgb_to_hsvs(rgb_imgs, False)
121    assert hsv_base_imgs.shape == hsv_de.shape
122    assert_allclose(hsv_base_imgs.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0)
123
124    rgb_base = np.array([
125        np.vectorize(colorsys.hsv_to_rgb)(
126            img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64))
127        for img in hsv_base_imgs
128    ])
129    rgb_de = util.hsv_to_rgbs(hsv_base_imgs, False)
130    assert rgb_base.shape == rgb_de.shape
131    assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0)
132
133
134def test_rgb_hsv_pipeline():
135    # First dataset
136    transforms1 = [
137        vision.Decode(),
138        vision.Resize([64, 64]),
139        vision.ToTensor()
140    ]
141    transforms1 = mindspore.dataset.transforms.py_transforms.Compose(transforms1)
142    ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
143    ds1 = ds1.map(operations=transforms1, input_columns=["image"])
144
145    # Second dataset
146    transforms2 = [
147        vision.Decode(),
148        vision.Resize([64, 64]),
149        vision.ToTensor(),
150        vision.RgbToHsv(),
151        vision.HsvToRgb()
152    ]
153    transform2 = mindspore.dataset.transforms.py_transforms.Compose(transforms2)
154    ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
155    ds2 = ds2.map(operations=transform2, input_columns=["image"])
156
157    num_iter = 0
158    for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), ds2.create_dict_iterator(num_epochs=1)):
159        num_iter += 1
160        ori_img = data1["image"].asnumpy()
161        cvt_img = data2["image"].asnumpy()
162        assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0)
163        assert ori_img.shape == cvt_img.shape
164
165
166if __name__ == "__main__":
167    test_rgb_hsv_hwc()
168    test_rgb_hsv_batch_hwc()
169    test_rgb_hsv_chw()
170    test_rgb_hsv_batch_chw()
171    test_rgb_hsv_pipeline()
172