• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/annotator_test-include.h"
18 
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <type_traits>
23 
24 #include "annotator/annotator.h"
25 #include "annotator/collections.h"
26 #include "annotator/model_generated.h"
27 #include "annotator/test-utils.h"
28 #include "annotator/types-test-util.h"
29 #include "annotator/types.h"
30 #include "utils/grammar/utils/locale-shard-map.h"
31 #include "utils/grammar/utils/rules.h"
32 #include "utils/testing/annotator.h"
33 #include "lang_id/fb_model/lang-id-from-fb.h"
34 #include "lang_id/lang-id.h"
35 
36 namespace libtextclassifier3 {
37 namespace test_internal {
38 
39 using ::testing::Contains;
40 using ::testing::ElementsAre;
41 using ::testing::ElementsAreArray;
42 using ::testing::Eq;
43 using ::testing::IsEmpty;
44 using ::testing::UnorderedElementsAreArray;
45 
GetTestModelPath()46 std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
47 
GetModelWithVocabPath()48 std::string GetModelWithVocabPath() {
49   return GetModelPath() + "test_vocab_model.fb";
50 }
51 
GetTestModelWithDatetimeRegEx()52 std::string GetTestModelWithDatetimeRegEx() {
53   std::string model_buffer = ReadFile(GetTestModelPath());
54   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
55     model->datetime_grammar_model.reset(nullptr);
56   });
57   return model_buffer;
58 }
59 
ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan> & result,const std::string & currency,const std::string & amount,const int whole_part,const int decimal_part,const int nanos)60 void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
61                               const std::string& currency,
62                               const std::string& amount, const int whole_part,
63                               const int decimal_part, const int nanos) {
64   ASSERT_GT(result.size(), 0);
65   ASSERT_GT(result[0].classification.size(), 0);
66   ASSERT_EQ(result[0].classification[0].collection, "money");
67 
68   const EntityData* entity_data =
69       GetEntityData(result[0].classification[0].serialized_entity_data.data());
70   ASSERT_NE(entity_data, nullptr);
71   ASSERT_NE(entity_data->money(), nullptr);
72   EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency);
73   EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount);
74   EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part);
75   EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part);
76   EXPECT_EQ(entity_data->money()->nanos(), nanos);
77 }
78 
TEST_F(AnnotatorTest,EmbeddingExecutorLoadingFails)79 TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
80   std::unique_ptr<Annotator> classifier =
81       Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(),
82                           calendarlib_.get());
83   EXPECT_FALSE(classifier);
84 }
85 
VerifyClassifyText(const Annotator * classifier)86 void VerifyClassifyText(const Annotator* classifier) {
87   ASSERT_TRUE(classifier);
88 
89   EXPECT_EQ("other",
90             FirstResult(classifier->ClassifyText(
91                 "this afternoon Barack Obama gave a speech at", {15, 27})));
92   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
93                          "Call me at (800) 123-456 today", {11, 24})));
94 
95   // More lines.
96   EXPECT_EQ("other",
97             FirstResult(classifier->ClassifyText(
98                 "this afternoon Barack Obama gave a speech at|Visit "
99                 "www.google.com every today!|Call me at (800) 123-456 today.",
100                 {15, 27})));
101   EXPECT_EQ("phone",
102             FirstResult(classifier->ClassifyText(
103                 "this afternoon Barack Obama gave a speech at|Visit "
104                 "www.google.com every today!|Call me at (800) 123-456 today.",
105                 {90, 103})));
106 
107   // Single word.
108   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
109   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
110 
111   // Junk. These should not crash the test.
112   classifier->ClassifyText("", {0, 0});
113   classifier->ClassifyText("asdf", {0, 0});
114   classifier->ClassifyText("asdf", {0, 27});
115   classifier->ClassifyText("asdf", {-30, 300});
116   classifier->ClassifyText("asdf", {-10, -1});
117   classifier->ClassifyText("asdf", {100, 17});
118   classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
119 
120   // Test invalid utf8 input.
121   EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
122                                      "\xf0\x9f\x98\x8b\x8b", {0, 0})));
123 }
124 
TEST_F(AnnotatorTest,ClassifyText)125 TEST_F(AnnotatorTest, ClassifyText) {
126   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
127       GetTestModelPath(), unilib_.get(), calendarlib_.get());
128   VerifyClassifyText(classifier.get());
129 }
130 
TEST_F(AnnotatorTest,ClassifyTextLocalesAndDictionary)131 TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
132   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
133       GetTestModelPath(), unilib_.get(), calendarlib_.get());
134   ASSERT_TRUE(classifier);
135 
136   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7})));
137 
138   ClassificationOptions classification_options;
139   classification_options.detected_text_language_tags = "en";
140   EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
141                               "isotope", {0, 7}, classification_options)));
142 
143   classification_options.detected_text_language_tags = "uz";
144   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
145                          "isotope", {0, 7}, classification_options)));
146 }
147 
TEST_F(AnnotatorTest,ClassifyTextUseVocabAnnotatorWithoutVocabModel)148 TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) {
149   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
150       GetTestModelPath(), unilib_.get(), calendarlib_.get());
151   ASSERT_TRUE(classifier);
152 
153   ClassificationOptions classification_options;
154   classification_options.detected_text_language_tags = "en";
155   classification_options.use_vocab_annotator = true;
156 
157   EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
158                               "isotope", {0, 7}, classification_options)));
159 }
160 
161 #ifdef TC3_VOCAB_ANNOTATOR_IMPL
TEST_F(AnnotatorTest,ClassifyTextWithVocabModel)162 TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) {
163   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
164       GetModelWithVocabPath(), unilib_.get(), calendarlib_.get());
165   ASSERT_TRUE(classifier);
166 
167   ClassificationOptions classification_options;
168   classification_options.detected_text_language_tags = "en";
169 
170   // The FFModel model does not annotate "integrity" as "dictionary", but the
171   // vocab annotator does. So we can use that to check if the vocab annotator is
172   // in use.
173   classification_options.use_vocab_annotator = true;
174   EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
175                               "integrity", {0, 9}, classification_options)));
176   classification_options.use_vocab_annotator = false;
177   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
178                          "integrity", {0, 9}, classification_options)));
179 }
180 #endif  // TC3_VOCAB_ANNOTATOR_IMPL
181 
TEST_F(AnnotatorTest,ClassifyTextDisabledFail)182 TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
183   const std::string test_model = ReadFile(GetTestModelPath());
184   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
185   TC3_CHECK(unpacked_model != nullptr);
186 
187   unpacked_model->classification_model.clear();
188   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
189   unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
190 
191   flatbuffers::FlatBufferBuilder builder;
192   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
193 
194   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
195       reinterpret_cast<const char*>(builder.GetBufferPointer()),
196       builder.GetSize(), unilib_.get(), calendarlib_.get());
197 
198   // The classification model is still needed for selection scores.
199   ASSERT_FALSE(classifier);
200 }
201 
TEST_F(AnnotatorTest,ClassifyTextDisabled)202 TEST_F(AnnotatorTest, ClassifyTextDisabled) {
203   const std::string test_model = ReadFile(GetTestModelPath());
204   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
205 
206   unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
207 
208   flatbuffers::FlatBufferBuilder builder;
209   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
210 
211   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
212       reinterpret_cast<const char*>(builder.GetBufferPointer()),
213       builder.GetSize(), unilib_.get(), calendarlib_.get());
214   ASSERT_TRUE(classifier);
215 
216   EXPECT_THAT(
217       classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
218       IsEmpty());
219 }
220 
TEST_F(AnnotatorTest,ClassifyTextFilteredCollections)221 TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
222   const std::string test_model = ReadFile(GetTestModelPath());
223 
224   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
225       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
226   ASSERT_TRUE(classifier);
227 
228   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
229                          "Call me at (800) 123-456 today", {11, 24})));
230 
231   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
232   unpacked_model->output_options.reset(new OutputOptionsT);
233 
234   // Disable phone classification
235   unpacked_model->output_options->filtered_collections_classification.push_back(
236       "phone");
237 
238   flatbuffers::FlatBufferBuilder builder;
239   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
240 
241   classifier = Annotator::FromUnownedBuffer(
242       reinterpret_cast<const char*>(builder.GetBufferPointer()),
243       builder.GetSize(), unilib_.get(), calendarlib_.get());
244   ASSERT_TRUE(classifier);
245 
246   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
247                          "Call me at (800) 123-456 today", {11, 24})));
248 
249   // Check that the address classification still passes.
250   EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
251                            "350 Third Street, Cambridge", {0, 27})));
252 }
253 
TEST_F(AnnotatorTest,ClassifyTextRegularExpression)254 TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
255   const std::string test_model = ReadFile(GetTestModelPath());
256   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
257 
258   // Add test regex models.
259   unpacked_model->regex_model->patterns.push_back(MakePattern(
260       "person", "Barack Obama", /*enabled_for_classification=*/true,
261       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
262   unpacked_model->regex_model->patterns.push_back(MakePattern(
263       "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
264       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
265   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
266       MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
267                   /*enabled_for_classification=*/true,
268                   /*enabled_for_selection=*/false,
269                   /*enabled_for_annotation=*/false, 1.0);
270   verified_pattern->verification_options.reset(new VerificationOptionsT);
271   verified_pattern->verification_options->verify_luhn_checksum = true;
272   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
273 
274   flatbuffers::FlatBufferBuilder builder;
275   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
276 
277   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
278       reinterpret_cast<const char*>(builder.GetBufferPointer()),
279       builder.GetSize(), unilib_.get(), calendarlib_.get());
280   ASSERT_TRUE(classifier);
281 
282   EXPECT_EQ("flight",
283             FirstResult(classifier->ClassifyText(
284                 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
285   EXPECT_EQ("person",
286             FirstResult(classifier->ClassifyText(
287                 "this afternoon Barack Obama gave a speech at", {15, 27})));
288   EXPECT_EQ("email",
289             FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
290   EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
291                          "Contact me at you@android.com", {14, 29})));
292 
293   EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
294                        "Visit www.google.com every today!", {6, 20})));
295 
296   EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
297   EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
298                                                            {7, 12})));
299   EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
300                                 "cc: 4012 8888 8888 1881", {4, 23})));
301   EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
302                                 "2221 0067 4735 6281", {0, 19})));
303   // Luhn check fails.
304   EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
305                                                           {0, 19})));
306 
307   // More lines.
308   EXPECT_EQ("url",
309             FirstResult(classifier->ClassifyText(
310                 "this afternoon Barack Obama gave a speech at|Visit "
311                 "www.google.com every today!|Call me at (800) 123-456 today.",
312                 {51, 65})));
313 }
314 
315 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionLuaVerification)316 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) {
317   const std::string test_model = ReadFile(GetTestModelPath());
318   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
319 
320   // Add test regex models.
321   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
322       MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
323                   /*enabled_for_classification=*/true,
324                   /*enabled_for_selection=*/false,
325                   /*enabled_for_annotation=*/false, 1.0);
326   verified_pattern->verification_options.reset(new VerificationOptionsT);
327   verified_pattern->verification_options->lua_verifier = 0;
328   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
329   unpacked_model->regex_model->lua_verifier.push_back(
330       "return match[2].text==\"99\"");
331 
332   flatbuffers::FlatBufferBuilder builder;
333   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
334 
335   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
336       reinterpret_cast<const char*>(builder.GetBufferPointer()),
337       builder.GetSize(), unilib_.get(), calendarlib_.get());
338   ASSERT_TRUE(classifier);
339 
340   // Custom rule triggers and is correctly verified.
341   EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText(
342                                    "99-00-123456-12345678", {0, 21})));
343 
344   // Custom verification fails.
345   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
346                          "90-00-123456-12345678", {0, 21})));
347 }
348 #endif  // TC3_DISABLE_LUA
349 
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityData)350 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
351   const std::string test_model = ReadFile(GetTestModelPath());
352   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
353 
354   // Add fake entity schema metadata.
355   AddTestEntitySchemaData(unpacked_model.get());
356 
357   AddTestRegexModel(unpacked_model.get());
358 
359   flatbuffers::FlatBufferBuilder builder;
360   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
361 
362   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
363       reinterpret_cast<const char*>(builder.GetBufferPointer()),
364       builder.GetSize(), unilib_.get(), calendarlib_.get());
365   ASSERT_TRUE(classifier);
366 
367   // Check with full name.
368   {
369     auto classifications =
370         classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
371     EXPECT_EQ(1, classifications.size());
372     EXPECT_EQ("person_with_age", classifications[0].collection);
373 
374     // Check entity data.
375     const flatbuffers::Table* entity =
376         flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
377             classifications[0].serialized_entity_data.data()));
378     EXPECT_EQ(
379         entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
380         "Barack");
381     EXPECT_EQ(
382         entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
383         "Obama");
384     // Check `age`.
385     EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
386 
387     // Check `is_alive`.
388     EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
389 
390     // Check `former_us_president`.
391     EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
392   }
393 
394   // Check only with first name.
395   {
396     auto classifications =
397         classifier->ClassifyText("Barack is 57 years old", {0, 22});
398     EXPECT_EQ(1, classifications.size());
399     EXPECT_EQ("person_with_age", classifications[0].collection);
400 
401     // Check entity data.
402     const flatbuffers::Table* entity =
403         flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
404             classifications[0].serialized_entity_data.data()));
405     EXPECT_EQ(
406         entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
407         "Barack");
408 
409     // Check `age`.
410     EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
411 
412     // Check `is_alive`.
413     EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
414 
415     // Check `former_us_president`.
416     EXPECT_FALSE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
417   }
418 }
419 
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityDataNormalization)420 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) {
421   const std::string test_model = ReadFile(GetTestModelPath());
422   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
423 
424   // Add fake entity schema metadata.
425   AddTestEntitySchemaData(unpacked_model.get());
426 
427   AddTestRegexModel(unpacked_model.get());
428 
429   // Upper case last name as post-processing.
430   RegexModel_::PatternT* pattern =
431       unpacked_model->regex_model->patterns.back().get();
432   pattern->capturing_group[2]->normalization_options.reset(
433       new NormalizationOptionsT);
434   pattern->capturing_group[2]
435       ->normalization_options->codepointwise_normalization =
436       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
437 
438   flatbuffers::FlatBufferBuilder builder;
439   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
440 
441   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
442       reinterpret_cast<const char*>(builder.GetBufferPointer()),
443       builder.GetSize(), unilib_.get(), calendarlib_.get());
444   ASSERT_TRUE(classifier);
445 
446   auto classifications =
447       classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
448   EXPECT_EQ(1, classifications.size());
449   EXPECT_EQ("person_with_age", classifications[0].collection);
450 
451   // Check entity data normalization.
452   const flatbuffers::Table* entity =
453       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
454           classifications[0].serialized_entity_data.data()));
455   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
456             "OBAMA");
457 }
458 
TEST_F(AnnotatorTest,ClassifyTextPriorityResolution)459 TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
460   const std::string test_model = ReadFile(GetTestModelPath());
461   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
462   TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
463   // Add test regex models.
464   unpacked_model->regex_model->patterns.clear();
465   unpacked_model->regex_model->patterns.push_back(MakePattern(
466       "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
467       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
468       /*score=*/1.0, /*priority_score=*/1.0));
469   unpacked_model->regex_model->patterns.push_back(MakePattern(
470       "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
471       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
472       /*score=*/1.0, /*priority_score=*/0.0));
473 
474   {
475     flatbuffers::FlatBufferBuilder builder;
476     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
477     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
478         reinterpret_cast<const char*>(builder.GetBufferPointer()),
479         builder.GetSize(), unilib_.get(), calendarlib_.get());
480     ASSERT_TRUE(classifier);
481 
482     EXPECT_EQ("flight1",
483               FirstResult(classifier->ClassifyText(
484                   "Your flight LX373 is delayed by 3 hours.", {12, 17})));
485   }
486 
487   unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
488   {
489     flatbuffers::FlatBufferBuilder builder;
490     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
491     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
492         reinterpret_cast<const char*>(builder.GetBufferPointer()),
493         builder.GetSize(), unilib_.get(), calendarlib_.get());
494     ASSERT_TRUE(classifier);
495 
496     EXPECT_EQ("flight2",
497               FirstResult(classifier->ClassifyText(
498                   "Your flight LX373 is delayed by 3 hours.", {12, 17})));
499   }
500 }
501 
TEST_F(AnnotatorTest,AnnotatePriorityResolution)502 TEST_F(AnnotatorTest, AnnotatePriorityResolution) {
503   const std::string test_model = ReadFile(GetTestModelPath());
504   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
505   TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
506   // Add test regex models. One of them  has higher priority score than
507   // the other. We'll test that always the one with higher priority score
508   // ends up winning.
509   unpacked_model->regex_model->patterns.clear();
510   const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})";
511   unpacked_model->regex_model->patterns.push_back(MakePattern(
512       "flight", flight_regex, /*enabled_for_classification=*/true,
513       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
514       /*score=*/1.0, /*priority_score=*/1.0));
515   unpacked_model->regex_model->patterns.push_back(MakePattern(
516       "flight", flight_regex, /*enabled_for_classification=*/true,
517       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
518       /*score=*/1.0, /*priority_score=*/0.0));
519 
520   // "flight" that wins should have a priority score of 1.0.
521   {
522     flatbuffers::FlatBufferBuilder builder;
523     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
524     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
525         reinterpret_cast<const char*>(builder.GetBufferPointer()),
526         builder.GetSize(), unilib_.get(), calendarlib_.get());
527     ASSERT_TRUE(classifier);
528 
529     const std::vector<AnnotatedSpan> results =
530         classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
531     ASSERT_THAT(results, Not(IsEmpty()));
532     EXPECT_THAT(results[0].classification, Not(IsEmpty()));
533     EXPECT_GE(results[0].classification[0].priority_score, 0.9);
534   }
535 
536   // When we increase the priority score, the "flight" that wins should have a
537   // priority score of 3.0.
538   unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
539   {
540     flatbuffers::FlatBufferBuilder builder;
541     FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
542     std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
543         reinterpret_cast<const char*>(builder.GetBufferPointer()),
544         builder.GetSize(), unilib_.get(), calendarlib_.get());
545     ASSERT_TRUE(classifier);
546 
547     const std::vector<AnnotatedSpan> results =
548         classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
549     ASSERT_THAT(results, Not(IsEmpty()));
550     EXPECT_THAT(results[0].classification, Not(IsEmpty()));
551     EXPECT_GE(results[0].classification[0].priority_score, 2.9);
552   }
553 }
554 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpression)555 TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
556   const std::string test_model = ReadFile(GetTestModelPath());
557   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
558 
559   // Add test regex models.
560   unpacked_model->regex_model->patterns.push_back(MakePattern(
561       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
562       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
563   unpacked_model->regex_model->patterns.push_back(MakePattern(
564       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
565       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
566   unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
567   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
568       MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
569                   /*enabled_for_classification=*/false,
570                   /*enabled_for_selection=*/true,
571                   /*enabled_for_annotation=*/false, 1.0);
572   verified_pattern->verification_options.reset(new VerificationOptionsT);
573   verified_pattern->verification_options->verify_luhn_checksum = true;
574   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
575 
576   flatbuffers::FlatBufferBuilder builder;
577   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
578 
579   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
580       reinterpret_cast<const char*>(builder.GetBufferPointer()),
581       builder.GetSize(), unilib_.get(), calendarlib_.get());
582   ASSERT_TRUE(classifier);
583 
584   // Check regular expression selection.
585   EXPECT_EQ(classifier->SuggestSelection(
586                 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
587             CodepointSpan(12, 19));
588   EXPECT_EQ(classifier->SuggestSelection(
589                 "this afternoon Barack Obama gave a speech at", {15, 21}),
590             CodepointSpan(15, 27));
591   EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
592             CodepointSpan(4, 23));
593 }
594 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionCustomSelectionBounds)595 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
596   const std::string test_model = ReadFile(GetTestModelPath());
597   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
598 
599   // Add test regex models.
600   std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
601       MakePattern("date_range",
602                   "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
603                   "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
604                   /*enabled_for_classification=*/false,
605                   /*enabled_for_selection=*/true,
606                   /*enabled_for_annotation=*/false, 1.0);
607   custom_selection_bounds_pattern->capturing_group.emplace_back(
608       new CapturingGroupT);
609   custom_selection_bounds_pattern->capturing_group.emplace_back(
610       new CapturingGroupT);
611   custom_selection_bounds_pattern->capturing_group.emplace_back(
612       new CapturingGroupT);
613   custom_selection_bounds_pattern->capturing_group.emplace_back(
614       new CapturingGroupT);
615   custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
616   custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
617   custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
618   custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
619   unpacked_model->regex_model->patterns.push_back(
620       std::move(custom_selection_bounds_pattern));
621 
622   flatbuffers::FlatBufferBuilder builder;
623   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
624 
625   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
626       reinterpret_cast<const char*>(builder.GetBufferPointer()),
627       builder.GetSize(), unilib_.get(), calendarlib_.get());
628   ASSERT_TRUE(classifier);
629 
630   // Check regular expression selection.
631   EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
632                                          {21, 23}),
633             CodepointSpan(10, 34));
634   EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
635             CodepointSpan(9, 17));
636 }
637 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsModelWins)638 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
639   const std::string test_model = ReadFile(GetTestModelPath());
640   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
641 
642   // Add test regex models.
643   unpacked_model->regex_model->patterns.push_back(MakePattern(
644       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
645       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
646   unpacked_model->regex_model->patterns.push_back(MakePattern(
647       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
648       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
649   unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
650 
651   flatbuffers::FlatBufferBuilder builder;
652   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
653 
654   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
655       reinterpret_cast<const char*>(builder.GetBufferPointer()),
656       builder.GetSize(), unilib_.get(), calendarlib_.get());
657   ASSERT_TRUE(classifier);
658 
659   // Check conflict resolution.
660   EXPECT_EQ(
661       classifier->SuggestSelection(
662           "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
663           {55, 57}),
664       CodepointSpan(26, 62));
665 }
666 
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsRegexWins)667 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
668   const std::string test_model = ReadFile(GetTestModelPath());
669   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
670 
671   // Add test regex models.
672   unpacked_model->regex_model->patterns.push_back(MakePattern(
673       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
674       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
675   unpacked_model->regex_model->patterns.push_back(MakePattern(
676       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
677       /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
678   unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
679 
680   flatbuffers::FlatBufferBuilder builder;
681   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
682 
683   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
684       reinterpret_cast<const char*>(builder.GetBufferPointer()),
685       builder.GetSize(), unilib_.get(), calendarlib_.get());
686   ASSERT_TRUE(classifier);
687 
688   // Check conflict resolution.
689   EXPECT_EQ(
690       classifier->SuggestSelection(
691           "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
692           {55, 57}),
693       CodepointSpan(55, 62));
694 }
695 
TEST_F(AnnotatorTest,AnnotateRegex)696 TEST_F(AnnotatorTest, AnnotateRegex) {
697   const std::string test_model = ReadFile(GetTestModelPath());
698   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
699 
700   // Add test regex models.
701   unpacked_model->regex_model->patterns.push_back(MakePattern(
702       "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
703       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
704   unpacked_model->regex_model->patterns.push_back(MakePattern(
705       "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
706       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
707   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
708       MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
709                   /*enabled_for_classification=*/false,
710                   /*enabled_for_selection=*/false,
711                   /*enabled_for_annotation=*/true, 1.0);
712   verified_pattern->verification_options.reset(new VerificationOptionsT);
713   verified_pattern->verification_options->verify_luhn_checksum = true;
714   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
715   flatbuffers::FlatBufferBuilder builder;
716   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
717 
718   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
719       reinterpret_cast<const char*>(builder.GetBufferPointer()),
720       builder.GetSize(), unilib_.get(), calendarlib_.get());
721   ASSERT_TRUE(classifier);
722 
723   const std::string test_string =
724       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
725       "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
726   EXPECT_THAT(classifier->Annotate(test_string),
727               ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
728                                 IsAnnotatedSpan(28, 55, "address"),
729                                 IsAnnotatedSpan(79, 91, "phone"),
730                                 IsAnnotatedSpan(107, 126, "payment_card")}));
731 }
732 
TEST_F(AnnotatorTest,AnnotatesFlightNumbers)733 TEST_F(AnnotatorTest, AnnotatesFlightNumbers) {
734   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
735       GetTestModelPath(), unilib_.get(), calendarlib_.get());
736   ASSERT_TRUE(classifier);
737 
738   // ICAO is only used for selected airlines.
739   // Expected: LX373, EZY1234 and U21234.
740   const std::string test_string = "flights LX373, SWR373, EZY1234, U21234";
741   EXPECT_THAT(classifier->Annotate(test_string),
742               ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"),
743                                 IsAnnotatedSpan(23, 30, "flight"),
744                                 IsAnnotatedSpan(32, 38, "flight")}));
745 }
746 
747 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,AnnotateRegexLuaVerification)748 TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) {
749   const std::string test_model = ReadFile(GetTestModelPath());
750   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
751 
752   // Add test regex models.
753   std::unique_ptr<RegexModel_::PatternT> verified_pattern =
754       MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
755                   /*enabled_for_classification=*/true,
756                   /*enabled_for_selection=*/true,
757                   /*enabled_for_annotation=*/true, 1.0);
758   verified_pattern->verification_options.reset(new VerificationOptionsT);
759   verified_pattern->verification_options->lua_verifier = 0;
760   unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
761   unpacked_model->regex_model->lua_verifier.push_back(
762       "return match[2].text==\"99\"");
763 
764   flatbuffers::FlatBufferBuilder builder;
765   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
766 
767   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
768       reinterpret_cast<const char*>(builder.GetBufferPointer()),
769       builder.GetSize(), unilib_.get(), calendarlib_.get());
770   ASSERT_TRUE(classifier);
771 
772   const std::string test_string =
773       "your parcel is on the way: 99-00-123456-12345678";
774   EXPECT_THAT(classifier->Annotate(test_string),
775               ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")}));
776 }
777 #endif  // TC3_DISABLE_LUA
778 
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityData)779 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) {
780   const std::string test_model = ReadFile(GetTestModelPath());
781   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
782 
783   // Add fake entity schema metadata.
784   AddTestEntitySchemaData(unpacked_model.get());
785 
786   AddTestRegexModel(unpacked_model.get());
787 
788   flatbuffers::FlatBufferBuilder builder;
789   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
790 
791   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
792       reinterpret_cast<const char*>(builder.GetBufferPointer()),
793       builder.GetSize(), unilib_.get(), calendarlib_.get());
794   ASSERT_TRUE(classifier);
795 
796   AnnotationOptions options;
797   options.is_serialized_entity_data_enabled = true;
798   auto annotations =
799       classifier->Annotate("Barack Obama is 57 years old", options);
800   EXPECT_EQ(1, annotations.size());
801   EXPECT_EQ(1, annotations[0].classification.size());
802   EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
803 
804   // Check entity data.
805   const flatbuffers::Table* entity =
806       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
807           annotations[0].classification[0].serialized_entity_data.data()));
808   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
809             "Barack");
810   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
811             "Obama");
812   // Check `age`.
813   EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
814 
815   // Check `is_alive`.
816   EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
817 
818   // Check `former_us_president`.
819   EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
820 }
821 
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataNormalization)822 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) {
823   const std::string test_model = ReadFile(GetTestModelPath());
824   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
825 
826   // Add fake entity schema metadata.
827   AddTestEntitySchemaData(unpacked_model.get());
828 
829   AddTestRegexModel(unpacked_model.get());
830 
831   // Upper case last name as post-processing.
832   RegexModel_::PatternT* pattern =
833       unpacked_model->regex_model->patterns.back().get();
834   pattern->capturing_group[2]->normalization_options.reset(
835       new NormalizationOptionsT);
836   pattern->capturing_group[2]
837       ->normalization_options->codepointwise_normalization =
838       NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
839 
840   flatbuffers::FlatBufferBuilder builder;
841   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
842 
843   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
844       reinterpret_cast<const char*>(builder.GetBufferPointer()),
845       builder.GetSize(), unilib_.get(), calendarlib_.get());
846   ASSERT_TRUE(classifier);
847 
848   AnnotationOptions options;
849   options.is_serialized_entity_data_enabled = true;
850   auto annotations =
851       classifier->Annotate("Barack Obama is 57 years old", options);
852   EXPECT_EQ(1, annotations.size());
853   EXPECT_EQ(1, annotations[0].classification.size());
854   EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
855 
856   // Check normalization.
857   const flatbuffers::Table* entity =
858       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
859           annotations[0].classification[0].serialized_entity_data.data()));
860   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
861             "OBAMA");
862 }
863 
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataDisabled)864 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) {
865   const std::string test_model = ReadFile(GetTestModelPath());
866   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
867 
868   // Add fake entity schema metadata.
869   AddTestEntitySchemaData(unpacked_model.get());
870 
871   AddTestRegexModel(unpacked_model.get());
872 
873   flatbuffers::FlatBufferBuilder builder;
874   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
875 
876   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
877       reinterpret_cast<const char*>(builder.GetBufferPointer()),
878       builder.GetSize(), unilib_.get(), calendarlib_.get());
879   ASSERT_TRUE(classifier);
880 
881   AnnotationOptions options;
882   options.is_serialized_entity_data_enabled = false;
883   auto annotations =
884       classifier->Annotate("Barack Obama is 57 years old", options);
885   EXPECT_EQ(1, annotations.size());
886   EXPECT_EQ(1, annotations[0].classification.size());
887   EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
888 
889   // Check entity data.
890   EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data);
891 }
892 
TEST_F(AnnotatorTest,PhoneFiltering)893 TEST_F(AnnotatorTest, PhoneFiltering) {
894   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
895       GetTestModelPath(), unilib_.get(), calendarlib_.get());
896   ASSERT_TRUE(classifier);
897 
898   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
899                          "phone: (123) 456 789", {7, 20})));
900   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
901                          "phone: (123) 456 789,0001112", {7, 25})));
902   EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
903                          "phone: (123) 456 789,0001112", {7, 28})));
904 }
905 
TEST_F(AnnotatorTest,SuggestSelection)906 TEST_F(AnnotatorTest, SuggestSelection) {
907   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
908       GetTestModelPath(), unilib_.get(), calendarlib_.get());
909   ASSERT_TRUE(classifier);
910 
911   EXPECT_EQ(classifier->SuggestSelection(
912                 "this afternoon Barack Obama gave a speech at", {15, 21}),
913             CodepointSpan(15, 21));
914 
915   // Try passing whole string.
916   // If more than 1 token is specified, we should return back what entered.
917   EXPECT_EQ(
918       classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
919       CodepointSpan(0, 27));
920 
921   // Single letter.
922   EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1));
923 
924   // Single word.
925   EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4));
926 
927   EXPECT_EQ(
928       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
929       CodepointSpan(11, 23));
930 
931   // Unpaired bracket stripping.
932   EXPECT_EQ(
933       classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}),
934       CodepointSpan(11, 25));
935   EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}),
936             CodepointSpan(12, 15));
937   EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}),
938             CodepointSpan(11, 15));
939   EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}),
940             CodepointSpan(12, 15));
941 
942   // If the resulting selection would be empty, the original span is returned.
943   EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
944             CodepointSpan(11, 13));
945   EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
946             CodepointSpan(11, 12));
947   EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
948             CodepointSpan(11, 12));
949 
950   // If the original span is larger than the found selection, the original span
951   // is returned.
952   EXPECT_EQ(
953       classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}),
954       CodepointSpan(5, 24));
955 }
956 
TEST_F(AnnotatorTest,SuggestSelectionDisabledFail)957 TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
958   const std::string test_model = ReadFile(GetTestModelPath());
959   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
960 
961   // Disable the selection model.
962   unpacked_model->selection_model.clear();
963   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
964   unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
965 
966   flatbuffers::FlatBufferBuilder builder;
967   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
968 
969   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
970       reinterpret_cast<const char*>(builder.GetBufferPointer()),
971       builder.GetSize(), unilib_.get(), calendarlib_.get());
972   // Selection model needs to be present for annotation.
973   ASSERT_FALSE(classifier);
974 }
975 
TEST_F(AnnotatorTest,SuggestSelectionDisabled)976 TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
977   const std::string test_model = ReadFile(GetTestModelPath());
978   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
979 
980   // Disable the selection model.
981   unpacked_model->selection_model.clear();
982   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
983   unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
984   unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
985 
986   // Disable the number annotator. With the selection model disabled, there is
987   // no feature processor, which is required for the number annotator.
988   unpacked_model->number_annotator_options->enabled = false;
989 
990   flatbuffers::FlatBufferBuilder builder;
991   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
992 
993   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
994       reinterpret_cast<const char*>(builder.GetBufferPointer()),
995       builder.GetSize(), unilib_.get(), calendarlib_.get());
996   ASSERT_TRUE(classifier);
997 
998   EXPECT_EQ(
999       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1000       CodepointSpan(11, 14));
1001 
1002   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
1003                          "call me at (800) 123-456 today", {11, 24})));
1004 
1005   EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
1006               IsEmpty());
1007 }
1008 
TEST_F(AnnotatorTest,SuggestSelectionFilteredCollections)1009 TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
1010   const std::string test_model = ReadFile(GetTestModelPath());
1011 
1012   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1013       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1014   ASSERT_TRUE(classifier);
1015 
1016   EXPECT_EQ(
1017       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1018       CodepointSpan(11, 23));
1019 
1020   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1021   unpacked_model->output_options.reset(new OutputOptionsT);
1022 
1023   // Disable phone selection
1024   unpacked_model->output_options->filtered_collections_selection.push_back(
1025       "phone");
1026   // We need to force this for filtering.
1027   unpacked_model->selection_options->always_classify_suggested_selection = true;
1028 
1029   flatbuffers::FlatBufferBuilder builder;
1030   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1031 
1032   classifier = Annotator::FromUnownedBuffer(
1033       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1034       builder.GetSize(), unilib_.get(), calendarlib_.get());
1035   ASSERT_TRUE(classifier);
1036 
1037   EXPECT_EQ(
1038       classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1039       CodepointSpan(11, 14));
1040 
1041   // Address selection should still work.
1042   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1043             CodepointSpan(0, 27));
1044 }
1045 
TEST_F(AnnotatorTest,SuggestSelectionsAreSymmetric)1046 TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
1047   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1048       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1049   ASSERT_TRUE(classifier);
1050 
1051   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
1052             CodepointSpan(0, 27));
1053   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1054             CodepointSpan(0, 27));
1055   EXPECT_EQ(
1056       classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
1057       CodepointSpan(0, 27));
1058   EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
1059                                          {16, 22}),
1060             CodepointSpan(6, 33));
1061 }
1062 
TEST_F(AnnotatorTest,SuggestSelectionWithNewLine)1063 TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
1064   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1065       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1066   ASSERT_TRUE(classifier);
1067 
1068   EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
1069             CodepointSpan(4, 16));
1070   EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
1071             CodepointSpan(0, 12));
1072 
1073   SelectionOptions options;
1074   EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
1075             CodepointSpan(0, 12));
1076 }
1077 
TEST_F(AnnotatorTest,SuggestSelectionWithPunctuation)1078 TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
1079   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1080       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1081   ASSERT_TRUE(classifier);
1082 
1083   // From the right.
1084   EXPECT_EQ(classifier->SuggestSelection(
1085                 "this afternoon BarackObama, gave a speech at", {15, 26}),
1086             CodepointSpan(15, 26));
1087 
1088   // From the right multiple.
1089   EXPECT_EQ(classifier->SuggestSelection(
1090                 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
1091             CodepointSpan(15, 26));
1092 
1093   // From the left multiple.
1094   EXPECT_EQ(classifier->SuggestSelection(
1095                 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
1096             CodepointSpan(21, 32));
1097 
1098   // From both sides.
1099   EXPECT_EQ(classifier->SuggestSelection(
1100                 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
1101             CodepointSpan(16, 27));
1102 }
1103 
TEST_F(AnnotatorTest,SuggestSelectionNoCrashWithJunk)1104 TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
1105   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1106       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1107   ASSERT_TRUE(classifier);
1108 
1109   // Try passing in bunch of invalid selections.
1110   EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27));
1111   EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
1112             CodepointSpan(-10, 27));
1113   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
1114             CodepointSpan(0, 27));
1115   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
1116             CodepointSpan(-30, 300));
1117   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
1118             CodepointSpan(-10, -1));
1119   EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
1120             CodepointSpan(100, 17));
1121 
1122   // Try passing invalid utf8.
1123   EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
1124             CodepointSpan(-1, -1));
1125 }
1126 
TEST_F(AnnotatorTest,SuggestSelectionSelectSpace)1127 TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
1128   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1129       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1130   ASSERT_TRUE(classifier);
1131 
1132   EXPECT_EQ(
1133       classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
1134       CodepointSpan(11, 23));
1135   EXPECT_EQ(
1136       classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
1137       CodepointSpan(10, 11));
1138   EXPECT_EQ(
1139       classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
1140       CodepointSpan(23, 24));
1141   EXPECT_EQ(
1142       classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
1143       CodepointSpan(23, 24));
1144   EXPECT_EQ(classifier->SuggestSelection("call me at 857   225 3556, today",
1145                                          {14, 17}),
1146             CodepointSpan(11, 25));
1147   EXPECT_EQ(
1148       classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
1149       CodepointSpan(11, 23));
1150   EXPECT_EQ(
1151       classifier->SuggestSelection(
1152           "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
1153       CodepointSpan(14, 40));
1154   EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
1155             CodepointSpan(4, 5));
1156   EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
1157             CodepointSpan(7, 8));
1158 
1159   // With a punctuation around the selected whitespace.
1160   EXPECT_EQ(
1161       classifier->SuggestSelection(
1162           "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
1163       CodepointSpan(14, 41));
1164 
1165   // When all's whitespace, should return the original indices.
1166   EXPECT_EQ(classifier->SuggestSelection("      ", {0, 1}),
1167             CodepointSpan(0, 1));
1168   EXPECT_EQ(classifier->SuggestSelection("      ", {0, 3}),
1169             CodepointSpan(0, 3));
1170   EXPECT_EQ(classifier->SuggestSelection("      ", {2, 3}),
1171             CodepointSpan(2, 3));
1172   EXPECT_EQ(classifier->SuggestSelection("      ", {5, 6}),
1173             CodepointSpan(5, 6));
1174 }
1175 
TEST_F(AnnotatorTest,SnapLeftIfWhitespaceSelection)1176 TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
1177   UnicodeText text;
1178 
1179   text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
1180   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1181             CodepointSpan(3, 4));
1182   text = UTF8ToUnicodeText("abcd     ", /*do_copy=*/false);
1183   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1184             CodepointSpan(3, 4));
1185 
1186   // Nothing on the left.
1187   text = UTF8ToUnicodeText("     efgh", /*do_copy=*/false);
1188   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1189             CodepointSpan(4, 5));
1190   text = UTF8ToUnicodeText("     efgh", /*do_copy=*/false);
1191   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1192             CodepointSpan(0, 1));
1193 
1194   // Whitespace only.
1195   text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
1196   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_),
1197             CodepointSpan(2, 3));
1198   text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
1199   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1200             CodepointSpan(4, 5));
1201   text = UTF8ToUnicodeText("     ", /*do_copy=*/false);
1202   EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1203             CodepointSpan(0, 1));
1204 }
1205 
TEST_F(AnnotatorTest,Annotate)1206 TEST_F(AnnotatorTest, Annotate) {
1207   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1208       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1209   ASSERT_TRUE(classifier);
1210 
1211   const std::string test_string =
1212       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1213       "number is 853 225 3556";
1214   EXPECT_THAT(classifier->Annotate(test_string),
1215               ElementsAreArray({
1216                   IsAnnotatedSpan(28, 55, "address"),
1217                   IsAnnotatedSpan(79, 91, "phone"),
1218               }));
1219 
1220   AnnotationOptions options;
1221   EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1222               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1223   EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1224               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1225   // Try passing invalid utf8.
1226   EXPECT_TRUE(
1227       classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
1228           .empty());
1229 }
1230 
TEST_F(AnnotatorTest,AnnotatesWithBracketStripping)1231 TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
1232   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1233       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1234   ASSERT_TRUE(classifier);
1235 
1236   EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"),
1237               ElementsAreArray({
1238                   IsAnnotatedSpan(11, 26, "phone"),
1239               }));
1240 
1241   // Unpaired bracket stripping.
1242   EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"),
1243               ElementsAreArray({
1244                   IsAnnotatedSpan(12, 23, "phone"),
1245               }));
1246   EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"),
1247               ElementsAreArray({
1248                   IsAnnotatedSpan(11, 22, "phone"),
1249               }));
1250   EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"),
1251               ElementsAreArray({
1252                   IsAnnotatedSpan(12, 23, "phone"),
1253               }));
1254 }
1255 
TEST_F(AnnotatorTest,AnnotatesOverlappingNumbers)1256 TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
1257   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1258       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1259   ASSERT_TRUE(classifier);
1260   AnnotationOptions options;
1261   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1262 
1263   // Number, float number and percentage annotator.
1264   EXPECT_THAT(
1265       classifier->Annotate("853 225 3556 and then turn it up 99%, 99 "
1266                            "number, 12345.12345 float number",
1267                            options),
1268       UnorderedElementsAreArray(
1269           {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"),
1270            IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"),
1271            IsAnnotatedSpan(33, 35, "number"),
1272            IsAnnotatedSpan(33, 36, "percentage"),
1273            IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"),
1274            IsAnnotatedSpan(49, 60, "phone")}));
1275 }
1276 
TEST_F(AnnotatorTest,DoesNotAnnotateNumbersInSmartUsecase)1277 TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) {
1278   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1279       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1280   ASSERT_TRUE(classifier);
1281   AnnotationOptions options;
1282   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1283 
1284   EXPECT_THAT(classifier->Annotate(
1285                   "853 225 3556 and then turn it up 99%, 99 number", options),
1286               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
1287                                 IsAnnotatedSpan(33, 36, "percentage")}));
1288 }
1289 
VerifyAnnotatesDurationsInRawMode(const Annotator * classifier)1290 void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) {
1291   ASSERT_TRUE(classifier);
1292   AnnotationOptions options;
1293   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1294 
1295   // Duration annotator.
1296   EXPECT_THAT(classifier->Annotate(
1297                   "it took 9 minutes and 7 seconds to get there", options),
1298               Contains(IsDurationSpan(
1299                   /*start=*/8, /*end=*/31,
1300                   /*duration_ms=*/9 * 60 * 1000 + 7 * 1000)));
1301 }
1302 
TEST_F(AnnotatorTest,AnnotatesDurationsInRawMode)1303 TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) {
1304   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1305       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1306   VerifyAnnotatesDurationsInRawMode(classifier.get());
1307 }
1308 
VerifyDurationAndRelativeTimeCanOverlapInRawMode(const Annotator * classifier)1309 void VerifyDurationAndRelativeTimeCanOverlapInRawMode(
1310     const Annotator* classifier) {
1311   ASSERT_TRUE(classifier);
1312   AnnotationOptions options;
1313   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1314   options.locales = "en";
1315 
1316   const std::vector<AnnotatedSpan> annotations =
1317       classifier->Annotate("let's meet in 3 hours", options);
1318 
1319   EXPECT_THAT(annotations,
1320               Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21,
1321                                       /*time_ms_utc=*/10800000L,
1322                                       DatetimeGranularity::GRANULARITY_HOUR)));
1323   EXPECT_THAT(annotations,
1324               Contains(IsDurationSpan(/*start=*/14, /*end=*/21,
1325                                       /*duration_ms=*/3 * 60 * 60 * 1000)));
1326 }
1327 
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawMode)1328 TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
1329   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1330       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1331   VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1332 }
1333 
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx)1334 TEST_F(AnnotatorTest,
1335        DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx) {
1336   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1337   std::unique_ptr<Annotator> classifier =
1338       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1339                                    unilib_.get(), calendarlib_.get());
1340   VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1341 }
1342 
TEST_F(AnnotatorTest,AnnotateSplitLines)1343 TEST_F(AnnotatorTest, AnnotateSplitLines) {
1344   std::string model_buffer = ReadFile(GetTestModelPath());
1345   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1346     model->selection_feature_options->only_use_line_with_click = true;
1347   });
1348   std::unique_ptr<Annotator> classifier =
1349       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1350                                    unilib_.get(), calendarlib_.get());
1351 
1352   ASSERT_TRUE(classifier);
1353 
1354   const std::string str1 =
1355       "hey, sorry, just finished up. i didn't hear back from you in time.";
1356   const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1357 
1358   const int kAnnotationLength = 26;
1359   EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
1360   EXPECT_THAT(
1361       classifier->Annotate(str2),
1362       ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
1363 
1364   const std::string str3 = str1 + "\n" + str2;
1365   EXPECT_THAT(
1366       classifier->Annotate(str3),
1367       ElementsAreArray({IsAnnotatedSpan(
1368           str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
1369 }
1370 
TEST_F(AnnotatorTest,UsePipeAsNewLineCharacterShouldAnnotateSplitLines)1371 TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) {
1372   std::string model_buffer = ReadFile(GetTestModelPath());
1373   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1374     model->selection_feature_options->only_use_line_with_click = true;
1375     model->selection_feature_options->use_pipe_character_for_newline = true;
1376   });
1377   std::unique_ptr<Annotator> classifier =
1378       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1379                                    unilib_.get(), calendarlib_.get());
1380 
1381   ASSERT_TRUE(classifier);
1382 
1383   const std::string str1 = "hey, this is my phone number 853 225 3556";
1384   const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1385   const std::string str3 = str1 + "|" + str2;
1386   const int kAnnotationLengthPhone = 12;
1387   const int kAnnotationLengthAddress = 26;
1388   // Splitting the lines on `str3` should have the same behavior (e.g. find the
1389   // phone and address spans) as if we would annotate `str1` and `str2`
1390   // individually.
1391   const std::vector<AnnotatedSpan>& annotated_spans =
1392       classifier->Annotate(str3);
1393   EXPECT_THAT(annotated_spans,
1394               ElementsAreArray(
1395                   {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"),
1396                    IsAnnotatedSpan(static_cast<int>(str1.size()) + 1,
1397                                    static_cast<int>(str1.size() + 1 +
1398                                                     kAnnotationLengthAddress),
1399                                    "address")}));
1400 }
1401 
TEST_F(AnnotatorTest,NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines)1402 TEST_F(AnnotatorTest,
1403        NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) {
1404   std::string model_buffer = ReadFile(GetTestModelPath());
1405   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1406     model->selection_feature_options->only_use_line_with_click = true;
1407     model->selection_feature_options->use_pipe_character_for_newline = false;
1408   });
1409   std::unique_ptr<Annotator> classifier =
1410       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1411                                    unilib_.get(), calendarlib_.get());
1412 
1413   ASSERT_TRUE(classifier);
1414 
1415   const std::string str1 = "hey, this is my phone number 853 225 3556";
1416   const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1417   const std::string str3 = str1 + "|" + str2;
1418   const std::vector<AnnotatedSpan>& annotated_spans =
1419       classifier->Annotate(str3);
1420   // Note: We only check that we get a single annotated span here when the '|'
1421   // character is not used to split lines. The reason behind this is that the
1422   // model is not precise for such example and the resulted annotated span might
1423   // change when the model changes.
1424   EXPECT_THAT(annotated_spans.size(), 1);
1425 }
1426 
TEST_F(AnnotatorTest,AnnotateSmallBatches)1427 TEST_F(AnnotatorTest, AnnotateSmallBatches) {
1428   const std::string test_model = ReadFile(GetTestModelPath());
1429   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1430 
1431   // Set the batch size.
1432   unpacked_model->selection_options->batch_size = 4;
1433   flatbuffers::FlatBufferBuilder builder;
1434   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1435 
1436   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1437       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1438       builder.GetSize(), unilib_.get(), calendarlib_.get());
1439   ASSERT_TRUE(classifier);
1440 
1441   const std::string test_string =
1442       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1443       "number is 853 225 3556";
1444   EXPECT_THAT(classifier->Annotate(test_string),
1445               ElementsAreArray({
1446                   IsAnnotatedSpan(28, 55, "address"),
1447                   IsAnnotatedSpan(79, 91, "phone"),
1448               }));
1449 
1450   AnnotationOptions options;
1451   EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1452               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1453   EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1454               ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1455 }
1456 
TEST_F(AnnotatorTest,AnnotateFilteringDiscardAll)1457 TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
1458   const std::string test_model = ReadFile(GetTestModelPath());
1459   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1460 
1461   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1462   // Add test threshold.
1463   unpacked_model->triggering_options->min_annotate_confidence =
1464       2.f;  // Discards all results.
1465   flatbuffers::FlatBufferBuilder builder;
1466   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1467 
1468   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1469       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1470       builder.GetSize(), unilib_.get(), calendarlib_.get());
1471   ASSERT_TRUE(classifier);
1472 
1473   const std::string test_string =
1474       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1475       "number is 853 225 3556";
1476 
1477   EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
1478 }
1479 
TEST_F(AnnotatorTest,AnnotateFilteringKeepAll)1480 TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
1481   const std::string test_model = ReadFile(GetTestModelPath());
1482   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1483 
1484   // Add test thresholds.
1485   unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1486   unpacked_model->triggering_options->min_annotate_confidence =
1487       0.f;  // Keeps all results.
1488   unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
1489   flatbuffers::FlatBufferBuilder builder;
1490   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1491 
1492   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1493       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1494       builder.GetSize(), unilib_.get(), calendarlib_.get());
1495   ASSERT_TRUE(classifier);
1496 
1497   const std::string test_string =
1498       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1499       "number is 853 225 3556";
1500   EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
1501 }
1502 
TEST_F(AnnotatorTest,AnnotateDisabled)1503 TEST_F(AnnotatorTest, AnnotateDisabled) {
1504   const std::string test_model = ReadFile(GetTestModelPath());
1505   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1506 
1507   // Disable the model for annotation.
1508   unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
1509   flatbuffers::FlatBufferBuilder builder;
1510   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1511 
1512   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1513       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1514       builder.GetSize(), unilib_.get(), calendarlib_.get());
1515   ASSERT_TRUE(classifier);
1516   const std::string test_string =
1517       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1518       "number is 853 225 3556";
1519   EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
1520 }
1521 
TEST_F(AnnotatorTest,AnnotateFilteredCollections)1522 TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
1523   const std::string test_model = ReadFile(GetTestModelPath());
1524 
1525   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1526       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1527   ASSERT_TRUE(classifier);
1528 
1529   const std::string test_string =
1530       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1531       "number is 853 225 3556";
1532 
1533   EXPECT_THAT(classifier->Annotate(test_string),
1534               ElementsAreArray({
1535                   IsAnnotatedSpan(28, 55, "address"),
1536                   IsAnnotatedSpan(79, 91, "phone"),
1537               }));
1538 
1539   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1540   unpacked_model->output_options.reset(new OutputOptionsT);
1541 
1542   // Disable phone annotation
1543   unpacked_model->output_options->filtered_collections_annotation.push_back(
1544       "phone");
1545 
1546   flatbuffers::FlatBufferBuilder builder;
1547   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1548 
1549   classifier = Annotator::FromUnownedBuffer(
1550       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1551       builder.GetSize(), unilib_.get(), calendarlib_.get());
1552   ASSERT_TRUE(classifier);
1553 
1554   EXPECT_THAT(classifier->Annotate(test_string),
1555               ElementsAreArray({
1556                   IsAnnotatedSpan(28, 55, "address"),
1557               }));
1558 }
1559 
TEST_F(AnnotatorTest,AnnotateFilteredCollectionsSuppress)1560 TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
1561   const std::string test_model = ReadFile(GetTestModelPath());
1562 
1563   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1564       test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1565   ASSERT_TRUE(classifier);
1566 
1567   const std::string test_string =
1568       "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1569       "number is 853 225 3556";
1570 
1571   EXPECT_THAT(classifier->Annotate(test_string),
1572               ElementsAreArray({
1573                   IsAnnotatedSpan(28, 55, "address"),
1574                   IsAnnotatedSpan(79, 91, "phone"),
1575               }));
1576 
1577   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1578   unpacked_model->output_options.reset(new OutputOptionsT);
1579 
1580   // We add a custom annotator that wins against the phone classification
1581   // below and that we subsequently suppress.
1582   unpacked_model->output_options->filtered_collections_annotation.push_back(
1583       "suppress");
1584 
1585   unpacked_model->regex_model->patterns.push_back(MakePattern(
1586       "suppress", "(\\d{3} ?\\d{4})",
1587       /*enabled_for_classification=*/false,
1588       /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
1589 
1590   flatbuffers::FlatBufferBuilder builder;
1591   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1592 
1593   classifier = Annotator::FromUnownedBuffer(
1594       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1595       builder.GetSize(), unilib_.get(), calendarlib_.get());
1596   ASSERT_TRUE(classifier);
1597 
1598   EXPECT_THAT(classifier->Annotate(test_string),
1599               ElementsAreArray({
1600                   IsAnnotatedSpan(28, 55, "address"),
1601               }));
1602 }
1603 
VerifyClassifyTextDateInZurichTimezone(const Annotator * classifier)1604 void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) {
1605   EXPECT_TRUE(classifier);
1606   ClassificationOptions options;
1607   options.reference_timezone = "Europe/Zurich";
1608   options.locales = "en";
1609 
1610   std::vector<ClassificationResult> result =
1611       classifier->ClassifyText("january 1, 2017", {0, 15}, options);
1612 
1613   EXPECT_THAT(result,
1614               ElementsAre(IsDateResult(1483225200000,
1615                                        DatetimeGranularity::GRANULARITY_DAY)));
1616 }
1617 
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezone)1618 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
1619   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1620       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1621   VerifyClassifyTextDateInZurichTimezone(classifier.get());
1622 }
1623 
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezoneWithDatetimeRegEx)1624 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeRegEx) {
1625   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1626   std::unique_ptr<Annotator> classifier =
1627       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1628                                    unilib_.get(), calendarlib_.get());
1629   VerifyClassifyTextDateInZurichTimezone(classifier.get());
1630 }
1631 
VerifyClassifyTextDateInLATimezone(const Annotator * classifier)1632 void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
1633   EXPECT_TRUE(classifier);
1634   ClassificationOptions options;
1635   options.reference_timezone = "America/Los_Angeles";
1636   options.locales = "en";
1637 
1638   std::vector<ClassificationResult> result =
1639       classifier->ClassifyText("march 1, 2017", {0, 13}, options);
1640 
1641   EXPECT_THAT(result,
1642               ElementsAre(IsDateResult(1488355200000,
1643                                        DatetimeGranularity::GRANULARITY_DAY)));
1644 }
1645 
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezoneWithDatetimeRegEx)1646 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeRegEx) {
1647   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1648   std::unique_ptr<Annotator> classifier =
1649       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1650                                    unilib_.get(), calendarlib_.get());
1651   VerifyClassifyTextDateInLATimezone(classifier.get());
1652 }
1653 
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezone)1654 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
1655   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1656       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1657   VerifyClassifyTextDateInLATimezone(classifier.get());
1658 }
1659 
VerifyClassifyTextDateOnAotherLine(const Annotator * classifier)1660 void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
1661   EXPECT_TRUE(classifier);
1662   ClassificationOptions options;
1663   options.reference_timezone = "Europe/Zurich";
1664   options.locales = "en";
1665 
1666   std::vector<ClassificationResult> result = classifier->ClassifyText(
1667       "hello world this is the first line\n"
1668       "january 1, 2017",
1669       {35, 50}, options);
1670 
1671   EXPECT_THAT(result,
1672               ElementsAre(IsDateResult(1483225200000,
1673                                        DatetimeGranularity::GRANULARITY_DAY)));
1674 }
1675 
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLineWithDatetimeRegEx)1676 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeRegEx) {
1677   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1678   std::unique_ptr<Annotator> classifier =
1679       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1680                                    unilib_.get(), calendarlib_.get());
1681   VerifyClassifyTextDateOnAotherLine(classifier.get());
1682 }
1683 
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLine)1684 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
1685   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1686       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1687   VerifyClassifyTextDateOnAotherLine(classifier.get());
1688 }
1689 
VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(const Annotator * classifier)1690 void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
1691     const Annotator* classifier) {
1692   EXPECT_TRUE(classifier);
1693   std::vector<ClassificationResult> result;
1694   ClassificationOptions options;
1695 
1696   options.reference_timezone = "Europe/Zurich";
1697   options.locales = "en-US";
1698   result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
1699 
1700   // In US, the date should be interpreted as <month>.<day>.
1701   EXPECT_THAT(result,
1702               ElementsAre(IsDatetimeResult(
1703                   5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
1704 }
1705 
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDay)1706 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
1707   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1708       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1709   VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1710 }
1711 
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx)1712 TEST_F(AnnotatorTest,
1713        ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx) {
1714   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1715   std::unique_ptr<Annotator> classifier =
1716       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1717                                    unilib_.get(), calendarlib_.get());
1718   VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1719 }
1720 
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay)1721 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
1722   std::string model_buffer = GetTestModelWithDatetimeRegEx();
1723   std::unique_ptr<Annotator> classifier =
1724       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1725                                    unilib_.get(), calendarlib_.get());
1726   EXPECT_TRUE(classifier);
1727   std::vector<ClassificationResult> result;
1728   ClassificationOptions options;
1729 
1730   options.reference_timezone = "Europe/Zurich";
1731   options.locales = "de";
1732   result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
1733 
1734   // In Germany, the date should be interpreted as <day>.<month>.
1735   EXPECT_THAT(result,
1736               ElementsAre(IsDatetimeResult(
1737                   10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
1738 }
1739 
TEST_F(AnnotatorTest,ClassifyTextAmbiguousDatetime)1740 TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
1741   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1742       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1743   EXPECT_TRUE(classifier);
1744   ClassificationOptions options;
1745   options.reference_timezone = "Europe/Zurich";
1746   options.locales = "en-US";
1747   const std::vector<ClassificationResult> result =
1748       classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
1749 
1750   EXPECT_THAT(
1751       result,
1752       ElementsAre(
1753           IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1754           IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1755 }
1756 
TEST_F(AnnotatorTest,AnnotateAmbiguousDatetime)1757 TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
1758   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1759       GetTestModelPath(), unilib_.get(), calendarlib_.get());
1760   EXPECT_TRUE(classifier);
1761   AnnotationOptions options;
1762   options.reference_timezone = "Europe/Zurich";
1763   options.locales = "en-US";
1764   const std::vector<AnnotatedSpan> spans =
1765       classifier->Annotate("set an alarm for 10:30", options);
1766 
1767   ASSERT_EQ(spans.size(), 1);
1768   const std::vector<ClassificationResult> result = spans[0].classification;
1769   EXPECT_THAT(
1770       result,
1771       ElementsAre(
1772           IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1773           IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1774 }
1775 
TEST_F(AnnotatorTest,SuggestTextDateDisabled)1776 TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
1777   std::string test_model = GetTestModelWithDatetimeRegEx();
1778   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1779 
1780   // Disable the patterns for selection.
1781   for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1782     unpacked_model->datetime_model->patterns[i]->enabled_modes =
1783         ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1784   }
1785   flatbuffers::FlatBufferBuilder builder;
1786   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1787 
1788   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1789       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1790       builder.GetSize(), unilib_.get(), calendarlib_.get());
1791   ASSERT_TRUE(classifier);
1792   EXPECT_EQ("date",
1793             FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1794   EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1795             CodepointSpan(0, 7));
1796   EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1797               ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1798 }
1799 
TEST_F(AnnotatorTest,AnnotatesWithGrammarModel)1800 TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
1801   const std::string test_model = ReadFile(GetTestModelPath());
1802   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1803 
1804   // Add test grammar model.
1805   unpacked_model->grammar_model.reset(new GrammarModelT);
1806   GrammarModelT* grammar_model = unpacked_model->grammar_model.get();
1807   grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
1808   grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU;
1809   grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false;
1810   grammar_model->tokenizer_options->tokenize_on_script_change = true;
1811 
1812   // Add test rules.
1813   grammar_model->rules.reset(new grammar::RulesSetT);
1814   grammar::LocaleShardMap locale_shard_map =
1815       grammar::LocaleShardMap::CreateLocaleShardMap({""});
1816   grammar::Rules rules(locale_shard_map);
1817   rules.Add("<tv_detective>", {"jessica", "fletcher"});
1818   rules.Add("<tv_detective>", {"columbo"});
1819   rules.Add("<tv_detective>", {"magnum"});
1820   rules.Add(
1821       "<famous_person>", {"<tv_detective>"},
1822       /*callback=*/
1823       static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1824       /*callback_param=*/0 /* rule classification result */);
1825 
1826   // Set result.
1827   grammar_model->rule_classification_result.emplace_back(
1828       new GrammarModel_::RuleClassificationResultT);
1829   GrammarModel_::RuleClassificationResultT* result =
1830       grammar_model->rule_classification_result.back().get();
1831   result->collection_name = "famous person";
1832   result->enabled_modes = ModeFlag_ALL;
1833   rules.Finalize().Serialize(/*include_debug_information=*/false,
1834                              grammar_model->rules.get());
1835   flatbuffers::FlatBufferBuilder builder;
1836   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1837 
1838   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1839       reinterpret_cast<const char*>(builder.GetBufferPointer()),
1840       builder.GetSize(), unilib_.get(), calendarlib_.get());
1841   ASSERT_TRUE(classifier);
1842 
1843   const std::string test_string =
1844       "Did you see the Novel Connection episode where Jessica Fletcher helps "
1845       "Magnum solve the case? I thought that was with Columbo ...";
1846 
1847   EXPECT_THAT(classifier->Annotate(test_string),
1848               ElementsAre(IsAnnotatedSpan(47, 63, "famous person"),
1849                           IsAnnotatedSpan(70, 76, "famous person"),
1850                           IsAnnotatedSpan(117, 124, "famous person")));
1851   EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher",
1852                                                    CodepointSpan{0, 16})),
1853               Eq("famous person"));
1854   EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}),
1855               Eq(CodepointSpan{0, 16}));
1856 }
1857 
TEST_F(AnnotatorTest,ResolveConflictsTrivial)1858 TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1859   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1860 
1861   std::vector<AnnotatedSpan> candidates{
1862       {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1863   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1864 
1865   BaseOptions options;
1866   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1867   std::vector<int> chosen;
1868   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1869                               locales, options,
1870                               /*interpreter_manager=*/nullptr, &chosen);
1871   EXPECT_THAT(chosen, ElementsAreArray({0}));
1872 }
1873 
TEST_F(AnnotatorTest,ResolveConflictsSequence)1874 TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1875   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1876 
1877   std::vector<AnnotatedSpan> candidates{{
1878       MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1879       MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1880       MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1881       MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1882       MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1883   }};
1884   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1885 
1886   BaseOptions options;
1887   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1888   std::vector<int> chosen;
1889   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1890                               locales, options,
1891                               /*interpreter_manager=*/nullptr, &chosen);
1892   EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1893 }
1894 
TEST_F(AnnotatorTest,ResolveConflictsThreeSpans)1895 TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1896   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1897 
1898   std::vector<AnnotatedSpan> candidates{{
1899       MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1900       MakeAnnotatedSpan({1, 5}, "phone", 0.5),  // Looser!
1901       MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1902   }};
1903   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1904 
1905   BaseOptions options;
1906   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1907   std::vector<int> chosen;
1908   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1909                               locales, options,
1910                               /*interpreter_manager=*/nullptr, &chosen);
1911   EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1912 }
1913 
TEST_F(AnnotatorTest,ResolveConflictsThreeSpansReversed)1914 TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1915   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1916 
1917   std::vector<AnnotatedSpan> candidates{{
1918       MakeAnnotatedSpan({0, 3}, "phone", 0.5),  // Looser!
1919       MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1920       MakeAnnotatedSpan({3, 7}, "phone", 0.6),  // Looser!
1921   }};
1922   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1923 
1924   BaseOptions options;
1925   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1926   std::vector<int> chosen;
1927   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1928                               locales, options,
1929                               /*interpreter_manager=*/nullptr, &chosen);
1930   EXPECT_THAT(chosen, ElementsAreArray({1}));
1931 }
1932 
TEST_F(AnnotatorTest,DoesNotPrioritizeLongerSpanWhenDoingConflictResolution)1933 TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) {
1934   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1935 
1936   std::vector<AnnotatedSpan> candidates{{
1937       MakeAnnotatedSpan({3, 7}, "unit", 1),
1938       MakeAnnotatedSpan({5, 13}, "unit", 1),  // Looser!
1939       MakeAnnotatedSpan({5, 30}, "url", 1),   // Looser!
1940       MakeAnnotatedSpan({14, 20}, "email", 1),
1941   }};
1942   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1943 
1944   BaseOptions options;
1945   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1946   std::vector<int> chosen;
1947   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1948                               locales, options,
1949                               /*interpreter_manager=*/nullptr, &chosen);
1950   // Picks the first and the last annotations because they do not overlap.
1951   EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
1952 }
1953 
TEST_F(AnnotatorTest,PrioritizeLongerSpanWhenDoingConflictResolution)1954 TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) {
1955   const std::string test_model = ReadFile(GetTestModelPath());
1956   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1957   TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
1958   unpacked_model->conflict_resolution_options.reset(
1959       new Model_::ConflictResolutionOptionsT);
1960   unpacked_model->conflict_resolution_options->prioritize_longest_annotation =
1961       true;
1962 
1963   flatbuffers::FlatBufferBuilder builder;
1964   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1965 
1966   std::unique_ptr<TestingAnnotator> classifier =
1967       TestingAnnotator::FromUnownedBuffer(
1968           reinterpret_cast<const char*>(builder.GetBufferPointer()),
1969           builder.GetSize(), unilib_.get(), calendarlib_.get());
1970   TC3_CHECK(classifier != nullptr);
1971 
1972   std::vector<AnnotatedSpan> candidates{{
1973       MakeAnnotatedSpan({3, 7}, "unit", 1),     // Looser!
1974       MakeAnnotatedSpan({5, 13}, "unit", 1),    // Looser!
1975       MakeAnnotatedSpan({5, 30}, "url", 1),     // Pick longest match.
1976       MakeAnnotatedSpan({14, 20}, "email", 1),  // Looser!
1977   }};
1978   std::vector<Locale> locales = {Locale::FromBCP47("en")};
1979 
1980   BaseOptions options;
1981   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1982   std::vector<int> chosen;
1983   classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1984                                locales, options,
1985                                /*interpreter_manager=*/nullptr, &chosen);
1986   EXPECT_THAT(chosen, ElementsAreArray({2}));
1987 }
1988 
TEST_F(AnnotatorTest,ResolveConflictsFiveSpans)1989 TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
1990   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1991 
1992   std::vector<AnnotatedSpan> candidates{{
1993       MakeAnnotatedSpan({0, 3}, "phone", 0.5),
1994       MakeAnnotatedSpan({1, 5}, "other", 1.0),  // Looser!
1995       MakeAnnotatedSpan({3, 7}, "phone", 0.6),
1996       MakeAnnotatedSpan({8, 12}, "phone", 0.6),  // Looser!
1997       MakeAnnotatedSpan({11, 15}, "phone", 0.9),
1998   }};
1999   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2000 
2001   BaseOptions options;
2002   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
2003   std::vector<int> chosen;
2004   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2005                               locales, options,
2006                               /*interpreter_manager=*/nullptr, &chosen);
2007   EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
2008 }
2009 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst)2010 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
2011   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2012 
2013   std::vector<AnnotatedSpan> candidates{{
2014       MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2015                         AnnotatedSpan::Source::KNOWLEDGE),
2016       MakeAnnotatedSpan({5, 10}, "address", 0.6),
2017   }};
2018   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2019 
2020   BaseOptions options;
2021   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2022   std::vector<int> chosen;
2023   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2024                               locales, options,
2025                               /*interpreter_manager=*/nullptr, &chosen);
2026   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2027 }
2028 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond)2029 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
2030   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2031 
2032   std::vector<AnnotatedSpan> candidates{{
2033       MakeAnnotatedSpan({0, 15}, "address", 0.7),
2034       MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2035                         AnnotatedSpan::Source::KNOWLEDGE),
2036   }};
2037   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2038 
2039   BaseOptions options;
2040   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2041   std::vector<int> chosen;
2042   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2043                               locales, options,
2044                               /*interpreter_manager=*/nullptr, &chosen);
2045   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2046 }
2047 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedBothKnowledge)2048 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
2049   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2050 
2051   std::vector<AnnotatedSpan> candidates{{
2052       MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2053                         AnnotatedSpan::Source::KNOWLEDGE),
2054       MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2055                         AnnotatedSpan::Source::KNOWLEDGE),
2056   }};
2057   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2058 
2059   BaseOptions options;
2060   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2061   std::vector<int> chosen;
2062   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2063                               locales, options,
2064                               /*interpreter_manager=*/nullptr, &chosen);
2065   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2066 }
2067 
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsNotAllowed)2068 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
2069   TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2070 
2071   std::vector<AnnotatedSpan> candidates{{
2072       MakeAnnotatedSpan({0, 15}, "address", 0.7),
2073       MakeAnnotatedSpan({5, 10}, "date", 0.6),
2074   }};
2075   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2076 
2077   BaseOptions options;
2078   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2079   std::vector<int> chosen;
2080   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2081                               locales, options,
2082                               /*interpreter_manager=*/nullptr, &chosen);
2083   EXPECT_THAT(chosen, ElementsAreArray({0}));
2084 }
2085 
TEST_F(AnnotatorTest,ResolveConflictsRawModeGeneralOverlapsAllowed)2086 TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) {
2087   TestingAnnotator classifier(
2088       unilib_.get(), calendarlib_.get(), [](ModelT* model) {
2089         model->conflict_resolution_options.reset(
2090             new Model_::ConflictResolutionOptionsT);
2091         model->conflict_resolution_options->do_conflict_resolution_in_raw_mode =
2092             false;
2093       });
2094 
2095   std::vector<AnnotatedSpan> candidates{{
2096       MakeAnnotatedSpan({0, 15}, "address", 0.7),
2097       MakeAnnotatedSpan({5, 10}, "date", 0.6),
2098   }};
2099   std::vector<Locale> locales = {Locale::FromBCP47("en")};
2100 
2101   BaseOptions options;
2102   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2103   std::vector<int> chosen;
2104   classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2105                               locales, options,
2106                               /*interpreter_manager=*/nullptr, &chosen);
2107   EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2108 }
2109 
VerifyLongInput(const Annotator * classifier)2110 void VerifyLongInput(const Annotator* classifier) {
2111   ASSERT_TRUE(classifier);
2112 
2113   for (const auto& type_value_pair :
2114        std::vector<std::pair<std::string, std::string>>{
2115            {"address", "350 Third Street, Cambridge"},
2116            {"phone", "123 456-7890"},
2117            {"url", "www.google.com"},
2118            {"email", "someone@gmail.com"},
2119            {"flight", "LX 38"},
2120            {"date", "September 1, 2018"}}) {
2121     const std::string input_100k = std::string(50000, ' ') +
2122                                    type_value_pair.second +
2123                                    std::string(50000, ' ');
2124     const int value_length = type_value_pair.second.size();
2125 
2126     AnnotationOptions annotation_options;
2127     annotation_options.locales = "en";
2128     EXPECT_THAT(classifier->Annotate(input_100k, annotation_options),
2129                 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
2130                                                   type_value_pair.first)}));
2131     SelectionOptions selection_options;
2132     selection_options.locales = "en";
2133     EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001},
2134                                            selection_options),
2135               CodepointSpan(50000, 50000 + value_length));
2136 
2137     ClassificationOptions classification_options;
2138     classification_options.locales = "en";
2139     EXPECT_EQ(type_value_pair.first,
2140               FirstResult(classifier->ClassifyText(
2141                   input_100k, {50000, 50000 + value_length},
2142                   classification_options)));
2143   }
2144 }
2145 
TEST_F(AnnotatorTest,LongInput)2146 TEST_F(AnnotatorTest, LongInput) {
2147   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2148       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2149   VerifyLongInput(classifier.get());
2150 }
2151 
TEST_F(AnnotatorTest,LongInputWithRegExDatetime)2152 TEST_F(AnnotatorTest, LongInputWithRegExDatetime) {
2153   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2154   std::unique_ptr<Annotator> classifier =
2155       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2156                                    unilib_.get(), calendarlib_.get());
2157   VerifyLongInput(classifier.get());
2158 }
2159 
2160 // These coarse tests are there only to make sure the execution happens in
2161 // reasonable amount of time.
TEST_F(AnnotatorTest,LongInputNoResultCheck)2162 TEST_F(AnnotatorTest, LongInputNoResultCheck) {
2163   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2164       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2165   ASSERT_TRUE(classifier);
2166 
2167   for (const std::string& value :
2168        std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
2169     const std::string input_100k =
2170         std::string(50000, ' ') + value + std::string(50000, ' ');
2171     const int value_length = value.size();
2172 
2173     classifier->Annotate(input_100k);
2174     classifier->SuggestSelection(input_100k, {50000, 50001});
2175     classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
2176   }
2177 }
2178 
TEST_F(AnnotatorTest,MaxTokenLength)2179 TEST_F(AnnotatorTest, MaxTokenLength) {
2180   const std::string test_model = ReadFile(GetTestModelPath());
2181   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2182 
2183   std::unique_ptr<Annotator> classifier;
2184 
2185   // With unrestricted number of tokens should behave normally.
2186   unpacked_model->classification_options->max_num_tokens = -1;
2187 
2188   flatbuffers::FlatBufferBuilder builder;
2189   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2190   classifier = Annotator::FromUnownedBuffer(
2191       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2192       builder.GetSize(), unilib_.get(), calendarlib_.get());
2193   ASSERT_TRUE(classifier);
2194 
2195   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2196                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2197             "address");
2198 
2199   // Raise the maximum number of tokens to suppress the classification.
2200   unpacked_model->classification_options->max_num_tokens = 3;
2201 
2202   flatbuffers::FlatBufferBuilder builder2;
2203   FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2204   classifier = Annotator::FromUnownedBuffer(
2205       reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2206       builder2.GetSize(), unilib_.get(), calendarlib_.get());
2207   ASSERT_TRUE(classifier);
2208 
2209   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2210                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2211             "other");
2212 }
2213 
TEST_F(AnnotatorTest,MinAddressTokenLength)2214 TEST_F(AnnotatorTest, MinAddressTokenLength) {
2215   const std::string test_model = ReadFile(GetTestModelPath());
2216   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2217 
2218   std::unique_ptr<Annotator> classifier;
2219 
2220   // With unrestricted number of address tokens should behave normally.
2221   unpacked_model->classification_options->address_min_num_tokens = 0;
2222 
2223   flatbuffers::FlatBufferBuilder builder;
2224   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2225   classifier = Annotator::FromUnownedBuffer(
2226       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2227       builder.GetSize(), unilib_.get(), calendarlib_.get());
2228   ASSERT_TRUE(classifier);
2229 
2230   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2231                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2232             "address");
2233 
2234   // Raise number of address tokens to suppress the address classification.
2235   unpacked_model->classification_options->address_min_num_tokens = 5;
2236 
2237   flatbuffers::FlatBufferBuilder builder2;
2238   FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2239   classifier = Annotator::FromUnownedBuffer(
2240       reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2241       builder2.GetSize(), unilib_.get(), calendarlib_.get());
2242   ASSERT_TRUE(classifier);
2243 
2244   EXPECT_EQ(FirstResult(classifier->ClassifyText(
2245                 "I live at 350 Third Street, Cambridge.", {10, 37})),
2246             "other");
2247 }
2248 
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighOtherIsPreferredToFlight)2249 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) {
2250   const std::string test_model = ReadFile(GetTestModelPath());
2251   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2252   unpacked_model->triggering_options->other_collection_priority_score = 1.0;
2253 
2254   flatbuffers::FlatBufferBuilder builder;
2255   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2256   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2257       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2258       builder.GetSize(), unilib_.get(), calendarlib_.get());
2259   ASSERT_TRUE(classifier);
2260 
2261   EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other");
2262 }
2263 
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighFlightIsPreferredToOther)2264 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) {
2265   const std::string test_model = ReadFile(GetTestModelPath());
2266   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2267   unpacked_model->triggering_options->other_collection_priority_score = -100.0;
2268 
2269   flatbuffers::FlatBufferBuilder builder;
2270   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2271   std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2272       reinterpret_cast<const char*>(builder.GetBufferPointer()),
2273       builder.GetSize(), unilib_.get(), calendarlib_.get());
2274   ASSERT_TRUE(classifier);
2275 
2276   EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight");
2277 }
2278 
TEST_F(AnnotatorTest,VisitAnnotatorModel)2279 TEST_F(AnnotatorTest, VisitAnnotatorModel) {
2280   EXPECT_TRUE(
2281       VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
2282         if (model == nullptr) {
2283           return false;
2284         }
2285         return true;
2286       }));
2287   EXPECT_FALSE(VisitAnnotatorModel<bool>(
2288       GetModelPath() + "non_existing_model.fb", [](const Model* model) {
2289         if (model == nullptr) {
2290           return false;
2291         }
2292         return true;
2293       }));
2294 }
2295 
TEST_F(AnnotatorTest,TriggersWhenNoLanguageDetected)2296 TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) {
2297   std::string model_buffer = ReadFile(GetTestModelPath());
2298   model_buffer = ModifyAnnotatorModel(
2299       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2300   std::unique_ptr<Annotator> classifier =
2301       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2302                                    unilib_.get(), calendarlib_.get());
2303   ASSERT_TRUE(classifier);
2304 
2305   EXPECT_THAT(classifier->Annotate("(555) 225-3556"),
2306               ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2307   EXPECT_EQ("phone",
2308             FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14})));
2309   EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}),
2310             CodepointSpan(0, 14));
2311 }
2312 
TEST_F(AnnotatorTest,AnnotateTriggersWhenSupportedLanguageDetected)2313 TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) {
2314   std::string model_buffer = ReadFile(GetTestModelPath());
2315   model_buffer = ModifyAnnotatorModel(
2316       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2317   std::unique_ptr<Annotator> classifier =
2318       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2319                                    unilib_.get(), calendarlib_.get());
2320   ASSERT_TRUE(classifier);
2321   AnnotationOptions options;
2322   options.detected_text_language_tags = "cs";
2323 
2324   EXPECT_THAT(classifier->Annotate("(555) 225-3556", options),
2325               ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2326 }
2327 
TEST_F(AnnotatorTest,AnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2328 TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2329   std::string model_buffer = ReadFile(GetTestModelPath());
2330   model_buffer = ModifyAnnotatorModel(
2331       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2332   std::unique_ptr<Annotator> classifier =
2333       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2334                                    unilib_.get(), calendarlib_.get());
2335   ASSERT_TRUE(classifier);
2336   AnnotationOptions options;
2337   options.detected_text_language_tags = "de";
2338 
2339   EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty());
2340 }
2341 
TEST_F(AnnotatorTest,ClassifyTextTriggersWhenSupportedLanguageDetected)2342 TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) {
2343   std::string model_buffer = ReadFile(GetTestModelPath());
2344   model_buffer = ModifyAnnotatorModel(
2345       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2346   std::unique_ptr<Annotator> classifier =
2347       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2348                                    unilib_.get(), calendarlib_.get());
2349   ASSERT_TRUE(classifier);
2350   ClassificationOptions options;
2351   options.detected_text_language_tags = "cs";
2352 
2353   EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556",
2354                                                           {0, 14}, options)));
2355 }
2356 
TEST_F(AnnotatorTest,ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2357 TEST_F(AnnotatorTest,
2358        ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2359   std::string model_buffer = ReadFile(GetTestModelPath());
2360   model_buffer = ModifyAnnotatorModel(
2361       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2362   std::unique_ptr<Annotator> classifier =
2363       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2364                                    unilib_.get(), calendarlib_.get());
2365   ASSERT_TRUE(classifier);
2366   ClassificationOptions options;
2367   options.detected_text_language_tags = "de";
2368 
2369   EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options),
2370               IsEmpty());
2371 }
2372 
TEST_F(AnnotatorTest,SuggestSelectionTriggersWhenSupportedLanguageDetected)2373 TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) {
2374   std::string model_buffer = ReadFile(GetTestModelPath());
2375   model_buffer = ModifyAnnotatorModel(
2376       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2377   std::unique_ptr<Annotator> classifier =
2378       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2379                                    unilib_.get(), calendarlib_.get());
2380   ASSERT_TRUE(classifier);
2381   SelectionOptions options;
2382   options.detected_text_language_tags = "cs";
2383 
2384   EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2385             CodepointSpan(0, 14));
2386 }
2387 
TEST_F(AnnotatorTest,SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2388 TEST_F(AnnotatorTest,
2389        SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2390   std::string model_buffer = ReadFile(GetTestModelPath());
2391   model_buffer = ModifyAnnotatorModel(
2392       model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2393   std::unique_ptr<Annotator> classifier =
2394       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2395                                    unilib_.get(), calendarlib_.get());
2396   ASSERT_TRUE(classifier);
2397   SelectionOptions options;
2398   options.detected_text_language_tags = "de";
2399 
2400   EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2401             CodepointSpan(6, 9));
2402 }
2403 
TEST_F(AnnotatorTest,MlModelTriggersWhenNoLanguageDetected)2404 TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) {
2405   std::string model_buffer = ReadFile(GetTestModelPath());
2406   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2407     model->triggering_locales = "en,cs";
2408     model->triggering_options->locales = "en,cs";
2409   });
2410   std::unique_ptr<Annotator> classifier =
2411       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2412                                    unilib_.get(), calendarlib_.get());
2413   ASSERT_TRUE(classifier);
2414 
2415   EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"),
2416               ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2417   EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2418                            "350 Third Street, Cambridge", {0, 27})));
2419   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
2420             CodepointSpan(0, 27));
2421 }
2422 
TEST_F(AnnotatorTest,MlModelAnnotateTriggersWhenSupportedLanguageDetected)2423 TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) {
2424   std::string model_buffer = ReadFile(GetTestModelPath());
2425   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2426     model->triggering_locales = "en,cs";
2427     model->triggering_options->locales = "en,cs";
2428   });
2429   std::unique_ptr<Annotator> classifier =
2430       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2431                                    unilib_.get(), calendarlib_.get());
2432   ASSERT_TRUE(classifier);
2433   AnnotationOptions options;
2434   options.detected_text_language_tags = "cs";
2435 
2436   EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2437               ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2438 }
2439 
TEST_F(AnnotatorTest,MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2440 TEST_F(AnnotatorTest,
2441        MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2442   std::string model_buffer = ReadFile(GetTestModelPath());
2443   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2444     model->triggering_locales = "en,cs";
2445     model->triggering_options->locales = "en,cs";
2446   });
2447   std::unique_ptr<Annotator> classifier =
2448       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2449                                    unilib_.get(), calendarlib_.get());
2450   ASSERT_TRUE(classifier);
2451   AnnotationOptions options;
2452   options.detected_text_language_tags = "de";
2453 
2454   EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2455               IsEmpty());
2456 }
2457 
TEST_F(AnnotatorTest,MlModelClassifyTextTriggersWhenSupportedLanguageDetected)2458 TEST_F(AnnotatorTest,
2459        MlModelClassifyTextTriggersWhenSupportedLanguageDetected) {
2460   std::string model_buffer = ReadFile(GetTestModelPath());
2461   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2462     model->triggering_locales = "en,cs";
2463     model->triggering_options->locales = "en,cs";
2464   });
2465   std::unique_ptr<Annotator> classifier =
2466       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2467                                    unilib_.get(), calendarlib_.get());
2468   ASSERT_TRUE(classifier);
2469   ClassificationOptions options;
2470   options.detected_text_language_tags = "cs";
2471 
2472   EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2473                            "350 Third Street, Cambridge", {0, 27}, options)));
2474 }
2475 
TEST_F(AnnotatorTest,MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2476 TEST_F(AnnotatorTest,
2477        MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2478   std::string model_buffer = ReadFile(GetTestModelPath());
2479   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2480     model->triggering_locales = "en,cs";
2481     model->triggering_options->locales = "en,cs";
2482   });
2483   std::unique_ptr<Annotator> classifier =
2484       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2485                                    unilib_.get(), calendarlib_.get());
2486   ASSERT_TRUE(classifier);
2487   ClassificationOptions options;
2488   options.detected_text_language_tags = "de";
2489 
2490   EXPECT_THAT(
2491       classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options),
2492       IsEmpty());
2493 }
2494 
TEST_F(AnnotatorTest,MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected)2495 TEST_F(AnnotatorTest,
2496        MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) {
2497   std::string model_buffer = ReadFile(GetTestModelPath());
2498   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2499     model->triggering_locales = "en,cs";
2500     model->triggering_options->locales = "en,cs";
2501   });
2502   std::unique_ptr<Annotator> classifier =
2503       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2504                                    unilib_.get(), calendarlib_.get());
2505   ASSERT_TRUE(classifier);
2506   SelectionOptions options;
2507   options.detected_text_language_tags = "cs";
2508 
2509   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2510                                          options),
2511             CodepointSpan(0, 27));
2512 }
2513 
TEST_F(AnnotatorTest,MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2514 TEST_F(AnnotatorTest,
2515        MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2516   std::string model_buffer = ReadFile(GetTestModelPath());
2517   model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2518     model->triggering_locales = "en,cs";
2519     model->triggering_options->locales = "en,cs";
2520   });
2521   std::unique_ptr<Annotator> classifier =
2522       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2523                                    unilib_.get(), calendarlib_.get());
2524   ASSERT_TRUE(classifier);
2525   SelectionOptions options;
2526   options.detected_text_language_tags = "de";
2527 
2528   EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2529                                          options),
2530             CodepointSpan(4, 9));
2531 }
2532 
VerifyClassifyTextOutputsDatetimeEntityData(const Annotator * classifier)2533 void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) {
2534   EXPECT_TRUE(classifier);
2535   std::vector<ClassificationResult> result;
2536   ClassificationOptions options;
2537   options.locales = "en-US";
2538 
2539   result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
2540 
2541   ASSERT_GE(result.size(), 0);
2542   const EntityData* entity_data =
2543       GetEntityData(result[0].serialized_entity_data.data());
2544   ASSERT_NE(entity_data, nullptr);
2545   ASSERT_NE(entity_data->datetime(), nullptr);
2546   EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L);
2547   EXPECT_EQ(entity_data->datetime()->granularity(),
2548             EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE);
2549   EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6);
2550 
2551   auto* meridiem = entity_data->datetime()->datetime_component()->Get(0);
2552   EXPECT_EQ(meridiem->component_type(),
2553             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM);
2554   EXPECT_EQ(meridiem->absolute_value(), 0);
2555   EXPECT_EQ(meridiem->relative_count(), 0);
2556   EXPECT_EQ(meridiem->relation_type(),
2557             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2558 
2559   auto* minute = entity_data->datetime()->datetime_component()->Get(1);
2560   EXPECT_EQ(minute->component_type(),
2561             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE);
2562   EXPECT_EQ(minute->absolute_value(), 0);
2563   EXPECT_EQ(minute->relative_count(), 0);
2564   EXPECT_EQ(minute->relation_type(),
2565             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2566 
2567   auto* hour = entity_data->datetime()->datetime_component()->Get(2);
2568   EXPECT_EQ(hour->component_type(),
2569             EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR);
2570   EXPECT_EQ(hour->absolute_value(), 0);
2571   EXPECT_EQ(hour->relative_count(), 0);
2572   EXPECT_EQ(hour->relation_type(),
2573             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2574 
2575   auto* day = entity_data->datetime()->datetime_component()->Get(3);
2576   EXPECT_EQ(
2577       day->component_type(),
2578       EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2579   EXPECT_EQ(day->absolute_value(), 5);
2580   EXPECT_EQ(day->relative_count(), 0);
2581   EXPECT_EQ(day->relation_type(),
2582             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2583 
2584   auto* month = entity_data->datetime()->datetime_component()->Get(4);
2585   EXPECT_EQ(month->component_type(),
2586             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2587   EXPECT_EQ(month->absolute_value(), 3);
2588   EXPECT_EQ(month->relative_count(), 0);
2589   EXPECT_EQ(month->relation_type(),
2590             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2591 
2592   auto* year = entity_data->datetime()->datetime_component()->Get(5);
2593   EXPECT_EQ(year->component_type(),
2594             EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2595   EXPECT_EQ(year->absolute_value(), 1970);
2596   EXPECT_EQ(year->relative_count(), 0);
2597   EXPECT_EQ(year->relation_type(),
2598             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2599 }
2600 
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityData)2601 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) {
2602   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2603       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2604   VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2605 }
2606 
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx)2607 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx) {
2608   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2609   std::unique_ptr<Annotator> classifier =
2610       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2611                                    unilib_.get(), calendarlib_.get());
2612   VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2613 }
2614 
VerifyAnnotateOutputsDatetimeEntityData(const Annotator * classifier)2615 void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
2616   EXPECT_TRUE(classifier);
2617   std::vector<AnnotatedSpan> result;
2618   AnnotationOptions options;
2619   options.is_serialized_entity_data_enabled = true;
2620   options.locales = "en";
2621 
2622   result = classifier->Annotate("September 1, 2019", options);
2623 
2624   ASSERT_GE(result.size(), 0);
2625   ASSERT_GE(result[0].classification.size(), 0);
2626   ASSERT_EQ(result[0].classification[0].collection, "date");
2627   const EntityData* entity_data =
2628       GetEntityData(result[0].classification[0].serialized_entity_data.data());
2629   ASSERT_NE(entity_data, nullptr);
2630   ASSERT_NE(entity_data->datetime(), nullptr);
2631   EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L);
2632   EXPECT_EQ(entity_data->datetime()->granularity(),
2633             EntityData_::Datetime_::Granularity_GRANULARITY_DAY);
2634   EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3);
2635 
2636   auto* day = entity_data->datetime()->datetime_component()->Get(0);
2637   EXPECT_EQ(
2638       day->component_type(),
2639       EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2640   EXPECT_EQ(day->absolute_value(), 1);
2641   EXPECT_EQ(day->relative_count(), 0);
2642   EXPECT_EQ(day->relation_type(),
2643             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2644 
2645   auto* month = entity_data->datetime()->datetime_component()->Get(1);
2646   EXPECT_EQ(month->component_type(),
2647             EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2648   EXPECT_EQ(month->absolute_value(), 9);
2649   EXPECT_EQ(month->relative_count(), 0);
2650   EXPECT_EQ(month->relation_type(),
2651             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2652 
2653   auto* year = entity_data->datetime()->datetime_component()->Get(2);
2654   EXPECT_EQ(year->component_type(),
2655             EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2656   EXPECT_EQ(year->absolute_value(), 2019);
2657   EXPECT_EQ(year->relative_count(), 0);
2658   EXPECT_EQ(year->relation_type(),
2659             EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2660 }
2661 
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityData)2662 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) {
2663   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2664       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2665   VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2666 }
2667 
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx)2668 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx) {
2669   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2670   std::unique_ptr<Annotator> classifier =
2671       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2672                                    unilib_.get(), calendarlib_.get());
2673   VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2674 }
2675 
TEST_F(AnnotatorTest,AnnotateOutputsMoneyEntityData)2676 TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
2677   // std::string model_buffer = GetTestModelWithDatetimeRegEx();
2678   // std::unique_ptr<Annotator> classifier =
2679   //     Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2680   //                                  unilib_.get(), calendarlib_.get());
2681   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2682       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2683   EXPECT_TRUE(classifier);
2684   AnnotationOptions options;
2685   options.is_serialized_entity_data_enabled = true;
2686 
2687   ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF",
2688                            /*amount=*/"3.5", /*whole_part=*/3,
2689                            /*decimal_part=*/5, /*nanos=*/500000000);
2690   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF",
2691                            /*amount=*/"3.5", /*whole_part=*/3,
2692                            /*decimal_part=*/5, /*nanos=*/500000000);
2693   ExpectFirstEntityIsMoney(
2694       classifier->Annotate("For online purchase of CHF 23.00 enter", options),
2695       "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2696       /*nanos=*/0);
2697   ExpectFirstEntityIsMoney(
2698       classifier->Annotate("For online purchase of 23.00 CHF enter", options),
2699       "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2700       /*nanos=*/0);
2701   ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£",
2702                            /*amount=*/"4.8198", /*whole_part=*/4,
2703                            /*decimal_part=*/8198, /*nanos=*/819800000);
2704   ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£",
2705                            /*amount=*/"4.8198", /*whole_part=*/4,
2706                            /*decimal_part=*/8198, /*nanos=*/819800000);
2707   ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2708                            /*amount=*/"0.0255", /*whole_part=*/0,
2709                            /*decimal_part=*/255, /*nanos=*/25500000);
2710   ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2711                            /*amount=*/"0.0255", /*whole_part=*/0,
2712                            /*decimal_part=*/255, /*nanos=*/25500000);
2713   ExpectFirstEntityIsMoney(
2714       classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE "
2715                            "OR on card ending 0000.",
2716                            options),
2717       "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2718       /*nanos=*/0);
2719   ExpectFirstEntityIsMoney(
2720       classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE "
2721                            "OR on card ending 0000.",
2722                            options),
2723       "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2724       /*nanos=*/0);
2725 
2726   ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF",
2727                            /*amount=*/"35",
2728                            /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0);
2729   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF",
2730                            /*amount=*/"35", /*whole_part=*/35,
2731                            /*decimal_part=*/0, /*nanos=*/0);
2732   ExpectFirstEntityIsMoney(
2733       classifier->Annotate("and win back up to CHF 150 - with digitec",
2734                            options),
2735       "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2736       /*nanos=*/0);
2737   ExpectFirstEntityIsMoney(
2738       classifier->Annotate("and win back up to 150 CHF - with digitec",
2739                            options),
2740       "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2741       /*nanos=*/0);
2742 
2743   ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options),
2744                            "CHF", /*amount=*/"3.555.333",
2745                            /*whole_part=*/3555333, /*decimal_part=*/0,
2746                            /*nanos=*/0);
2747   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options),
2748                            "CHF", /*amount=*/"3.555.333",
2749                            /*whole_part=*/3555333, /*decimal_part=*/0,
2750                            /*nanos=*/0);
2751   ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF",
2752                            /*amount=*/"10,000", /*whole_part=*/10000,
2753                            /*decimal_part=*/0, /*nanos=*/0);
2754   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF",
2755                            /*amount=*/"10,000", /*whole_part=*/10000,
2756                            /*decimal_part=*/0, /*nanos=*/0);
2757 
2758   ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF",
2759                            /*amount=*/"3,555.33", /*whole_part=*/3555,
2760                            /*decimal_part=*/33, /*nanos=*/330000000);
2761   ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF",
2762                            /*amount=*/"3,555.33", /*whole_part=*/3555,
2763                            /*decimal_part=*/33, /*nanos=*/330000000);
2764   ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$",
2765                            /*amount=*/"3,000.00", /*whole_part=*/3000,
2766                            /*decimal_part=*/0, /*nanos=*/0);
2767   ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$",
2768                            /*amount=*/"3,000.00", /*whole_part=*/3000,
2769                            /*decimal_part=*/0, /*nanos=*/0);
2770 
2771   ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF",
2772                            /*amount=*/"1.2", /*whole_part=*/1,
2773                            /*decimal_part=*/2, /*nanos=*/200000000);
2774   ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF",
2775                            /*amount=*/"1.2", /*whole_part=*/1,
2776                            /*decimal_part=*/2, /*nanos=*/200000000);
2777 
2778   ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$",
2779                            /*amount=*/"1.123456789", /*whole_part=*/1,
2780                            /*decimal_part=*/123456789, /*nanos=*/123456789);
2781   ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF",
2782                            /*amount=*/"10.01", /*whole_part=*/10,
2783                            /*decimal_part=*/1, /*nanos=*/10000000);
2784 
2785   ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$",
2786                            /*amount=*/"59 million", /*whole_part=*/59000000,
2787                            /*decimal_part=*/0, /*nanos=*/0);
2788   ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€",
2789                            /*amount=*/"7.05 k", /*whole_part=*/7050,
2790                            /*decimal_part=*/5, /*nanos=*/0);
2791   ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€",
2792                            /*amount=*/"7.123456789 m", /*whole_part=*/7123456,
2793                            /*decimal_part=*/123456789, /*nanos=*/789000000);
2794   ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€",
2795                            /*amount=*/"7.000056789 k", /*whole_part=*/7000,
2796                            /*decimal_part=*/56789, /*nanos=*/56789000);
2797 
2798   ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$",
2799                            /*amount=*/"59.3 billion", /*whole_part=*/59,
2800                            /*decimal_part=*/3, /*nanos=*/300000000);
2801   ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$",
2802                            /*amount=*/"1.5 billion", /*whole_part=*/1500000000,
2803                            /*decimal_part=*/5, /*nanos=*/0);
2804 }
2805 
TEST_F(AnnotatorTest,TranslateAction)2806 TEST_F(AnnotatorTest, TranslateAction) {
2807   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2808       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2809   std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
2810       libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() +
2811                                                              "lang_id.smfb");
2812   classifier->SetLangId(langid_model.get());
2813 
2814   ClassificationOptions options;
2815   options.user_familiar_language_tags = "de";
2816 
2817   std::vector<ClassificationResult> classifications =
2818       classifier->ClassifyText("hello, how are you doing?", {11, 14}, options);
2819   EXPECT_EQ(classifications.size(), 1);
2820   EXPECT_EQ(classifications[0].collection, "translate");
2821 }
2822 
TEST_F(AnnotatorTest,AnnotateStructuredInputCallsMultipleAnnotators)2823 TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) {
2824   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2825       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2826 
2827   std::vector<InputFragment> string_fragments = {
2828       {.text = "He owes me 3.5 CHF."},
2829       {.text = "...was born on 13/12/1989."},
2830   };
2831 
2832   AnnotationOptions annotation_options;
2833   annotation_options.locales = "en";
2834   StatusOr<Annotations> annotations_status =
2835       classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2836   ASSERT_TRUE(annotations_status.ok());
2837   Annotations annotations = annotations_status.ValueOrDie();
2838   ASSERT_EQ(annotations.annotated_spans.size(), 2);
2839   EXPECT_THAT(annotations.annotated_spans[0],
2840               ElementsAreArray({IsAnnotatedSpan(11, 18, "money")}));
2841   EXPECT_THAT(annotations.annotated_spans[1],
2842               ElementsAreArray({IsAnnotatedSpan(15, 25, "date")}));
2843 }
2844 
VerifyInputFragmentTimestampOverridesAnnotationOptions(const Annotator * classifier)2845 void VerifyInputFragmentTimestampOverridesAnnotationOptions(
2846     const Annotator* classifier) {
2847   AnnotationOptions annotation_options;
2848   annotation_options.locales = "en";
2849   annotation_options.reference_time_ms_utc =
2850       1554465190000;                             // 04/05/2019 11:53 am
2851   int64 fragment_reference_time = 946727580000;  // 01/01/2000 11:53 am
2852   std::vector<InputFragment> string_fragments = {
2853       {.text = "New event at 17:20"},
2854       {
2855           .text = "New event at 17:20",
2856           .datetime_options = Optional<DatetimeOptions>(
2857               {.reference_time_ms_utc = fragment_reference_time}),
2858       }};
2859   StatusOr<Annotations> annotations_status =
2860       classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2861   ASSERT_TRUE(annotations_status.ok());
2862   Annotations annotations = annotations_status.ValueOrDie();
2863   ASSERT_EQ(annotations.annotated_spans.size(), 2);
2864   EXPECT_THAT(annotations.annotated_spans[0],
2865               ElementsAreArray({IsDatetimeSpan(
2866                   /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000,
2867                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2868   EXPECT_THAT(annotations.annotated_spans[1],
2869               ElementsAreArray({IsDatetimeSpan(
2870                   /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000,
2871                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2872 }
2873 
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx)2874 TEST_F(AnnotatorTest,
2875        InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx) {
2876   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2877   std::unique_ptr<Annotator> classifier =
2878       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2879                                    unilib_.get(), calendarlib_.get());
2880   VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2881 }
2882 
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptions)2883 TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) {
2884   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2885       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2886   VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2887 }
2888 
VerifyInputFragmentTimezoneOverridesAnnotationOptions(const Annotator * classifier)2889 void VerifyInputFragmentTimezoneOverridesAnnotationOptions(
2890     const Annotator* classifier) {
2891   std::vector<InputFragment> string_fragments = {
2892       {.text = "11/12/2020 17:20"},
2893       {
2894           .text = "11/12/2020 17:20",
2895           .datetime_options = Optional<DatetimeOptions>(
2896               {.reference_timezone = "Europe/Zurich"}),
2897       }};
2898   AnnotationOptions annotation_options;
2899   annotation_options.locales = "en-US";
2900   StatusOr<Annotations> annotations_status =
2901       classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2902   ASSERT_TRUE(annotations_status.ok());
2903   Annotations annotations = annotations_status.ValueOrDie();
2904   ASSERT_EQ(annotations.annotated_spans.size(), 2);
2905   EXPECT_THAT(annotations.annotated_spans[0],
2906               ElementsAreArray({IsDatetimeSpan(
2907                   /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000,
2908                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2909   EXPECT_THAT(annotations.annotated_spans[1],
2910               ElementsAreArray({IsDatetimeSpan(
2911                   /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000,
2912                   DatetimeGranularity::GRANULARITY_MINUTE)}));
2913 }
2914 
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptions)2915 TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) {
2916   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2917       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2918   VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2919 }
2920 
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx)2921 TEST_F(AnnotatorTest,
2922        InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx) {
2923   std::string model_buffer = GetTestModelWithDatetimeRegEx();
2924   std::unique_ptr<Annotator> classifier =
2925       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2926                                    unilib_.get(), calendarlib_.get());
2927   VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2928 }
2929 
2930 namespace {
AddDummyRegexDatetimeModel(ModelT * unpacked_model)2931 void AddDummyRegexDatetimeModel(ModelT* unpacked_model) {
2932   unpacked_model->datetime_model.reset(new DatetimeModelT);
2933   // This needs to be false otherwise we'd have to define some extractor. When
2934   // this is false, the 0-th capturing group (whole match) from the pattern is
2935   // used to come up with the indices.
2936   unpacked_model->datetime_model->use_extractors_for_locating = false;
2937   unpacked_model->datetime_model->locales.push_back("en-US");
2938   unpacked_model->datetime_model->default_locales.push_back(0);  // en-US
2939   unpacked_model->datetime_model->patterns.push_back(
2940       std::unique_ptr<DatetimeModelPatternT>(new DatetimeModelPatternT));
2941   unpacked_model->datetime_model->patterns.back()->locales.push_back(
2942       0);  // en-US
2943   unpacked_model->datetime_model->patterns.back()->regexes.push_back(
2944       std::unique_ptr<DatetimeModelPattern_::RegexT>(
2945           new DatetimeModelPattern_::RegexT));
2946   unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern =
2947       "THIS_MATCHES_IN_REGEX_MODEL";
2948   unpacked_model->datetime_model->patterns.back()
2949       ->regexes.back()
2950       ->groups.push_back(DatetimeGroupType_GROUP_UNUSED);
2951 }
2952 }  // namespace
2953 
TEST_F(AnnotatorTest,AnnotateFiltersOutExactDuplicates)2954 TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
2955   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2956       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2957   ASSERT_TRUE(classifier);
2958 
2959   // This test assumes that both ML model and Regex model trigger on the
2960   // following text and output "phone" annotation for it.
2961   const std::string test_string = "1000000000";
2962   AnnotationOptions options;
2963   options.annotation_usecase = ANNOTATION_USECASE_RAW;
2964   int num_phones = 0;
2965   for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) {
2966     if (span.classification[0].collection == "phone") {
2967       num_phones++;
2968     }
2969   }
2970 
2971   EXPECT_EQ(num_phones, 1);
2972 }
2973 
2974 // This test tests the optimizations in Annotator, which make some of the
2975 // annotators not run in the RAW mode when not requested. We test here that the
2976 // results indeed don't contain such annotations. However, this is a bick hacky,
2977 // since one could also add post-filtering, in which case these tests would
2978 // trivially pass.
TEST_F(AnnotatorTest,RawModeOptimizationWorks)2979 TEST_F(AnnotatorTest, RawModeOptimizationWorks) {
2980   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2981       GetTestModelPath(), unilib_.get(), calendarlib_.get());
2982   ASSERT_TRUE(classifier);
2983 
2984   AnnotationOptions options;
2985   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2986   // Requesting a non-existing type to avoid overlap with existing types.
2987   options.entity_types.insert("some_unknown_entity_type");
2988 
2989   // Normally, the following command would produce the following annotations:
2990   //   Span(19, 24, date, 1.000000),
2991   //   Span(53, 56, number, 1.000000),
2992   //   Span(53, 80, address, 1.000000),
2993   //   Span(128, 142, phone, 1.000000),
2994   //   Span(129, 132, number, 1.000000),
2995   //   Span(192, 200, phone, 1.000000),
2996   //   Span(192, 206, datetime, 1.000000),
2997   //   Span(246, 253, number, 1.000000),
2998   //   Span(246, 253, phone, 1.000000),
2999   //   Span(292, 293, number, 1.000000),
3000   //   Span(292, 301, duration, 1.000000) }
3001   // But because of the optimizations, it doesn't produce anything, since
3002   // we didn't request any of these entities.
3003   EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today
3004                             350 Third Street, Cambridge
3005                             my phone number is (853) 225-3556
3006                             this is when we met: 1.9.2021 13:00
3007                             my number: 1234567
3008                             duration: 3 minutes
3009                             )--",
3010                                    options),
3011               IsEmpty());
3012 }
3013 
VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(const Annotator * classifier)3014 void VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(
3015     const Annotator* classifier) {
3016   ASSERT_TRUE(classifier);
3017   struct Example {
3018     std::string collection;
3019     std::string text;
3020   };
3021 
3022   // These examples contain one example per annotator, to check that each of
3023   // the annotators can work in the RAW mode on its own.
3024   //
3025   // WARNING: This list doesn't contain yet entries for the app, contact, and
3026   // person annotators. Hopefully this won't be needed once b/155214735 is
3027   // fixed and the piping shared across annotators.
3028   std::vector<Example> examples{
3029       // ML Model.
3030       {.collection = Collections::Address(),
3031        .text = "... 350 Third Street, Cambridge ..."},
3032       // Datetime annotator.
3033       {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."},
3034       // Duration annotator.
3035       {.collection = Collections::Duration(),
3036        .text = "... 3 hours and 9 seconds ..."},
3037       // Regex annotator.
3038       {.collection = Collections::Email(),
3039        .text = "... platypus@theanimal.org ..."},
3040       // Number annotator.
3041       {.collection = Collections::Number(), .text = "... 100 ..."},
3042   };
3043 
3044   for (const Example& example : examples) {
3045     AnnotationOptions options;
3046     options.locales = "en";
3047     options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3048     options.entity_types.insert(example.collection);
3049 
3050     EXPECT_THAT(classifier->Annotate(example.text, options),
3051                 Contains(IsAnnotationWithType(example.collection)))
3052         << " text: '" << example.text
3053         << "', collection: " << example.collection;
3054   }
3055 }
3056 
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawMode)3057 TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) {
3058   std::unique_ptr<Annotator> classifier = Annotator::FromPath(
3059       GetTestModelPath(), unilib_.get(), calendarlib_.get());
3060   VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3061 }
3062 
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx)3063 TEST_F(AnnotatorTest,
3064        AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx) {
3065   std::string model_buffer = GetTestModelWithDatetimeRegEx();
3066   std::unique_ptr<Annotator> classifier =
3067       Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
3068                                    unilib_.get(), calendarlib_.get());
3069   VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3070 }
3071 
TEST_F(AnnotatorTest,InitializeFromString)3072 TEST_F(AnnotatorTest, InitializeFromString) {
3073   const std::string test_model = ReadFile(GetTestModelPath());
3074   std::unique_ptr<Annotator> classifier =
3075       Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3076   ASSERT_TRUE(classifier);
3077   EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
3078 }
3079 
3080 // Regression test for cl/338280366. Enabling only_use_line_with_click had
3081 // the effect, that some annotators in the previous code releases would
3082 // receive only the last line of the input text. This test has the entity on the
3083 // first line (duration).
TEST_F(AnnotatorTest,RegressionTestOnlyUseLineWithClickLastLine)3084 TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) {
3085   const std::string test_model = ReadFile(GetTestModelPath());
3086   std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
3087 
3088   std::unique_ptr<Annotator> classifier;
3089 
3090   // With unrestricted number of tokens should behave normally.
3091   unpacked_model->selection_feature_options->only_use_line_with_click = true;
3092 
3093   flatbuffers::FlatBufferBuilder builder;
3094   FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
3095   classifier = Annotator::FromUnownedBuffer(
3096       reinterpret_cast<const char*>(builder.GetBufferPointer()),
3097       builder.GetSize(), unilib_.get(), calendarlib_.get());
3098   ASSERT_TRUE(classifier);
3099 
3100   AnnotationOptions options;
3101   options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3102 
3103   const std::vector<AnnotatedSpan> annotations =
3104       classifier->Annotate("let's meet in 3 hours\nbut not now", options);
3105 
3106   EXPECT_THAT(annotations, Contains(IsDurationSpan(
3107                                /*start=*/14, /*end=*/21,
3108                                /*duration_ms=*/3 * 60 * 60 * 1000)));
3109 }
3110 
TEST_F(AnnotatorTest,DoesntProcessInvalidUtf8)3111 TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
3112   const std::string test_model = ReadFile(GetTestModelPath());
3113   const std::string invalid_utf8_text_with_phone_number =
3114       "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
3115 
3116   std::unique_ptr<Annotator> classifier =
3117       Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3118   ASSERT_TRUE(classifier);
3119   EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
3120               IsEmpty());
3121   EXPECT_THAT(
3122       classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
3123       Eq(CodepointSpan{1, 4}));
3124   EXPECT_THAT(
3125       classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
3126       IsEmpty());
3127 }
3128 
3129 }  // namespace test_internal
3130 }  // namespace libtextclassifier3
3131