• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17import mindspore.dataset as ds
18from mindspore import log as logger
19
20
21def test_batch_corner_cases():
22    def gen(num):
23        for i in range(num):
24            yield (np.array([i]),)
25
26    def test_repeat_batch(gen_num, repeats, batch_size, drop, res):
27        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(repeats).batch(batch_size, drop)
28        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
29            res.append(item["num"])
30
31    def test_batch_repeat(gen_num, repeats, batch_size, drop, res):
32        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, drop).repeat(repeats)
33        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
34            res.append(item["num"])
35
36    tst1, tst2, tst3, tst4 = [], [], [], []
37    # case 1 & 2, where batch_size is greater than the entire epoch, with drop equals to both val
38    test_repeat_batch(gen_num=2, repeats=4, batch_size=7, drop=False, res=tst1)
39    np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0], [1], [0]]), tst1[0], "\nATTENTION BATCH FAILED\n")
40    np.testing.assert_array_equal(np.array([[1]]), tst1[1], "\nATTENTION TEST BATCH FAILED\n")
41    assert len(tst1) == 2, "\nATTENTION TEST BATCH FAILED\n"
42    test_repeat_batch(gen_num=2, repeats=4, batch_size=5, drop=True, res=tst2)
43    np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0]]), tst2[0], "\nATTENTION BATCH FAILED\n")
44    assert len(tst2) == 1, "\nATTENTION TEST BATCH FAILED\n"
45    # case 3 & 4, batch before repeat with different drop
46    test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=True, res=tst3)
47    np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst3[0], "\nATTENTION BATCH FAILED\n")
48    np.testing.assert_array_equal(tst3[0], tst3[1], "\nATTENTION BATCH FAILED\n")
49    assert len(tst3) == 2, "\nATTENTION BATCH FAILED\n"
50    test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=False, res=tst4)
51    np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst4[0], "\nATTENTION BATCH FAILED\n")
52    np.testing.assert_array_equal(tst4[0], tst4[2], "\nATTENTION BATCH FAILED\n")
53    np.testing.assert_array_equal(tst4[1], np.array([[4]]), "\nATTENTION BATCH FAILED\n")
54    np.testing.assert_array_equal(tst4[1], tst4[3], "\nATTENTION BATCH FAILED\n")
55    assert len(tst4) == 4, "\nATTENTION BATCH FAILED\n"
56
57
58# each sub-test in this function is tested twice with exact parameter except that the second test passes each row
59# to a pyfunc which makes a deep copy of the row
60def test_variable_size_batch():
61    def check_res(arr1, arr2):
62        for ind, _ in enumerate(arr1):
63            if not np.array_equal(arr1[ind], np.array(arr2[ind])):
64                return False
65        return len(arr1) == len(arr2)
66
67    def gen(num):
68        for i in range(num):
69            yield (np.array([i]),)
70
71    def add_one_by_batch_num(batchInfo):
72        return batchInfo.get_batch_num() + 1
73
74    def add_one_by_epoch(batchInfo):
75        return batchInfo.get_epoch_num() + 1
76
77    def simple_copy(colList, batchInfo):
78        _ = batchInfo
79        return ([np.copy(arr) for arr in colList],)
80
81    def test_repeat_batch(gen_num, r, drop, func, res):
82        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r).batch(batch_size=func,
83                                                                                     drop_remainder=drop)
84        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
85            res.append(item["num"])
86
87    # same as test_repeat_batch except each row is passed through via a map which makes a copy of each element
88    def test_repeat_batch_with_copy_map(gen_num, r, drop, func):
89        res = []
90        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r) \
91            .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy)
92        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
93            res.append(item["num"])
94        return res
95
96    def test_batch_repeat(gen_num, r, drop, func, res):
97        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size=func, drop_remainder=drop).repeat(
98            r)
99        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
100            res.append(item["num"])
101
102    # same as test_batch_repeat except each row is passed through via a map which makes a copy of each element
103    def test_batch_repeat_with_copy_map(gen_num, r, drop, func):
104        res = []
105        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]) \
106            .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy).repeat(r)
107        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
108            res.append(item["num"])
109        return res
110
111    tst1, tst2, tst3, tst4, tst5, tst6, tst7 = [], [], [], [], [], [], []
112
113    # no repeat, simple var size, based on batch_num
114    test_repeat_batch(7, 1, True, add_one_by_batch_num, tst1)
115    assert check_res(tst1, [[[0]], [[1], [2]], [[3], [4], [5]]]), "\nATTENTION VAR BATCH FAILED\n"
116    assert check_res(tst1, test_repeat_batch_with_copy_map(7, 1, True, add_one_by_batch_num)), "\nMAP FAILED\n"
117    test_repeat_batch(9, 1, False, add_one_by_batch_num, tst2)
118    assert check_res(tst2, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]]), "\nATTENTION VAR BATCH FAILED\n"
119    assert check_res(tst2, test_repeat_batch_with_copy_map(9, 1, False, add_one_by_batch_num)), "\nMAP FAILED\n"
120    # batch after repeat, cross epoch batch
121    test_repeat_batch(7, 2, False, add_one_by_batch_num, tst3)
122    assert check_res(tst3, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [0], [1], [2]],
123                            [[3], [4], [5], [6]]]), "\nATTENTION VAR BATCH FAILED\n"
124    assert check_res(tst3, test_repeat_batch_with_copy_map(7, 2, False, add_one_by_batch_num)), "\nMAP FAILED\n"
125    # repeat after batch, no cross epoch batch, remainder dropped
126    test_batch_repeat(9, 7, True, add_one_by_batch_num, tst4)
127    assert check_res(tst4, [[[0]], [[1], [2]], [[3], [4], [5]]] * 7), "\nATTENTION VAR BATCH FAILED\n"
128    assert check_res(tst4, test_batch_repeat_with_copy_map(9, 7, True, add_one_by_batch_num)), "\nAMAP FAILED\n"
129    # repeat after batch, no cross epoch batch, remainder kept
130    test_batch_repeat(9, 3, False, add_one_by_batch_num, tst5)
131    assert check_res(tst5, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]] * 3), "\nATTENTION VAR BATCH FAILED\n"
132    assert check_res(tst5, test_batch_repeat_with_copy_map(9, 3, False, add_one_by_batch_num)), "\nMAP FAILED\n"
133    # batch_size based on epoch number, drop
134    test_batch_repeat(4, 4, True, add_one_by_epoch, tst6)
135    assert check_res(tst6, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]],
136                            [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n"
137    assert check_res(tst6, test_batch_repeat_with_copy_map(4, 4, True, add_one_by_epoch)), "\nMAP FAILED\n"
138    # batch_size based on epoch number, no drop
139    test_batch_repeat(4, 4, False, add_one_by_epoch, tst7)
140    assert check_res(tst7, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], [[3]],
141                            [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" + str(tst7)
142    assert check_res(tst7, test_batch_repeat_with_copy_map(4, 4, False, add_one_by_epoch)), "\nMAP FAILED\n"
143
144
145def test_basic_batch_map():
146    def check_res(arr1, arr2):
147        for ind, _ in enumerate(arr1):
148            if not np.array_equal(arr1[ind], np.array(arr2[ind])):
149                return False
150        return len(arr1) == len(arr2)
151
152    def gen(num):
153        for i in range(num):
154            yield (np.array([i]),)
155
156    def invert_sign_per_epoch(colList, batchInfo):
157        return ([np.copy(((-1) ** batchInfo.get_epoch_num()) * arr) for arr in colList],)
158
159    def invert_sign_per_batch(colList, batchInfo):
160        return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],)
161
162    def batch_map_config(num, r, batch_size, func, res):
163        data1 = ds.GeneratorDataset((lambda: gen(num)), ["num"]) \
164            .batch(batch_size=batch_size, input_columns=["num"], per_batch_map=func).repeat(r)
165        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
166            res.append(item["num"])
167
168    tst1, tst2, = [], []
169    batch_map_config(4, 2, 2, invert_sign_per_epoch, tst1)
170    assert check_res(tst1, [[[0], [1]], [[2], [3]], [[0], [-1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str(
171        tst1)
172    # each batch, the sign of a row is changed, test map is corrected performed according to its batch_num
173    batch_map_config(4, 2, 2, invert_sign_per_batch, tst2)
174    assert check_res(tst2,
175                     [[[0], [1]], [[-2], [-3]], [[0], [1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2)
176
177
178def test_batch_multi_col_map():
179    def check_res(arr1, arr2):
180        for ind, _ in enumerate(arr1):
181            if not np.array_equal(arr1[ind], np.array(arr2[ind])):
182                return False
183        return len(arr1) == len(arr2)
184
185    def gen(num):
186        for i in range(num):
187            yield (np.array([i]), np.array([i ** 2]))
188
189    def col1_col2_add_num(col1, col2, batchInfo):
190        _ = batchInfo
191        return ([[np.copy(arr + 100) for arr in col1],
192                 [np.copy(arr + 300) for arr in col2]])
193
194    def invert_sign_per_batch(colList, batchInfo):
195        return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],)
196
197    def invert_sign_per_batch_multi_col(col1, col2, batchInfo):
198        return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col1],
199                [np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col2])
200
201    def batch_map_config(num, r, batch_size, func, col_names, res):
202        data1 = ds.GeneratorDataset((lambda: gen(num)), ["num", "num_square"]) \
203            .batch(batch_size=batch_size, input_columns=col_names, per_batch_map=func).repeat(r)
204        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
205            res.append(np.array([item["num"], item["num_square"]]))
206
207    tst1, tst2, tst3, tst4 = [], [], [], []
208    batch_map_config(4, 2, 2, invert_sign_per_batch, ["num_square"], tst1)
209    assert check_res(tst1, [[[[0], [1]], [[0], [1]]], [[[2], [3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]],
210                            [[[2], [3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst1)
211
212    batch_map_config(4, 2, 2, invert_sign_per_batch_multi_col, ["num", "num_square"], tst2)
213    assert check_res(tst2, [[[[0], [1]], [[0], [1]]], [[[-2], [-3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]],
214                            [[[-2], [-3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2)
215
216    # the two tests below verify the order of the map.
217    # num_square column adds 100, num column adds 300.
218    batch_map_config(4, 3, 2, col1_col2_add_num, ["num_square", "num"], tst3)
219    assert check_res(tst3, [[[[300], [301]], [[100], [101]]],
220                            [[[302], [303]], [[104], [109]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst3)
221    # num column adds 100, num_square column adds 300.
222    batch_map_config(4, 3, 2, col1_col2_add_num, ["num", "num_square"], tst4)
223    assert check_res(tst4, [[[[100], [101]], [[300], [301]]],
224                            [[[102], [103]], [[304], [309]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst4)
225
226
227def test_var_batch_multi_col_map():
228    def check_res(arr1, arr2):
229        for ind, _ in enumerate(arr1):
230            if not np.array_equal(arr1[ind], np.array(arr2[ind])):
231                return False
232        return len(arr1) == len(arr2)
233
234    # gen 3 columns
235    # first column: 0, 3, 6, 9 ... ...
236    # second column:1, 4, 7, 10 ... ...
237    # third column: 2, 5, 8, 11 ... ...
238    def gen_3_cols(num):
239        for i in range(num):
240            yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2]))
241
242    # first epoch batch_size per batch: 1, 2 ,3 ... ...
243    # second epoch batch_size per batch: 2, 4, 6 ... ...
244    # third epoch batch_size per batch: 3, 6 ,9 ... ...
245    def batch_func(batchInfo):
246        return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1)
247
248    # multiply first col by batch_num, multiply second col by -batch_num
249    def map_func(col1, col2, batchInfo):
250        return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1],
251                [np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2])
252
253    def batch_map_config(num, r, fbatch, fmap, col_names, res):
254        data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \
255            .batch(batch_size=fbatch, input_columns=col_names, per_batch_map=fmap).repeat(r)
256        for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
257            res.append(np.array([item["col1"], item["col2"], item["col3"]]))
258
259    tst1 = []
260    tst1_res = [[[[0]], [[-1]], [[2]]], [[[6], [12]], [[-8], [-14]], [[5], [8]]],
261                [[[27], [36], [45]], [[-30], [-39], [-48]], [[11], [14], [17]]],
262                [[[72], [84], [96], [108]], [[-76], [-88], [-100], [-112]], [[20], [23], [26], [29]]]]
263    batch_map_config(10, 1, batch_func, map_func, ["col1", "col2"], tst1)
264    assert check_res(tst1, tst1_res), "test_var_batch_multi_col_map FAILED"
265
266
267def test_var_batch_var_resize():
268    # fake resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
269    def np_psedo_resize(col, batchInfo):
270        s = (batchInfo.get_batch_num() + 1) ** 2
271        return ([np.copy(c[0:s, 0:s, :]) for c in col],)
272
273    def add_one(batchInfo):
274        return batchInfo.get_batch_num() + 1
275
276    data1 = ds.ImageFolderDataset("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True)
277    data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize)
278    # i-th batch has shape [i, i^2, i^2, 3]
279    i = 1
280    for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
281        assert item["image"].shape == (i, i ** 2, i ** 2, 3), "\ntest_var_batch_var_resize FAILED\n"
282        i += 1
283
284
285def test_exception():
286    def gen(num):
287        for i in range(num):
288            yield (np.array([i]),)
289
290    def bad_batch_size(batchInfo):
291        raise StopIteration
292        # return batchInfo.get_batch_num()
293
294    def bad_map_func(col, batchInfo):
295        raise StopIteration
296        # return (col,)
297
298    data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size)
299    try:
300        for _ in data1.create_dict_iterator(num_epochs=1):
301            pass
302        assert False
303    except RuntimeError:
304        pass
305
306    data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func)
307    try:
308        for _ in data2.create_dict_iterator(num_epochs=1):
309            pass
310        assert False
311    except RuntimeError:
312        pass
313
314
315def test_multi_col_map():
316    def gen_2_cols(num):
317        for i in range(1, 1 + num):
318            yield (np.array([i]), np.array([i ** 2]))
319
320    def split_col(col, batchInfo):
321        return ([np.copy(arr) for arr in col], [np.copy(-arr) for arr in col])
322
323    def merge_col(col1, col2, batchInfo):
324        merged = []
325        for k, v in enumerate(col1):
326            merged.append(np.array(v + col2[k]))
327        return (merged,)
328
329    def swap_col(col1, col2, batchInfo):
330        return ([np.copy(a) for a in col2], [np.copy(b) for b in col1])
331
332    def batch_map_config(num, s, f, in_nms, out_nms, col_order=None):
333        try:
334            dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"])
335            dst = dst.batch(batch_size=s, input_columns=in_nms, output_columns=out_nms, per_batch_map=f,
336                            column_order=col_order)
337            res = []
338            for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
339                res.append(row)
340            return res
341        except (ValueError, RuntimeError, TypeError) as e:
342            return str(e)
343
344    # split 1 col into 2 cols
345    res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"])[0]
346    assert np.array_equal(res["col1"], [[1], [2]])
347    assert np.array_equal(res["col_x"], [[1], [4]]) and np.array_equal(res["col_y"], [[-1], [-4]])
348
349    # merge 2 cols into 1 col
350    res = batch_map_config(4, 4, merge_col, ["col1", "col2"], ["merged"])[0]
351    assert np.array_equal(res["merged"], [[2], [6], [12], [20]])
352
353    # swap once
354    res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col1", "col2"])[0]
355    assert np.array_equal(res["col1"], [[1], [4], [9]]) and np.array_equal(res["col2"], [[1], [2], [3]])
356
357    # swap twice
358    res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col2", "col1"])[0]
359    assert np.array_equal(res["col2"], [[1], [4], [9]]) and np.array_equal(res["col1"], [[1], [2], [3]])
360
361    # test project after map
362    res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col_x", "col_y", "col1"])[0]
363    assert list(res.keys()) == ["col_x", "col_y", "col1"]
364
365    # test the insertion order is maintained
366    res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col1", "col_x", "col_y"])[0]
367    assert list(res.keys()) == ["col1", "col_x", "col_y"]
368
369    # test exceptions
370    assert "output_columns with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], 233)
371    assert "column_order with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], ["col1"], 233)
372    assert "output_columns in batch is not set correctly" in batch_map_config(2, 2, split_col, ["col2"], ["col1"])
373    assert "Incorrect number of columns" in batch_map_config(2, 2, split_col, ["col2"], ["col3", "col4", "col5"])
374    assert "col-1 doesn't exist" in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])
375
376
377def test_exceptions_2():
378    def gen(num):
379        for i in range(num):
380            yield (np.array([i]),)
381
382    def simple_copy(colList, batchInfo):
383        return ([np.copy(arr) for arr in colList],)
384
385    def concat_copy(colList, batchInfo):
386        # this will duplicate the number of rows returned, which would be wrong!
387        return ([np.copy(arr) for arr in colList] * 2,)
388
389    def shrink_copy(colList, batchInfo):
390        # this will duplicate the number of rows returned, which would be wrong!
391        return ([np.copy(arr) for arr in colList][0:int(len(colList) / 2)],)
392
393    def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map):
394        data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols,
395                                                                           per_batch_map=per_batch_map)
396        try:
397            for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
398                pass
399            return "success"
400        except RuntimeError as e:
401            return str(e)
402
403    # test exception where column name is incorrect
404    assert "col:num1 doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy)
405    assert "expects: 2 rows returned from per_batch_map, got: 4" in test_exceptions_config(4, 2, ["num"], concat_copy)
406    assert "expects: 4 rows returned from per_batch_map, got: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy)
407
408
409if __name__ == '__main__':
410    logger.info("Running test_var_batch_map.py test_batch_corner_cases() function")
411    test_batch_corner_cases()
412
413    logger.info("Running test_var_batch_map.py test_variable_size_batch() function")
414    test_variable_size_batch()
415
416    logger.info("Running test_var_batch_map.py test_basic_batch_map() function")
417    test_basic_batch_map()
418
419    logger.info("Running test_var_batch_map.py test_batch_multi_col_map() function")
420    test_batch_multi_col_map()
421
422    logger.info("Running test_var_batch_map.py tesgit t_var_batch_multi_col_map() function")
423    test_var_batch_multi_col_map()
424
425    logger.info("Running test_var_batch_map.py test_var_batch_var_resize() function")
426    test_var_batch_var_resize()
427
428    logger.info("Running test_var_batch_map.py test_exception() function")
429    test_exception()
430
431    logger.info("Running test_var_batch_map.py test_multi_col_map() function")
432    test_multi_col_map()
433
434    logger.info("Running test_var_batch_map.py test_exceptions_2() function")
435    test_exceptions_2()
436