• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16import pytest
17import mindspore.dataset as ds
18
19DATA_FILE = '../data/dataset/testCSV/1.csv'
20
21
22def test_csv_dataset_basic():
23    """
24    Test CSV with repeat, skip and so on
25    """
26    TRAIN_FILE = '../data/dataset/testCSV/1.csv'
27
28    buffer = []
29    data = ds.CSVDataset(
30        TRAIN_FILE,
31        field_delim=',',
32        column_defaults=["0", 0, 0.0, "0"],
33        column_names=['1', '2', '3', '4'],
34        shuffle=False)
35    data = data.repeat(2)
36    data = data.skip(2)
37    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
38        buffer.append(d)
39    assert len(buffer) == 4
40
41
42def test_csv_dataset_one_file():
43    data = ds.CSVDataset(
44        DATA_FILE,
45        column_defaults=["1", "2", "3", "4"],
46        column_names=['col1', 'col2', 'col3', 'col4'],
47        shuffle=False)
48    buffer = []
49    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
50        buffer.append(d)
51    assert len(buffer) == 3
52
53
54def test_csv_dataset_all_file():
55    APPEND_FILE = '../data/dataset/testCSV/2.csv'
56    data = ds.CSVDataset(
57        [DATA_FILE, APPEND_FILE],
58        column_defaults=["1", "2", "3", "4"],
59        column_names=['col1', 'col2', 'col3', 'col4'],
60        shuffle=False)
61    buffer = []
62    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
63        buffer.append(d)
64    assert len(buffer) == 10
65
66
67def test_csv_dataset_num_samples():
68    data = ds.CSVDataset(
69        DATA_FILE,
70        column_defaults=["1", "2", "3", "4"],
71        column_names=['col1', 'col2', 'col3', 'col4'],
72        shuffle=False, num_samples=2)
73    count = 0
74    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
75        count += 1
76    assert count == 2
77
78
79def test_csv_dataset_distribution():
80    TEST_FILE = '../data/dataset/testCSV/1.csv'
81    data = ds.CSVDataset(
82        TEST_FILE,
83        column_defaults=["1", "2", "3", "4"],
84        column_names=['col1', 'col2', 'col3', 'col4'],
85        shuffle=False, num_shards=2, shard_id=0)
86    count = 0
87    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
88        count += 1
89    assert count == 2
90
91
92def test_csv_dataset_quoted():
93    TEST_FILE = '../data/dataset/testCSV/quoted.csv'
94    data = ds.CSVDataset(
95        TEST_FILE,
96        column_defaults=["", "", "", ""],
97        column_names=['col1', 'col2', 'col3', 'col4'],
98        shuffle=False)
99    buffer = []
100    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
101        buffer.extend([d['col1'].item().decode("utf8"),
102                       d['col2'].item().decode("utf8"),
103                       d['col3'].item().decode("utf8"),
104                       d['col4'].item().decode("utf8")])
105    assert buffer == ['a', 'b', 'c', 'd']
106
107
108def test_csv_dataset_separated():
109    TEST_FILE = '../data/dataset/testCSV/separated.csv'
110    data = ds.CSVDataset(
111        TEST_FILE,
112        field_delim='|',
113        column_defaults=["", "", "", ""],
114        column_names=['col1', 'col2', 'col3', 'col4'],
115        shuffle=False)
116    buffer = []
117    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
118        buffer.extend([d['col1'].item().decode("utf8"),
119                       d['col2'].item().decode("utf8"),
120                       d['col3'].item().decode("utf8"),
121                       d['col4'].item().decode("utf8")])
122    assert buffer == ['a', 'b', 'c', 'd']
123
124
125def test_csv_dataset_embedded():
126    TEST_FILE = '../data/dataset/testCSV/embedded.csv'
127    data = ds.CSVDataset(
128        TEST_FILE,
129        column_defaults=["", "", "", ""],
130        column_names=['col1', 'col2', 'col3', 'col4'],
131        shuffle=False)
132    buffer = []
133    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
134        buffer.extend([d['col1'].item().decode("utf8"),
135                       d['col2'].item().decode("utf8"),
136                       d['col3'].item().decode("utf8"),
137                       d['col4'].item().decode("utf8")])
138    assert buffer == ['a,b', 'c"d', 'e\nf', ' g ']
139
140
141def test_csv_dataset_chinese():
142    TEST_FILE = '../data/dataset/testCSV/chinese.csv'
143    data = ds.CSVDataset(
144        TEST_FILE,
145        column_defaults=["", "", "", "", ""],
146        column_names=['col1', 'col2', 'col3', 'col4', 'col5'],
147        shuffle=False)
148    buffer = []
149    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
150        buffer.extend([d['col1'].item().decode("utf8"),
151                       d['col2'].item().decode("utf8"),
152                       d['col3'].item().decode("utf8"),
153                       d['col4'].item().decode("utf8"),
154                       d['col5'].item().decode("utf8")])
155    assert buffer == ['大家', '早上好', '中午好', '下午好', '晚上好']
156
157
158def test_csv_dataset_header():
159    TEST_FILE = '../data/dataset/testCSV/header.csv'
160    data = ds.CSVDataset(
161        TEST_FILE,
162        column_defaults=["", "", "", ""],
163        shuffle=False)
164    buffer = []
165    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
166        buffer.extend([d['col1'].item().decode("utf8"),
167                       d['col2'].item().decode("utf8"),
168                       d['col3'].item().decode("utf8"),
169                       d['col4'].item().decode("utf8")])
170    assert buffer == ['a', 'b', 'c', 'd']
171
172
173def test_csv_dataset_number():
174    TEST_FILE = '../data/dataset/testCSV/number.csv'
175    data = ds.CSVDataset(
176        TEST_FILE,
177        column_defaults=[0.0, 0.0, 0, 0.0],
178        column_names=['col1', 'col2', 'col3', 'col4'],
179        shuffle=False)
180    buffer = []
181    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
182        buffer.extend([d['col1'].item(),
183                       d['col2'].item(),
184                       d['col3'].item(),
185                       d['col4'].item()])
186    assert np.allclose(buffer, [3.0, 0.3, 4, 55.5])
187
188
189def test_csv_dataset_field_delim_none():
190    """
191    Test CSV with field_delim=None
192    """
193    TRAIN_FILE = '../data/dataset/testCSV/1.csv'
194
195    buffer = []
196    data = ds.CSVDataset(
197        TRAIN_FILE,
198        field_delim=None,
199        column_defaults=["0", 0, 0.0, "0"],
200        column_names=['1', '2', '3', '4'],
201        shuffle=False)
202    data = data.repeat(2)
203    data = data.skip(2)
204    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
205        buffer.append(d)
206    assert len(buffer) == 4
207
208
209def test_csv_dataset_size():
210    TEST_FILE = '../data/dataset/testCSV/size.csv'
211    data = ds.CSVDataset(
212        TEST_FILE,
213        column_defaults=[0.0, 0.0, 0, 0.0],
214        column_names=['col1', 'col2', 'col3', 'col4'],
215        shuffle=False)
216    assert data.get_dataset_size() == 5
217
218
219def test_csv_dataset_type_error():
220    TEST_FILE = '../data/dataset/testCSV/exception.csv'
221    data = ds.CSVDataset(
222        TEST_FILE,
223        column_defaults=["", 0, "", ""],
224        column_names=['col1', 'col2', 'col3', 'col4'],
225        shuffle=False)
226    with pytest.raises(Exception) as err:
227        for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
228            pass
229    assert "type does not match" in str(err.value)
230
231
232def test_csv_dataset_exception():
233    TEST_FILE = '../data/dataset/testCSV/exception.csv'
234    data = ds.CSVDataset(
235        TEST_FILE,
236        column_defaults=["", "", "", ""],
237        column_names=['col1', 'col2', 'col3', 'col4'],
238        shuffle=False)
239    with pytest.raises(Exception) as err:
240        for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
241            pass
242    assert "failed to parse file" in str(err.value)
243
244    TEST_FILE1 = '../data/dataset/testCSV/quoted.csv'
245    def exception_func(item):
246        raise Exception("Error occur!")
247
248    try:
249        data = ds.CSVDataset(
250            TEST_FILE1,
251            column_defaults=["", "", "", ""],
252            column_names=['col1', 'col2', 'col3', 'col4'],
253            shuffle=False)
254        data = data.map(operations=exception_func, input_columns=["col1"], num_parallel_workers=1)
255        for _ in data.__iter__():
256            pass
257        assert False
258    except RuntimeError as e:
259        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
260
261    try:
262        data = ds.CSVDataset(
263            TEST_FILE1,
264            column_defaults=["", "", "", ""],
265            column_names=['col1', 'col2', 'col3', 'col4'],
266            shuffle=False)
267        data = data.map(operations=exception_func, input_columns=["col2"], num_parallel_workers=1)
268        for _ in data.__iter__():
269            pass
270        assert False
271    except RuntimeError as e:
272        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
273
274    try:
275        data = ds.CSVDataset(
276            TEST_FILE1,
277            column_defaults=["", "", "", ""],
278            column_names=['col1', 'col2', 'col3', 'col4'],
279            shuffle=False)
280        data = data.map(operations=exception_func, input_columns=["col3"], num_parallel_workers=1)
281        for _ in data.__iter__():
282            pass
283        assert False
284    except RuntimeError as e:
285        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
286
287    try:
288        data = ds.CSVDataset(
289            TEST_FILE1,
290            column_defaults=["", "", "", ""],
291            column_names=['col1', 'col2', 'col3', 'col4'],
292            shuffle=False)
293        data = data.map(operations=exception_func, input_columns=["col4"], num_parallel_workers=1)
294        for _ in data.__iter__():
295            pass
296        assert False
297    except RuntimeError as e:
298        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
299
300
301def test_csv_dataset_duplicate_columns():
302    data = ds.CSVDataset(
303        DATA_FILE,
304        column_defaults=["1", "2", "3", "4"],
305        column_names=['col1', 'col2', 'col3', 'col4', 'col1', 'col2', 'col3', 'col4'],
306        shuffle=False)
307    with pytest.raises(RuntimeError) as info:
308        _ = data.create_dict_iterator(num_epochs=1, output_numpy=True)
309    assert "Invalid parameter, duplicate column names are not allowed: col1" in str(info.value)
310    assert "column_names" in str(info.value)
311
312
313if __name__ == "__main__":
314    test_csv_dataset_basic()
315    test_csv_dataset_one_file()
316    test_csv_dataset_all_file()
317    test_csv_dataset_num_samples()
318    test_csv_dataset_distribution()
319    test_csv_dataset_quoted()
320    test_csv_dataset_separated()
321    test_csv_dataset_embedded()
322    test_csv_dataset_chinese()
323    test_csv_dataset_header()
324    test_csv_dataset_number()
325    test_csv_dataset_field_delim_none()
326    test_csv_dataset_size()
327    test_csv_dataset_type_error()
328    test_csv_dataset_exception()
329    test_csv_dataset_duplicate_columns()
330