• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "smartselect/token-feature-extractor.h"
18 
19 #include <string>
20 
21 #include "util/base/logging.h"
22 #include "util/hash/farmhash.h"
23 #include "util/strings/stringpiece.h"
24 #include "util/utf8/unicodetext.h"
25 #include "unicode/regex.h"
26 #include "unicode/uchar.h"
27 
28 namespace libtextclassifier {
29 
30 namespace {
31 
RemapTokenAscii(const std::string & token,const TokenFeatureExtractorOptions & options)32 std::string RemapTokenAscii(const std::string& token,
33                             const TokenFeatureExtractorOptions& options) {
34   if (!options.remap_digits && !options.lowercase_tokens) {
35     return token;
36   }
37 
38   std::string copy = token;
39   for (int i = 0; i < token.size(); ++i) {
40     if (options.remap_digits && isdigit(copy[i])) {
41       copy[i] = '0';
42     }
43     if (options.lowercase_tokens) {
44       copy[i] = tolower(copy[i]);
45     }
46   }
47   return copy;
48 }
49 
RemapTokenUnicode(const std::string & token,const TokenFeatureExtractorOptions & options,UnicodeText * remapped)50 void RemapTokenUnicode(const std::string& token,
51                        const TokenFeatureExtractorOptions& options,
52                        UnicodeText* remapped) {
53   if (!options.remap_digits && !options.lowercase_tokens) {
54     // Leave remapped untouched.
55     return;
56   }
57 
58   UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
59   icu::UnicodeString icu_string;
60   for (auto it = word.begin(); it != word.end(); ++it) {
61     if (options.remap_digits && u_isdigit(*it)) {
62       icu_string.append('0');
63     } else if (options.lowercase_tokens) {
64       icu_string.append(u_tolower(*it));
65     } else {
66       icu_string.append(*it);
67     }
68   }
69   std::string utf8_str;
70   icu_string.toUTF8String(utf8_str);
71   remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
72 }
73 
74 }  // namespace
75 
TokenFeatureExtractor(const TokenFeatureExtractorOptions & options)76 TokenFeatureExtractor::TokenFeatureExtractor(
77     const TokenFeatureExtractorOptions& options)
78     : options_(options) {
79   UErrorCode status;
80   for (const std::string& pattern : options.regexp_features) {
81     status = U_ZERO_ERROR;
82     regex_patterns_.push_back(
83         std::unique_ptr<icu::RegexPattern>(icu::RegexPattern::compile(
84             icu::UnicodeString(pattern.c_str(), pattern.size(), "utf-8"), 0,
85             status)));
86     if (U_FAILURE(status)) {
87       TC_LOG(WARNING) << "Failed to load pattern" << pattern;
88     }
89   }
90 }
91 
HashToken(StringPiece token) const92 int TokenFeatureExtractor::HashToken(StringPiece token) const {
93   return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
94 }
95 
ExtractCharactergramFeatures(const Token & token) const96 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
97     const Token& token) const {
98   if (options_.unicode_aware_features) {
99     return ExtractCharactergramFeaturesUnicode(token);
100   } else {
101     return ExtractCharactergramFeaturesAscii(token);
102   }
103 }
104 
ExtractCharactergramFeaturesAscii(const Token & token) const105 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
106     const Token& token) const {
107   std::vector<int> result;
108   if (token.is_padding || token.value.empty()) {
109     result.push_back(HashToken("<PAD>"));
110   } else {
111     const std::string word = RemapTokenAscii(token.value, options_);
112 
113     // Trim words that are over max_word_length characters.
114     const int max_word_length = options_.max_word_length;
115     std::string feature_word;
116     if (word.size() > max_word_length) {
117       feature_word =
118           "^" + word.substr(0, max_word_length / 2) + "\1" +
119           word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
120           "$";
121     } else {
122       // Add a prefix and suffix to the word.
123       feature_word = "^" + word + "$";
124     }
125 
126     // Upper-bound the number of charactergram extracted to avoid resizing.
127     result.reserve(options_.chargram_orders.size() * feature_word.size());
128 
129     // Generate the character-grams.
130     for (int chargram_order : options_.chargram_orders) {
131       if (chargram_order == 1) {
132         for (int i = 1; i < feature_word.size() - 1; ++i) {
133           result.push_back(
134               HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
135         }
136       } else {
137         for (int i = 0;
138              i < static_cast<int>(feature_word.size()) - chargram_order + 1;
139              ++i) {
140           result.push_back(HashToken(
141               StringPiece(feature_word, /*offset=*/i, /*len=*/chargram_order)));
142         }
143       }
144     }
145   }
146   return result;
147 }
148 
ExtractCharactergramFeaturesUnicode(const Token & token) const149 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
150     const Token& token) const {
151   std::vector<int> result;
152   if (token.is_padding || token.value.empty()) {
153     result.push_back(HashToken("<PAD>"));
154   } else {
155     UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
156     RemapTokenUnicode(token.value, options_, &word);
157 
158     // Trim the word if needed by finding a left-cut point and right-cut point.
159     auto left_cut = word.begin();
160     auto right_cut = word.end();
161     for (int i = 0; i < options_.max_word_length / 2; i++) {
162       if (left_cut < right_cut) {
163         ++left_cut;
164       }
165       if (left_cut < right_cut) {
166         --right_cut;
167       }
168     }
169 
170     std::string feature_word;
171     if (left_cut == right_cut) {
172       feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
173     } else {
174       // clang-format off
175       feature_word = "^" +
176                      word.UTF8Substring(word.begin(), left_cut) +
177                      "\1" +
178                      word.UTF8Substring(right_cut, word.end()) +
179                      "$";
180       // clang-format on
181     }
182 
183     const UnicodeText feature_word_unicode =
184         UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
185 
186     // Upper-bound the number of charactergram extracted to avoid resizing.
187     result.reserve(options_.chargram_orders.size() * feature_word.size());
188 
189     // Generate the character-grams.
190     for (int chargram_order : options_.chargram_orders) {
191       UnicodeText::const_iterator it_start = feature_word_unicode.begin();
192       UnicodeText::const_iterator it_end = feature_word_unicode.end();
193       if (chargram_order == 1) {
194         ++it_start;
195         --it_end;
196       }
197 
198       UnicodeText::const_iterator it_chargram_start = it_start;
199       UnicodeText::const_iterator it_chargram_end = it_start;
200       bool chargram_is_complete = true;
201       for (int i = 0; i < chargram_order; ++i) {
202         if (it_chargram_end == it_end) {
203           chargram_is_complete = false;
204           break;
205         }
206         ++it_chargram_end;
207       }
208       if (!chargram_is_complete) {
209         continue;
210       }
211 
212       for (; it_chargram_end <= it_end;
213            ++it_chargram_start, ++it_chargram_end) {
214         const int length_bytes =
215             it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
216         result.push_back(HashToken(
217             StringPiece(it_chargram_start.utf8_data(), length_bytes)));
218       }
219     }
220   }
221   return result;
222 }
223 
Extract(const Token & token,bool is_in_span,std::vector<int> * sparse_features,std::vector<float> * dense_features) const224 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
225                                     std::vector<int>* sparse_features,
226                                     std::vector<float>* dense_features) const {
227   if (sparse_features == nullptr || dense_features == nullptr) {
228     return false;
229   }
230 
231   *sparse_features = ExtractCharactergramFeatures(token);
232 
233   if (options_.extract_case_feature) {
234     if (options_.unicode_aware_features) {
235       UnicodeText token_unicode =
236           UTF8ToUnicodeText(token.value, /*do_copy=*/false);
237       if (!token.value.empty() && u_isupper(*token_unicode.begin())) {
238         dense_features->push_back(1.0);
239       } else {
240         dense_features->push_back(-1.0);
241       }
242     } else {
243       if (!token.value.empty() && isupper(*token.value.begin())) {
244         dense_features->push_back(1.0);
245       } else {
246         dense_features->push_back(-1.0);
247       }
248     }
249   }
250 
251   if (options_.extract_selection_mask_feature) {
252     if (is_in_span) {
253       dense_features->push_back(1.0);
254     } else {
255       if (options_.unicode_aware_features) {
256         dense_features->push_back(-1.0);
257       } else {
258         dense_features->push_back(0.0);
259       }
260     }
261   }
262 
263   // Add regexp features.
264   if (!regex_patterns_.empty()) {
265     icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
266                                    "utf-8");
267     for (int i = 0; i < regex_patterns_.size(); ++i) {
268       if (!regex_patterns_[i].get()) {
269         dense_features->push_back(-1.0);
270         continue;
271       }
272 
273       // Check for match.
274       UErrorCode status = U_ZERO_ERROR;
275       std::unique_ptr<icu::RegexMatcher> matcher(
276           regex_patterns_[i]->matcher(unicode_str, status));
277       if (matcher->find()) {
278         dense_features->push_back(1.0);
279       } else {
280         dense_features->push_back(-1.0);
281       }
282     }
283   }
284   return true;
285 }
286 
287 }  // namespace libtextclassifier
288