1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/token-feature-extractor.h"
18
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21
22 namespace libtextclassifier3 {
23 namespace {
24
25 class TokenFeatureExtractorTest : public ::testing::Test {
26 protected:
TokenFeatureExtractorTest()27 explicit TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
28 UniLib unilib_;
29 };
30
31 class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
32 public:
33 using TokenFeatureExtractor::HashToken;
34 using TokenFeatureExtractor::TokenFeatureExtractor;
35 };
36
TEST_F(TokenFeatureExtractorTest,ExtractAscii)37 TEST_F(TokenFeatureExtractorTest, ExtractAscii) {
38 TokenFeatureExtractorOptions options;
39 options.num_buckets = 1000;
40 options.chargram_orders = std::vector<int>{1, 2, 3};
41 options.extract_case_feature = true;
42 options.unicode_aware_features = false;
43 options.extract_selection_mask_feature = true;
44 TestingTokenFeatureExtractor extractor(options, &unilib_);
45
46 std::vector<int> sparse_features;
47 std::vector<float> dense_features;
48
49 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
50 &dense_features);
51
52 EXPECT_THAT(sparse_features,
53 testing::ElementsAreArray({
54 // clang-format off
55 extractor.HashToken("H"),
56 extractor.HashToken("e"),
57 extractor.HashToken("l"),
58 extractor.HashToken("l"),
59 extractor.HashToken("o"),
60 extractor.HashToken("^H"),
61 extractor.HashToken("He"),
62 extractor.HashToken("el"),
63 extractor.HashToken("ll"),
64 extractor.HashToken("lo"),
65 extractor.HashToken("o$"),
66 extractor.HashToken("^He"),
67 extractor.HashToken("Hel"),
68 extractor.HashToken("ell"),
69 extractor.HashToken("llo"),
70 extractor.HashToken("lo$")
71 // clang-format on
72 }));
73 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
74
75 sparse_features.clear();
76 dense_features.clear();
77 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
78 &dense_features);
79
80 EXPECT_THAT(sparse_features,
81 testing::ElementsAreArray({
82 // clang-format off
83 extractor.HashToken("w"),
84 extractor.HashToken("o"),
85 extractor.HashToken("r"),
86 extractor.HashToken("l"),
87 extractor.HashToken("d"),
88 extractor.HashToken("!"),
89 extractor.HashToken("^w"),
90 extractor.HashToken("wo"),
91 extractor.HashToken("or"),
92 extractor.HashToken("rl"),
93 extractor.HashToken("ld"),
94 extractor.HashToken("d!"),
95 extractor.HashToken("!$"),
96 extractor.HashToken("^wo"),
97 extractor.HashToken("wor"),
98 extractor.HashToken("orl"),
99 extractor.HashToken("rld"),
100 extractor.HashToken("ld!"),
101 extractor.HashToken("d!$"),
102 // clang-format on
103 }));
104 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
105 }
106
TEST_F(TokenFeatureExtractorTest,ExtractAsciiNoChargrams)107 TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
108 TokenFeatureExtractorOptions options;
109 options.num_buckets = 1000;
110 options.chargram_orders = std::vector<int>{};
111 options.extract_case_feature = true;
112 options.unicode_aware_features = false;
113 options.extract_selection_mask_feature = true;
114 TestingTokenFeatureExtractor extractor(options, &unilib_);
115
116 std::vector<int> sparse_features;
117 std::vector<float> dense_features;
118
119 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
120 &dense_features);
121
122 EXPECT_THAT(sparse_features,
123 testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
124 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
125
126 sparse_features.clear();
127 dense_features.clear();
128 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
129 &dense_features);
130
131 EXPECT_THAT(sparse_features,
132 testing::ElementsAreArray({extractor.HashToken("^world!$")}));
133 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
134 }
135
TEST_F(TokenFeatureExtractorTest,ExtractUnicode)136 TEST_F(TokenFeatureExtractorTest, ExtractUnicode) {
137 TokenFeatureExtractorOptions options;
138 options.num_buckets = 1000;
139 options.chargram_orders = std::vector<int>{1, 2, 3};
140 options.extract_case_feature = true;
141 options.unicode_aware_features = true;
142 options.extract_selection_mask_feature = true;
143 TestingTokenFeatureExtractor extractor(options, &unilib_);
144
145 std::vector<int> sparse_features;
146 std::vector<float> dense_features;
147
148 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
149 &dense_features);
150
151 EXPECT_THAT(sparse_features,
152 testing::ElementsAreArray({
153 // clang-format off
154 extractor.HashToken("H"),
155 extractor.HashToken("ě"),
156 extractor.HashToken("l"),
157 extractor.HashToken("l"),
158 extractor.HashToken("ó"),
159 extractor.HashToken("^H"),
160 extractor.HashToken("Hě"),
161 extractor.HashToken("ěl"),
162 extractor.HashToken("ll"),
163 extractor.HashToken("ló"),
164 extractor.HashToken("ó$"),
165 extractor.HashToken("^Hě"),
166 extractor.HashToken("Hěl"),
167 extractor.HashToken("ěll"),
168 extractor.HashToken("lló"),
169 extractor.HashToken("ló$")
170 // clang-format on
171 }));
172 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
173
174 sparse_features.clear();
175 dense_features.clear();
176 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
177 &dense_features);
178
179 EXPECT_THAT(sparse_features,
180 testing::ElementsAreArray({
181 // clang-format off
182 extractor.HashToken("w"),
183 extractor.HashToken("o"),
184 extractor.HashToken("r"),
185 extractor.HashToken("l"),
186 extractor.HashToken("d"),
187 extractor.HashToken("!"),
188 extractor.HashToken("^w"),
189 extractor.HashToken("wo"),
190 extractor.HashToken("or"),
191 extractor.HashToken("rl"),
192 extractor.HashToken("ld"),
193 extractor.HashToken("d!"),
194 extractor.HashToken("!$"),
195 extractor.HashToken("^wo"),
196 extractor.HashToken("wor"),
197 extractor.HashToken("orl"),
198 extractor.HashToken("rld"),
199 extractor.HashToken("ld!"),
200 extractor.HashToken("d!$"),
201 // clang-format on
202 }));
203 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
204 }
205
TEST_F(TokenFeatureExtractorTest,ExtractUnicodeNoChargrams)206 TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
207 TokenFeatureExtractorOptions options;
208 options.num_buckets = 1000;
209 options.chargram_orders = std::vector<int>{};
210 options.extract_case_feature = true;
211 options.unicode_aware_features = true;
212 options.extract_selection_mask_feature = true;
213 TestingTokenFeatureExtractor extractor(options, &unilib_);
214
215 std::vector<int> sparse_features;
216 std::vector<float> dense_features;
217
218 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
219 &dense_features);
220
221 EXPECT_THAT(sparse_features,
222 testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
223 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
224
225 sparse_features.clear();
226 dense_features.clear();
227 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
228 &dense_features);
229
230 EXPECT_THAT(sparse_features, testing::ElementsAreArray({
231 extractor.HashToken("^world!$"),
232 }));
233 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
234 }
235
236 #ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest,ICUCaseFeature)237 TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
238 TokenFeatureExtractorOptions options;
239 options.num_buckets = 1000;
240 options.chargram_orders = std::vector<int>{1, 2};
241 options.extract_case_feature = true;
242 options.unicode_aware_features = true;
243 options.extract_selection_mask_feature = false;
244 TestingTokenFeatureExtractor extractor(options, &unilib_);
245
246 std::vector<int> sparse_features;
247 std::vector<float> dense_features;
248 extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
249 &dense_features);
250 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
251
252 sparse_features.clear();
253 dense_features.clear();
254 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
255 &dense_features);
256 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
257
258 sparse_features.clear();
259 dense_features.clear();
260 extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
261 &dense_features);
262 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
263
264 sparse_features.clear();
265 dense_features.clear();
266 extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
267 &dense_features);
268 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
269 }
270 #endif
271
TEST_F(TokenFeatureExtractorTest,DigitRemapping)272 TEST_F(TokenFeatureExtractorTest, DigitRemapping) {
273 TokenFeatureExtractorOptions options;
274 options.num_buckets = 1000;
275 options.chargram_orders = std::vector<int>{1, 2};
276 options.remap_digits = true;
277 options.unicode_aware_features = false;
278 TestingTokenFeatureExtractor extractor(options, &unilib_);
279
280 std::vector<int> sparse_features;
281 std::vector<float> dense_features;
282 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
283 &dense_features);
284
285 std::vector<int> sparse_features2;
286 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
287 &dense_features);
288 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
289
290 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
291 &dense_features);
292 EXPECT_THAT(sparse_features,
293 testing::Not(testing::ElementsAreArray(sparse_features2)));
294 }
295
TEST_F(TokenFeatureExtractorTest,DigitRemappingUnicode)296 TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) {
297 TokenFeatureExtractorOptions options;
298 options.num_buckets = 1000;
299 options.chargram_orders = std::vector<int>{1, 2};
300 options.remap_digits = true;
301 options.unicode_aware_features = true;
302 TestingTokenFeatureExtractor extractor(options, &unilib_);
303
304 std::vector<int> sparse_features;
305 std::vector<float> dense_features;
306 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
307 &dense_features);
308
309 std::vector<int> sparse_features2;
310 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
311 &dense_features);
312 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
313
314 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
315 &dense_features);
316 EXPECT_THAT(sparse_features,
317 testing::Not(testing::ElementsAreArray(sparse_features2)));
318 }
319
TEST_F(TokenFeatureExtractorTest,LowercaseAscii)320 TEST_F(TokenFeatureExtractorTest, LowercaseAscii) {
321 TokenFeatureExtractorOptions options;
322 options.num_buckets = 1000;
323 options.chargram_orders = std::vector<int>{1, 2};
324 options.lowercase_tokens = true;
325 options.unicode_aware_features = false;
326 TestingTokenFeatureExtractor extractor(options, &unilib_);
327
328 std::vector<int> sparse_features;
329 std::vector<float> dense_features;
330 extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
331 &dense_features);
332
333 std::vector<int> sparse_features2;
334 extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
335 &dense_features);
336 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
337
338 extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
339 &dense_features);
340 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
341 }
342
343 #ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest,LowercaseUnicode)344 TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
345 TokenFeatureExtractorOptions options;
346 options.num_buckets = 1000;
347 options.chargram_orders = std::vector<int>{1, 2};
348 options.lowercase_tokens = true;
349 options.unicode_aware_features = true;
350 TestingTokenFeatureExtractor extractor(options, &unilib_);
351
352 std::vector<int> sparse_features;
353 std::vector<float> dense_features;
354 extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
355
356 std::vector<int> sparse_features2;
357 extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
358 &dense_features);
359 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
360 }
361 #endif
362
363 #ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest,RegexFeatures)364 TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
365 TokenFeatureExtractorOptions options;
366 options.num_buckets = 1000;
367 options.chargram_orders = std::vector<int>{1, 2};
368 options.remap_digits = false;
369 options.unicode_aware_features = false;
370 options.regexp_features.push_back("^[a-z]+$"); // all lower case.
371 options.regexp_features.push_back("^[0-9]+$"); // all digits.
372 TestingTokenFeatureExtractor extractor(options, &unilib_);
373
374 std::vector<int> sparse_features;
375 std::vector<float> dense_features;
376 extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
377 &dense_features);
378 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
379
380 dense_features.clear();
381 extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
382 &dense_features);
383 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
384
385 dense_features.clear();
386 extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
387 &dense_features);
388 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
389
390 dense_features.clear();
391 extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
392 &dense_features);
393 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
394 }
395 #endif
396
TEST_F(TokenFeatureExtractorTest,ExtractTooLongWord)397 TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) {
398 TokenFeatureExtractorOptions options;
399 options.num_buckets = 1000;
400 options.chargram_orders = std::vector<int>{22};
401 options.extract_case_feature = true;
402 options.unicode_aware_features = true;
403 options.extract_selection_mask_feature = true;
404 TestingTokenFeatureExtractor extractor(options, &unilib_);
405
406 // Test that this runs. ASAN should catch problems.
407 std::vector<int> sparse_features;
408 std::vector<float> dense_features;
409 extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
410 &sparse_features, &dense_features);
411
412 EXPECT_THAT(sparse_features,
413 testing::ElementsAreArray({
414 // clang-format off
415 extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
416 extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
417 // clang-format on
418 }));
419 }
420
TEST_F(TokenFeatureExtractorTest,ExtractAsciiUnicodeMatches)421 TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
422 TokenFeatureExtractorOptions options;
423 options.num_buckets = 1000;
424 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
425 options.extract_case_feature = true;
426 options.unicode_aware_features = true;
427 options.extract_selection_mask_feature = true;
428
429 TestingTokenFeatureExtractor extractor_unicode(options, &unilib_);
430
431 options.unicode_aware_features = false;
432 TestingTokenFeatureExtractor extractor_ascii(options, &unilib_);
433
434 for (const std::string& input :
435 {"https://www.abcdefgh.com/in/xxxkkkvayio",
436 "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
437 "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
438 "x", "Hello", "Hey,", "Hi", ""}) {
439 std::vector<int> sparse_features_unicode;
440 std::vector<float> dense_features_unicode;
441 extractor_unicode.Extract(Token{input, 0, 0}, true,
442 &sparse_features_unicode,
443 &dense_features_unicode);
444
445 std::vector<int> sparse_features_ascii;
446 std::vector<float> dense_features_ascii;
447 extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
448 &dense_features_ascii);
449
450 EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
451 EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
452 }
453 }
454
TEST_F(TokenFeatureExtractorTest,ExtractForPadToken)455 TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) {
456 TokenFeatureExtractorOptions options;
457 options.num_buckets = 1000;
458 options.chargram_orders = std::vector<int>{1, 2};
459 options.extract_case_feature = true;
460 options.unicode_aware_features = false;
461 options.extract_selection_mask_feature = true;
462
463 TestingTokenFeatureExtractor extractor(options, &unilib_);
464
465 std::vector<int> sparse_features;
466 std::vector<float> dense_features;
467
468 extractor.Extract(Token(), false, &sparse_features, &dense_features);
469
470 EXPECT_THAT(sparse_features,
471 testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
472 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
473 }
474
TEST_F(TokenFeatureExtractorTest,ExtractFiltered)475 TEST_F(TokenFeatureExtractorTest, ExtractFiltered) {
476 TokenFeatureExtractorOptions options;
477 options.num_buckets = 1000;
478 options.chargram_orders = std::vector<int>{1, 2, 3};
479 options.extract_case_feature = true;
480 options.unicode_aware_features = false;
481 options.extract_selection_mask_feature = true;
482 options.allowed_chargrams.insert("^H");
483 options.allowed_chargrams.insert("ll");
484 options.allowed_chargrams.insert("llo");
485 options.allowed_chargrams.insert("w");
486 options.allowed_chargrams.insert("!");
487 options.allowed_chargrams.insert("\xc4"); // UTF8 control character.
488
489 TestingTokenFeatureExtractor extractor(options, &unilib_);
490
491 std::vector<int> sparse_features;
492 std::vector<float> dense_features;
493
494 extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
495 &dense_features);
496
497 EXPECT_THAT(sparse_features,
498 testing::ElementsAreArray({
499 // clang-format off
500 0,
501 extractor.HashToken("\xc4"),
502 0,
503 0,
504 0,
505 0,
506 extractor.HashToken("^H"),
507 0,
508 0,
509 0,
510 extractor.HashToken("ll"),
511 0,
512 0,
513 0,
514 0,
515 0,
516 0,
517 extractor.HashToken("llo"),
518 0
519 // clang-format on
520 }));
521 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
522
523 sparse_features.clear();
524 dense_features.clear();
525 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
526 &dense_features);
527
528 EXPECT_THAT(sparse_features, testing::ElementsAreArray({
529 // clang-format off
530 extractor.HashToken("w"),
531 0,
532 0,
533 0,
534 0,
535 extractor.HashToken("!"),
536 0,
537 0,
538 0,
539 0,
540 0,
541 0,
542 0,
543 0,
544 0,
545 0,
546 0,
547 0,
548 0,
549 // clang-format on
550 }));
551 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
552 EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
553 }
554
TEST_F(TokenFeatureExtractorTest,ExtractEmptyToken)555 TEST_F(TokenFeatureExtractorTest, ExtractEmptyToken) {
556 TokenFeatureExtractorOptions options;
557 options.num_buckets = 1000;
558 options.chargram_orders = std::vector<int>{1, 2, 3};
559 options.extract_case_feature = true;
560 options.unicode_aware_features = false;
561 options.extract_selection_mask_feature = true;
562 TestingTokenFeatureExtractor extractor(options, &unilib_);
563
564 std::vector<int> sparse_features;
565 std::vector<float> dense_features;
566
567 // Should not crash.
568 extractor.Extract(Token(), true, &sparse_features, &dense_features);
569
570 EXPECT_THAT(sparse_features, testing::ElementsAreArray({
571 // clang-format off
572 extractor.HashToken("<PAD>"),
573 // clang-format on
574 }));
575 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
576 }
577
578 } // namespace
579 } // namespace libtextclassifier3
580