1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16Testing RgbToHsv and HsvToRgb op in DE 17""" 18 19import colorsys 20import numpy as np 21from numpy.testing import assert_allclose 22 23import mindspore.dataset as ds 24import mindspore.dataset.transforms.py_transforms 25import mindspore.dataset.vision.py_transforms as vision 26import mindspore.dataset.vision.py_transforms_util as util 27 28DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] 29SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" 30 31 32def generate_numpy_random_rgb(shape): 33 # Only generate floating points that are fractions like n / 256, since they 34 # are RGB pixels. Some low-precision floating point types in this test can't 35 # handle arbitrary precision floating points well. 36 return np.random.randint(0, 256, shape) / 255. 37 38 39def test_rgb_hsv_hwc(): 40 rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32) 41 rgb_np = rgb_flat.reshape((8, 8, 3)) 42 hsv_base = np.array([ 43 colorsys.rgb_to_hsv( 44 r.astype(np.float64), g.astype(np.float64), b.astype(np.float64)) 45 for r, g, b in rgb_flat 46 ]) 47 hsv_base = hsv_base.reshape((8, 8, 3)) 48 hsv_de = util.rgb_to_hsvs(rgb_np, True) 49 assert hsv_base.shape == hsv_de.shape 50 assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0) 51 52 hsv_flat = hsv_base.reshape(64, 3) 53 rgb_base = np.array([ 54 colorsys.hsv_to_rgb( 55 h.astype(np.float64), s.astype(np.float64), v.astype(np.float64)) 56 for h, s, v in hsv_flat 57 ]) 58 rgb_base = rgb_base.reshape((8, 8, 3)) 59 rgb_de = util.hsv_to_rgbs(hsv_base, True) 60 assert rgb_base.shape == rgb_de.shape 61 assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0) 62 63 64def test_rgb_hsv_batch_hwc(): 65 rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32) 66 rgb_np = rgb_flat.reshape((4, 2, 8, 3)) 67 hsv_base = np.array([ 68 colorsys.rgb_to_hsv( 69 r.astype(np.float64), g.astype(np.float64), b.astype(np.float64)) 70 for r, g, b in rgb_flat 71 ]) 72 hsv_base = hsv_base.reshape((4, 2, 8, 3)) 73 hsv_de = util.rgb_to_hsvs(rgb_np, True) 74 assert hsv_base.shape == hsv_de.shape 75 assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0) 76 77 hsv_flat = hsv_base.reshape((64, 3)) 78 rgb_base = np.array([ 79 colorsys.hsv_to_rgb( 80 h.astype(np.float64), s.astype(np.float64), v.astype(np.float64)) 81 for h, s, v in hsv_flat 82 ]) 83 rgb_base = rgb_base.reshape((4, 2, 8, 3)) 84 rgb_de = util.hsv_to_rgbs(hsv_base, True) 85 assert rgb_de.shape == rgb_base.shape 86 assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0) 87 88 89def test_rgb_hsv_chw(): 90 rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32) 91 rgb_np = rgb_flat.reshape((3, 8, 8)) 92 hsv_base = np.array([ 93 np.vectorize(colorsys.rgb_to_hsv)( 94 rgb_np[0, :, :].astype(np.float64), rgb_np[1, :, :].astype(np.float64), rgb_np[2, :, :].astype(np.float64)) 95 ]) 96 hsv_base = hsv_base.reshape((3, 8, 8)) 97 hsv_de = util.rgb_to_hsvs(rgb_np, False) 98 assert hsv_base.shape == hsv_de.shape 99 assert_allclose(hsv_base.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0) 100 101 rgb_base = np.array([ 102 np.vectorize(colorsys.hsv_to_rgb)( 103 hsv_base[0, :, :].astype(np.float64), hsv_base[1, :, :].astype(np.float64), 104 hsv_base[2, :, :].astype(np.float64)) 105 ]) 106 rgb_base = rgb_base.reshape((3, 8, 8)) 107 rgb_de = util.hsv_to_rgbs(hsv_base, False) 108 assert rgb_de.shape == rgb_base.shape 109 assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0) 110 111 112def test_rgb_hsv_batch_chw(): 113 rgb_flat = generate_numpy_random_rgb((64, 3)).astype(np.float32) 114 rgb_imgs = rgb_flat.reshape((4, 3, 2, 8)) 115 hsv_base_imgs = np.array([ 116 np.vectorize(colorsys.rgb_to_hsv)( 117 img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64)) 118 for img in rgb_imgs 119 ]) 120 hsv_de = util.rgb_to_hsvs(rgb_imgs, False) 121 assert hsv_base_imgs.shape == hsv_de.shape 122 assert_allclose(hsv_base_imgs.flatten(), hsv_de.flatten(), rtol=1e-5, atol=0) 123 124 rgb_base = np.array([ 125 np.vectorize(colorsys.hsv_to_rgb)( 126 img[0, :, :].astype(np.float64), img[1, :, :].astype(np.float64), img[2, :, :].astype(np.float64)) 127 for img in hsv_base_imgs 128 ]) 129 rgb_de = util.hsv_to_rgbs(hsv_base_imgs, False) 130 assert rgb_base.shape == rgb_de.shape 131 assert_allclose(rgb_base.flatten(), rgb_de.flatten(), rtol=1e-5, atol=0) 132 133 134def test_rgb_hsv_pipeline(): 135 # First dataset 136 transforms1 = [ 137 vision.Decode(), 138 vision.Resize([64, 64]), 139 vision.ToTensor() 140 ] 141 transforms1 = mindspore.dataset.transforms.py_transforms.Compose(transforms1) 142 ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 143 ds1 = ds1.map(operations=transforms1, input_columns=["image"]) 144 145 # Second dataset 146 transforms2 = [ 147 vision.Decode(), 148 vision.Resize([64, 64]), 149 vision.ToTensor(), 150 vision.RgbToHsv(), 151 vision.HsvToRgb() 152 ] 153 transform2 = mindspore.dataset.transforms.py_transforms.Compose(transforms2) 154 ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) 155 ds2 = ds2.map(operations=transform2, input_columns=["image"]) 156 157 num_iter = 0 158 for data1, data2 in zip(ds1.create_dict_iterator(num_epochs=1), ds2.create_dict_iterator(num_epochs=1)): 159 num_iter += 1 160 ori_img = data1["image"].asnumpy() 161 cvt_img = data2["image"].asnumpy() 162 assert_allclose(ori_img.flatten(), cvt_img.flatten(), rtol=1e-5, atol=0) 163 assert ori_img.shape == cvt_img.shape 164 165 166if __name__ == "__main__": 167 test_rgb_hsv_hwc() 168 test_rgb_hsv_batch_hwc() 169 test_rgb_hsv_chw() 170 test_rgb_hsv_batch_chw() 171 test_rgb_hsv_pipeline() 172