• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This is the test module for mindrecord
17"""
18import collections
19import os
20import re
21import string
22
23import numpy as np
24import pytest
25
26import mindspore.dataset as ds
27from mindspore import log as logger
28from mindspore.mindrecord import FileWriter
29
30FILES_NUM = 4
31CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
32CV1_FILE_NAME = "../data/mindrecord/imagenet1.mindrecord"
33CV2_FILE_NAME = "../data/mindrecord/imagenet2.mindrecord"
34CV_DIR_NAME = "../data/mindrecord/testImageNetData"
35NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
36NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
37NLP_FILE_VOCAB = "../data/mindrecord/testAclImdbData/vocab.txt"
38
39
40@pytest.fixture
41def add_and_remove_cv_file():
42    """add/remove cv file"""
43    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
44             for x in range(FILES_NUM)]
45    try:
46        for x in paths:
47            os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
48            os.remove("{}.db".format(x)) if os.path.exists(
49                "{}.db".format(x)) else None
50        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
51        data = get_data(CV_DIR_NAME)
52        cv_schema_json = {"id": {"type": "int32"},
53                          "file_name": {"type": "string"},
54                          "label": {"type": "int32"},
55                          "data": {"type": "bytes"}}
56        writer.add_schema(cv_schema_json, "img_schema")
57        writer.add_index(["file_name", "label"])
58        writer.write_raw_data(data)
59        writer.commit()
60        yield "yield_cv_data"
61    except Exception as error:
62        for x in paths:
63            os.remove("{}".format(x))
64            os.remove("{}.db".format(x))
65        raise error
66    else:
67        for x in paths:
68            os.remove("{}".format(x))
69            os.remove("{}.db".format(x))
70
71
72@pytest.fixture
73def add_and_remove_nlp_file():
74    """add/remove nlp file"""
75    paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
76             for x in range(FILES_NUM)]
77    try:
78        for x in paths:
79            if os.path.exists("{}".format(x)):
80                os.remove("{}".format(x))
81            if os.path.exists("{}.db".format(x)):
82                os.remove("{}.db".format(x))
83        writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
84        data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
85        nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
86                           "rating": {"type": "float32"},
87                           "input_ids": {"type": "int64",
88                                         "shape": [-1]},
89                           "input_mask": {"type": "int64",
90                                          "shape": [1, -1]},
91                           "segment_ids": {"type": "int64",
92                                           "shape": [2, -1]}
93                           }
94        writer.set_header_size(1 << 14)
95        writer.set_page_size(1 << 15)
96        writer.add_schema(nlp_schema_json, "nlp_schema")
97        writer.add_index(["id", "rating"])
98        writer.write_raw_data(data)
99        writer.commit()
100        yield "yield_nlp_data"
101    except Exception as error:
102        for x in paths:
103            os.remove("{}".format(x))
104            os.remove("{}.db".format(x))
105        raise error
106    else:
107        for x in paths:
108            os.remove("{}".format(x))
109            os.remove("{}.db".format(x))
110
111
112def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file):
113    """tutorial for cv minderdataset."""
114    columns_list = ["label", "file_name", "data"]
115
116    data = get_data(CV_DIR_NAME)
117    padded_sample = data[0]
118    padded_sample['label'] = -1
119    padded_sample['file_name'] = 'dummy.jpg'
120    num_readers = 4
121    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, padded_sample=padded_sample, num_padded=5)
122    assert data_set.get_dataset_size() == 15
123    num_iter = 0
124    num_padded_iter = 0
125    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
126        logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
127        logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
128        logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
129        if item['label'] == -1:
130            num_padded_iter += 1
131            assert item['file_name'] == bytes(padded_sample['file_name'],
132                                              encoding='utf8')
133            assert item['label'] == padded_sample['label']
134            assert (item['data'] == np.array(list(padded_sample['data']))).all()
135        num_iter += 1
136    assert num_padded_iter == 5
137    assert num_iter == 15
138
139def test_cv_minddataset_reader_basic_padded_samples_type_cast(add_and_remove_cv_file):
140    """tutorial for cv minderdataset."""
141    columns_list = ["label", "file_name", "data"]
142
143    data = get_data(CV_DIR_NAME)
144    padded_sample = data[0]
145    padded_sample['label'] = -1
146    padded_sample['file_name'] = 99999
147    num_readers = 4
148    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, padded_sample=padded_sample, num_padded=5)
149    assert data_set.get_dataset_size() == 15
150    num_iter = 0
151    num_padded_iter = 0
152    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
153        logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
154        logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
155        logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
156        if item['label'] == -1:
157            num_padded_iter += 1
158            assert item['file_name'] == bytes(str(padded_sample['file_name']),
159                                              encoding='utf8')
160            assert item['label'] == padded_sample['label']
161            assert (item['data'] == np.array(list(padded_sample['data']))).all()
162        num_iter += 1
163    assert num_padded_iter == 5
164    assert num_iter == 15
165
166
167def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
168    """tutorial for cv minddataset."""
169    columns_list = ["data", "file_name", "label"]
170
171    data = get_data(CV_DIR_NAME)
172    padded_sample = data[0]
173    padded_sample['label'] = -2
174    padded_sample['file_name'] = 'dummy.jpg'
175    num_readers = 4
176
177    def partitions(num_shards, num_padded, dataset_size):
178        num_padded_iter = 0
179        num_iter = 0
180        for partition_id in range(num_shards):
181            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
182                                      num_shards=num_shards,
183                                      shard_id=partition_id,
184                                      padded_sample=padded_sample,
185                                      num_padded=num_padded)
186            assert data_set.get_dataset_size() == dataset_size
187            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
188                logger.info("-------------- partition : {} ------------------------".format(partition_id))
189                logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
190                logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
191                logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
192                logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
193                if item['label'] == -2:
194                    num_padded_iter += 1
195                    assert item['file_name'] == bytes(padded_sample['file_name'], encoding='utf8')
196                    assert item['label'] == padded_sample['label']
197                    assert (item['data'] == np.array(list(padded_sample['data']))).all()
198                num_iter += 1
199        assert num_padded_iter == num_padded
200        return num_iter == dataset_size * num_shards
201
202    partitions(4, 2, 3)
203    partitions(5, 5, 3)
204    partitions(9, 8, 2)
205
206
207def test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_file):
208    """tutorial for cv minddataset."""
209    columns_list = ["data", "file_name", "label"]
210
211    data = get_data(CV_DIR_NAME)
212    padded_sample = data[0]
213    padded_sample['label'] = -2
214    padded_sample['file_name'] = 'dummy.jpg'
215    num_readers = 4
216
217    def partitions(num_shards, num_padded, dataset_size):
218        repeat_size = 5
219        num_padded_iter = 0
220        num_iter = 0
221        for partition_id in range(num_shards):
222            epoch1_shuffle_result = []
223            epoch2_shuffle_result = []
224            epoch3_shuffle_result = []
225            epoch4_shuffle_result = []
226            epoch5_shuffle_result = []
227            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
228                                      num_shards=num_shards,
229                                      shard_id=partition_id,
230                                      padded_sample=padded_sample,
231                                      num_padded=num_padded)
232            assert data_set.get_dataset_size() == dataset_size
233            data_set = data_set.repeat(repeat_size)
234            local_index = 0
235            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
236                logger.info("-------------- partition : {} ------------------------".format(partition_id))
237                logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
238                logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
239                logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
240                logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
241                if item['label'] == -2:
242                    num_padded_iter += 1
243                    assert item['file_name'] == bytes(padded_sample['file_name'], encoding='utf8')
244                    assert item['label'] == padded_sample['label']
245                    assert (item['data'] == np.array(list(padded_sample['data']))).all()
246                if local_index < dataset_size:
247                    epoch1_shuffle_result.append(item["file_name"])
248                elif local_index < dataset_size * 2:
249                    epoch2_shuffle_result.append(item["file_name"])
250                elif local_index < dataset_size * 3:
251                    epoch3_shuffle_result.append(item["file_name"])
252                elif local_index < dataset_size * 4:
253                    epoch4_shuffle_result.append(item["file_name"])
254                elif local_index < dataset_size * 5:
255                    epoch5_shuffle_result.append(item["file_name"])
256                local_index += 1
257                num_iter += 1
258            assert len(epoch1_shuffle_result) == dataset_size
259            assert len(epoch2_shuffle_result) == dataset_size
260            assert len(epoch3_shuffle_result) == dataset_size
261            assert len(epoch4_shuffle_result) == dataset_size
262            assert len(epoch5_shuffle_result) == dataset_size
263            assert local_index == dataset_size * repeat_size
264
265            # When dataset_size is equal to 2, too high probability is the same result after shuffle operation
266            if dataset_size > 2:
267                assert epoch1_shuffle_result != epoch2_shuffle_result
268                assert epoch2_shuffle_result != epoch3_shuffle_result
269                assert epoch3_shuffle_result != epoch4_shuffle_result
270                assert epoch4_shuffle_result != epoch5_shuffle_result
271        assert num_padded_iter == num_padded * repeat_size
272        assert num_iter == dataset_size * num_shards * repeat_size
273
274    partitions(4, 2, 3)
275    partitions(5, 5, 3)
276    partitions(9, 8, 2)
277
278
279def test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file):
280    """tutorial for cv minddataset."""
281    columns_list = ["data", "file_name", "label"]
282
283    data = get_data(CV_DIR_NAME)
284    padded_sample = data[0]
285    padded_sample['label'] = -2
286    padded_sample['file_name'] = 'dummy.jpg'
287    num_readers = 4
288
289    def partitions(num_shards, num_padded):
290        for partition_id in range(num_shards):
291            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
292                                      num_shards=num_shards,
293                                      shard_id=partition_id,
294                                      padded_sample=padded_sample,
295                                      num_padded=num_padded)
296            num_iter = 0
297            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
298                num_iter += 1
299        return num_iter
300
301    with pytest.raises(RuntimeError):
302        partitions(4, 1)
303
304
305def test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_and_remove_cv_file):
306    columns_list = ["data", "file_name", "label"]
307
308    data = get_data(CV_DIR_NAME)
309    padded_sample = data[0]
310    padded_sample['label'] = -2
311    padded_sample['file_name'] = 'dummy.jpg'
312    num_readers = 4
313
314    def partitions(num_shards, num_padded):
315        for partition_id in range(num_shards):
316            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
317                                      num_shards=num_shards,
318                                      shard_id=partition_id,
319                                      padded_sample=padded_sample,
320                                      num_padded=num_padded)
321            with pytest.raises(RuntimeError):
322                data_set.get_dataset_size() == 3
323
324    partitions(4, 1)
325
326
327def test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_remove_cv_file):
328    columns_list = ["data", "file_name", "label"]
329
330    data = get_data(CV_DIR_NAME)
331    padded_sample = data[0]
332    padded_sample.pop('label', None)
333    padded_sample['file_name'] = 'dummy.jpg'
334    num_readers = 4
335
336    def partitions(num_shards, num_padded):
337        for partition_id in range(num_shards):
338            data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
339                                      num_shards=num_shards,
340                                      shard_id=partition_id,
341                                      padded_sample=padded_sample,
342                                      num_padded=num_padded)
343        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
344            logger.info("-------------- partition : {} ------------------------".format(partition_id))
345            logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
346            logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
347            logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
348
349    with pytest.raises(Exception, match="padded_sample cannot match columns_list."):
350        partitions(4, 2)
351
352
353def test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_cv_file):
354    data = get_data(CV_DIR_NAME)
355    padded_sample = data[0]
356    padded_sample['label'] = -2
357    padded_sample['file_name'] = 'dummy.jpg'
358    num_readers = 4
359
360    def partitions(num_shards, num_padded):
361        for partition_id in range(num_shards):
362            data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
363                                      num_shards=num_shards,
364                                      shard_id=partition_id,
365                                      padded_sample=padded_sample,
366                                      num_padded=num_padded)
367        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
368            logger.info("-------------- partition : {} ------------------------".format(partition_id))
369            logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
370            logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
371            logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
372
373    with pytest.raises(Exception, match="padded_sample is specified and requires columns_list as well."):
374        partitions(4, 2)
375
376
377def test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv_file):
378    columns_list = ["data", "file_name", "label"]
379    data = get_data(CV_DIR_NAME)
380    padded_sample = data[0]
381    padded_sample['file_name'] = 'dummy.jpg'
382    num_readers = 4
383
384    def partitions(num_shards, num_padded):
385        for partition_id in range(num_shards):
386            data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
387                                      num_shards=num_shards,
388                                      shard_id=partition_id,
389                                      padded_sample=padded_sample)
390        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
391            logger.info("-------------- partition : {} ------------------------".format(partition_id))
392            logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
393            logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
394            logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
395
396    with pytest.raises(Exception, match="padded_sample is specified and requires num_padded as well."):
397        partitions(4, 2)
398
399
400def test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remove_cv_file):
401    columns_list = ["data", "file_name", "label"]
402    data = get_data(CV_DIR_NAME)
403    padded_sample = data[0]
404    padded_sample['file_name'] = 'dummy.jpg'
405    num_readers = 4
406
407    def partitions(num_shards, num_padded):
408        for partition_id in range(num_shards):
409            data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
410                                      num_shards=num_shards,
411                                      shard_id=partition_id,
412                                      num_padded=num_padded)
413        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
414            logger.info("-------------- partition : {} ------------------------".format(partition_id))
415            logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
416            logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
417            logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
418
419    with pytest.raises(Exception, match="num_padded is specified but padded_sample is not."):
420        partitions(4, 2)
421
422
423def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
424    columns_list = ["input_ids", "id", "rating"]
425
426    data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
427    padded_sample = data[0]
428    padded_sample['id'] = "-1"
429    padded_sample['input_ids'] = np.array([-1, -1, -1, -1], dtype=np.int64)
430    padded_sample['rating'] = 1.0
431    num_readers = 4
432
433    def partitions(num_shards, num_padded, dataset_size):
434        num_padded_iter = 0
435        num_iter = 0
436        for partition_id in range(num_shards):
437            data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
438                                      num_shards=num_shards,
439                                      shard_id=partition_id,
440                                      padded_sample=padded_sample,
441                                      num_padded=num_padded)
442            assert data_set.get_dataset_size() == dataset_size
443            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
444                logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
445                logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
446                logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
447                    item["input_ids"],
448                    item["input_ids"].shape))
449                if item['id'] == bytes('-1', encoding='utf-8'):
450                    num_padded_iter += 1
451                    assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
452                    assert (item['input_ids'] == padded_sample['input_ids']).all()
453                    assert (item['rating'] == padded_sample['rating']).all()
454                num_iter += 1
455        assert num_padded_iter == num_padded
456        assert num_iter == dataset_size * num_shards
457
458    partitions(4, 6, 4)
459    partitions(5, 5, 3)
460    partitions(9, 8, 2)
461
462
463def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_nlp_file):
464    columns_list = ["input_ids", "id", "rating"]
465
466    data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
467    padded_sample = data[0]
468    padded_sample['id'] = "-1"
469    padded_sample['input_ids'] = np.array([-1, -1, -1, -1], dtype=np.int64)
470    padded_sample['rating'] = 1.0
471    num_readers = 4
472    repeat_size = 3
473
474    def partitions(num_shards, num_padded, dataset_size):
475        num_padded_iter = 0
476        num_iter = 0
477
478        for partition_id in range(num_shards):
479            epoch1_shuffle_result = []
480            epoch2_shuffle_result = []
481            epoch3_shuffle_result = []
482            data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
483                                      num_shards=num_shards,
484                                      shard_id=partition_id,
485                                      padded_sample=padded_sample,
486                                      num_padded=num_padded)
487            assert data_set.get_dataset_size() == dataset_size
488            data_set = data_set.repeat(repeat_size)
489
490            local_index = 0
491            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
492                logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
493                logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
494                logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
495                    item["input_ids"],
496                    item["input_ids"].shape))
497                if item['id'] == bytes('-1', encoding='utf-8'):
498                    num_padded_iter += 1
499                    assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
500                    assert (item['input_ids'] == padded_sample['input_ids']).all()
501                    assert (item['rating'] == padded_sample['rating']).all()
502
503                if local_index < dataset_size:
504                    epoch1_shuffle_result.append(item['id'])
505                elif local_index < dataset_size * 2:
506                    epoch2_shuffle_result.append(item['id'])
507                elif local_index < dataset_size * 3:
508                    epoch3_shuffle_result.append(item['id'])
509                local_index += 1
510                num_iter += 1
511            assert len(epoch1_shuffle_result) == dataset_size
512            assert len(epoch2_shuffle_result) == dataset_size
513            assert len(epoch3_shuffle_result) == dataset_size
514            assert local_index == dataset_size * repeat_size
515
516            # When dataset_size is equal to 2, too high probability is the same result after shuffle operation
517            if dataset_size > 2:
518                assert epoch1_shuffle_result != epoch2_shuffle_result
519                assert epoch2_shuffle_result != epoch3_shuffle_result
520        assert num_padded_iter == num_padded * repeat_size
521        assert num_iter == dataset_size * num_shards * repeat_size
522
523    partitions(4, 6, 4)
524    partitions(5, 5, 3)
525    partitions(9, 8, 2)
526
527
528def test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_result_per_epoch(add_and_remove_nlp_file):
529    columns_list = ["input_ids", "id", "rating"]
530
531    padded_sample = {}
532    padded_sample['id'] = "-1"
533    padded_sample['input_ids'] = np.array([-1, -1, -1, -1], dtype=np.int64)
534    padded_sample['rating'] = 1.0
535    num_readers = 4
536    repeat_size = 3
537
538    def partitions(num_shards, num_padded, dataset_size):
539        num_padded_iter = 0
540        num_iter = 0
541
542        epoch_result = [[["" for i in range(dataset_size)] for i in range(repeat_size)] for i in range(num_shards)]
543
544        for partition_id in range(num_shards):
545            data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
546                                      num_shards=num_shards,
547                                      shard_id=partition_id,
548                                      padded_sample=padded_sample,
549                                      num_padded=num_padded)
550            assert data_set.get_dataset_size() == dataset_size
551            data_set = data_set.repeat(repeat_size)
552            inner_num_iter = 0
553            for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
554                logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
555                logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
556                logger.info("-------------- item[input_ids]: {}, shape: {} -----------------"
557                            .format(item["input_ids"], item["input_ids"].shape))
558                if item['id'] == bytes('-1', encoding='utf-8'):
559                    num_padded_iter += 1
560                    assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
561                    assert (item['input_ids'] == padded_sample['input_ids']).all()
562                    assert (item['rating'] == padded_sample['rating']).all()
563                # save epoch result
564                epoch_result[partition_id][int(inner_num_iter / dataset_size)][inner_num_iter % dataset_size] = item[
565                    "id"]
566                num_iter += 1
567                inner_num_iter += 1
568            assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
569            assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
570            assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
571            if dataset_size > 2:
572                epoch_result[partition_id][0].sort()
573                epoch_result[partition_id][1].sort()
574                epoch_result[partition_id][2].sort()
575                assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
576                assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
577                assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
578        assert num_padded_iter == num_padded * repeat_size
579        assert num_iter == dataset_size * num_shards * repeat_size
580
581    partitions(4, 6, 4)
582    partitions(5, 5, 3)
583    partitions(9, 8, 2)
584
585
586def get_data(dir_name):
587    """
588    usage: get data from imagenet dataset
589    params:
590    dir_name: directory containing folder images and annotation information
591
592    """
593    if not os.path.isdir(dir_name):
594        raise IOError("Directory {} not exists".format(dir_name))
595    img_dir = os.path.join(dir_name, "images")
596    ann_file = os.path.join(dir_name, "annotation.txt")
597    with open(ann_file, "r") as file_reader:
598        lines = file_reader.readlines()
599
600    data_list = []
601    for i, line in enumerate(lines):
602        try:
603            filename, label = line.split(",")
604            label = label.strip("\n")
605            with open(os.path.join(img_dir, filename), "rb") as file_reader:
606                img = file_reader.read()
607            data_json = {"id": i,
608                         "file_name": filename,
609                         "data": img,
610                         "label": int(label)}
611            data_list.append(data_json)
612        except FileNotFoundError:
613            continue
614    return data_list
615
616
617def get_nlp_data(dir_name, vocab_file, num):
618    """
619    Return raw data of aclImdb dataset.
620
621    Args:
622        dir_name (str): String of aclImdb dataset's path.
623        vocab_file (str): String of dictionary's path.
624        num (int): Number of sample.
625
626    Returns:
627        List
628    """
629    if not os.path.isdir(dir_name):
630        raise IOError("Directory {} not exists".format(dir_name))
631    for root, dirs, files in os.walk(dir_name):
632        for index, file_name_extension in enumerate(files):
633            if index < num:
634                file_path = os.path.join(root, file_name_extension)
635                file_name, _ = file_name_extension.split('.', 1)
636                id_, rating = file_name.split('_', 1)
637                with open(file_path, 'r') as f:
638                    raw_content = f.read()
639
640                dictionary = load_vocab(vocab_file)
641                vectors = [dictionary.get('[CLS]')]
642                vectors += [dictionary.get(i) if i in dictionary
643                            else dictionary.get('[UNK]')
644                            for i in re.findall(r"[\w']+|[{}]"
645                                                .format(string.punctuation),
646                                                raw_content)]
647                vectors += [dictionary.get('[SEP]')]
648                input_, mask, segment = inputs(vectors)
649                input_ids = np.reshape(np.array(input_), [-1])
650                input_mask = np.reshape(np.array(mask), [1, -1])
651                segment_ids = np.reshape(np.array(segment), [2, -1])
652                data = {
653                    "label": 1,
654                    "id": id_,
655                    "rating": float(rating),
656                    "input_ids": input_ids,
657                    "input_mask": input_mask,
658                    "segment_ids": segment_ids
659                }
660                yield data
661
662
663def convert_to_uni(text):
664    if isinstance(text, str):
665        return text
666    if isinstance(text, bytes):
667        return text.decode('utf-8', 'ignore')
668    raise Exception("The type %s does not convert!" % type(text))
669
670
671def load_vocab(vocab_file):
672    """load vocabulary to translate statement."""
673    vocab = collections.OrderedDict()
674    vocab.setdefault('blank', 2)
675    index = 0
676    with open(vocab_file) as reader:
677        while True:
678            tmp = reader.readline()
679            if not tmp:
680                break
681            token = convert_to_uni(tmp)
682            token = token.strip()
683            vocab[token] = index
684            index += 1
685    return vocab
686
687
688def inputs(vectors, maxlen=50):
689    length = len(vectors)
690    if length > maxlen:
691        return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
692    input_ = vectors + [0] * (maxlen - length)
693    mask = [1] * length + [0] * (maxlen - length)
694    segment = [0] * maxlen
695    return input_, mask, segment
696
697
698if __name__ == '__main__':
699    test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file)
700    test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file)
701    test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_file)
702    test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file)
703    test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_and_remove_cv_file)
704    test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_remove_cv_file)
705    test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_cv_file)
706    test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv_file)
707    test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remove_cv_file)
708    test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file)
709    test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_nlp_file)
710    test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_result_per_epoch(add_and_remove_nlp_file)
711