• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import mindspore.dataset as ds
17from mindspore import log as logger
18from util import config_get_set_num_parallel_workers, config_get_set_seed
19
20
21DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
22DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
23
24
25def test_textline_dataset_one_file():
26    data = ds.TextFileDataset(DATA_FILE)
27    count = 0
28    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
29        logger.info("{}".format(i["text"]))
30        count += 1
31    assert count == 3
32
33
34def test_textline_dataset_all_file():
35    data = ds.TextFileDataset(DATA_ALL_FILE)
36    count = 0
37    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
38        logger.info("{}".format(i["text"]))
39        count += 1
40    assert count == 5
41
42
43def test_textline_dataset_num_samples_none():
44    # Do not provide a num_samples argument, so it would be None by default
45    data = ds.TextFileDataset(DATA_FILE)
46    count = 0
47    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
48        logger.info("{}".format(i["text"]))
49        count += 1
50    assert count == 3
51
52
53def test_textline_dataset_shuffle_false4():
54    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
55    original_seed = config_get_set_seed(987)
56    data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
57    count = 0
58    line = ["This is a text file.", "Another file.",
59            "Be happy every day.", "End of file.", "Good luck to everyone."]
60    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
61        strs = i["text"].item().decode("utf8")
62        assert strs == line[count]
63        count += 1
64    assert count == 5
65    # Restore configuration
66    ds.config.set_num_parallel_workers(original_num_parallel_workers)
67    ds.config.set_seed(original_seed)
68
69
70def test_textline_dataset_shuffle_false1():
71    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
72    original_seed = config_get_set_seed(987)
73    data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
74    count = 0
75    line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
76            "Another file.", "End of file."]
77    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
78        strs = i["text"].item().decode("utf8")
79        assert strs == line[count]
80        count += 1
81    assert count == 5
82    # Restore configuration
83    ds.config.set_num_parallel_workers(original_num_parallel_workers)
84    ds.config.set_seed(original_seed)
85
86
87def test_textline_dataset_shuffle_files4():
88    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
89    original_seed = config_get_set_seed(135)
90    data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
91    count = 0
92    line = ["This is a text file.", "Another file.",
93            "Be happy every day.", "End of file.", "Good luck to everyone."]
94    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
95        strs = i["text"].item().decode("utf8")
96        assert strs == line[count]
97        count += 1
98    assert count == 5
99    # Restore configuration
100    ds.config.set_num_parallel_workers(original_num_parallel_workers)
101    ds.config.set_seed(original_seed)
102
103
104def test_textline_dataset_shuffle_files1():
105    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
106    original_seed = config_get_set_seed(135)
107    data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
108    count = 0
109    line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
110            "Another file.", "End of file."]
111    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
112        strs = i["text"].item().decode("utf8")
113        assert strs == line[count]
114        count += 1
115    assert count == 5
116    # Restore configuration
117    ds.config.set_num_parallel_workers(original_num_parallel_workers)
118    ds.config.set_seed(original_seed)
119
120
121def test_textline_dataset_shuffle_global4():
122    original_num_parallel_workers = config_get_set_num_parallel_workers(4)
123    original_seed = config_get_set_seed(246)
124    data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
125    count = 0
126    line = ["Another file.", "Good luck to everyone.", "End of file.",
127            "This is a text file.", "Be happy every day."]
128    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
129        strs = i["text"].item().decode("utf8")
130        assert strs == line[count]
131        count += 1
132    assert count == 5
133    # Restore configuration
134    ds.config.set_num_parallel_workers(original_num_parallel_workers)
135    ds.config.set_seed(original_seed)
136
137
138def test_textline_dataset_shuffle_global1():
139    original_num_parallel_workers = config_get_set_num_parallel_workers(1)
140    original_seed = config_get_set_seed(246)
141    data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
142    count = 0
143    line = ["Another file.", "Good luck to everyone.", "This is a text file.",
144            "End of file.", "Be happy every day."]
145    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
146        strs = i["text"].item().decode("utf8")
147        assert strs == line[count]
148        count += 1
149    assert count == 5
150    # Restore configuration
151    ds.config.set_num_parallel_workers(original_num_parallel_workers)
152    ds.config.set_seed(original_seed)
153
154
155def test_textline_dataset_num_samples():
156    data = ds.TextFileDataset(DATA_FILE, num_samples=2)
157    count = 0
158    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
159        count += 1
160    assert count == 2
161
162
163def test_textline_dataset_distribution():
164    data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1)
165    count = 0
166    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
167        count += 1
168    assert count == 3
169
170
171def test_textline_dataset_repeat():
172    data = ds.TextFileDataset(DATA_FILE, shuffle=False)
173    data = data.repeat(3)
174    count = 0
175    line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
176            "This is a text file.", "Be happy every day.", "Good luck to everyone.",
177            "This is a text file.", "Be happy every day.", "Good luck to everyone."]
178    for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
179        strs = i["text"].item().decode("utf8")
180        assert strs == line[count]
181        count += 1
182    assert count == 9
183
184
185def test_textline_dataset_get_datasetsize():
186    data = ds.TextFileDataset(DATA_FILE)
187    size = data.get_dataset_size()
188    assert size == 3
189
190def test_textline_dataset_to_device():
191    data = ds.TextFileDataset(DATA_FILE, shuffle=False)
192    data = data.to_device()
193    data.send()
194
195def test_textline_dataset_exceptions():
196    with pytest.raises(ValueError) as error_info:
197        _ = ds.TextFileDataset(DATA_FILE, num_samples=-1)
198    assert "num_samples exceeds the boundary" in str(error_info.value)
199
200    with pytest.raises(ValueError) as error_info:
201        _ = ds.TextFileDataset("does/not/exist/no.txt")
202    assert "The following patterns did not match any files" in str(error_info.value)
203
204    with pytest.raises(ValueError) as error_info:
205        _ = ds.TextFileDataset("")
206    assert "The following patterns did not match any files" in str(error_info.value)
207
208    def exception_func(item):
209        raise Exception("Error occur!")
210    with pytest.raises(RuntimeError) as error_info:
211        data = ds.TextFileDataset(DATA_FILE)
212        data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
213        for _ in data.__iter__():
214            pass
215    assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
216
217
218if __name__ == "__main__":
219    test_textline_dataset_one_file()
220    test_textline_dataset_all_file()
221    test_textline_dataset_num_samples_none()
222    test_textline_dataset_shuffle_false4()
223    test_textline_dataset_shuffle_false1()
224    test_textline_dataset_shuffle_files4()
225    test_textline_dataset_shuffle_files1()
226    test_textline_dataset_shuffle_global4()
227    test_textline_dataset_shuffle_global1()
228    test_textline_dataset_num_samples()
229    test_textline_dataset_distribution()
230    test_textline_dataset_repeat()
231    test_textline_dataset_get_datasetsize()
232    test_textline_dataset_to_device()
233    test_textline_dataset_exceptions()
234