• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""
16validators for text ops
17"""
18from functools import wraps
19import numpy as np
20
21import mindspore._c_dataengine as cde
22import mindspore.common.dtype as mstype
23from mindspore._c_expression import typing
24
25import mindspore.dataset.text as text
26from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \
27    INT32_MAX, check_value, check_positive, check_pos_int32, check_filename, check_non_negative_int32
28
29
30def check_add_token(method):
31    """Wrapper method to check the parameters of add token."""
32
33    @wraps(method)
34    def new_method(self, *args, **kwargs):
35        [token, begin], _ = parse_user_args(method, *args, **kwargs)
36        type_check(token, (str,), "token")
37        type_check(begin, (bool,), "begin")
38        return method(self, *args, **kwargs)
39
40    return new_method
41
42
43def check_unique_list_of_words(words, arg_name):
44    """Check that words is a list and each element is a str without any duplication"""
45
46    type_check(words, (list,), arg_name)
47    words_set = set()
48    for word in words:
49        type_check(word, (str,), arg_name)
50        if word in words_set:
51            raise ValueError(arg_name + " contains duplicate word: " + word + ".")
52        words_set.add(word)
53    return words_set
54
55
56def check_lookup(method):
57    """A wrapper that wraps a parameter checker to the original function."""
58
59    @wraps(method)
60    def new_method(self, *args, **kwargs):
61        [vocab, unknown_token, data_type], _ = parse_user_args(method, *args, **kwargs)
62
63        if unknown_token is not None:
64            type_check(unknown_token, (str,), "unknown_token")
65
66        type_check(vocab, (text.Vocab,), "vocab is not an instance of text.Vocab.")
67        type_check(vocab.c_vocab, (cde.Vocab,), "vocab.c_vocab is not an instance of cde.Vocab.")
68        type_check(data_type, (typing.Type,), "data_type")
69
70        return method(self, *args, **kwargs)
71
72    return new_method
73
74
75def check_from_file(method):
76    """A wrapper that wraps a parameter checker to the original function."""
77
78    @wraps(method)
79    def new_method(self, *args, **kwargs):
80        [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args,
81                                                                                               **kwargs)
82        if special_tokens is not None:
83            check_unique_list_of_words(special_tokens, "special_tokens")
84        type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"])
85        if vocab_size is not None:
86            check_positive(vocab_size, "vocab_size")
87        type_check(special_first, (bool,), special_first)
88
89        return method(self, *args, **kwargs)
90
91    return new_method
92
93
94def check_vocab(c_vocab):
95    """Check the c_vocab of Vocab is initialized or not"""
96
97    if not isinstance(c_vocab, cde.Vocab):
98        error = "The Vocab has not built yet, got type {0}. ".format(type(c_vocab))
99        suggestion = "Use Vocab.from_dataset(), Vocab.from_list(), Vocab.from_file() or Vocab.from_dict() " \
100                     "to build a Vocab."
101        raise RuntimeError(error + suggestion)
102
103
104def check_tokens_to_ids(method):
105    """A wrapper that wraps a parameter checker to the original function."""
106
107    @wraps(method)
108    def new_method(self, *args, **kwargs):
109        [tokens], _ = parse_user_args(method, *args, **kwargs)
110        type_check(tokens, (str, list, np.ndarray), "tokens")
111        if isinstance(tokens, list):
112            param_names = ["tokens[{0}]".format(i) for i in range(len(tokens))]
113            type_check_list(tokens, (str, np.str_), param_names)
114
115        return method(self, *args, **kwargs)
116
117    return new_method
118
119
120def check_ids_to_tokens(method):
121    """A wrapper that wraps a parameter checker to the original function."""
122
123    @wraps(method)
124    def new_method(self, *args, **kwargs):
125        [ids], _ = parse_user_args(method, *args, **kwargs)
126        type_check(ids, (int, list, np.ndarray), "ids")
127        if isinstance(ids, int):
128            check_value(ids, (0, INT32_MAX), "ids")
129        if isinstance(ids, list):
130            for index, id_ in enumerate(ids):
131                type_check(id_, (int, np.int_), "ids[{}]".format(index))
132                check_value(id_, (0, INT32_MAX), "ids[{}]".format(index))
133
134        return method(self, *args, **kwargs)
135
136    return new_method
137
138
139def check_from_list(method):
140    """A wrapper that wraps a parameter checker to the original function."""
141
142    @wraps(method)
143    def new_method(self, *args, **kwargs):
144        [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs)
145
146        word_set = check_unique_list_of_words(word_list, "word_list")
147        if special_tokens is not None:
148            token_set = check_unique_list_of_words(special_tokens, "special_tokens")
149
150            intersect = word_set.intersection(token_set)
151
152            if intersect != set():
153                raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".")
154
155        type_check(special_first, (bool,), "special_first")
156
157        return method(self, *args, **kwargs)
158
159    return new_method
160
161
162def check_from_dict(method):
163    """A wrapper that wraps a parameter checker to the original function."""
164
165    @wraps(method)
166    def new_method(self, *args, **kwargs):
167        [word_dict], _ = parse_user_args(method, *args, **kwargs)
168
169        type_check(word_dict, (dict,), "word_dict")
170
171        for word, word_id in word_dict.items():
172            type_check(word, (str,), "word")
173            type_check(word_id, (int,), "word_id")
174            check_value(word_id, (0, INT32_MAX), "word_id")
175        return method(self, *args, **kwargs)
176
177    return new_method
178
179
180def check_jieba_init(method):
181    """Wrapper method to check the parameters of jieba init."""
182
183    @wraps(method)
184    def new_method(self, *args, **kwargs):
185        [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs)
186
187        if hmm_path is None:
188            raise ValueError("The dict of HMMSegment in cppjieba is not provided.")
189        if not isinstance(hmm_path, str):
190            raise TypeError("Wrong input type for hmm_path, should be string.")
191        if mp_path is None:
192            raise ValueError("The dict of MPSegment in cppjieba is not provided.")
193        if not isinstance(mp_path, str):
194            raise TypeError("Wrong input type for mp_path, should be string.")
195        if not isinstance(with_offsets, bool):
196            raise TypeError("Wrong input type for with_offsets, should be boolean.")
197        return method(self, *args, **kwargs)
198
199    return new_method
200
201
202def check_jieba_add_word(method):
203    """Wrapper method to check the parameters of jieba add word."""
204
205    @wraps(method)
206    def new_method(self, *args, **kwargs):
207        [word, freq], _ = parse_user_args(method, *args, **kwargs)
208        if word is None:
209            raise ValueError("word is not provided.")
210        if freq is not None:
211            check_uint32(freq)
212        return method(self, *args, **kwargs)
213
214    return new_method
215
216
217def check_jieba_add_dict(method):
218    """Wrapper method to check the parameters of add dict."""
219
220    @wraps(method)
221    def new_method(self, *args, **kwargs):
222        parse_user_args(method, *args, **kwargs)
223        return method(self, *args, **kwargs)
224
225    return new_method
226
227
228def check_with_offsets(method):
229    """Wrapper method to check if with_offsets is the only one parameter."""
230
231    @wraps(method)
232    def new_method(self, *args, **kwargs):
233        [with_offsets], _ = parse_user_args(method, *args, **kwargs)
234        if not isinstance(with_offsets, bool):
235            raise TypeError("Wrong input type for with_offsets, should be boolean.")
236        return method(self, *args, **kwargs)
237
238    return new_method
239
240
241def check_unicode_script_tokenizer(method):
242    """Wrapper method to check the parameter of UnicodeScriptTokenizer."""
243
244    @wraps(method)
245    def new_method(self, *args, **kwargs):
246        [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs)
247        if not isinstance(keep_whitespace, bool):
248            raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
249        if not isinstance(with_offsets, bool):
250            raise TypeError("Wrong input type for with_offsets, should be boolean.")
251        return method(self, *args, **kwargs)
252
253    return new_method
254
255
256def check_wordpiece_tokenizer(method):
257    """Wrapper method to check the parameter of WordpieceTokenizer."""
258
259    @wraps(method)
260    def new_method(self, *args, **kwargs):
261        [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \
262            parse_user_args(method, *args, **kwargs)
263        if vocab is None:
264            raise ValueError("vocab is not provided.")
265        if not isinstance(vocab, text.Vocab):
266            raise TypeError("Wrong input type for vocab, should be text.Vocab object.")
267        if not isinstance(suffix_indicator, str):
268            raise TypeError("Wrong input type for suffix_indicator, should be string.")
269        if not isinstance(unknown_token, str):
270            raise TypeError("Wrong input type for unknown_token, should be string.")
271        if not isinstance(with_offsets, bool):
272            raise TypeError("Wrong input type for with_offsets, should be boolean.")
273        check_uint32(max_bytes_per_token)
274        return method(self, *args, **kwargs)
275
276    return new_method
277
278
279def check_regex_replace(method):
280    """Wrapper method to check the parameter of RegexReplace."""
281
282    @wraps(method)
283    def new_method(self, *args, **kwargs):
284        [pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs)
285        type_check(pattern, (str,), "pattern")
286        type_check(replace, (str,), "replace")
287        type_check(replace_all, (bool,), "replace_all")
288        return method(self, *args, **kwargs)
289
290    return new_method
291
292
293def check_regex_tokenizer(method):
294    """Wrapper method to check the parameter of RegexTokenizer."""
295
296    @wraps(method)
297    def new_method(self, *args, **kwargs):
298        [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs)
299        if delim_pattern is None:
300            raise ValueError("delim_pattern is not provided.")
301        if not isinstance(delim_pattern, str):
302            raise TypeError("Wrong input type for delim_pattern, should be string.")
303        if not isinstance(keep_delim_pattern, str):
304            raise TypeError("Wrong input type for keep_delim_pattern, should be string.")
305        if not isinstance(with_offsets, bool):
306            raise TypeError("Wrong input type for with_offsets, should be boolean.")
307        return method(self, *args, **kwargs)
308
309    return new_method
310
311
312def check_basic_tokenizer(method):
313    """Wrapper method to check the parameter of RegexTokenizer."""
314
315    @wraps(method)
316    def new_method(self, *args, **kwargs):
317        [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \
318            parse_user_args(method, *args, **kwargs)
319        if not isinstance(lower_case, bool):
320            raise TypeError("Wrong input type for lower_case, should be boolean.")
321        if not isinstance(keep_whitespace, bool):
322            raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
323        if not isinstance(preserve_unused, bool):
324            raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
325        if not isinstance(with_offsets, bool):
326            raise TypeError("Wrong input type for with_offsets, should be boolean.")
327        return method(self, *args, **kwargs)
328
329    return new_method
330
331
332def check_bert_tokenizer(method):
333    """Wrapper method to check the parameter of BertTokenizer."""
334
335    @wraps(method)
336    def new_method(self, *args, **kwargs):
337        [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _,
338         preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs)
339        if vocab is None:
340            raise ValueError("vacab is not provided.")
341        if not isinstance(vocab, text.Vocab):
342            raise TypeError("Wrong input type for vocab, should be text.Vocab object.")
343        if not isinstance(suffix_indicator, str):
344            raise TypeError("Wrong input type for suffix_indicator, should be string.")
345        if not isinstance(max_bytes_per_token, int):
346            raise TypeError("Wrong input type for max_bytes_per_token, should be int.")
347        check_uint32(max_bytes_per_token)
348
349        if not isinstance(unknown_token, str):
350            raise TypeError("Wrong input type for unknown_token, should be string.")
351        if not isinstance(lower_case, bool):
352            raise TypeError("Wrong input type for lower_case, should be boolean.")
353        if not isinstance(keep_whitespace, bool):
354            raise TypeError("Wrong input type for keep_whitespace, should be boolean.")
355        if not isinstance(preserve_unused_token, bool):
356            raise TypeError("Wrong input type for preserve_unused_token, should be boolean.")
357        if not isinstance(with_offsets, bool):
358            raise TypeError("Wrong input type for with_offsets, should be boolean.")
359        return method(self, *args, **kwargs)
360
361    return new_method
362
363
364def check_from_dataset(method):
365    """A wrapper that wraps a parameter checker to the original function."""
366
367    @wraps(method)
368    def new_method(self, *args, **kwargs):
369
370        [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args,
371                                                                                            **kwargs)
372        if columns is not None:
373            if not isinstance(columns, list):
374                columns = [columns]
375                type_check_list(columns, (str,), "col")
376
377        if freq_range is not None:
378            type_check(freq_range, (tuple,), "freq_range")
379
380            if len(freq_range) != 2:
381                raise ValueError("freq_range needs to be a tuple of 2 element.")
382
383            for num in freq_range:
384                if num is not None and (not isinstance(num, int)):
385                    raise ValueError(
386                        "freq_range needs to be either None or a tuple of 2 integers or an int and a None.")
387
388            if isinstance(freq_range[0], int) and isinstance(freq_range[1], int):
389                if freq_range[0] > freq_range[1] or freq_range[0] < 0:
390                    raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).")
391
392        type_check(top_k, (int, type(None)), "top_k")
393
394        if isinstance(top_k, int):
395            check_positive(top_k, "top_k")
396        type_check(special_first, (bool,), "special_first")
397
398        if special_tokens is not None:
399            check_unique_list_of_words(special_tokens, "special_tokens")
400
401        return method(self, *args, **kwargs)
402
403    return new_method
404
405
406def check_slidingwindow(method):
407    """A wrapper that wraps a parameter checker to the original function(sliding window operation)."""
408
409    @wraps(method)
410    def new_method(self, *args, **kwargs):
411        [width, axis], _ = parse_user_args(method, *args, **kwargs)
412        check_pos_int32(width, "width")
413        type_check(axis, (int,), "axis")
414        return method(self, *args, **kwargs)
415
416    return new_method
417
418
419def check_ngram(method):
420    """A wrapper that wraps a parameter checker to the original function."""
421
422    @wraps(method)
423    def new_method(self, *args, **kwargs):
424        [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs)
425
426        if isinstance(n, int):
427            n = [n]
428
429        if not (isinstance(n, list) and n != []):
430            raise ValueError("n needs to be a non-empty list of positive integers.")
431
432        for i, gram in enumerate(n):
433            type_check(gram, (int,), "gram[{0}]".format(i))
434            check_positive(gram, "gram_{}".format(i))
435
436        if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance(
437                left_pad[1], int)):
438            raise ValueError("left_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
439
440        if not (isinstance(right_pad, tuple) and len(right_pad) == 2 and isinstance(right_pad[0], str) and isinstance(
441                right_pad[1], int)):
442            raise ValueError("right_pad needs to be a tuple of (str, int) str is pad token and int is pad_width.")
443
444        if not (left_pad[1] >= 0 and right_pad[1] >= 0):
445            raise ValueError("padding width need to be positive numbers.")
446
447        type_check(separator, (str,), "separator")
448
449        kwargs["n"] = n
450        kwargs["left_pad"] = left_pad
451        kwargs["right_pad"] = right_pad
452        kwargs["separator"] = separator
453
454        return method(self, **kwargs)
455
456    return new_method
457
458
459def check_truncate(method):
460    """Wrapper method to check the parameters of number of truncate."""
461
462    @wraps(method)
463    def new_method(self, *args, **kwargs):
464        [max_seq_len], _ = parse_user_args(method, *args, **kwargs)
465        check_pos_int32(max_seq_len, "max_seq_len")
466        return method(self, *args, **kwargs)
467
468    return new_method
469
470
471def check_pair_truncate(method):
472    """Wrapper method to check the parameters of number of pair truncate."""
473
474    @wraps(method)
475    def new_method(self, *args, **kwargs):
476        parse_user_args(method, *args, **kwargs)
477        return method(self, *args, **kwargs)
478
479    return new_method
480
481
482def check_to_number(method):
483    """A wrapper that wraps a parameter check to the original function (ToNumber)."""
484
485    @wraps(method)
486    def new_method(self, *args, **kwargs):
487        [data_type], _ = parse_user_args(method, *args, **kwargs)
488        type_check(data_type, (typing.Type,), "data_type")
489
490        if data_type not in mstype.number_type:
491            raise TypeError("data_type: " + str(data_type) + " is not numeric data type.")
492
493        return method(self, *args, **kwargs)
494
495    return new_method
496
497
498def check_python_tokenizer(method):
499    """A wrapper that wraps a parameter check to the original function (PythonTokenizer)."""
500
501    @wraps(method)
502    def new_method(self, *args, **kwargs):
503        [tokenizer], _ = parse_user_args(method, *args, **kwargs)
504
505        if not callable(tokenizer):
506            raise TypeError("tokenizer is not a callable Python function.")
507
508        return method(self, *args, **kwargs)
509
510    return new_method
511
512
513def check_from_dataset_sentencepiece(method):
514    """A wrapper that wraps a parameter checker to the original function (from_dataset)."""
515
516    @wraps(method)
517    def new_method(self, *args, **kwargs):
518        [_, col_names, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
519
520        if col_names is not None:
521            type_check_list(col_names, (str,), "col_names")
522
523        if vocab_size is not None:
524            check_uint32(vocab_size, "vocab_size")
525        else:
526            raise TypeError("vocab_size must be provided.")
527
528        if character_coverage is not None:
529            type_check(character_coverage, (float,), "character_coverage")
530
531        if model_type is not None:
532            from .utils import SentencePieceModel
533            type_check(model_type, (str, SentencePieceModel), "model_type")
534
535        if params is not None:
536            type_check(params, (dict,), "params")
537
538        return method(self, *args, **kwargs)
539
540    return new_method
541
542
543def check_from_file_sentencepiece(method):
544    """A wrapper that wraps a parameter checker to the original function (from_file)."""
545
546    @wraps(method)
547    def new_method(self, *args, **kwargs):
548        [file_path, vocab_size, character_coverage, model_type, params], _ = parse_user_args(method, *args, **kwargs)
549
550        if file_path is not None:
551            type_check(file_path, (list,), "file_path")
552
553        if vocab_size is not None:
554            check_uint32(vocab_size, "vocab_size")
555
556        if character_coverage is not None:
557            type_check(character_coverage, (float,), "character_coverage")
558
559        if model_type is not None:
560            from .utils import SentencePieceModel
561            type_check(model_type, (str, SentencePieceModel), "model_type")
562
563        if params is not None:
564            type_check(params, (dict,), "params")
565
566        return method(self, *args, **kwargs)
567
568    return new_method
569
570
571def check_save_model(method):
572    """A wrapper that wraps a parameter checker to the original function (save_model)."""
573
574    @wraps(method)
575    def new_method(self, *args, **kwargs):
576        [vocab, path, filename], _ = parse_user_args(method, *args, **kwargs)
577
578        if vocab is not None:
579            type_check(vocab, (text.SentencePieceVocab,), "vocab")
580
581        if path is not None:
582            type_check(path, (str,), "path")
583
584        if filename is not None:
585            type_check(filename, (str,), "filename")
586
587        return method(self, *args, **kwargs)
588
589    return new_method
590
591
592def check_sentence_piece_tokenizer(method):
593
594    """A wrapper that wraps a parameter checker to the original function."""
595
596    from .utils import SPieceTokenizerOutType
597    @wraps(method)
598    def new_method(self, *args, **kwargs):
599        [mode, out_type], _ = parse_user_args(method, *args, **kwargs)
600
601        type_check(mode, (str, text.SentencePieceVocab), "mode is not an instance of str or text.SentencePieceVocab.")
602        type_check(out_type, (SPieceTokenizerOutType,), "out_type is not an instance of SPieceTokenizerOutType")
603
604        return method(self, *args, **kwargs)
605
606    return new_method
607
608
609def check_from_file_vectors(method):
610    """A wrapper that wraps a parameter checker to from_file of class Vectors."""
611
612    @wraps(method)
613    def new_method(self, *args, **kwargs):
614        [file_path, max_vectors], _ = parse_user_args(method, *args, **kwargs)
615
616        type_check(file_path, (str,), "file_path")
617        check_filename(file_path)
618        if max_vectors is not None:
619            type_check(max_vectors, (int,), "max_vectors")
620            check_non_negative_int32(max_vectors, "max_vectors")
621
622        return method(self, *args, **kwargs)
623
624    return new_method
625
626
627def check_to_vectors(method):
628    """A wrapper that wraps a parameter checker to ToVectors."""
629
630    @wraps(method)
631    def new_method(self, *args, **kwargs):
632        [vectors, unk_init, lower_case_backup], _ = parse_user_args(method, *args, **kwargs)
633
634        type_check(vectors, (cde.Vectors,), "vectors")
635        if unk_init is not None:
636            type_check(unk_init, (list, tuple), "unk_init")
637            for i, value in enumerate(unk_init):
638                type_check(value, (int, float), "unk_init[{0}]".format(i))
639        type_check(lower_case_backup, (bool,), "lower_case_backup")
640        return method(self, *args, **kwargs)
641
642    return new_method
643