• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16This is the test module for saveOp.
17"""
18import os
19from string import punctuation
20import numpy as np
21import pytest
22import mindspore.dataset as ds
23from mindspore import log as logger
24from mindspore.mindrecord import FileWriter
25
26TEMP_FILE = "../data/mindrecord/testMindDataSet/temp.mindrecord"
27AUTO_FILE = "../data/mindrecord/testMindDataSet/auto.mindrecord"
28TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
29FILES_NUM = 1
30num_readers = 1
31
32
33@pytest.fixture(name="add_remove_file")
34def fixture_remove():
35    """add/remove cv file"""
36    if os.path.exists("{}".format(TEMP_FILE)):
37        os.remove("{}".format(TEMP_FILE))
38    if os.path.exists("{}.db".format(TEMP_FILE)):
39        os.remove("{}.db".format(TEMP_FILE))
40
41    if os.path.exists("{}".format(AUTO_FILE)):
42        os.remove("{}".format(AUTO_FILE))
43    if os.path.exists("{}.db".format(AUTO_FILE)):
44        os.remove("{}.db".format(AUTO_FILE))
45    yield "yield_cv_data"
46    if os.path.exists("{}".format(TEMP_FILE)):
47        os.remove("{}".format(TEMP_FILE))
48    if os.path.exists("{}.db".format(TEMP_FILE)):
49        os.remove("{}.db".format(TEMP_FILE))
50
51    if os.path.exists("{}".format(AUTO_FILE)):
52        os.remove("{}".format(AUTO_FILE))
53    if os.path.exists("{}.db".format(AUTO_FILE)):
54        os.remove("{}.db".format(AUTO_FILE))
55
56
57def test_case_00(add_remove_file):  # only bin data
58    data = [{"image1": bytes("image1 bytes abc", encoding='UTF-8'),
59             "image2": bytes("image1 bytes def", encoding='UTF-8'),
60             "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
61             "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
62             "image5": bytes("image1 bytes mno", encoding='UTF-8')},
63            {"image1": bytes("image2 bytes abc", encoding='UTF-8'),
64             "image2": bytes("image2 bytes def", encoding='UTF-8'),
65             "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
66             "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
67             "image5": bytes("image2 bytes mno", encoding='UTF-8')},
68            {"image1": bytes("image3 bytes abc", encoding='UTF-8'),
69             "image2": bytes("image3 bytes def", encoding='UTF-8'),
70             "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
71             "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
72             "image5": bytes("image3 bytes mno", encoding='UTF-8')},
73            {"image1": bytes("image5 bytes abc", encoding='UTF-8'),
74             "image2": bytes("image5 bytes def", encoding='UTF-8'),
75             "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
76             "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
77             "image5": bytes("image5 bytes mno", encoding='UTF-8')},
78            {"image1": bytes("image6 bytes abc", encoding='UTF-8'),
79             "image2": bytes("image6 bytes def", encoding='UTF-8'),
80             "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
81             "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
82             "image5": bytes("image6 bytes mno", encoding='UTF-8')}]
83    schema = {
84        "image1": {"type": "bytes"},
85        "image2": {"type": "bytes"},
86        "image3": {"type": "bytes"},
87        "image4": {"type": "bytes"},
88        "image5": {"type": "bytes"}}
89    writer = FileWriter(TEMP_FILE, FILES_NUM)
90    writer.add_schema(schema, "schema")
91    writer.write_raw_data(data)
92    writer.commit()
93
94    d1 = ds.MindDataset(TEMP_FILE, None, num_readers, shuffle=False)
95    d1.save(AUTO_FILE, FILES_NUM)
96    data_value_to_list = []
97
98    for item in data:
99        new_data = {}
100        new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
101        new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
102        new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
103        new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
104        new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
105        data_value_to_list.append(new_data)
106
107    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
108                        num_parallel_workers=num_readers,
109                        shuffle=False)
110    assert d2.get_dataset_size() == 5
111    num_iter = 0
112    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
113        assert len(item) == 5
114        for field in item:
115            if isinstance(item[field], np.ndarray):
116                assert (item[field] ==
117                        data_value_to_list[num_iter][field]).all()
118            else:
119                assert item[field] == data_value_to_list[num_iter][field]
120        num_iter += 1
121    assert num_iter == 5
122
123
124def test_case_01(add_remove_file):  # only raw data
125    data = [{"file_name": "001.jpg", "label": 43},
126            {"file_name": "002.jpg", "label": 91},
127            {"file_name": "003.jpg", "label": 61},
128            {"file_name": "004.jpg", "label": 29},
129            {"file_name": "005.jpg", "label": 78},
130            {"file_name": "006.jpg", "label": 37}]
131    schema = {"file_name": {"type": "string"},
132              "label": {"type": "int32"}
133              }
134
135    writer = FileWriter(TEMP_FILE, FILES_NUM)
136    writer.add_schema(schema, "schema")
137    writer.write_raw_data(data)
138    writer.commit()
139
140    d1 = ds.MindDataset(TEMP_FILE, None, num_readers, shuffle=False)
141    d1.save(AUTO_FILE, FILES_NUM)
142
143    data_value_to_list = []
144    for item in data:
145        new_data = {}
146        new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
147        new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
148        data_value_to_list.append(new_data)
149
150    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
151                        num_parallel_workers=num_readers,
152                        shuffle=False)
153    assert d2.get_dataset_size() == 6
154    num_iter = 0
155    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
156        logger.info(item)
157        assert len(item) == 2
158        for field in item:
159            if isinstance(item[field], np.ndarray):
160                assert (item[field] ==
161                        data_value_to_list[num_iter][field]).all()
162            else:
163                assert item[field] == data_value_to_list[num_iter][field]
164        num_iter += 1
165    assert num_iter == 6
166
167
168def test_case_02(add_remove_file):  # muti-bytes
169    data = [{"file_name": "001.jpg", "label": 43,
170             "float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32),
171             "float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471,
172                                        123414314.2141243, 87.1212122], dtype=np.float64),
173             "float32": 3456.12345,
174             "float64": 1987654321.123456785,
175             "source_sos_ids": np.array([1, 2, 3, 4, 5], dtype=np.int32),
176             "source_sos_mask": np.array([6, 7, 8, 9, 10, 11, 12], dtype=np.int64),
177             "image1": bytes("image1 bytes abc", encoding='UTF-8'),
178             "image2": bytes("image1 bytes def", encoding='UTF-8'),
179             "image3": bytes("image1 bytes ghi", encoding='UTF-8'),
180             "image4": bytes("image1 bytes jkl", encoding='UTF-8'),
181             "image5": bytes("image1 bytes mno", encoding='UTF-8')},
182            {"file_name": "002.jpg", "label": 91,
183             "float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32),
184             "float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471,
185                                        123414314.2141243, 87.1212122], dtype=np.float64),
186             "float32": 3456.12445,
187             "float64": 1987654321.123456786,
188             "source_sos_ids": np.array([11, 2, 3, 4, 5], dtype=np.int32),
189             "source_sos_mask": np.array([16, 7, 8, 9, 10, 11, 12], dtype=np.int64),
190             "image1": bytes("image2 bytes abc", encoding='UTF-8'),
191             "image2": bytes("image2 bytes def", encoding='UTF-8'),
192             "image3": bytes("image2 bytes ghi", encoding='UTF-8'),
193             "image4": bytes("image2 bytes jkl", encoding='UTF-8'),
194             "image5": bytes("image2 bytes mno", encoding='UTF-8')},
195            {"file_name": "003.jpg", "label": 61,
196             "float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32),
197             "float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471,
198                                        123414314.2141243, 87.1212122], dtype=np.float64),
199             "float32": 3456.12545,
200             "float64": 1987654321.123456787,
201             "source_sos_ids": np.array([21, 2, 3, 4, 5], dtype=np.int32),
202             "source_sos_mask": np.array([26, 7, 8, 9, 10, 11, 12], dtype=np.int64),
203             "image1": bytes("image3 bytes abc", encoding='UTF-8'),
204             "image2": bytes("image3 bytes def", encoding='UTF-8'),
205             "image3": bytes("image3 bytes ghi", encoding='UTF-8'),
206             "image4": bytes("image3 bytes jkl", encoding='UTF-8'),
207             "image5": bytes("image3 bytes mno", encoding='UTF-8')},
208            {"file_name": "004.jpg", "label": 29,
209             "float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32),
210             "float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471,
211                                        123414314.2141243, 87.1212122], dtype=np.float64),
212             "float32": 3456.12645,
213             "float64": 1987654321.123456788,
214             "source_sos_ids": np.array([31, 2, 3, 4, 5], dtype=np.int32),
215             "source_sos_mask": np.array([36, 7, 8, 9, 10, 11, 12], dtype=np.int64),
216             "image1": bytes("image4 bytes abc", encoding='UTF-8'),
217             "image2": bytes("image4 bytes def", encoding='UTF-8'),
218             "image3": bytes("image4 bytes ghi", encoding='UTF-8'),
219             "image4": bytes("image4 bytes jkl", encoding='UTF-8'),
220             "image5": bytes("image4 bytes mno", encoding='UTF-8')},
221            {"file_name": "005.jpg", "label": 78,
222             "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
223             "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
224                                        123414314.2141243, 87.1212122], dtype=np.float64),
225             "float32": 3456.12745,
226             "float64": 1987654321.123456789,
227             "source_sos_ids": np.array([41, 2, 3, 4, 5], dtype=np.int32),
228             "source_sos_mask": np.array([46, 7, 8, 9, 10, 11, 12], dtype=np.int64),
229             "image1": bytes("image5 bytes abc", encoding='UTF-8'),
230             "image2": bytes("image5 bytes def", encoding='UTF-8'),
231             "image3": bytes("image5 bytes ghi", encoding='UTF-8'),
232             "image4": bytes("image5 bytes jkl", encoding='UTF-8'),
233             "image5": bytes("image5 bytes mno", encoding='UTF-8')},
234            {"file_name": "006.jpg", "label": 37,
235             "float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32),
236             "float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471,
237                                        123414314.2141243, 87.1212122], dtype=np.float64),
238             "float32": 3456.12745,
239             "float64": 1987654321.123456789,
240             "source_sos_ids": np.array([51, 2, 3, 4, 5], dtype=np.int32),
241             "source_sos_mask": np.array([56, 7, 8, 9, 10, 11, 12], dtype=np.int64),
242             "image1": bytes("image6 bytes abc", encoding='UTF-8'),
243             "image2": bytes("image6 bytes def", encoding='UTF-8'),
244             "image3": bytes("image6 bytes ghi", encoding='UTF-8'),
245             "image4": bytes("image6 bytes jkl", encoding='UTF-8'),
246             "image5": bytes("image6 bytes mno", encoding='UTF-8')}
247            ]
248    schema = {"file_name": {"type": "string"},
249              "float32_array": {"type": "float32", "shape": [-1]},
250              "float64_array": {"type": "float64", "shape": [-1]},
251              "float32": {"type": "float32"},
252              "float64": {"type": "float64"},
253              "source_sos_ids": {"type": "int32", "shape": [-1]},
254              "source_sos_mask": {"type": "int64", "shape": [-1]},
255              "image1": {"type": "bytes"},
256              "image2": {"type": "bytes"},
257              "image3": {"type": "bytes"},
258              "label": {"type": "int32"},
259              "image4": {"type": "bytes"},
260              "image5": {"type": "bytes"}}
261    writer = FileWriter(TEMP_FILE, FILES_NUM)
262    writer.add_schema(schema, "schema")
263    writer.write_raw_data(data)
264    writer.commit()
265
266    d1 = ds.MindDataset(TEMP_FILE, None, num_readers, shuffle=False)
267    d1.save(AUTO_FILE, FILES_NUM)
268    data_value_to_list = []
269
270    for item in data:
271        new_data = {}
272        new_data['file_name'] = np.asarray(item["file_name"], dtype='S')
273        new_data['float32_array'] = item["float32_array"]
274        new_data['float64_array'] = item["float64_array"]
275        new_data['float32'] = item["float32"]
276        new_data['float64'] = item["float64"]
277        new_data['source_sos_ids'] = item["source_sos_ids"]
278        new_data['source_sos_mask'] = item["source_sos_mask"]
279        new_data['label'] = np.asarray(list([item["label"]]), dtype=np.int32)
280        new_data['image1'] = np.asarray(list(item["image1"]), dtype=np.uint8)
281        new_data['image2'] = np.asarray(list(item["image2"]), dtype=np.uint8)
282        new_data['image3'] = np.asarray(list(item["image3"]), dtype=np.uint8)
283        new_data['image4'] = np.asarray(list(item["image4"]), dtype=np.uint8)
284        new_data['image5'] = np.asarray(list(item["image5"]), dtype=np.uint8)
285        data_value_to_list.append(new_data)
286
287    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
288                        num_parallel_workers=num_readers,
289                        shuffle=False)
290    assert d2.get_dataset_size() == 6
291    num_iter = 0
292    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
293        assert len(item) == 13
294        for field in item:
295            if isinstance(item[field], np.ndarray):
296                if item[field].dtype == np.float32:
297                    assert (item[field] ==
298                            np.array(data_value_to_list[num_iter][field], np.float32)).all()
299                else:
300                    assert (item[field] ==
301                            data_value_to_list[num_iter][field]).all()
302            else:
303                assert item[field] == data_value_to_list[num_iter][field]
304        num_iter += 1
305    assert num_iter == 6
306
307
308def generator_1d():
309    for i in range(10):
310        yield (np.array([i]),)
311
312
313def test_case_03(add_remove_file):
314
315    # apply dataset operations
316    d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
317
318    d1.save(AUTO_FILE)
319
320    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
321                        num_parallel_workers=num_readers,
322                        shuffle=False)
323
324    i = 0
325    # each data is a dictionary
326    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
327        golden = np.array([i])
328        np.testing.assert_array_equal(item["data"], golden)
329        i = i + 1
330
331
332def generator_with_type(t):
333    for i in range(64):
334        yield (np.array([i], dtype=t),)
335
336
337def type_tester(t):
338    logger.info("Test with Type {}".format(t.__name__))
339
340    # apply dataset operations
341    data1 = ds.GeneratorDataset((lambda: generator_with_type(t)), ["data"], shuffle=False)
342
343    data1 = data1.batch(4)
344
345    data1 = data1.repeat(3)
346
347    data1.save(AUTO_FILE)
348
349    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
350                        num_parallel_workers=num_readers,
351                        shuffle=False)
352
353    i = 0
354    num_repeat = 0
355    # each data is a dictionary
356    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
357        golden = np.array([[i], [i + 1], [i + 2], [i + 3]], dtype=t)
358        logger.info(item)
359        np.testing.assert_array_equal(item["data"], golden)
360        i = i + 4
361        if i == 64:
362            i = 0
363            num_repeat += 1
364    assert num_repeat == 3
365    if os.path.exists("{}".format(AUTO_FILE)):
366        os.remove("{}".format(AUTO_FILE))
367    if os.path.exists("{}.db".format(AUTO_FILE)):
368        os.remove("{}.db".format(AUTO_FILE))
369
370
371def test_case_04():
372    # uint8 will drop shape as mindrecord store uint8 as bytes
373    types = [np.int8, np.int16, np.int32, np.int64,
374             np.uint16, np.uint32, np.float32, np.float64]
375
376    for t in types:
377        type_tester(t)
378
379
380def test_case_05(add_remove_file):
381
382    d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
383
384    with pytest.raises(Exception, match="num_files should between 0 and 1000."):
385        d1.save(AUTO_FILE, 0)
386
387
388def test_case_06(add_remove_file):
389
390    d1 = ds.GeneratorDataset(generator_1d, ["data"], shuffle=False)
391
392    with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
393        d1.save(AUTO_FILE, 1, "tfrecord")
394
395
396def cast_name(key):
397    """
398    Cast schema names which containing special characters to valid names.
399    """
400    special_symbols = set('{}{}'.format(punctuation, ' '))
401    special_symbols.remove('_')
402    new_key = ['_' if x in special_symbols else x for x in key]
403    casted_key = ''.join(new_key)
404    return casted_key
405
406
407def test_case_07():
408    if os.path.exists("{}".format(AUTO_FILE)):
409        os.remove("{}".format(AUTO_FILE))
410    if os.path.exists("{}.db".format(AUTO_FILE)):
411        os.remove("{}.db".format(AUTO_FILE))
412    d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
413    tf_data = []
414    for x in d1.create_dict_iterator(num_epochs=1, output_numpy=True):
415        tf_data.append(x)
416    d1.save(AUTO_FILE, FILES_NUM)
417    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
418                        num_parallel_workers=num_readers,
419                        shuffle=False)
420    mr_data = []
421    for x in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
422        mr_data.append(x)
423    count = 0
424    for x in tf_data:
425        for k, v in x.items():
426            if isinstance(v, np.ndarray):
427                assert (v == mr_data[count][cast_name(k)]).all()
428            else:
429                assert v == mr_data[count][cast_name(k)]
430        count += 1
431    assert count == 10
432
433    if os.path.exists("{}".format(AUTO_FILE)):
434        os.remove("{}".format(AUTO_FILE))
435    if os.path.exists("{}.db".format(AUTO_FILE)):
436        os.remove("{}.db".format(AUTO_FILE))
437
438
439def generator_dynamic_1d():
440    arr = []
441    for i in range(10):
442        if i % 5 == 0:
443            arr = []
444        arr += [i]
445        yield (np.array(arr),)
446
447
448def generator_dynamic_2d_0():
449    for i in range(10):
450        if i < 5:
451            yield (np.arange(5).reshape([1, 5]),)
452        else:
453            yield (np.arange(10).reshape([2, 5]),)
454
455
456def generator_dynamic_2d_1():
457    for i in range(10):
458        if i < 5:
459            yield (np.arange(5).reshape([5, 1]),)
460        else:
461            yield (np.arange(10).reshape([5, 2]),)
462
463
464def test_case_08(add_remove_file):
465
466    # apply dataset operations
467    d1 = ds.GeneratorDataset(generator_dynamic_1d, ["data"], shuffle=False)
468
469    d1.save(AUTO_FILE)
470
471    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
472                        num_parallel_workers=num_readers,
473                        shuffle=False)
474
475    i = 0
476    arr = []
477    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
478        if i % 5 == 0:
479            arr = []
480        arr += [i]
481        golden = np.array(arr)
482        np.testing.assert_array_equal(item["data"], golden)
483        i = i + 1
484
485
486def test_case_09(add_remove_file):
487
488    # apply dataset operations
489    d1 = ds.GeneratorDataset(generator_dynamic_2d_0, ["data"], shuffle=False)
490
491    d1.save(AUTO_FILE)
492
493    d2 = ds.MindDataset(dataset_file=AUTO_FILE,
494                        num_parallel_workers=num_readers,
495                        shuffle=False)
496
497    i = 0
498    for item in d2.create_dict_iterator(num_epochs=1, output_numpy=True):
499        if i < 5:
500            golden = np.arange(5).reshape([1, 5])
501        else:
502            golden = np.arange(10).reshape([2, 5])
503        np.testing.assert_array_equal(item["data"], golden)
504        i = i + 1
505
506
507def test_case_10(add_remove_file):
508
509    # apply dataset operations
510    d1 = ds.GeneratorDataset(generator_dynamic_2d_1, ["data"], shuffle=False)
511
512    with pytest.raises(Exception, match=
513                       "Error: besides dimension 0, other dimension shape is different from the previous's"):
514        d1.save(AUTO_FILE)
515