• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16validators for text ops
17"""
18
19from functools import wraps
20import mindspore.common.dtype as mstype
21
22import mindspore._c_dataengine as cde
23from mindspore._c_expression import typing
24
25from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
26    INT32_MAX, check_value, check_positive, check_pos_int32
27
28
29def check_unique_list_of_words(words, arg_name):
30    """Check that words is a list and each element is a str without any duplication"""
31
32    type_check(words, (list,), arg_name)
33    words_set = set()
34    for word in words:
35        type_check(word, (str,), arg_name)
36        if word in words_set:
37            raise ValueError(arg_name + " contains duplicate word: " + word + ".")
38        words_set.add(word)
39    return words_set
40
41
42def check_lookup(method):
43    """A wrapper that wraps a parameter checker to the original function."""
44
45    @wraps(method)
46    def new_method(self, *args, **kwargs):
47        [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
48
49        if unknown_token is not None:
50            type_check(unknown_token, (str,), "unknown_token")
51
52        type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.")
53        type_check(data_type, (typing.Type,), "data_type")
54
55        return method(self, *args, **kwargs)
56
57    return new_method
58
59
60def check_from_file(method):
61    """A wrapper that wraps a parameter checker to the original function."""
62
63    @wraps(method)
64    def new_method(self, *args, **kwargs):
65        [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
66                                                                                               **kwargs)
67        if special_tokens is not None:
68            check_unique_list_of_words(special_tokens, "special_tokens")
69        type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
70        if vocab_size is not None:
71            check_positive(vocab_size, "vocab_size")
72        type_check(special_first, (bool,), special_first)
73
74        return method(self, *args, **kwargs)
75
76    return new_method
77
78
79def check_from_list(method):
80    """A wrapper that wraps a parameter checker to the original function."""
81
82    @wraps(method)
83    def new_method(self, *args, **kwargs):
84        [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
85
86        word_set = check_unique_list_of_words(word_list, "word_list")
87        if special_tokens is not None:
88            token_set = check_unique_list_of_words(special_tokens, "special_tokens")
89
90            intersect = word_set.intersection(token_set)
91
92            if intersect != set():
93                raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
94
95        type_check(special_first, (bool,), "special_first")
96
97        return method(self, *args, **kwargs)
98
99    return new_method
100
101
102def check_from_dict(method):
103    """A wrapper that wraps a parameter checker to the original function."""
104
105    @wraps(method)
106    def new_method(self, *args, **kwargs):
107        [word_dict], _ = parse_user_args(method, *args, **kwargs)
108
109        type_check(word_dict, (dict,), "word_dict")
110
111        for word, word_id in word_dict.items():
112            type_check(word, (str,), "word")
113            type_check(word_id, (int,), "word_id")
114            check_value(word_id, (0, INT32_MAX), "word_id")
115        return method(self, *args, **kwargs)
116
117    return new_method
118
119
120def check_jieba_init(method):
121    """Wrapper method to check the parameters of jieba init."""
122
123    @wraps(method)
124    def new_method(self, *args, **kwargs):
125        [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs)
126
127        if hmm_path is None:
128            raise ValueError("The dict of HMMSegment in cppjieba is not provided.")
129        if not isinstance(hmm_path, str):
130            raise TypeError("Wrong input type for hmm_path, should be string.")
131        if mp_path is None:
132            raise ValueError("The dict of MPSegment in cppjieba is not provided.")
133        if not isinstance(mp_path, str):
134            raise TypeError("Wrong input type for mp_path, should be string.")
135        if not isinstance(with_offsets, bool):
136            raise TypeError("Wrong input type for with_offsets, should be boolean.")
137        return method(self, *args, **kwargs)
138
139    return new_method
140
141
142def check_jieba_add_word(method):
143    """Wrapper method to check the parameters of jieba add word."""
144
145    @wraps(method)
146    def new_method(self, *args, **kwargs):
147        [word, freq], _ = parse_user_args(method, *args, **kwargs)
148        if word is None:
149            raise ValueError("word is not provided.")
150        if freq is not None:
151            check_uint32(freq)
152        return method(self, *args, **kwargs)
153
154    return new_method
155
156
157def check_jieba_add_dict(method):
158    """Wrapper method to check the parameters of add dict."""
159
160    @wraps(method)
161    def new_method(self, *args, **kwargs):
162        parse_user_args(method, *args, **kwargs)
163        return method(self, *args, **kwargs)
164
165    return new_method
166
167
168def check_with_offsets(method):
169    """Wrapper method to check if with_offsets is the only one parameter."""
170
171    @wraps(method)
172    def new_method(self, *args, **kwargs):
173        [with_offsets], _ = parse_user_args(method, *args, **kwargs)
174        if not isinstance(with_offsets, bool):
175            raise TypeError("Wrong input type for with_offsets, should be boolean.")
176        return method(self, *args, **kwargs)
177
178    return new_method
179
180
181def check_unicode_script_tokenizer(method):
182    """Wrapper method to check the parameter of UnicodeScriptTokenizer."""
183
184    @wraps(method)
185    def new_method(self, *args, **kwargs):
186        [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs)
187        if not isinstance(keep_whitespace, bool):
188            raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
189        if not isinstance(with_offsets, bool):
190            raise TypeError("Wrong input type for with_offsets, should be boolean.")
191        return method(self, *args, **kwargs)
192
193    return new_method
194
195
196def check_wordpiece_tokenizer(method):
197    """Wrapper method to check the parameter of WordpieceTokenizer."""
198
199    @wraps(method)
200    def new_method(self, *args, **kwargs):
201        [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
202            parse_user_args(method, *args, **kwargs)
203        if vocab is None:
204            raise ValueError("vocab is not provided.")
205        if not isinstance(vocab, cde.Vocab):
206            raise TypeError("Wrong input type for vocab, should be Vocab object.")
207        if not isinstance(suffix_indicator, str):
208            raise TypeError("Wrong input type for suffix_indicator, should be string.")
209        if not isinstance(unknown_token, str):
210            raise TypeError("Wrong input type for unknown_token, should be string.")
211        if not isinstance(with_offsets, bool):
212            raise TypeError("Wrong input type for with_offsets, should be boolean.")
213        check_uint32(max_bytes_per_token)
214        return method(self, *args, **kwargs)
215
216    return new_method
217
218
219def check_regex_replace(method):
220    """Wrapper method to check the parameter of RegexReplace."""
221
222    @wraps(method)
223    def new_method(self, *args, **kwargs):
224        [pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs)
225        type_check(pattern, (str,), "pattern")
226        type_check(replace, (str,), "replace")
227        type_check(replace_all, (bool,), "replace_all")
228        return method(self, *args, **kwargs)
229
230    return new_method
231
232
233def check_regex_tokenizer(method):
234    """Wrapper method to check the parameter of RegexTokenizer."""
235
236    @wraps(method)
237    def new_method(self, *args, **kwargs):
238        [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs)
239        if delim_pattern is None:
240            raise ValueError("delim_pattern is not provided.")
241        if not isinstance(delim_pattern, str):
242            raise TypeError("Wrong input type for delim_pattern, should be string.")
243        if not isinstance(keep_delim_pattern, str):
244            raise TypeError("Wrong input type for keep_delim_pattern, should be string.")
245        if not isinstance(with_offsets, bool):
246            raise TypeError("Wrong input type for with_offsets, should be boolean.")
247        return method(self, *args, **kwargs)
248
249    return new_method
250
251
252def check_basic_tokenizer(method):
253    """Wrapper method to check the parameter of RegexTokenizer."""
254
255    @wraps(method)
256    def new_method(self, *args, **kwargs):
257        [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
258            parse_user_args(method, *args, **kwargs)
259        if not isinstance(lower_case, bool):
260            raise TypeError("Wrong input type for lower_case, should be boolean.")
261        if not isinstance(keep_whitespace, bool):
262            raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
263        if not isinstance(preserve_unused, bool):
264            raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
265        if not isinstance(with_offsets, bool):
266            raise TypeError("Wrong input type for with_offsets, should be boolean.")
267        return method(self, *args, **kwargs)
268
269    return new_method
270
271
272def check_bert_tokenizer(method):
273    """Wrapper method to check the parameter of BertTokenizer."""
274
275    @wraps(method)
276    def new_method(self, *args, **kwargs):
277        [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _,
278         preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs)
279        if vocab is None:
280            raise ValueError("vacab is not provided.")
281        if not isinstance(vocab, cde.Vocab):
282            raise TypeError("Wrong input type for vocab, should be Vocab object.")
283        if not isinstance(suffix_indicator, str):
284            raise TypeError("Wrong input type for suffix_indicator, should be string.")
285        if not isinstance(max_bytes_per_token, int):
286            raise TypeError("Wrong input type for max_bytes_per_token, should be int.")
287        check_uint32(max_bytes_per_token)
288
289        if not isinstance(unknown_token, str):
290            raise TypeError("Wrong input type for unknown_token, should be string.")
291        if not isinstance(lower_case, bool):
292            raise TypeError("Wrong input type for lower_case, should be boolean.")
293        if not isinstance(keep_whitespace, bool):
294            raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
295        if not isinstance(preserve_unused_token, bool):
296            raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
297        if not isinstance(with_offsets, bool):
298            raise TypeError("Wrong input type for with_offsets, should be boolean.")
299        return method(self, *args, **kwargs)
300
301    return new_method
302
303
304def check_from_dataset(method):
305    """A wrapper that wraps a parameter checker to the original function."""
306
307    @wraps(method)
308    def new_method(self, *args, **kwargs):
309
310        [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
311                                                                                            **kwargs)
312        if columns is not None:
313            if not isinstance(columns, list):
314                columns = [columns]
315                type_check_list(columns, (str,), "col")
316
317        if freq_range is not None:
318            type_check(freq_range, (tuple,), "freq_range")
319
320            if len(freq_range) != 2:
321                raise ValueError("freq_range needs to be a tuple of 2 element.")
322
323            for num in freq_range:
324                if num is not None and (not isinstance(num, int)):
325                    raise ValueError(
326                        "freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
327
328            if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
329                if freq_range[0] > freq_range[1] or freq_range[0] < 0:
330                    raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
331
332        type_check(top_k, (int, type(None)), "top_k")
333
334        if isinstance(top_k, int):
335            check_positive(top_k, "top_k")
336        type_check(special_first, (bool,), "special_first")
337
338        if special_tokens is not None:
339            check_unique_list_of_words(special_tokens, "special_tokens")
340
341        return method(self, *args, **kwargs)
342
343    return new_method
344
345
346def check_slidingwindow(method):
347    """A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
348
349    @wraps(method)
350    def new_method(self, *args, **kwargs):
351        [width, axis], _ = parse_user_args(method, *args, **kwargs)
352        check_pos_int32(width, "width")
353        type_check(axis, (int,), "axis")
354        return method(self, *args, **kwargs)
355
356    return new_method
357
358
359def check_ngram(method):
360    """A wrapper that wraps a parameter checker to the original function."""
361
362    @wraps(method)
363    def new_method(self, *args, **kwargs):
364        [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
365
366        if isinstance(n, int):
367            n = [n]
368
369        if not (isinstance(n, list) and n != []):
370            raise ValueError("n needs to be a non-empty list of positive integers.")
371
372        for i, gram in enumerate(n):
373            type_check(gram, (int,), "gram[{0}]".format(i))
374            check_positive(gram, "gram_{}".format(i))
375
376        if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
377                left_pad[1], int)):
378            raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
379
380        if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance(
381                right_pad[1], int)):
382            raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
383
384        if not (left_pad[1] >= 0 and right_pad[1] >= 0):
385            raise ValueError("padding width need to be positive numbers.")
386
387        type_check(separator, (str,), "separator")
388
389        kwargs["n"] = n
390        kwargs["left_pad"] = left_pad
391        kwargs["right_pad"] = right_pad
392        kwargs["separator"] = separator
393
394        return method(self, **kwargs)
395
396    return new_method
397
398
399def check_pair_truncate(method):
400    """Wrapper method to check the parameters of number of pair truncate."""
401
402    @wraps(method)
403    def new_method(self, *args, **kwargs):
404        parse_user_args(method, *args, **kwargs)
405        return method(self, *args, **kwargs)
406
407    return new_method
408
409
410def check_to_number(method):
411    """A wrapper that wraps a parameter check to the original function (ToNumber)."""
412
413    @wraps(method)
414    def new_method(self, *args, **kwargs):
415        [data_type], _ = parse_user_args(method, *args, **kwargs)
416        type_check(data_type, (typing.Type,), "data_type")
417
418        if data_type not in mstype.number_type:
419            raise TypeError("data_type: " + str(data_type) + " is not numeric data type.")
420
421        return method(self, *args, **kwargs)
422
423    return new_method
424
425
426def check_python_tokenizer(method):
427    """A wrapper that wraps a parameter check to the original function (PythonTokenizer)."""
428
429    @wraps(method)
430    def new_method(self, *args, **kwargs):
431        [tokenizer], _ = parse_user_args(method, *args, **kwargs)
432
433        if not callable(tokenizer):
434            raise TypeError("tokenizer is not a callable Python function.")
435
436        return method(self, *args, **kwargs)
437
438    return new_method
439
440
441def check_from_dataset_sentencepiece(method):
442    """A wrapper that wraps a parameter checker to the original function (from_dataset)."""
443
444    @wraps(method)
445    def new_method(self, *args, **kwargs):
446        [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
447
448        if col_names is not None:
449            type_check_list(col_names, (str,), "col_names")
450
451        if vocab_size is not None:
452            check_uint32(vocab_size, "vocab_size")
453        else:
454            raise TypeError("vocab_size must be provided.")
455
456        if character_coverage is not None:
457            type_check(character_coverage, (float,), "character_coverage")
458
459        if model_type is not None:
460            from .utils import SentencePieceModel
461            type_check(model_type, (str, SentencePieceModel), "model_type")
462
463        if params is not None:
464            type_check(params, (dict,), "params")
465
466        return method(self, *args, **kwargs)
467
468    return new_method
469
470
471def check_from_file_sentencepiece(method):
472    """A wrapper that wraps a parameter checker to the original function (from_file)."""
473
474    @wraps(method)
475    def new_method(self, *args, **kwargs):
476        [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
477
478        if file_path is not None:
479            type_check(file_path, (list,), "file_path")
480
481        if vocab_size is not None:
482            check_uint32(vocab_size, "vocab_size")
483
484        if character_coverage is not None:
485            type_check(character_coverage, (float,), "character_coverage")
486
487        if model_type is not None:
488            from .utils import SentencePieceModel
489            type_check(model_type, (str, SentencePieceModel), "model_type")
490
491        if params is not None:
492            type_check(params, (dict,), "params")
493
494        return method(self, *args, **kwargs)
495
496    return new_method
497
498
499def check_save_model(method):
500    """A wrapper that wraps a parameter checker to the original function (save_model)."""
501
502    @wraps(method)
503    def new_method(self, *args, **kwargs):
504        [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs)
505
506        if vocab is not None:
507            type_check(vocab, (cde.SentencePieceVocab,), "vocab")
508
509        if path is not None:
510            type_check(path, (str,), "path")
511
512        if filename is not None:
513            type_check(filename, (str,), "filename")
514
515        return method(self, *args, **kwargs)
516
517    return new_method
518
519
520def check_sentence_piece_tokenizer(method):
521
522    """A wrapper that wraps a parameter checker to the original function."""
523
524    from .utils import SPieceTokenizerOutType
525    @wraps(method)
526    def new_method(self, *args, **kwargs):
527        [mode, out_type], _ = parse_user_args(method, *args, **kwargs)
528
529        type_check(mode, (str, cde.SentencePieceVocab), "mode is not an instance of str or cde.SentencePieceVocab.")
530        type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType")
531
532        return method(self, *args, **kwargs)
533
534    return new_method
535