• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from io import BytesIO
2import copy
3import os
4import numpy as np
5import pytest
6from PIL import Image
7
8import mindspore.dataset as ds
9from mindspore.mindrecord import FileWriter
10import mindspore.dataset.vision.c_transforms as V_C
11
12FILES_NUM = 4
13CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
14CV_DIR_NAME = "../data/mindrecord/testImageNetData"
15
16
17def generator_5():
18    for i in range(0, 5):
19        yield (np.array([i]),)
20
21
22def generator_8():
23    for i in range(5, 8):
24        yield (np.array([i]),)
25
26
27def generator_10():
28    for i in range(0, 10):
29        yield (np.array([i]),)
30
31
32def generator_20():
33    for i in range(10, 20):
34        yield (np.array([i]),)
35
36
37def generator_30():
38    for i in range(20, 30):
39        yield (np.array([i]),)
40
41
42def test_TFRecord_Padded():
43    DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
44    SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
45    result_list = [[159109, 2], [192607, 3], [179251, 4], [1, 5]]
46    verify_list = []
47    shard_num = 4
48    for i in range(shard_num):
49        data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"],
50                                  shuffle=False, shard_equal_rows=True)
51
52        padded_samples = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
53                          {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
54                          {'image': np.zeros(5, np.uint8)}]
55
56        padded_ds = ds.PaddedDataset(padded_samples)
57        concat_ds = data + padded_ds
58        testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
59        concat_ds.use_sampler(testsampler)
60        shard_list = []
61        for item in concat_ds.create_dict_iterator(num_epochs=1, output_numpy=True):
62            shard_list.append(len(item['image']))
63        verify_list.append(shard_list)
64    assert verify_list == result_list
65
66
67def test_GeneratorDataSet_Padded():
68    result_list = []
69    for i in range(10):
70        tem_list = []
71        tem_list.append(i)
72        tem_list.append(10 + i)
73        result_list.append(tem_list)
74
75    verify_list = []
76    data1 = ds.GeneratorDataset(generator_20, ["col1"])
77    data2 = ds.GeneratorDataset(generator_10, ["col1"])
78    data3 = data2 + data1
79    shard_num = 10
80    for i in range(shard_num):
81        distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
82        data3.use_sampler(distributed_sampler)
83        tem_list = []
84        for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
85            tem_list.append(ele['col1'][0])
86        verify_list.append(tem_list)
87
88    assert verify_list == result_list
89
90
91def test_Reapeat_afterPadded():
92    result_list = [1, 3, 5, 7]
93    verify_list = []
94
95    data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
96             {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
97             {'image': np.zeros(5, np.uint8)}]
98    data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
99             {'image': np.zeros(8, np.uint8)}]
100
101    ds1 = ds.PaddedDataset(data1)
102    ds2 = ds.PaddedDataset(data2)
103    ds3 = ds1 + ds2
104
105    testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
106    ds3.use_sampler(testsampler)
107    repeat_num = 2
108    ds3 = ds3.repeat(repeat_num)
109    for item in ds3.create_dict_iterator(num_epochs=1, output_numpy=True):
110        verify_list.append(len(item['image']))
111
112    assert verify_list == result_list * repeat_num
113
114
115def test_bath_afterPadded():
116    data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
117             {'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
118             {'image': np.zeros(1, np.uint8)}]
119    data2 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(1, np.uint8)},
120             {'image': np.zeros(1, np.uint8)}]
121
122    ds1 = ds.PaddedDataset(data1)
123    ds2 = ds.PaddedDataset(data2)
124    ds3 = ds1 + ds2
125
126    testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
127    ds3.use_sampler(testsampler)
128
129    ds4 = ds3.batch(2)
130    assert sum([1 for _ in ds4]) == 2
131
132
133def test_Unevenly_distributed():
134    result_list = [[1, 4, 7], [2, 5, 8], [3, 6]]
135    verify_list = []
136
137    data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
138             {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
139             {'image': np.zeros(5, np.uint8)}]
140    data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
141             {'image': np.zeros(8, np.uint8)}]
142
143    testsampler = ds.DistributedSampler(num_shards=4, shard_id=0, shuffle=False, num_samples=None, offset=1)
144
145    ds1 = ds.PaddedDataset(data1)
146    ds2 = ds.PaddedDataset(data2)
147    ds3 = ds1 + ds2
148    numShard = 3
149    for i in range(numShard):
150        tem_list = []
151        testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
152        ds3.use_sampler(testsampler)
153        for item in ds3.create_dict_iterator(num_epochs=1, output_numpy=True):
154            tem_list.append(len(item['image']))
155        verify_list.append(tem_list)
156    assert verify_list == result_list
157
158
159def test_three_datasets_connected():
160    result_list = []
161    for i in range(10):
162        tem_list = []
163        tem_list.append(i)
164        tem_list.append(10 + i)
165        tem_list.append(20 + i)
166        result_list.append(tem_list)
167
168    verify_list = []
169    data1 = ds.GeneratorDataset(generator_10, ["col1"])
170    data2 = ds.GeneratorDataset(generator_20, ["col1"])
171    data3 = ds.GeneratorDataset(generator_30, ["col1"])
172    data4 = data1 + data2 + data3
173    shard_num = 10
174    for i in range(shard_num):
175        distributed_sampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
176        data4.use_sampler(distributed_sampler)
177        tem_list = []
178        for ele in data4.create_dict_iterator(num_epochs=1, output_numpy=True):
179            tem_list.append(ele['col1'][0])
180        verify_list.append(tem_list)
181
182    assert verify_list == result_list
183
184
185def test_raise_error():
186    data1 = [{'image': np.zeros(0, np.uint8)}, {'image': np.zeros(0, np.uint8)},
187             {'image': np.zeros(0, np.uint8)}, {'image': np.zeros(0, np.uint8)},
188             {'image': np.zeros(0, np.uint8)}]
189    data2 = [{'image': np.zeros(0, np.uint8)}, {'image': np.zeros(0, np.uint8)},
190             {'image': np.zeros(0, np.uint8)}]
191
192    ds1 = ds.PaddedDataset(data1)
193    ds4 = ds1.batch(2)
194    ds2 = ds.PaddedDataset(data2)
195    ds3 = ds4 + ds2
196
197    with pytest.raises(TypeError) as excinfo:
198        testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=None)
199        ds3.use_sampler(testsampler)
200        assert excinfo.type == 'TypeError'
201
202    with pytest.raises(TypeError) as excinfo:
203        otherSampler = ds.SequentialSampler()
204        ds3.use_sampler(otherSampler)
205        assert excinfo.type == 'TypeError'
206
207    with pytest.raises(ValueError) as excinfo:
208        testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=True, num_samples=None)
209        ds3.use_sampler(testsampler)
210        assert excinfo.type == 'ValueError'
211
212    with pytest.raises(ValueError) as excinfo:
213        testsampler = ds.DistributedSampler(num_shards=2, shard_id=0, shuffle=False, num_samples=5)
214        ds3.use_sampler(testsampler)
215        assert excinfo.type == 'ValueError'
216
217def test_imagefolder_error():
218    DATA_DIR = "../data/dataset/testPK/data"
219    data = ds.ImageFolderDataset(DATA_DIR, num_samples=14)
220
221    data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
222             {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
223             {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
224             {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
225             {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
226             {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
227
228    data2 = ds.PaddedDataset(data1)
229    data3 = data + data2
230    with pytest.raises(ValueError) as excinfo:
231        testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
232        data3.use_sampler(testsampler)
233        assert excinfo.type == 'ValueError'
234
235def test_imagefolder_padded():
236    DATA_DIR = "../data/dataset/testPK/data"
237    data = ds.ImageFolderDataset(DATA_DIR)
238
239    data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)},
240             {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)},
241             {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)},
242             {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)},
243             {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)},
244             {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}]
245
246    data2 = ds.PaddedDataset(data1)
247    data3 = data + data2
248    testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None)
249    data3.use_sampler(testsampler)
250    assert sum([1 for _ in data3]) == 10
251    verify_list = []
252
253    for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
254        verify_list.append(len(ele['image']))
255    assert verify_list[8] == 1
256    assert verify_list[9] == 6
257
258
259def test_imagefolder_padded_with_decode():
260    num_shards = 5
261    count = 0
262    for shard_id in range(num_shards):
263        DATA_DIR = "../data/dataset/testPK/data"
264        data = ds.ImageFolderDataset(DATA_DIR)
265
266        white_io = BytesIO()
267        Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
268        padded_sample = {}
269        padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
270        padded_sample['label'] = np.array(-1, np.int32)
271
272        white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
273        data2 = ds.PaddedDataset(white_samples)
274        data3 = data + data2
275
276        testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
277        data3.use_sampler(testsampler)
278        data3 = data3.map(operations=V_C.Decode(), input_columns="image")
279        shard_sample_count = 0
280        for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
281            print("label: {}".format(ele['label']))
282            count += 1
283            shard_sample_count += 1
284        assert shard_sample_count in (9, 10)
285    assert count == 48
286
287
288def test_imagefolder_padded_with_decode_and_get_dataset_size():
289    num_shards = 5
290    count = 0
291    for shard_id in range(num_shards):
292        DATA_DIR = "../data/dataset/testPK/data"
293        data = ds.ImageFolderDataset(DATA_DIR)
294
295        white_io = BytesIO()
296        Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG')
297        padded_sample = {}
298        padded_sample['image'] = np.array(bytearray(white_io.getvalue()), dtype='uint8')
299        padded_sample['label'] = np.array(-1, np.int32)
300
301        white_samples = [padded_sample, padded_sample, padded_sample, padded_sample]
302        data2 = ds.PaddedDataset(white_samples)
303        data3 = data + data2
304
305        testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None)
306        data3.use_sampler(testsampler)
307        shard_dataset_size = data3.get_dataset_size()
308        data3 = data3.map(operations=V_C.Decode(), input_columns="image")
309        shard_sample_count = 0
310        for ele in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
311            print("label: {}".format(ele['label']))
312            count += 1
313            shard_sample_count += 1
314        assert shard_sample_count in (9, 10)
315        assert shard_dataset_size == shard_sample_count
316    assert count == 48
317
318
319def test_more_shard_padded():
320    result_list = []
321    for i in range(8):
322        result_list.append(1)
323    result_list.append(0)
324
325    data1 = ds.GeneratorDataset(generator_5, ["col1"])
326    data2 = ds.GeneratorDataset(generator_8, ["col1"])
327    data3 = data1 + data2
328    vertifyList = []
329    numShard = 9
330    for i in range(numShard):
331        tem_list = []
332        testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
333        data3.use_sampler(testsampler)
334        for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
335            tem_list.append(item['col1'])
336        vertifyList.append(tem_list)
337
338    assert [len(ele) for ele in vertifyList] == result_list
339
340    vertifyList1 = []
341    result_list1 = []
342    for i in range(8):
343        result_list1.append([i + 1])
344    result_list1.append([])
345
346    data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)},
347             {'image': np.zeros(3, np.uint8)}, {'image': np.zeros(4, np.uint8)},
348             {'image': np.zeros(5, np.uint8)}]
349    data2 = [{'image': np.zeros(6, np.uint8)}, {'image': np.zeros(7, np.uint8)},
350             {'image': np.zeros(8, np.uint8)}]
351
352    ds1 = ds.PaddedDataset(data1)
353    ds2 = ds.PaddedDataset(data2)
354    ds3 = ds1 + ds2
355
356    for i in range(numShard):
357        tem_list = []
358        testsampler = ds.DistributedSampler(num_shards=numShard, shard_id=i, shuffle=False, num_samples=None)
359        ds3.use_sampler(testsampler)
360        for item in ds3.create_dict_iterator(num_epochs=1, output_numpy=True):
361            tem_list.append(len(item['image']))
362        vertifyList1.append(tem_list)
363
364    assert vertifyList1 == result_list1
365
366
367def get_data(dir_name):
368    """
369    usage: get data from imagenet dataset
370
371    params:
372    dir_name: directory containing folder images and annotation information
373    """
374    if not os.path.isdir(dir_name):
375        raise IOError("Directory {} not exists".format(dir_name))
376    img_dir = os.path.join(dir_name, "images")
377    ann_file = os.path.join(dir_name, "annotation.txt")
378    with open(ann_file, "r") as file_reader:
379        lines = file_reader.readlines()
380
381    data_list = []
382    for i, line in enumerate(lines):
383        try:
384            filename, label = line.split(",")
385            label = label.strip("\n")
386            with open(os.path.join(img_dir, filename), "rb") as file_reader:
387                img = file_reader.read()
388            data_json = {"id": i,
389                         "file_name": filename,
390                         "data": img,
391                         "label": int(label)}
392            data_list.append(data_json)
393        except FileNotFoundError:
394            continue
395    return data_list
396
397
398@pytest.fixture(name="remove_mindrecord_file")
399def add_and_remove_cv_file():
400    """add/remove cv file"""
401    paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
402             for x in range(FILES_NUM)]
403    try:
404        for x in paths:
405            if os.path.exists("{}".format(x)):
406                os.remove("{}".format(x))
407            if os.path.exists("{}.db".format(x)):
408                os.remove("{}.db".format(x))
409        writer = FileWriter(CV_FILE_NAME, FILES_NUM)
410        data = get_data(CV_DIR_NAME)
411        cv_schema_json = {"id": {"type": "int32"},
412                          "file_name": {"type": "string"},
413                          "label": {"type": "int32"},
414                          "data": {"type": "bytes"}}
415        writer.add_schema(cv_schema_json, "img_schema")
416        writer.add_index(["file_name", "label"])
417        writer.write_raw_data(data)
418        writer.commit()
419        yield "yield_cv_data"
420    except Exception as error:
421        for x in paths:
422            os.remove("{}".format(x))
423            os.remove("{}.db".format(x))
424        raise error
425    else:
426        for x in paths:
427            os.remove("{}".format(x))
428            os.remove("{}.db".format(x))
429
430
431def test_Mindrecord_Padded(remove_mindrecord_file):
432    result_list = []
433    verify_list = [[1, 2], [3, 4], [5, 11], [6, 12], [7, 13], [8, 14], [9], [10]]
434    num_readers = 4
435    data_set = ds.MindDataset(CV_FILE_NAME + "0", ['file_name'], num_readers, shuffle=False)
436    data1 = [{'file_name': np.array(b'image_00011.jpg', dtype='|S15')},
437             {'file_name': np.array(b'image_00012.jpg', dtype='|S15')},
438             {'file_name': np.array(b'image_00013.jpg', dtype='|S15')},
439             {'file_name': np.array(b'image_00014.jpg', dtype='|S15')}]
440    ds1 = ds.PaddedDataset(data1)
441    ds2 = data_set + ds1
442    shard_num = 8
443    for i in range(shard_num):
444        testsampler = ds.DistributedSampler(num_shards=shard_num, shard_id=i, shuffle=False, num_samples=None)
445        ds2.use_sampler(testsampler)
446        tem_list = []
447        for ele in ds2.create_dict_iterator(num_epochs=1, output_numpy=True):
448            tem_list.append(int(ele['file_name'].tostring().decode().lstrip('image_').rstrip('.jpg')))
449        result_list.append(tem_list)
450    assert result_list == verify_list
451
452
453def test_clue_padded_and_skip_with_0_samples():
454    """
455    Test num_samples param of CLUE dataset
456    """
457    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
458
459    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
460    count = 0
461    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
462        count += 1
463    assert count == 3
464
465    data_copy1 = copy.deepcopy(data)
466
467    sample = {"label": np.array(1, np.string_),
468              "sentence1": np.array(1, np.string_),
469              "sentence2": np.array(1, np.string_)}
470    samples = [sample]
471    padded_ds = ds.PaddedDataset(samples)
472    dataset = data + padded_ds
473    testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
474    dataset.use_sampler(testsampler)
475    assert dataset.get_dataset_size() == 2
476    count = 0
477    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
478        count += 1
479    assert count == 2
480
481    dataset = dataset.skip(count=2)  # dataset2 has none samples
482    count = 0
483    for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
484        count += 1
485    assert count == 0
486
487    with pytest.raises(ValueError, match="There are no samples in the "):
488        dataset = dataset.concat(data_copy1)
489        count = 0
490        for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
491            count += 1
492        assert count == 2
493
494
495def test_celeba_padded():
496    data = ds.CelebADataset("../data/dataset/testCelebAData/")
497
498    padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
499    padded_ds = ds.PaddedDataset(padded_samples)
500    data = data + padded_ds
501    dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
502    data.use_sampler(dis_sampler)
503    data = data.repeat(2)
504
505    count = 0
506    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
507        count = count + 1
508    assert count == 4
509
510
511if __name__ == '__main__':
512    test_TFRecord_Padded()
513    test_GeneratorDataSet_Padded()
514    test_Reapeat_afterPadded()
515    test_bath_afterPadded()
516    test_Unevenly_distributed()
517    test_three_datasets_connected()
518    test_raise_error()
519    test_imagefolden_padded()
520    test_more_shard_padded()
521    test_Mindrecord_Padded(add_and_remove_cv_file)
522