1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15""" 16validators for text ops 17""" 18 19from functools import wraps 20import mindspore.common.dtype as mstype 21 22import mindspore._c_dataengine as cde 23from mindspore._c_expression import typing 24 25from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ 26 INT32_MAX, check_value, check_positive, check_pos_int32 27 28 29def check_unique_list_of_words(words, arg_name): 30 """Check that words is a list and each element is a str without any duplication""" 31 32 type_check(words, (list,), arg_name) 33 words_set = set() 34 for word in words: 35 type_check(word, (str,), arg_name) 36 if word in words_set: 37 raise ValueError(arg_name + " contains duplicate word: " + word + ".") 38 words_set.add(word) 39 return words_set 40 41 42def check_lookup(method): 43 """A wrapper that wraps a parameter checker to the original function.""" 44 45 @wraps(method) 46 def new_method(self, *args, **kwargs): 47 [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs) 48 49 if unknown_token is not None: 50 type_check(unknown_token, (str,), "unknown_token") 51 52 type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") 53 type_check(data_type, (typing.Type,), "data_type") 54 55 return method(self, *args, **kwargs) 56 57 return new_method 58 59 60def check_from_file(method): 61 """A wrapper that wraps a parameter checker to the original function.""" 62 63 @wraps(method) 64 def new_method(self, *args, **kwargs): 65 [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, 66 **kwargs) 67 if special_tokens is not None: 68 check_unique_list_of_words(special_tokens, "special_tokens") 69 type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) 70 if vocab_size is not None: 71 check_positive(vocab_size, "vocab_size") 72 type_check(special_first, (bool,), special_first) 73 74 return method(self, *args, **kwargs) 75 76 return new_method 77 78 79def check_from_list(method): 80 """A wrapper that wraps a parameter checker to the original function.""" 81 82 @wraps(method) 83 def new_method(self, *args, **kwargs): 84 [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) 85 86 word_set = check_unique_list_of_words(word_list, "word_list") 87 if special_tokens is not None: 88 token_set = check_unique_list_of_words(special_tokens, "special_tokens") 89 90 intersect = word_set.intersection(token_set) 91 92 if intersect != set(): 93 raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") 94 95 type_check(special_first, (bool,), "special_first") 96 97 return method(self, *args, **kwargs) 98 99 return new_method 100 101 102def check_from_dict(method): 103 """A wrapper that wraps a parameter checker to the original function.""" 104 105 @wraps(method) 106 def new_method(self, *args, **kwargs): 107 [word_dict], _ = parse_user_args(method, *args, **kwargs) 108 109 type_check(word_dict, (dict,), "word_dict") 110 111 for word, word_id in word_dict.items(): 112 type_check(word, (str,), "word") 113 type_check(word_id, (int,), "word_id") 114 check_value(word_id, (0, INT32_MAX), "word_id") 115 return method(self, *args, **kwargs) 116 117 return new_method 118 119 120def check_jieba_init(method): 121 """Wrapper method to check the parameters of jieba init.""" 122 123 @wraps(method) 124 def new_method(self, *args, **kwargs): 125 [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs) 126 127 if hmm_path is None: 128 raise ValueError("The dict of HMMSegment in cppjieba is not provided.") 129 if not isinstance(hmm_path, str): 130 raise TypeError("Wrong input type for hmm_path, should be string.") 131 if mp_path is None: 132 raise ValueError("The dict of MPSegment in cppjieba is not provided.") 133 if not isinstance(mp_path, str): 134 raise TypeError("Wrong input type for mp_path, should be string.") 135 if not isinstance(with_offsets, bool): 136 raise TypeError("Wrong input type for with_offsets, should be boolean.") 137 return method(self, *args, **kwargs) 138 139 return new_method 140 141 142def check_jieba_add_word(method): 143 """Wrapper method to check the parameters of jieba add word.""" 144 145 @wraps(method) 146 def new_method(self, *args, **kwargs): 147 [word, freq], _ = parse_user_args(method, *args, **kwargs) 148 if word is None: 149 raise ValueError("word is not provided.") 150 if freq is not None: 151 check_uint32(freq) 152 return method(self, *args, **kwargs) 153 154 return new_method 155 156 157def check_jieba_add_dict(method): 158 """Wrapper method to check the parameters of add dict.""" 159 160 @wraps(method) 161 def new_method(self, *args, **kwargs): 162 parse_user_args(method, *args, **kwargs) 163 return method(self, *args, **kwargs) 164 165 return new_method 166 167 168def check_with_offsets(method): 169 """Wrapper method to check if with_offsets is the only one parameter.""" 170 171 @wraps(method) 172 def new_method(self, *args, **kwargs): 173 [with_offsets], _ = parse_user_args(method, *args, **kwargs) 174 if not isinstance(with_offsets, bool): 175 raise TypeError("Wrong input type for with_offsets, should be boolean.") 176 return method(self, *args, **kwargs) 177 178 return new_method 179 180 181def check_unicode_script_tokenizer(method): 182 """Wrapper method to check the parameter of UnicodeScriptTokenizer.""" 183 184 @wraps(method) 185 def new_method(self, *args, **kwargs): 186 [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs) 187 if not isinstance(keep_whitespace, bool): 188 raise TypeError("Wrong input type for keep_whitespace, should be boolean.") 189 if not isinstance(with_offsets, bool): 190 raise TypeError("Wrong input type for with_offsets, should be boolean.") 191 return method(self, *args, **kwargs) 192 193 return new_method 194 195 196def check_wordpiece_tokenizer(method): 197 """Wrapper method to check the parameter of WordpieceTokenizer.""" 198 199 @wraps(method) 200 def new_method(self, *args, **kwargs): 201 [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \ 202 parse_user_args(method, *args, **kwargs) 203 if vocab is None: 204 raise ValueError("vocab is not provided.") 205 if not isinstance(vocab, cde.Vocab): 206 raise TypeError("Wrong input type for vocab, should be Vocab object.") 207 if not isinstance(suffix_indicator, str): 208 raise TypeError("Wrong input type for suffix_indicator, should be string.") 209 if not isinstance(unknown_token, str): 210 raise TypeError("Wrong input type for unknown_token, should be string.") 211 if not isinstance(with_offsets, bool): 212 raise TypeError("Wrong input type for with_offsets, should be boolean.") 213 check_uint32(max_bytes_per_token) 214 return method(self, *args, **kwargs) 215 216 return new_method 217 218 219def check_regex_replace(method): 220 """Wrapper method to check the parameter of RegexReplace.""" 221 222 @wraps(method) 223 def new_method(self, *args, **kwargs): 224 [pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs) 225 type_check(pattern, (str,), "pattern") 226 type_check(replace, (str,), "replace") 227 type_check(replace_all, (bool,), "replace_all") 228 return method(self, *args, **kwargs) 229 230 return new_method 231 232 233def check_regex_tokenizer(method): 234 """Wrapper method to check the parameter of RegexTokenizer.""" 235 236 @wraps(method) 237 def new_method(self, *args, **kwargs): 238 [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs) 239 if delim_pattern is None: 240 raise ValueError("delim_pattern is not provided.") 241 if not isinstance(delim_pattern, str): 242 raise TypeError("Wrong input type for delim_pattern, should be string.") 243 if not isinstance(keep_delim_pattern, str): 244 raise TypeError("Wrong input type for keep_delim_pattern, should be string.") 245 if not isinstance(with_offsets, bool): 246 raise TypeError("Wrong input type for with_offsets, should be boolean.") 247 return method(self, *args, **kwargs) 248 249 return new_method 250 251 252def check_basic_tokenizer(method): 253 """Wrapper method to check the parameter of RegexTokenizer.""" 254 255 @wraps(method) 256 def new_method(self, *args, **kwargs): 257 [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \ 258 parse_user_args(method, *args, **kwargs) 259 if not isinstance(lower_case, bool): 260 raise TypeError("Wrong input type for lower_case, should be boolean.") 261 if not isinstance(keep_whitespace, bool): 262 raise TypeError("Wrong input type for keep_whitespace, should be boolean.") 263 if not isinstance(preserve_unused, bool): 264 raise TypeError("Wrong input type for preserve_unused_token, should be boolean.") 265 if not isinstance(with_offsets, bool): 266 raise TypeError("Wrong input type for with_offsets, should be boolean.") 267 return method(self, *args, **kwargs) 268 269 return new_method 270 271 272def check_bert_tokenizer(method): 273 """Wrapper method to check the parameter of BertTokenizer.""" 274 275 @wraps(method) 276 def new_method(self, *args, **kwargs): 277 [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _, 278 preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs) 279 if vocab is None: 280 raise ValueError("vacab is not provided.") 281 if not isinstance(vocab, cde.Vocab): 282 raise TypeError("Wrong input type for vocab, should be Vocab object.") 283 if not isinstance(suffix_indicator, str): 284 raise TypeError("Wrong input type for suffix_indicator, should be string.") 285 if not isinstance(max_bytes_per_token, int): 286 raise TypeError("Wrong input type for max_bytes_per_token, should be int.") 287 check_uint32(max_bytes_per_token) 288 289 if not isinstance(unknown_token, str): 290 raise TypeError("Wrong input type for unknown_token, should be string.") 291 if not isinstance(lower_case, bool): 292 raise TypeError("Wrong input type for lower_case, should be boolean.") 293 if not isinstance(keep_whitespace, bool): 294 raise TypeError("Wrong input type for keep_whitespace, should be boolean.") 295 if not isinstance(preserve_unused_token, bool): 296 raise TypeError("Wrong input type for preserve_unused_token, should be boolean.") 297 if not isinstance(with_offsets, bool): 298 raise TypeError("Wrong input type for with_offsets, should be boolean.") 299 return method(self, *args, **kwargs) 300 301 return new_method 302 303 304def check_from_dataset(method): 305 """A wrapper that wraps a parameter checker to the original function.""" 306 307 @wraps(method) 308 def new_method(self, *args, **kwargs): 309 310 [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args, 311 **kwargs) 312 if columns is not None: 313 if not isinstance(columns, list): 314 columns = [columns] 315 type_check_list(columns, (str,), "col") 316 317 if freq_range is not None: 318 type_check(freq_range, (tuple,), "freq_range") 319 320 if len(freq_range) != 2: 321 raise ValueError("freq_range needs to be a tuple of 2 element.") 322 323 for num in freq_range: 324 if num is not None and (not isinstance(num, int)): 325 raise ValueError( 326 "freq_range needs to be either None or a tuple of 2 integers or an int and a None.") 327 328 if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): 329 if freq_range[0] > freq_range[1] or freq_range[0] < 0: 330 raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") 331 332 type_check(top_k, (int, type(None)), "top_k") 333 334 if isinstance(top_k, int): 335 check_positive(top_k, "top_k") 336 type_check(special_first, (bool,), "special_first") 337 338 if special_tokens is not None: 339 check_unique_list_of_words(special_tokens, "special_tokens") 340 341 return method(self, *args, **kwargs) 342 343 return new_method 344 345 346def check_slidingwindow(method): 347 """A wrapper that wraps a parameter checker to the original function(sliding window operation).""" 348 349 @wraps(method) 350 def new_method(self, *args, **kwargs): 351 [width, axis], _ = parse_user_args(method, *args, **kwargs) 352 check_pos_int32(width, "width") 353 type_check(axis, (int,), "axis") 354 return method(self, *args, **kwargs) 355 356 return new_method 357 358 359def check_ngram(method): 360 """A wrapper that wraps a parameter checker to the original function.""" 361 362 @wraps(method) 363 def new_method(self, *args, **kwargs): 364 [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs) 365 366 if isinstance(n, int): 367 n = [n] 368 369 if not (isinstance(n, list) and n != []): 370 raise ValueError("n needs to be a non-empty list of positive integers.") 371 372 for i, gram in enumerate(n): 373 type_check(gram, (int,), "gram[{0}]".format(i)) 374 check_positive(gram, "gram_{}".format(i)) 375 376 if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( 377 left_pad[1], int)): 378 raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.") 379 380 if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance( 381 right_pad[1], int)): 382 raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.") 383 384 if not (left_pad[1] >= 0 and right_pad[1] >= 0): 385 raise ValueError("padding width need to be positive numbers.") 386 387 type_check(separator, (str,), "separator") 388 389 kwargs["n"] = n 390 kwargs["left_pad"] = left_pad 391 kwargs["right_pad"] = right_pad 392 kwargs["separator"] = separator 393 394 return method(self, **kwargs) 395 396 return new_method 397 398 399def check_pair_truncate(method): 400 """Wrapper method to check the parameters of number of pair truncate.""" 401 402 @wraps(method) 403 def new_method(self, *args, **kwargs): 404 parse_user_args(method, *args, **kwargs) 405 return method(self, *args, **kwargs) 406 407 return new_method 408 409 410def check_to_number(method): 411 """A wrapper that wraps a parameter check to the original function (ToNumber).""" 412 413 @wraps(method) 414 def new_method(self, *args, **kwargs): 415 [data_type], _ = parse_user_args(method, *args, **kwargs) 416 type_check(data_type, (typing.Type,), "data_type") 417 418 if data_type not in mstype.number_type: 419 raise TypeError("data_type: " + str(data_type) + " is not numeric data type.") 420 421 return method(self, *args, **kwargs) 422 423 return new_method 424 425 426def check_python_tokenizer(method): 427 """A wrapper that wraps a parameter check to the original function (PythonTokenizer).""" 428 429 @wraps(method) 430 def new_method(self, *args, **kwargs): 431 [tokenizer], _ = parse_user_args(method, *args, **kwargs) 432 433 if not callable(tokenizer): 434 raise TypeError("tokenizer is not a callable Python function.") 435 436 return method(self, *args, **kwargs) 437 438 return new_method 439 440 441def check_from_dataset_sentencepiece(method): 442 """A wrapper that wraps a parameter checker to the original function (from_dataset).""" 443 444 @wraps(method) 445 def new_method(self, *args, **kwargs): 446 [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) 447 448 if col_names is not None: 449 type_check_list(col_names, (str,), "col_names") 450 451 if vocab_size is not None: 452 check_uint32(vocab_size, "vocab_size") 453 else: 454 raise TypeError("vocab_size must be provided.") 455 456 if character_coverage is not None: 457 type_check(character_coverage, (float,), "character_coverage") 458 459 if model_type is not None: 460 from .utils import SentencePieceModel 461 type_check(model_type, (str, SentencePieceModel), "model_type") 462 463 if params is not None: 464 type_check(params, (dict,), "params") 465 466 return method(self, *args, **kwargs) 467 468 return new_method 469 470 471def check_from_file_sentencepiece(method): 472 """A wrapper that wraps a parameter checker to the original function (from_file).""" 473 474 @wraps(method) 475 def new_method(self, *args, **kwargs): 476 [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs) 477 478 if file_path is not None: 479 type_check(file_path, (list,), "file_path") 480 481 if vocab_size is not None: 482 check_uint32(vocab_size, "vocab_size") 483 484 if character_coverage is not None: 485 type_check(character_coverage, (float,), "character_coverage") 486 487 if model_type is not None: 488 from .utils import SentencePieceModel 489 type_check(model_type, (str, SentencePieceModel), "model_type") 490 491 if params is not None: 492 type_check(params, (dict,), "params") 493 494 return method(self, *args, **kwargs) 495 496 return new_method 497 498 499def check_save_model(method): 500 """A wrapper that wraps a parameter checker to the original function (save_model).""" 501 502 @wraps(method) 503 def new_method(self, *args, **kwargs): 504 [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs) 505 506 if vocab is not None: 507 type_check(vocab, (cde.SentencePieceVocab,), "vocab") 508 509 if path is not None: 510 type_check(path, (str,), "path") 511 512 if filename is not None: 513 type_check(filename, (str,), "filename") 514 515 return method(self, *args, **kwargs) 516 517 return new_method 518 519 520def check_sentence_piece_tokenizer(method): 521 522 """A wrapper that wraps a parameter checker to the original function.""" 523 524 from .utils import SPieceTokenizerOutType 525 @wraps(method) 526 def new_method(self, *args, **kwargs): 527 [mode, out_type], _ = parse_user_args(method, *args, **kwargs) 528 529 type_check(mode, (str, cde.SentencePieceVocab), "mode is not an instance of str or cde.SentencePieceVocab.") 530 type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType") 531 532 return method(self, *args, **kwargs) 533 534 return new_method 535