1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "smartselect/text-classification-model.h"
18
19 #include <fcntl.h>
20 #include <stdio.h>
21 #include <memory>
22 #include <string>
23
24 #include "base.h"
25 #include "gtest/gtest.h"
26
27 namespace libtextclassifier {
28 namespace {
29
GetModelPath()30 std::string GetModelPath() {
31 return TEST_DATA_DIR "smartselection.model";
32 }
33
TEST(TextClassificationModelTest,ReadModelOptions)34 TEST(TextClassificationModelTest, ReadModelOptions) {
35 const std::string model_path = GetModelPath();
36 int fd = open(model_path.c_str(), O_RDONLY);
37 ModelOptions model_options;
38 ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options));
39 close(fd);
40
41 EXPECT_EQ("en", model_options.language());
42 EXPECT_GT(model_options.version(), 0);
43 }
44
TEST(TextClassificationModelTest,SuggestSelection)45 TEST(TextClassificationModelTest, SuggestSelection) {
46 const std::string model_path = GetModelPath();
47 int fd = open(model_path.c_str(), O_RDONLY);
48 std::unique_ptr<TextClassificationModel> model(
49 new TextClassificationModel(fd));
50 close(fd);
51
52 EXPECT_EQ(model->SuggestSelection(
53 "this afternoon Barack Obama gave a speech at", {15, 21}),
54 std::make_pair(15, 27));
55
56 // Try passing whole string.
57 // If more than 1 token is specified, we should return back what entered.
58 EXPECT_EQ(model->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
59 std::make_pair(0, 27));
60
61 // Single letter.
62 EXPECT_EQ(std::make_pair(0, 1), model->SuggestSelection("a", {0, 1}));
63
64 // Single word.
65 EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
66 }
67
TEST(TextClassificationModelTest,SuggestSelectionsAreSymmetric)68 TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
69 const std::string model_path = GetModelPath();
70 int fd = open(model_path.c_str(), O_RDONLY);
71 std::unique_ptr<TextClassificationModel> model(
72 new TextClassificationModel(fd));
73 close(fd);
74
75 EXPECT_EQ(std::make_pair(0, 27),
76 model->SuggestSelection("350 Third Street, Cambridge", {0, 3}));
77 EXPECT_EQ(std::make_pair(0, 27),
78 model->SuggestSelection("350 Third Street, Cambridge", {4, 9}));
79 EXPECT_EQ(std::make_pair(0, 27),
80 model->SuggestSelection("350 Third Street, Cambridge", {10, 16}));
81 EXPECT_EQ(std::make_pair(6, 33),
82 model->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
83 {16, 22}));
84 }
85
TEST(TextClassificationModelTest,SuggestSelectionWithNewLine)86 TEST(TextClassificationModelTest, SuggestSelectionWithNewLine) {
87 const std::string model_path = GetModelPath();
88 int fd = open(model_path.c_str(), O_RDONLY);
89 std::unique_ptr<TextClassificationModel> model(
90 new TextClassificationModel(fd));
91 close(fd);
92
93 std::tuple<int, int> selection;
94 selection = model->SuggestSelection("abc\nBarack Obama", {4, 10});
95 EXPECT_EQ(4, std::get<0>(selection));
96 EXPECT_EQ(16, std::get<1>(selection));
97
98 selection = model->SuggestSelection("Barack Obama\nabc", {0, 6});
99 EXPECT_EQ(0, std::get<0>(selection));
100 EXPECT_EQ(12, std::get<1>(selection));
101 }
102
TEST(TextClassificationModelTest,SuggestSelectionWithPunctuation)103 TEST(TextClassificationModelTest, SuggestSelectionWithPunctuation) {
104 const std::string model_path = GetModelPath();
105 int fd = open(model_path.c_str(), O_RDONLY);
106 std::unique_ptr<TextClassificationModel> model(
107 new TextClassificationModel(fd));
108 close(fd);
109
110 std::tuple<int, int> selection;
111
112 // From the right.
113 selection = model->SuggestSelection(
114 "this afternoon Barack Obama, gave a speech at", {15, 21});
115 EXPECT_EQ(15, std::get<0>(selection));
116 EXPECT_EQ(27, std::get<1>(selection));
117
118 // From the right multiple.
119 selection = model->SuggestSelection(
120 "this afternoon Barack Obama,.,.,, gave a speech at", {15, 21});
121 EXPECT_EQ(15, std::get<0>(selection));
122 EXPECT_EQ(27, std::get<1>(selection));
123
124 // From the left multiple.
125 selection = model->SuggestSelection(
126 "this afternoon ,.,.,,Barack Obama gave a speech at", {21, 27});
127 EXPECT_EQ(21, std::get<0>(selection));
128 EXPECT_EQ(27, std::get<1>(selection));
129
130 // From both sides.
131 selection = model->SuggestSelection(
132 "this afternoon !Barack Obama,- gave a speech at", {16, 22});
133 EXPECT_EQ(16, std::get<0>(selection));
134 EXPECT_EQ(28, std::get<1>(selection));
135 }
136
137 class TestingTextClassificationModel
138 : public libtextclassifier::TextClassificationModel {
139 public:
TestingTextClassificationModel(int fd)140 explicit TestingTextClassificationModel(int fd)
141 : libtextclassifier::TextClassificationModel(fd) {}
142
143 using libtextclassifier::TextClassificationModel::StripPunctuation;
144
DisableClassificationHints()145 void DisableClassificationHints() {
146 sharing_options_.set_always_accept_url_hint(false);
147 sharing_options_.set_always_accept_email_hint(false);
148 }
149 };
150
TEST(TextClassificationModelTest,StripPunctuation)151 TEST(TextClassificationModelTest, StripPunctuation) {
152 const std::string model_path = GetModelPath();
153 int fd = open(model_path.c_str(), O_RDONLY);
154 std::unique_ptr<TestingTextClassificationModel> model(
155 new TestingTextClassificationModel(fd));
156 close(fd);
157
158 EXPECT_EQ(std::make_pair(3, 10),
159 model->StripPunctuation({0, 10}, ".,-abcd.()"));
160 EXPECT_EQ(std::make_pair(0, 6), model->StripPunctuation({0, 6}, "(abcd)"));
161 EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "[abcd]"));
162 EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "{abcd}"));
163
164 // Empty result.
165 EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 1}, "&"));
166 EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 4}, "&-,}"));
167
168 // Invalid indices
169 EXPECT_EQ(std::make_pair(-1, 523), model->StripPunctuation({-1, 523}, "a"));
170 EXPECT_EQ(std::make_pair(-1, -1), model->StripPunctuation({-1, -1}, "a"));
171 EXPECT_EQ(std::make_pair(0, -1), model->StripPunctuation({0, -1}, "a"));
172 }
173
TEST(TextClassificationModelTest,SuggestSelectionNoCrashWithJunk)174 TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) {
175 const std::string model_path = GetModelPath();
176 int fd = open(model_path.c_str(), O_RDONLY);
177 std::unique_ptr<TextClassificationModel> ff_model(
178 new TextClassificationModel(fd));
179 close(fd);
180
181 std::tuple<int, int> selection;
182
183 // Try passing in bunch of invalid selections.
184 selection = ff_model->SuggestSelection("", {0, 27});
185 // If more than 1 token is specified, we should return back what entered.
186 EXPECT_EQ(0, std::get<0>(selection));
187 EXPECT_EQ(27, std::get<1>(selection));
188
189 selection = ff_model->SuggestSelection("", {-10, 27});
190 // If more than 1 token is specified, we should return back what entered.
191 EXPECT_EQ(-10, std::get<0>(selection));
192 EXPECT_EQ(27, std::get<1>(selection));
193
194 selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {0, 27});
195 // If more than 1 token is specified, we should return back what entered.
196 EXPECT_EQ(0, std::get<0>(selection));
197 EXPECT_EQ(27, std::get<1>(selection));
198
199 selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-30, 300});
200 // If more than 1 token is specified, we should return back what entered.
201 EXPECT_EQ(-30, std::get<0>(selection));
202 EXPECT_EQ(300, std::get<1>(selection));
203
204 selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-10, -1});
205 // If more than 1 token is specified, we should return back what entered.
206 EXPECT_EQ(-10, std::get<0>(selection));
207 EXPECT_EQ(-1, std::get<1>(selection));
208
209 selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {100, 17});
210 // If more than 1 token is specified, we should return back what entered.
211 EXPECT_EQ(100, std::get<0>(selection));
212 EXPECT_EQ(17, std::get<1>(selection));
213 }
214
215 namespace {
216
FindBestResult(std::vector<std::pair<std::string,float>> results)217 std::string FindBestResult(std::vector<std::pair<std::string, float>> results) {
218 if (results.empty()) {
219 return "<INVALID RESULTS>";
220 }
221
222 std::sort(results.begin(), results.end(),
223 [](const std::pair<std::string, float> a,
224 const std::pair<std::string, float> b) {
225 return a.second > b.second;
226 });
227 return results[0].first;
228 }
229
230 } // namespace
231
TEST(TextClassificationModelTest,ClassifyText)232 TEST(TextClassificationModelTest, ClassifyText) {
233 const std::string model_path = GetModelPath();
234 int fd = open(model_path.c_str(), O_RDONLY);
235 std::unique_ptr<TestingTextClassificationModel> model(
236 new TestingTextClassificationModel(fd));
237 close(fd);
238
239 model->DisableClassificationHints();
240 EXPECT_EQ("other",
241 FindBestResult(model->ClassifyText(
242 "this afternoon Barack Obama gave a speech at", {15, 27})));
243 EXPECT_EQ("other",
244 FindBestResult(model->ClassifyText("you@android.com", {0, 15})));
245 EXPECT_EQ("other", FindBestResult(model->ClassifyText(
246 "Contact me at you@android.com", {14, 29})));
247 EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
248 "Call me at (800) 123-456 today", {11, 24})));
249 EXPECT_EQ("other", FindBestResult(model->ClassifyText(
250 "Visit www.google.com every today!", {6, 20})));
251
252 // More lines.
253 EXPECT_EQ("other",
254 FindBestResult(model->ClassifyText(
255 "this afternoon Barack Obama gave a speech at|Visit "
256 "www.google.com every today!|Call me at (800) 123-456 today.",
257 {15, 27})));
258 EXPECT_EQ("other",
259 FindBestResult(model->ClassifyText(
260 "this afternoon Barack Obama gave a speech at|Visit "
261 "www.google.com every today!|Call me at (800) 123-456 today.",
262 {51, 65})));
263 EXPECT_EQ("phone",
264 FindBestResult(model->ClassifyText(
265 "this afternoon Barack Obama gave a speech at|Visit "
266 "www.google.com every today!|Call me at (800) 123-456 today.",
267 {90, 103})));
268
269 // Single word.
270 EXPECT_EQ("other", FindBestResult(model->ClassifyText("obama", {0, 5})));
271 EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4})));
272 EXPECT_EQ("<INVALID RESULTS>",
273 FindBestResult(model->ClassifyText("asdf", {0, 0})));
274
275 // Junk.
276 EXPECT_EQ("<INVALID RESULTS>",
277 FindBestResult(model->ClassifyText("", {0, 0})));
278 EXPECT_EQ("<INVALID RESULTS>", FindBestResult(model->ClassifyText(
279 "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
280 }
281
TEST(TextClassificationModelTest,ClassifyTextWithHints)282 TEST(TextClassificationModelTest, ClassifyTextWithHints) {
283 const std::string model_path = GetModelPath();
284 int fd = open(model_path.c_str(), O_RDONLY);
285 std::unique_ptr<TestingTextClassificationModel> model(
286 new TestingTextClassificationModel(fd));
287 close(fd);
288
289 // When EMAIL hint is passed, the result should be email.
290 EXPECT_EQ("email",
291 FindBestResult(model->ClassifyText(
292 "x", {0, 1}, TextClassificationModel::SELECTION_IS_EMAIL)));
293 // When URL hint is passed, the result should be email.
294 EXPECT_EQ("url",
295 FindBestResult(model->ClassifyText(
296 "x", {0, 1}, TextClassificationModel::SELECTION_IS_URL)));
297 // When both hints are passed, the result should be url (as it's probably
298 // better to let Chrome handle this case).
299 EXPECT_EQ("url", FindBestResult(model->ClassifyText(
300 "x", {0, 1},
301 TextClassificationModel::SELECTION_IS_EMAIL |
302 TextClassificationModel::SELECTION_IS_URL)));
303
304 // With disabled hints, we should get the same prediction regardless of the
305 // hint.
306 model->DisableClassificationHints();
307 EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
308 model->ClassifyText("x", {0, 1},
309 TextClassificationModel::SELECTION_IS_EMAIL));
310
311 EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
312 model->ClassifyText("x", {0, 1},
313 TextClassificationModel::SELECTION_IS_URL));
314 }
315
TEST(TextClassificationModelTest,PhoneFiltering)316 TEST(TextClassificationModelTest, PhoneFiltering) {
317 const std::string model_path = GetModelPath();
318 int fd = open(model_path.c_str(), O_RDONLY);
319 std::unique_ptr<TestingTextClassificationModel> model(
320 new TestingTextClassificationModel(fd));
321 close(fd);
322
323 EXPECT_EQ("phone", FindBestResult(model->ClassifyText("phone: (123) 456 789",
324 {7, 20}, 0)));
325 EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
326 "phone: (123) 456 789,0001112", {7, 25}, 0)));
327 EXPECT_EQ("other", FindBestResult(model->ClassifyText(
328 "phone: (123) 456 789,0001112", {7, 28}, 0)));
329 }
330
331 } // namespace
332 } // namespace libtextclassifier
333