• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "smartselect/text-classification-model.h"
18 
19 #include <fcntl.h>
20 #include <stdio.h>
21 #include <memory>
22 #include <string>
23 
24 #include "base.h"
25 #include "gtest/gtest.h"
26 
27 namespace libtextclassifier {
28 namespace {
29 
GetModelPath()30 std::string GetModelPath() {
31   return TEST_DATA_DIR "smartselection.model";
32 }
33 
TEST(TextClassificationModelTest,ReadModelOptions)34 TEST(TextClassificationModelTest, ReadModelOptions) {
35   const std::string model_path = GetModelPath();
36   int fd = open(model_path.c_str(), O_RDONLY);
37   ModelOptions model_options;
38   ASSERT_TRUE(ReadSelectionModelOptions(fd, &model_options));
39   close(fd);
40 
41   EXPECT_EQ("en", model_options.language());
42   EXPECT_GT(model_options.version(), 0);
43 }
44 
TEST(TextClassificationModelTest,SuggestSelection)45 TEST(TextClassificationModelTest, SuggestSelection) {
46   const std::string model_path = GetModelPath();
47   int fd = open(model_path.c_str(), O_RDONLY);
48   std::unique_ptr<TextClassificationModel> model(
49       new TextClassificationModel(fd));
50   close(fd);
51 
52   EXPECT_EQ(model->SuggestSelection(
53                 "this afternoon Barack Obama gave a speech at", {15, 21}),
54             std::make_pair(15, 27));
55 
56   // Try passing whole string.
57   // If more than 1 token is specified, we should return back what entered.
58   EXPECT_EQ(model->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
59             std::make_pair(0, 27));
60 
61   // Single letter.
62   EXPECT_EQ(std::make_pair(0, 1), model->SuggestSelection("a", {0, 1}));
63 
64   // Single word.
65   EXPECT_EQ(std::make_pair(0, 4), model->SuggestSelection("asdf", {0, 4}));
66 }
67 
TEST(TextClassificationModelTest,SuggestSelectionsAreSymmetric)68 TEST(TextClassificationModelTest, SuggestSelectionsAreSymmetric) {
69   const std::string model_path = GetModelPath();
70   int fd = open(model_path.c_str(), O_RDONLY);
71   std::unique_ptr<TextClassificationModel> model(
72       new TextClassificationModel(fd));
73   close(fd);
74 
75   EXPECT_EQ(std::make_pair(0, 27),
76             model->SuggestSelection("350 Third Street, Cambridge", {0, 3}));
77   EXPECT_EQ(std::make_pair(0, 27),
78             model->SuggestSelection("350 Third Street, Cambridge", {4, 9}));
79   EXPECT_EQ(std::make_pair(0, 27),
80             model->SuggestSelection("350 Third Street, Cambridge", {10, 16}));
81   EXPECT_EQ(std::make_pair(6, 33),
82             model->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
83                                     {16, 22}));
84 }
85 
TEST(TextClassificationModelTest,SuggestSelectionWithNewLine)86 TEST(TextClassificationModelTest, SuggestSelectionWithNewLine) {
87   const std::string model_path = GetModelPath();
88   int fd = open(model_path.c_str(), O_RDONLY);
89   std::unique_ptr<TextClassificationModel> model(
90       new TextClassificationModel(fd));
91   close(fd);
92 
93   std::tuple<int, int> selection;
94   selection = model->SuggestSelection("abc\nBarack Obama", {4, 10});
95   EXPECT_EQ(4, std::get<0>(selection));
96   EXPECT_EQ(16, std::get<1>(selection));
97 
98   selection = model->SuggestSelection("Barack Obama\nabc", {0, 6});
99   EXPECT_EQ(0, std::get<0>(selection));
100   EXPECT_EQ(12, std::get<1>(selection));
101 }
102 
TEST(TextClassificationModelTest,SuggestSelectionWithPunctuation)103 TEST(TextClassificationModelTest, SuggestSelectionWithPunctuation) {
104   const std::string model_path = GetModelPath();
105   int fd = open(model_path.c_str(), O_RDONLY);
106   std::unique_ptr<TextClassificationModel> model(
107       new TextClassificationModel(fd));
108   close(fd);
109 
110   std::tuple<int, int> selection;
111 
112   // From the right.
113   selection = model->SuggestSelection(
114       "this afternoon Barack Obama, gave a speech at", {15, 21});
115   EXPECT_EQ(15, std::get<0>(selection));
116   EXPECT_EQ(27, std::get<1>(selection));
117 
118   // From the right multiple.
119   selection = model->SuggestSelection(
120       "this afternoon Barack Obama,.,.,, gave a speech at", {15, 21});
121   EXPECT_EQ(15, std::get<0>(selection));
122   EXPECT_EQ(27, std::get<1>(selection));
123 
124   // From the left multiple.
125   selection = model->SuggestSelection(
126       "this afternoon ,.,.,,Barack Obama gave a speech at", {21, 27});
127   EXPECT_EQ(21, std::get<0>(selection));
128   EXPECT_EQ(27, std::get<1>(selection));
129 
130   // From both sides.
131   selection = model->SuggestSelection(
132       "this afternoon !Barack Obama,- gave a speech at", {16, 22});
133   EXPECT_EQ(16, std::get<0>(selection));
134   EXPECT_EQ(28, std::get<1>(selection));
135 }
136 
137 class TestingTextClassificationModel
138     : public libtextclassifier::TextClassificationModel {
139  public:
TestingTextClassificationModel(int fd)140   explicit TestingTextClassificationModel(int fd)
141       : libtextclassifier::TextClassificationModel(fd) {}
142 
143   using libtextclassifier::TextClassificationModel::StripPunctuation;
144 
DisableClassificationHints()145   void DisableClassificationHints() {
146     sharing_options_.set_always_accept_url_hint(false);
147     sharing_options_.set_always_accept_email_hint(false);
148   }
149 };
150 
TEST(TextClassificationModelTest,StripPunctuation)151 TEST(TextClassificationModelTest, StripPunctuation) {
152   const std::string model_path = GetModelPath();
153   int fd = open(model_path.c_str(), O_RDONLY);
154   std::unique_ptr<TestingTextClassificationModel> model(
155       new TestingTextClassificationModel(fd));
156   close(fd);
157 
158   EXPECT_EQ(std::make_pair(3, 10),
159             model->StripPunctuation({0, 10}, ".,-abcd.()"));
160   EXPECT_EQ(std::make_pair(0, 6), model->StripPunctuation({0, 6}, "(abcd)"));
161   EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "[abcd]"));
162   EXPECT_EQ(std::make_pair(1, 5), model->StripPunctuation({0, 6}, "{abcd}"));
163 
164   // Empty result.
165   EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 1}, "&"));
166   EXPECT_EQ(std::make_pair(0, 0), model->StripPunctuation({0, 4}, "&-,}"));
167 
168   // Invalid indices
169   EXPECT_EQ(std::make_pair(-1, 523), model->StripPunctuation({-1, 523}, "a"));
170   EXPECT_EQ(std::make_pair(-1, -1), model->StripPunctuation({-1, -1}, "a"));
171   EXPECT_EQ(std::make_pair(0, -1), model->StripPunctuation({0, -1}, "a"));
172 }
173 
TEST(TextClassificationModelTest,SuggestSelectionNoCrashWithJunk)174 TEST(TextClassificationModelTest, SuggestSelectionNoCrashWithJunk) {
175   const std::string model_path = GetModelPath();
176   int fd = open(model_path.c_str(), O_RDONLY);
177   std::unique_ptr<TextClassificationModel> ff_model(
178       new TextClassificationModel(fd));
179   close(fd);
180 
181   std::tuple<int, int> selection;
182 
183   // Try passing in bunch of invalid selections.
184   selection = ff_model->SuggestSelection("", {0, 27});
185   // If more than 1 token is specified, we should return back what entered.
186   EXPECT_EQ(0, std::get<0>(selection));
187   EXPECT_EQ(27, std::get<1>(selection));
188 
189   selection = ff_model->SuggestSelection("", {-10, 27});
190   // If more than 1 token is specified, we should return back what entered.
191   EXPECT_EQ(-10, std::get<0>(selection));
192   EXPECT_EQ(27, std::get<1>(selection));
193 
194   selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {0, 27});
195   // If more than 1 token is specified, we should return back what entered.
196   EXPECT_EQ(0, std::get<0>(selection));
197   EXPECT_EQ(27, std::get<1>(selection));
198 
199   selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-30, 300});
200   // If more than 1 token is specified, we should return back what entered.
201   EXPECT_EQ(-30, std::get<0>(selection));
202   EXPECT_EQ(300, std::get<1>(selection));
203 
204   selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {-10, -1});
205   // If more than 1 token is specified, we should return back what entered.
206   EXPECT_EQ(-10, std::get<0>(selection));
207   EXPECT_EQ(-1, std::get<1>(selection));
208 
209   selection = ff_model->SuggestSelection("Word 1 2 3 hello!", {100, 17});
210   // If more than 1 token is specified, we should return back what entered.
211   EXPECT_EQ(100, std::get<0>(selection));
212   EXPECT_EQ(17, std::get<1>(selection));
213 }
214 
215 namespace {
216 
FindBestResult(std::vector<std::pair<std::string,float>> results)217 std::string FindBestResult(std::vector<std::pair<std::string, float>> results) {
218   if (results.empty()) {
219     return "<INVALID RESULTS>";
220   }
221 
222   std::sort(results.begin(), results.end(),
223             [](const std::pair<std::string, float> a,
224                const std::pair<std::string, float> b) {
225               return a.second > b.second;
226             });
227   return results[0].first;
228 }
229 
230 }  // namespace
231 
TEST(TextClassificationModelTest,ClassifyText)232 TEST(TextClassificationModelTest, ClassifyText) {
233   const std::string model_path = GetModelPath();
234   int fd = open(model_path.c_str(), O_RDONLY);
235   std::unique_ptr<TestingTextClassificationModel> model(
236       new TestingTextClassificationModel(fd));
237   close(fd);
238 
239   model->DisableClassificationHints();
240   EXPECT_EQ("other",
241             FindBestResult(model->ClassifyText(
242                 "this afternoon Barack Obama gave a speech at", {15, 27})));
243   EXPECT_EQ("other",
244             FindBestResult(model->ClassifyText("you@android.com", {0, 15})));
245   EXPECT_EQ("other", FindBestResult(model->ClassifyText(
246                          "Contact me at you@android.com", {14, 29})));
247   EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
248                          "Call me at (800) 123-456 today", {11, 24})));
249   EXPECT_EQ("other", FindBestResult(model->ClassifyText(
250                          "Visit www.google.com every today!", {6, 20})));
251 
252   // More lines.
253   EXPECT_EQ("other",
254             FindBestResult(model->ClassifyText(
255                 "this afternoon Barack Obama gave a speech at|Visit "
256                 "www.google.com every today!|Call me at (800) 123-456 today.",
257                 {15, 27})));
258   EXPECT_EQ("other",
259             FindBestResult(model->ClassifyText(
260                 "this afternoon Barack Obama gave a speech at|Visit "
261                 "www.google.com every today!|Call me at (800) 123-456 today.",
262                 {51, 65})));
263   EXPECT_EQ("phone",
264             FindBestResult(model->ClassifyText(
265                 "this afternoon Barack Obama gave a speech at|Visit "
266                 "www.google.com every today!|Call me at (800) 123-456 today.",
267                 {90, 103})));
268 
269   // Single word.
270   EXPECT_EQ("other", FindBestResult(model->ClassifyText("obama", {0, 5})));
271   EXPECT_EQ("other", FindBestResult(model->ClassifyText("asdf", {0, 4})));
272   EXPECT_EQ("<INVALID RESULTS>",
273             FindBestResult(model->ClassifyText("asdf", {0, 0})));
274 
275   // Junk.
276   EXPECT_EQ("<INVALID RESULTS>",
277             FindBestResult(model->ClassifyText("", {0, 0})));
278   EXPECT_EQ("<INVALID RESULTS>", FindBestResult(model->ClassifyText(
279                                      "a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
280 }
281 
TEST(TextClassificationModelTest,ClassifyTextWithHints)282 TEST(TextClassificationModelTest, ClassifyTextWithHints) {
283   const std::string model_path = GetModelPath();
284   int fd = open(model_path.c_str(), O_RDONLY);
285   std::unique_ptr<TestingTextClassificationModel> model(
286       new TestingTextClassificationModel(fd));
287   close(fd);
288 
289   // When EMAIL hint is passed, the result should be email.
290   EXPECT_EQ("email",
291             FindBestResult(model->ClassifyText(
292                 "x", {0, 1}, TextClassificationModel::SELECTION_IS_EMAIL)));
293   // When URL hint is passed, the result should be email.
294   EXPECT_EQ("url",
295             FindBestResult(model->ClassifyText(
296                 "x", {0, 1}, TextClassificationModel::SELECTION_IS_URL)));
297   // When both hints are passed, the result should be url (as it's probably
298   // better to let Chrome handle this case).
299   EXPECT_EQ("url", FindBestResult(model->ClassifyText(
300                        "x", {0, 1},
301                        TextClassificationModel::SELECTION_IS_EMAIL |
302                            TextClassificationModel::SELECTION_IS_URL)));
303 
304   // With disabled hints, we should get the same prediction regardless of the
305   // hint.
306   model->DisableClassificationHints();
307   EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
308             model->ClassifyText("x", {0, 1},
309                                 TextClassificationModel::SELECTION_IS_EMAIL));
310 
311   EXPECT_EQ(model->ClassifyText("x", {0, 1}, 0),
312             model->ClassifyText("x", {0, 1},
313                                 TextClassificationModel::SELECTION_IS_URL));
314 }
315 
TEST(TextClassificationModelTest,PhoneFiltering)316 TEST(TextClassificationModelTest, PhoneFiltering) {
317   const std::string model_path = GetModelPath();
318   int fd = open(model_path.c_str(), O_RDONLY);
319   std::unique_ptr<TestingTextClassificationModel> model(
320       new TestingTextClassificationModel(fd));
321   close(fd);
322 
323   EXPECT_EQ("phone", FindBestResult(model->ClassifyText("phone: (123) 456 789",
324                                                         {7, 20}, 0)));
325   EXPECT_EQ("phone", FindBestResult(model->ClassifyText(
326                          "phone: (123) 456 789,0001112", {7, 25}, 0)));
327   EXPECT_EQ("other", FindBestResult(model->ClassifyText(
328                          "phone: (123) 456 789,0001112", {7, 28}, 0)));
329 }
330 
331 }  // namespace
332 }  // namespace libtextclassifier
333