1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import matplotlib.pyplot as plt 16import numpy as np 17import pytest 18 19import mindspore.dataset as ds 20import mindspore.dataset.vision.c_transforms as c_vision 21 22 23DATASET_DIR = "../data/dataset/testDIV2KData/div2k" 24 25 26def test_div2k_basic(plot=False): 27 usage = "train" # train, valid, all 28 downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild 29 scale = 2 # 2, 3, 4, 8 30 31 data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True) 32 count = 0 33 hr_images_list = [] 34 lr_images_list = [] 35 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): 36 hr_images_list.append(item['hr_image']) 37 lr_images_list.append(item['lr_image']) 38 count = count + 1 39 assert count == 5 40 if plot: 41 flag = "{}_{}_{}".format(usage, scale, downgrade) 42 visualize_dataset(hr_images_list, lr_images_list, flag) 43 44 45def visualize_dataset(hr_images_list, lr_images_list, flag): 46 """ 47 Helper function to visualize the dataset samples 48 """ 49 image_num = len(hr_images_list) 50 for i in range(image_num): 51 plt.subplot(121) 52 plt.imshow(hr_images_list[i]) 53 plt.title('Original') 54 plt.subplot(122) 55 plt.imshow(lr_images_list[i]) 56 plt.title(flag) 57 plt.savefig('./div2k_{}_{}.jpg'.format(flag, str(i))) 58 59 60def test_div2k_basic_func(): 61 # case 0: test usage equal to `all` 62 usage = "all" # train, valid, all 63 downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild 64 scale = 2 # 2, 3, 4, 8 65 66 data0 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale) 67 num_iter0 = 0 68 for _ in data0.create_dict_iterator(num_epochs=1): 69 num_iter0 += 1 70 assert num_iter0 == 6 71 72 # case 1: test num_samples 73 usage = "train" # train, valid, all 74 75 data1 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=4) 76 num_iter1 = 0 77 for _ in data1.create_dict_iterator(num_epochs=1): 78 num_iter1 += 1 79 assert num_iter1 == 4 80 81 # case 2: test repeat 82 data2 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=3) 83 data2 = data2.repeat(5) 84 num_iter2 = 0 85 for _ in data2.create_dict_iterator(num_epochs=1): 86 num_iter2 += 1 87 assert num_iter2 == 15 88 89 # case 3: test batch with drop_remainder=False 90 data3 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True) 91 assert data3.get_dataset_size() == 5 92 assert data3.get_batch_size() == 1 93 resize_op = c_vision.Resize([100, 100]) 94 data3 = data3.map(operations=resize_op, input_columns=["hr_image"], num_parallel_workers=1) 95 data3 = data3.map(operations=resize_op, input_columns=["lr_image"], num_parallel_workers=1) 96 data3 = data3.batch(batch_size=3) # drop_remainder is default to be False 97 assert data3.get_dataset_size() == 2 98 assert data3.get_batch_size() == 3 99 num_iter3 = 0 100 for _ in data3.create_dict_iterator(num_epochs=1): 101 num_iter3 += 1 102 assert num_iter3 == 2 103 104 # case 4: test batch with drop_remainder=True 105 data4 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, decode=True) 106 assert data4.get_dataset_size() == 5 107 assert data4.get_batch_size() == 1 108 data4 = data4.map(operations=resize_op, input_columns=["hr_image"], num_parallel_workers=1) 109 data4 = data4.map(operations=resize_op, input_columns=["lr_image"], num_parallel_workers=1) 110 data4 = data4.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped 111 assert data4.get_dataset_size() == 1 112 assert data4.get_batch_size() == 3 113 num_iter4 = 0 114 for _ in data4.create_dict_iterator(num_epochs=1): 115 num_iter4 += 1 116 assert num_iter4 == 1 117 118 # case 5: test get_col_names 119 data5 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_samples=1) 120 assert data5.get_col_names() == ["hr_image", "lr_image"] 121 122 123def test_div2k_sequential_sampler(): 124 """ 125 Test DIV2KDataset with SequentialSampler 126 """ 127 usage = "train" # train, valid, all 128 downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild 129 scale = 2 # 2, 3, 4, 8 130 131 num_samples = 2 132 sampler = ds.SequentialSampler(num_samples=num_samples) 133 data1 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, sampler=sampler) 134 data2 = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False, 135 num_samples=num_samples) 136 num_iter = 0 137 for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), 138 data2.create_dict_iterator(num_epochs=1, output_numpy=True)): 139 np.testing.assert_array_equal(item1["hr_image"], item2["hr_image"]) 140 np.testing.assert_array_equal(item1["lr_image"], item2["lr_image"]) 141 num_iter += 1 142 assert num_iter == num_samples 143 144 145def test_div2k_exception(): 146 usage = "train" # train, valid, all 147 downgrade = "bicubic" # bicubic, unknown, mild, difficult, wild 148 scale = 2 # 2, 3, 4, 8 149 150 error_msg_1 = "does not exist or is not a directory or permission denied!" 151 with pytest.raises(ValueError, match=error_msg_1): 152 ds.DIV2KDataset("NoExistsDir", usage=usage, downgrade=downgrade, scale=scale) 153 154 error_msg_2 = r"Input usage is not within the valid set of \['train', 'valid', 'all'\]." 155 with pytest.raises(ValueError, match=error_msg_2): 156 ds.DIV2KDataset(DATASET_DIR, usage="test", downgrade=downgrade, scale=scale) 157 158 error_msg_3 = r"Input scale is not within the valid set of \[2, 3, 4, 8\]." 159 with pytest.raises(ValueError, match=error_msg_3): 160 ds.DIV2KDataset(DATASET_DIR, usage=usage, scale=16, downgrade=downgrade) 161 162 error_msg_4 = r"Input downgrade is not within the valid set of .*" 163 with pytest.raises(ValueError, match=error_msg_4): 164 ds.DIV2KDataset(DATASET_DIR, usage=usage, scale=scale, downgrade="downgrade") 165 166 error_msg_5 = "sampler and shuffle cannot be specified at the same time" 167 with pytest.raises(RuntimeError, match=error_msg_5): 168 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False, 169 sampler=ds.PKSampler(3)) 170 171 error_msg_6 = "sampler and sharding cannot be specified at the same time" 172 with pytest.raises(RuntimeError, match=error_msg_6): 173 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id=0, 174 sampler=ds.PKSampler(3)) 175 176 error_msg_7 = "num_shards is specified and currently requires shard_id as well" 177 with pytest.raises(RuntimeError, match=error_msg_7): 178 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=10) 179 180 error_msg_8 = "shard_id is specified but num_shards is not" 181 with pytest.raises(RuntimeError, match=error_msg_8): 182 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shard_id=0) 183 184 error_msg_9 = "Input shard_id is not within the required interval" 185 with pytest.raises(ValueError, match=error_msg_9): 186 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=5, shard_id=-1) 187 with pytest.raises(ValueError, match=error_msg_9): 188 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=5, shard_id=5) 189 with pytest.raises(ValueError, match=error_msg_9): 190 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id=5) 191 192 error_msg_10 = "num_parallel_workers exceeds" 193 with pytest.raises(ValueError, match=error_msg_10): 194 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False, 195 num_parallel_workers=0) 196 with pytest.raises(ValueError, match=error_msg_10): 197 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False, 198 num_parallel_workers=256) 199 with pytest.raises(ValueError, match=error_msg_10): 200 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, shuffle=False, 201 num_parallel_workers=-2) 202 203 error_msg_11 = "Argument shard_id" 204 with pytest.raises(TypeError, match=error_msg_11): 205 ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale, num_shards=2, shard_id="0") 206 207 def exception_func(item): 208 raise Exception("Error occur!") 209 210 try: 211 data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale) 212 data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1) 213 num_rows = 0 214 for _ in data.create_dict_iterator(): 215 num_rows += 1 216 assert False 217 except RuntimeError as e: 218 assert "map operation: [PyFunc] failed. The corresponding data files:" in str(e) 219 220 try: 221 data = ds.DIV2KDataset(DATASET_DIR, usage=usage, downgrade=downgrade, scale=scale) 222 data = data.map(operations=exception_func, input_columns=["hr_image"], num_parallel_workers=1) 223 num_rows = 0 224 for _ in data.create_dict_iterator(): 225 num_rows += 1 226 assert False 227 except RuntimeError as e: 228 assert "map operation: [PyFunc] failed. The corresponding data files:" in str(e) 229 230 231if __name__ == "__main__": 232 test_div2k_basic() 233 test_div2k_basic_func() 234 test_div2k_sequential_sampler() 235 test_div2k_exception() 236