• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import mindspore.dataset as ds
16from mindspore import log as logger
17from util import save_and_check_dict
18
19# Note: Number of rows in test.data dataset:  12
20DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
21GENERATE_GOLDEN = False
22
23
24def test_batch_01():
25    """
26    Test batch: batch_size>1, drop_remainder=True, no remainder exists
27    """
28    logger.info("test_batch_01")
29    # define parameters
30    batch_size = 2
31    drop_remainder = True
32
33    # apply dataset operations
34    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
35    data1 = data1.batch(batch_size, drop_remainder)
36
37    assert sum([1 for _ in data1]) == 6
38    filename = "batch_01_result.npz"
39    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
40
41
42def test_batch_02():
43    """
44    Test batch: batch_size>1, drop_remainder=True, remainder exists
45    """
46    logger.info("test_batch_02")
47    # define parameters
48    batch_size = 5
49    drop_remainder = True
50
51    # apply dataset operations
52    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
53    data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
54
55    assert sum([1 for _ in data1]) == 2
56    filename = "batch_02_result.npz"
57    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
58
59
60def test_batch_03():
61    """
62    Test batch: batch_size>1, drop_remainder=False, no remainder exists
63    """
64    logger.info("test_batch_03")
65    # define parameters
66    batch_size = 3
67    drop_remainder = False
68
69    # apply dataset operations
70    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
71    data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)
72
73    assert sum([1 for _ in data1]) == 4
74    filename = "batch_03_result.npz"
75    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
76
77
78def test_batch_04():
79    """
80    Test batch: batch_size>1, drop_remainder=False, remainder exists
81    """
82    logger.info("test_batch_04")
83    # define parameters
84    batch_size = 7
85    drop_remainder = False
86
87    # apply dataset operations
88    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
89    data1 = data1.batch(batch_size, drop_remainder)
90
91    assert sum([1 for _ in data1]) == 2
92    filename = "batch_04_result.npz"
93    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
94
95
96def test_batch_05():
97    """
98    Test batch: batch_size=1 (minimum valid size), drop_remainder default
99    """
100    logger.info("test_batch_05")
101    # define parameters
102    batch_size = 1
103
104    # apply dataset operations
105    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
106    data1 = data1.batch(batch_size)
107
108    assert sum([1 for _ in data1]) == 12
109    filename = "batch_05_result.npz"
110    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
111
112
113def test_batch_06():
114    """
115    Test batch: batch_size = number-of-rows-in-dataset, drop_remainder=True, reorder params
116    """
117    logger.info("test_batch_06")
118    # define parameters
119    batch_size = 12
120    drop_remainder = False
121
122    # apply dataset operations
123    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
124    data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)
125
126    assert sum([1 for _ in data1]) == 1
127    filename = "batch_06_result.npz"
128    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
129
130
131def test_batch_07():
132    """
133    Test batch: num_parallel_workers>1, drop_remainder=False, reorder params
134    """
135    logger.info("test_batch_07")
136    # define parameters
137    batch_size = 4
138    drop_remainder = False
139    num_parallel_workers = 2
140
141    # apply dataset operations
142    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
143    data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
144                        batch_size=batch_size)
145
146    assert sum([1 for _ in data1]) == 3
147    filename = "batch_07_result.npz"
148    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
149
150
151def test_batch_08():
152    """
153    Test batch: num_parallel_workers=1, drop_remainder default
154    """
155    logger.info("test_batch_08")
156    # define parameters
157    batch_size = 6
158    num_parallel_workers = 1
159
160    # apply dataset operations
161    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
162    data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)
163
164    assert sum([1 for _ in data1]) == 2
165    filename = "batch_08_result.npz"
166    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
167
168
169def test_batch_09():
170    """
171    Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=False
172    """
173    logger.info("test_batch_09")
174    # define parameters
175    batch_size = 13
176    drop_remainder = False
177
178    # apply dataset operations
179    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
180    data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
181
182    assert sum([1 for _ in data1]) == 1
183    filename = "batch_09_result.npz"
184    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
185
186
187def test_batch_10():
188    """
189    Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=True
190    """
191    logger.info("test_batch_10")
192    # define parameters
193    batch_size = 99
194    drop_remainder = True
195
196    # apply dataset operations
197    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
198    data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
199
200    assert sum([1 for _ in data1]) == 0
201    filename = "batch_10_result.npz"
202    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
203
204
205def test_batch_11():
206    """
207    Test batch: batch_size=1 and dataset-size=1
208    """
209    logger.info("test_batch_11")
210    # define parameters
211    batch_size = 1
212
213    # apply dataset operations
214    # Use schema file with 1 row
215    schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema1Row.json"
216    data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
217    data1 = data1.batch(batch_size)
218
219    assert sum([1 for _ in data1]) == 1
220    filename = "batch_11_result.npz"
221    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
222
223
224def test_batch_12():
225    """
226    Test batch: batch_size boolean value True, treated as valid value 1
227    """
228    logger.info("test_batch_12")
229    # define parameters
230    batch_size = True
231
232    # apply dataset operations
233    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
234    data1 = data1.batch(batch_size=batch_size)
235
236    assert sum([1 for _ in data1]) == 12
237    filename = "batch_12_result.npz"
238    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
239
240
241def test_batch_13():
242    """
243    Test batch: python_multiprocessing is True and does not work for per_batch_map is None
244    """
245    logger.info("test_batch_12")
246    # define parameters
247    batch_size = True
248
249    # apply dataset operations
250    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
251    data1 = data1.batch(batch_size=batch_size, python_multiprocessing=True)
252
253    assert sum([1 for _ in data1]) == 12
254    filename = "batch_12_result.npz"
255    save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
256
257
258def test_batch_exception_01():
259    """
260    Test batch exception: num_parallel_workers=0
261    """
262    logger.info("test_batch_exception_01")
263
264    # apply dataset operations
265    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
266    try:
267        data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0)
268        sum([1 for _ in data1])
269
270    except Exception as e:
271        logger.info("Got an exception in DE: {}".format(str(e)))
272        assert "num_parallel_workers" in str(e)
273
274
275def test_batch_exception_02():
276    """
277    Test batch exception: num_parallel_workers<0
278    """
279    logger.info("test_batch_exception_02")
280
281    # apply dataset operations
282    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
283    try:
284        data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1)
285        sum([1 for _ in data1])
286
287    except Exception as e:
288        logger.info("Got an exception in DE: {}".format(str(e)))
289        assert "num_parallel_workers" in str(e)
290
291
292def test_batch_exception_03():
293    """
294    Test batch exception: batch_size=0
295    """
296    logger.info("test_batch_exception_03")
297
298    # apply dataset operations
299    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
300    try:
301        data1 = data1.batch(batch_size=0)
302        sum([1 for _ in data1])
303
304    except Exception as e:
305        logger.info("Got an exception in DE: {}".format(str(e)))
306        assert "batch_size" in str(e)
307
308
309def test_batch_exception_04():
310    """
311    Test batch exception: batch_size<0
312    """
313    logger.info("test_batch_exception_04")
314
315    # apply dataset operations
316    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
317    try:
318        data1 = data1.batch(batch_size=-1)
319        sum([1 for _ in data1])
320
321    except Exception as e:
322        logger.info("Got an exception in DE: {}".format(str(e)))
323        assert "batch_size" in str(e)
324
325
326def test_batch_exception_05():
327    """
328    Test batch exception: batch_size boolean value False, treated as invalid value 0
329    """
330    logger.info("test_batch_exception_05")
331
332    # apply dataset operations
333    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
334    try:
335        data1 = data1.batch(batch_size=False)
336        sum([1 for _ in data1])
337
338    except Exception as e:
339        logger.info("Got an exception in DE: {}".format(str(e)))
340        assert "batch_size" in str(e)
341
342
343def test_batch_exception_07():
344    """
345    Test batch exception: drop_remainder wrong type
346    """
347    logger.info("test_batch_exception_07")
348
349    # apply dataset operations
350    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
351    try:
352        data1 = data1.batch(3, drop_remainder=0)
353        sum([1 for _ in data1])
354
355    except Exception as e:
356        logger.info("Got an exception in DE: {}".format(str(e)))
357        assert "drop_remainder" in str(e)
358
359
360def test_batch_exception_08():
361    """
362    Test batch exception: num_parallel_workers wrong type
363    """
364    logger.info("test_batch_exception_08")
365
366    # apply dataset operations
367    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
368    try:
369        data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False)
370        sum([1 for _ in data1])
371
372    except Exception as e:
373        logger.info("Got an exception in DE: {}".format(str(e)))
374        assert "num_parallel_workers" in str(e)
375
376
377def test_batch_exception_09():
378    """
379    Test batch exception: Missing mandatory batch_size
380    """
381    logger.info("test_batch_exception_09")
382
383    # apply dataset operations
384    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
385    try:
386        data1 = data1.batch(drop_remainder=True, num_parallel_workers=4)
387        sum([1 for _ in data1])
388
389    except Exception as e:
390        logger.info("Got an exception in DE: {}".format(str(e)))
391        assert "batch_size" in str(e)
392
393
394def test_batch_exception_10():
395    """
396    Test batch exception: num_parallel_workers>>1
397    """
398    logger.info("test_batch_exception_10")
399
400    # apply dataset operations
401    data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
402    try:
403        data1 = data1.batch(batch_size=4, num_parallel_workers=8192)
404        sum([1 for _ in data1])
405
406    except Exception as e:
407        logger.info("Got an exception in DE: {}".format(str(e)))
408        assert "num_parallel_workers" in str(e)
409
410
411def test_batch_exception_11():
412    """
413    Test batch exception: wrong input order, num_parallel_workers wrongly used as drop_remainder
414    """
415    logger.info("test_batch_exception_11")
416    # define parameters
417    batch_size = 6
418    num_parallel_workers = 1
419
420    # apply dataset operations
421    data1 = ds.TFRecordDataset(DATA_DIR)
422    try:
423        data1 = data1.batch(batch_size, num_parallel_workers)
424        sum([1 for _ in data1])
425
426    except Exception as e:
427        logger.info("Got an exception in DE: {}".format(str(e)))
428        assert "drop_remainder" in str(e)
429
430
431def test_batch_exception_12():
432    """
433    Test batch exception: wrong input order, drop_remainder wrongly used as batch_size
434    """
435    logger.info("test_batch_exception_12")
436    # define parameters
437    batch_size = 1
438    drop_remainder = True
439
440    # apply dataset operations
441    data1 = ds.TFRecordDataset(DATA_DIR)
442    try:
443        data1 = data1.batch(drop_remainder, batch_size)
444        sum([1 for _ in data1])
445
446    except Exception as e:
447        logger.info("Got an exception in DE: {}".format(str(e)))
448        assert "drop_remainder" in str(e)
449
450
451def test_batch_exception_13():
452    """
453    Test batch exception: invalid input parameter
454    """
455    logger.info("test_batch_exception_13")
456    # define parameters
457    batch_size = 4
458
459    # apply dataset operations
460    data1 = ds.TFRecordDataset(DATA_DIR)
461    try:
462        data1 = data1.batch(batch_size, shard_id=1)
463        sum([1 for _ in data1])
464
465    except Exception as e:
466        logger.info("Got an exception in DE: {}".format(str(e)))
467        assert "shard_id" in str(e)
468
469
470def test_batch_exception_14():
471    """
472    Test per_batch_map and input column name
473    """
474    logger.info("test_batch_exception_14")
475    batch_size = 2
476    input_columns = ["num"]
477    data1 = ds.TFRecordDataset(DATA_DIR)
478    try:
479        _ = data1.batch(batch_size=batch_size, input_columns=input_columns)
480    except ValueError as e:
481        assert "per_batch_map and input_columns need to be passed in together." in str(e)
482
483
484def test_batch_exception_15():
485    """
486    Test batch_size = int32 max value + 1
487    """
488    logger.info("test_batch_exception_15")
489    batch_size = 2147483647 + 1
490    input_columns = ["num"]
491    data1 = ds.TFRecordDataset(DATA_DIR)
492    err_msg = ""
493    try:
494        _ = data1.batch(batch_size=batch_size, input_columns=input_columns)
495    except ValueError as e:
496        err_msg = str(e)
497    assert "batch_size is not within the required interval of [1, 2147483647]" in err_msg
498
499
500if __name__ == '__main__':
501    test_batch_01()
502    test_batch_02()
503    test_batch_03()
504    test_batch_04()
505    test_batch_05()
506    test_batch_06()
507    test_batch_07()
508    test_batch_08()
509    test_batch_09()
510    test_batch_10()
511    test_batch_11()
512    test_batch_12()
513    test_batch_13()
514    test_batch_exception_01()
515    test_batch_exception_02()
516    test_batch_exception_03()
517    test_batch_exception_04()
518    test_batch_exception_05()
519    test_batch_exception_07()
520    test_batch_exception_08()
521    test_batch_exception_09()
522    test_batch_exception_10()
523    test_batch_exception_11()
524    test_batch_exception_12()
525    test_batch_exception_13()
526    test_batch_exception_14()
527    test_batch_exception_15()
528    logger.info('\n')
529