• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/token-feature-extractor.h"
18 
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 
22 namespace libtextclassifier3 {
23 namespace {
24 
25 class TokenFeatureExtractorTest : public ::testing::Test {
26  protected:
TokenFeatureExtractorTest()27   explicit TokenFeatureExtractorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
28   UniLib unilib_;
29 };
30 
31 class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
32  public:
33   using TokenFeatureExtractor::HashToken;
34   using TokenFeatureExtractor::TokenFeatureExtractor;
35 };
36 
TEST_F(TokenFeatureExtractorTest,ExtractAscii)37 TEST_F(TokenFeatureExtractorTest, ExtractAscii) {
38   TokenFeatureExtractorOptions options;
39   options.num_buckets = 1000;
40   options.chargram_orders = std::vector<int>{1, 2, 3};
41   options.extract_case_feature = true;
42   options.unicode_aware_features = false;
43   options.extract_selection_mask_feature = true;
44   TestingTokenFeatureExtractor extractor(options, &unilib_);
45 
46   std::vector<int> sparse_features;
47   std::vector<float> dense_features;
48 
49   extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
50                     &dense_features);
51 
52   EXPECT_THAT(sparse_features,
53               testing::ElementsAreArray({
54                   // clang-format off
55                   extractor.HashToken("H"),
56                   extractor.HashToken("e"),
57                   extractor.HashToken("l"),
58                   extractor.HashToken("l"),
59                   extractor.HashToken("o"),
60                   extractor.HashToken("^H"),
61                   extractor.HashToken("He"),
62                   extractor.HashToken("el"),
63                   extractor.HashToken("ll"),
64                   extractor.HashToken("lo"),
65                   extractor.HashToken("o$"),
66                   extractor.HashToken("^He"),
67                   extractor.HashToken("Hel"),
68                   extractor.HashToken("ell"),
69                   extractor.HashToken("llo"),
70                   extractor.HashToken("lo$")
71                   // clang-format on
72               }));
73   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
74 
75   sparse_features.clear();
76   dense_features.clear();
77   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
78                     &dense_features);
79 
80   EXPECT_THAT(sparse_features,
81               testing::ElementsAreArray({
82                   // clang-format off
83                   extractor.HashToken("w"),
84                   extractor.HashToken("o"),
85                   extractor.HashToken("r"),
86                   extractor.HashToken("l"),
87                   extractor.HashToken("d"),
88                   extractor.HashToken("!"),
89                   extractor.HashToken("^w"),
90                   extractor.HashToken("wo"),
91                   extractor.HashToken("or"),
92                   extractor.HashToken("rl"),
93                   extractor.HashToken("ld"),
94                   extractor.HashToken("d!"),
95                   extractor.HashToken("!$"),
96                   extractor.HashToken("^wo"),
97                   extractor.HashToken("wor"),
98                   extractor.HashToken("orl"),
99                   extractor.HashToken("rld"),
100                   extractor.HashToken("ld!"),
101                   extractor.HashToken("d!$"),
102                   // clang-format on
103               }));
104   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
105 }
106 
TEST_F(TokenFeatureExtractorTest,ExtractAsciiNoChargrams)107 TEST_F(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
108   TokenFeatureExtractorOptions options;
109   options.num_buckets = 1000;
110   options.chargram_orders = std::vector<int>{};
111   options.extract_case_feature = true;
112   options.unicode_aware_features = false;
113   options.extract_selection_mask_feature = true;
114   TestingTokenFeatureExtractor extractor(options, &unilib_);
115 
116   std::vector<int> sparse_features;
117   std::vector<float> dense_features;
118 
119   extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
120                     &dense_features);
121 
122   EXPECT_THAT(sparse_features,
123               testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
124   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
125 
126   sparse_features.clear();
127   dense_features.clear();
128   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
129                     &dense_features);
130 
131   EXPECT_THAT(sparse_features,
132               testing::ElementsAreArray({extractor.HashToken("^world!$")}));
133   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
134 }
135 
TEST_F(TokenFeatureExtractorTest,ExtractUnicode)136 TEST_F(TokenFeatureExtractorTest, ExtractUnicode) {
137   TokenFeatureExtractorOptions options;
138   options.num_buckets = 1000;
139   options.chargram_orders = std::vector<int>{1, 2, 3};
140   options.extract_case_feature = true;
141   options.unicode_aware_features = true;
142   options.extract_selection_mask_feature = true;
143   TestingTokenFeatureExtractor extractor(options, &unilib_);
144 
145   std::vector<int> sparse_features;
146   std::vector<float> dense_features;
147 
148   extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
149                     &dense_features);
150 
151   EXPECT_THAT(sparse_features,
152               testing::ElementsAreArray({
153                   // clang-format off
154                   extractor.HashToken("H"),
155                   extractor.HashToken("ě"),
156                   extractor.HashToken("l"),
157                   extractor.HashToken("l"),
158                   extractor.HashToken("ó"),
159                   extractor.HashToken("^H"),
160                   extractor.HashToken("Hě"),
161                   extractor.HashToken("ěl"),
162                   extractor.HashToken("ll"),
163                   extractor.HashToken("ló"),
164                   extractor.HashToken("ó$"),
165                   extractor.HashToken("^Hě"),
166                   extractor.HashToken("Hěl"),
167                   extractor.HashToken("ěll"),
168                   extractor.HashToken("lló"),
169                   extractor.HashToken("ló$")
170                   // clang-format on
171               }));
172   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
173 
174   sparse_features.clear();
175   dense_features.clear();
176   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
177                     &dense_features);
178 
179   EXPECT_THAT(sparse_features,
180               testing::ElementsAreArray({
181                   // clang-format off
182                   extractor.HashToken("w"),
183                   extractor.HashToken("o"),
184                   extractor.HashToken("r"),
185                   extractor.HashToken("l"),
186                   extractor.HashToken("d"),
187                   extractor.HashToken("!"),
188                   extractor.HashToken("^w"),
189                   extractor.HashToken("wo"),
190                   extractor.HashToken("or"),
191                   extractor.HashToken("rl"),
192                   extractor.HashToken("ld"),
193                   extractor.HashToken("d!"),
194                   extractor.HashToken("!$"),
195                   extractor.HashToken("^wo"),
196                   extractor.HashToken("wor"),
197                   extractor.HashToken("orl"),
198                   extractor.HashToken("rld"),
199                   extractor.HashToken("ld!"),
200                   extractor.HashToken("d!$"),
201                   // clang-format on
202               }));
203   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
204 }
205 
TEST_F(TokenFeatureExtractorTest,ExtractUnicodeNoChargrams)206 TEST_F(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
207   TokenFeatureExtractorOptions options;
208   options.num_buckets = 1000;
209   options.chargram_orders = std::vector<int>{};
210   options.extract_case_feature = true;
211   options.unicode_aware_features = true;
212   options.extract_selection_mask_feature = true;
213   TestingTokenFeatureExtractor extractor(options, &unilib_);
214 
215   std::vector<int> sparse_features;
216   std::vector<float> dense_features;
217 
218   extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
219                     &dense_features);
220 
221   EXPECT_THAT(sparse_features,
222               testing::ElementsAreArray({extractor.HashToken("^Hělló$")}));
223   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
224 
225   sparse_features.clear();
226   dense_features.clear();
227   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
228                     &dense_features);
229 
230   EXPECT_THAT(sparse_features, testing::ElementsAreArray({
231                                    extractor.HashToken("^world!$"),
232                                }));
233   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
234 }
235 
236 #ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest,ICUCaseFeature)237 TEST_F(TokenFeatureExtractorTest, ICUCaseFeature) {
238   TokenFeatureExtractorOptions options;
239   options.num_buckets = 1000;
240   options.chargram_orders = std::vector<int>{1, 2};
241   options.extract_case_feature = true;
242   options.unicode_aware_features = true;
243   options.extract_selection_mask_feature = false;
244   TestingTokenFeatureExtractor extractor(options, &unilib_);
245 
246   std::vector<int> sparse_features;
247   std::vector<float> dense_features;
248   extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
249                     &dense_features);
250   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
251 
252   sparse_features.clear();
253   dense_features.clear();
254   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
255                     &dense_features);
256   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
257 
258   sparse_features.clear();
259   dense_features.clear();
260   extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
261                     &dense_features);
262   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
263 
264   sparse_features.clear();
265   dense_features.clear();
266   extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
267                     &dense_features);
268   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
269 }
270 #endif
271 
TEST_F(TokenFeatureExtractorTest,DigitRemapping)272 TEST_F(TokenFeatureExtractorTest, DigitRemapping) {
273   TokenFeatureExtractorOptions options;
274   options.num_buckets = 1000;
275   options.chargram_orders = std::vector<int>{1, 2};
276   options.remap_digits = true;
277   options.unicode_aware_features = false;
278   TestingTokenFeatureExtractor extractor(options, &unilib_);
279 
280   std::vector<int> sparse_features;
281   std::vector<float> dense_features;
282   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
283                     &dense_features);
284 
285   std::vector<int> sparse_features2;
286   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
287                     &dense_features);
288   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
289 
290   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
291                     &dense_features);
292   EXPECT_THAT(sparse_features,
293               testing::Not(testing::ElementsAreArray(sparse_features2)));
294 }
295 
TEST_F(TokenFeatureExtractorTest,DigitRemappingUnicode)296 TEST_F(TokenFeatureExtractorTest, DigitRemappingUnicode) {
297   TokenFeatureExtractorOptions options;
298   options.num_buckets = 1000;
299   options.chargram_orders = std::vector<int>{1, 2};
300   options.remap_digits = true;
301   options.unicode_aware_features = true;
302   TestingTokenFeatureExtractor extractor(options, &unilib_);
303 
304   std::vector<int> sparse_features;
305   std::vector<float> dense_features;
306   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
307                     &dense_features);
308 
309   std::vector<int> sparse_features2;
310   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
311                     &dense_features);
312   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
313 
314   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
315                     &dense_features);
316   EXPECT_THAT(sparse_features,
317               testing::Not(testing::ElementsAreArray(sparse_features2)));
318 }
319 
TEST_F(TokenFeatureExtractorTest,LowercaseAscii)320 TEST_F(TokenFeatureExtractorTest, LowercaseAscii) {
321   TokenFeatureExtractorOptions options;
322   options.num_buckets = 1000;
323   options.chargram_orders = std::vector<int>{1, 2};
324   options.lowercase_tokens = true;
325   options.unicode_aware_features = false;
326   TestingTokenFeatureExtractor extractor(options, &unilib_);
327 
328   std::vector<int> sparse_features;
329   std::vector<float> dense_features;
330   extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
331                     &dense_features);
332 
333   std::vector<int> sparse_features2;
334   extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
335                     &dense_features);
336   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
337 
338   extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
339                     &dense_features);
340   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
341 }
342 
343 #ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest,LowercaseUnicode)344 TEST_F(TokenFeatureExtractorTest, LowercaseUnicode) {
345   TokenFeatureExtractorOptions options;
346   options.num_buckets = 1000;
347   options.chargram_orders = std::vector<int>{1, 2};
348   options.lowercase_tokens = true;
349   options.unicode_aware_features = true;
350   TestingTokenFeatureExtractor extractor(options, &unilib_);
351 
352   std::vector<int> sparse_features;
353   std::vector<float> dense_features;
354   extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
355 
356   std::vector<int> sparse_features2;
357   extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
358                     &dense_features);
359   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
360 }
361 #endif
362 
363 #ifdef TC3_TEST_ICU
TEST_F(TokenFeatureExtractorTest,RegexFeatures)364 TEST_F(TokenFeatureExtractorTest, RegexFeatures) {
365   TokenFeatureExtractorOptions options;
366   options.num_buckets = 1000;
367   options.chargram_orders = std::vector<int>{1, 2};
368   options.remap_digits = false;
369   options.unicode_aware_features = false;
370   options.regexp_features.push_back("^[a-z]+$");  // all lower case.
371   options.regexp_features.push_back("^[0-9]+$");  // all digits.
372   TestingTokenFeatureExtractor extractor(options, &unilib_);
373 
374   std::vector<int> sparse_features;
375   std::vector<float> dense_features;
376   extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
377                     &dense_features);
378   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
379 
380   dense_features.clear();
381   extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
382                     &dense_features);
383   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
384 
385   dense_features.clear();
386   extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
387                     &dense_features);
388   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
389 
390   dense_features.clear();
391   extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
392                     &dense_features);
393   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
394 }
395 #endif
396 
TEST_F(TokenFeatureExtractorTest,ExtractTooLongWord)397 TEST_F(TokenFeatureExtractorTest, ExtractTooLongWord) {
398   TokenFeatureExtractorOptions options;
399   options.num_buckets = 1000;
400   options.chargram_orders = std::vector<int>{22};
401   options.extract_case_feature = true;
402   options.unicode_aware_features = true;
403   options.extract_selection_mask_feature = true;
404   TestingTokenFeatureExtractor extractor(options, &unilib_);
405 
406   // Test that this runs. ASAN should catch problems.
407   std::vector<int> sparse_features;
408   std::vector<float> dense_features;
409   extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
410                     &sparse_features, &dense_features);
411 
412   EXPECT_THAT(sparse_features,
413               testing::ElementsAreArray({
414                   // clang-format off
415                   extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
416                   extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
417                   // clang-format on
418               }));
419 }
420 
TEST_F(TokenFeatureExtractorTest,ExtractAsciiUnicodeMatches)421 TEST_F(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
422   TokenFeatureExtractorOptions options;
423   options.num_buckets = 1000;
424   options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
425   options.extract_case_feature = true;
426   options.unicode_aware_features = true;
427   options.extract_selection_mask_feature = true;
428 
429   TestingTokenFeatureExtractor extractor_unicode(options, &unilib_);
430 
431   options.unicode_aware_features = false;
432   TestingTokenFeatureExtractor extractor_ascii(options, &unilib_);
433 
434   for (const std::string& input :
435        {"https://www.abcdefgh.com/in/xxxkkkvayio",
436         "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
437         "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
438         "x", "Hello", "Hey,", "Hi", ""}) {
439     std::vector<int> sparse_features_unicode;
440     std::vector<float> dense_features_unicode;
441     extractor_unicode.Extract(Token{input, 0, 0}, true,
442                               &sparse_features_unicode,
443                               &dense_features_unicode);
444 
445     std::vector<int> sparse_features_ascii;
446     std::vector<float> dense_features_ascii;
447     extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
448                             &dense_features_ascii);
449 
450     EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
451     EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
452   }
453 }
454 
TEST_F(TokenFeatureExtractorTest,ExtractForPadToken)455 TEST_F(TokenFeatureExtractorTest, ExtractForPadToken) {
456   TokenFeatureExtractorOptions options;
457   options.num_buckets = 1000;
458   options.chargram_orders = std::vector<int>{1, 2};
459   options.extract_case_feature = true;
460   options.unicode_aware_features = false;
461   options.extract_selection_mask_feature = true;
462 
463   TestingTokenFeatureExtractor extractor(options, &unilib_);
464 
465   std::vector<int> sparse_features;
466   std::vector<float> dense_features;
467 
468   extractor.Extract(Token(), false, &sparse_features, &dense_features);
469 
470   EXPECT_THAT(sparse_features,
471               testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
472   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
473 }
474 
TEST_F(TokenFeatureExtractorTest,ExtractFiltered)475 TEST_F(TokenFeatureExtractorTest, ExtractFiltered) {
476   TokenFeatureExtractorOptions options;
477   options.num_buckets = 1000;
478   options.chargram_orders = std::vector<int>{1, 2, 3};
479   options.extract_case_feature = true;
480   options.unicode_aware_features = false;
481   options.extract_selection_mask_feature = true;
482   options.allowed_chargrams.insert("^H");
483   options.allowed_chargrams.insert("ll");
484   options.allowed_chargrams.insert("llo");
485   options.allowed_chargrams.insert("w");
486   options.allowed_chargrams.insert("!");
487   options.allowed_chargrams.insert("\xc4");  // UTF8 control character.
488 
489   TestingTokenFeatureExtractor extractor(options, &unilib_);
490 
491   std::vector<int> sparse_features;
492   std::vector<float> dense_features;
493 
494   extractor.Extract(Token{"Hěllo", 0, 5}, true, &sparse_features,
495                     &dense_features);
496 
497   EXPECT_THAT(sparse_features,
498               testing::ElementsAreArray({
499                   // clang-format off
500                   0,
501                   extractor.HashToken("\xc4"),
502                   0,
503                   0,
504                   0,
505                   0,
506                   extractor.HashToken("^H"),
507                   0,
508                   0,
509                   0,
510                   extractor.HashToken("ll"),
511                   0,
512                   0,
513                   0,
514                   0,
515                   0,
516                   0,
517                   extractor.HashToken("llo"),
518                   0
519                   // clang-format on
520               }));
521   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
522 
523   sparse_features.clear();
524   dense_features.clear();
525   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
526                     &dense_features);
527 
528   EXPECT_THAT(sparse_features, testing::ElementsAreArray({
529                                    // clang-format off
530                   extractor.HashToken("w"),
531                   0,
532                   0,
533                   0,
534                   0,
535                   extractor.HashToken("!"),
536                   0,
537                   0,
538                   0,
539                   0,
540                   0,
541                   0,
542                   0,
543                   0,
544                   0,
545                   0,
546                   0,
547                   0,
548                   0,
549                                    // clang-format on
550                                }));
551   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
552   EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
553 }
554 
TEST_F(TokenFeatureExtractorTest,ExtractEmptyToken)555 TEST_F(TokenFeatureExtractorTest, ExtractEmptyToken) {
556   TokenFeatureExtractorOptions options;
557   options.num_buckets = 1000;
558   options.chargram_orders = std::vector<int>{1, 2, 3};
559   options.extract_case_feature = true;
560   options.unicode_aware_features = false;
561   options.extract_selection_mask_feature = true;
562   TestingTokenFeatureExtractor extractor(options, &unilib_);
563 
564   std::vector<int> sparse_features;
565   std::vector<float> dense_features;
566 
567   // Should not crash.
568   extractor.Extract(Token(), true, &sparse_features, &dense_features);
569 
570   EXPECT_THAT(sparse_features, testing::ElementsAreArray({
571                                    // clang-format off
572                   extractor.HashToken("<PAD>"),
573                                    // clang-format on
574                                }));
575   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
576 }
577 
578 }  // namespace
579 }  // namespace libtextclassifier3
580