• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This is the test module for mindrecord
17"""
18import collections
19import json
20import math
21import os
22import re
23import string
24import pytest
25import numpy as np
26
27import mindspore.dataset as ds
28import mindspore.dataset.vision.c_transforms as vision
29from mindspore import log as logger
30from mindspore.dataset.vision import Inter
31from mindspore.mindrecord import FileWriter
32
33FILES_NUM = 4
34CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
35CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
36CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
37CV_DIR_NAME = "../data/mindrecord/testImageNetData"
38NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
39OLD_NLP_FILE_NAME = "../data/mindrecord/testOldVersion/aclImdb.mindrecord"
40NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
41NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt"
42
43
44@pytest.fixture
45def add_and_remove_cv_file():
46    """add/remove cv file"""
47    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
48             for x in range(FILES_NUM)]
49    try:
50        for x in paths:
51            if os.path.exists("{}".format(x)):
52                os.remove("{}".format(x))
53            if os.path.exists("{}.db".format(x)):
54                os.remove("{}.db".format(x))
55        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
56        data = get_data(CV_DIR_NAME)
57        cv_schema_json = {"id": {"type": "int32"},
58                          "file_name": {"type": "string"},
59                          "label": {"type": "int32"},
60                          "data": {"type": "bytes"}}
61        writer.add_schema(cv_schema_json, "img_schema")
62        writer.add_index(["file_name", "label"])
63        writer.write_raw_data(data)
64        writer.commit()
65        yield "yield_cv_data"
66    except Exception as error:
67        for x in paths:
68            os.remove("{}".format(x))
69            os.remove("{}.db".format(x))
70        raise error
71    else:
72        for x in paths:
73            os.remove("{}".format(x))
74            os.remove("{}.db".format(x))
75
76
77@pytest.fixture
78def add_and_remove_nlp_file():
79    """add/remove nlp file"""
80    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
81             for x in range(FILES_NUM)]
82    try:
83        for x in paths:
84            if os.path.exists("{}".format(x)):
85                os.remove("{}".format(x))
86            if os.path.exists("{}.db".format(x)):
87                os.remove("{}.db".format(x))
88        writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
89        data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
90        nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
91                           "rating": {"type": "float32"},
92                           "input_ids": {"type": "int64",
93                                         "shape": [-1]},
94                           "input_mask": {"type": "int64",
95                                          "shape": [1, -1]},
96                           "segment_ids": {"type": "int64",
97                                           "shape": [2, -1]}
98                           }
99        writer.set_header_size(1 << 14)
100        writer.set_page_size(1 << 15)
101        writer.add_schema(nlp_schema_json, "nlp_schema")
102        writer.add_index(["id", "rating"])
103        writer.write_raw_data(data)
104        writer.commit()
105        yield "yield_nlp_data"
106    except Exception as error:
107        for x in paths:
108            os.remove("{}".format(x))
109            os.remove("{}.db".format(x))
110        raise error
111    else:
112        for x in paths:
113            os.remove("{}".format(x))
114            os.remove("{}.db".format(x))
115
116
117@pytest.fixture
118def add_and_remove_nlp_compress_file():
119    """add/remove nlp file"""
120    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
121             for x in range(FILES_NUM)]
122    try:
123        for x in paths:
124            if os.path.exists("{}".format(x)):
125                os.remove("{}".format(x))
126            if os.path.exists("{}.db".format(x)):
127                os.remove("{}.db".format(x))
128        writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
129        data = []
130        for row_id in range(16):
131            data.append({
132                "label": row_id,
133                "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
134                                                255, 256, -32768, 32767, -32769, 32768, -2147483648,
135                                                2147483647], dtype=np.int32), [-1]),
136                "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
137                                                256, -32768, 32767, -32769, 32768,
138                                                -2147483648, 2147483647, -2147483649, 2147483649,
139                                                -922337036854775808, 9223372036854775807]), [1, -1]),
140                "array_c": str.encode("nlp data"),
141                "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
142            })
143        nlp_schema_json = {"label": {"type": "int32"},
144                           "array_a": {"type": "int32",
145                                       "shape": [-1]},
146                           "array_b": {"type": "int64",
147                                       "shape": [1, -1]},
148                           "array_c": {"type": "bytes"},
149                           "array_d": {"type": "int64",
150                                       "shape": [2, -1]}
151                           }
152        writer.set_header_size(1 << 14)
153        writer.set_page_size(1 << 15)
154        writer.add_schema(nlp_schema_json, "nlp_schema")
155        writer.write_raw_data(data)
156        writer.commit()
157        yield "yield_nlp_data"
158    except Exception as error:
159        for x in paths:
160            os.remove("{}".format(x))
161            os.remove("{}.db".format(x))
162        raise error
163    else:
164        for x in paths:
165            os.remove("{}".format(x))
166            os.remove("{}.db".format(x))
167
168
169def test_nlp_compress_data(add_and_remove_nlp_compress_file):
170    """tutorial for nlp minderdataset."""
171    data = []
172    for row_id in range(16):
173        data.append({
174            "label": row_id,
175            "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
176                                            255, 256, -32768, 32767, -32769, 32768, -2147483648,
177                                            2147483647], dtype=np.int32), [-1]),
178            "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
179                                            256, -32768, 32767, -32769, 32768,
180                                            -2147483648, 2147483647, -2147483649, 2147483649,
181                                            -922337036854775808, 9223372036854775807]), [1, -1]),
182            "array_c": str.encode("nlp data"),
183            "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
184        })
185    num_readers = 1
186    data_set = ds.MindDataset(
187        NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
188    assert data_set.get_dataset_size() == 16
189    num_iter = 0
190    for x, item in zip(data, data_set.create_dict_iterator(num_epochs=1, output_numpy=True)):
191        assert (item["array_a"] == x["array_a"]).all()
192        assert (item["array_b"] == x["array_b"]).all()
193        assert item["array_c"].tobytes() == x["array_c"]
194        assert (item["array_d"] == x["array_d"]).all()
195        assert item["label"] == x["label"]
196        num_iter += 1
197    assert num_iter == 16
198
199
200def test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file):
201    """tutorial for nlp minderdataset."""
202    num_readers = 1
203    data_set = ds.MindDataset(
204        NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
205    old_data_set = ds.MindDataset(
206        OLD_NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
207    assert old_data_set.get_dataset_size() == 16
208    num_iter = 0
209    for x, item in zip(old_data_set.create_dict_iterator(num_epochs=1, output_numpy=True),
210                       data_set.create_dict_iterator(num_epochs=1, output_numpy=True)):
211        assert (item["array_a"] == x["array_a"]).all()
212        assert (item["array_b"] == x["array_b"]).all()
213        assert (item["array_c"] == x["array_c"]).all()
214        assert (item["array_d"] == x["array_d"]).all()
215        assert item["label"] == x["label"]
216        num_iter += 1
217    assert num_iter == 16
218
219
220def test_cv_minddataset_writer_tutorial():
221    """tutorial for cv dataset writer."""
222    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
223             for x in range(FILES_NUM)]
224    try:
225        for x in paths:
226            if os.path.exists("{}".format(x)):
227                os.remove("{}".format(x))
228            if os.path.exists("{}.db".format(x)):
229                os.remove("{}.db".format(x))
230        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
231        data = get_data(CV_DIR_NAME)
232        cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
233                          "data": {"type": "bytes"}}
234        writer.add_schema(cv_schema_json, "img_schema")
235        writer.add_index(["file_name", "label"])
236        writer.write_raw_data(data)
237        writer.commit()
238    except Exception as error:
239        for x in paths:
240            os.remove("{}".format(x))
241            os.remove("{}.db".format(x))
242        raise error
243    else:
244        for x in paths:
245            os.remove("{}".format(x))
246            os.remove("{}.db".format(x))
247
248
249def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
250    """tutorial for cv minddataset."""
251    columns_list = ["data", "file_name", "label"]
252    num_readers = 4
253
254    def partitions(num_shards):
255        for partition_id in range(num_shards):
256            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
257                                      num_shards=num_shards, shard_id=partition_id)
258            num_iter = 0
259            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
260                logger.info("-------------- partition : {} ------------------------".format(partition_id))
261                logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
262                logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
263                num_iter += 1
264        return num_iter
265
266    assert partitions(4) == 3
267    assert partitions(5) == 2
268    assert partitions(9) == 2
269
270
271def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
272    """tutorial for cv minddataset."""
273    columns_list = ["data", "file_name", "label"]
274    num_readers = 4
275
276    def partitions(num_shards):
277        for partition_id in range(num_shards):
278            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
279                                      num_shards=num_shards,
280                                      shard_id=partition_id, num_samples=1)
281
282            assert data_set.get_dataset_size() == 1
283            num_iter = 0
284            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
285                logger.info("-------------- partition : {} ------------------------".format(partition_id))
286                logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
287                logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
288                num_iter += 1
289        return num_iter
290
291    assert partitions(4) == 1
292    assert partitions(5) == 1
293    assert partitions(9) == 1
294
295
296def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
297    """tutorial for cv minddataset."""
298    columns_list = ["data", "file_name", "label"]
299    num_readers = 4
300
301    def partitions(num_shards):
302        for partition_id in range(num_shards):
303            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
304                                      num_shards=num_shards,
305                                      shard_id=partition_id, num_samples=2)
306
307            assert data_set.get_dataset_size() == 2
308            num_iter = 0
309            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
310                logger.info("-------------- partition : {} ------------------------".format(partition_id))
311                logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
312                logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
313                num_iter += 1
314        return num_iter
315
316    assert partitions(4) == 2
317    assert partitions(5) == 2
318    assert partitions(9) == 2
319
320
321def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
322    """tutorial for cv minddataset."""
323    columns_list = ["data", "file_name", "label"]
324    num_readers = 4
325
326    def partitions(num_shards, expect):
327        for partition_id in range(num_shards):
328            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
329                                      num_shards=num_shards,
330                                      shard_id=partition_id, num_samples=3)
331
332            assert data_set.get_dataset_size() == expect
333            num_iter = 0
334            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
335                logger.info("-------------- partition : {} ------------------------".format(partition_id))
336                logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
337                logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
338                num_iter += 1
339        return num_iter
340
341    assert partitions(4, 3) == 3
342    assert partitions(5, 2) == 2
343    assert partitions(9, 2) == 2
344
345def test_cv_minddataset_partition_num_samples_3(add_and_remove_cv_file):
346    """tutorial for cv minddataset."""
347    columns_list = ["data", "file_name", "label"]
348    num_readers = 4
349
350    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, num_shards=1, shard_id=0, num_samples=5)
351
352    assert data_set.get_dataset_size() == 5
353    num_iter = 0
354    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
355        logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
356        logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
357        num_iter += 1
358
359    assert num_iter == 5
360
361def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
362    """tutorial for cv minddataset."""
363    columns_list = ["data", "file_name", "label"]
364    num_readers = 4
365    num_shards = 3
366    epoch1 = []
367    epoch2 = []
368    epoch3 = []
369
370    for partition_id in range(num_shards):
371        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
372                                  num_shards=num_shards, shard_id=partition_id)
373
374        data_set = data_set.repeat(3)
375
376        num_iter = 0
377        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
378            logger.info("-------------- partition : {} ------------------------".format(partition_id))
379            logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
380            logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
381            num_iter += 1
382            if num_iter <= 4:
383                epoch1.append(item["file_name"])  # save epoch 1 list
384            elif num_iter <= 8:
385                epoch2.append(item["file_name"])  # save epoch 2 list
386            else:
387                epoch3.append(item["file_name"])  # save epoch 3 list
388        assert num_iter == 12
389        assert len(epoch1) == 4
390        assert len(epoch2) == 4
391        assert len(epoch3) == 4
392        assert epoch1 not in (epoch2, epoch3)
393        assert epoch2 not in (epoch1, epoch3)
394        assert epoch3 not in (epoch1, epoch2)
395        epoch1 = []
396        epoch2 = []
397        epoch3 = []
398
399
400def test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file):
401    """tutorial for cv minddataset."""
402    columns_list = ["data", "file_name", "label"]
403    num_readers = 4
404    num_shards = 3
405    epoch_result = [[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]],  # save partition 0 result
406                    [["", "", "", ""], ["", "", "", ""], ["", "", "", ""]],  # save partition 1 result
407                    [["", "", "", ""], ["", "", "", ""], ["", "", "", ""]]]  # svae partition 2 result
408
409    for partition_id in range(num_shards):
410        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
411                                  num_shards=num_shards, shard_id=partition_id)
412
413        data_set = data_set.repeat(3)
414
415        num_iter = 0
416        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
417            logger.info("-------------- partition : {} ------------------------".format(partition_id))
418            logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
419            logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
420            # total 3 partition, 4 result per epoch, total 12 result
421            epoch_result[partition_id][int(num_iter / 4)][num_iter % 4] = item["file_name"]  # save epoch result
422            num_iter += 1
423        assert num_iter == 12
424        assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
425        assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
426        assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
427        epoch_result[partition_id][0].sort()
428        epoch_result[partition_id][1].sort()
429        epoch_result[partition_id][2].sort()
430        assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
431        assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
432        assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
433
434
435def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file):
436    """tutorial for cv minddataset."""
437    columns_list = ["data", "file_name", "label"]
438    num_readers = 4
439
440    ds.config.set_seed(54321)
441    epoch1 = []
442    epoch2 = []
443    epoch3 = []
444
445    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
446    data_set = data_set.repeat(3)
447
448    num_iter = 0
449    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
450        logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
451        logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
452        num_iter += 1
453        if num_iter <= 10:
454            epoch1.append(item["file_name"])  # save epoch 1 list
455        elif num_iter <= 20:
456            epoch2.append(item["file_name"])  # save epoch 2 list
457        else:
458            epoch3.append(item["file_name"])  # save epoch 3 list
459    assert num_iter == 30
460    assert len(epoch1) == 10
461    assert len(epoch2) == 10
462    assert len(epoch3) == 10
463    assert epoch1 not in (epoch2, epoch3)
464    assert epoch2 not in (epoch1, epoch3)
465    assert epoch3 not in (epoch1, epoch2)
466
467    epoch1_new_dataset = []
468    epoch2_new_dataset = []
469    epoch3_new_dataset = []
470
471    data_set2 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
472    data_set2 = data_set2.repeat(3)
473
474    num_iter = 0
475    for item in data_set2.create_dict_iterator(num_epochs=1, output_numpy=True):
476        logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
477        logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
478        num_iter += 1
479        if num_iter <= 10:
480            epoch1_new_dataset.append(item["file_name"])  # save epoch 1 list
481        elif num_iter <= 20:
482            epoch2_new_dataset.append(item["file_name"])  # save epoch 2 list
483        else:
484            epoch3_new_dataset.append(item["file_name"])  # save epoch 3 list
485    assert num_iter == 30
486    assert len(epoch1_new_dataset) == 10
487    assert len(epoch2_new_dataset) == 10
488    assert len(epoch3_new_dataset) == 10
489    assert epoch1_new_dataset not in (epoch2_new_dataset, epoch3_new_dataset)
490    assert epoch2_new_dataset not in (epoch1_new_dataset, epoch3_new_dataset)
491    assert epoch3_new_dataset not in (epoch1_new_dataset, epoch2_new_dataset)
492
493    assert epoch1 == epoch1_new_dataset
494    assert epoch2 == epoch2_new_dataset
495    assert epoch3 == epoch3_new_dataset
496
497    ds.config.set_seed(12345)
498    epoch1_new_dataset2 = []
499    epoch2_new_dataset2 = []
500    epoch3_new_dataset2 = []
501
502    data_set3 = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
503    data_set3 = data_set3.repeat(3)
504
505    num_iter = 0
506    for item in data_set3.create_dict_iterator(num_epochs=1, output_numpy=True):
507        logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
508        logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
509        num_iter += 1
510        if num_iter <= 10:
511            epoch1_new_dataset2.append(item["file_name"])  # save epoch 1 list
512        elif num_iter <= 20:
513            epoch2_new_dataset2.append(item["file_name"])  # save epoch 2 list
514        else:
515            epoch3_new_dataset2.append(item["file_name"])  # save epoch 3 list
516    assert num_iter == 30
517    assert len(epoch1_new_dataset2) == 10
518    assert len(epoch2_new_dataset2) == 10
519    assert len(epoch3_new_dataset2) == 10
520    assert epoch1_new_dataset2 not in (epoch2_new_dataset2, epoch3_new_dataset2)
521    assert epoch2_new_dataset2 not in (epoch1_new_dataset2, epoch3_new_dataset2)
522    assert epoch3_new_dataset2 not in (epoch1_new_dataset2, epoch2_new_dataset2)
523
524    assert epoch1 != epoch1_new_dataset2
525    assert epoch2 != epoch2_new_dataset2
526    assert epoch3 != epoch3_new_dataset2
527
528
529def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
530    """tutorial for cv minddataset."""
531    columns_list = ["data", "file_name", "label"]
532    num_readers = 4
533    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
534    assert data_set.get_dataset_size() == 10
535    repeat_num = 2
536    data_set = data_set.repeat(repeat_num)
537    num_iter = 0
538    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
539        logger.info(
540            "-------------- get dataset size {} -----------------".format(num_iter))
541        logger.info(
542            "-------------- item[label]: {} ---------------------".format(item["label"]))
543        logger.info(
544            "-------------- item[data]: {} ----------------------".format(item["data"]))
545        num_iter += 1
546    assert num_iter == 20
547    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
548                              num_shards=4, shard_id=3)
549    assert data_set.get_dataset_size() == 3
550
551
552def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
553    """tutorial for cv minddataset."""
554    columns_list = ["data", "label"]
555    num_readers = 4
556    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
557    decode_op = vision.Decode()
558    data_set = data_set.map(
559        input_columns=["data"], operations=decode_op, num_parallel_workers=2)
560    resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
561    data_set = data_set.map(operations=resize_op, input_columns="data",
562                            num_parallel_workers=2)
563    data_set = data_set.batch(2)
564    data_set = data_set.repeat(2)
565    num_iter = 0
566    labels = []
567    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
568        logger.info(
569            "-------------- get dataset size {} -----------------".format(num_iter))
570        logger.info(
571            "-------------- item[label]: {} ---------------------".format(item["label"]))
572        logger.info(
573            "-------------- item[data]: {} ----------------------".format(item["data"]))
574        num_iter += 1
575        labels.append(item["label"])
576    assert num_iter == 10
577    logger.info("repeat shuffle: {}".format(labels))
578    assert len(labels) == 10
579    assert labels[0:5] == labels[0:5]
580    assert labels[0:5] != labels[5:5]
581
582
583def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
584    """tutorial for cv minddataset."""
585    columns_list = ["data", "label"]
586    num_readers = 4
587    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
588    decode_op = vision.Decode()
589    data_set = data_set.map(
590        input_columns=["data"], operations=decode_op, num_parallel_workers=2)
591    resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
592    data_set = data_set.map(operations=resize_op, input_columns="data",
593                            num_parallel_workers=2)
594    data_set = data_set.batch(32, drop_remainder=True)
595    num_iter = 0
596    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
597        logger.info(
598            "-------------- get dataset size {} -----------------".format(num_iter))
599        logger.info(
600            "-------------- item[label]: {} ---------------------".format(item["label"]))
601        logger.info(
602            "-------------- item[data]: {} ----------------------".format(item["data"]))
603        num_iter += 1
604    assert num_iter == 0
605
606
607def test_cv_minddataset_issue_888(add_and_remove_cv_file):
608    """issue 888 test."""
609    columns_list = ["data", "label"]
610    num_readers = 2
611    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1)
612    data_set = data_set.shuffle(2)
613    data_set = data_set.repeat(9)
614    num_iter = 0
615    for _ in data_set.create_dict_iterator(num_epochs=1):
616        num_iter += 1
617    assert num_iter == 18
618
619
620def test_cv_minddataset_reader_file_list(add_and_remove_cv_file):
621    """tutorial for cv minderdataset."""
622    columns_list = ["data", "file_name", "label"]
623    num_readers = 4
624    data_set = ds.MindDataset([CV_FILE_NAME + str(x)
625                               for x in range(FILES_NUM)], columns_list, num_readers)
626    assert data_set.get_dataset_size() == 10
627    num_iter = 0
628    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
629        logger.info(
630            "-------------- cv reader basic: {} ------------------------".format(num_iter))
631        logger.info(
632            "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
633        logger.info(
634            "-------------- item[data]: {} -----------------------------".format(item["data"]))
635        logger.info(
636            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
637        logger.info(
638            "-------------- item[label]: {} ----------------------------".format(item["label"]))
639        num_iter += 1
640    assert num_iter == 10
641
642
643def test_cv_minddataset_reader_one_partition(add_and_remove_cv_file):
644    """tutorial for cv minderdataset."""
645    columns_list = ["data", "file_name", "label"]
646    num_readers = 4
647    data_set = ds.MindDataset([CV_FILE_NAME + "0"], columns_list, num_readers)
648    assert data_set.get_dataset_size() < 10
649    num_iter = 0
650    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
651        logger.info(
652            "-------------- cv reader basic: {} ------------------------".format(num_iter))
653        logger.info(
654            "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
655        logger.info(
656            "-------------- item[data]: {} -----------------------------".format(item["data"]))
657        logger.info(
658            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
659        logger.info(
660            "-------------- item[label]: {} ----------------------------".format(item["label"]))
661        num_iter += 1
662    assert num_iter < 10
663
664
665def test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file):
666    """tutorial for cv minderdataset."""
667    try:
668        if os.path.exists(CV1_FILE_NAME):
669            os.remove(CV1_FILE_NAME)
670        if os.path.exists("{}.db".format(CV1_FILE_NAME)):
671            os.remove("{}.db".format(CV1_FILE_NAME))
672        if os.path.exists(CV2_FILE_NAME):
673            os.remove(CV2_FILE_NAME)
674        if os.path.exists("{}.db".format(CV2_FILE_NAME)):
675            os.remove("{}.db".format(CV2_FILE_NAME))
676        writer = FileWriter(CV1_FILE_NAME, 1)
677        data = get_data(CV_DIR_NAME)
678        cv_schema_json = {"id": {"type": "int32"},
679                          "file_name": {"type": "string"},
680                          "label": {"type": "int32"},
681                          "data": {"type": "bytes"}}
682        writer.add_schema(cv_schema_json, "CV1_schema")
683        writer.add_index(["file_name", "label"])
684        writer.write_raw_data(data)
685        writer.commit()
686
687        writer = FileWriter(CV2_FILE_NAME, 1)
688        data = get_data(CV_DIR_NAME)
689        cv_schema_json = {"id": {"type": "int32"},
690                          "file_name": {"type": "string"},
691                          "label": {"type": "int32"},
692                          "data": {"type": "bytes"}}
693        writer.add_schema(cv_schema_json, "CV2_schema")
694        writer.add_index(["file_name", "label"])
695        writer.write_raw_data(data)
696        writer.commit()
697        columns_list = ["data", "file_name", "label"]
698        num_readers = 4
699        data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(FILES_NUM)] + [CV1_FILE_NAME, CV2_FILE_NAME],
700                                  columns_list, num_readers)
701        assert data_set.get_dataset_size() == 30
702        num_iter = 0
703        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
704            logger.info(
705                "-------------- cv reader basic: {} ------------------------".format(num_iter))
706            logger.info(
707                "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
708            logger.info(
709                "-------------- item[data]: {} -----------------------------".format(item["data"]))
710            logger.info(
711                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
712            logger.info(
713                "-------------- item[label]: {} ----------------------------".format(item["label"]))
714            num_iter += 1
715        assert num_iter == 30
716    except Exception as error:
717        if os.path.exists(CV1_FILE_NAME):
718            os.remove(CV1_FILE_NAME)
719        if os.path.exists("{}.db".format(CV1_FILE_NAME)):
720            os.remove("{}.db".format(CV1_FILE_NAME))
721        if os.path.exists(CV2_FILE_NAME):
722            os.remove(CV2_FILE_NAME)
723        if os.path.exists("{}.db".format(CV2_FILE_NAME)):
724            os.remove("{}.db".format(CV2_FILE_NAME))
725        raise error
726    else:
727        if os.path.exists(CV1_FILE_NAME):
728            os.remove(CV1_FILE_NAME)
729        if os.path.exists("{}.db".format(CV1_FILE_NAME)):
730            os.remove("{}.db".format(CV1_FILE_NAME))
731        if os.path.exists(CV2_FILE_NAME):
732            os.remove(CV2_FILE_NAME)
733        if os.path.exists("{}.db".format(CV2_FILE_NAME)):
734            os.remove("{}.db".format(CV2_FILE_NAME))
735
736
737def test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file):
738    paths = ["{}{}".format(CV1_FILE_NAME, str(x).rjust(1, '0'))
739             for x in range(FILES_NUM)]
740    try:
741        for x in paths:
742            if os.path.exists("{}".format(x)):
743                os.remove("{}".format(x))
744            if os.path.exists("{}.db".format(x)):
745                os.remove("{}.db".format(x))
746        writer = FileWriter(CV1_FILE_NAME, FILES_NUM)
747        data = get_data(CV_DIR_NAME)
748        cv_schema_json = {"id": {"type": "int32"},
749                          "file_name": {"type": "string"},
750                          "label": {"type": "int32"},
751                          "data": {"type": "bytes"}}
752        writer.add_schema(cv_schema_json, "CV1_schema")
753        writer.add_index(["file_name", "label"])
754        writer.write_raw_data(data)
755        writer.commit()
756
757        columns_list = ["data", "file_name", "label"]
758        num_readers = 4
759        data_set = ds.MindDataset([CV_FILE_NAME + str(x) for x in range(2)] +
760                                  [CV1_FILE_NAME + str(x) for x in range(2, 4)],
761                                  columns_list, num_readers)
762        assert data_set.get_dataset_size() < 20
763        num_iter = 0
764        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
765            logger.info(
766                "-------------- cv reader basic: {} ------------------------".format(num_iter))
767            logger.info(
768                "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
769            logger.info(
770                "-------------- item[data]: {} -----------------------------".format(item["data"]))
771            logger.info(
772                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
773            logger.info(
774                "-------------- item[label]: {} ----------------------------".format(item["label"]))
775            num_iter += 1
776        assert num_iter < 20
777    except Exception as error:
778        for x in paths:
779            os.remove("{}".format(x))
780            os.remove("{}.db".format(x))
781        raise error
782    else:
783        for x in paths:
784            os.remove("{}".format(x))
785            os.remove("{}.db".format(x))
786
787
788def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
789    """tutorial for cv minderdataset."""
790    columns_list = ["data", "file_name", "label"]
791    num_readers = 4
792    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
793    assert data_set.get_dataset_size() == 10
794    num_iter = 0
795    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
796        logger.info(
797            "-------------- cv reader basic: {} ------------------------".format(num_iter))
798        logger.info(
799            "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
800        logger.info(
801            "-------------- item[data]: {} -----------------------------".format(item["data"]))
802        logger.info(
803            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
804        logger.info(
805            "-------------- item[label]: {} ----------------------------".format(item["label"]))
806        num_iter += 1
807    assert num_iter == 10
808
809
810def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file):
811    """tutorial for nlp minderdataset."""
812    num_readers = 4
813    data_set = ds.MindDataset(NLP_FILE_NAME + "0", None, num_readers)
814    assert data_set.get_dataset_size() == 10
815    num_iter = 0
816    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
817        logger.info(
818            "-------------- cv reader basic: {} ------------------------".format(num_iter))
819        logger.info(
820            "-------------- num_iter: {} ------------------------".format(num_iter))
821        logger.info(
822            "-------------- item[id]: {} ------------------------".format(item["id"]))
823        logger.info(
824            "-------------- item[rating]: {} --------------------".format(item["rating"]))
825        logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
826            item["input_ids"], item["input_ids"].shape))
827        logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format(
828            item["input_mask"], item["input_mask"].shape))
829        logger.info("-------------- item[segment_ids]: {}, shape: {} -----------------".format(
830            item["segment_ids"], item["segment_ids"].shape))
831        assert item["input_ids"].shape == (50,)
832        assert item["input_mask"].shape == (1, 50)
833        assert item["segment_ids"].shape == (2, 25)
834        num_iter += 1
835    assert num_iter == 10
836
837
838def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
839    """tutorial for cv minderdataset."""
840    columns_list = ["data", "file_name", "label"]
841    num_readers = 4
842    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
843    assert data_set.get_dataset_size() == 10
844    for _ in range(5):
845        num_iter = 0
846        for data in data_set.create_tuple_iterator(output_numpy=True):
847            logger.info("data is {}".format(data))
848            num_iter += 1
849        assert num_iter == 10
850
851        data_set.reset()
852
853
854def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file):
855    """tutorial for cv minderdataset."""
856    columns_list = ["data", "label"]
857    num_readers = 4
858    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
859
860    resize_height = 32
861    resize_width = 32
862
863    # define map operations
864    decode_op = vision.Decode()
865    resize_op = vision.Resize((resize_height, resize_width))
866
867    data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4)
868    data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4)
869
870    data_set = data_set.batch(2)
871    assert data_set.get_dataset_size() == 5
872    for _ in range(5):
873        num_iter = 0
874        for data in data_set.create_tuple_iterator(output_numpy=True):
875            logger.info("data is {}".format(data))
876            num_iter += 1
877        assert num_iter == 5
878
879        data_set.reset()
880
881
882def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file):
883    """tutorial for cv minderdataset."""
884    data_set = ds.MindDataset(CV_FILE_NAME + "0")
885    assert data_set.get_dataset_size() == 10
886    num_iter = 0
887    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
888        logger.info(
889            "-------------- cv reader basic: {} ------------------------".format(num_iter))
890        logger.info(
891            "-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
892        logger.info(
893            "-------------- item[data]: {} -----------------------------".format(item["data"]))
894        logger.info(
895            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
896        logger.info(
897            "-------------- item[label]: {} ----------------------------".format(item["label"]))
898        num_iter += 1
899    assert num_iter == 10
900
901
902def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file):
903    """tutorial for cv minderdataset."""
904    columns_list = ["data", "file_name", "label"]
905    num_readers = 4
906    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
907    repeat_num = 2
908    data_set = data_set.repeat(repeat_num)
909    num_iter = 0
910    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
911        logger.info(
912            "-------------- repeat two test {} ------------------------".format(num_iter))
913        logger.info(
914            "-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
915        logger.info(
916            "-------------- item[data]: {} ----------------------------".format(item["data"]))
917        logger.info(
918            "-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
919        logger.info(
920            "-------------- item[label]: {} ---------------------------".format(item["label"]))
921        num_iter += 1
922    assert num_iter == 20
923
924
925def get_data(dir_name):
926    """
927    usage: get data from imagenet dataset
928    params:
929    dir_name: directory containing folder images and annotation information
930
931    """
932    if not os.path.isdir(dir_name):
933        raise IOError("Directory {} not exists".format(dir_name))
934    img_dir = os.path.join(dir_name, "images")
935    ann_file = os.path.join(dir_name, "annotation.txt")
936    with open(ann_file, "r") as file_reader:
937        lines = file_reader.readlines()
938
939    data_list = []
940    for i, line in enumerate(lines):
941        try:
942            filename, label = line.split(",")
943            label = label.strip("\n")
944            with open(os.path.join(img_dir, filename), "rb") as file_reader:
945                img = file_reader.read()
946            data_json = {"id": i,
947                         "file_name": filename,
948                         "data": img,
949                         "label": int(label)}
950            data_list.append(data_json)
951        except FileNotFoundError:
952            continue
953    return data_list
954
955
956def get_multi_bytes_data(file_name, bytes_num=3):
957    """
958    Return raw data of multi-bytes dataset.
959
960    Args:
961        file_name (str): String of multi-bytes dataset's path.
962        bytes_num (int): Number of bytes fields.
963
964    Returns:
965       List
966    """
967    if not os.path.exists(file_name):
968        raise IOError("map file {} not exists".format(file_name))
969    dir_name = os.path.dirname(file_name)
970    with open(file_name, "r") as file_reader:
971        lines = file_reader.readlines()
972    data_list = []
973    row_num = 0
974    for line in lines:
975        try:
976            img10_path = line.strip('\n').split(" ")
977            img5 = []
978            for path in img10_path[:bytes_num]:
979                with open(os.path.join(dir_name, path), "rb") as file_reader:
980                    img5 += [file_reader.read()]
981            data_json = {"image_{}".format(i): img5[i]
982                         for i in range(len(img5))}
983            data_json.update({"id": row_num})
984            row_num += 1
985            data_list.append(data_json)
986        except FileNotFoundError:
987            continue
988    return data_list
989
990
991def get_mkv_data(dir_name):
992    """
993    Return raw data of Vehicle_and_Person dataset.
994
995    Args:
996        dir_name (str): String of Vehicle_and_Person dataset's path.
997
998    Returns:
999        List
1000    """
1001    if not os.path.isdir(dir_name):
1002        raise IOError("Directory {} not exists".format(dir_name))
1003    img_dir = os.path.join(dir_name, "Image")
1004    label_dir = os.path.join(dir_name, "prelabel")
1005
1006    data_list = []
1007    file_list = os.listdir(label_dir)
1008
1009    index = 1
1010    for item in file_list:
1011        if os.path.splitext(item)[1] == '.json':
1012            file_path = os.path.join(label_dir, item)
1013
1014            image_name = ''.join([os.path.splitext(item)[0], ".jpg"])
1015            image_path = os.path.join(img_dir, image_name)
1016
1017            with open(file_path, "r") as load_f:
1018                load_dict = json.load(load_f)
1019
1020            if os.path.exists(image_path):
1021                with open(image_path, "rb") as file_reader:
1022                    img = file_reader.read()
1023                data_json = {"file_name": image_name,
1024                             "prelabel": str(load_dict),
1025                             "data": img,
1026                             "id": index}
1027                data_list.append(data_json)
1028            index += 1
1029    logger.info('{} images are missing'.format(
1030        len(file_list) - len(data_list)))
1031    return data_list
1032
1033
1034def get_nlp_data(dir_name, vocab_file, num):
1035    """
1036    Return raw data of aclImdb dataset.
1037
1038    Args:
1039        dir_name (str): String of aclImdb dataset's path.
1040        vocab_file (str): String of dictionary's path.
1041        num (int): Number of sample.
1042
1043    Returns:
1044        List
1045    """
1046    if not os.path.isdir(dir_name):
1047        raise IOError("Directory {} not exists".format(dir_name))
1048    for root, _, files in os.walk(dir_name):
1049        for index, file_name_extension in enumerate(files):
1050            if index < num:
1051                file_path = os.path.join(root, file_name_extension)
1052                file_name, _ = file_name_extension.split('.', 1)
1053                id_, rating = file_name.split('_', 1)
1054                with open(file_path, 'r') as f:
1055                    raw_content = f.read()
1056
1057                dictionary = load_vocab(vocab_file)
1058                vectors = [dictionary.get('[CLS]')]
1059                vectors += [dictionary.get(i) if i in dictionary
1060                            else dictionary.get('[UNK]')
1061                            for i in re.findall(r"[\w']+|[{}]"
1062                                                .format(string.punctuation),
1063                                                raw_content)]
1064                vectors += [dictionary.get('[SEP]')]
1065                input_, mask, segment = inputs(vectors)
1066                input_ids = np.reshape(np.array(input_), [-1])
1067                input_mask = np.reshape(np.array(mask), [1, -1])
1068                segment_ids = np.reshape(np.array(segment), [2, -1])
1069                data = {
1070                    "label": 1,
1071                    "id": id_,
1072                    "rating": float(rating),
1073                    "input_ids": input_ids,
1074                    "input_mask": input_mask,
1075                    "segment_ids": segment_ids
1076                }
1077                yield data
1078
1079
1080def convert_to_uni(text):
1081    if isinstance(text, str):
1082        return text
1083    if isinstance(text, bytes):
1084        return text.decode('utf-8', 'ignore')
1085    raise Exception("The type %s does not convert!" % type(text))
1086
1087
1088def load_vocab(vocab_file):
1089    """load vocabulary to translate statement."""
1090    vocab = collections.OrderedDict()
1091    vocab.setdefault('blank', 2)
1092    index = 0
1093    with open(vocab_file) as reader:
1094        while True:
1095            tmp = reader.readline()
1096            if not tmp:
1097                break
1098            token = convert_to_uni(tmp)
1099            token = token.strip()
1100            vocab[token] = index
1101            index += 1
1102    return vocab
1103
1104
1105def inputs(vectors, maxlen=50):
1106    length = len(vectors)
1107    if length > maxlen:
1108        return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
1109    input_ = vectors + [0] * (maxlen - length)
1110    mask = [1] * length + [0] * (maxlen - length)
1111    segment = [0] * maxlen
1112    return input_, mask, segment
1113
1114
1115def test_write_with_multi_bytes_and_array_and_read_by_MindDataset():
1116    mindrecord_file_name = "test.mindrecord"
1117    try:
1118        if os.path.exists("{}".format(mindrecord_file_name)):
1119            os.remove("{}".format(mindrecord_file_name))
1120        if os.path.exists("{}.db".format(mindrecord_file_name)):
1121            os.remove("{}.db".format(mindrecord_file_name))
1122        data = [{"file_name": "001.jpg", "label": 4,
1123                 "image1": bytes("image1 bytes abc", encoding='UTF-8'),
1124                 "image2": bytes("image1 bytes def", encoding='UTF-8'),
1125                 "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
1126                 "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1127                 "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
1128                 "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
1129                 "image5": bytes("image1 bytes mno", encoding='UTF-8'),
1130                 "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
1131                 "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
1132                 "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1133                 "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
1134                {"file_name": "002.jpg", "label": 5,
1135                 "image1": bytes("image2 bytes abc", encoding='UTF-8'),
1136                 "image2": bytes("image2 bytes def", encoding='UTF-8'),
1137                 "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
1138                 "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
1139                 "image5": bytes("image2 bytes mno", encoding='UTF-8'),
1140                 "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
1141                 "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1142                 "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
1143                 "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
1144                 "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1145                 "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
1146                {"file_name": "003.jpg", "label": 6,
1147                 "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
1148                 "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1149                 "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
1150                 "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
1151                 "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1152                 "image1": bytes("image3 bytes abc", encoding='UTF-8'),
1153                 "image2": bytes("image3 bytes def", encoding='UTF-8'),
1154                 "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
1155                 "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
1156                 "image5": bytes("image3 bytes mno", encoding='UTF-8'),
1157                 "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
1158                {"file_name": "004.jpg", "label": 7,
1159                 "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
1160                 "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1161                 "image1": bytes("image4 bytes abc", encoding='UTF-8'),
1162                 "image2": bytes("image4 bytes def", encoding='UTF-8'),
1163                 "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
1164                 "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
1165                 "image5": bytes("image4 bytes mno", encoding='UTF-8'),
1166                 "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
1167                 "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
1168                 "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1169                 "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
1170                {"file_name": "005.jpg", "label": 8,
1171                 "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
1172                 "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1173                 "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
1174                 "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
1175                 "image1": bytes("image5 bytes abc", encoding='UTF-8'),
1176                 "image2": bytes("image5 bytes def", encoding='UTF-8'),
1177                 "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
1178                 "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
1179                 "image5": bytes("image5 bytes mno", encoding='UTF-8'),
1180                 "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1181                 "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
1182                {"file_name": "006.jpg", "label": 9,
1183                 "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
1184                 "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1185                 "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
1186                 "image1": bytes("image6 bytes abc", encoding='UTF-8'),
1187                 "image2": bytes("image6 bytes def", encoding='UTF-8'),
1188                 "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
1189                 "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
1190                 "image5": bytes("image6 bytes mno", encoding='UTF-8'),
1191                 "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
1192                 "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1193                 "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
1194                ]
1195
1196        writer = FileWriter(mindrecord_file_name)
1197        schema = {"file_name": {"type": "string"},
1198                  "image1": {"type": "bytes"},
1199                  "image2": {"type": "bytes"},
1200                  "source_sos_ids": {"type": "int64", "shape": [-1]},
1201                  "source_sos_mask": {"type": "int64", "shape": [-1]},
1202                  "image3": {"type": "bytes"},
1203                  "image4": {"type": "bytes"},
1204                  "image5": {"type": "bytes"},
1205                  "target_sos_ids": {"type": "int64", "shape": [-1]},
1206                  "target_sos_mask": {"type": "int64", "shape": [-1]},
1207                  "target_eos_ids": {"type": "int64", "shape": [-1]},
1208                  "target_eos_mask": {"type": "int64", "shape": [-1]},
1209                  "label": {"type": "int32"}}
1210        writer.add_schema(schema, "data is so cool")
1211        writer.write_raw_data(data)
1212        writer.commit()
1213
1214        # change data value to list
1215        data_value_to_list = []
1216        for item in data:
1217            new_data = {}
1218            new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
1219            new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
1220            new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
1221            new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
1222            new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
1223            new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
1224            new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
1225            new_data['source_sos_ids'] = item["source_sos_ids"]
1226            new_data['source_sos_mask'] = item["source_sos_mask"]
1227            new_data['target_sos_ids'] = item["target_sos_ids"]
1228            new_data['target_sos_mask'] = item["target_sos_mask"]
1229            new_data['target_eos_ids'] = item["target_eos_ids"]
1230            new_data['target_eos_mask'] = item["target_eos_mask"]
1231            data_value_to_list.append(new_data)
1232
1233        num_readers = 2
1234        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1235                                  num_parallel_workers=num_readers,
1236                                  shuffle=False)
1237        assert data_set.get_dataset_size() == 6
1238        num_iter = 0
1239        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1240            assert len(item) == 13
1241            for field in item:
1242                if isinstance(item[field], np.ndarray):
1243                    assert (item[field] ==
1244                            data_value_to_list[num_iter][field]).all()
1245                else:
1246                    assert item[field] == data_value_to_list[num_iter][field]
1247            num_iter += 1
1248        assert num_iter == 6
1249
1250        num_readers = 2
1251        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1252                                  columns_list=["source_sos_ids",
1253                                                "source_sos_mask", "target_sos_ids"],
1254                                  num_parallel_workers=num_readers,
1255                                  shuffle=False)
1256        assert data_set.get_dataset_size() == 6
1257        num_iter = 0
1258        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1259            assert len(item) == 3
1260            for field in item:
1261                if isinstance(item[field], np.ndarray):
1262                    assert (item[field] == data[num_iter][field]).all()
1263                else:
1264                    assert item[field] == data[num_iter][field]
1265            num_iter += 1
1266        assert num_iter == 6
1267
1268        num_readers = 1
1269        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1270                                  columns_list=["image2", "source_sos_mask", "image3", "target_sos_ids"],
1271                                  num_parallel_workers=num_readers,
1272                                  shuffle=False)
1273        assert data_set.get_dataset_size() == 6
1274        num_iter = 0
1275        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1276            assert len(item) == 4
1277            for field in item:
1278                if isinstance(item[field], np.ndarray):
1279                    assert (item[field] ==
1280                            data_value_to_list[num_iter][field]).all()
1281                else:
1282                    assert item[field] == data_value_to_list[num_iter][field]
1283            num_iter += 1
1284        assert num_iter == 6
1285
1286        num_readers = 3
1287        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1288                                  columns_list=["target_sos_ids",
1289                                                "image4", "source_sos_ids"],
1290                                  num_parallel_workers=num_readers,
1291                                  shuffle=False)
1292        assert data_set.get_dataset_size() == 6
1293        num_iter = 0
1294        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1295            assert len(item) == 3
1296            for field in item:
1297                if isinstance(item[field], np.ndarray):
1298                    assert (item[field] ==
1299                            data_value_to_list[num_iter][field]).all()
1300                else:
1301                    assert item[field] == data_value_to_list[num_iter][field]
1302            num_iter += 1
1303        assert num_iter == 6
1304
1305        num_readers = 3
1306        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1307                                  columns_list=["target_sos_ids", "image5",
1308                                                "image4", "image3", "source_sos_ids"],
1309                                  num_parallel_workers=num_readers,
1310                                  shuffle=False)
1311        assert data_set.get_dataset_size() == 6
1312        num_iter = 0
1313        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1314            assert len(item) == 5
1315            for field in item:
1316                if isinstance(item[field], np.ndarray):
1317                    assert (item[field] ==
1318                            data_value_to_list[num_iter][field]).all()
1319                else:
1320                    assert item[field] == data_value_to_list[num_iter][field]
1321            num_iter += 1
1322        assert num_iter == 6
1323
1324        num_readers = 1
1325        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1326                                  columns_list=["target_eos_mask", "image5",
1327                                                "image2", "source_sos_mask", "label"],
1328                                  num_parallel_workers=num_readers,
1329                                  shuffle=False)
1330        assert data_set.get_dataset_size() == 6
1331        num_iter = 0
1332        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1333            assert len(item) == 5
1334            for field in item:
1335                if isinstance(item[field], np.ndarray):
1336                    assert (item[field] ==
1337                            data_value_to_list[num_iter][field]).all()
1338                else:
1339                    assert item[field] == data_value_to_list[num_iter][field]
1340            num_iter += 1
1341        assert num_iter == 6
1342
1343        num_readers = 2
1344        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1345                                  columns_list=["label", "target_eos_mask", "image1", "target_eos_ids",
1346                                                "source_sos_mask", "image2", "image4", "image3",
1347                                                "source_sos_ids", "image5", "file_name"],
1348                                  num_parallel_workers=num_readers,
1349                                  shuffle=False)
1350        assert data_set.get_dataset_size() == 6
1351        num_iter = 0
1352        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1353            assert len(item) == 11
1354            for field in item:
1355                if isinstance(item[field], np.ndarray):
1356                    assert (item[field] ==
1357                            data_value_to_list[num_iter][field]).all()
1358                else:
1359                    assert item[field] == data_value_to_list[num_iter][field]
1360            num_iter += 1
1361        assert num_iter == 6
1362    except Exception as error:
1363        os.remove("{}".format(mindrecord_file_name))
1364        os.remove("{}.db".format(mindrecord_file_name))
1365        raise error
1366    else:
1367        os.remove("{}".format(mindrecord_file_name))
1368        os.remove("{}.db".format(mindrecord_file_name))
1369
1370
1371def test_write_with_multi_bytes_and_MindDataset():
1372    mindrecord_file_name = "test.mindrecord"
1373    try:
1374        data = [{"file_name": "001.jpg", "label": 43,
1375                 "image1": bytes("image1 bytes abc", encoding='UTF-8'),
1376                 "image2": bytes("image1 bytes def", encoding='UTF-8'),
1377                 "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
1378                 "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
1379                 "image5": bytes("image1 bytes mno", encoding='UTF-8')},
1380                {"file_name": "002.jpg", "label": 91,
1381                 "image1": bytes("image2 bytes abc", encoding='UTF-8'),
1382                 "image2": bytes("image2 bytes def", encoding='UTF-8'),
1383                 "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
1384                 "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
1385                 "image5": bytes("image2 bytes mno", encoding='UTF-8')},
1386                {"file_name": "003.jpg", "label": 61,
1387                 "image1": bytes("image3 bytes abc", encoding='UTF-8'),
1388                 "image2": bytes("image3 bytes def", encoding='UTF-8'),
1389                 "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
1390                 "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
1391                 "image5": bytes("image3 bytes mno", encoding='UTF-8')},
1392                {"file_name": "004.jpg", "label": 29,
1393                 "image1": bytes("image4 bytes abc", encoding='UTF-8'),
1394                 "image2": bytes("image4 bytes def", encoding='UTF-8'),
1395                 "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
1396                 "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
1397                 "image5": bytes("image4 bytes mno", encoding='UTF-8')},
1398                {"file_name": "005.jpg", "label": 78,
1399                 "image1": bytes("image5 bytes abc", encoding='UTF-8'),
1400                 "image2": bytes("image5 bytes def", encoding='UTF-8'),
1401                 "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
1402                 "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
1403                 "image5": bytes("image5 bytes mno", encoding='UTF-8')},
1404                {"file_name": "006.jpg", "label": 37,
1405                 "image1": bytes("image6 bytes abc", encoding='UTF-8'),
1406                 "image2": bytes("image6 bytes def", encoding='UTF-8'),
1407                 "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
1408                 "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
1409                 "image5": bytes("image6 bytes mno", encoding='UTF-8')}
1410                ]
1411        writer = FileWriter(mindrecord_file_name)
1412        schema = {"file_name": {"type": "string"},
1413                  "image1": {"type": "bytes"},
1414                  "image2": {"type": "bytes"},
1415                  "image3": {"type": "bytes"},
1416                  "label": {"type": "int32"},
1417                  "image4": {"type": "bytes"},
1418                  "image5": {"type": "bytes"}}
1419        writer.add_schema(schema, "data is so cool")
1420        writer.write_raw_data(data)
1421        writer.commit()
1422
1423        # change data value to list
1424        data_value_to_list = []
1425        for item in data:
1426            new_data = {}
1427            new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
1428            new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
1429            new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
1430            new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
1431            new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
1432            new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
1433            new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
1434            data_value_to_list.append(new_data)
1435
1436        num_readers = 2
1437        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1438                                  num_parallel_workers=num_readers,
1439                                  shuffle=False)
1440        assert data_set.get_dataset_size() == 6
1441        num_iter = 0
1442        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1443            assert len(item) == 7
1444            for field in item:
1445                if isinstance(item[field], np.ndarray):
1446                    assert (item[field] ==
1447                            data_value_to_list[num_iter][field]).all()
1448                else:
1449                    assert item[field] == data_value_to_list[num_iter][field]
1450            num_iter += 1
1451        assert num_iter == 6
1452
1453        num_readers = 2
1454        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1455                                  columns_list=["image1", "image2", "image5"],
1456                                  num_parallel_workers=num_readers,
1457                                  shuffle=False)
1458        assert data_set.get_dataset_size() == 6
1459        num_iter = 0
1460        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1461            assert len(item) == 3
1462            for field in item:
1463                if isinstance(item[field], np.ndarray):
1464                    assert (item[field] ==
1465                            data_value_to_list[num_iter][field]).all()
1466                else:
1467                    assert item[field] == data_value_to_list[num_iter][field]
1468            num_iter += 1
1469        assert num_iter == 6
1470
1471        num_readers = 2
1472        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1473                                  columns_list=["image2", "image4"],
1474                                  num_parallel_workers=num_readers,
1475                                  shuffle=False)
1476        assert data_set.get_dataset_size() == 6
1477        num_iter = 0
1478        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1479            assert len(item) == 2
1480            for field in item:
1481                if isinstance(item[field], np.ndarray):
1482                    assert (item[field] ==
1483                            data_value_to_list[num_iter][field]).all()
1484                else:
1485                    assert item[field] == data_value_to_list[num_iter][field]
1486            num_iter += 1
1487        assert num_iter == 6
1488
1489        num_readers = 2
1490        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1491                                  columns_list=["image5", "image2"],
1492                                  num_parallel_workers=num_readers,
1493                                  shuffle=False)
1494        assert data_set.get_dataset_size() == 6
1495        num_iter = 0
1496        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1497            assert len(item) == 2
1498            for field in item:
1499                if isinstance(item[field], np.ndarray):
1500                    assert (item[field] ==
1501                            data_value_to_list[num_iter][field]).all()
1502                else:
1503                    assert item[field] == data_value_to_list[num_iter][field]
1504            num_iter += 1
1505        assert num_iter == 6
1506
1507        num_readers = 2
1508        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1509                                  columns_list=["image5", "image2", "label"],
1510                                  num_parallel_workers=num_readers,
1511                                  shuffle=False)
1512        assert data_set.get_dataset_size() == 6
1513        num_iter = 0
1514        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1515            assert len(item) == 3
1516            for field in item:
1517                if isinstance(item[field], np.ndarray):
1518                    assert (item[field] ==
1519                            data_value_to_list[num_iter][field]).all()
1520                else:
1521                    assert item[field] == data_value_to_list[num_iter][field]
1522            num_iter += 1
1523        assert num_iter == 6
1524
1525        num_readers = 2
1526        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1527                                  columns_list=["image4", "image5",
1528                                                "image2", "image3", "file_name"],
1529                                  num_parallel_workers=num_readers,
1530                                  shuffle=False)
1531        assert data_set.get_dataset_size() == 6
1532        num_iter = 0
1533        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1534            assert len(item) == 5
1535            for field in item:
1536                if isinstance(item[field], np.ndarray):
1537                    assert (item[field] ==
1538                            data_value_to_list[num_iter][field]).all()
1539                else:
1540                    assert item[field] == data_value_to_list[num_iter][field]
1541            num_iter += 1
1542        assert num_iter == 6
1543    except Exception as error:
1544        os.remove("{}".format(mindrecord_file_name))
1545        os.remove("{}.db".format(mindrecord_file_name))
1546        raise error
1547    else:
1548        os.remove("{}".format(mindrecord_file_name))
1549        os.remove("{}.db".format(mindrecord_file_name))
1550
1551
1552def test_write_with_multi_array_and_MindDataset():
1553    mindrecord_file_name = "test.mindrecord"
1554    try:
1555        data = [{"source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int64),
1556                 "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1557                 "source_eos_ids": np.array([13, 14, 15, 16, 17, 18], dtype=np.int64),
1558                 "source_eos_mask": np.array([19, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
1559                 "target_sos_ids": np.array([28, 29, 30, 31, 32], dtype=np.int64),
1560                 "target_sos_mask": np.array([33, 34, 35, 36, 37, 38], dtype=np.int64),
1561                 "target_eos_ids": np.array([39, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1562                 "target_eos_mask": np.array([48, 49, 50, 51], dtype=np.int64)},
1563                {"source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int64),
1564                 "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1565                 "source_eos_ids": np.array([113, 14, 15, 16, 17, 18], dtype=np.int64),
1566                 "source_eos_mask": np.array([119, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
1567                 "target_sos_ids": np.array([128, 29, 30, 31, 32], dtype=np.int64),
1568                 "target_sos_mask": np.array([133, 34, 35, 36, 37, 38], dtype=np.int64),
1569                 "target_eos_ids": np.array([139, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1570                 "target_eos_mask": np.array([148, 49, 50, 51], dtype=np.int64)},
1571                {"source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int64),
1572                 "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1573                 "source_eos_ids": np.array([213, 14, 15, 16, 17, 18], dtype=np.int64),
1574                 "source_eos_mask": np.array([219, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
1575                 "target_sos_ids": np.array([228, 29, 30, 31, 32], dtype=np.int64),
1576                 "target_sos_mask": np.array([233, 34, 35, 36, 37, 38], dtype=np.int64),
1577                 "target_eos_ids": np.array([239, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1578                 "target_eos_mask": np.array([248, 49, 50, 51], dtype=np.int64)},
1579                {"source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int64),
1580                 "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1581                 "source_eos_ids": np.array([313, 14, 15, 16, 17, 18], dtype=np.int64),
1582                 "source_eos_mask": np.array([319, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
1583                 "target_sos_ids": np.array([328, 29, 30, 31, 32], dtype=np.int64),
1584                 "target_sos_mask": np.array([333, 34, 35, 36, 37, 38], dtype=np.int64),
1585                 "target_eos_ids": np.array([339, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1586                 "target_eos_mask": np.array([348, 49, 50, 51], dtype=np.int64)},
1587                {"source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int64),
1588                 "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1589                 "source_eos_ids": np.array([413, 14, 15, 16, 17, 18], dtype=np.int64),
1590                 "source_eos_mask": np.array([419, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
1591                 "target_sos_ids": np.array([428, 29, 30, 31, 32], dtype=np.int64),
1592                 "target_sos_mask": np.array([433, 34, 35, 36, 37, 38], dtype=np.int64),
1593                 "target_eos_ids": np.array([439, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1594                 "target_eos_mask": np.array([448, 49, 50, 51], dtype=np.int64)},
1595                {"source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int64),
1596                 "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
1597                 "source_eos_ids": np.array([513, 14, 15, 16, 17, 18], dtype=np.int64),
1598                 "source_eos_mask": np.array([519, 20, 21, 22, 23, 24, 25, 26, 27], dtype=np.int64),
1599                 "target_sos_ids": np.array([528, 29, 30, 31, 32], dtype=np.int64),
1600                 "target_sos_mask": np.array([533, 34, 35, 36, 37, 38], dtype=np.int64),
1601                 "target_eos_ids": np.array([539, 40, 41, 42, 43, 44, 45, 46, 47], dtype=np.int64),
1602                 "target_eos_mask": np.array([548, 49, 50, 51], dtype=np.int64)}
1603                ]
1604        writer = FileWriter(mindrecord_file_name)
1605        schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
1606                  "source_sos_mask": {"type": "int64", "shape": [-1]},
1607                  "source_eos_ids": {"type": "int64", "shape": [-1]},
1608                  "source_eos_mask": {"type": "int64", "shape": [-1]},
1609                  "target_sos_ids": {"type": "int64", "shape": [-1]},
1610                  "target_sos_mask": {"type": "int64", "shape": [-1]},
1611                  "target_eos_ids": {"type": "int64", "shape": [-1]},
1612                  "target_eos_mask": {"type": "int64", "shape": [-1]}}
1613        writer.add_schema(schema, "data is so cool")
1614        writer.write_raw_data(data)
1615        writer.commit()
1616
1617        # change data value to list - do none
1618        data_value_to_list = []
1619        for item in data:
1620            new_data = {}
1621            new_data['source_sos_ids'] = item["source_sos_ids"]
1622            new_data['source_sos_mask'] = item["source_sos_mask"]
1623            new_data['source_eos_ids'] = item["source_eos_ids"]
1624            new_data['source_eos_mask'] = item["source_eos_mask"]
1625            new_data['target_sos_ids'] = item["target_sos_ids"]
1626            new_data['target_sos_mask'] = item["target_sos_mask"]
1627            new_data['target_eos_ids'] = item["target_eos_ids"]
1628            new_data['target_eos_mask'] = item["target_eos_mask"]
1629            data_value_to_list.append(new_data)
1630
1631        num_readers = 2
1632        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1633                                  num_parallel_workers=num_readers,
1634                                  shuffle=False)
1635        assert data_set.get_dataset_size() == 6
1636        num_iter = 0
1637        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1638            assert len(item) == 8
1639            for field in item:
1640                if isinstance(item[field], np.ndarray):
1641                    assert (item[field] ==
1642                            data_value_to_list[num_iter][field]).all()
1643                else:
1644                    assert item[field] == data_value_to_list[num_iter][field]
1645            num_iter += 1
1646        assert num_iter == 6
1647
1648        num_readers = 2
1649        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1650                                  columns_list=["source_eos_ids", "source_eos_mask",
1651                                                "target_sos_ids", "target_sos_mask",
1652                                                "target_eos_ids", "target_eos_mask"],
1653                                  num_parallel_workers=num_readers,
1654                                  shuffle=False)
1655        assert data_set.get_dataset_size() == 6
1656        num_iter = 0
1657        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1658            assert len(item) == 6
1659            for field in item:
1660                if isinstance(item[field], np.ndarray):
1661                    assert (item[field] ==
1662                            data_value_to_list[num_iter][field]).all()
1663                else:
1664                    assert item[field] == data_value_to_list[num_iter][field]
1665            num_iter += 1
1666        assert num_iter == 6
1667
1668        num_readers = 2
1669        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1670                                  columns_list=["source_sos_ids",
1671                                                "target_sos_ids",
1672                                                "target_eos_mask"],
1673                                  num_parallel_workers=num_readers,
1674                                  shuffle=False)
1675        assert data_set.get_dataset_size() == 6
1676        num_iter = 0
1677        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1678            assert len(item) == 3
1679            for field in item:
1680                if isinstance(item[field], np.ndarray):
1681                    assert (item[field] ==
1682                            data_value_to_list[num_iter][field]).all()
1683                else:
1684                    assert item[field] == data_value_to_list[num_iter][field]
1685            num_iter += 1
1686        assert num_iter == 6
1687
1688        num_readers = 2
1689        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1690                                  columns_list=["target_eos_mask",
1691                                                "source_eos_mask",
1692                                                "source_sos_mask"],
1693                                  num_parallel_workers=num_readers,
1694                                  shuffle=False)
1695        assert data_set.get_dataset_size() == 6
1696        num_iter = 0
1697        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1698            assert len(item) == 3
1699            for field in item:
1700                if isinstance(item[field], np.ndarray):
1701                    assert (item[field] ==
1702                            data_value_to_list[num_iter][field]).all()
1703                else:
1704                    assert item[field] == data_value_to_list[num_iter][field]
1705            num_iter += 1
1706        assert num_iter == 6
1707
1708        num_readers = 2
1709        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1710                                  columns_list=["target_eos_ids"],
1711                                  num_parallel_workers=num_readers,
1712                                  shuffle=False)
1713        assert data_set.get_dataset_size() == 6
1714        num_iter = 0
1715        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1716            assert len(item) == 1
1717            for field in item:
1718                if isinstance(item[field], np.ndarray):
1719                    assert (item[field] ==
1720                            data_value_to_list[num_iter][field]).all()
1721                else:
1722                    assert item[field] == data_value_to_list[num_iter][field]
1723            num_iter += 1
1724        assert num_iter == 6
1725
1726        num_readers = 1
1727        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1728                                  columns_list=["target_eos_mask", "target_eos_ids",
1729                                                "target_sos_mask", "target_sos_ids",
1730                                                "source_eos_mask", "source_eos_ids",
1731                                                "source_sos_mask", "source_sos_ids"],
1732                                  num_parallel_workers=num_readers,
1733                                  shuffle=False)
1734        assert data_set.get_dataset_size() == 6
1735        num_iter = 0
1736        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1737            assert len(item) == 8
1738            for field in item:
1739                if isinstance(item[field], np.ndarray):
1740                    assert (item[field] ==
1741                            data_value_to_list[num_iter][field]).all()
1742                else:
1743                    assert item[field] == data_value_to_list[num_iter][field]
1744            num_iter += 1
1745        assert num_iter == 6
1746    except Exception as error:
1747        os.remove("{}".format(mindrecord_file_name))
1748        os.remove("{}.db".format(mindrecord_file_name))
1749        raise error
1750    else:
1751        os.remove("{}".format(mindrecord_file_name))
1752        os.remove("{}.db".format(mindrecord_file_name))
1753
1754
1755def test_numpy_generic():
1756    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
1757             for x in range(FILES_NUM)]
1758    try:
1759        for x in paths:
1760            if os.path.exists("{}".format(x)):
1761                os.remove("{}".format(x))
1762            if os.path.exists("{}.db".format(x)):
1763                os.remove("{}.db".format(x))
1764        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
1765        cv_schema_json = {"label1": {"type": "int32"}, "label2": {"type": "int64"},
1766                          "label3": {"type": "float32"}, "label4": {"type": "float64"}}
1767        data = []
1768        for idx in range(10):
1769            row = {}
1770            row['label1'] = np.int32(idx)
1771            row['label2'] = np.int64(idx * 10)
1772            row['label3'] = np.float32(idx + 0.12345)
1773            row['label4'] = np.float64(idx + 0.12345789)
1774            data.append(row)
1775        writer.add_schema(cv_schema_json, "img_schema")
1776        writer.write_raw_data(data)
1777        writer.commit()
1778
1779        num_readers = 4
1780        data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, shuffle=False)
1781        assert data_set.get_dataset_size() == 10
1782        idx = 0
1783        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1784            assert item['label1'] == item['label1']
1785            assert item['label2'] == item['label2']
1786            assert item['label3'] == item['label3']
1787            assert item['label4'] == item['label4']
1788            idx += 1
1789        assert idx == 10
1790    except Exception as error:
1791        for x in paths:
1792            os.remove("{}".format(x))
1793            os.remove("{}.db".format(x))
1794        raise error
1795    else:
1796        for x in paths:
1797            os.remove("{}".format(x))
1798            os.remove("{}.db".format(x))
1799
1800
1801def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset():
1802    mindrecord_file_name = "test.mindrecord"
1803    try:
1804        data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
1805                 "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
1806                                            123414314.2141243, 87.1212122], dtype=np.float64),
1807                 "float32": 3456.12345,
1808                 "float64": 1987654321.123456785,
1809                 "int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32),
1810                 "int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64),
1811                 "int32": 3456,
1812                 "int64": 947654321123},
1813                {"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
1814                 "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
1815                                            123414314.2141243, 87.1212122], dtype=np.float64),
1816                 "float32": 3456.12445,
1817                 "float64": 1987654321.123456786,
1818                 "int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32),
1819                 "int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64),
1820                 "int32": 3466,
1821                 "int64": 957654321123},
1822                {"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
1823                 "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
1824                                            123414314.2141243, 87.1212122], dtype=np.float64),
1825                 "float32": 3456.12545,
1826                 "float64": 1987654321.123456787,
1827                 "int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32),
1828                 "int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64),
1829                 "int32": 3476,
1830                 "int64": 967654321123},
1831                {"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
1832                 "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
1833                                            123414314.2141243, 87.1212122], dtype=np.float64),
1834                 "float32": 3456.12645,
1835                 "float64": 1987654321.123456788,
1836                 "int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32),
1837                 "int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64),
1838                 "int32": 3486,
1839                 "int64": 977654321123},
1840                {"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
1841                 "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
1842                                            123414314.2141243, 87.1212122], dtype=np.float64),
1843                 "float32": 3456.12745,
1844                 "float64": 1987654321.123456789,
1845                 "int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32),
1846                 "int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64),
1847                 "int32": 3496,
1848                 "int64": 987654321123},
1849                ]
1850        writer = FileWriter(mindrecord_file_name)
1851        schema = {"float32_array": {"type": "float32", "shape": [-1]},
1852                  "float64_array": {"type": "float64", "shape": [-1]},
1853                  "float32": {"type": "float32"},
1854                  "float64": {"type": "float64"},
1855                  "int32_array": {"type": "int32", "shape": [-1]},
1856                  "int64_array": {"type": "int64", "shape": [-1]},
1857                  "int32": {"type": "int32"},
1858                  "int64": {"type": "int64"}}
1859        writer.add_schema(schema, "data is so cool")
1860        writer.write_raw_data(data)
1861        writer.commit()
1862
1863        # change data value to list - do none
1864        data_value_to_list = []
1865        for item in data:
1866            new_data = {}
1867            new_data['float32_array'] = item["float32_array"]
1868            new_data['float64_array'] = item["float64_array"]
1869            new_data['float32'] = item["float32"]
1870            new_data['float64'] = item["float64"]
1871            new_data['int32_array'] = item["int32_array"]
1872            new_data['int64_array'] = item["int64_array"]
1873            new_data['int32'] = item["int32"]
1874            new_data['int64'] = item["int64"]
1875            data_value_to_list.append(new_data)
1876
1877        num_readers = 2
1878        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1879                                  num_parallel_workers=num_readers,
1880                                  shuffle=False)
1881        assert data_set.get_dataset_size() == 5
1882        num_iter = 0
1883        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1884            assert len(item) == 8
1885            for field in item:
1886                if isinstance(item[field], np.ndarray):
1887                    if item[field].dtype == np.float32:
1888                        assert (item[field] ==
1889                                np.array(data_value_to_list[num_iter][field], np.float32)).all()
1890                    else:
1891                        assert (item[field] ==
1892                                data_value_to_list[num_iter][field]).all()
1893                else:
1894                    assert item[field] == data_value_to_list[num_iter][field]
1895            num_iter += 1
1896        assert num_iter == 5
1897
1898        num_readers = 2
1899        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1900                                  columns_list=["float32", "int32"],
1901                                  num_parallel_workers=num_readers,
1902                                  shuffle=False)
1903        assert data_set.get_dataset_size() == 5
1904        num_iter = 0
1905        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1906            assert len(item) == 2
1907            for field in item:
1908                if isinstance(item[field], np.ndarray):
1909                    if item[field].dtype == np.float32:
1910                        assert (item[field] ==
1911                                np.array(data_value_to_list[num_iter][field], np.float32)).all()
1912                    else:
1913                        assert (item[field] ==
1914                                data_value_to_list[num_iter][field]).all()
1915                else:
1916                    assert item[field] == data_value_to_list[num_iter][field]
1917            num_iter += 1
1918        assert num_iter == 5
1919
1920        num_readers = 2
1921        data_set = ds.MindDataset(dataset_file=mindrecord_file_name,
1922                                  columns_list=["float64", "int64"],
1923                                  num_parallel_workers=num_readers,
1924                                  shuffle=False)
1925        assert data_set.get_dataset_size() == 5
1926        num_iter = 0
1927        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
1928            assert len(item) == 2
1929            for field in item:
1930                if isinstance(item[field], np.ndarray):
1931                    if item[field].dtype == np.float32:
1932                        assert (item[field] ==
1933                                np.array(data_value_to_list[num_iter][field], np.float32)).all()
1934                    elif item[field].dtype == np.float64:
1935                        assert math.isclose(item[field],
1936                                            np.array(data_value_to_list[num_iter][field], np.float64),
1937                                            rel_tol=1e-14)
1938                    else:
1939                        assert (item[field] ==
1940                                data_value_to_list[num_iter][field]).all()
1941                else:
1942                    assert item[field] == data_value_to_list[num_iter][field]
1943            num_iter += 1
1944        assert num_iter == 5
1945    except Exception as error:
1946        os.remove("{}".format(mindrecord_file_name))
1947        os.remove("{}.db".format(mindrecord_file_name))
1948        raise error
1949    else:
1950        os.remove("{}".format(mindrecord_file_name))
1951        os.remove("{}.db".format(mindrecord_file_name))
1952
1953FILES = ["0.mindrecord", "1.mindrecord", "2.mindrecord", "3.mindrecord"]
1954ITEMS = [10, 14, 8, 20]
1955FILES_ITEMS = {FILES[0]: ITEMS[0], FILES[1]: ITEMS[1], FILES[2]: ITEMS[2], FILES[3]: ITEMS[3]}
1956
1957@pytest.fixture
1958def create_multi_mindrecord_files():
1959    """files: {0.mindrecord : 10, 1.mindrecord : 14, 2.mindrecord : 8, 3.mindrecord : 20}"""
1960    try:
1961        index = 0
1962        for filename in FILES_ITEMS:
1963            key = filename
1964            if os.path.exists(key):
1965                os.remove("{}".format(key))
1966                os.remove("{}.db".format(key))
1967
1968            value = FILES_ITEMS[key]
1969            data_list = []
1970            for i in range(value):
1971                data = {}
1972                data['id'] = i + index
1973                data_list.append(data)
1974            index += value
1975
1976            writer = FileWriter(key)
1977            schema = {"id": {"type": "int32"}}
1978            writer.add_schema(schema, "data is so cool")
1979            writer.write_raw_data(data_list)
1980            writer.commit()
1981        yield "yield_create_multi_mindrecord_files"
1982    except Exception as error:
1983        for filename in FILES_ITMES:
1984            if os.path.exists(filename):
1985                os.remove("{}".format(filename))
1986                os.remove("{}.db".format(filename))
1987        raise error
1988    else:
1989        for filename in FILES_ITEMS:
1990            if os.path.exists(filename):
1991                os.remove("{}".format(filename))
1992                os.remove("{}.db".format(filename))
1993
1994def test_shuffle_with_global_infile_files(create_multi_mindrecord_files):
1995    datas_all = []
1996    index = 0
1997    for filename in FILES_ITEMS:
1998        value = FILES_ITEMS[filename]
1999        data_list = []
2000        for i in range(value):
2001            data = {}
2002            data['id'] = np.array(i + index, dtype=np.int32)
2003            data_list.append(data)
2004        index += value
2005        datas_all.append(data_list)
2006
2007    # no shuffle parameter
2008    num_readers = 2
2009    data_set = ds.MindDataset(dataset_file=FILES,
2010                              num_parallel_workers=num_readers)
2011    assert data_set.get_dataset_size() == 52
2012    num_iter = 0
2013    datas_all_minddataset = []
2014    data_list = []
2015    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2016        assert len(item) == 1
2017        data_list.append(item)
2018        if num_iter == 9:
2019            datas_all_minddataset.append(data_list)
2020            data_list = []
2021        elif num_iter == 23:
2022            datas_all_minddataset.append(data_list)
2023            data_list = []
2024        elif num_iter == 31:
2025            datas_all_minddataset.append(data_list)
2026            data_list = []
2027        elif num_iter == 51:
2028            datas_all_minddataset.append(data_list)
2029            data_list = []
2030        num_iter += 1
2031    assert data_set.get_dataset_size() == 52
2032
2033    assert len(datas_all) == len(datas_all_minddataset)
2034    for i, _ in enumerate(datas_all):
2035        assert len(datas_all[i]) == len(datas_all_minddataset[i])
2036        assert datas_all[i] != datas_all_minddataset[i]
2037
2038    # shuffle=False
2039    num_readers = 2
2040    data_set = ds.MindDataset(dataset_file=FILES,
2041                              num_parallel_workers=num_readers,
2042                              shuffle=False)
2043    assert data_set.get_dataset_size() == 52
2044    num_iter = 0
2045    datas_all_minddataset = []
2046    data_list = []
2047    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2048        assert len(item) == 1
2049        data_list.append(item)
2050        if num_iter == 9:
2051            datas_all_minddataset.append(data_list)
2052            data_list = []
2053        elif num_iter == 23:
2054            datas_all_minddataset.append(data_list)
2055            data_list = []
2056        elif num_iter == 31:
2057            datas_all_minddataset.append(data_list)
2058            data_list = []
2059        elif num_iter == 51:
2060            datas_all_minddataset.append(data_list)
2061            data_list = []
2062        num_iter += 1
2063    assert data_set.get_dataset_size() == 52
2064
2065    assert len(datas_all) == len(datas_all_minddataset)
2066    for i, _ in enumerate(datas_all):
2067        assert len(datas_all[i]) == len(datas_all_minddataset[i])
2068        assert datas_all[i] == datas_all_minddataset[i]
2069
2070    # shuffle=True
2071    num_readers = 2
2072    data_set = ds.MindDataset(dataset_file=FILES,
2073                              num_parallel_workers=num_readers,
2074                              shuffle=True)
2075    assert data_set.get_dataset_size() == 52
2076    num_iter = 0
2077    datas_all_minddataset = []
2078    data_list = []
2079    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2080        assert len(item) == 1
2081        data_list.append(item)
2082        if num_iter == 9:
2083            datas_all_minddataset.append(data_list)
2084            data_list = []
2085        elif num_iter == 23:
2086            datas_all_minddataset.append(data_list)
2087            data_list = []
2088        elif num_iter == 31:
2089            datas_all_minddataset.append(data_list)
2090            data_list = []
2091        elif num_iter == 51:
2092            datas_all_minddataset.append(data_list)
2093            data_list = []
2094        num_iter += 1
2095    assert data_set.get_dataset_size() == 52
2096
2097    assert len(datas_all) == len(datas_all_minddataset)
2098    for i, _ in enumerate(datas_all):
2099        assert len(datas_all[i]) == len(datas_all_minddataset[i])
2100        assert datas_all[i] != datas_all_minddataset[i]
2101
2102    # shuffle=Shuffle.GLOBAL
2103    num_readers = 2
2104    data_set = ds.MindDataset(dataset_file=FILES,
2105                              num_parallel_workers=num_readers,
2106                              shuffle=ds.Shuffle.GLOBAL)
2107    assert data_set.get_dataset_size() == 52
2108    num_iter = 0
2109    datas_all_minddataset = []
2110    data_list = []
2111    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2112        assert len(item) == 1
2113        data_list.append(item)
2114        if num_iter == 9:
2115            datas_all_minddataset.append(data_list)
2116            data_list = []
2117        elif num_iter == 23:
2118            datas_all_minddataset.append(data_list)
2119            data_list = []
2120        elif num_iter == 31:
2121            datas_all_minddataset.append(data_list)
2122            data_list = []
2123        elif num_iter == 51:
2124            datas_all_minddataset.append(data_list)
2125            data_list = []
2126        num_iter += 1
2127    assert data_set.get_dataset_size() == 52
2128
2129    assert len(datas_all) == len(datas_all_minddataset)
2130    for i, _ in enumerate(datas_all):
2131        assert len(datas_all[i]) == len(datas_all_minddataset[i])
2132        assert datas_all[i] != datas_all_minddataset[i]
2133
2134    # shuffle=Shuffle.INFILE
2135    num_readers = 2
2136    data_set = ds.MindDataset(dataset_file=FILES,
2137                              num_parallel_workers=num_readers,
2138                              shuffle=ds.Shuffle.INFILE)
2139    assert data_set.get_dataset_size() == 52
2140    num_iter = 0
2141    datas_all_minddataset = []
2142    data_list = []
2143    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2144        assert len(item) == 1
2145        data_list.append(item)
2146        if num_iter == 9:
2147            datas_all_minddataset.append(data_list)
2148            data_list = []
2149        elif num_iter == 23:
2150            datas_all_minddataset.append(data_list)
2151            data_list = []
2152        elif num_iter == 31:
2153            datas_all_minddataset.append(data_list)
2154            data_list = []
2155        elif num_iter == 51:
2156            datas_all_minddataset.append(data_list)
2157            data_list = []
2158        num_iter += 1
2159    assert data_set.get_dataset_size() == 52
2160
2161    def sort_list_with_dict(dict_in_list):
2162        keys = []
2163        for item in dict_in_list:
2164            for key in item:
2165                keys.append(int(item[key]))
2166        keys.sort()
2167        data_list = []
2168        for item in keys:
2169            data = {}
2170            data['id'] = np.array(item, dtype=np.int32)
2171            data_list.append(data)
2172        return data_list
2173
2174    assert len(datas_all) == len(datas_all_minddataset)
2175    for i, _ in enumerate(datas_all):
2176        assert len(datas_all[i]) == len(datas_all_minddataset[i])
2177        assert datas_all[i] != datas_all_minddataset[i]
2178        # order the datas_all_minddataset
2179        new_datas_all_minddataset = sort_list_with_dict(datas_all_minddataset[i])
2180        assert datas_all[i] == new_datas_all_minddataset
2181
2182    # shuffle=Shuffle.FILES
2183    num_readers = 2
2184    data_set = ds.MindDataset(dataset_file=FILES,
2185                              num_parallel_workers=num_readers,
2186                              shuffle=ds.Shuffle.FILES)
2187    assert data_set.get_dataset_size() == 52
2188
2189    num_iter = 0
2190    data_list = []
2191
2192    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2193        assert len(item) == 1
2194        data_list.append(item)
2195        num_iter += 1
2196    assert data_set.get_dataset_size() == 52
2197
2198    current_shard_size = 0
2199    current_shard_index = 0
2200    shard_count = 0
2201    datas_index = 0
2202    origin_index = [i for i in range(len(ITEMS))]
2203    current_index = []
2204    while shard_count < len(ITEMS):
2205        if data_list[datas_index]['id'] < 10:
2206            current_shard_index = 0
2207        elif data_list[datas_index]['id'] < 24:
2208            current_shard_index = 1
2209        elif data_list[datas_index]['id'] < 32:
2210            current_shard_index = 2
2211        elif data_list[datas_index]['id'] < 52:
2212            current_shard_index = 3
2213        else:
2214            raise ValueError("Index out of range")
2215        current_shard_size = ITEMS[current_shard_index]
2216
2217        tmp_datas = data_list[datas_index:datas_index + current_shard_size]
2218        current_index.append(current_shard_index)
2219        assert len(datas_all[current_shard_index]) == len(tmp_datas)
2220        assert datas_all[current_shard_index] == tmp_datas
2221
2222        datas_index += current_shard_size
2223        shard_count += 1
2224    assert origin_index != current_index
2225
2226def test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_files):
2227    datas_all = []
2228    datas_all_samples = []
2229    index = 0
2230    for filename in FILES_ITEMS:
2231        value = FILES_ITEMS[filename]
2232        data_list = []
2233        for i in range(value):
2234            data = {}
2235            data['id'] = np.array(i + index, dtype=np.int32)
2236            data_list.append(data)
2237            datas_all_samples.append(data)
2238        index += value
2239        datas_all.append(data_list)
2240
2241    # no shuffle parameter
2242    num_readers = 2
2243    data_set = ds.MindDataset(dataset_file=FILES,
2244                              num_parallel_workers=num_readers,
2245                              num_shards=4,
2246                              shard_id=3)
2247    assert data_set.get_dataset_size() == 13
2248    num_iter = 0
2249    data_list = []
2250    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2251        assert len(item) == 1
2252        data_list.append(item)
2253        num_iter += 1
2254    assert num_iter == 13
2255    assert data_list != datas_all_samples[3*13:]
2256
2257    # shuffle=False
2258    num_readers = 2
2259    data_set = ds.MindDataset(dataset_file=FILES,
2260                              num_parallel_workers=num_readers,
2261                              shuffle=False,
2262                              num_shards=4,
2263                              shard_id=2)
2264    assert data_set.get_dataset_size() == 13
2265    num_iter = 0
2266    data_list = []
2267    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2268        assert len(item) == 1
2269        data_list.append(item)
2270        num_iter += 1
2271    assert num_iter == 13
2272    assert data_list == datas_all_samples[2*13:3*13]
2273
2274    # shuffle=True
2275    num_readers = 2
2276    data_set = ds.MindDataset(dataset_file=FILES,
2277                              num_parallel_workers=num_readers,
2278                              shuffle=True,
2279                              num_shards=4,
2280                              shard_id=1)
2281    assert data_set.get_dataset_size() == 13
2282    num_iter = 0
2283    data_list = []
2284    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2285        assert len(item) == 1
2286        data_list.append(item)
2287        num_iter += 1
2288    assert num_iter == 13
2289    assert data_list != datas_all_samples[1*13:2*13]
2290
2291    # shuffle=Shuffle.GLOBAL
2292    num_readers = 2
2293    data_set = ds.MindDataset(dataset_file=FILES,
2294                              num_parallel_workers=num_readers,
2295                              shuffle=ds.Shuffle.GLOBAL,
2296                              num_shards=4,
2297                              shard_id=0)
2298    assert data_set.get_dataset_size() == 13
2299    num_iter = 0
2300    data_list = []
2301    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2302        assert len(item) == 1
2303        data_list.append(item)
2304        num_iter += 1
2305    assert num_iter == 13
2306    assert data_list != datas_all_samples[0:1*13]
2307
2308    # shuffle=Shuffle.INFILE
2309    output_datas = []
2310    for shard_id in range(4):
2311        num_readers = 2
2312        data_set = ds.MindDataset(dataset_file=FILES,
2313                                  num_parallel_workers=num_readers,
2314                                  shuffle=ds.Shuffle.INFILE,
2315                                  num_shards=4,
2316                                  shard_id=shard_id)
2317        assert data_set.get_dataset_size() == 13
2318        num_iter = 0
2319        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2320            assert len(item) == 1
2321            output_datas.append(item)
2322            num_iter += 1
2323        assert num_iter == 13
2324
2325    num_iter = 0
2326    datas_all_minddataset = []
2327    data_list = []
2328    for item in output_datas:
2329        assert len(item) == 1
2330        data_list.append(item)
2331        if num_iter == 9:
2332            datas_all_minddataset.append(data_list)
2333            data_list = []
2334        elif num_iter == 23:
2335            datas_all_minddataset.append(data_list)
2336            data_list = []
2337        elif num_iter == 31:
2338            datas_all_minddataset.append(data_list)
2339            data_list = []
2340        elif num_iter == 51:
2341            datas_all_minddataset.append(data_list)
2342            data_list = []
2343        num_iter += 1
2344    assert num_iter == 52
2345
2346    def sort_list_with_dict(dict_in_list):
2347        keys = []
2348        for item in dict_in_list:
2349            for key in item:
2350                keys.append(int(item[key]))
2351        keys.sort()
2352        data_list = []
2353        for item in keys:
2354            data = {}
2355            data['id'] = np.array(item, dtype=np.int32)
2356            data_list.append(data)
2357        return data_list
2358
2359    assert len(datas_all) == len(datas_all_minddataset)
2360    for i, _ in enumerate(datas_all):
2361        assert len(datas_all[i]) == len(datas_all_minddataset[i])
2362        assert datas_all[i] != datas_all_minddataset[i]
2363        # order the datas_all_minddataset
2364        new_datas_all_minddataset = sort_list_with_dict(datas_all_minddataset[i])
2365        assert datas_all[i] == new_datas_all_minddataset
2366
2367    # shuffle=Shuffle.FILES
2368    data_list = []
2369    for shard_id in range(4):
2370        num_readers = 2
2371        data_set = ds.MindDataset(dataset_file=FILES,
2372                                  num_parallel_workers=num_readers,
2373                                  shuffle=ds.Shuffle.FILES,
2374                                  num_shards=4,
2375                                  shard_id=shard_id)
2376        assert data_set.get_dataset_size() == 13
2377        num_iter = 0
2378        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
2379            assert len(item) == 1
2380            data_list.append(item)
2381            num_iter += 1
2382        assert num_iter == 13
2383    assert len(data_list) == 52
2384
2385    current_shard_size = 0
2386    current_shard_index = 0
2387    shard_count = 0
2388    datas_index = 0
2389    origin_index = [i for i in range(len(ITEMS))]
2390    current_index = []
2391    while shard_count < len(ITEMS):
2392        if data_list[datas_index]['id'] < 10:
2393            current_shard_index = 0
2394        elif data_list[datas_index]['id'] < 24:
2395            current_shard_index = 1
2396        elif data_list[datas_index]['id'] < 32:
2397            current_shard_index = 2
2398        elif data_list[datas_index]['id'] < 52:
2399            current_shard_index = 3
2400        else:
2401            raise ValueError("Index out of range")
2402        current_shard_size = ITEMS[current_shard_index]
2403
2404        tmp_datas = data_list[datas_index:datas_index + current_shard_size]
2405        current_index.append(current_shard_index)
2406        assert len(datas_all[current_shard_index]) == len(tmp_datas)
2407        assert datas_all[current_shard_index] == tmp_datas
2408
2409        datas_index += current_shard_size
2410        shard_count += 1
2411    assert origin_index != current_index
2412
2413def test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files):
2414    datas_all = []
2415    datas_all_samples = []
2416    index = 0
2417    for filename in FILES_ITEMS:
2418        value = FILES_ITEMS[filename]
2419        data_list = []
2420        for i in range(value):
2421            data = {}
2422            data['id'] = np.array(i + index, dtype=np.int32)
2423            data_list.append(data)
2424            datas_all_samples.append(data)
2425        index += value
2426        datas_all.append(data_list)
2427
2428    epoch_size = 3
2429
2430    # no shuffle parameter
2431    for shard_id in range(4):
2432        num_readers = 2
2433        data_set = ds.MindDataset(dataset_file=FILES,
2434                                  num_parallel_workers=num_readers,
2435                                  num_shards=4,
2436                                  shard_id=shard_id)
2437        assert data_set.get_dataset_size() == 13
2438        data_list = []
2439        dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
2440        for epoch in range(epoch_size):  # 3 epoch
2441            num_iter = 0
2442            new_datas = []
2443            for item in dataset_iter:
2444                assert len(item) == 1
2445                new_datas.append(item)
2446                num_iter += 1
2447            assert num_iter == 13
2448            assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
2449            assert data_list != new_datas
2450            data_list = new_datas
2451
2452    # shuffle=False
2453    for shard_id in range(4):
2454        num_readers = 2
2455        data_set = ds.MindDataset(dataset_file=FILES,
2456                                  num_parallel_workers=num_readers,
2457                                  shuffle=False,
2458                                  num_shards=4,
2459                                  shard_id=shard_id)
2460        assert data_set.get_dataset_size() == 13
2461        data_list = []
2462        dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
2463        for epoch in range(epoch_size):  # 3 epoch
2464            num_iter = 0
2465            new_datas = []
2466            for item in dataset_iter:
2467                assert len(item) == 1
2468                new_datas.append(item)
2469                num_iter += 1
2470            assert num_iter == 13
2471            assert new_datas == datas_all_samples[shard_id*13:(shard_id+1)*13]
2472
2473    # shuffle=True
2474    for shard_id in range(4):
2475        num_readers = 2
2476        data_set = ds.MindDataset(dataset_file=FILES,
2477                                  num_parallel_workers=num_readers,
2478                                  shuffle=True,
2479                                  num_shards=4,
2480                                  shard_id=shard_id)
2481        assert data_set.get_dataset_size() == 13
2482        data_list = []
2483        dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
2484        for epoch in range(epoch_size):  # 3 epoch
2485            num_iter = 0
2486            new_datas = []
2487            for item in dataset_iter:
2488                assert len(item) == 1
2489                new_datas.append(item)
2490                num_iter += 1
2491            assert num_iter == 13
2492            assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
2493            assert data_list != new_datas
2494            data_list = new_datas
2495
2496    # shuffle=Shuffle.GLOBAL
2497    for shard_id in range(4):
2498        num_readers = 2
2499        data_set = ds.MindDataset(dataset_file=FILES,
2500                                  num_parallel_workers=num_readers,
2501                                  shuffle=ds.Shuffle.GLOBAL,
2502                                  num_shards=4,
2503                                  shard_id=shard_id)
2504        assert data_set.get_dataset_size() == 13
2505        data_list = []
2506        dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
2507        for epoch in range(epoch_size):  # 3 epoch
2508            num_iter = 0
2509            new_datas = []
2510            for item in dataset_iter:
2511                assert len(item) == 1
2512                new_datas.append(item)
2513                num_iter += 1
2514            assert num_iter == 13
2515            assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
2516            assert data_list != new_datas
2517            data_list = new_datas
2518
2519    # shuffle=Shuffle.INFILE
2520    for shard_id in range(4):
2521        num_readers = 2
2522        data_set = ds.MindDataset(dataset_file=FILES,
2523                                  num_parallel_workers=num_readers,
2524                                  shuffle=ds.Shuffle.INFILE,
2525                                  num_shards=4,
2526                                  shard_id=shard_id)
2527        assert data_set.get_dataset_size() == 13
2528        data_list = []
2529        dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
2530        for epoch in range(epoch_size):  # 3 epoch
2531            num_iter = 0
2532            new_datas = []
2533            for item in dataset_iter:
2534                assert len(item) == 1
2535                new_datas.append(item)
2536                num_iter += 1
2537            assert num_iter == 13
2538            assert new_datas != datas_all_samples[shard_id*13:(shard_id+1)*13]
2539            assert data_list != new_datas
2540            data_list = new_datas
2541
2542    # shuffle=Shuffle.FILES
2543    datas_epoch1 = []
2544    datas_epoch2 = []
2545    datas_epoch3 = []
2546    for shard_id in range(4):
2547        num_readers = 2
2548        data_set = ds.MindDataset(dataset_file=FILES,
2549                                  num_parallel_workers=num_readers,
2550                                  shuffle=ds.Shuffle.FILES,
2551                                  num_shards=4,
2552                                  shard_id=shard_id)
2553        assert data_set.get_dataset_size() == 13
2554        dataset_iter = data_set.create_dict_iterator(num_epochs=epoch_size, output_numpy=True)
2555        for epoch in range(epoch_size):  # 3 epoch
2556            num_iter = 0
2557            for item in dataset_iter:
2558                assert len(item) == 1
2559                if epoch == 0:
2560                    datas_epoch1.append(item)
2561                elif epoch == 1:
2562                    datas_epoch2.append(item)
2563                elif epoch == 2:
2564                    datas_epoch3.append(item)
2565                num_iter += 1
2566            assert num_iter == 13
2567    assert datas_epoch1 not in (datas_epoch2, datas_epoch3)
2568    assert datas_epoch2 not in (datas_epoch1, datas_epoch3)
2569    assert datas_epoch3 not in (datas_epoch2, datas_epoch1)
2570
2571def test_field_is_null_numpy():
2572    """add/remove nlp file"""
2573    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
2574             for x in range(FILES_NUM)]
2575    for x in paths:
2576        if os.path.exists("{}".format(x)):
2577            os.remove("{}".format(x))
2578        if os.path.exists("{}.db".format(x)):
2579            os.remove("{}.db".format(x))
2580
2581    writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
2582    data = []
2583    # field array_d is null
2584    for row_id in range(16):
2585        data.append({
2586            "label": row_id,
2587            "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
2588                                            255, 256, -32768, 32767, -32769, 32768, -2147483648,
2589                                            2147483647], dtype=np.int32), [-1]),
2590            "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
2591                                            256, -32768, 32767, -32769, 32768,
2592                                            -2147483648, 2147483647, -2147483649, 2147483649,
2593                                            -922337036854775808, 9223372036854775807]), [1, -1]),
2594            "array_d": np.array([], dtype=np.int64)
2595        })
2596    nlp_schema_json = {"label": {"type": "int32"},
2597                       "array_a": {"type": "int32",
2598                                   "shape": [-1]},
2599                       "array_b": {"type": "int64",
2600                                   "shape": [1, -1]},
2601                       "array_d": {"type": "int64",
2602                                   "shape": [-1]}
2603                       }
2604    writer.set_header_size(1 << 14)
2605    writer.set_page_size(1 << 15)
2606    writer.add_schema(nlp_schema_json, "nlp_schema")
2607    writer.write_raw_data(data)
2608    writer.commit()
2609
2610    data_set = ds.MindDataset(dataset_file=NLP_FILE_NAME + "0",
2611                              columns_list=["label", "array_a", "array_b", "array_d"],
2612                              num_parallel_workers=2,
2613                              shuffle=False)
2614    assert data_set.get_dataset_size() == 16
2615    assert data_set.output_shapes() == [[], [15], [1, 19], []]
2616    assert data_set.output_types()[0] == np.int32
2617    assert data_set.output_types()[1] == np.int32
2618    assert data_set.output_types()[2] == np.int64
2619    assert data_set.output_types()[3] == np.int64
2620
2621    for x in paths:
2622        os.remove("{}".format(x))
2623        os.remove("{}.db".format(x))
2624
2625def test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file):
2626    """test for loop dataset iterator"""
2627    data = []
2628    for row_id in range(16):
2629        data.append({
2630            "label": row_id,
2631            "array_a": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129,
2632                                            255, 256, -32768, 32767, -32769, 32768, -2147483648,
2633                                            2147483647], dtype=np.int32), [-1]),
2634            "array_b": np.reshape(np.array([0, 1, -1, 127, -128, 128, -129, 255,
2635                                            256, -32768, 32767, -32769, 32768,
2636                                            -2147483648, 2147483647, -2147483649, 2147483649,
2637                                            -922337036854775808, 9223372036854775807]), [1, -1]),
2638            "array_c": str.encode("nlp data"),
2639            "array_d": np.reshape(np.array([[-10, -127], [10, 127]]), [2, -1])
2640        })
2641    num_readers = 1
2642    data_set = ds.MindDataset(
2643        NLP_FILE_NAME + "0", None, num_readers, shuffle=False)
2644    assert data_set.get_dataset_size() == 16
2645
2646    # create_dict_iterator in for loop
2647    for _ in range(10):
2648        num_iter = 0
2649        for x, item in zip(data, data_set.create_dict_iterator(num_epochs=1, output_numpy=True)):
2650            assert (item["array_a"] == x["array_a"]).all()
2651            assert (item["array_b"] == x["array_b"]).all()
2652            assert item["array_c"].tobytes() == x["array_c"]
2653            assert (item["array_d"] == x["array_d"]).all()
2654            assert item["label"] == x["label"]
2655            num_iter += 1
2656        assert num_iter == 16
2657
2658    # create_dict_iterator beyond for loop
2659    dataset_iter = data_set.create_dict_iterator(num_epochs=10, output_numpy=True)
2660    new_data = data * 10
2661    for _ in range(10):
2662        num_iter = 0
2663        for x, item in zip(new_data, dataset_iter):
2664            assert (item["array_a"] == x["array_a"]).all()
2665            assert (item["array_b"] == x["array_b"]).all()
2666            assert item["array_c"].tobytes() == x["array_c"]
2667            assert (item["array_d"] == x["array_d"]).all()
2668            assert item["label"] == x["label"]
2669            num_iter += 1
2670        assert num_iter == 16
2671
2672    # create mulit iter by user
2673    dataset_iter2 = data_set.create_dict_iterator(num_epochs=1, output_numpy=True)
2674    assert (next(dataset_iter2)["array_a"] == data[0]["array_a"]).all()
2675    assert (next(dataset_iter2)["array_a"] == data[1]["array_a"]).all()
2676
2677    dataset_iter3 = data_set.create_dict_iterator(num_epochs=1, output_numpy=True)
2678    assert (next(dataset_iter3)["array_a"] == data[0]["array_a"]).all()
2679    assert (next(dataset_iter3)["array_a"] == data[1]["array_a"]).all()
2680    assert (next(dataset_iter3)["array_a"] == data[2]["array_a"]).all()
2681
2682    assert (next(dataset_iter2)["array_a"] == data[2]["array_a"]).all()
2683    assert (next(dataset_iter2)["array_a"] == data[3]["array_a"]).all()
2684
2685    dataset_iter4 = data_set.create_dict_iterator(num_epochs=1, output_numpy=True)
2686    assert (next(dataset_iter4)["array_a"] == data[0]["array_a"]).all()
2687    assert (next(dataset_iter4)["array_a"] == data[1]["array_a"]).all()
2688    assert (next(dataset_iter4)["array_a"] == data[2]["array_a"]).all()
2689
2690    assert (next(dataset_iter3)["array_a"] == data[3]["array_a"]).all()
2691    assert (next(dataset_iter3)["array_a"] == data[4]["array_a"]).all()
2692    assert (next(dataset_iter3)["array_a"] == data[5]["array_a"]).all()
2693
2694if __name__ == '__main__':
2695    test_nlp_compress_data(add_and_remove_nlp_compress_file)
2696    test_nlp_compress_data_old_version(add_and_remove_nlp_compress_file)
2697    test_cv_minddataset_writer_tutorial()
2698    test_cv_minddataset_partition_tutorial(add_and_remove_cv_file)
2699    test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file)
2700    test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file)
2701    test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file)
2702    test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file)
2703    test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file)
2704    test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file)
2705    test_cv_minddataset_dataset_size(add_and_remove_cv_file)
2706    test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file)
2707    test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file)
2708    test_cv_minddataset_issue_888(add_and_remove_cv_file)
2709    test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file)
2710    test_cv_minddataset_blockreader_some_field_not_in_index_tutorial(add_and_remove_cv_file)
2711    test_cv_minddataset_reader_file_list(add_and_remove_cv_file)
2712    test_cv_minddataset_reader_one_partition(add_and_remove_cv_file)
2713    test_cv_minddataset_reader_two_dataset(add_and_remove_cv_file)
2714    test_cv_minddataset_reader_two_dataset_partition(add_and_remove_cv_file)
2715    test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file)
2716    test_nlp_minddataset_reader_basic_tutorial(add_and_remove_cv_file)
2717    test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file)
2718    test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file)
2719    test_cv_minddataset_reader_no_columns(add_and_remove_cv_file)
2720    test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file)
2721    test_write_with_multi_bytes_and_array_and_read_by_MindDataset()
2722    test_write_with_multi_bytes_and_MindDataset()
2723    test_write_with_multi_array_and_MindDataset()
2724    test_numpy_generic()
2725    test_write_with_float32_float64_float32_array_float64_array_and_MindDataset()
2726    test_shuffle_with_global_infile_files(create_multi_mindrecord_files)
2727    test_distributed_shuffle_with_global_infile_files(create_multi_mindrecord_files)
2728    test_distributed_shuffle_with_multi_epochs(create_multi_mindrecord_files)
2729    test_field_is_null_numpy()
2730    test_for_loop_dataset_iterator(add_and_remove_nlp_compress_file)
2731