1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16validators for text ops 17""" 18from functools import wraps 19import numpy as np 20 21import mindspore._c_dataengine as cde 22import mindspore.common.dtype as mstype 23from mindspore._c_expression import typing 24 25import mindspore.dataset.text as text 26from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ 27 INT32_MAX, check_value, check_positive, check_pos_int32, check_filename, check_non_negative_int32 28 29 30def check_add_token(method): 31 """Wrapper method to check the parameters of add token.""" 32 33 @wraps(method) 34 def new_method(self, *args, **kwargs): 35 [token, begin], _ = parse_user_args(method, *args, **kwargs) 36 type_check(token, (str,), "token") 37 type_check(begin, (bool,), "begin") 38 return method(self, *args, **kwargs) 39 40 return new_method 41 42 43def check_unique_list_of_words(words, arg_name): 44 """Check that words is a list and each element is a str without any duplication""" 45 46 type_check(words, (list,), arg_name) 47 words_set = set() 48 for word in words: 49 type_check(word, (str,), arg_name) 50 if word in words_set: 51 raise ValueError(arg_name + " contains duplicate word: " + word + ".") 52 words_set.add(word) 53 return words_set 54 55 56def check_lookup(method): 57 """A wrapper that wraps a parameter checker to the original function.""" 58 59 @wraps(method) 60 def new_method(self, *args, **kwargs): 61 [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs) 62 63 if unknown_token is not None: 64 type_check(unknown_token, (str,), "unknown_token") 65 66 type_check(vocab, (text.Vocab,), "vocab is not an instance of text.Vocab.") 67 type_check(vocab.c_vocab, (cde.Vocab,), "vocab.c_vocab is not an instance of cde.Vocab.") 68 type_check(data_type, (typing.Type,), "data_type") 69 70 return method(self, *args, **kwargs) 71 72 return new_method 73 74 75def check_from_file(method): 76 """A wrapper that wraps a parameter checker to the original function.""" 77 78 @wraps(method) 79 def new_method(self, *args, **kwargs): 80 [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, 81 **kwargs) 82 if special_tokens is not None: 83 check_unique_list_of_words(special_tokens, "special_tokens") 84 type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) 85 if vocab_size is not None: 86 check_positive(vocab_size, "vocab_size") 87 type_check(special_first, (bool,), special_first) 88 89 return method(self, *args, **kwargs) 90 91 return new_method 92 93 94def check_vocab(c_vocab): 95 """Check the c_vocab of Vocab is initialized or not""" 96 97 if not isinstance(c_vocab, cde.Vocab): 98 error = "The Vocab has not built yet, got type {0}. ".format(type(c_vocab)) 99 suggestion = "Use Vocab.from_dataset(), Vocab.from_list(), Vocab.from_file() or Vocab.from_dict() " \ 100 "to build a Vocab." 101 raise RuntimeError(error + suggestion) 102 103 104def check_tokens_to_ids(method): 105 """A wrapper that wraps a parameter checker to the original function.""" 106 107 @wraps(method) 108 def new_method(self, *args, **kwargs): 109 [tokens], _ = parse_user_args(method, *args, **kwargs) 110 type_check(tokens, (str, list, np.ndarray), "tokens") 111 if isinstance(tokens, list): 112 param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))] 113 type_check_list(tokens, (str, np.str_), param_names) 114 115 return method(self, *args, **kwargs) 116 117 return new_method 118 119 120def check_ids_to_tokens(method): 121 """A wrapper that wraps a parameter checker to the original function.""" 122 123 @wraps(method) 124 def new_method(self, *args, **kwargs): 125 [ids], _ = parse_user_args(method, *args, **kwargs) 126 type_check(ids, (int, list, np.ndarray), "ids") 127 if isinstance(ids, int): 128 check_value(ids, (0, INT32_MAX), "ids") 129 if isinstance(ids, list): 130 for index, id_ in enumerate(ids): 131 type_check(id_, (int, np.int_), "ids[{}]".format(index)) 132 check_value(id_, (0, INT32_MAX), "ids[{}]".format(index)) 133 134 return method(self, *args, **kwargs) 135 136 return new_method 137 138 139def check_from_list(method): 140 """A wrapper that wraps a parameter checker to the original function.""" 141 142 @wraps(method) 143 def new_method(self, *args, **kwargs): 144 [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) 145 146 word_set = check_unique_list_of_words(word_list, "word_list") 147 if special_tokens is not None: 148 token_set = check_unique_list_of_words(special_tokens, "special_tokens") 149 150 intersect = word_set.intersection(token_set) 151 152 if intersect != set(): 153 raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") 154 155 type_check(special_first, (bool,), "special_first") 156 157 return method(self, *args, **kwargs) 158 159 return new_method 160 161 162def check_from_dict(method): 163 """A wrapper that wraps a parameter checker to the original function.""" 164 165 @wraps(method) 166 def new_method(self, *args, **kwargs): 167 [word_dict], _ = parse_user_args(method, *args, **kwargs) 168 169 type_check(word_dict, (dict,), "word_dict") 170 171 for word, word_id in word_dict.items(): 172 type_check(word, (str,), "word") 173 type_check(word_id, (int,), "word_id") 174 check_value(word_id, (0, INT32_MAX), "word_id") 175 return method(self, *args, **kwargs) 176 177 return new_method 178 179 180def check_jieba_init(method): 181 """Wrapper method to check the parameters of jieba init.""" 182 183 @wraps(method) 184 def new_method(self, *args, **kwargs): 185 [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs) 186 187 if hmm_path is None: 188 raise ValueError("The dict of HMMSegment in cppjieba is not provided.") 189 if not isinstance(hmm_path, str): 190 raise TypeError("Wrong input type for hmm_path, should be string.") 191 if mp_path is None: 192 raise ValueError("The dict of MPSegment in cppjieba is not provided.") 193 if not isinstance(mp_path, str): 194 raise TypeError("Wrong input type for mp_path, should be string.") 195 if not isinstance(with_offsets, bool): 196 raise TypeError("Wrong input type for with_offsets, should be boolean.") 197 return method(self, *args, **kwargs) 198 199 return new_method 200 201 202def check_jieba_add_word(method): 203 """Wrapper method to check the parameters of jieba add word.""" 204 205 @wraps(method) 206 def new_method(self, *args, **kwargs): 207 [word, freq], _ = parse_user_args(method, *args, **kwargs) 208 if word is None: 209 raise ValueError("word is not provided.") 210 if freq is not None: 211 check_uint32(freq) 212 return method(self, *args, **kwargs) 213 214 return new_method 215 216 217def check_jieba_add_dict(method): 218 """Wrapper method to check the parameters of add dict.""" 219 220 @wraps(method) 221 def new_method(self, *args, **kwargs): 222 parse_user_args(method, *args, **kwargs) 223 return method(self, *args, **kwargs) 224 225 return new_method 226 227 228def check_with_offsets(method): 229 """Wrapper method to check if with_offsets is the only one parameter.""" 230 231 @wraps(method) 232 def new_method(self, *args, **kwargs): 233 [with_offsets], _ = parse_user_args(method, *args, **kwargs) 234 if not isinstance(with_offsets, bool): 235 raise TypeError("Wrong input type for with_offsets, should be boolean.") 236 return method(self, *args, **kwargs) 237 238 return new_method 239 240 241def check_unicode_script_tokenizer(method): 242 """Wrapper method to check the parameter of UnicodeScriptTokenizer.""" 243 244 @wraps(method) 245 def new_method(self, *args, **kwargs): 246 [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs) 247 if not isinstance(keep_whitespace, bool): 248 raise TypeError("Wrong input type for keep_whitespace, should be boolean.") 249 if not isinstance(with_offsets, bool): 250 raise TypeError("Wrong input type for with_offsets, should be boolean.") 251 return method(self, *args, **kwargs) 252 253 return new_method 254 255 256def check_wordpiece_tokenizer(method): 257 """Wrapper method to check the parameter of WordpieceTokenizer.""" 258 259 @wraps(method) 260 def new_method(self, *args, **kwargs): 261 [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \ 262 parse_user_args(method, *args, **kwargs) 263 if vocab is None: 264 raise ValueError("vocab is not provided.") 265 if not isinstance(vocab, text.Vocab): 266 raise TypeError("Wrong input type for vocab, should be text.Vocab object.") 267 if not isinstance(suffix_indicator, str): 268 raise TypeError("Wrong input type for suffix_indicator, should be string.") 269 if not isinstance(unknown_token, str): 270 raise TypeError("Wrong input type for unknown_token, should be string.") 271 if not isinstance(with_offsets, bool): 272 raise TypeError("Wrong input type for with_offsets, should be boolean.") 273 check_uint32(max_bytes_per_token) 274 return method(self, *args, **kwargs) 275 276 return new_method 277 278 279def check_regex_replace(method): 280 """Wrapper method to check the parameter of RegexReplace.""" 281 282 @wraps(method) 283 def new_method(self, *args, **kwargs): 284 [pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs) 285 type_check(pattern, (str,), "pattern") 286 type_check(replace, (str,), "replace") 287 type_check(replace_all, (bool,), "replace_all") 288 return method(self, *args, **kwargs) 289 290 return new_method 291 292 293def check_regex_tokenizer(method): 294 """Wrapper method to check the parameter of RegexTokenizer.""" 295 296 @wraps(method) 297 def new_method(self, *args, **kwargs): 298 [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs) 299 if delim_pattern is None: 300 raise ValueError("delim_pattern is not provided.") 301 if not isinstance(delim_pattern, str): 302 raise TypeError("Wrong input type for delim_pattern, should be string.") 303 if not isinstance(keep_delim_pattern, str): 304 raise TypeError("Wrong input type for keep_delim_pattern, should be string.") 305 if not isinstance(with_offsets, bool): 306 raise TypeError("Wrong input type for with_offsets, should be boolean.") 307 return method(self, *args, **kwargs) 308 309 return new_method 310 311 312def check_basic_tokenizer(method): 313 """Wrapper method to check the parameter of RegexTokenizer.""" 314 315 @wraps(method) 316 def new_method(self, *args, **kwargs): 317 [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \ 318 parse_user_args(method, *args, **kwargs) 319 if not isinstance(lower_case, bool): 320 raise TypeError("Wrong input type for lower_case, should be boolean.") 321 if not isinstance(keep_whitespace, bool): 322 raise TypeError("Wrong input type for keep_whitespace, should be boolean.") 323 if not isinstance(preserve_unused, bool): 324 raise TypeError("Wrong input type for preserve_unused_token, should be boolean.") 325 if not isinstance(with_offsets, bool): 326 raise TypeError("Wrong input type for with_offsets, should be boolean.") 327 return method(self, *args, **kwargs) 328 329 return new_method 330 331 332def check_bert_tokenizer(method): 333 """Wrapper method to check the parameter of BertTokenizer.""" 334 335 @wraps(method) 336 def new_method(self, *args, **kwargs): 337 [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _, 338 preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs) 339 if vocab is None: 340 raise ValueError("vacab is not provided.") 341 if not isinstance(vocab, text.Vocab): 342 raise TypeError("Wrong input type for vocab, should be text.Vocab object.") 343 if not isinstance(suffix_indicator, str): 344 raise TypeError("Wrong input type for suffix_indicator, should be string.") 345 if not isinstance(max_bytes_per_token, int): 346 raise TypeError("Wrong input type for max_bytes_per_token, should be int.") 347 check_uint32(max_bytes_per_token) 348 349 if not isinstance(unknown_token, str): 350 raise TypeError("Wrong input type for unknown_token, should be string.") 351 if not isinstance(lower_case, bool): 352 raise TypeError("Wrong input type for lower_case, should be boolean.") 353 if not isinstance(keep_whitespace, bool): 354 raise TypeError("Wrong input type for keep_whitespace, should be boolean.") 355 if not isinstance(preserve_unused_token, bool): 356 raise TypeError("Wrong input type for preserve_unused_token, should be boolean.") 357 if not isinstance(with_offsets, bool): 358 raise TypeError("Wrong input type for with_offsets, should be boolean.") 359 return method(self, *args, **kwargs) 360 361 return new_method 362 363 364def check_from_dataset(method): 365 """A wrapper that wraps a parameter checker to the original function.""" 366 367 @wraps(method) 368 def new_method(self, *args, **kwargs): 369 370 [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args, 371 **kwargs) 372 if columns is not None: 373 if not isinstance(columns, list): 374 columns = [columns] 375 type_check_list(columns, (str,), "col") 376 377 if freq_range is not None: 378 type_check(freq_range, (tuple,), "freq_range") 379 380 if len(freq_range) != 2: 381 raise ValueError("freq_range needs to be a tuple of 2 element.") 382 383 for num in freq_range: 384 if num is not None and (not isinstance(num, int)): 385 raise ValueError( 386 "freq_range needs to be either None or a tuple of 2 integers or an int and a None.") 387 388 if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): 389 if freq_range[0] > freq_range[1] or freq_range[0] < 0: 390 raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") 391 392 type_check(top_k, (int, type(None)), "top_k") 393 394 if isinstance(top_k, int): 395 check_positive(top_k, "top_k") 396 type_check(special_first, (bool,), "special_first") 397 398 if special_tokens is not None: 399 check_unique_list_of_words(special_tokens, "special_tokens") 400 401 return method(self, *args, **kwargs) 402 403 return new_method 404 405 406def check_slidingwindow(method): 407 """A wrapper that wraps a parameter checker to the original function(sliding window operation).""" 408 409 @wraps(method) 410 def new_method(self, *args, **kwargs): 411 [width, axis], _ = parse_user_args(method, *args, **kwargs) 412 check_pos_int32(width, "width") 413 type_check(axis, (int,), "axis") 414 return method(self, *args, **kwargs) 415 416 return new_method 417 418 419def check_ngram(method): 420 """A wrapper that wraps a parameter checker to the original function.""" 421 422 @wraps(method) 423 def new_method(self, *args, **kwargs): 424 [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs) 425 426 if isinstance(n, int): 427 n = [n] 428 429 if not (isinstance(n, list) and n != []): 430 raise ValueError("n needs to be a non-empty list of positive integers.") 431 432 for i, gram in enumerate(n): 433 type_check(gram, (int,), "gram[{0}]".format(i)) 434 check_positive(gram, "gram_{}".format(i)) 435 436 if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( 437 left_pad[1], int)): 438 raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.") 439 440 if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance( 441 right_pad[1], int)): 442 raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.") 443 444 if not (left_pad[1] >= 0 and right_pad[1] >= 0): 445 raise ValueError("padding width need to be positive numbers.") 446 447 type_check(separator, (str,), "separator") 448 449 kwargs["n"] = n 450 kwargs["left_pad"] = left_pad 451 kwargs["right_pad"] = right_pad 452 kwargs["separator"] = separator 453 454 return method(self, **kwargs) 455 456 return new_method 457 458 459def check_truncate(method): 460 """Wrapper method to check the parameters of number of truncate.""" 461 462 @wraps(method) 463 def new_method(self, *args, **kwargs): 464 [max_seq_len], _ = parse_user_args(method, *args, **kwargs) 465 check_pos_int32(max_seq_len, "max_seq_len") 466 return method(self, *args, **kwargs) 467 468 return new_method 469 470 471def check_pair_truncate(method): 472 """Wrapper method to check the parameters of number of pair truncate.""" 473 474 @wraps(method) 475 def new_method(self, *args, **kwargs): 476 parse_user_args(method, *args, **kwargs) 477 return method(self, *args, **kwargs) 478 479 return new_method 480 481 482def check_to_number(method): 483 """A wrapper that wraps a parameter check to the original function (ToNumber).""" 484 485 @wraps(method) 486 def new_method(self, *args, **kwargs): 487 [data_type], _ = parse_user_args(method, *args, **kwargs) 488 type_check(data_type, (typing.Type,), "data_type") 489 490 if data_type not in mstype.number_type: 491 raise TypeError("data_type: " + str(data_type) + " is not numeric data type.") 492 493 return method(self, *args, **kwargs) 494 495 return new_method 496 497 498def check_python_tokenizer(method): 499 """A wrapper that wraps a parameter check to the original function (PythonTokenizer).""" 500 501 @wraps(method) 502 def new_method(self, *args, **kwargs): 503 [tokenizer], _ = parse_user_args(method, *args, **kwargs) 504 505 if not callable(tokenizer): 506 raise TypeError("tokenizer is not a callable Python function.") 507 508 return method(self, *args, **kwargs) 509 510 return new_method 511 512 513def check_from_dataset_sentencepiece(method): 514 """A wrapper that wraps a parameter checker to the original function (from_dataset).""" 515 516 @wraps(method) 517 def new_method(self, *args, **kwargs): 518 [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) 519 520 if col_names is not None: 521 type_check_list(col_names, (str,), "col_names") 522 523 if vocab_size is not None: 524 check_uint32(vocab_size, "vocab_size") 525 else: 526 raise TypeError("vocab_size must be provided.") 527 528 if character_coverage is not None: 529 type_check(character_coverage, (float,), "character_coverage") 530 531 if model_type is not None: 532 from .utils import SentencePieceModel 533 type_check(model_type, (str, SentencePieceModel), "model_type") 534 535 if params is not None: 536 type_check(params, (dict,), "params") 537 538 return method(self, *args, **kwargs) 539 540 return new_method 541 542 543def check_from_file_sentencepiece(method): 544 """A wrapper that wraps a parameter checker to the original function (from_file).""" 545 546 @wraps(method) 547 def new_method(self, *args, **kwargs): 548 [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) 549 550 if file_path is not None: 551 type_check(file_path, (list,), "file_path") 552 553 if vocab_size is not None: 554 check_uint32(vocab_size, "vocab_size") 555 556 if character_coverage is not None: 557 type_check(character_coverage, (float,), "character_coverage") 558 559 if model_type is not None: 560 from .utils import SentencePieceModel 561 type_check(model_type, (str, SentencePieceModel), "model_type") 562 563 if params is not None: 564 type_check(params, (dict,), "params") 565 566 return method(self, *args, **kwargs) 567 568 return new_method 569 570 571def check_save_model(method): 572 """A wrapper that wraps a parameter checker to the original function (save_model).""" 573 574 @wraps(method) 575 def new_method(self, *args, **kwargs): 576 [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs) 577 578 if vocab is not None: 579 type_check(vocab, (text.SentencePieceVocab,), "vocab") 580 581 if path is not None: 582 type_check(path, (str,), "path") 583 584 if filename is not None: 585 type_check(filename, (str,), "filename") 586 587 return method(self, *args, **kwargs) 588 589 return new_method 590 591 592def check_sentence_piece_tokenizer(method): 593 594 """A wrapper that wraps a parameter checker to the original function.""" 595 596 from .utils import SPieceTokenizerOutType 597 @wraps(method) 598 def new_method(self, *args, **kwargs): 599 [mode, out_type], _ = parse_user_args(method, *args, **kwargs) 600 601 type_check(mode, (str, text.SentencePieceVocab), "mode is not an instance of str or text.SentencePieceVocab.") 602 type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType") 603 604 return method(self, *args, **kwargs) 605 606 return new_method 607 608 609def check_from_file_vectors(method): 610 """A wrapper that wraps a parameter checker to from_file of class Vectors.""" 611 612 @wraps(method) 613 def new_method(self, *args, **kwargs): 614 [file_path, max_vectors], _ = parse_user_args(method, *args, **kwargs) 615 616 type_check(file_path, (str,), "file_path") 617 check_filename(file_path) 618 if max_vectors is not None: 619 type_check(max_vectors, (int,), "max_vectors") 620 check_non_negative_int32(max_vectors, "max_vectors") 621 622 return method(self, *args, **kwargs) 623 624 return new_method 625 626 627def check_to_vectors(method): 628 """A wrapper that wraps a parameter checker to ToVectors.""" 629 630 @wraps(method) 631 def new_method(self, *args, **kwargs): 632 [vectors, unk_init, lower_case_backup], _ = parse_user_args(method, *args, **kwargs) 633 634 type_check(vectors, (cde.Vectors,), "vectors") 635 if unk_init is not None: 636 type_check(unk_init, (list, tuple), "unk_init") 637 for i, value in enumerate(unk_init): 638 type_check(value, (int, float), "unk_init[{0}]".format(i)) 639 type_check(lower_case_backup, (bool,), "lower_case_backup") 640 return method(self, *args, **kwargs) 641 642 return new_method 643