1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import os 16import pytest 17import mindspore.dataset as ds 18 19 20def test_clue(): 21 """ 22 Test CLUE with repeat, skip and so on 23 """ 24 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 25 26 buffer = [] 27 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) 28 data = data.repeat(2) 29 data = data.skip(3) 30 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 31 buffer.append({ 32 'label': d['label'].item().decode("utf8"), 33 'sentence1': d['sentence1'].item().decode("utf8"), 34 'sentence2': d['sentence2'].item().decode("utf8") 35 }) 36 assert len(buffer) == 3 37 38 39def test_clue_num_shards(): 40 """ 41 Test num_shards param of CLUE dataset 42 """ 43 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 44 45 buffer = [] 46 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1) 47 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 48 buffer.append({ 49 'label': d['label'].item().decode("utf8"), 50 'sentence1': d['sentence1'].item().decode("utf8"), 51 'sentence2': d['sentence2'].item().decode("utf8") 52 }) 53 assert len(buffer) == 1 54 55 56def test_clue_num_samples(): 57 """ 58 Test num_samples param of CLUE dataset 59 """ 60 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 61 62 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2) 63 count = 0 64 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): 65 count += 1 66 assert count == 2 67 68 69def test_textline_dataset_get_datasetsize(): 70 """ 71 Test get_dataset_size of CLUE dataset 72 """ 73 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 74 75 data = ds.TextFileDataset(TRAIN_FILE) 76 size = data.get_dataset_size() 77 assert size == 3 78 79 80def test_clue_afqmc(): 81 """ 82 Test AFQMC for train, test and evaluation 83 """ 84 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 85 TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json' 86 EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json' 87 88 # train 89 buffer = [] 90 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) 91 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 92 buffer.append({ 93 'label': d['label'].item().decode("utf8"), 94 'sentence1': d['sentence1'].item().decode("utf8"), 95 'sentence2': d['sentence2'].item().decode("utf8") 96 }) 97 assert len(buffer) == 3 98 99 # test 100 buffer = [] 101 data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False) 102 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 103 buffer.append({ 104 'id': d['id'], 105 'sentence1': d['sentence1'].item().decode("utf8"), 106 'sentence2': d['sentence2'].item().decode("utf8") 107 }) 108 assert len(buffer) == 3 109 110 # evaluation 111 buffer = [] 112 data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False) 113 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 114 buffer.append({ 115 'label': d['label'].item().decode("utf8"), 116 'sentence1': d['sentence1'].item().decode("utf8"), 117 'sentence2': d['sentence2'].item().decode("utf8") 118 }) 119 assert len(buffer) == 3 120 121 122def test_clue_cmnli(): 123 """ 124 Test CMNLI for train, test and evaluation 125 """ 126 TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json' 127 TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json' 128 EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json' 129 130 # train 131 buffer = [] 132 data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False) 133 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 134 buffer.append({ 135 'label': d['label'].item().decode("utf8"), 136 'sentence1': d['sentence1'].item().decode("utf8"), 137 'sentence2': d['sentence2'].item().decode("utf8") 138 }) 139 assert len(buffer) == 3 140 141 # test 142 buffer = [] 143 data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False) 144 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 145 buffer.append({ 146 'id': d['id'], 147 'sentence1': d['sentence1'], 148 'sentence2': d['sentence2'] 149 }) 150 assert len(buffer) == 3 151 152 # eval 153 buffer = [] 154 data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False) 155 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 156 buffer.append({ 157 'label': d['label'], 158 'sentence1': d['sentence1'], 159 'sentence2': d['sentence2'] 160 }) 161 assert len(buffer) == 3 162 163 164def test_clue_csl(): 165 """ 166 Test CSL for train, test and evaluation 167 """ 168 TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json' 169 TEST_FILE = '../data/dataset/testCLUE/csl/test.json' 170 EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json' 171 172 # train 173 buffer = [] 174 data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False) 175 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 176 buffer.append({ 177 'id': d['id'], 178 'abst': d['abst'].item().decode("utf8"), 179 'keyword': [i.item().decode("utf8") for i in d['keyword']], 180 'label': d['label'].item().decode("utf8") 181 }) 182 assert len(buffer) == 3 183 184 # test 185 buffer = [] 186 data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False) 187 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 188 buffer.append({ 189 'id': d['id'], 190 'abst': d['abst'].item().decode("utf8"), 191 'keyword': [i.item().decode("utf8") for i in d['keyword']], 192 }) 193 assert len(buffer) == 3 194 195 # eval 196 buffer = [] 197 data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False) 198 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 199 buffer.append({ 200 'id': d['id'], 201 'abst': d['abst'].item().decode("utf8"), 202 'keyword': [i.item().decode("utf8") for i in d['keyword']], 203 'label': d['label'].item().decode("utf8") 204 }) 205 assert len(buffer) == 3 206 207 208def test_clue_iflytek(): 209 """ 210 Test IFLYTEK for train, test and evaluation 211 """ 212 TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json' 213 TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json' 214 EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json' 215 216 # train 217 buffer = [] 218 data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False) 219 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 220 buffer.append({ 221 'label': d['label'].item().decode("utf8"), 222 'label_des': d['label_des'].item().decode("utf8"), 223 'sentence': d['sentence'].item().decode("utf8"), 224 }) 225 assert len(buffer) == 3 226 227 # test 228 buffer = [] 229 data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False) 230 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 231 buffer.append({ 232 'id': d['id'], 233 'sentence': d['sentence'].item().decode("utf8") 234 }) 235 assert len(buffer) == 3 236 237 # eval 238 buffer = [] 239 data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False) 240 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 241 buffer.append({ 242 'label': d['label'].item().decode("utf8"), 243 'label_des': d['label_des'].item().decode("utf8"), 244 'sentence': d['sentence'].item().decode("utf8") 245 }) 246 assert len(buffer) == 3 247 248 249def test_clue_tnews(): 250 """ 251 Test TNEWS for train, test and evaluation 252 """ 253 TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json' 254 TEST_FILE = '../data/dataset/testCLUE/tnews/test.json' 255 EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json' 256 257 # train 258 buffer = [] 259 data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False) 260 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 261 buffer.append({ 262 'label': d['label'].item().decode("utf8"), 263 'label_desc': d['label_desc'].item().decode("utf8"), 264 'sentence': d['sentence'].item().decode("utf8"), 265 'keywords': 266 d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] 267 }) 268 assert len(buffer) == 3 269 270 # test 271 buffer = [] 272 data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False) 273 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 274 buffer.append({ 275 'id': d['id'], 276 'sentence': d['sentence'].item().decode("utf8"), 277 'keywords': 278 d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] 279 }) 280 assert len(buffer) == 3 281 282 # eval 283 buffer = [] 284 data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False) 285 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 286 buffer.append({ 287 'label': d['label'].item().decode("utf8"), 288 'label_desc': d['label_desc'].item().decode("utf8"), 289 'sentence': d['sentence'].item().decode("utf8"), 290 'keywords': 291 d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords'] 292 }) 293 assert len(buffer) == 3 294 295 296def test_clue_wsc(): 297 """ 298 Test WSC for train, test and evaluation 299 """ 300 TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json' 301 TEST_FILE = '../data/dataset/testCLUE/wsc/test.json' 302 EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json' 303 304 # train 305 buffer = [] 306 data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train') 307 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 308 buffer.append({ 309 'span1_index': d['span1_index'], 310 'span2_index': d['span2_index'], 311 'span1_text': d['span1_text'].item().decode("utf8"), 312 'span2_text': d['span2_text'].item().decode("utf8"), 313 'idx': d['idx'], 314 'label': d['label'].item().decode("utf8"), 315 'text': d['text'].item().decode("utf8") 316 }) 317 assert len(buffer) == 3 318 319 # test 320 buffer = [] 321 data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test') 322 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 323 buffer.append({ 324 'span1_index': d['span1_index'], 325 'span2_index': d['span2_index'], 326 'span1_text': d['span1_text'].item().decode("utf8"), 327 'span2_text': d['span2_text'].item().decode("utf8"), 328 'idx': d['idx'], 329 'text': d['text'].item().decode("utf8") 330 }) 331 assert len(buffer) == 3 332 333 # eval 334 buffer = [] 335 data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval') 336 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): 337 buffer.append({ 338 'span1_index': d['span1_index'], 339 'span2_index': d['span2_index'], 340 'span1_text': d['span1_text'].item().decode("utf8"), 341 'span2_text': d['span2_text'].item().decode("utf8"), 342 'idx': d['idx'], 343 'label': d['label'].item().decode("utf8"), 344 'text': d['text'].item().decode("utf8") 345 }) 346 assert len(buffer) == 3 347 348def test_clue_to_device(): 349 """ 350 Test CLUE with to_device 351 """ 352 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 353 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) 354 data = data.to_device() 355 data.send() 356 357 358def test_clue_invalid_files(): 359 """ 360 Test CLUE with invalid files 361 """ 362 AFQMC_DIR = '../data/dataset/testCLUE/afqmc' 363 afqmc_train_json = os.path.join(AFQMC_DIR) 364 with pytest.raises(ValueError) as info: 365 _ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False) 366 assert "The following patterns did not match any files" in str(info.value) 367 assert AFQMC_DIR in str(info.value) 368 369 370def test_clue_exception_file_path(): 371 """ 372 Test file info in err msg when exception occur of CLUE dataset 373 """ 374 TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' 375 def exception_func(item): 376 raise Exception("Error occur!") 377 378 try: 379 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') 380 data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1) 381 for _ in data.create_dict_iterator(): 382 pass 383 assert False 384 except RuntimeError as e: 385 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 386 387 try: 388 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') 389 data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1) 390 for _ in data.create_dict_iterator(): 391 pass 392 assert False 393 except RuntimeError as e: 394 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 395 396 try: 397 data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') 398 data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1) 399 for _ in data.create_dict_iterator(): 400 pass 401 assert False 402 except RuntimeError as e: 403 assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) 404 405 406if __name__ == "__main__": 407 test_clue() 408 test_clue_num_shards() 409 test_clue_num_samples() 410 test_textline_dataset_get_datasetsize() 411 test_clue_afqmc() 412 test_clue_cmnli() 413 test_clue_csl() 414 test_clue_iflytek() 415 test_clue_tnews() 416 test_clue_wsc() 417 test_clue_to_device() 418 test_clue_invalid_files() 419 test_clue_exception_file_path() 420