• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This is the test module for mindrecord
17"""
18import os
19import pytest
20import numpy as np
21
22import mindspore.dataset as ds
23from mindspore import log as logger
24from mindspore.dataset.text import to_str
25from mindspore.mindrecord import FileWriter
26
27FILES_NUM = 4
28CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
29CV_DIR_NAME = "../data/mindrecord/testImageNetData"
30
31
32@pytest.fixture
33def add_and_remove_cv_file():
34    """add/remove cv file"""
35    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
36             for x in range(FILES_NUM)]
37    try:
38        for x in paths:
39            if os.path.exists("{}".format(x)):
40                os.remove("{}".format(x))
41            if os.path.exists("{}.db".format(x)):
42                os.remove("{}.db".format(x))
43        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
44        data = get_data(CV_DIR_NAME, True)
45        cv_schema_json = {"id": {"type": "int32"},
46                          "file_name": {"type": "string"},
47                          "label": {"type": "int32"},
48                          "data": {"type": "bytes"}}
49        writer.add_schema(cv_schema_json, "img_schema")
50        writer.add_index(["file_name", "label"])
51        writer.write_raw_data(data)
52        writer.commit()
53        yield "yield_cv_data"
54    except Exception as error:
55        for x in paths:
56            os.remove("{}".format(x))
57            os.remove("{}.db".format(x))
58        raise error
59    else:
60        for x in paths:
61            os.remove("{}".format(x))
62            os.remove("{}.db".format(x))
63
64
65def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
66    """tutorial for cv minderdataset."""
67    num_readers = 4
68    sampler = ds.PKSampler(2)
69    data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
70                              sampler=sampler)
71
72    assert data_set.get_dataset_size() == 6
73    num_iter = 0
74    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
75        logger.info(
76            "-------------- cv reader basic: {} ------------------------".format(num_iter))
77        logger.info("-------------- item[file_name]: \
78                {}------------------------".format(to_str(item["file_name"])))
79        logger.info(
80            "-------------- item[label]: {} ----------------------------".format(item["label"]))
81        num_iter += 1
82
83
84def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
85    """tutorial for cv minderdataset."""
86    columns_list = ["data", "file_name", "label"]
87    num_readers = 4
88    sampler = ds.PKSampler(2)
89    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
90                              sampler=sampler)
91
92    assert data_set.get_dataset_size() == 6
93    num_iter = 0
94    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
95        logger.info(
96            "-------------- cv reader basic: {} ------------------------".format(num_iter))
97        logger.info("-------------- item[data]: \
98                {}------------------------".format(item["data"][:10]))
99        logger.info("-------------- item[file_name]: \
100                {}------------------------".format(to_str(item["file_name"])))
101        logger.info(
102            "-------------- item[label]: {} ----------------------------".format(item["label"]))
103        num_iter += 1
104
105
106def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
107    """tutorial for cv minderdataset."""
108    columns_list = ["data", "file_name", "label"]
109    num_readers = 4
110    sampler = ds.PKSampler(3, None, True)
111    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
112                              sampler=sampler)
113
114    assert data_set.get_dataset_size() == 9
115    num_iter = 0
116    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
117        logger.info(
118            "-------------- cv reader basic: {} ------------------------".format(num_iter))
119        logger.info("-------------- item[file_name]: \
120                {}------------------------".format(to_str(item["file_name"])))
121        logger.info(
122            "-------------- item[label]: {} ----------------------------".format(item["label"]))
123        num_iter += 1
124    assert num_iter == 9
125
126
127def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file):
128    """tutorial for cv minderdataset."""
129    columns_list = ["data", "file_name", "label"]
130    num_readers = 4
131    sampler = ds.PKSampler(3, None, True, 'label', 5)
132    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
133                              sampler=sampler)
134
135    assert data_set.get_dataset_size() == 5
136    num_iter = 0
137    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
138        logger.info(
139            "-------------- cv reader basic: {} ------------------------".format(num_iter))
140        logger.info("-------------- item[file_name]: \
141                {}------------------------".format(to_str(item["file_name"])))
142        logger.info(
143            "-------------- item[label]: {} ----------------------------".format(item["label"]))
144        num_iter += 1
145    assert num_iter == 5
146
147
148def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file):
149    """tutorial for cv minderdataset."""
150    columns_list = ["data", "file_name", "label"]
151    num_readers = 4
152    sampler = ds.PKSampler(3, None, True, 'label', 10)
153    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
154                              sampler=sampler)
155
156    assert data_set.get_dataset_size() == 9
157    num_iter = 0
158    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
159        logger.info(
160            "-------------- cv reader basic: {} ------------------------".format(num_iter))
161        logger.info("-------------- item[file_name]: \
162                {}------------------------".format(to_str(item["file_name"])))
163        logger.info(
164            "-------------- item[label]: {} ----------------------------".format(item["label"]))
165        num_iter += 1
166    assert num_iter == 9
167
168
169def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file):
170    """tutorial for cv minderdataset."""
171    columns_list = ["data", "file_name", "label"]
172    num_readers = 4
173    sampler = ds.PKSampler(5, None, True)
174    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
175                              sampler=sampler)
176    assert data_set.get_dataset_size() == 15
177    num_iter = 0
178    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
179        logger.info(
180            "-------------- cv reader basic: {} ------------------------".format(num_iter))
181        logger.info("-------------- item[file_name]: \
182                {}------------------------".format(to_str(item["file_name"])))
183        logger.info(
184            "-------------- item[label]: {} ----------------------------".format(item["label"]))
185        num_iter += 1
186    assert num_iter == 15
187
188
189def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file):
190    """tutorial for cv minderdataset."""
191    columns_list = ["data", "file_name", "label"]
192    num_readers = 4
193    sampler = ds.PKSampler(5, None, True, 'label', 20)
194    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
195                              sampler=sampler)
196    assert data_set.get_dataset_size() == 15
197    num_iter = 0
198    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
199        logger.info(
200            "-------------- cv reader basic: {} ------------------------".format(num_iter))
201        logger.info("-------------- item[file_name]: \
202                {}------------------------".format(to_str(item["file_name"])))
203        logger.info(
204            "-------------- item[label]: {} ----------------------------".format(item["label"]))
205        num_iter += 1
206    assert num_iter == 15
207
208
209def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file):
210    """tutorial for cv minderdataset."""
211    columns_list = ["data", "file_name", "label"]
212    num_readers = 4
213    sampler = ds.PKSampler(5, None, True, 'label', 10)
214    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
215                              sampler=sampler)
216    assert data_set.get_dataset_size() == 10
217    num_iter = 0
218    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
219        logger.info(
220            "-------------- cv reader basic: {} ------------------------".format(num_iter))
221        logger.info("-------------- item[file_name]: \
222                {}------------------------".format(to_str(item["file_name"])))
223        logger.info(
224            "-------------- item[label]: {} ----------------------------".format(item["label"]))
225        num_iter += 1
226    assert num_iter == 10
227
228
229def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
230    """tutorial for cv minderdataset."""
231    columns_list = ["data", "file_name", "label"]
232    num_readers = 4
233    indices = [1, 2, 3, 5, 7]
234    samplers = (ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices))
235    for sampler in samplers:
236        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
237                                  sampler=sampler)
238        assert data_set.get_dataset_size() == 5
239        num_iter = 0
240        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
241            logger.info(
242                "-------------- cv reader basic: {} ------------------------".format(num_iter))
243            logger.info(
244                "-------------- item[data]: {}  -----------------------------".format(item["data"]))
245            logger.info(
246                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
247            logger.info(
248                "-------------- item[label]: {} ----------------------------".format(item["label"]))
249            num_iter += 1
250        assert num_iter == 5
251
252
253def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
254    """tutorial for cv minderdataset."""
255    columns_list = ["data", "file_name", "label"]
256    num_readers = 4
257    indices = [1, 2, 2, 5, 7, 9]
258    samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
259    for sampler in samplers:
260        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
261                                  sampler=sampler)
262        assert data_set.get_dataset_size() == 6
263        num_iter = 0
264        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
265            logger.info(
266                "-------------- cv reader basic: {} ------------------------".format(num_iter))
267            logger.info(
268                "-------------- item[data]: {}  -----------------------------".format(item["data"]))
269            logger.info(
270                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
271            logger.info(
272                "-------------- item[label]: {} ----------------------------".format(item["label"]))
273            num_iter += 1
274        assert num_iter == 6
275
276
277def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
278    """tutorial for cv minderdataset."""
279    columns_list = ["data", "file_name", "label"]
280    num_readers = 4
281    indices = []
282    samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
283    for sampler in samplers:
284        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
285                                  sampler=sampler)
286        assert data_set.get_dataset_size() == 0
287        num_iter = 0
288        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
289            logger.info(
290                "-------------- cv reader basic: {} ------------------------".format(num_iter))
291            logger.info(
292                "-------------- item[data]: {}  -----------------------------".format(item["data"]))
293            logger.info(
294                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
295            logger.info(
296                "-------------- item[label]: {} ----------------------------".format(item["label"]))
297            num_iter += 1
298        assert num_iter == 0
299
300
301def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file):
302    """tutorial for cv minderdataset."""
303    columns_list = ["data", "file_name", "label"]
304    num_readers = 4
305    indices = [1, 2, 4, 11, 13]
306    samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
307    for sampler in samplers:
308        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
309                                  sampler=sampler)
310        assert data_set.get_dataset_size() == 5
311        num_iter = 0
312        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
313            logger.info(
314                "-------------- cv reader basic: {} ------------------------".format(num_iter))
315            logger.info(
316                "-------------- item[data]: {}  -----------------------------".format(item["data"]))
317            logger.info(
318                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
319            logger.info(
320                "-------------- item[label]: {} ----------------------------".format(item["label"]))
321            num_iter += 1
322        assert num_iter == 5
323
324
325def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
326    columns_list = ["data", "file_name", "label"]
327    num_readers = 4
328    indices = [1, 2, 4, -1, -2]
329    samplers = ds.SubsetRandomSampler(indices), ds.SubsetSampler(indices)
330    for sampler in samplers:
331        data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
332                                  sampler=sampler)
333        assert data_set.get_dataset_size() == 5
334        num_iter = 0
335        for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
336            logger.info(
337                "-------------- cv reader basic: {} ------------------------".format(num_iter))
338            logger.info(
339                "-------------- item[data]: {}  -----------------------------".format(item["data"]))
340            logger.info(
341                "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
342            logger.info(
343                "-------------- item[label]: {} ----------------------------".format(item["label"]))
344            num_iter += 1
345        assert num_iter == 5
346
347
348def test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file):
349    data = get_data(CV_DIR_NAME, True)
350    columns_list = ["data", "file_name", "label"]
351    num_readers = 4
352    sampler = ds.RandomSampler()
353    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
354                              sampler=sampler)
355    assert data_set.get_dataset_size() == 10
356    num_iter = 0
357    new_dataset = []
358    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
359        logger.info(
360            "-------------- cv reader basic: {} ------------------------".format(num_iter))
361        logger.info(
362            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
363        logger.info(
364            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
365        logger.info(
366            "-------------- item[label]: {} ----------------------------".format(item["label"]))
367        num_iter += 1
368        new_dataset.append(item['file_name'])
369    assert num_iter == 10
370    assert new_dataset != [x['file_name'] for x in data]
371
372
373def test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file):
374    columns_list = ["data", "file_name", "label"]
375    num_readers = 4
376    sampler = ds.RandomSampler()
377    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
378                              sampler=sampler)
379    assert data_set.get_dataset_size() == 10
380    ds1 = data_set.repeat(3)
381    num_iter = 0
382    epoch1_dataset = []
383    epoch2_dataset = []
384    epoch3_dataset = []
385    for item in ds1.create_dict_iterator(num_epochs=1, output_numpy=True):
386        logger.info(
387            "-------------- cv reader basic: {} ------------------------".format(num_iter))
388        logger.info(
389            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
390        logger.info(
391            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
392        logger.info(
393            "-------------- item[label]: {} ----------------------------".format(item["label"]))
394        num_iter += 1
395        if num_iter <= 10:
396            epoch1_dataset.append(item['file_name'])
397        elif num_iter <= 20:
398            epoch2_dataset.append(item['file_name'])
399        else:
400            epoch3_dataset.append(item['file_name'])
401    assert num_iter == 30
402    assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
403    assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
404    assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
405
406
407def test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file):
408    columns_list = ["data", "file_name", "label"]
409    num_readers = 4
410    sampler = ds.RandomSampler(replacement=True, num_samples=5)
411    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
412                              sampler=sampler)
413    assert data_set.get_dataset_size() == 5
414    num_iter = 0
415    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
416        logger.info(
417            "-------------- cv reader basic: {} ------------------------".format(num_iter))
418        logger.info(
419            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
420        logger.info(
421            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
422        logger.info(
423            "-------------- item[label]: {} ----------------------------".format(item["label"]))
424        num_iter += 1
425    assert num_iter == 5
426
427
428def test_cv_minddataset_random_sampler_replacement_false_1(add_and_remove_cv_file):
429    columns_list = ["data", "file_name", "label"]
430    num_readers = 4
431    sampler = ds.RandomSampler(replacement=False, num_samples=2)
432    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
433                              sampler=sampler)
434    assert data_set.get_dataset_size() == 2
435    num_iter = 0
436    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
437        logger.info(
438            "-------------- cv reader basic: {} ------------------------".format(num_iter))
439        logger.info(
440            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
441        logger.info(
442            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
443        logger.info(
444            "-------------- item[label]: {} ----------------------------".format(item["label"]))
445        num_iter += 1
446    assert num_iter == 2
447
448
449def test_cv_minddataset_random_sampler_replacement_false_2(add_and_remove_cv_file):
450    columns_list = ["data", "file_name", "label"]
451    num_readers = 4
452    sampler = ds.RandomSampler(replacement=False, num_samples=20)
453    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
454                              sampler=sampler)
455    assert data_set.get_dataset_size() == 10
456    num_iter = 0
457    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
458        logger.info(
459            "-------------- cv reader basic: {} ------------------------".format(num_iter))
460        logger.info(
461            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
462        logger.info(
463            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
464        logger.info(
465            "-------------- item[label]: {} ----------------------------".format(item["label"]))
466        num_iter += 1
467    assert num_iter == 10
468
469
470def test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file):
471    data = get_data(CV_DIR_NAME, True)
472    columns_list = ["data", "file_name", "label"]
473    num_readers = 4
474    sampler = ds.SequentialSampler(1, 4)
475    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
476                              sampler=sampler)
477    assert data_set.get_dataset_size() == 4
478    num_iter = 0
479    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
480        logger.info(
481            "-------------- cv reader basic: {} ------------------------".format(num_iter))
482        logger.info(
483            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
484        logger.info(
485            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
486        logger.info(
487            "-------------- item[label]: {} ----------------------------".format(item["label"]))
488        assert item['file_name'] == np.array(
489            data[num_iter + 1]['file_name'], dtype='S')
490        num_iter += 1
491    assert num_iter == 4
492
493
494def test_cv_minddataset_sequential_sampler_offeset(add_and_remove_cv_file):
495    data = get_data(CV_DIR_NAME, True)
496    columns_list = ["data", "file_name", "label"]
497    num_readers = 4
498    sampler = ds.SequentialSampler(2, 10)
499    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
500                              sampler=sampler)
501    dataset_size = data_set.get_dataset_size()
502    assert dataset_size == 10
503    num_iter = 0
504    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
505        logger.info(
506            "-------------- cv reader basic: {} ------------------------".format(num_iter))
507        logger.info(
508            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
509        logger.info(
510            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
511        logger.info(
512            "-------------- item[label]: {} ----------------------------".format(item["label"]))
513        assert item['file_name'] == np.array(
514            data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
515        num_iter += 1
516    assert num_iter == 10
517
518
519def test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file):
520    data = get_data(CV_DIR_NAME, True)
521    columns_list = ["data", "file_name", "label"]
522    num_readers = 4
523    sampler = ds.SequentialSampler(2, 20)
524    data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
525                              sampler=sampler)
526    dataset_size = data_set.get_dataset_size()
527    assert dataset_size == 10
528    num_iter = 0
529    for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
530        logger.info(
531            "-------------- cv reader basic: {} ------------------------".format(num_iter))
532        logger.info(
533            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
534        logger.info(
535            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
536        logger.info(
537            "-------------- item[label]: {} ----------------------------".format(item["label"]))
538        assert item['file_name'] == np.array(
539            data[(num_iter + 2) % dataset_size]['file_name'], dtype='S')
540        num_iter += 1
541    assert num_iter == 10
542
543
544def test_cv_minddataset_split_basic(add_and_remove_cv_file):
545    data = get_data(CV_DIR_NAME, True)
546    columns_list = ["data", "file_name", "label"]
547    num_readers = 4
548    d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
549                       num_readers, shuffle=False)
550    d1, d2 = d.split([8, 2], randomize=False)
551    assert d.get_dataset_size() == 10
552    assert d1.get_dataset_size() == 8
553    assert d2.get_dataset_size() == 2
554    num_iter = 0
555    for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
556        logger.info(
557            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
558        logger.info(
559            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
560        logger.info(
561            "-------------- item[label]: {} ----------------------------".format(item["label"]))
562        assert item['file_name'] == np.array(data[num_iter]['file_name'],
563                                             dtype='S')
564        num_iter += 1
565    assert num_iter == 8
566    num_iter = 0
567    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
568        logger.info(
569            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
570        logger.info(
571            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
572        logger.info(
573            "-------------- item[label]: {} ----------------------------".format(item["label"]))
574        assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
575                                             dtype='S')
576        num_iter += 1
577    assert num_iter == 2
578
579
580def test_cv_minddataset_split_exact_percent(add_and_remove_cv_file):
581    data = get_data(CV_DIR_NAME, True)
582    columns_list = ["data", "file_name", "label"]
583    num_readers = 4
584    d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
585                       num_readers, shuffle=False)
586    d1, d2 = d.split([0.8, 0.2], randomize=False)
587    assert d.get_dataset_size() == 10
588    assert d1.get_dataset_size() == 8
589    assert d2.get_dataset_size() == 2
590    num_iter = 0
591    for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
592        logger.info(
593            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
594        logger.info(
595            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
596        logger.info(
597            "-------------- item[label]: {} ----------------------------".format(item["label"]))
598        assert item['file_name'] == np.array(
599            data[num_iter]['file_name'], dtype='S')
600        num_iter += 1
601    assert num_iter == 8
602    num_iter = 0
603    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
604        logger.info(
605            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
606        logger.info(
607            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
608        logger.info(
609            "-------------- item[label]: {} ----------------------------".format(item["label"]))
610        assert item['file_name'] == np.array(data[num_iter + 8]['file_name'],
611                                             dtype='S')
612        num_iter += 1
613    assert num_iter == 2
614
615
616def test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file):
617    data = get_data(CV_DIR_NAME, True)
618    columns_list = ["data", "file_name", "label"]
619    num_readers = 4
620    d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
621                       num_readers, shuffle=False)
622    d1, d2 = d.split([0.41, 0.59], randomize=False)
623    assert d.get_dataset_size() == 10
624    assert d1.get_dataset_size() == 4
625    assert d2.get_dataset_size() == 6
626    num_iter = 0
627    for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
628        logger.info(
629            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
630        logger.info(
631            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
632        logger.info(
633            "-------------- item[label]: {} ----------------------------".format(item["label"]))
634        assert item['file_name'] == np.array(
635            data[num_iter]['file_name'], dtype='S')
636        num_iter += 1
637    assert num_iter == 4
638    num_iter = 0
639    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
640        logger.info(
641            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
642        logger.info(
643            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
644        logger.info(
645            "-------------- item[label]: {} ----------------------------".format(item["label"]))
646        assert item['file_name'] == np.array(data[num_iter + 4]['file_name'],
647                                             dtype='S')
648        num_iter += 1
649    assert num_iter == 6
650
651
652def test_cv_minddataset_split_deterministic(add_and_remove_cv_file):
653    columns_list = ["data", "file_name", "label"]
654    num_readers = 4
655    d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
656                       num_readers, shuffle=False)
657    # should set seed to avoid data overlap
658    ds.config.set_seed(111)
659    d1, d2 = d.split([0.8, 0.2])
660    assert d.get_dataset_size() == 10
661    assert d1.get_dataset_size() == 8
662    assert d2.get_dataset_size() == 2
663
664    d1_dataset = []
665    d2_dataset = []
666    num_iter = 0
667    for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
668        logger.info(
669            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
670        logger.info(
671            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
672        logger.info(
673            "-------------- item[label]: {} ----------------------------".format(item["label"]))
674        d1_dataset.append(item['file_name'])
675        num_iter += 1
676    assert num_iter == 8
677    num_iter = 0
678    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
679        logger.info(
680            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
681        logger.info(
682            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
683        logger.info(
684            "-------------- item[label]: {} ----------------------------".format(item["label"]))
685        d2_dataset.append(item['file_name'])
686        num_iter += 1
687    assert num_iter == 2
688    inter_dataset = [x for x in d1_dataset if x in d2_dataset]
689    assert inter_dataset == []  # intersection of  d1 and d2
690
691
692def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
693    data = get_data(CV_DIR_NAME, True)
694    columns_list = ["data", "file_name", "label"]
695    num_readers = 4
696    d = ds.MindDataset(CV_FILE_NAME + "0", columns_list,
697                       num_readers, shuffle=False)
698    # should set seed to avoid data overlap
699    ds.config.set_seed(111)
700    d1, d2 = d.split([0.8, 0.2])
701    assert d.get_dataset_size() == 10
702    assert d1.get_dataset_size() == 8
703    assert d2.get_dataset_size() == 2
704    distributed_sampler = ds.DistributedSampler(2, 0)
705    d1.use_sampler(distributed_sampler)
706    assert d1.get_dataset_size() == 4
707
708    num_iter = 0
709    d1_shard1 = []
710    for item in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
711        logger.info(
712            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
713        logger.info(
714            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
715        logger.info(
716            "-------------- item[label]: {} ----------------------------".format(item["label"]))
717        num_iter += 1
718        d1_shard1.append(item['file_name'])
719    assert num_iter == 4
720    assert d1_shard1 != [x['file_name'] for x in data[0:4]]
721
722    distributed_sampler = ds.DistributedSampler(2, 1)
723    d1.use_sampler(distributed_sampler)
724    assert d1.get_dataset_size() == 4
725
726    d1s = d1.repeat(3)
727    epoch1_dataset = []
728    epoch2_dataset = []
729    epoch3_dataset = []
730    num_iter = 0
731    for item in d1s.create_dict_iterator(num_epochs=1, output_numpy=True):
732        logger.info(
733            "-------------- item[data]: {}  -----------------------------".format(item["data"]))
734        logger.info(
735            "-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
736        logger.info(
737            "-------------- item[label]: {} ----------------------------".format(item["label"]))
738        num_iter += 1
739        if num_iter <= 4:
740            epoch1_dataset.append(item['file_name'])
741        elif num_iter <= 8:
742            epoch2_dataset.append(item['file_name'])
743        else:
744            epoch3_dataset.append(item['file_name'])
745    assert len(epoch1_dataset) == 4
746    assert len(epoch2_dataset) == 4
747    assert len(epoch3_dataset) == 4
748    inter_dataset = [x for x in d1_shard1 if x in epoch1_dataset]
749    assert inter_dataset == []  # intersection of d1's shard1 and d1's shard2
750    assert epoch1_dataset not in (epoch2_dataset, epoch3_dataset)
751    assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
752    assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
753
754    epoch1_dataset.sort()
755    epoch2_dataset.sort()
756    epoch3_dataset.sort()
757    assert epoch1_dataset != epoch2_dataset
758    assert epoch2_dataset != epoch3_dataset
759    assert epoch3_dataset != epoch1_dataset
760
761
762def get_data(dir_name, sampler=False):
763    """
764    usage: get data from imagenet dataset
765    params:
766    dir_name: directory containing folder images and annotation information
767
768    """
769    if not os.path.isdir(dir_name):
770        raise IOError("Directory {} not exists".format(dir_name))
771    img_dir = os.path.join(dir_name, "images")
772    if sampler:
773        ann_file = os.path.join(dir_name, "annotation_sampler.txt")
774    else:
775        ann_file = os.path.join(dir_name, "annotation.txt")
776    with open(ann_file, "r") as file_reader:
777        lines = file_reader.readlines()
778
779    data_list = []
780    for i, line in enumerate(lines):
781        try:
782            filename, label = line.split(",")
783            label = label.strip("\n")
784            with open(os.path.join(img_dir, filename), "rb") as file_reader:
785                img = file_reader.read()
786            data_json = {"id": i,
787                         "file_name": filename,
788                         "data": img,
789                         "label": int(label)}
790            data_list.append(data_json)
791        except FileNotFoundError:
792            continue
793    return data_list
794
795
796if __name__ == '__main__':
797    test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file)
798    test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file)
799    test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file)
800    test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file)
801    test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file)
802    test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file)
803    test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file)
804    test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file)
805    test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file)
806    test_cv_minddataset_random_sampler_basic(add_and_remove_cv_file)
807    test_cv_minddataset_random_sampler_repeat(add_and_remove_cv_file)
808    test_cv_minddataset_random_sampler_replacement(add_and_remove_cv_file)
809    test_cv_minddataset_sequential_sampler_basic(add_and_remove_cv_file)
810    test_cv_minddataset_sequential_sampler_exceed_size(add_and_remove_cv_file)
811    test_cv_minddataset_split_basic(add_and_remove_cv_file)
812    test_cv_minddataset_split_exact_percent(add_and_remove_cv_file)
813    test_cv_minddataset_split_fuzzy_percent(add_and_remove_cv_file)
814    test_cv_minddataset_split_deterministic(add_and_remove_cv_file)
815    test_cv_minddataset_split_sharding(add_and_remove_cv_file)
816