• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2019-2021 Huawei Technologies Co., Ltd
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16
17import os
18import pytest
19
20import mindspore.dataset as ds
21
22from mindspore.mindrecord import FileWriter
23
24CV_FILE_NAME = "./imagenet.mindrecord"
25CV1_FILE_NAME = "./imagenet1.mindrecord"
26
27
28def create_cv_mindrecord(files_num):
29    """tutorial for cv dataset writer."""
30    if os.path.exists(CV_FILE_NAME):
31        os.remove(CV_FILE_NAME)
32    if os.path.exists("{}.db".format(CV_FILE_NAME)):
33        os.remove("{}.db".format(CV_FILE_NAME))
34    writer = FileWriter(CV_FILE_NAME, files_num)
35    cv_schema_json = {"file_name": {"type": "string"},
36                      "label": {"type": "int32"}, "data": {"type": "bytes"}}
37    data = [{"file_name": "001.jpg", "label": 43,
38             "data": bytes('0xffsafdafda', encoding='utf-8')}]
39    writer.add_schema(cv_schema_json, "img_schema")
40    writer.add_index(["file_name", "label"])
41    writer.write_raw_data(data)
42    writer.commit()
43
44
45def create_diff_schema_cv_mindrecord(files_num):
46    """tutorial for cv dataset writer."""
47    if os.path.exists(CV1_FILE_NAME):
48        os.remove(CV1_FILE_NAME)
49    if os.path.exists("{}.db".format(CV1_FILE_NAME)):
50        os.remove("{}.db".format(CV1_FILE_NAME))
51    writer = FileWriter(CV1_FILE_NAME, files_num)
52    cv_schema_json = {"file_name_1": {"type": "string"},
53                      "label": {"type": "int32"}, "data": {"type": "bytes"}}
54    data = [{"file_name_1": "001.jpg", "label": 43,
55             "data": bytes('0xffsafdafda', encoding='utf-8')}]
56    writer.add_schema(cv_schema_json, "img_schema")
57    writer.add_index(["file_name_1", "label"])
58    writer.write_raw_data(data)
59    writer.commit()
60
61
62def create_diff_page_size_cv_mindrecord(files_num):
63    """tutorial for cv dataset writer."""
64    if os.path.exists(CV1_FILE_NAME):
65        os.remove(CV1_FILE_NAME)
66    if os.path.exists("{}.db".format(CV1_FILE_NAME)):
67        os.remove("{}.db".format(CV1_FILE_NAME))
68    writer = FileWriter(CV1_FILE_NAME, files_num)
69    writer.set_page_size(1 << 26)  # 64MB
70    cv_schema_json = {"file_name": {"type": "string"},
71                      "label": {"type": "int32"}, "data": {"type": "bytes"}}
72    data = [{"file_name": "001.jpg", "label": 43,
73             "data": bytes('0xffsafdafda', encoding='utf-8')}]
74    writer.add_schema(cv_schema_json, "img_schema")
75    writer.add_index(["file_name", "label"])
76    writer.write_raw_data(data)
77    writer.commit()
78
79
80def test_cv_lack_json():
81    """tutorial for cv minderdataset."""
82    create_cv_mindrecord(1)
83    columns_list = ["data", "file_name", "label"]
84    num_readers = 4
85    with pytest.raises(Exception):
86        ds.MindDataset(CV_FILE_NAME, "no_exist.json",
87                       columns_list, num_readers)
88    os.remove(CV_FILE_NAME)
89    os.remove("{}.db".format(CV_FILE_NAME))
90
91
92def test_cv_lack_mindrecord():
93    """tutorial for cv minderdataset."""
94    columns_list = ["data", "file_name", "label"]
95    num_readers = 4
96    with pytest.raises(Exception, match="does not exist or permission denied"):
97        _ = ds.MindDataset("no_exist.mindrecord", columns_list, num_readers)
98
99
100def test_invalid_mindrecord():
101    with open('dummy.mindrecord', 'w') as f:
102        f.write('just for test')
103    columns_list = ["data", "file_name", "label"]
104    num_readers = 4
105    with pytest.raises(RuntimeError, match="Unexpected error. Invalid file "
106                       "content, incorrect file or file header is exceeds the upper limit."):
107        data_set = ds.MindDataset(
108            'dummy.mindrecord', columns_list, num_readers)
109        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
110            pass
111    os.remove('dummy.mindrecord')
112
113
114def test_minddataset_lack_db():
115    create_cv_mindrecord(1)
116    os.remove("{}.db".format(CV_FILE_NAME))
117    columns_list = ["data", "file_name", "label"]
118    num_readers = 4
119    with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file, path:"):
120        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
121        num_iter = 0
122        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
123            num_iter += 1
124        try:
125            assert num_iter == 0
126        except Exception as error:
127            os.remove(CV_FILE_NAME)
128            raise error
129        else:
130            os.remove(CV_FILE_NAME)
131
132
133def test_cv_minddataset_pk_sample_error_class_column():
134    create_cv_mindrecord(1)
135    columns_list = ["data", "file_name", "label"]
136    num_readers = 4
137    sampler = ds.PKSampler(5, None, True, 'no_exist_column')
138    with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."):
139        data_set = ds.MindDataset(
140            CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
141        num_iter = 0
142        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
143            num_iter += 1
144    os.remove(CV_FILE_NAME)
145    os.remove("{}.db".format(CV_FILE_NAME))
146
147
148def test_cv_minddataset_pk_sample_exclusive_shuffle():
149    create_cv_mindrecord(1)
150    columns_list = ["data", "file_name", "label"]
151    num_readers = 4
152    sampler = ds.PKSampler(2)
153    with pytest.raises(Exception, match="sampler and shuffle cannot be specified at the same time."):
154        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
155                                  sampler=sampler, shuffle=False)
156        num_iter = 0
157        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
158            num_iter += 1
159    os.remove(CV_FILE_NAME)
160    os.remove("{}.db".format(CV_FILE_NAME))
161
162
163def test_cv_minddataset_reader_different_schema():
164    create_cv_mindrecord(1)
165    create_diff_schema_cv_mindrecord(1)
166    columns_list = ["data", "label"]
167    num_readers = 4
168    with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, "
169                       "MindRecord files meta data is not consistent."):
170        data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
171                                  num_readers)
172        num_iter = 0
173        for _ in data_set.create_dict_iterator(num_epochs=1):
174            num_iter += 1
175    os.remove(CV_FILE_NAME)
176    os.remove("{}.db".format(CV_FILE_NAME))
177    os.remove(CV1_FILE_NAME)
178    os.remove("{}.db".format(CV1_FILE_NAME))
179
180
181def test_cv_minddataset_reader_different_page_size():
182    create_cv_mindrecord(1)
183    create_diff_page_size_cv_mindrecord(1)
184    columns_list = ["data", "label"]
185    num_readers = 4
186    with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, "
187                       "MindRecord files meta data is not consistent."):
188        data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
189                                  num_readers)
190        num_iter = 0
191        for _ in data_set.create_dict_iterator(num_epochs=1):
192            num_iter += 1
193    os.remove(CV_FILE_NAME)
194    os.remove("{}.db".format(CV_FILE_NAME))
195    os.remove(CV1_FILE_NAME)
196    os.remove("{}.db".format(CV1_FILE_NAME))
197
198
199def test_minddataset_invalidate_num_shards():
200    create_cv_mindrecord(1)
201    columns_list = ["data", "label"]
202    num_readers = 4
203    with pytest.raises(Exception) as error_info:
204        data_set = ds.MindDataset(
205            CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
206        num_iter = 0
207        for _ in data_set.create_dict_iterator(num_epochs=1):
208            num_iter += 1
209    try:
210        assert 'Input shard_id is not within the required interval of [0, 0].' in str(
211            error_info.value)
212    except Exception as error:
213        os.remove(CV_FILE_NAME)
214        os.remove("{}.db".format(CV_FILE_NAME))
215        raise error
216    else:
217        os.remove(CV_FILE_NAME)
218        os.remove("{}.db".format(CV_FILE_NAME))
219
220
221def test_minddataset_invalidate_shard_id():
222    create_cv_mindrecord(1)
223    columns_list = ["data", "label"]
224    num_readers = 4
225    with pytest.raises(Exception) as error_info:
226        data_set = ds.MindDataset(
227            CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
228        num_iter = 0
229        for _ in data_set.create_dict_iterator(num_epochs=1):
230            num_iter += 1
231    try:
232        assert 'Input shard_id is not within the required interval of [0, 0].' in str(
233            error_info.value)
234    except Exception as error:
235        os.remove(CV_FILE_NAME)
236        os.remove("{}.db".format(CV_FILE_NAME))
237        raise error
238    else:
239        os.remove(CV_FILE_NAME)
240        os.remove("{}.db".format(CV_FILE_NAME))
241
242
243def test_minddataset_shard_id_bigger_than_num_shard():
244    create_cv_mindrecord(1)
245    columns_list = ["data", "label"]
246    num_readers = 4
247    with pytest.raises(Exception) as error_info:
248        data_set = ds.MindDataset(
249            CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
250        num_iter = 0
251        for _ in data_set.create_dict_iterator(num_epochs=1):
252            num_iter += 1
253    try:
254        assert 'Input shard_id is not within the required interval of [0, 1].' in str(
255            error_info.value)
256    except Exception as error:
257        os.remove(CV_FILE_NAME)
258        os.remove("{}.db".format(CV_FILE_NAME))
259        raise error
260
261    with pytest.raises(Exception) as error_info:
262        data_set = ds.MindDataset(
263            CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
264        num_iter = 0
265        for _ in data_set.create_dict_iterator(num_epochs=1):
266            num_iter += 1
267    try:
268        assert 'Input shard_id is not within the required interval of [0, 1].' in str(
269            error_info.value)
270    except Exception as error:
271        os.remove(CV_FILE_NAME)
272        os.remove("{}.db".format(CV_FILE_NAME))
273        raise error
274    else:
275        os.remove(CV_FILE_NAME)
276        os.remove("{}.db".format(CV_FILE_NAME))
277
278
279def test_cv_minddataset_partition_num_samples_equals_0():
280    """tutorial for cv minddataset."""
281    create_cv_mindrecord(1)
282    columns_list = ["data", "label"]
283    num_readers = 4
284
285    def partitions(num_shards):
286        for partition_id in range(num_shards):
287            data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
288                                      num_shards=num_shards,
289                                      shard_id=partition_id, num_samples=-1)
290            num_iter = 0
291            for _ in data_set.create_dict_iterator(num_epochs=1):
292                num_iter += 1
293
294    with pytest.raises(ValueError) as error_info:
295        partitions(5)
296    try:
297        assert 'num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)' in str(
298            error_info.value)
299    except Exception as error:
300        os.remove(CV_FILE_NAME)
301        os.remove("{}.db".format(CV_FILE_NAME))
302        raise error
303    else:
304        os.remove(CV_FILE_NAME)
305        os.remove("{}.db".format(CV_FILE_NAME))
306
307
308def test_mindrecord_exception():
309    """tutorial for exception scenario of minderdataset + map would print error info."""
310
311    def exception_func(item):
312        raise Exception("Error occur!")
313
314    create_cv_mindrecord(1)
315    columns_list = ["data", "file_name", "label"]
316    with pytest.raises(RuntimeError, match="The corresponding data files"):
317        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False)
318        data_set = data_set.map(operations=exception_func, input_columns=["data"],
319                                num_parallel_workers=1)
320        num_iter = 0
321        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
322            num_iter += 1
323    with pytest.raises(RuntimeError, match="The corresponding data files"):
324        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False)
325        data_set = data_set.map(operations=exception_func, input_columns=["file_name"],
326                                num_parallel_workers=1)
327        num_iter = 0
328        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
329            num_iter += 1
330    with pytest.raises(RuntimeError, match="The corresponding data files"):
331        data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False)
332        data_set = data_set.map(operations=exception_func, input_columns=["label"],
333                                num_parallel_workers=1)
334        num_iter = 0
335        for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
336            num_iter += 1
337    os.remove(CV_FILE_NAME)
338    os.remove("{}.db".format(CV_FILE_NAME))
339
340
341if __name__ == '__main__':
342    test_cv_lack_json()
343    test_cv_lack_mindrecord()
344    test_invalid_mindrecord()
345    test_minddataset_lack_db()
346    test_cv_minddataset_pk_sample_error_class_column()
347    test_cv_minddataset_pk_sample_exclusive_shuffle()
348    test_cv_minddataset_reader_different_schema()
349    test_cv_minddataset_reader_different_page_size()
350    test_minddataset_invalidate_num_shards()
351    test_minddataset_invalidate_shard_id()
352    test_minddataset_shard_id_bigger_than_num_shard()
353    test_cv_minddataset_partition_num_samples_equals_0()
354    test_mindrecord_exception()
355