• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "smartselect/token-feature-extractor.h"
18 
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 
22 namespace libtextclassifier {
23 namespace {
24 
25 class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
26  public:
27   using TokenFeatureExtractor::TokenFeatureExtractor;
28   using TokenFeatureExtractor::HashToken;
29 };
30 
TEST(TokenFeatureExtractorTest,ExtractAscii)31 TEST(TokenFeatureExtractorTest, ExtractAscii) {
32   TokenFeatureExtractorOptions options;
33   options.num_buckets = 1000;
34   options.chargram_orders = std::vector<int>{1, 2, 3};
35   options.extract_case_feature = true;
36   options.unicode_aware_features = false;
37   options.extract_selection_mask_feature = true;
38   TestingTokenFeatureExtractor extractor(options);
39 
40   std::vector<int> sparse_features;
41   std::vector<float> dense_features;
42 
43   extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
44                     &dense_features);
45 
46   EXPECT_THAT(sparse_features,
47               testing::ElementsAreArray({
48                   // clang-format off
49                   extractor.HashToken("H"),
50                   extractor.HashToken("e"),
51                   extractor.HashToken("l"),
52                   extractor.HashToken("l"),
53                   extractor.HashToken("o"),
54                   extractor.HashToken("^H"),
55                   extractor.HashToken("He"),
56                   extractor.HashToken("el"),
57                   extractor.HashToken("ll"),
58                   extractor.HashToken("lo"),
59                   extractor.HashToken("o$"),
60                   extractor.HashToken("^He"),
61                   extractor.HashToken("Hel"),
62                   extractor.HashToken("ell"),
63                   extractor.HashToken("llo"),
64                   extractor.HashToken("lo$")
65                   // clang-format on
66               }));
67   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
68 
69   sparse_features.clear();
70   dense_features.clear();
71   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
72                     &dense_features);
73 
74   EXPECT_THAT(sparse_features,
75               testing::ElementsAreArray({
76                   // clang-format off
77                   extractor.HashToken("w"),
78                   extractor.HashToken("o"),
79                   extractor.HashToken("r"),
80                   extractor.HashToken("l"),
81                   extractor.HashToken("d"),
82                   extractor.HashToken("!"),
83                   extractor.HashToken("^w"),
84                   extractor.HashToken("wo"),
85                   extractor.HashToken("or"),
86                   extractor.HashToken("rl"),
87                   extractor.HashToken("ld"),
88                   extractor.HashToken("d!"),
89                   extractor.HashToken("!$"),
90                   extractor.HashToken("^wo"),
91                   extractor.HashToken("wor"),
92                   extractor.HashToken("orl"),
93                   extractor.HashToken("rld"),
94                   extractor.HashToken("ld!"),
95                   extractor.HashToken("d!$"),
96                   // clang-format on
97               }));
98   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
99 }
100 
TEST(TokenFeatureExtractorTest,ExtractUnicode)101 TEST(TokenFeatureExtractorTest, ExtractUnicode) {
102   TokenFeatureExtractorOptions options;
103   options.num_buckets = 1000;
104   options.chargram_orders = std::vector<int>{1, 2, 3};
105   options.extract_case_feature = true;
106   options.unicode_aware_features = true;
107   options.extract_selection_mask_feature = true;
108   TestingTokenFeatureExtractor extractor(options);
109 
110   std::vector<int> sparse_features;
111   std::vector<float> dense_features;
112 
113   extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
114                     &dense_features);
115 
116   EXPECT_THAT(sparse_features,
117               testing::ElementsAreArray({
118                   // clang-format off
119                   extractor.HashToken("H"),
120                   extractor.HashToken("ě"),
121                   extractor.HashToken("l"),
122                   extractor.HashToken("l"),
123                   extractor.HashToken("ó"),
124                   extractor.HashToken("^H"),
125                   extractor.HashToken("Hě"),
126                   extractor.HashToken("ěl"),
127                   extractor.HashToken("ll"),
128                   extractor.HashToken("ló"),
129                   extractor.HashToken("ó$"),
130                   extractor.HashToken("^Hě"),
131                   extractor.HashToken("Hěl"),
132                   extractor.HashToken("ěll"),
133                   extractor.HashToken("lló"),
134                   extractor.HashToken("ló$")
135                   // clang-format on
136               }));
137   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
138 
139   sparse_features.clear();
140   dense_features.clear();
141   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
142                     &dense_features);
143 
144   EXPECT_THAT(sparse_features,
145               testing::ElementsAreArray({
146                   // clang-format off
147                   extractor.HashToken("w"),
148                   extractor.HashToken("o"),
149                   extractor.HashToken("r"),
150                   extractor.HashToken("l"),
151                   extractor.HashToken("d"),
152                   extractor.HashToken("!"),
153                   extractor.HashToken("^w"),
154                   extractor.HashToken("wo"),
155                   extractor.HashToken("or"),
156                   extractor.HashToken("rl"),
157                   extractor.HashToken("ld"),
158                   extractor.HashToken("d!"),
159                   extractor.HashToken("!$"),
160                   extractor.HashToken("^wo"),
161                   extractor.HashToken("wor"),
162                   extractor.HashToken("orl"),
163                   extractor.HashToken("rld"),
164                   extractor.HashToken("ld!"),
165                   extractor.HashToken("d!$"),
166                   // clang-format on
167               }));
168   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
169 }
170 
TEST(TokenFeatureExtractorTest,ICUCaseFeature)171 TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
172   TokenFeatureExtractorOptions options;
173   options.num_buckets = 1000;
174   options.chargram_orders = std::vector<int>{1, 2};
175   options.extract_case_feature = true;
176   options.unicode_aware_features = true;
177   options.extract_selection_mask_feature = false;
178   TokenFeatureExtractor extractor(options);
179 
180   std::vector<int> sparse_features;
181   std::vector<float> dense_features;
182   extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
183                     &dense_features);
184   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
185 
186   sparse_features.clear();
187   dense_features.clear();
188   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
189                     &dense_features);
190   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
191 
192   sparse_features.clear();
193   dense_features.clear();
194   extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
195                     &dense_features);
196   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
197 
198   sparse_features.clear();
199   dense_features.clear();
200   extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
201                     &dense_features);
202   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
203 }
204 
TEST(TokenFeatureExtractorTest,DigitRemapping)205 TEST(TokenFeatureExtractorTest, DigitRemapping) {
206   TokenFeatureExtractorOptions options;
207   options.num_buckets = 1000;
208   options.chargram_orders = std::vector<int>{1, 2};
209   options.remap_digits = true;
210   options.unicode_aware_features = false;
211   TokenFeatureExtractor extractor(options);
212 
213   std::vector<int> sparse_features;
214   std::vector<float> dense_features;
215   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
216                     &dense_features);
217 
218   std::vector<int> sparse_features2;
219   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
220                     &dense_features);
221   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
222 
223   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
224                     &dense_features);
225   EXPECT_THAT(sparse_features,
226               testing::Not(testing::ElementsAreArray(sparse_features2)));
227 }
228 
TEST(TokenFeatureExtractorTest,DigitRemappingUnicode)229 TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
230   TokenFeatureExtractorOptions options;
231   options.num_buckets = 1000;
232   options.chargram_orders = std::vector<int>{1, 2};
233   options.remap_digits = true;
234   options.unicode_aware_features = true;
235   TokenFeatureExtractor extractor(options);
236 
237   std::vector<int> sparse_features;
238   std::vector<float> dense_features;
239   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
240                     &dense_features);
241 
242   std::vector<int> sparse_features2;
243   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
244                     &dense_features);
245   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
246 
247   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
248                     &dense_features);
249   EXPECT_THAT(sparse_features,
250               testing::Not(testing::ElementsAreArray(sparse_features2)));
251 }
252 
TEST(TokenFeatureExtractorTest,LowercaseAscii)253 TEST(TokenFeatureExtractorTest, LowercaseAscii) {
254   TokenFeatureExtractorOptions options;
255   options.num_buckets = 1000;
256   options.chargram_orders = std::vector<int>{1, 2};
257   options.lowercase_tokens = true;
258   options.unicode_aware_features = false;
259   TokenFeatureExtractor extractor(options);
260 
261   std::vector<int> sparse_features;
262   std::vector<float> dense_features;
263   extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
264                     &dense_features);
265 
266   std::vector<int> sparse_features2;
267   extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
268                     &dense_features);
269   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
270 
271   extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
272                     &dense_features);
273   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
274 }
275 
TEST(TokenFeatureExtractorTest,LowercaseUnicode)276 TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
277   TokenFeatureExtractorOptions options;
278   options.num_buckets = 1000;
279   options.chargram_orders = std::vector<int>{1, 2};
280   options.lowercase_tokens = true;
281   options.unicode_aware_features = true;
282   TokenFeatureExtractor extractor(options);
283 
284   std::vector<int> sparse_features;
285   std::vector<float> dense_features;
286   extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
287 
288   std::vector<int> sparse_features2;
289   extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
290                     &dense_features);
291   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
292 }
293 
TEST(TokenFeatureExtractorTest,RegexFeatures)294 TEST(TokenFeatureExtractorTest, RegexFeatures) {
295   TokenFeatureExtractorOptions options;
296   options.num_buckets = 1000;
297   options.chargram_orders = std::vector<int>{1, 2};
298   options.remap_digits = false;
299   options.unicode_aware_features = false;
300   options.regexp_features.push_back("^[a-z]+$");  // all lower case.
301   options.regexp_features.push_back("^[0-9]+$");  // all digits.
302   TokenFeatureExtractor extractor(options);
303 
304   std::vector<int> sparse_features;
305   std::vector<float> dense_features;
306   extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
307                     &dense_features);
308   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
309 
310   dense_features.clear();
311   extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
312                     &dense_features);
313   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
314 
315   dense_features.clear();
316   extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
317                     &dense_features);
318   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
319 
320   dense_features.clear();
321   extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
322                     &dense_features);
323   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
324 }
325 
TEST(TokenFeatureExtractorTest,ExtractTooLongWord)326 TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
327   TokenFeatureExtractorOptions options;
328   options.num_buckets = 1000;
329   options.chargram_orders = std::vector<int>{22};
330   options.extract_case_feature = true;
331   options.unicode_aware_features = true;
332   options.extract_selection_mask_feature = true;
333   TestingTokenFeatureExtractor extractor(options);
334 
335   // Test that this runs. ASAN should catch problems.
336   std::vector<int> sparse_features;
337   std::vector<float> dense_features;
338   extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
339                     &sparse_features, &dense_features);
340 
341   EXPECT_THAT(sparse_features,
342               testing::ElementsAreArray({
343                   // clang-format off
344                   extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
345                   extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
346                   // clang-format on
347               }));
348 }
349 
TEST(TokenFeatureExtractorTest,ExtractAsciiUnicodeMatches)350 TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
351   TokenFeatureExtractorOptions options;
352   options.num_buckets = 1000;
353   options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
354   options.extract_case_feature = true;
355   options.unicode_aware_features = true;
356   options.extract_selection_mask_feature = true;
357   TestingTokenFeatureExtractor extractor_unicode(options);
358 
359   options.unicode_aware_features = false;
360   TestingTokenFeatureExtractor extractor_ascii(options);
361 
362   for (const std::string& input :
363        {"https://www.abcdefgh.com/in/xxxkkkvayio",
364         "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
365         "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
366         "x", "Hello", "Hey,", "Hi", ""}) {
367     std::vector<int> sparse_features_unicode;
368     std::vector<float> dense_features_unicode;
369     extractor_unicode.Extract(Token{input, 0, 0}, true,
370                               &sparse_features_unicode,
371                               &dense_features_unicode);
372 
373     std::vector<int> sparse_features_ascii;
374     std::vector<float> dense_features_ascii;
375     extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
376                             &dense_features_ascii);
377 
378     EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
379     EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
380   }
381 }
382 
TEST(TokenFeatureExtractorTest,ExtractForPadToken)383 TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
384   TokenFeatureExtractorOptions options;
385   options.num_buckets = 1000;
386   options.chargram_orders = std::vector<int>{1, 2};
387   options.extract_case_feature = true;
388   options.unicode_aware_features = false;
389   options.extract_selection_mask_feature = true;
390 
391   TestingTokenFeatureExtractor extractor(options);
392 
393   std::vector<int> sparse_features;
394   std::vector<float> dense_features;
395 
396   extractor.Extract(Token(), false, &sparse_features, &dense_features);
397 
398   EXPECT_THAT(sparse_features,
399               testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
400   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
401 }
402 
403 }  // namespace
404 }  // namespace libtextclassifier
405