1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "smartselect/token-feature-extractor.h"
18
19 #include <string>
20
21 #include "util/base/logging.h"
22 #include "util/hash/farmhash.h"
23 #include "util/strings/stringpiece.h"
24 #include "util/utf8/unicodetext.h"
25 #include "unicode/regex.h"
26 #include "unicode/uchar.h"
27
28 namespace libtextclassifier {
29
30 namespace {
31
RemapTokenAscii(const std::string & token,const TokenFeatureExtractorOptions & options)32 std::string RemapTokenAscii(const std::string& token,
33 const TokenFeatureExtractorOptions& options) {
34 if (!options.remap_digits && !options.lowercase_tokens) {
35 return token;
36 }
37
38 std::string copy = token;
39 for (int i = 0; i < token.size(); ++i) {
40 if (options.remap_digits && isdigit(copy[i])) {
41 copy[i] = '0';
42 }
43 if (options.lowercase_tokens) {
44 copy[i] = tolower(copy[i]);
45 }
46 }
47 return copy;
48 }
49
RemapTokenUnicode(const std::string & token,const TokenFeatureExtractorOptions & options,UnicodeText * remapped)50 void RemapTokenUnicode(const std::string& token,
51 const TokenFeatureExtractorOptions& options,
52 UnicodeText* remapped) {
53 if (!options.remap_digits && !options.lowercase_tokens) {
54 // Leave remapped untouched.
55 return;
56 }
57
58 UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
59 icu::UnicodeString icu_string;
60 for (auto it = word.begin(); it != word.end(); ++it) {
61 if (options.remap_digits && u_isdigit(*it)) {
62 icu_string.append('0');
63 } else if (options.lowercase_tokens) {
64 icu_string.append(u_tolower(*it));
65 } else {
66 icu_string.append(*it);
67 }
68 }
69 std::string utf8_str;
70 icu_string.toUTF8String(utf8_str);
71 remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
72 }
73
74 } // namespace
75
TokenFeatureExtractor(const TokenFeatureExtractorOptions & options)76 TokenFeatureExtractor::TokenFeatureExtractor(
77 const TokenFeatureExtractorOptions& options)
78 : options_(options) {
79 UErrorCode status;
80 for (const std::string& pattern : options.regexp_features) {
81 status = U_ZERO_ERROR;
82 regex_patterns_.push_back(
83 std::unique_ptr<icu::RegexPattern>(icu::RegexPattern::compile(
84 icu::UnicodeString(pattern.c_str(), pattern.size(), "utf-8"), 0,
85 status)));
86 if (U_FAILURE(status)) {
87 TC_LOG(WARNING) << "Failed to load pattern" << pattern;
88 }
89 }
90 }
91
HashToken(StringPiece token) const92 int TokenFeatureExtractor::HashToken(StringPiece token) const {
93 return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
94 }
95
ExtractCharactergramFeatures(const Token & token) const96 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
97 const Token& token) const {
98 if (options_.unicode_aware_features) {
99 return ExtractCharactergramFeaturesUnicode(token);
100 } else {
101 return ExtractCharactergramFeaturesAscii(token);
102 }
103 }
104
ExtractCharactergramFeaturesAscii(const Token & token) const105 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
106 const Token& token) const {
107 std::vector<int> result;
108 if (token.is_padding || token.value.empty()) {
109 result.push_back(HashToken("<PAD>"));
110 } else {
111 const std::string word = RemapTokenAscii(token.value, options_);
112
113 // Trim words that are over max_word_length characters.
114 const int max_word_length = options_.max_word_length;
115 std::string feature_word;
116 if (word.size() > max_word_length) {
117 feature_word =
118 "^" + word.substr(0, max_word_length / 2) + "\1" +
119 word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
120 "$";
121 } else {
122 // Add a prefix and suffix to the word.
123 feature_word = "^" + word + "$";
124 }
125
126 // Upper-bound the number of charactergram extracted to avoid resizing.
127 result.reserve(options_.chargram_orders.size() * feature_word.size());
128
129 // Generate the character-grams.
130 for (int chargram_order : options_.chargram_orders) {
131 if (chargram_order == 1) {
132 for (int i = 1; i < feature_word.size() - 1; ++i) {
133 result.push_back(
134 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
135 }
136 } else {
137 for (int i = 0;
138 i < static_cast<int>(feature_word.size()) - chargram_order + 1;
139 ++i) {
140 result.push_back(HashToken(
141 StringPiece(feature_word, /*offset=*/i, /*len=*/chargram_order)));
142 }
143 }
144 }
145 }
146 return result;
147 }
148
ExtractCharactergramFeaturesUnicode(const Token & token) const149 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
150 const Token& token) const {
151 std::vector<int> result;
152 if (token.is_padding || token.value.empty()) {
153 result.push_back(HashToken("<PAD>"));
154 } else {
155 UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
156 RemapTokenUnicode(token.value, options_, &word);
157
158 // Trim the word if needed by finding a left-cut point and right-cut point.
159 auto left_cut = word.begin();
160 auto right_cut = word.end();
161 for (int i = 0; i < options_.max_word_length / 2; i++) {
162 if (left_cut < right_cut) {
163 ++left_cut;
164 }
165 if (left_cut < right_cut) {
166 --right_cut;
167 }
168 }
169
170 std::string feature_word;
171 if (left_cut == right_cut) {
172 feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
173 } else {
174 // clang-format off
175 feature_word = "^" +
176 word.UTF8Substring(word.begin(), left_cut) +
177 "\1" +
178 word.UTF8Substring(right_cut, word.end()) +
179 "$";
180 // clang-format on
181 }
182
183 const UnicodeText feature_word_unicode =
184 UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
185
186 // Upper-bound the number of charactergram extracted to avoid resizing.
187 result.reserve(options_.chargram_orders.size() * feature_word.size());
188
189 // Generate the character-grams.
190 for (int chargram_order : options_.chargram_orders) {
191 UnicodeText::const_iterator it_start = feature_word_unicode.begin();
192 UnicodeText::const_iterator it_end = feature_word_unicode.end();
193 if (chargram_order == 1) {
194 ++it_start;
195 --it_end;
196 }
197
198 UnicodeText::const_iterator it_chargram_start = it_start;
199 UnicodeText::const_iterator it_chargram_end = it_start;
200 bool chargram_is_complete = true;
201 for (int i = 0; i < chargram_order; ++i) {
202 if (it_chargram_end == it_end) {
203 chargram_is_complete = false;
204 break;
205 }
206 ++it_chargram_end;
207 }
208 if (!chargram_is_complete) {
209 continue;
210 }
211
212 for (; it_chargram_end <= it_end;
213 ++it_chargram_start, ++it_chargram_end) {
214 const int length_bytes =
215 it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
216 result.push_back(HashToken(
217 StringPiece(it_chargram_start.utf8_data(), length_bytes)));
218 }
219 }
220 }
221 return result;
222 }
223
Extract(const Token & token,bool is_in_span,std::vector<int> * sparse_features,std::vector<float> * dense_features) const224 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
225 std::vector<int>* sparse_features,
226 std::vector<float>* dense_features) const {
227 if (sparse_features == nullptr || dense_features == nullptr) {
228 return false;
229 }
230
231 *sparse_features = ExtractCharactergramFeatures(token);
232
233 if (options_.extract_case_feature) {
234 if (options_.unicode_aware_features) {
235 UnicodeText token_unicode =
236 UTF8ToUnicodeText(token.value, /*do_copy=*/false);
237 if (!token.value.empty() && u_isupper(*token_unicode.begin())) {
238 dense_features->push_back(1.0);
239 } else {
240 dense_features->push_back(-1.0);
241 }
242 } else {
243 if (!token.value.empty() && isupper(*token.value.begin())) {
244 dense_features->push_back(1.0);
245 } else {
246 dense_features->push_back(-1.0);
247 }
248 }
249 }
250
251 if (options_.extract_selection_mask_feature) {
252 if (is_in_span) {
253 dense_features->push_back(1.0);
254 } else {
255 if (options_.unicode_aware_features) {
256 dense_features->push_back(-1.0);
257 } else {
258 dense_features->push_back(0.0);
259 }
260 }
261 }
262
263 // Add regexp features.
264 if (!regex_patterns_.empty()) {
265 icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
266 "utf-8");
267 for (int i = 0; i < regex_patterns_.size(); ++i) {
268 if (!regex_patterns_[i].get()) {
269 dense_features->push_back(-1.0);
270 continue;
271 }
272
273 // Check for match.
274 UErrorCode status = U_ZERO_ERROR;
275 std::unique_ptr<icu::RegexMatcher> matcher(
276 regex_patterns_[i]->matcher(unicode_str, status));
277 if (matcher->find()) {
278 dense_features->push_back(1.0);
279 } else {
280 dense_features->push_back(-1.0);
281 }
282 }
283 }
284 return true;
285 }
286
287 } // namespace libtextclassifier
288