1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17import mindspore.dataset as ds 18from mindspore import log as logger 19 20 21def test_batch_corner_cases(): 22 def gen(num): 23 for i in range(num): 24 yield (np.array([i]),) 25 26 def test_repeat_batch(gen_num, repeats, batch_size, drop, res): 27 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(repeats).batch(batch_size, drop) 28 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 29 res.append(item["num"]) 30 31 def test_batch_repeat(gen_num, repeats, batch_size, drop, res): 32 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, drop).repeat(repeats) 33 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 34 res.append(item["num"]) 35 36 tst1, tst2, tst3, tst4 = [], [], [], [] 37 # case 1 & 2, where batch_size is greater than the entire epoch, with drop equals to both val 38 test_repeat_batch(gen_num=2, repeats=4, batch_size=7, drop=False, res=tst1) 39 np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0], [1], [0]]), tst1[0], "\nATTENTION BATCH FAILED\n") 40 np.testing.assert_array_equal(np.array([[1]]), tst1[1], "\nATTENTION TEST BATCH FAILED\n") 41 assert len(tst1) == 2, "\nATTENTION TEST BATCH FAILED\n" 42 test_repeat_batch(gen_num=2, repeats=4, batch_size=5, drop=True, res=tst2) 43 np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0]]), tst2[0], "\nATTENTION BATCH FAILED\n") 44 assert len(tst2) == 1, "\nATTENTION TEST BATCH FAILED\n" 45 # case 3 & 4, batch before repeat with different drop 46 test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=True, res=tst3) 47 np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst3[0], "\nATTENTION BATCH FAILED\n") 48 np.testing.assert_array_equal(tst3[0], tst3[1], "\nATTENTION BATCH FAILED\n") 49 assert len(tst3) == 2, "\nATTENTION BATCH FAILED\n" 50 test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=False, res=tst4) 51 np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst4[0], "\nATTENTION BATCH FAILED\n") 52 np.testing.assert_array_equal(tst4[0], tst4[2], "\nATTENTION BATCH FAILED\n") 53 np.testing.assert_array_equal(tst4[1], np.array([[4]]), "\nATTENTION BATCH FAILED\n") 54 np.testing.assert_array_equal(tst4[1], tst4[3], "\nATTENTION BATCH FAILED\n") 55 assert len(tst4) == 4, "\nATTENTION BATCH FAILED\n" 56 57 58# each sub-test in this function is tested twice with exact parameter except that the second test passes each row 59# to a pyfunc which makes a deep copy of the row 60def test_variable_size_batch(): 61 def check_res(arr1, arr2): 62 for ind, _ in enumerate(arr1): 63 if not np.array_equal(arr1[ind], np.array(arr2[ind])): 64 return False 65 return len(arr1) == len(arr2) 66 67 def gen(num): 68 for i in range(num): 69 yield (np.array([i]),) 70 71 def add_one_by_batch_num(batchInfo): 72 return batchInfo.get_batch_num() + 1 73 74 def add_one_by_epoch(batchInfo): 75 return batchInfo.get_epoch_num() + 1 76 77 def simple_copy(colList, batchInfo): 78 _ = batchInfo 79 return ([np.copy(arr) for arr in colList],) 80 81 def test_repeat_batch(gen_num, r, drop, func, res): 82 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r).batch(batch_size=func, 83 drop_remainder=drop) 84 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 85 res.append(item["num"]) 86 87 # same as test_repeat_batch except each row is passed through via a map which makes a copy of each element 88 def test_repeat_batch_with_copy_map(gen_num, r, drop, func): 89 res = [] 90 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r) \ 91 .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy) 92 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 93 res.append(item["num"]) 94 return res 95 96 def test_batch_repeat(gen_num, r, drop, func, res): 97 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size=func, drop_remainder=drop).repeat( 98 r) 99 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 100 res.append(item["num"]) 101 102 # same as test_batch_repeat except each row is passed through via a map which makes a copy of each element 103 def test_batch_repeat_with_copy_map(gen_num, r, drop, func): 104 res = [] 105 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]) \ 106 .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy).repeat(r) 107 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 108 res.append(item["num"]) 109 return res 110 111 tst1, tst2, tst3, tst4, tst5, tst6, tst7 = [], [], [], [], [], [], [] 112 113 # no repeat, simple var size, based on batch_num 114 test_repeat_batch(7, 1, True, add_one_by_batch_num, tst1) 115 assert check_res(tst1, [[[0]], [[1], [2]], [[3], [4], [5]]]), "\nATTENTION VAR BATCH FAILED\n" 116 assert check_res(tst1, test_repeat_batch_with_copy_map(7, 1, True, add_one_by_batch_num)), "\nMAP FAILED\n" 117 test_repeat_batch(9, 1, False, add_one_by_batch_num, tst2) 118 assert check_res(tst2, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]]), "\nATTENTION VAR BATCH FAILED\n" 119 assert check_res(tst2, test_repeat_batch_with_copy_map(9, 1, False, add_one_by_batch_num)), "\nMAP FAILED\n" 120 # batch after repeat, cross epoch batch 121 test_repeat_batch(7, 2, False, add_one_by_batch_num, tst3) 122 assert check_res(tst3, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [0], [1], [2]], 123 [[3], [4], [5], [6]]]), "\nATTENTION VAR BATCH FAILED\n" 124 assert check_res(tst3, test_repeat_batch_with_copy_map(7, 2, False, add_one_by_batch_num)), "\nMAP FAILED\n" 125 # repeat after batch, no cross epoch batch, remainder dropped 126 test_batch_repeat(9, 7, True, add_one_by_batch_num, tst4) 127 assert check_res(tst4, [[[0]], [[1], [2]], [[3], [4], [5]]] * 7), "\nATTENTION VAR BATCH FAILED\n" 128 assert check_res(tst4, test_batch_repeat_with_copy_map(9, 7, True, add_one_by_batch_num)), "\nAMAP FAILED\n" 129 # repeat after batch, no cross epoch batch, remainder kept 130 test_batch_repeat(9, 3, False, add_one_by_batch_num, tst5) 131 assert check_res(tst5, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]] * 3), "\nATTENTION VAR BATCH FAILED\n" 132 assert check_res(tst5, test_batch_repeat_with_copy_map(9, 3, False, add_one_by_batch_num)), "\nMAP FAILED\n" 133 # batch_size based on epoch number, drop 134 test_batch_repeat(4, 4, True, add_one_by_epoch, tst6) 135 assert check_res(tst6, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], 136 [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" 137 assert check_res(tst6, test_batch_repeat_with_copy_map(4, 4, True, add_one_by_epoch)), "\nMAP FAILED\n" 138 # batch_size based on epoch number, no drop 139 test_batch_repeat(4, 4, False, add_one_by_epoch, tst7) 140 assert check_res(tst7, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], [[3]], 141 [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" + str(tst7) 142 assert check_res(tst7, test_batch_repeat_with_copy_map(4, 4, False, add_one_by_epoch)), "\nMAP FAILED\n" 143 144 145def test_basic_batch_map(): 146 def check_res(arr1, arr2): 147 for ind, _ in enumerate(arr1): 148 if not np.array_equal(arr1[ind], np.array(arr2[ind])): 149 return False 150 return len(arr1) == len(arr2) 151 152 def gen(num): 153 for i in range(num): 154 yield (np.array([i]),) 155 156 def invert_sign_per_epoch(colList, batchInfo): 157 return ([np.copy(((-1) ** batchInfo.get_epoch_num()) * arr) for arr in colList],) 158 159 def invert_sign_per_batch(colList, batchInfo): 160 return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],) 161 162 def batch_map_config(num, r, batch_size, func, res): 163 data1 = ds.GeneratorDataset((lambda: gen(num)), ["num"]) \ 164 .batch(batch_size=batch_size, input_columns=["num"], per_batch_map=func).repeat(r) 165 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 166 res.append(item["num"]) 167 168 tst1, tst2, = [], [] 169 batch_map_config(4, 2, 2, invert_sign_per_epoch, tst1) 170 assert check_res(tst1, [[[0], [1]], [[2], [3]], [[0], [-1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str( 171 tst1) 172 # each batch, the sign of a row is changed, test map is corrected performed according to its batch_num 173 batch_map_config(4, 2, 2, invert_sign_per_batch, tst2) 174 assert check_res(tst2, 175 [[[0], [1]], [[-2], [-3]], [[0], [1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2) 176 177 178def test_batch_multi_col_map(): 179 def check_res(arr1, arr2): 180 for ind, _ in enumerate(arr1): 181 if not np.array_equal(arr1[ind], np.array(arr2[ind])): 182 return False 183 return len(arr1) == len(arr2) 184 185 def gen(num): 186 for i in range(num): 187 yield (np.array([i]), np.array([i ** 2])) 188 189 def col1_col2_add_num(col1, col2, batchInfo): 190 _ = batchInfo 191 return ([[np.copy(arr + 100) for arr in col1], 192 [np.copy(arr + 300) for arr in col2]]) 193 194 def invert_sign_per_batch(colList, batchInfo): 195 return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],) 196 197 def invert_sign_per_batch_multi_col(col1, col2, batchInfo): 198 return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col1], 199 [np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col2]) 200 201 def batch_map_config(num, r, batch_size, func, col_names, res): 202 data1 = ds.GeneratorDataset((lambda: gen(num)), ["num", "num_square"]) \ 203 .batch(batch_size=batch_size, input_columns=col_names, per_batch_map=func).repeat(r) 204 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 205 res.append(np.array([item["num"], item["num_square"]])) 206 207 tst1, tst2, tst3, tst4 = [], [], [], [] 208 batch_map_config(4, 2, 2, invert_sign_per_batch, ["num_square"], tst1) 209 assert check_res(tst1, [[[[0], [1]], [[0], [1]]], [[[2], [3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]], 210 [[[2], [3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst1) 211 212 batch_map_config(4, 2, 2, invert_sign_per_batch_multi_col, ["num", "num_square"], tst2) 213 assert check_res(tst2, [[[[0], [1]], [[0], [1]]], [[[-2], [-3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]], 214 [[[-2], [-3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2) 215 216 # the two tests below verify the order of the map. 217 # num_square column adds 100, num column adds 300. 218 batch_map_config(4, 3, 2, col1_col2_add_num, ["num_square", "num"], tst3) 219 assert check_res(tst3, [[[[300], [301]], [[100], [101]]], 220 [[[302], [303]], [[104], [109]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst3) 221 # num column adds 100, num_square column adds 300. 222 batch_map_config(4, 3, 2, col1_col2_add_num, ["num", "num_square"], tst4) 223 assert check_res(tst4, [[[[100], [101]], [[300], [301]]], 224 [[[102], [103]], [[304], [309]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst4) 225 226 227def test_var_batch_multi_col_map(): 228 def check_res(arr1, arr2): 229 for ind, _ in enumerate(arr1): 230 if not np.array_equal(arr1[ind], np.array(arr2[ind])): 231 return False 232 return len(arr1) == len(arr2) 233 234 # gen 3 columns 235 # first column: 0, 3, 6, 9 ... ... 236 # second column:1, 4, 7, 10 ... ... 237 # third column: 2, 5, 8, 11 ... ... 238 def gen_3_cols(num): 239 for i in range(num): 240 yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2])) 241 242 # first epoch batch_size per batch: 1, 2 ,3 ... ... 243 # second epoch batch_size per batch: 2, 4, 6 ... ... 244 # third epoch batch_size per batch: 3, 6 ,9 ... ... 245 def batch_func(batchInfo): 246 return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1) 247 248 # multiply first col by batch_num, multiply second col by -batch_num 249 def map_func(col1, col2, batchInfo): 250 return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1], 251 [np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2]) 252 253 def batch_map_config(num, r, fbatch, fmap, col_names, res): 254 data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \ 255 .batch(batch_size=fbatch, input_columns=col_names, per_batch_map=fmap).repeat(r) 256 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 257 res.append(np.array([item["col1"], item["col2"], item["col3"]])) 258 259 tst1 = [] 260 tst1_res = [[[[0]], [[-1]], [[2]]], [[[6], [12]], [[-8], [-14]], [[5], [8]]], 261 [[[27], [36], [45]], [[-30], [-39], [-48]], [[11], [14], [17]]], 262 [[[72], [84], [96], [108]], [[-76], [-88], [-100], [-112]], [[20], [23], [26], [29]]]] 263 batch_map_config(10, 1, batch_func, map_func, ["col1", "col2"], tst1) 264 assert check_res(tst1, tst1_res), "test_var_batch_multi_col_map FAILED" 265 266 267def test_var_batch_var_resize(): 268 # fake resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25) 269 def np_psedo_resize(col, batchInfo): 270 s = (batchInfo.get_batch_num() + 1) ** 2 271 return ([np.copy(c[0:s, 0:s, :]) for c in col],) 272 273 def add_one(batchInfo): 274 return batchInfo.get_batch_num() + 1 275 276 data1 = ds.ImageFolderDataset("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True) 277 data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize) 278 # i-th batch has shape [i, i^2, i^2, 3] 279 i = 1 280 for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 281 assert item["image"].shape == (i, i ** 2, i ** 2, 3), "\ntest_var_batch_var_resize FAILED\n" 282 i += 1 283 284 285def test_exception(): 286 def gen(num): 287 for i in range(num): 288 yield (np.array([i]),) 289 290 def bad_batch_size(batchInfo): 291 raise StopIteration 292 # return batchInfo.get_batch_num() 293 294 def bad_map_func(col, batchInfo): 295 raise StopIteration 296 # return (col,) 297 298 data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size) 299 try: 300 for _ in data1.create_dict_iterator(num_epochs=1): 301 pass 302 assert False 303 except RuntimeError: 304 pass 305 306 data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func) 307 try: 308 for _ in data2.create_dict_iterator(num_epochs=1): 309 pass 310 assert False 311 except RuntimeError: 312 pass 313 314 315def test_multi_col_map(): 316 def gen_2_cols(num): 317 for i in range(1, 1 + num): 318 yield (np.array([i]), np.array([i ** 2])) 319 320 def split_col(col, batchInfo): 321 return ([np.copy(arr) for arr in col], [np.copy(-arr) for arr in col]) 322 323 def merge_col(col1, col2, batchInfo): 324 merged = [] 325 for k, v in enumerate(col1): 326 merged.append(np.array(v + col2[k])) 327 return (merged,) 328 329 def swap_col(col1, col2, batchInfo): 330 return ([np.copy(a) for a in col2], [np.copy(b) for b in col1]) 331 332 def batch_map_config(num, s, f, in_nms, out_nms, col_order=None): 333 try: 334 dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"]) 335 dst = dst.batch(batch_size=s, input_columns=in_nms, output_columns=out_nms, per_batch_map=f, 336 column_order=col_order) 337 res = [] 338 for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True): 339 res.append(row) 340 return res 341 except (ValueError, RuntimeError, TypeError) as e: 342 return str(e) 343 344 # split 1 col into 2 cols 345 res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"])[0] 346 assert np.array_equal(res["col1"], [[1], [2]]) 347 assert np.array_equal(res["col_x"], [[1], [4]]) and np.array_equal(res["col_y"], [[-1], [-4]]) 348 349 # merge 2 cols into 1 col 350 res = batch_map_config(4, 4, merge_col, ["col1", "col2"], ["merged"])[0] 351 assert np.array_equal(res["merged"], [[2], [6], [12], [20]]) 352 353 # swap once 354 res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col1", "col2"])[0] 355 assert np.array_equal(res["col1"], [[1], [4], [9]]) and np.array_equal(res["col2"], [[1], [2], [3]]) 356 357 # swap twice 358 res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col2", "col1"])[0] 359 assert np.array_equal(res["col2"], [[1], [4], [9]]) and np.array_equal(res["col1"], [[1], [2], [3]]) 360 361 # test project after map 362 res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col_x", "col_y", "col1"])[0] 363 assert list(res.keys()) == ["col_x", "col_y", "col1"] 364 365 # test the insertion order is maintained 366 res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col1", "col_x", "col_y"])[0] 367 assert list(res.keys()) == ["col1", "col_x", "col_y"] 368 369 # test exceptions 370 assert "output_columns with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], 233) 371 assert "column_order with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], ["col1"], 233) 372 assert "output_columns in batch is not set correctly" in batch_map_config(2, 2, split_col, ["col2"], ["col1"]) 373 assert "Incorrect number of columns" in batch_map_config(2, 2, split_col, ["col2"], ["col3", "col4", "col5"]) 374 assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"]) 375 376 377def test_exceptions_2(): 378 def gen(num): 379 for i in range(num): 380 yield (np.array([i]),) 381 382 def simple_copy(colList, batchInfo): 383 return ([np.copy(arr) for arr in colList],) 384 385 def concat_copy(colList, batchInfo): 386 # this will duplicate the number of rows returned, which would be wrong! 387 return ([np.copy(arr) for arr in colList] * 2,) 388 389 def shrink_copy(colList, batchInfo): 390 # this will duplicate the number of rows returned, which would be wrong! 391 return ([np.copy(arr) for arr in colList][0:int(len(colList) / 2)],) 392 393 def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map): 394 data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols, 395 per_batch_map=per_batch_map) 396 try: 397 for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): 398 pass 399 return "success" 400 except RuntimeError as e: 401 return str(e) 402 403 # test exception where column name is incorrect 404 assert "col:num1 doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy) 405 assert "expects: 2 rows returned from per_batch_map, got: 4" in test_exceptions_config(4, 2, ["num"], concat_copy) 406 assert "expects: 4 rows returned from per_batch_map, got: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy) 407 408 409if __name__ == '__main__': 410 logger.info("Running test_var_batch_map.py test_batch_corner_cases() function") 411 test_batch_corner_cases() 412 413 logger.info("Running test_var_batch_map.py test_variable_size_batch() function") 414 test_variable_size_batch() 415 416 logger.info("Running test_var_batch_map.py test_basic_batch_map() function") 417 test_basic_batch_map() 418 419 logger.info("Running test_var_batch_map.py test_batch_multi_col_map() function") 420 test_batch_multi_col_map() 421 422 logger.info("Running test_var_batch_map.py tesgit t_var_batch_multi_col_map() function") 423 test_var_batch_multi_col_map() 424 425 logger.info("Running test_var_batch_map.py test_var_batch_var_resize() function") 426 test_var_batch_var_resize() 427 428 logger.info("Running test_var_batch_map.py test_exception() function") 429 test_exception() 430 431 logger.info("Running test_var_batch_map.py test_multi_col_map() function") 432 test_multi_col_map() 433 434 logger.info("Running test_var_batch_map.py test_exceptions_2() function") 435 test_exceptions_2() 436