1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/token-feature-extractor.h"
18
19 #include <cctype>
20 #include <string>
21
22 #include "utils/base/logging.h"
23 #include "utils/hash/farmhash.h"
24 #include "utils/strings/stringpiece.h"
25 #include "utils/utf8/unicodetext.h"
26
27 namespace libtextclassifier3 {
28
29 namespace {
30
RemapTokenAscii(const std::string & token,const TokenFeatureExtractorOptions & options)31 std::string RemapTokenAscii(const std::string& token,
32 const TokenFeatureExtractorOptions& options) {
33 if (!options.remap_digits && !options.lowercase_tokens) {
34 return token;
35 }
36
37 std::string copy = token;
38 for (int i = 0; i < token.size(); ++i) {
39 if (options.remap_digits && isdigit(copy[i])) {
40 copy[i] = '0';
41 }
42 if (options.lowercase_tokens) {
43 copy[i] = tolower(copy[i]);
44 }
45 }
46 return copy;
47 }
48
RemapTokenUnicode(const std::string & token,const TokenFeatureExtractorOptions & options,const UniLib & unilib,UnicodeText * remapped)49 void RemapTokenUnicode(const std::string& token,
50 const TokenFeatureExtractorOptions& options,
51 const UniLib& unilib, UnicodeText* remapped) {
52 if (!options.remap_digits && !options.lowercase_tokens) {
53 // Leave remapped untouched.
54 return;
55 }
56
57 UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
58 remapped->clear();
59 for (auto it = word.begin(); it != word.end(); ++it) {
60 if (options.remap_digits && unilib.IsDigit(*it)) {
61 remapped->push_back('0');
62 } else if (options.lowercase_tokens) {
63 remapped->push_back(unilib.ToLower(*it));
64 } else {
65 remapped->push_back(*it);
66 }
67 }
68 }
69
70 } // namespace
71
TokenFeatureExtractor(const TokenFeatureExtractorOptions & options,const UniLib * unilib)72 TokenFeatureExtractor::TokenFeatureExtractor(
73 const TokenFeatureExtractorOptions& options, const UniLib* unilib)
74 : options_(options), unilib_(*unilib) {
75 for (const std::string& pattern : options.regexp_features) {
76 regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
77 unilib_.CreateRegexPattern(UTF8ToUnicodeText(
78 pattern.c_str(), pattern.size(), /*do_copy=*/false))));
79 }
80 }
81
Extract(const Token & token,bool is_in_span,std::vector<int> * sparse_features,std::vector<float> * dense_features) const82 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
83 std::vector<int>* sparse_features,
84 std::vector<float>* dense_features) const {
85 if (!dense_features) {
86 return false;
87 }
88 if (sparse_features) {
89 *sparse_features = ExtractCharactergramFeatures(token);
90 }
91 *dense_features = ExtractDenseFeatures(token, is_in_span);
92 return true;
93 }
94
ExtractCharactergramFeatures(const Token & token) const95 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
96 const Token& token) const {
97 if (options_.unicode_aware_features) {
98 return ExtractCharactergramFeaturesUnicode(token);
99 } else {
100 return ExtractCharactergramFeaturesAscii(token);
101 }
102 }
103
ExtractDenseFeatures(const Token & token,bool is_in_span) const104 std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
105 const Token& token, bool is_in_span) const {
106 std::vector<float> dense_features;
107
108 if (options_.extract_case_feature) {
109 if (options_.unicode_aware_features) {
110 UnicodeText token_unicode =
111 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
112 if (!token.value.empty() && unilib_.IsUpper(*token_unicode.begin())) {
113 dense_features.push_back(1.0);
114 } else {
115 dense_features.push_back(-1.0);
116 }
117 } else {
118 if (!token.value.empty() && isupper(*token.value.begin())) {
119 dense_features.push_back(1.0);
120 } else {
121 dense_features.push_back(-1.0);
122 }
123 }
124 }
125
126 if (options_.extract_selection_mask_feature) {
127 if (is_in_span) {
128 dense_features.push_back(1.0);
129 } else {
130 if (options_.unicode_aware_features) {
131 dense_features.push_back(-1.0);
132 } else {
133 dense_features.push_back(0.0);
134 }
135 }
136 }
137
138 // Add regexp features.
139 if (!regex_patterns_.empty()) {
140 UnicodeText token_unicode =
141 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
142 for (int i = 0; i < regex_patterns_.size(); ++i) {
143 if (!regex_patterns_[i].get()) {
144 dense_features.push_back(-1.0);
145 continue;
146 }
147 auto matcher = regex_patterns_[i]->Matcher(token_unicode);
148 int status;
149 if (matcher->Matches(&status)) {
150 dense_features.push_back(1.0);
151 } else {
152 dense_features.push_back(-1.0);
153 }
154 }
155 }
156
157 return dense_features;
158 }
159
HashToken(StringPiece token) const160 int TokenFeatureExtractor::HashToken(StringPiece token) const {
161 if (options_.allowed_chargrams.empty()) {
162 return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
163 } else {
164 // Padding and out-of-vocabulary tokens have extra buckets reserved because
165 // they are special and important tokens, and we don't want them to share
166 // embedding with other charactergrams.
167 // TODO(zilka): Experimentally verify.
168 const int kNumExtraBuckets = 2;
169 const std::string token_string = token.ToString();
170 if (token_string == "<PAD>") {
171 return 1;
172 } else if (options_.allowed_chargrams.find(token_string) ==
173 options_.allowed_chargrams.end()) {
174 return 0; // Out-of-vocabulary.
175 } else {
176 return (tc3farmhash::Fingerprint64(token) %
177 (options_.num_buckets - kNumExtraBuckets)) +
178 kNumExtraBuckets;
179 }
180 }
181 }
182
ExtractCharactergramFeaturesAscii(const Token & token) const183 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
184 const Token& token) const {
185 std::vector<int> result;
186 if (token.is_padding || token.value.empty()) {
187 result.push_back(HashToken("<PAD>"));
188 } else {
189 const std::string word = RemapTokenAscii(token.value, options_);
190
191 // Trim words that are over max_word_length characters.
192 const int max_word_length = options_.max_word_length;
193 std::string feature_word;
194 if (word.size() > max_word_length) {
195 feature_word =
196 "^" + word.substr(0, max_word_length / 2) + "\1" +
197 word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
198 "$";
199 } else {
200 // Add a prefix and suffix to the word.
201 feature_word = "^" + word + "$";
202 }
203
204 // Upper-bound the number of charactergram extracted to avoid resizing.
205 result.reserve(options_.chargram_orders.size() * feature_word.size());
206
207 if (options_.chargram_orders.empty()) {
208 result.push_back(HashToken(feature_word));
209 } else {
210 // Generate the character-grams.
211 for (int chargram_order : options_.chargram_orders) {
212 if (chargram_order == 1) {
213 for (int i = 1; i < feature_word.size() - 1; ++i) {
214 result.push_back(
215 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
216 }
217 } else {
218 for (int i = 0;
219 i < static_cast<int>(feature_word.size()) - chargram_order + 1;
220 ++i) {
221 result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
222 /*len=*/chargram_order)));
223 }
224 }
225 }
226 }
227 }
228 return result;
229 }
230
ExtractCharactergramFeaturesUnicode(const Token & token) const231 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
232 const Token& token) const {
233 std::vector<int> result;
234 if (token.is_padding || token.value.empty()) {
235 result.push_back(HashToken("<PAD>"));
236 } else {
237 UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
238 RemapTokenUnicode(token.value, options_, unilib_, &word);
239
240 // Trim the word if needed by finding a left-cut point and right-cut point.
241 auto left_cut = word.begin();
242 auto right_cut = word.end();
243 for (int i = 0; i < options_.max_word_length / 2; i++) {
244 if (left_cut < right_cut) {
245 ++left_cut;
246 }
247 if (left_cut < right_cut) {
248 --right_cut;
249 }
250 }
251
252 std::string feature_word;
253 if (left_cut == right_cut) {
254 feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
255 } else {
256 // clang-format off
257 feature_word = "^" +
258 word.UTF8Substring(word.begin(), left_cut) +
259 "\1" +
260 word.UTF8Substring(right_cut, word.end()) +
261 "$";
262 // clang-format on
263 }
264
265 const UnicodeText feature_word_unicode =
266 UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
267
268 // Upper-bound the number of charactergram extracted to avoid resizing.
269 result.reserve(options_.chargram_orders.size() * feature_word.size());
270
271 if (options_.chargram_orders.empty()) {
272 result.push_back(HashToken(feature_word));
273 } else {
274 // Generate the character-grams.
275 for (int chargram_order : options_.chargram_orders) {
276 UnicodeText::const_iterator it_start = feature_word_unicode.begin();
277 UnicodeText::const_iterator it_end = feature_word_unicode.end();
278 if (chargram_order == 1) {
279 ++it_start;
280 --it_end;
281 }
282
283 UnicodeText::const_iterator it_chargram_start = it_start;
284 UnicodeText::const_iterator it_chargram_end = it_start;
285 bool chargram_is_complete = true;
286 for (int i = 0; i < chargram_order; ++i) {
287 if (it_chargram_end == it_end) {
288 chargram_is_complete = false;
289 break;
290 }
291 ++it_chargram_end;
292 }
293 if (!chargram_is_complete) {
294 continue;
295 }
296
297 for (; it_chargram_end <= it_end;
298 ++it_chargram_start, ++it_chargram_end) {
299 const int length_bytes =
300 it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
301 result.push_back(HashToken(
302 StringPiece(it_chargram_start.utf8_data(), length_bytes)));
303 }
304 }
305 }
306 }
307 return result;
308 }
309
310 } // namespace libtextclassifier3
311