• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Tokenization classes."""
17
18from __future__ import absolute_import, division, print_function, unicode_literals
19
20import collections
21import logging
22import os
23import unicodedata
24from io import open
25
26
27logger = logging.getLogger(__name__)
28
29
30def load_vocab(vocab_file, vocab_map_ids_path=None):
31    """Loads a vocabulary file into a dictionary."""
32    vocab = collections.OrderedDict()
33    if vocab_map_ids_path is not None:
34        vocab_new_ids = list()
35        with open(vocab_map_ids_path, "r", encoding="utf-8") as vocab_new_ids_reader:
36            while True:
37                index = vocab_new_ids_reader.readline()
38                if not index:
39                    break
40                index = index.strip()
41                vocab_new_ids.append(int(index))
42        index = 0
43        with open(vocab_file, "r", encoding="utf-8") as reader:
44            while True:
45                token = reader.readline()
46                if not token:
47                    break
48                token = token.strip()
49                vocab[token] = vocab_new_ids[index]
50                index += 1
51        return vocab
52    index = 0
53    with open(vocab_file, "r", encoding="utf-8") as reader:
54        while True:
55            token = reader.readline()
56            if not token:
57                break
58            token = token.strip()
59            vocab[token] = index
60            index += 1
61    return vocab
62
63
64def whitespace_tokenize(text):
65    """Runs basic whitespace cleaning and splitting on a piece of text."""
66    text = text.strip()
67    if not text:
68        return []
69    tokens = text.split()
70    return tokens
71
72
73class BertTokenizer:
74    """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
75
76    def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, basic_only=False,
77                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
78        """Constructs a BertTokenizer.
79
80        Args:
81          vocab_file: Path to a one-wordpiece-per-line vocabulary file
82          do_lower_case: Whether to lower case the input
83                         Only has an effect when do_wordpiece_only=False
84          do_basic_tokenize: Whether to do basic tokenization before wordpiece.
85          max_len: An artificial maximum length to truncate tokenized sequences to;
86                         Effective maximum length is always the minimum of this
87                         value (if specified) and the underlying BERT model's
88                         sequence length.
89          never_split: List of tokens which will never be split during tokenization.
90                         Only has an effect when do_wordpiece_only=False
91        """
92        self.vocab = load_vocab(vocab_file)
93        self.ids_to_tokens = collections.OrderedDict(
94            [(ids, tok) for tok, ids in self.vocab.items()])
95        self.do_basic_tokenize = do_basic_tokenize
96        if do_basic_tokenize:
97            self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
98                                                  never_split=never_split)
99        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
100        self.max_len = max_len if max_len is not None else int(1e12)
101        self.basic_only = basic_only
102
103    def tokenize(self, text):
104        split_tokens = []
105        if self.do_basic_tokenize:
106            for token in self.basic_tokenizer.tokenize(text):
107                if self.basic_only:
108                    split_tokens.append(token)
109                else:
110                    for sub_token in self.wordpiece_tokenizer.tokenize(token):
111                        split_tokens.append(sub_token)
112        else:
113            split_tokens = self.wordpiece_tokenizer.tokenize(text)
114        return split_tokens
115
116    def convert_tokens_to_ids(self, tokens):
117        """Converts a sequence of tokens into ids using the vocab."""
118        ids = []
119        for token in tokens:
120            ids.append(self.vocab.get(token, self.vocab['[UNK]']))
121        return ids
122
123    def convert_ids_to_tokens(self, ids):
124        """Converts a sequence of ids in wordpiece tokens using the vocab."""
125        tokens = []
126        for i in ids:
127            tokens.append(self.ids_to_tokens[i])
128        return tokens
129
130    def save_vocabulary(self, vocab_path):
131        """Save the tokenizer vocabulary to a directory or file."""
132        index = 0
133        if os.path.isdir(vocab_path):
134            vocab_file = os.path.join(vocab_path, 'vocab.txt')
135        else:
136            raise FileNotFoundError
137        with open(vocab_file, "w", encoding="utf-8") as writer:
138            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
139                if index != token_index:
140                    index = token_index
141                writer.write(token + u'\n')
142                index += 1
143        return vocab_file
144
145    @classmethod
146    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
147        """
148        Instantiate a PreTrainedBertModel from a pre-trained model file.
149        Download and cache the pre-trained model file if needed.
150        """
151
152        if 'txt' in pretrained_model_name_or_path:
153            resolved_vocab_file = pretrained_model_name_or_path
154        else:
155            resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'vocab.txt')
156
157        max_len = 512
158        kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
159        # Instantiate tokenizer.
160        tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
161
162        return tokenizer
163
164
165class BasicTokenizer:
166    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
167
168    def __init__(self,
169                 do_lower_case=True,
170                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
171        """Constructs a BasicTokenizer.
172
173        Args:
174          do_lower_case: Whether to lower case the input.
175        """
176        self.do_lower_case = do_lower_case
177        self.never_split = never_split
178
179    def tokenize(self, text):
180        """Tokenizes a piece of text."""
181        text = self._clean_text(text)
182        # This was added on November 1st, 2018 for the multilingual and Chinese
183        # models. This is also applied to the English models now, but it doesn't
184        # matter since the English models were not trained on any Chinese data
185        # and generally don't have any Chinese data in them (there are Chinese
186        # characters in the vocabulary because Wikipedia does have some Chinese
187        # words in the English Wikipedia.).
188        text = self._tokenize_chinese_chars(text)
189        orig_tokens = whitespace_tokenize(text)
190        split_tokens = []
191        for token in orig_tokens:
192            if self.do_lower_case and token not in self.never_split:
193                token = token.lower()
194                token = self._run_strip_accents(token)
195            split_tokens.extend(self._run_split_on_punc(token))
196
197        output_tokens = whitespace_tokenize(" ".join(split_tokens))
198        return output_tokens
199
200    @staticmethod
201    def _run_strip_accents(text):
202        """Strips accents from a piece of text."""
203        text = unicodedata.normalize("NFD", text)
204        output = []
205        for char in text:
206            cat = unicodedata.category(char)
207            if cat == "Mn":
208                continue
209            output.append(char)
210        return "".join(output)
211
212    def _run_split_on_punc(self, text):
213        """Splits punctuation on a piece of text."""
214        if text in self.never_split:
215            return [text]
216        chars = list(text)
217        i = 0
218        start_new_word = True
219        output = []
220        while i < len(chars):
221            char = chars[i]
222            if _is_punctuation(char):
223                output.append([char])
224                start_new_word = True
225            else:
226                if start_new_word:
227                    output.append([])
228                start_new_word = False
229                output[-1].append(char)
230            i += 1
231
232        return ["".join(x) for x in output]
233
234    def _tokenize_chinese_chars(self, text):
235        """Adds whitespace around any CJK character."""
236        output = []
237        for char in text:
238            cp = ord(char)
239            if self._is_chinese_char(cp):
240                output.append(" ")
241                output.append(char)
242                output.append(" ")
243            else:
244                output.append(char)
245        return "".join(output)
246
247    @staticmethod
248    def _is_chinese_char(cp):
249        """Checks whether CP is the codepoint of a CJK character."""
250        # This defines a "chinese character" as anything in the CJK Unicode block:
251        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
252        #
253        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
254        # despite its name. The modern Korean Hangul alphabet is a different block,
255        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
256        # space-separated words, so they are not treated specially and handled
257        # like the all of the other languages.
258        if (
259                (0x4E00 <= cp <= 0x9FFF) or
260                (0x3400 <= cp <= 0x4DBF) or
261                (0x20000 <= cp <= 0x2A6DF) or
262                (0x2A700 <= cp <= 0x2B73F) or
263                (0x2B740 <= cp <= 0x2B81F) or
264                (0x2B820 <= cp <= 0x2CEAF) or
265                (0xF900 <= cp <= 0xFAFF) or
266                (0x2F800 <= cp <= 0x2FA1F)
267        ):
268            return True
269
270        return False
271
272    @staticmethod
273    def _clean_text(text):
274        """Performs invalid character removal and whitespace cleanup on text."""
275        output = []
276        for char in text:
277            cp = ord(char)
278            if cp == 0 or cp == 0xfffd or _is_control(char):
279                continue
280            if _is_whitespace(char):
281                output.append(" ")
282            else:
283                output.append(char)
284        return "".join(output)
285
286
287class WordpieceTokenizer:
288    """Runs WordPiece tokenization."""
289
290    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
291        self.vocab = vocab
292        self.unk_token = unk_token
293        self.max_input_chars_per_word = max_input_chars_per_word
294
295    def tokenize(self, text):
296        """Tokenizes a piece of text into its word pieces.
297
298        This uses a greedy longest-match-first algorithm to perform tokenization
299        using the given vocabulary.
300
301        For example:
302          input = "unaffable"
303          output = ["un", "##aff", "##able"]
304
305        Args:
306          text: A single token or whitespace separated tokens. This should have
307            already been passed through `BasicTokenizer`.
308
309        Returns:
310          A list of wordpiece tokens.
311        """
312
313        output_tokens = []
314        for token in whitespace_tokenize(text):
315            chars = list(token)
316            if len(chars) > self.max_input_chars_per_word:
317                output_tokens.append(self.unk_token)
318                continue
319
320            is_bad = False
321            start = 0
322            sub_tokens = []
323            while start < len(chars):
324                end = len(chars)
325                cur_substr = None
326                while start < end:
327                    substr = "".join(chars[start:end])
328                    if start > 0:
329                        substr = "##" + substr
330                    if substr in self.vocab:
331                        cur_substr = substr
332                        break
333                    end -= 1
334                if cur_substr is None:
335                    is_bad = True
336                    break
337                sub_tokens.append(cur_substr)
338                start = end
339
340            if is_bad:
341                output_tokens.append(self.unk_token)
342            else:
343                output_tokens.extend(sub_tokens)
344        return output_tokens
345
346
347def _is_whitespace(char):
348    """Checks whether `chars` is a whitespace character."""
349    # \t, \n, and \r are technically control characters but we treat them
350    # as whitespace since they are generally considered as such.
351    if char in (" ", "\t", "\n", "\r"):
352        return True
353    cat = unicodedata.category(char)
354    if cat == "Zs":
355        return True
356    return False
357
358
359def _is_control(char):
360    """Checks whether `chars` is a control character."""
361    # These are technically control characters but we count them as whitespace
362    # characters.
363    if char in ("\t", "\n", "\r"):
364        return False
365    cat = unicodedata.category(char)
366    if cat.startswith("C"):
367        return True
368    return False
369
370
371def _is_punctuation(char):
372    """Checks whether `chars` is a punctuation character."""
373    cp = ord(char)
374    # We treat all non-letter/number ASCII as punctuation.
375    # Characters such as "^", "$", and "`" are not in the Unicode
376    # Punctuation class but we treat them as punctuation anyways, for
377    # consistency.
378    if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126):
379        return True
380    cat = unicodedata.category(char)
381    if cat.startswith("P"):
382        return True
383    return False
384
385
386class CustomizedBasicTokenizer(BasicTokenizer):
387    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
388
389    def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None):
390        """Constructs a BasicTokenizer.
391
392        Args:
393          do_lower_case: Whether to lower case the input.
394        """
395        super().__init__(do_lower_case, never_split)
396        self.do_lower_case = do_lower_case
397        self.never_split = never_split
398        self.keywords = keywords
399
400    def tokenize(self, text):
401        """Tokenizes a piece of text."""
402        text = self._clean_text(text)
403        # This was added on November 1st, 2018 for the multilingual and Chinese
404        # models. This is also applied to the English models now, but it doesn't
405        # matter since the English models were not trained on any Chinese data
406        # and generally don't have any Chinese data in them (there are Chinese
407        # characters in the vocabulary because Wikipedia does have some Chinese
408        # words in the English Wikipedia.).
409        text = self._tokenize_chinese_chars(text)
410        orig_tokens = whitespace_tokenize(text)
411
412        if self.keywords is not None:
413            new_orig_tokens = []
414            lengths = [len(_) for _ in self.keywords]
415            max_length = max(lengths)
416            orig_tokens_len = len(orig_tokens)
417            i = 0
418            while i < orig_tokens_len:
419                has_add = False
420                for length in range(max_length, 0, -1):
421                    if i + length > orig_tokens_len:
422                        continue
423                    add_token = ''.join(orig_tokens[i:i+length])
424                    if add_token in self.keywords:
425                        new_orig_tokens.append(add_token)
426                        i += length
427                        has_add = True
428                        break
429                if not has_add:
430                    new_orig_tokens.append(orig_tokens[i])
431                    i += 1
432        else:
433            new_orig_tokens = orig_tokens
434
435        split_tokens = []
436        for token in new_orig_tokens:
437            if self.do_lower_case and token not in self.never_split:
438                token = token.lower()
439                token = self._run_strip_accents(token)
440            split_tokens.extend(self._run_split_on_punc(token))
441
442        output_tokens = whitespace_tokenize(" ".join(split_tokens))
443        return output_tokens
444
445
446class CustomizedTokenizer(BertTokenizer):
447    """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
448
449    def __init__(self, vocab_file, do_lower_case=True, max_len=None,
450                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None):
451        """Constructs a CustomizedTokenizer.
452
453        Args:
454          vocab_file: Path to a one-wordpiece-per-line vocabulary file
455          do_lower_case: Whether to lower case the input
456                         Only has an effect when do_wordpiece_only=False
457          max_len: An artificial maximum length to truncate tokenized sequences to;
458                         Effective maximum length is always the minimum of this
459                         value (if specified) and the underlying BERT model's
460                         sequence length.
461          never_split: List of tokens which will never be split during tokenization.
462                         Only has an effect when do_wordpiece_only=False
463        """
464        super().__init__(vocab_file, do_lower_case, max_len, never_split)
465
466        self.vocab = load_vocab(vocab_file)
467        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
468        self.basic_tokenizer = CustomizedBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split,
469                                                        keywords=keywords)
470        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
471        self.max_len = max_len if max_len is not None else int(1e12)
472
473    def tokenize(self, text):
474        split_tokens = []
475        basic_tokens = self.basic_tokenizer.tokenize(text)
476        for token in basic_tokens:
477            wordpiece_tokens = self.wordpiece_tokenizer.tokenize(token)
478            for sub_token in wordpiece_tokens:
479                split_tokens.append(sub_token)
480        return split_tokens
481
482    @classmethod
483    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
484        """
485        Instantiate a PreTrainedBertModel from a pre-trained model file.
486        Download and cache the pre-trained model file if needed.
487        """
488        resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'customized_vocab.txt')
489
490        max_len = 512
491        kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
492        # Instantiate tokenizer.
493        tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
494
495        return tokenizer
496
497
498class CustomizedTextBasicTokenizer(BasicTokenizer):
499    def tokenize(self, text):
500        """Tokenizes a piece of text."""
501        text = self._clean_text(text)
502        # This was added on November 1st, 2018 for the multilingual and Chinese
503        # models. This is also applied to the English models now, but it doesn't
504        # matter since the English models were not trained on any Chinese data
505        # and generally don't have any Chinese data in them (there are Chinese
506        # characters in the vocabulary because Wikipedia does have some Chinese
507        # words in the English Wikipedia.).
508        text = self._tokenize_chinese_chars(text)
509        orig_tokens = whitespace_tokenize(text)
510        split_tokens = []
511        for token in orig_tokens:
512            if self.do_lower_case and token not in self.never_split:
513                token = token.lower()
514            split_tokens.extend(self._run_split_on_punc(token))
515
516        output_tokens = whitespace_tokenize(" ".join(split_tokens))
517        return output_tokens
518
519    def _tokenize_chinese_chars(self, text):
520        """Adds whitespace around any CJK character."""
521        output = []
522        for char in text:
523            cp = ord(char)
524            if self._is_chinese_char(cp) or len(char.encode('utf-8')) > 1:
525                output.append(" ")
526                output.append(char)
527                output.append(" ")
528            else:
529                output.append(char)
530        return "".join(output)
531
532
533class CustomizedTextTokenizer(BertTokenizer):
534    """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
535
536    def __init__(self, vocab_file, do_lower_case=True, max_len=None,
537                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"),
538                 vocab_map_ids_path=None):
539        """Constructs a CustomizedTokenizer.
540
541        Args:
542          vocab_file: Path to a one-wordpiece-per-line vocabulary file
543          do_lower_case: Whether to lower case the input
544                         Only has an effect when do_wordpiece_only=False
545          max_len: An artificial maximum length to truncate tokenized sequences to;
546                         Effective maximum length is always the minimum of this
547                         value (if specified) and the underlying BERT model's
548                         sequence length.
549          never_split: List of tokens which will never be split during tokenization.
550                         Only has an effect when do_wordpiece_only=False
551        """
552        super().__init__(vocab_file, do_lower_case, max_len, never_split)
553
554        self.vocab = load_vocab(vocab_file, vocab_map_ids_path)
555        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
556        self.basic_tokenizer = CustomizedTextBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split)
557        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
558        self.max_len = max_len if max_len is not None else int(1e12)
559