1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Tokenization classes.""" 17 18from __future__ import absolute_import, division, print_function, unicode_literals 19 20import collections 21import logging 22import os 23import unicodedata 24from io import open 25 26 27logger = logging.getLogger(__name__) 28 29 30def load_vocab(vocab_file, vocab_map_ids_path=None): 31 """Loads a vocabulary file into a dictionary.""" 32 vocab = collections.OrderedDict() 33 if vocab_map_ids_path is not None: 34 vocab_new_ids = list() 35 with open(vocab_map_ids_path, "r", encoding="utf-8") as vocab_new_ids_reader: 36 while True: 37 index = vocab_new_ids_reader.readline() 38 if not index: 39 break 40 index = index.strip() 41 vocab_new_ids.append(int(index)) 42 index = 0 43 with open(vocab_file, "r", encoding="utf-8") as reader: 44 while True: 45 token = reader.readline() 46 if not token: 47 break 48 token = token.strip() 49 vocab[token] = vocab_new_ids[index] 50 index += 1 51 return vocab 52 index = 0 53 with open(vocab_file, "r", encoding="utf-8") as reader: 54 while True: 55 token = reader.readline() 56 if not token: 57 break 58 token = token.strip() 59 vocab[token] = index 60 index += 1 61 return vocab 62 63 64def whitespace_tokenize(text): 65 """Runs basic whitespace cleaning and splitting on a piece of text.""" 66 text = text.strip() 67 if not text: 68 return [] 69 tokens = text.split() 70 return tokens 71 72 73class BertTokenizer: 74 """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 75 76 def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, basic_only=False, 77 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 78 """Constructs a BertTokenizer. 79 80 Args: 81 vocab_file: Path to a one-wordpiece-per-line vocabulary file 82 do_lower_case: Whether to lower case the input 83 Only has an effect when do_wordpiece_only=False 84 do_basic_tokenize: Whether to do basic tokenization before wordpiece. 85 max_len: An artificial maximum length to truncate tokenized sequences to; 86 Effective maximum length is always the minimum of this 87 value (if specified) and the underlying BERT model's 88 sequence length. 89 never_split: List of tokens which will never be split during tokenization. 90 Only has an effect when do_wordpiece_only=False 91 """ 92 self.vocab = load_vocab(vocab_file) 93 self.ids_to_tokens = collections.OrderedDict( 94 [(ids, tok) for tok, ids in self.vocab.items()]) 95 self.do_basic_tokenize = do_basic_tokenize 96 if do_basic_tokenize: 97 self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 98 never_split=never_split) 99 self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 100 self.max_len = max_len if max_len is not None else int(1e12) 101 self.basic_only = basic_only 102 103 def tokenize(self, text): 104 split_tokens = [] 105 if self.do_basic_tokenize: 106 for token in self.basic_tokenizer.tokenize(text): 107 if self.basic_only: 108 split_tokens.append(token) 109 else: 110 for sub_token in self.wordpiece_tokenizer.tokenize(token): 111 split_tokens.append(sub_token) 112 else: 113 split_tokens = self.wordpiece_tokenizer.tokenize(text) 114 return split_tokens 115 116 def convert_tokens_to_ids(self, tokens): 117 """Converts a sequence of tokens into ids using the vocab.""" 118 ids = [] 119 for token in tokens: 120 ids.append(self.vocab.get(token, self.vocab['[UNK]'])) 121 return ids 122 123 def convert_ids_to_tokens(self, ids): 124 """Converts a sequence of ids in wordpiece tokens using the vocab.""" 125 tokens = [] 126 for i in ids: 127 tokens.append(self.ids_to_tokens[i]) 128 return tokens 129 130 def save_vocabulary(self, vocab_path): 131 """Save the tokenizer vocabulary to a directory or file.""" 132 index = 0 133 if os.path.isdir(vocab_path): 134 vocab_file = os.path.join(vocab_path, 'vocab.txt') 135 else: 136 raise FileNotFoundError 137 with open(vocab_file, "w", encoding="utf-8") as writer: 138 for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 139 if index != token_index: 140 index = token_index 141 writer.write(token + u'\n') 142 index += 1 143 return vocab_file 144 145 @classmethod 146 def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 147 """ 148 Instantiate a PreTrainedBertModel from a pre-trained model file. 149 Download and cache the pre-trained model file if needed. 150 """ 151 152 if 'txt' in pretrained_model_name_or_path: 153 resolved_vocab_file = pretrained_model_name_or_path 154 else: 155 resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'vocab.txt') 156 157 max_len = 512 158 kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 159 # Instantiate tokenizer. 160 tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 161 162 return tokenizer 163 164 165class BasicTokenizer: 166 """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 167 168 def __init__(self, 169 do_lower_case=True, 170 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 171 """Constructs a BasicTokenizer. 172 173 Args: 174 do_lower_case: Whether to lower case the input. 175 """ 176 self.do_lower_case = do_lower_case 177 self.never_split = never_split 178 179 def tokenize(self, text): 180 """Tokenizes a piece of text.""" 181 text = self._clean_text(text) 182 # This was added on November 1st, 2018 for the multilingual and Chinese 183 # models. This is also applied to the English models now, but it doesn't 184 # matter since the English models were not trained on any Chinese data 185 # and generally don't have any Chinese data in them (there are Chinese 186 # characters in the vocabulary because Wikipedia does have some Chinese 187 # words in the English Wikipedia.). 188 text = self._tokenize_chinese_chars(text) 189 orig_tokens = whitespace_tokenize(text) 190 split_tokens = [] 191 for token in orig_tokens: 192 if self.do_lower_case and token not in self.never_split: 193 token = token.lower() 194 token = self._run_strip_accents(token) 195 split_tokens.extend(self._run_split_on_punc(token)) 196 197 output_tokens = whitespace_tokenize(" ".join(split_tokens)) 198 return output_tokens 199 200 @staticmethod 201 def _run_strip_accents(text): 202 """Strips accents from a piece of text.""" 203 text = unicodedata.normalize("NFD", text) 204 output = [] 205 for char in text: 206 cat = unicodedata.category(char) 207 if cat == "Mn": 208 continue 209 output.append(char) 210 return "".join(output) 211 212 def _run_split_on_punc(self, text): 213 """Splits punctuation on a piece of text.""" 214 if text in self.never_split: 215 return [text] 216 chars = list(text) 217 i = 0 218 start_new_word = True 219 output = [] 220 while i < len(chars): 221 char = chars[i] 222 if _is_punctuation(char): 223 output.append([char]) 224 start_new_word = True 225 else: 226 if start_new_word: 227 output.append([]) 228 start_new_word = False 229 output[-1].append(char) 230 i += 1 231 232 return ["".join(x) for x in output] 233 234 def _tokenize_chinese_chars(self, text): 235 """Adds whitespace around any CJK character.""" 236 output = [] 237 for char in text: 238 cp = ord(char) 239 if self._is_chinese_char(cp): 240 output.append(" ") 241 output.append(char) 242 output.append(" ") 243 else: 244 output.append(char) 245 return "".join(output) 246 247 @staticmethod 248 def _is_chinese_char(cp): 249 """Checks whether CP is the codepoint of a CJK character.""" 250 # This defines a "chinese character" as anything in the CJK Unicode block: 251 # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 252 # 253 # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 254 # despite its name. The modern Korean Hangul alphabet is a different block, 255 # as is Japanese Hiragana and Katakana. Those alphabets are used to write 256 # space-separated words, so they are not treated specially and handled 257 # like the all of the other languages. 258 if ( 259 (0x4E00 <= cp <= 0x9FFF) or 260 (0x3400 <= cp <= 0x4DBF) or 261 (0x20000 <= cp <= 0x2A6DF) or 262 (0x2A700 <= cp <= 0x2B73F) or 263 (0x2B740 <= cp <= 0x2B81F) or 264 (0x2B820 <= cp <= 0x2CEAF) or 265 (0xF900 <= cp <= 0xFAFF) or 266 (0x2F800 <= cp <= 0x2FA1F) 267 ): 268 return True 269 270 return False 271 272 @staticmethod 273 def _clean_text(text): 274 """Performs invalid character removal and whitespace cleanup on text.""" 275 output = [] 276 for char in text: 277 cp = ord(char) 278 if cp == 0 or cp == 0xfffd or _is_control(char): 279 continue 280 if _is_whitespace(char): 281 output.append(" ") 282 else: 283 output.append(char) 284 return "".join(output) 285 286 287class WordpieceTokenizer: 288 """Runs WordPiece tokenization.""" 289 290 def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 291 self.vocab = vocab 292 self.unk_token = unk_token 293 self.max_input_chars_per_word = max_input_chars_per_word 294 295 def tokenize(self, text): 296 """Tokenizes a piece of text into its word pieces. 297 298 This uses a greedy longest-match-first algorithm to perform tokenization 299 using the given vocabulary. 300 301 For example: 302 input = "unaffable" 303 output = ["un", "##aff", "##able"] 304 305 Args: 306 text: A single token or whitespace separated tokens. This should have 307 already been passed through `BasicTokenizer`. 308 309 Returns: 310 A list of wordpiece tokens. 311 """ 312 313 output_tokens = [] 314 for token in whitespace_tokenize(text): 315 chars = list(token) 316 if len(chars) > self.max_input_chars_per_word: 317 output_tokens.append(self.unk_token) 318 continue 319 320 is_bad = False 321 start = 0 322 sub_tokens = [] 323 while start < len(chars): 324 end = len(chars) 325 cur_substr = None 326 while start < end: 327 substr = "".join(chars[start:end]) 328 if start > 0: 329 substr = "##" + substr 330 if substr in self.vocab: 331 cur_substr = substr 332 break 333 end -= 1 334 if cur_substr is None: 335 is_bad = True 336 break 337 sub_tokens.append(cur_substr) 338 start = end 339 340 if is_bad: 341 output_tokens.append(self.unk_token) 342 else: 343 output_tokens.extend(sub_tokens) 344 return output_tokens 345 346 347def _is_whitespace(char): 348 """Checks whether `chars` is a whitespace character.""" 349 # \t, \n, and \r are technically control characters but we treat them 350 # as whitespace since they are generally considered as such. 351 if char in (" ", "\t", "\n", "\r"): 352 return True 353 cat = unicodedata.category(char) 354 if cat == "Zs": 355 return True 356 return False 357 358 359def _is_control(char): 360 """Checks whether `chars` is a control character.""" 361 # These are technically control characters but we count them as whitespace 362 # characters. 363 if char in ("\t", "\n", "\r"): 364 return False 365 cat = unicodedata.category(char) 366 if cat.startswith("C"): 367 return True 368 return False 369 370 371def _is_punctuation(char): 372 """Checks whether `chars` is a punctuation character.""" 373 cp = ord(char) 374 # We treat all non-letter/number ASCII as punctuation. 375 # Characters such as "^", "$", and "`" are not in the Unicode 376 # Punctuation class but we treat them as punctuation anyways, for 377 # consistency. 378 if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126): 379 return True 380 cat = unicodedata.category(char) 381 if cat.startswith("P"): 382 return True 383 return False 384 385 386class CustomizedBasicTokenizer(BasicTokenizer): 387 """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 388 389 def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None): 390 """Constructs a BasicTokenizer. 391 392 Args: 393 do_lower_case: Whether to lower case the input. 394 """ 395 super().__init__(do_lower_case, never_split) 396 self.do_lower_case = do_lower_case 397 self.never_split = never_split 398 self.keywords = keywords 399 400 def tokenize(self, text): 401 """Tokenizes a piece of text.""" 402 text = self._clean_text(text) 403 # This was added on November 1st, 2018 for the multilingual and Chinese 404 # models. This is also applied to the English models now, but it doesn't 405 # matter since the English models were not trained on any Chinese data 406 # and generally don't have any Chinese data in them (there are Chinese 407 # characters in the vocabulary because Wikipedia does have some Chinese 408 # words in the English Wikipedia.). 409 text = self._tokenize_chinese_chars(text) 410 orig_tokens = whitespace_tokenize(text) 411 412 if self.keywords is not None: 413 new_orig_tokens = [] 414 lengths = [len(_) for _ in self.keywords] 415 max_length = max(lengths) 416 orig_tokens_len = len(orig_tokens) 417 i = 0 418 while i < orig_tokens_len: 419 has_add = False 420 for length in range(max_length, 0, -1): 421 if i + length > orig_tokens_len: 422 continue 423 add_token = ''.join(orig_tokens[i:i+length]) 424 if add_token in self.keywords: 425 new_orig_tokens.append(add_token) 426 i += length 427 has_add = True 428 break 429 if not has_add: 430 new_orig_tokens.append(orig_tokens[i]) 431 i += 1 432 else: 433 new_orig_tokens = orig_tokens 434 435 split_tokens = [] 436 for token in new_orig_tokens: 437 if self.do_lower_case and token not in self.never_split: 438 token = token.lower() 439 token = self._run_strip_accents(token) 440 split_tokens.extend(self._run_split_on_punc(token)) 441 442 output_tokens = whitespace_tokenize(" ".join(split_tokens)) 443 return output_tokens 444 445 446class CustomizedTokenizer(BertTokenizer): 447 """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 448 449 def __init__(self, vocab_file, do_lower_case=True, max_len=None, 450 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), keywords=None): 451 """Constructs a CustomizedTokenizer. 452 453 Args: 454 vocab_file: Path to a one-wordpiece-per-line vocabulary file 455 do_lower_case: Whether to lower case the input 456 Only has an effect when do_wordpiece_only=False 457 max_len: An artificial maximum length to truncate tokenized sequences to; 458 Effective maximum length is always the minimum of this 459 value (if specified) and the underlying BERT model's 460 sequence length. 461 never_split: List of tokens which will never be split during tokenization. 462 Only has an effect when do_wordpiece_only=False 463 """ 464 super().__init__(vocab_file, do_lower_case, max_len, never_split) 465 466 self.vocab = load_vocab(vocab_file) 467 self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 468 self.basic_tokenizer = CustomizedBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split, 469 keywords=keywords) 470 self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 471 self.max_len = max_len if max_len is not None else int(1e12) 472 473 def tokenize(self, text): 474 split_tokens = [] 475 basic_tokens = self.basic_tokenizer.tokenize(text) 476 for token in basic_tokens: 477 wordpiece_tokens = self.wordpiece_tokenizer.tokenize(token) 478 for sub_token in wordpiece_tokens: 479 split_tokens.append(sub_token) 480 return split_tokens 481 482 @classmethod 483 def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 484 """ 485 Instantiate a PreTrainedBertModel from a pre-trained model file. 486 Download and cache the pre-trained model file if needed. 487 """ 488 resolved_vocab_file = os.path.join(pretrained_model_name_or_path, 'customized_vocab.txt') 489 490 max_len = 512 491 kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 492 # Instantiate tokenizer. 493 tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 494 495 return tokenizer 496 497 498class CustomizedTextBasicTokenizer(BasicTokenizer): 499 def tokenize(self, text): 500 """Tokenizes a piece of text.""" 501 text = self._clean_text(text) 502 # This was added on November 1st, 2018 for the multilingual and Chinese 503 # models. This is also applied to the English models now, but it doesn't 504 # matter since the English models were not trained on any Chinese data 505 # and generally don't have any Chinese data in them (there are Chinese 506 # characters in the vocabulary because Wikipedia does have some Chinese 507 # words in the English Wikipedia.). 508 text = self._tokenize_chinese_chars(text) 509 orig_tokens = whitespace_tokenize(text) 510 split_tokens = [] 511 for token in orig_tokens: 512 if self.do_lower_case and token not in self.never_split: 513 token = token.lower() 514 split_tokens.extend(self._run_split_on_punc(token)) 515 516 output_tokens = whitespace_tokenize(" ".join(split_tokens)) 517 return output_tokens 518 519 def _tokenize_chinese_chars(self, text): 520 """Adds whitespace around any CJK character.""" 521 output = [] 522 for char in text: 523 cp = ord(char) 524 if self._is_chinese_char(cp) or len(char.encode('utf-8')) > 1: 525 output.append(" ") 526 output.append(char) 527 output.append(" ") 528 else: 529 output.append(char) 530 return "".join(output) 531 532 533class CustomizedTextTokenizer(BertTokenizer): 534 """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 535 536 def __init__(self, vocab_file, do_lower_case=True, max_len=None, 537 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"), 538 vocab_map_ids_path=None): 539 """Constructs a CustomizedTokenizer. 540 541 Args: 542 vocab_file: Path to a one-wordpiece-per-line vocabulary file 543 do_lower_case: Whether to lower case the input 544 Only has an effect when do_wordpiece_only=False 545 max_len: An artificial maximum length to truncate tokenized sequences to; 546 Effective maximum length is always the minimum of this 547 value (if specified) and the underlying BERT model's 548 sequence length. 549 never_split: List of tokens which will never be split during tokenization. 550 Only has an effect when do_wordpiece_only=False 551 """ 552 super().__init__(vocab_file, do_lower_case, max_len, never_split) 553 554 self.vocab = load_vocab(vocab_file, vocab_map_ids_path) 555 self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 556 self.basic_tokenizer = CustomizedTextBasicTokenizer(do_lower_case=do_lower_case, never_split=never_split) 557 self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 558 self.max_len = max_len if max_len is not None else int(1e12) 559