1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/wordpiece_tokenizer.h"
18
19 #include "utils/utf8/unicodetext.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23
24 namespace libtextclassifier3 {
25
26 namespace {
27
Lookup(int byte_start,int byte_end,const absl::string_view token,const std::string & suffix_indicator,const WordpieceVocab * vocab_map,bool * in_vocab)28 LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
29 const std::string& suffix_indicator,
30 const WordpieceVocab* vocab_map, bool* in_vocab) {
31 int byte_len = byte_end - byte_start;
32 absl::string_view substr(token.data() + byte_start, byte_len);
33 std::string lookup_value;
34 if (byte_start > 0) {
35 lookup_value = absl::StrCat(suffix_indicator, substr);
36 } else {
37 // absl::CopyToString
38 lookup_value.assign(substr.begin(), substr.end());
39 }
40 return vocab_map->Contains(lookup_value, in_vocab);
41 }
42
43 // Sets byte_end to the longest byte sequence which:
44 // 1) is a proper UTF8 sequence
45 // 2) is in the vocab OR if split_unknown_characters is true, is a single
46 // UTF8 character.
47 // If no match is found, found_match is set to false.
LongestMatchStartingAt(int byte_start,const absl::string_view token,const std::string & suffix_indicator,const int max_chars_per_subtoken,bool split_unknown_characters,const WordpieceVocab * vocab_map,int * byte_end,bool * found_match,bool * match_is_unknown_character)48 LookupStatus LongestMatchStartingAt(
49 int byte_start, const absl::string_view token,
50 const std::string& suffix_indicator, const int max_chars_per_subtoken,
51 bool split_unknown_characters, const WordpieceVocab* vocab_map,
52 int* byte_end, bool* found_match, bool* match_is_unknown_character) {
53 *match_is_unknown_character = false;
54 *found_match = false;
55 const UnicodeText unicode_token =
56 UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
57 std::vector<int32_t> byte_ends;
58 int32_t codepoint_offset = byte_start;
59 for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
60 codepoint_offset += it.utf8_length();
61 byte_ends.push_back(codepoint_offset);
62 if (max_chars_per_subtoken > 0 &&
63 byte_ends.size() == max_chars_per_subtoken) {
64 // If the max bytes of a subtoken is known, do not search beyond that
65 // length.
66 break;
67 }
68 }
69 int n = byte_ends.size();
70 for (int i = n - 1; i >= 0; i--) {
71 bool in_vocab;
72 auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
73 vocab_map, &in_vocab);
74 if (!status.success) return status;
75 if (in_vocab) {
76 *byte_end = byte_ends[i];
77 *found_match = true;
78 return LookupStatus::OK();
79 }
80 if (i == 0 && split_unknown_characters) {
81 *byte_end = byte_ends[0];
82 *found_match = true;
83 *match_is_unknown_character = true;
84 return LookupStatus::OK();
85 }
86 }
87 return LookupStatus::OK();
88 }
89
90 // Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
91 // token is found.
NoTokenFound(const absl::string_view token,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)92 LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
93 const std::string& unknown_token,
94 std::vector<std::string>* subwords,
95 std::vector<int>* begin_offset,
96 std::vector<int>* end_offset, int* num_word_pieces) {
97 begin_offset->push_back(0);
98 if (use_unknown_token) {
99 subwords->push_back(unknown_token);
100 end_offset->push_back(token.length());
101 } else {
102 subwords->emplace_back(token.data(), token.length());
103 end_offset->push_back(token.length());
104 }
105 ++(*num_word_pieces);
106
107 return LookupStatus::OK();
108 }
109
110 // When a subword is found, this helper function will add the outputs to
111 // 'subwords', 'begin_offset' and 'end_offset'.
AddWord(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)112 void AddWord(const absl::string_view token, int byte_start, int byte_end,
113 const std::string& suffix_indicator,
114 std::vector<std::string>* subwords, std::vector<int>* begin_offset,
115 std::vector<int>* end_offset) {
116 begin_offset->push_back(byte_start);
117 int len = byte_end - byte_start;
118
119 if (byte_start > 0) {
120 // Prepend suffix_indicator if the token is within a word.
121 subwords->push_back(::absl::StrCat(
122 suffix_indicator, absl::string_view(token.data() + byte_start, len)));
123 } else {
124 subwords->emplace_back(token.data(), len);
125 }
126 end_offset->push_back(byte_end);
127 }
128
129 // Adds a single unknown character subword, found when split_unknown_characters
130 // is true.
AddUnknownCharacter(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)131 void AddUnknownCharacter(const absl::string_view token, int byte_start,
132 int byte_end, const std::string& suffix_indicator,
133 bool use_unknown_token,
134 const std::string& unknown_token,
135 std::vector<std::string>* subwords,
136 std::vector<int>* begin_offset,
137 std::vector<int>* end_offset) {
138 begin_offset->push_back(byte_start);
139 end_offset->push_back(byte_end);
140 int len = byte_end - byte_start;
141 if (use_unknown_token) {
142 if (byte_start > 0) {
143 // Prepend suffix_indicator if the character is within a word.
144 subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
145 } else {
146 subwords->push_back(unknown_token);
147 }
148 } else {
149 if (byte_start > 0) {
150 // Prepend suffix_indicator if the character is within a word.
151 subwords->push_back(::absl::StrCat(
152 suffix_indicator, absl::string_view(token.data() + byte_start, len)));
153 } else {
154 subwords->emplace_back(token.data(), len);
155 }
156 }
157 }
158
TokenizeL2RGreedy(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)159 LookupStatus TokenizeL2RGreedy(
160 const absl::string_view token, const int max_bytes_per_token,
161 const int max_chars_per_subtoken, const std::string& suffix_indicator,
162 bool use_unknown_token, const std::string& unknown_token,
163 bool split_unknown_characters, const WordpieceVocab* vocab_map,
164 std::vector<std::string>* subwords, std::vector<int>* begin_offset,
165 std::vector<int>* end_offset, int* num_word_pieces) {
166 std::vector<std::string> candidate_subwords;
167 std::vector<int> candidate_begin_offsets;
168 std::vector<int> candidate_end_offsets;
169 const int token_len = token.length();
170 for (int byte_start = 0; byte_start < token_len;) {
171 int byte_end;
172 bool found_subword;
173 bool match_is_unknown_character;
174 auto status = LongestMatchStartingAt(
175 byte_start, token, suffix_indicator, max_chars_per_subtoken,
176 split_unknown_characters, vocab_map, &byte_end, &found_subword,
177 &match_is_unknown_character);
178 if (!status.success) return status;
179 if (found_subword) {
180 if (match_is_unknown_character) {
181 AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
182 use_unknown_token, unknown_token,
183 &candidate_subwords, &candidate_begin_offsets,
184 &candidate_end_offsets);
185 } else {
186 AddWord(token, byte_start, byte_end, suffix_indicator,
187 &candidate_subwords, &candidate_begin_offsets,
188 &candidate_end_offsets);
189 }
190 byte_start = byte_end;
191 } else {
192 return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
193 begin_offset, end_offset, num_word_pieces);
194 }
195 }
196
197 subwords->insert(subwords->end(), candidate_subwords.begin(),
198 candidate_subwords.end());
199 begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
200 candidate_begin_offsets.end());
201 end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
202 candidate_end_offsets.end());
203 *num_word_pieces += candidate_subwords.size();
204 return LookupStatus::OK();
205 }
206
207 } // namespace
208
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)209 LookupStatus WordpieceTokenize(
210 const absl::string_view token, const int max_bytes_per_token,
211 const int max_chars_per_subtoken, const std::string& suffix_indicator,
212 bool use_unknown_token, const std::string& unknown_token,
213 bool split_unknown_characters, const WordpieceVocab* vocab_map,
214 std::vector<std::string>* subwords, std::vector<int>* begin_offset,
215 std::vector<int>* end_offset, int* num_word_pieces) {
216 int token_len = token.size();
217 if (token_len > max_bytes_per_token) {
218 begin_offset->push_back(0);
219 *num_word_pieces = 1;
220 if (use_unknown_token) {
221 subwords->emplace_back(unknown_token);
222 } else {
223 subwords->emplace_back(token);
224 }
225 end_offset->push_back(token.size());
226 return LookupStatus::OK();
227 }
228 return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
229 suffix_indicator, use_unknown_token, unknown_token,
230 split_unknown_characters, vocab_map, subwords,
231 begin_offset, end_offset, num_word_pieces);
232 }
233
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)234 LookupStatus WordpieceTokenize(
235 const absl::string_view token, const int max_bytes_per_token,
236 const std::string& suffix_indicator, bool use_unknown_token,
237 const std::string& unknown_token, const WordpieceVocab* vocab_map,
238 std::vector<std::string>* subwords, std::vector<int>* begin_offset,
239 std::vector<int>* end_offset, int* num_word_pieces) {
240 return WordpieceTokenize(token, max_bytes_per_token,
241 /* max_chars_per_subtoken= */ 0, suffix_indicator,
242 use_unknown_token, unknown_token,
243 /* split_unknown_characters= */ false, vocab_map,
244 subwords, begin_offset, end_offset, num_word_pieces);
245 }
246 } // namespace libtextclassifier3
247