1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "smartselect/token-feature-extractor.h"
18
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21
22 namespace libtextclassifier {
23 namespace {
24
25 class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
26 public:
27 using TokenFeatureExtractor::TokenFeatureExtractor;
28 using TokenFeatureExtractor::HashToken;
29 };
30
TEST(TokenFeatureExtractorTest,ExtractAscii)31 TEST(TokenFeatureExtractorTest, ExtractAscii) {
32 TokenFeatureExtractorOptions options;
33 options.num_buckets = 1000;
34 options.chargram_orders = std::vector<int>{1, 2, 3};
35 options.extract_case_feature = true;
36 options.unicode_aware_features = false;
37 options.extract_selection_mask_feature = true;
38 TestingTokenFeatureExtractor extractor(options);
39
40 std::vector<int> sparse_features;
41 std::vector<float> dense_features;
42
43 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
44 &dense_features);
45
46 EXPECT_THAT(sparse_features,
47 testing::ElementsAreArray({
48 // clang-format off
49 extractor.HashToken("H"),
50 extractor.HashToken("e"),
51 extractor.HashToken("l"),
52 extractor.HashToken("l"),
53 extractor.HashToken("o"),
54 extractor.HashToken("^H"),
55 extractor.HashToken("He"),
56 extractor.HashToken("el"),
57 extractor.HashToken("ll"),
58 extractor.HashToken("lo"),
59 extractor.HashToken("o$"),
60 extractor.HashToken("^He"),
61 extractor.HashToken("Hel"),
62 extractor.HashToken("ell"),
63 extractor.HashToken("llo"),
64 extractor.HashToken("lo$")
65 // clang-format on
66 }));
67 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
68
69 sparse_features.clear();
70 dense_features.clear();
71 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
72 &dense_features);
73
74 EXPECT_THAT(sparse_features,
75 testing::ElementsAreArray({
76 // clang-format off
77 extractor.HashToken("w"),
78 extractor.HashToken("o"),
79 extractor.HashToken("r"),
80 extractor.HashToken("l"),
81 extractor.HashToken("d"),
82 extractor.HashToken("!"),
83 extractor.HashToken("^w"),
84 extractor.HashToken("wo"),
85 extractor.HashToken("or"),
86 extractor.HashToken("rl"),
87 extractor.HashToken("ld"),
88 extractor.HashToken("d!"),
89 extractor.HashToken("!$"),
90 extractor.HashToken("^wo"),
91 extractor.HashToken("wor"),
92 extractor.HashToken("orl"),
93 extractor.HashToken("rld"),
94 extractor.HashToken("ld!"),
95 extractor.HashToken("d!$"),
96 // clang-format on
97 }));
98 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
99 }
100
TEST(TokenFeatureExtractorTest,ExtractUnicode)101 TEST(TokenFeatureExtractorTest, ExtractUnicode) {
102 TokenFeatureExtractorOptions options;
103 options.num_buckets = 1000;
104 options.chargram_orders = std::vector<int>{1, 2, 3};
105 options.extract_case_feature = true;
106 options.unicode_aware_features = true;
107 options.extract_selection_mask_feature = true;
108 TestingTokenFeatureExtractor extractor(options);
109
110 std::vector<int> sparse_features;
111 std::vector<float> dense_features;
112
113 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
114 &dense_features);
115
116 EXPECT_THAT(sparse_features,
117 testing::ElementsAreArray({
118 // clang-format off
119 extractor.HashToken("H"),
120 extractor.HashToken("ě"),
121 extractor.HashToken("l"),
122 extractor.HashToken("l"),
123 extractor.HashToken("ó"),
124 extractor.HashToken("^H"),
125 extractor.HashToken("Hě"),
126 extractor.HashToken("ěl"),
127 extractor.HashToken("ll"),
128 extractor.HashToken("ló"),
129 extractor.HashToken("ó$"),
130 extractor.HashToken("^Hě"),
131 extractor.HashToken("Hěl"),
132 extractor.HashToken("ěll"),
133 extractor.HashToken("lló"),
134 extractor.HashToken("ló$")
135 // clang-format on
136 }));
137 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
138
139 sparse_features.clear();
140 dense_features.clear();
141 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
142 &dense_features);
143
144 EXPECT_THAT(sparse_features,
145 testing::ElementsAreArray({
146 // clang-format off
147 extractor.HashToken("w"),
148 extractor.HashToken("o"),
149 extractor.HashToken("r"),
150 extractor.HashToken("l"),
151 extractor.HashToken("d"),
152 extractor.HashToken("!"),
153 extractor.HashToken("^w"),
154 extractor.HashToken("wo"),
155 extractor.HashToken("or"),
156 extractor.HashToken("rl"),
157 extractor.HashToken("ld"),
158 extractor.HashToken("d!"),
159 extractor.HashToken("!$"),
160 extractor.HashToken("^wo"),
161 extractor.HashToken("wor"),
162 extractor.HashToken("orl"),
163 extractor.HashToken("rld"),
164 extractor.HashToken("ld!"),
165 extractor.HashToken("d!$"),
166 // clang-format on
167 }));
168 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
169 }
170
TEST(TokenFeatureExtractorTest,ICUCaseFeature)171 TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
172 TokenFeatureExtractorOptions options;
173 options.num_buckets = 1000;
174 options.chargram_orders = std::vector<int>{1, 2};
175 options.extract_case_feature = true;
176 options.unicode_aware_features = true;
177 options.extract_selection_mask_feature = false;
178 TokenFeatureExtractor extractor(options);
179
180 std::vector<int> sparse_features;
181 std::vector<float> dense_features;
182 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
183 &dense_features);
184 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
185
186 sparse_features.clear();
187 dense_features.clear();
188 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
189 &dense_features);
190 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
191
192 sparse_features.clear();
193 dense_features.clear();
194 extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
195 &dense_features);
196 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
197
198 sparse_features.clear();
199 dense_features.clear();
200 extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
201 &dense_features);
202 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
203 }
204
TEST(TokenFeatureExtractorTest,DigitRemapping)205 TEST(TokenFeatureExtractorTest, DigitRemapping) {
206 TokenFeatureExtractorOptions options;
207 options.num_buckets = 1000;
208 options.chargram_orders = std::vector<int>{1, 2};
209 options.remap_digits = true;
210 options.unicode_aware_features = false;
211 TokenFeatureExtractor extractor(options);
212
213 std::vector<int> sparse_features;
214 std::vector<float> dense_features;
215 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
216 &dense_features);
217
218 std::vector<int> sparse_features2;
219 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
220 &dense_features);
221 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
222
223 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
224 &dense_features);
225 EXPECT_THAT(sparse_features,
226 testing::Not(testing::ElementsAreArray(sparse_features2)));
227 }
228
TEST(TokenFeatureExtractorTest,DigitRemappingUnicode)229 TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
230 TokenFeatureExtractorOptions options;
231 options.num_buckets = 1000;
232 options.chargram_orders = std::vector<int>{1, 2};
233 options.remap_digits = true;
234 options.unicode_aware_features = true;
235 TokenFeatureExtractor extractor(options);
236
237 std::vector<int> sparse_features;
238 std::vector<float> dense_features;
239 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
240 &dense_features);
241
242 std::vector<int> sparse_features2;
243 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
244 &dense_features);
245 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
246
247 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
248 &dense_features);
249 EXPECT_THAT(sparse_features,
250 testing::Not(testing::ElementsAreArray(sparse_features2)));
251 }
252
TEST(TokenFeatureExtractorTest,LowercaseAscii)253 TEST(TokenFeatureExtractorTest, LowercaseAscii) {
254 TokenFeatureExtractorOptions options;
255 options.num_buckets = 1000;
256 options.chargram_orders = std::vector<int>{1, 2};
257 options.lowercase_tokens = true;
258 options.unicode_aware_features = false;
259 TokenFeatureExtractor extractor(options);
260
261 std::vector<int> sparse_features;
262 std::vector<float> dense_features;
263 extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
264 &dense_features);
265
266 std::vector<int> sparse_features2;
267 extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
268 &dense_features);
269 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
270
271 extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
272 &dense_features);
273 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
274 }
275
TEST(TokenFeatureExtractorTest,LowercaseUnicode)276 TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
277 TokenFeatureExtractorOptions options;
278 options.num_buckets = 1000;
279 options.chargram_orders = std::vector<int>{1, 2};
280 options.lowercase_tokens = true;
281 options.unicode_aware_features = true;
282 TokenFeatureExtractor extractor(options);
283
284 std::vector<int> sparse_features;
285 std::vector<float> dense_features;
286 extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
287
288 std::vector<int> sparse_features2;
289 extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
290 &dense_features);
291 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
292 }
293
TEST(TokenFeatureExtractorTest,RegexFeatures)294 TEST(TokenFeatureExtractorTest, RegexFeatures) {
295 TokenFeatureExtractorOptions options;
296 options.num_buckets = 1000;
297 options.chargram_orders = std::vector<int>{1, 2};
298 options.remap_digits = false;
299 options.unicode_aware_features = false;
300 options.regexp_features.push_back("^[a-z]+$"); // all lower case.
301 options.regexp_features.push_back("^[0-9]+$"); // all digits.
302 TokenFeatureExtractor extractor(options);
303
304 std::vector<int> sparse_features;
305 std::vector<float> dense_features;
306 extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
307 &dense_features);
308 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
309
310 dense_features.clear();
311 extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
312 &dense_features);
313 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
314
315 dense_features.clear();
316 extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
317 &dense_features);
318 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
319
320 dense_features.clear();
321 extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
322 &dense_features);
323 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
324 }
325
TEST(TokenFeatureExtractorTest,ExtractTooLongWord)326 TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
327 TokenFeatureExtractorOptions options;
328 options.num_buckets = 1000;
329 options.chargram_orders = std::vector<int>{22};
330 options.extract_case_feature = true;
331 options.unicode_aware_features = true;
332 options.extract_selection_mask_feature = true;
333 TestingTokenFeatureExtractor extractor(options);
334
335 // Test that this runs. ASAN should catch problems.
336 std::vector<int> sparse_features;
337 std::vector<float> dense_features;
338 extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
339 &sparse_features, &dense_features);
340
341 EXPECT_THAT(sparse_features,
342 testing::ElementsAreArray({
343 // clang-format off
344 extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
345 extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
346 // clang-format on
347 }));
348 }
349
TEST(TokenFeatureExtractorTest,ExtractAsciiUnicodeMatches)350 TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
351 TokenFeatureExtractorOptions options;
352 options.num_buckets = 1000;
353 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
354 options.extract_case_feature = true;
355 options.unicode_aware_features = true;
356 options.extract_selection_mask_feature = true;
357 TestingTokenFeatureExtractor extractor_unicode(options);
358
359 options.unicode_aware_features = false;
360 TestingTokenFeatureExtractor extractor_ascii(options);
361
362 for (const std::string& input :
363 {"https://www.abcdefgh.com/in/xxxkkkvayio",
364 "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
365 "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
366 "x", "Hello", "Hey,", "Hi", ""}) {
367 std::vector<int> sparse_features_unicode;
368 std::vector<float> dense_features_unicode;
369 extractor_unicode.Extract(Token{input, 0, 0}, true,
370 &sparse_features_unicode,
371 &dense_features_unicode);
372
373 std::vector<int> sparse_features_ascii;
374 std::vector<float> dense_features_ascii;
375 extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
376 &dense_features_ascii);
377
378 EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
379 EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
380 }
381 }
382
TEST(TokenFeatureExtractorTest,ExtractForPadToken)383 TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
384 TokenFeatureExtractorOptions options;
385 options.num_buckets = 1000;
386 options.chargram_orders = std::vector<int>{1, 2};
387 options.extract_case_feature = true;
388 options.unicode_aware_features = false;
389 options.extract_selection_mask_feature = true;
390
391 TestingTokenFeatureExtractor extractor(options);
392
393 std::vector<int> sparse_features;
394 std::vector<float> dense_features;
395
396 extractor.Extract(Token(), false, &sparse_features, &dense_features);
397
398 EXPECT_THAT(sparse_features,
399 testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
400 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
401 }
402
403 } // namespace
404 } // namespace libtextclassifier
405