• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/token-feature-extractor.h"
18 
19 #include <cctype>
20 #include <string>
21 
22 #include "utils/base/logging.h"
23 #include "utils/hash/farmhash.h"
24 #include "utils/strings/stringpiece.h"
25 #include "utils/utf8/unicodetext.h"
26 
27 namespace libtextclassifier3 {
28 
29 namespace {
30 
RemapTokenAscii(const std::string & token,const TokenFeatureExtractorOptions & options)31 std::string RemapTokenAscii(const std::string& token,
32                             const TokenFeatureExtractorOptions& options) {
33   if (!options.remap_digits && !options.lowercase_tokens) {
34     return token;
35   }
36 
37   std::string copy = token;
38   for (int i = 0; i < token.size(); ++i) {
39     if (options.remap_digits && isdigit(copy[i])) {
40       copy[i] = '0';
41     }
42     if (options.lowercase_tokens) {
43       copy[i] = tolower(copy[i]);
44     }
45   }
46   return copy;
47 }
48 
RemapTokenUnicode(const std::string & token,const TokenFeatureExtractorOptions & options,const UniLib & unilib,UnicodeText * remapped)49 void RemapTokenUnicode(const std::string& token,
50                        const TokenFeatureExtractorOptions& options,
51                        const UniLib& unilib, UnicodeText* remapped) {
52   if (!options.remap_digits && !options.lowercase_tokens) {
53     // Leave remapped untouched.
54     return;
55   }
56 
57   UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
58   remapped->clear();
59   for (auto it = word.begin(); it != word.end(); ++it) {
60     if (options.remap_digits && unilib.IsDigit(*it)) {
61       remapped->push_back('0');
62     } else if (options.lowercase_tokens) {
63       remapped->push_back(unilib.ToLower(*it));
64     } else {
65       remapped->push_back(*it);
66     }
67   }
68 }
69 
70 }  // namespace
71 
TokenFeatureExtractor(const TokenFeatureExtractorOptions & options,const UniLib & unilib)72 TokenFeatureExtractor::TokenFeatureExtractor(
73     const TokenFeatureExtractorOptions& options, const UniLib& unilib)
74     : options_(options), unilib_(unilib) {
75   for (const std::string& pattern : options.regexp_features) {
76     regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
77         unilib_.CreateRegexPattern(UTF8ToUnicodeText(
78             pattern.c_str(), pattern.size(), /*do_copy=*/false))));
79   }
80 }
81 
Extract(const Token & token,bool is_in_span,std::vector<int> * sparse_features,std::vector<float> * dense_features) const82 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
83                                     std::vector<int>* sparse_features,
84                                     std::vector<float>* dense_features) const {
85   if (!dense_features) {
86     return false;
87   }
88   if (sparse_features) {
89     *sparse_features = ExtractCharactergramFeatures(token);
90   }
91   *dense_features = ExtractDenseFeatures(token, is_in_span);
92   return true;
93 }
94 
ExtractCharactergramFeatures(const Token & token) const95 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
96     const Token& token) const {
97   if (options_.unicode_aware_features) {
98     return ExtractCharactergramFeaturesUnicode(token);
99   } else {
100     return ExtractCharactergramFeaturesAscii(token);
101   }
102 }
103 
ExtractDenseFeatures(const Token & token,bool is_in_span) const104 std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
105     const Token& token, bool is_in_span) const {
106   std::vector<float> dense_features;
107 
108   if (options_.extract_case_feature) {
109     if (options_.unicode_aware_features) {
110       UnicodeText token_unicode =
111           UTF8ToUnicodeText(token.value, /*do_copy=*/false);
112       const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
113       if (!token.value.empty() && is_upper) {
114         dense_features.push_back(1.0);
115       } else {
116         dense_features.push_back(-1.0);
117       }
118     } else {
119       if (!token.value.empty() && isupper(*token.value.begin())) {
120         dense_features.push_back(1.0);
121       } else {
122         dense_features.push_back(-1.0);
123       }
124     }
125   }
126 
127   if (options_.extract_selection_mask_feature) {
128     if (is_in_span) {
129       dense_features.push_back(1.0);
130     } else {
131       if (options_.unicode_aware_features) {
132         dense_features.push_back(-1.0);
133       } else {
134         dense_features.push_back(0.0);
135       }
136     }
137   }
138 
139   // Add regexp features.
140   if (!regex_patterns_.empty()) {
141     UnicodeText token_unicode =
142         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
143     for (int i = 0; i < regex_patterns_.size(); ++i) {
144       if (!regex_patterns_[i].get()) {
145         dense_features.push_back(-1.0);
146         continue;
147       }
148       auto matcher = regex_patterns_[i]->Matcher(token_unicode);
149       int status;
150       if (matcher->Matches(&status)) {
151         dense_features.push_back(1.0);
152       } else {
153         dense_features.push_back(-1.0);
154       }
155     }
156   }
157 
158   return dense_features;
159 }
160 
HashToken(StringPiece token) const161 int TokenFeatureExtractor::HashToken(StringPiece token) const {
162   if (options_.allowed_chargrams.empty()) {
163     return tc3farmhash::Fingerprint64(token) % options_.num_buckets;
164   } else {
165     // Padding and out-of-vocabulary tokens have extra buckets reserved because
166     // they are special and important tokens, and we don't want them to share
167     // embedding with other charactergrams.
168     // TODO(zilka): Experimentally verify.
169     const int kNumExtraBuckets = 2;
170     const std::string token_string = token.ToString();
171     if (token_string == "<PAD>") {
172       return 1;
173     } else if (options_.allowed_chargrams.find(token_string) ==
174                options_.allowed_chargrams.end()) {
175       return 0;  // Out-of-vocabulary.
176     } else {
177       return (tc3farmhash::Fingerprint64(token) %
178               (options_.num_buckets - kNumExtraBuckets)) +
179              kNumExtraBuckets;
180     }
181   }
182 }
183 
ExtractCharactergramFeaturesAscii(const Token & token) const184 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
185     const Token& token) const {
186   std::vector<int> result;
187   if (token.is_padding || token.value.empty()) {
188     result.push_back(HashToken("<PAD>"));
189   } else {
190     const std::string word = RemapTokenAscii(token.value, options_);
191 
192     // Trim words that are over max_word_length characters.
193     const int max_word_length = options_.max_word_length;
194     std::string feature_word;
195     if (word.size() > max_word_length) {
196       feature_word =
197           "^" + word.substr(0, max_word_length / 2) + "\1" +
198           word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
199           "$";
200     } else {
201       // Add a prefix and suffix to the word.
202       feature_word = "^" + word + "$";
203     }
204 
205     // Upper-bound the number of charactergram extracted to avoid resizing.
206     result.reserve(options_.chargram_orders.size() * feature_word.size());
207 
208     if (options_.chargram_orders.empty()) {
209       result.push_back(HashToken(feature_word));
210     } else {
211       // Generate the character-grams.
212       for (int chargram_order : options_.chargram_orders) {
213         if (chargram_order == 1) {
214           for (int i = 1; i < feature_word.size() - 1; ++i) {
215             result.push_back(
216                 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
217           }
218         } else {
219           for (int i = 0;
220                i < static_cast<int>(feature_word.size()) - chargram_order + 1;
221                ++i) {
222             result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
223                                                    /*len=*/chargram_order)));
224           }
225         }
226       }
227     }
228   }
229   return result;
230 }
231 
ExtractCharactergramFeaturesUnicode(const Token & token) const232 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
233     const Token& token) const {
234   std::vector<int> result;
235   if (token.is_padding || token.value.empty()) {
236     result.push_back(HashToken("<PAD>"));
237   } else {
238     UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
239     RemapTokenUnicode(token.value, options_, unilib_, &word);
240 
241     // Trim the word if needed by finding a left-cut point and right-cut point.
242     auto left_cut = word.begin();
243     auto right_cut = word.end();
244     for (int i = 0; i < options_.max_word_length / 2; i++) {
245       if (left_cut < right_cut) {
246         ++left_cut;
247       }
248       if (left_cut < right_cut) {
249         --right_cut;
250       }
251     }
252 
253     std::string feature_word;
254     if (left_cut == right_cut) {
255       feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
256     } else {
257       // clang-format off
258       feature_word = "^" +
259                      word.UTF8Substring(word.begin(), left_cut) +
260                      "\1" +
261                      word.UTF8Substring(right_cut, word.end()) +
262                      "$";
263       // clang-format on
264     }
265 
266     const UnicodeText feature_word_unicode =
267         UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
268 
269     // Upper-bound the number of charactergram extracted to avoid resizing.
270     result.reserve(options_.chargram_orders.size() * feature_word.size());
271 
272     if (options_.chargram_orders.empty()) {
273       result.push_back(HashToken(feature_word));
274     } else {
275       // Generate the character-grams.
276       for (int chargram_order : options_.chargram_orders) {
277         UnicodeText::const_iterator it_start = feature_word_unicode.begin();
278         UnicodeText::const_iterator it_end = feature_word_unicode.end();
279         if (chargram_order == 1) {
280           ++it_start;
281           --it_end;
282         }
283 
284         UnicodeText::const_iterator it_chargram_start = it_start;
285         UnicodeText::const_iterator it_chargram_end = it_start;
286         bool chargram_is_complete = true;
287         for (int i = 0; i < chargram_order; ++i) {
288           if (it_chargram_end == it_end) {
289             chargram_is_complete = false;
290             break;
291           }
292           ++it_chargram_end;
293         }
294         if (!chargram_is_complete) {
295           continue;
296         }
297 
298         for (; it_chargram_end <= it_end;
299              ++it_chargram_start, ++it_chargram_end) {
300           const int length_bytes =
301               it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
302           result.push_back(HashToken(
303               StringPiece(it_chargram_start.utf8_data(), length_bytes)));
304         }
305       }
306     }
307   }
308   return result;
309 }
310 
311 }  // namespace libtextclassifier3
312