• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import os
16import pytest
17import mindspore.dataset as ds
18
19
20def test_clue():
21    """
22    Test CLUE with repeat, skip and so on
23    """
24    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
25
26    buffer = []
27    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
28    data = data.repeat(2)
29    data = data.skip(3)
30    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
31        buffer.append({
32            'label': d['label'].item().decode("utf8"),
33            'sentence1': d['sentence1'].item().decode("utf8"),
34            'sentence2': d['sentence2'].item().decode("utf8")
35        })
36    assert len(buffer) == 3
37
38
39def test_clue_num_shards():
40    """
41    Test num_shards param of CLUE dataset
42    """
43    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
44
45    buffer = []
46    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
47    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
48        buffer.append({
49            'label': d['label'].item().decode("utf8"),
50            'sentence1': d['sentence1'].item().decode("utf8"),
51            'sentence2': d['sentence2'].item().decode("utf8")
52        })
53    assert len(buffer) == 1
54
55
56def test_clue_num_samples():
57    """
58    Test num_samples param of CLUE dataset
59    """
60    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
61
62    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
63    count = 0
64    for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
65        count += 1
66    assert count == 2
67
68
69def test_textline_dataset_get_datasetsize():
70    """
71    Test get_dataset_size of CLUE dataset
72    """
73    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
74
75    data = ds.TextFileDataset(TRAIN_FILE)
76    size = data.get_dataset_size()
77    assert size == 3
78
79
80def test_clue_afqmc():
81    """
82    Test AFQMC for train, test and evaluation
83    """
84    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
85    TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
86    EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
87
88    # train
89    buffer = []
90    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
91    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
92        buffer.append({
93            'label': d['label'].item().decode("utf8"),
94            'sentence1': d['sentence1'].item().decode("utf8"),
95            'sentence2': d['sentence2'].item().decode("utf8")
96        })
97    assert len(buffer) == 3
98
99    # test
100    buffer = []
101    data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
102    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
103        buffer.append({
104            'id': d['id'],
105            'sentence1': d['sentence1'].item().decode("utf8"),
106            'sentence2': d['sentence2'].item().decode("utf8")
107        })
108    assert len(buffer) == 3
109
110    # evaluation
111    buffer = []
112    data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
113    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
114        buffer.append({
115            'label': d['label'].item().decode("utf8"),
116            'sentence1': d['sentence1'].item().decode("utf8"),
117            'sentence2': d['sentence2'].item().decode("utf8")
118        })
119    assert len(buffer) == 3
120
121
122def test_clue_cmnli():
123    """
124    Test CMNLI for train, test and evaluation
125    """
126    TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
127    TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
128    EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
129
130    # train
131    buffer = []
132    data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
133    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
134        buffer.append({
135            'label': d['label'].item().decode("utf8"),
136            'sentence1': d['sentence1'].item().decode("utf8"),
137            'sentence2': d['sentence2'].item().decode("utf8")
138        })
139    assert len(buffer) == 3
140
141    # test
142    buffer = []
143    data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
144    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
145        buffer.append({
146            'id': d['id'],
147            'sentence1': d['sentence1'],
148            'sentence2': d['sentence2']
149        })
150    assert len(buffer) == 3
151
152    # eval
153    buffer = []
154    data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
155    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
156        buffer.append({
157            'label': d['label'],
158            'sentence1': d['sentence1'],
159            'sentence2': d['sentence2']
160        })
161    assert len(buffer) == 3
162
163
164def test_clue_csl():
165    """
166    Test CSL for train, test and evaluation
167    """
168    TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
169    TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
170    EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
171
172    # train
173    buffer = []
174    data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
175    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
176        buffer.append({
177            'id': d['id'],
178            'abst': d['abst'].item().decode("utf8"),
179            'keyword': [i.item().decode("utf8") for i in d['keyword']],
180            'label': d['label'].item().decode("utf8")
181        })
182    assert len(buffer) == 3
183
184    # test
185    buffer = []
186    data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
187    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
188        buffer.append({
189            'id': d['id'],
190            'abst': d['abst'].item().decode("utf8"),
191            'keyword': [i.item().decode("utf8") for i in d['keyword']],
192        })
193    assert len(buffer) == 3
194
195    # eval
196    buffer = []
197    data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
198    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
199        buffer.append({
200            'id': d['id'],
201            'abst': d['abst'].item().decode("utf8"),
202            'keyword': [i.item().decode("utf8") for i in d['keyword']],
203            'label': d['label'].item().decode("utf8")
204        })
205    assert len(buffer) == 3
206
207
208def test_clue_iflytek():
209    """
210    Test IFLYTEK for train, test and evaluation
211    """
212    TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
213    TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
214    EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
215
216    # train
217    buffer = []
218    data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
219    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
220        buffer.append({
221            'label': d['label'].item().decode("utf8"),
222            'label_des': d['label_des'].item().decode("utf8"),
223            'sentence': d['sentence'].item().decode("utf8"),
224        })
225    assert len(buffer) == 3
226
227    # test
228    buffer = []
229    data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
230    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
231        buffer.append({
232            'id': d['id'],
233            'sentence': d['sentence'].item().decode("utf8")
234        })
235    assert len(buffer) == 3
236
237    # eval
238    buffer = []
239    data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
240    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
241        buffer.append({
242            'label': d['label'].item().decode("utf8"),
243            'label_des': d['label_des'].item().decode("utf8"),
244            'sentence': d['sentence'].item().decode("utf8")
245        })
246    assert len(buffer) == 3
247
248
249def test_clue_tnews():
250    """
251    Test TNEWS for train, test and evaluation
252    """
253    TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
254    TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
255    EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
256
257    # train
258    buffer = []
259    data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
260    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
261        buffer.append({
262            'label': d['label'].item().decode("utf8"),
263            'label_desc': d['label_desc'].item().decode("utf8"),
264            'sentence': d['sentence'].item().decode("utf8"),
265            'keywords':
266                d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
267        })
268    assert len(buffer) == 3
269
270    # test
271    buffer = []
272    data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
273    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
274        buffer.append({
275            'id': d['id'],
276            'sentence': d['sentence'].item().decode("utf8"),
277            'keywords':
278                d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
279        })
280    assert len(buffer) == 3
281
282    # eval
283    buffer = []
284    data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
285    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
286        buffer.append({
287            'label': d['label'].item().decode("utf8"),
288            'label_desc': d['label_desc'].item().decode("utf8"),
289            'sentence': d['sentence'].item().decode("utf8"),
290            'keywords':
291                d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
292        })
293    assert len(buffer) == 3
294
295
296def test_clue_wsc():
297    """
298    Test WSC for train, test and evaluation
299    """
300    TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
301    TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
302    EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
303
304    # train
305    buffer = []
306    data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
307    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
308        buffer.append({
309            'span1_index': d['span1_index'],
310            'span2_index': d['span2_index'],
311            'span1_text': d['span1_text'].item().decode("utf8"),
312            'span2_text': d['span2_text'].item().decode("utf8"),
313            'idx': d['idx'],
314            'label': d['label'].item().decode("utf8"),
315            'text': d['text'].item().decode("utf8")
316        })
317    assert len(buffer) == 3
318
319    # test
320    buffer = []
321    data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
322    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
323        buffer.append({
324            'span1_index': d['span1_index'],
325            'span2_index': d['span2_index'],
326            'span1_text': d['span1_text'].item().decode("utf8"),
327            'span2_text': d['span2_text'].item().decode("utf8"),
328            'idx': d['idx'],
329            'text': d['text'].item().decode("utf8")
330        })
331    assert len(buffer) == 3
332
333    # eval
334    buffer = []
335    data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
336    for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
337        buffer.append({
338            'span1_index': d['span1_index'],
339            'span2_index': d['span2_index'],
340            'span1_text': d['span1_text'].item().decode("utf8"),
341            'span2_text': d['span2_text'].item().decode("utf8"),
342            'idx': d['idx'],
343            'label': d['label'].item().decode("utf8"),
344            'text': d['text'].item().decode("utf8")
345        })
346    assert len(buffer) == 3
347
348def test_clue_to_device():
349    """
350    Test CLUE with to_device
351    """
352    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
353    data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
354    data = data.to_device()
355    data.send()
356
357
358def test_clue_invalid_files():
359    """
360    Test CLUE with invalid files
361    """
362    AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
363    afqmc_train_json = os.path.join(AFQMC_DIR)
364    with pytest.raises(ValueError) as info:
365        _ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
366    assert "The following patterns did not match any files" in str(info.value)
367    assert AFQMC_DIR in str(info.value)
368
369
370def test_clue_exception_file_path():
371    """
372    Test file info in err msg when exception occur of CLUE dataset
373    """
374    TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
375    def exception_func(item):
376        raise Exception("Error occur!")
377
378    try:
379        data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
380        data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
381        for _ in data.create_dict_iterator():
382            pass
383        assert False
384    except RuntimeError as e:
385        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
386
387    try:
388        data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
389        data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1)
390        for _ in data.create_dict_iterator():
391            pass
392        assert False
393    except RuntimeError as e:
394        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
395
396    try:
397        data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
398        data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1)
399        for _ in data.create_dict_iterator():
400            pass
401        assert False
402    except RuntimeError as e:
403        assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
404
405
406if __name__ == "__main__":
407    test_clue()
408    test_clue_num_shards()
409    test_clue_num_samples()
410    test_textline_dataset_get_datasetsize()
411    test_clue_afqmc()
412    test_clue_cmnli()
413    test_clue_csl()
414    test_clue_iflytek()
415    test_clue_tnews()
416    test_clue_wsc()
417    test_clue_to_device()
418    test_clue_invalid_files()
419    test_clue_exception_file_path()
420