1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/annotator_test-include.h"
18
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <type_traits>
23
24 #include "annotator/annotator.h"
25 #include "annotator/collections.h"
26 #include "annotator/model_generated.h"
27 #include "annotator/test-utils.h"
28 #include "annotator/types-test-util.h"
29 #include "annotator/types.h"
30 #include "utils/grammar/utils/locale-shard-map.h"
31 #include "utils/grammar/utils/rules.h"
32 #include "utils/testing/annotator.h"
33 #include "lang_id/fb_model/lang-id-from-fb.h"
34 #include "lang_id/lang-id.h"
35
36 namespace libtextclassifier3 {
37 namespace test_internal {
38
39 using ::testing::Contains;
40 using ::testing::ElementsAre;
41 using ::testing::ElementsAreArray;
42 using ::testing::Eq;
43 using ::testing::IsEmpty;
44 using ::testing::UnorderedElementsAreArray;
45
GetTestModelPath()46 std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
47
GetModelWithVocabPath()48 std::string GetModelWithVocabPath() {
49 return GetModelPath() + "test_vocab_model.fb";
50 }
51
GetTestModelWithDatetimeRegEx()52 std::string GetTestModelWithDatetimeRegEx() {
53 std::string model_buffer = ReadFile(GetTestModelPath());
54 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
55 model->datetime_grammar_model.reset(nullptr);
56 });
57 return model_buffer;
58 }
59
ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan> & result,const std::string & currency,const std::string & amount,const int whole_part,const int decimal_part,const int nanos)60 void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
61 const std::string& currency,
62 const std::string& amount, const int whole_part,
63 const int decimal_part, const int nanos) {
64 ASSERT_GT(result.size(), 0);
65 ASSERT_GT(result[0].classification.size(), 0);
66 ASSERT_EQ(result[0].classification[0].collection, "money");
67
68 const EntityData* entity_data =
69 GetEntityData(result[0].classification[0].serialized_entity_data.data());
70 ASSERT_NE(entity_data, nullptr);
71 ASSERT_NE(entity_data->money(), nullptr);
72 EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency);
73 EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount);
74 EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part);
75 EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part);
76 EXPECT_EQ(entity_data->money()->nanos(), nanos);
77 }
78
TEST_F(AnnotatorTest,EmbeddingExecutorLoadingFails)79 TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
80 std::unique_ptr<Annotator> classifier =
81 Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(),
82 calendarlib_.get());
83 EXPECT_FALSE(classifier);
84 }
85
VerifyClassifyText(const Annotator * classifier)86 void VerifyClassifyText(const Annotator* classifier) {
87 ASSERT_TRUE(classifier);
88
89 EXPECT_EQ("other",
90 FirstResult(classifier->ClassifyText(
91 "this afternoon Barack Obama gave a speech at", {15, 27})));
92 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
93 "Call me at (800) 123-456 today", {11, 24})));
94
95 // More lines.
96 EXPECT_EQ("other",
97 FirstResult(classifier->ClassifyText(
98 "this afternoon Barack Obama gave a speech at|Visit "
99 "www.google.com every today!|Call me at (800) 123-456 today.",
100 {15, 27})));
101 EXPECT_EQ("phone",
102 FirstResult(classifier->ClassifyText(
103 "this afternoon Barack Obama gave a speech at|Visit "
104 "www.google.com every today!|Call me at (800) 123-456 today.",
105 {90, 103})));
106
107 // Single word.
108 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
109 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
110
111 // Junk. These should not crash the test.
112 classifier->ClassifyText("", {0, 0});
113 classifier->ClassifyText("asdf", {0, 0});
114 classifier->ClassifyText("asdf", {0, 27});
115 classifier->ClassifyText("asdf", {-30, 300});
116 classifier->ClassifyText("asdf", {-10, -1});
117 classifier->ClassifyText("asdf", {100, 17});
118 classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
119
120 // Test invalid utf8 input.
121 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
122 "\xf0\x9f\x98\x8b\x8b", {0, 0})));
123 }
124
TEST_F(AnnotatorTest,ClassifyText)125 TEST_F(AnnotatorTest, ClassifyText) {
126 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
127 GetTestModelPath(), unilib_.get(), calendarlib_.get());
128 VerifyClassifyText(classifier.get());
129 }
130
TEST_F(AnnotatorTest,ClassifyTextLocalesAndDictionary)131 TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
132 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
133 GetTestModelPath(), unilib_.get(), calendarlib_.get());
134 ASSERT_TRUE(classifier);
135
136 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7})));
137
138 ClassificationOptions classification_options;
139 classification_options.detected_text_language_tags = "en";
140 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
141 "isotope", {0, 7}, classification_options)));
142
143 classification_options.detected_text_language_tags = "uz";
144 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
145 "isotope", {0, 7}, classification_options)));
146 }
147
TEST_F(AnnotatorTest,ClassifyTextUseVocabAnnotatorWithoutVocabModel)148 TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) {
149 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
150 GetTestModelPath(), unilib_.get(), calendarlib_.get());
151 ASSERT_TRUE(classifier);
152
153 ClassificationOptions classification_options;
154 classification_options.detected_text_language_tags = "en";
155 classification_options.use_vocab_annotator = true;
156
157 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
158 "isotope", {0, 7}, classification_options)));
159 }
160
161 #ifdef TC3_VOCAB_ANNOTATOR_IMPL
TEST_F(AnnotatorTest,ClassifyTextWithVocabModel)162 TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) {
163 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
164 GetModelWithVocabPath(), unilib_.get(), calendarlib_.get());
165 ASSERT_TRUE(classifier);
166
167 ClassificationOptions classification_options;
168 classification_options.detected_text_language_tags = "en";
169
170 // The FFModel model does not annotate "integrity" as "dictionary", but the
171 // vocab annotator does. So we can use that to check if the vocab annotator is
172 // in use.
173 classification_options.use_vocab_annotator = true;
174 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
175 "integrity", {0, 9}, classification_options)));
176 classification_options.use_vocab_annotator = false;
177 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
178 "integrity", {0, 9}, classification_options)));
179 }
180 #endif // TC3_VOCAB_ANNOTATOR_IMPL
181
TEST_F(AnnotatorTest,ClassifyTextDisabledFail)182 TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
183 const std::string test_model = ReadFile(GetTestModelPath());
184 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
185 TC3_CHECK(unpacked_model != nullptr);
186
187 unpacked_model->classification_model.clear();
188 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
189 unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
190
191 flatbuffers::FlatBufferBuilder builder;
192 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
193
194 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
195 reinterpret_cast<const char*>(builder.GetBufferPointer()),
196 builder.GetSize(), unilib_.get(), calendarlib_.get());
197
198 // The classification model is still needed for selection scores.
199 ASSERT_FALSE(classifier);
200 }
201
TEST_F(AnnotatorTest,ClassifyTextDisabled)202 TEST_F(AnnotatorTest, ClassifyTextDisabled) {
203 const std::string test_model = ReadFile(GetTestModelPath());
204 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
205
206 unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
207
208 flatbuffers::FlatBufferBuilder builder;
209 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
210
211 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
212 reinterpret_cast<const char*>(builder.GetBufferPointer()),
213 builder.GetSize(), unilib_.get(), calendarlib_.get());
214 ASSERT_TRUE(classifier);
215
216 EXPECT_THAT(
217 classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
218 IsEmpty());
219 }
220
TEST_F(AnnotatorTest,ClassifyTextFilteredCollections)221 TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
222 const std::string test_model = ReadFile(GetTestModelPath());
223
224 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
225 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
226 ASSERT_TRUE(classifier);
227
228 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
229 "Call me at (800) 123-456 today", {11, 24})));
230
231 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
232 unpacked_model->output_options.reset(new OutputOptionsT);
233
234 // Disable phone classification
235 unpacked_model->output_options->filtered_collections_classification.push_back(
236 "phone");
237
238 flatbuffers::FlatBufferBuilder builder;
239 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
240
241 classifier = Annotator::FromUnownedBuffer(
242 reinterpret_cast<const char*>(builder.GetBufferPointer()),
243 builder.GetSize(), unilib_.get(), calendarlib_.get());
244 ASSERT_TRUE(classifier);
245
246 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
247 "Call me at (800) 123-456 today", {11, 24})));
248
249 // Check that the address classification still passes.
250 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
251 "350 Third Street, Cambridge", {0, 27})));
252 }
253
TEST_F(AnnotatorTest,ClassifyTextRegularExpression)254 TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
255 const std::string test_model = ReadFile(GetTestModelPath());
256 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
257
258 // Add test regex models.
259 unpacked_model->regex_model->patterns.push_back(MakePattern(
260 "person", "Barack Obama", /*enabled_for_classification=*/true,
261 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
262 unpacked_model->regex_model->patterns.push_back(MakePattern(
263 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
264 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
265 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
266 MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
267 /*enabled_for_classification=*/true,
268 /*enabled_for_selection=*/false,
269 /*enabled_for_annotation=*/false, 1.0);
270 verified_pattern->verification_options.reset(new VerificationOptionsT);
271 verified_pattern->verification_options->verify_luhn_checksum = true;
272 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
273
274 flatbuffers::FlatBufferBuilder builder;
275 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
276
277 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
278 reinterpret_cast<const char*>(builder.GetBufferPointer()),
279 builder.GetSize(), unilib_.get(), calendarlib_.get());
280 ASSERT_TRUE(classifier);
281
282 EXPECT_EQ("flight",
283 FirstResult(classifier->ClassifyText(
284 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
285 EXPECT_EQ("person",
286 FirstResult(classifier->ClassifyText(
287 "this afternoon Barack Obama gave a speech at", {15, 27})));
288 EXPECT_EQ("email",
289 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
290 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
291 "Contact me at you@android.com", {14, 29})));
292
293 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
294 "Visit www.google.com every today!", {6, 20})));
295
296 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
297 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
298 {7, 12})));
299 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
300 "cc: 4012 8888 8888 1881", {4, 23})));
301 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
302 "2221 0067 4735 6281", {0, 19})));
303 // Luhn check fails.
304 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
305 {0, 19})));
306
307 // More lines.
308 EXPECT_EQ("url",
309 FirstResult(classifier->ClassifyText(
310 "this afternoon Barack Obama gave a speech at|Visit "
311 "www.google.com every today!|Call me at (800) 123-456 today.",
312 {51, 65})));
313 }
314
315 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionLuaVerification)316 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) {
317 const std::string test_model = ReadFile(GetTestModelPath());
318 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
319
320 // Add test regex models.
321 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
322 MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
323 /*enabled_for_classification=*/true,
324 /*enabled_for_selection=*/false,
325 /*enabled_for_annotation=*/false, 1.0);
326 verified_pattern->verification_options.reset(new VerificationOptionsT);
327 verified_pattern->verification_options->lua_verifier = 0;
328 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
329 unpacked_model->regex_model->lua_verifier.push_back(
330 "return match[2].text==\"99\"");
331
332 flatbuffers::FlatBufferBuilder builder;
333 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
334
335 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
336 reinterpret_cast<const char*>(builder.GetBufferPointer()),
337 builder.GetSize(), unilib_.get(), calendarlib_.get());
338 ASSERT_TRUE(classifier);
339
340 // Custom rule triggers and is correctly verified.
341 EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText(
342 "99-00-123456-12345678", {0, 21})));
343
344 // Custom verification fails.
345 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
346 "90-00-123456-12345678", {0, 21})));
347 }
348 #endif // TC3_DISABLE_LUA
349
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityData)350 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
351 const std::string test_model = ReadFile(GetTestModelPath());
352 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
353
354 // Add fake entity schema metadata.
355 AddTestEntitySchemaData(unpacked_model.get());
356
357 AddTestRegexModel(unpacked_model.get());
358
359 flatbuffers::FlatBufferBuilder builder;
360 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
361
362 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
363 reinterpret_cast<const char*>(builder.GetBufferPointer()),
364 builder.GetSize(), unilib_.get(), calendarlib_.get());
365 ASSERT_TRUE(classifier);
366
367 // Check with full name.
368 {
369 auto classifications =
370 classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
371 EXPECT_EQ(1, classifications.size());
372 EXPECT_EQ("person_with_age", classifications[0].collection);
373
374 // Check entity data.
375 const flatbuffers::Table* entity =
376 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
377 classifications[0].serialized_entity_data.data()));
378 EXPECT_EQ(
379 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
380 "Barack");
381 EXPECT_EQ(
382 entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
383 "Obama");
384 // Check `age`.
385 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
386
387 // Check `is_alive`.
388 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
389
390 // Check `former_us_president`.
391 EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
392 }
393
394 // Check only with first name.
395 {
396 auto classifications =
397 classifier->ClassifyText("Barack is 57 years old", {0, 22});
398 EXPECT_EQ(1, classifications.size());
399 EXPECT_EQ("person_with_age", classifications[0].collection);
400
401 // Check entity data.
402 const flatbuffers::Table* entity =
403 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
404 classifications[0].serialized_entity_data.data()));
405 EXPECT_EQ(
406 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
407 "Barack");
408
409 // Check `age`.
410 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
411
412 // Check `is_alive`.
413 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
414
415 // Check `former_us_president`.
416 EXPECT_FALSE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
417 }
418 }
419
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityDataNormalization)420 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) {
421 const std::string test_model = ReadFile(GetTestModelPath());
422 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
423
424 // Add fake entity schema metadata.
425 AddTestEntitySchemaData(unpacked_model.get());
426
427 AddTestRegexModel(unpacked_model.get());
428
429 // Upper case last name as post-processing.
430 RegexModel_::PatternT* pattern =
431 unpacked_model->regex_model->patterns.back().get();
432 pattern->capturing_group[2]->normalization_options.reset(
433 new NormalizationOptionsT);
434 pattern->capturing_group[2]
435 ->normalization_options->codepointwise_normalization =
436 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
437
438 flatbuffers::FlatBufferBuilder builder;
439 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
440
441 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
442 reinterpret_cast<const char*>(builder.GetBufferPointer()),
443 builder.GetSize(), unilib_.get(), calendarlib_.get());
444 ASSERT_TRUE(classifier);
445
446 auto classifications =
447 classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
448 EXPECT_EQ(1, classifications.size());
449 EXPECT_EQ("person_with_age", classifications[0].collection);
450
451 // Check entity data normalization.
452 const flatbuffers::Table* entity =
453 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
454 classifications[0].serialized_entity_data.data()));
455 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
456 "OBAMA");
457 }
458
TEST_F(AnnotatorTest,ClassifyTextPriorityResolution)459 TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
460 const std::string test_model = ReadFile(GetTestModelPath());
461 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
462 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
463 // Add test regex models.
464 unpacked_model->regex_model->patterns.clear();
465 unpacked_model->regex_model->patterns.push_back(MakePattern(
466 "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
467 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
468 /*score=*/1.0, /*priority_score=*/1.0));
469 unpacked_model->regex_model->patterns.push_back(MakePattern(
470 "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
471 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
472 /*score=*/1.0, /*priority_score=*/0.0));
473
474 {
475 flatbuffers::FlatBufferBuilder builder;
476 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
477 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
478 reinterpret_cast<const char*>(builder.GetBufferPointer()),
479 builder.GetSize(), unilib_.get(), calendarlib_.get());
480 ASSERT_TRUE(classifier);
481
482 EXPECT_EQ("flight1",
483 FirstResult(classifier->ClassifyText(
484 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
485 }
486
487 unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
488 {
489 flatbuffers::FlatBufferBuilder builder;
490 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
491 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
492 reinterpret_cast<const char*>(builder.GetBufferPointer()),
493 builder.GetSize(), unilib_.get(), calendarlib_.get());
494 ASSERT_TRUE(classifier);
495
496 EXPECT_EQ("flight2",
497 FirstResult(classifier->ClassifyText(
498 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
499 }
500 }
501
TEST_F(AnnotatorTest,AnnotatePriorityResolution)502 TEST_F(AnnotatorTest, AnnotatePriorityResolution) {
503 const std::string test_model = ReadFile(GetTestModelPath());
504 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
505 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
506 // Add test regex models. One of them has higher priority score than
507 // the other. We'll test that always the one with higher priority score
508 // ends up winning.
509 unpacked_model->regex_model->patterns.clear();
510 const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})";
511 unpacked_model->regex_model->patterns.push_back(MakePattern(
512 "flight", flight_regex, /*enabled_for_classification=*/true,
513 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
514 /*score=*/1.0, /*priority_score=*/1.0));
515 unpacked_model->regex_model->patterns.push_back(MakePattern(
516 "flight", flight_regex, /*enabled_for_classification=*/true,
517 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
518 /*score=*/1.0, /*priority_score=*/0.0));
519
520 // "flight" that wins should have a priority score of 1.0.
521 {
522 flatbuffers::FlatBufferBuilder builder;
523 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
524 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
525 reinterpret_cast<const char*>(builder.GetBufferPointer()),
526 builder.GetSize(), unilib_.get(), calendarlib_.get());
527 ASSERT_TRUE(classifier);
528
529 const std::vector<AnnotatedSpan> results =
530 classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
531 ASSERT_THAT(results, Not(IsEmpty()));
532 EXPECT_THAT(results[0].classification, Not(IsEmpty()));
533 EXPECT_GE(results[0].classification[0].priority_score, 0.9);
534 }
535
536 // When we increase the priority score, the "flight" that wins should have a
537 // priority score of 3.0.
538 unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
539 {
540 flatbuffers::FlatBufferBuilder builder;
541 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
542 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
543 reinterpret_cast<const char*>(builder.GetBufferPointer()),
544 builder.GetSize(), unilib_.get(), calendarlib_.get());
545 ASSERT_TRUE(classifier);
546
547 const std::vector<AnnotatedSpan> results =
548 classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
549 ASSERT_THAT(results, Not(IsEmpty()));
550 EXPECT_THAT(results[0].classification, Not(IsEmpty()));
551 EXPECT_GE(results[0].classification[0].priority_score, 2.9);
552 }
553 }
554
TEST_F(AnnotatorTest,SuggestSelectionRegularExpression)555 TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
556 const std::string test_model = ReadFile(GetTestModelPath());
557 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
558
559 // Add test regex models.
560 unpacked_model->regex_model->patterns.push_back(MakePattern(
561 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
562 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
563 unpacked_model->regex_model->patterns.push_back(MakePattern(
564 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
565 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
566 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
567 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
568 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
569 /*enabled_for_classification=*/false,
570 /*enabled_for_selection=*/true,
571 /*enabled_for_annotation=*/false, 1.0);
572 verified_pattern->verification_options.reset(new VerificationOptionsT);
573 verified_pattern->verification_options->verify_luhn_checksum = true;
574 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
575
576 flatbuffers::FlatBufferBuilder builder;
577 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
578
579 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
580 reinterpret_cast<const char*>(builder.GetBufferPointer()),
581 builder.GetSize(), unilib_.get(), calendarlib_.get());
582 ASSERT_TRUE(classifier);
583
584 // Check regular expression selection.
585 EXPECT_EQ(classifier->SuggestSelection(
586 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
587 CodepointSpan(12, 19));
588 EXPECT_EQ(classifier->SuggestSelection(
589 "this afternoon Barack Obama gave a speech at", {15, 21}),
590 CodepointSpan(15, 27));
591 EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
592 CodepointSpan(4, 23));
593 }
594
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionCustomSelectionBounds)595 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
596 const std::string test_model = ReadFile(GetTestModelPath());
597 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
598
599 // Add test regex models.
600 std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
601 MakePattern("date_range",
602 "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
603 "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
604 /*enabled_for_classification=*/false,
605 /*enabled_for_selection=*/true,
606 /*enabled_for_annotation=*/false, 1.0);
607 custom_selection_bounds_pattern->capturing_group.emplace_back(
608 new CapturingGroupT);
609 custom_selection_bounds_pattern->capturing_group.emplace_back(
610 new CapturingGroupT);
611 custom_selection_bounds_pattern->capturing_group.emplace_back(
612 new CapturingGroupT);
613 custom_selection_bounds_pattern->capturing_group.emplace_back(
614 new CapturingGroupT);
615 custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
616 custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
617 custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
618 custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
619 unpacked_model->regex_model->patterns.push_back(
620 std::move(custom_selection_bounds_pattern));
621
622 flatbuffers::FlatBufferBuilder builder;
623 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
624
625 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
626 reinterpret_cast<const char*>(builder.GetBufferPointer()),
627 builder.GetSize(), unilib_.get(), calendarlib_.get());
628 ASSERT_TRUE(classifier);
629
630 // Check regular expression selection.
631 EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
632 {21, 23}),
633 CodepointSpan(10, 34));
634 EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
635 CodepointSpan(9, 17));
636 }
637
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsModelWins)638 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
639 const std::string test_model = ReadFile(GetTestModelPath());
640 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
641
642 // Add test regex models.
643 unpacked_model->regex_model->patterns.push_back(MakePattern(
644 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
645 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
646 unpacked_model->regex_model->patterns.push_back(MakePattern(
647 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
648 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
649 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
650
651 flatbuffers::FlatBufferBuilder builder;
652 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
653
654 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
655 reinterpret_cast<const char*>(builder.GetBufferPointer()),
656 builder.GetSize(), unilib_.get(), calendarlib_.get());
657 ASSERT_TRUE(classifier);
658
659 // Check conflict resolution.
660 EXPECT_EQ(
661 classifier->SuggestSelection(
662 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
663 {55, 57}),
664 CodepointSpan(26, 62));
665 }
666
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsRegexWins)667 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
668 const std::string test_model = ReadFile(GetTestModelPath());
669 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
670
671 // Add test regex models.
672 unpacked_model->regex_model->patterns.push_back(MakePattern(
673 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
674 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
675 unpacked_model->regex_model->patterns.push_back(MakePattern(
676 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
677 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
678 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
679
680 flatbuffers::FlatBufferBuilder builder;
681 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
682
683 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
684 reinterpret_cast<const char*>(builder.GetBufferPointer()),
685 builder.GetSize(), unilib_.get(), calendarlib_.get());
686 ASSERT_TRUE(classifier);
687
688 // Check conflict resolution.
689 EXPECT_EQ(
690 classifier->SuggestSelection(
691 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
692 {55, 57}),
693 CodepointSpan(55, 62));
694 }
695
TEST_F(AnnotatorTest,AnnotateRegex)696 TEST_F(AnnotatorTest, AnnotateRegex) {
697 const std::string test_model = ReadFile(GetTestModelPath());
698 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
699
700 // Add test regex models.
701 unpacked_model->regex_model->patterns.push_back(MakePattern(
702 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
703 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
704 unpacked_model->regex_model->patterns.push_back(MakePattern(
705 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
706 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
707 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
708 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
709 /*enabled_for_classification=*/false,
710 /*enabled_for_selection=*/false,
711 /*enabled_for_annotation=*/true, 1.0);
712 verified_pattern->verification_options.reset(new VerificationOptionsT);
713 verified_pattern->verification_options->verify_luhn_checksum = true;
714 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
715 flatbuffers::FlatBufferBuilder builder;
716 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
717
718 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
719 reinterpret_cast<const char*>(builder.GetBufferPointer()),
720 builder.GetSize(), unilib_.get(), calendarlib_.get());
721 ASSERT_TRUE(classifier);
722
723 const std::string test_string =
724 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
725 "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
726 EXPECT_THAT(classifier->Annotate(test_string),
727 ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
728 IsAnnotatedSpan(28, 55, "address"),
729 IsAnnotatedSpan(79, 91, "phone"),
730 IsAnnotatedSpan(107, 126, "payment_card")}));
731 }
732
TEST_F(AnnotatorTest,AnnotatesFlightNumbers)733 TEST_F(AnnotatorTest, AnnotatesFlightNumbers) {
734 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
735 GetTestModelPath(), unilib_.get(), calendarlib_.get());
736 ASSERT_TRUE(classifier);
737
738 // ICAO is only used for selected airlines.
739 // Expected: LX373, EZY1234 and U21234.
740 const std::string test_string = "flights LX373, SWR373, EZY1234, U21234";
741 EXPECT_THAT(classifier->Annotate(test_string),
742 ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"),
743 IsAnnotatedSpan(23, 30, "flight"),
744 IsAnnotatedSpan(32, 38, "flight")}));
745 }
746
747 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,AnnotateRegexLuaVerification)748 TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) {
749 const std::string test_model = ReadFile(GetTestModelPath());
750 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
751
752 // Add test regex models.
753 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
754 MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
755 /*enabled_for_classification=*/true,
756 /*enabled_for_selection=*/true,
757 /*enabled_for_annotation=*/true, 1.0);
758 verified_pattern->verification_options.reset(new VerificationOptionsT);
759 verified_pattern->verification_options->lua_verifier = 0;
760 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
761 unpacked_model->regex_model->lua_verifier.push_back(
762 "return match[2].text==\"99\"");
763
764 flatbuffers::FlatBufferBuilder builder;
765 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
766
767 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
768 reinterpret_cast<const char*>(builder.GetBufferPointer()),
769 builder.GetSize(), unilib_.get(), calendarlib_.get());
770 ASSERT_TRUE(classifier);
771
772 const std::string test_string =
773 "your parcel is on the way: 99-00-123456-12345678";
774 EXPECT_THAT(classifier->Annotate(test_string),
775 ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")}));
776 }
777 #endif // TC3_DISABLE_LUA
778
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityData)779 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) {
780 const std::string test_model = ReadFile(GetTestModelPath());
781 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
782
783 // Add fake entity schema metadata.
784 AddTestEntitySchemaData(unpacked_model.get());
785
786 AddTestRegexModel(unpacked_model.get());
787
788 flatbuffers::FlatBufferBuilder builder;
789 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
790
791 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
792 reinterpret_cast<const char*>(builder.GetBufferPointer()),
793 builder.GetSize(), unilib_.get(), calendarlib_.get());
794 ASSERT_TRUE(classifier);
795
796 AnnotationOptions options;
797 options.is_serialized_entity_data_enabled = true;
798 auto annotations =
799 classifier->Annotate("Barack Obama is 57 years old", options);
800 EXPECT_EQ(1, annotations.size());
801 EXPECT_EQ(1, annotations[0].classification.size());
802 EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
803
804 // Check entity data.
805 const flatbuffers::Table* entity =
806 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
807 annotations[0].classification[0].serialized_entity_data.data()));
808 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
809 "Barack");
810 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
811 "Obama");
812 // Check `age`.
813 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
814
815 // Check `is_alive`.
816 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
817
818 // Check `former_us_president`.
819 EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
820 }
821
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataNormalization)822 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) {
823 const std::string test_model = ReadFile(GetTestModelPath());
824 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
825
826 // Add fake entity schema metadata.
827 AddTestEntitySchemaData(unpacked_model.get());
828
829 AddTestRegexModel(unpacked_model.get());
830
831 // Upper case last name as post-processing.
832 RegexModel_::PatternT* pattern =
833 unpacked_model->regex_model->patterns.back().get();
834 pattern->capturing_group[2]->normalization_options.reset(
835 new NormalizationOptionsT);
836 pattern->capturing_group[2]
837 ->normalization_options->codepointwise_normalization =
838 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
839
840 flatbuffers::FlatBufferBuilder builder;
841 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
842
843 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
844 reinterpret_cast<const char*>(builder.GetBufferPointer()),
845 builder.GetSize(), unilib_.get(), calendarlib_.get());
846 ASSERT_TRUE(classifier);
847
848 AnnotationOptions options;
849 options.is_serialized_entity_data_enabled = true;
850 auto annotations =
851 classifier->Annotate("Barack Obama is 57 years old", options);
852 EXPECT_EQ(1, annotations.size());
853 EXPECT_EQ(1, annotations[0].classification.size());
854 EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
855
856 // Check normalization.
857 const flatbuffers::Table* entity =
858 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
859 annotations[0].classification[0].serialized_entity_data.data()));
860 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
861 "OBAMA");
862 }
863
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataDisabled)864 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) {
865 const std::string test_model = ReadFile(GetTestModelPath());
866 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
867
868 // Add fake entity schema metadata.
869 AddTestEntitySchemaData(unpacked_model.get());
870
871 AddTestRegexModel(unpacked_model.get());
872
873 flatbuffers::FlatBufferBuilder builder;
874 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
875
876 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
877 reinterpret_cast<const char*>(builder.GetBufferPointer()),
878 builder.GetSize(), unilib_.get(), calendarlib_.get());
879 ASSERT_TRUE(classifier);
880
881 AnnotationOptions options;
882 options.is_serialized_entity_data_enabled = false;
883 auto annotations =
884 classifier->Annotate("Barack Obama is 57 years old", options);
885 EXPECT_EQ(1, annotations.size());
886 EXPECT_EQ(1, annotations[0].classification.size());
887 EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
888
889 // Check entity data.
890 EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data);
891 }
892
TEST_F(AnnotatorTest,PhoneFiltering)893 TEST_F(AnnotatorTest, PhoneFiltering) {
894 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
895 GetTestModelPath(), unilib_.get(), calendarlib_.get());
896 ASSERT_TRUE(classifier);
897
898 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
899 "phone: (123) 456 789", {7, 20})));
900 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
901 "phone: (123) 456 789,0001112", {7, 25})));
902 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
903 "phone: (123) 456 789,0001112", {7, 28})));
904 }
905
TEST_F(AnnotatorTest,SuggestSelection)906 TEST_F(AnnotatorTest, SuggestSelection) {
907 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
908 GetTestModelPath(), unilib_.get(), calendarlib_.get());
909 ASSERT_TRUE(classifier);
910
911 EXPECT_EQ(classifier->SuggestSelection(
912 "this afternoon Barack Obama gave a speech at", {15, 21}),
913 CodepointSpan(15, 21));
914
915 // Try passing whole string.
916 // If more than 1 token is specified, we should return back what entered.
917 EXPECT_EQ(
918 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
919 CodepointSpan(0, 27));
920
921 // Single letter.
922 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1));
923
924 // Single word.
925 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4));
926
927 EXPECT_EQ(
928 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
929 CodepointSpan(11, 23));
930
931 // Unpaired bracket stripping.
932 EXPECT_EQ(
933 classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}),
934 CodepointSpan(11, 25));
935 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}),
936 CodepointSpan(12, 15));
937 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}),
938 CodepointSpan(11, 15));
939 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}),
940 CodepointSpan(12, 15));
941
942 // If the resulting selection would be empty, the original span is returned.
943 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
944 CodepointSpan(11, 13));
945 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
946 CodepointSpan(11, 12));
947 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
948 CodepointSpan(11, 12));
949
950 // If the original span is larger than the found selection, the original span
951 // is returned.
952 EXPECT_EQ(
953 classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}),
954 CodepointSpan(5, 24));
955 }
956
TEST_F(AnnotatorTest,SuggestSelectionDisabledFail)957 TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
958 const std::string test_model = ReadFile(GetTestModelPath());
959 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
960
961 // Disable the selection model.
962 unpacked_model->selection_model.clear();
963 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
964 unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
965
966 flatbuffers::FlatBufferBuilder builder;
967 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
968
969 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
970 reinterpret_cast<const char*>(builder.GetBufferPointer()),
971 builder.GetSize(), unilib_.get(), calendarlib_.get());
972 // Selection model needs to be present for annotation.
973 ASSERT_FALSE(classifier);
974 }
975
TEST_F(AnnotatorTest,SuggestSelectionDisabled)976 TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
977 const std::string test_model = ReadFile(GetTestModelPath());
978 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
979
980 // Disable the selection model.
981 unpacked_model->selection_model.clear();
982 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
983 unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
984 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
985
986 // Disable the number annotator. With the selection model disabled, there is
987 // no feature processor, which is required for the number annotator.
988 unpacked_model->number_annotator_options->enabled = false;
989
990 flatbuffers::FlatBufferBuilder builder;
991 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
992
993 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
994 reinterpret_cast<const char*>(builder.GetBufferPointer()),
995 builder.GetSize(), unilib_.get(), calendarlib_.get());
996 ASSERT_TRUE(classifier);
997
998 EXPECT_EQ(
999 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1000 CodepointSpan(11, 14));
1001
1002 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
1003 "call me at (800) 123-456 today", {11, 24})));
1004
1005 EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
1006 IsEmpty());
1007 }
1008
TEST_F(AnnotatorTest,SuggestSelectionFilteredCollections)1009 TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
1010 const std::string test_model = ReadFile(GetTestModelPath());
1011
1012 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1013 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1014 ASSERT_TRUE(classifier);
1015
1016 EXPECT_EQ(
1017 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1018 CodepointSpan(11, 23));
1019
1020 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1021 unpacked_model->output_options.reset(new OutputOptionsT);
1022
1023 // Disable phone selection
1024 unpacked_model->output_options->filtered_collections_selection.push_back(
1025 "phone");
1026 // We need to force this for filtering.
1027 unpacked_model->selection_options->always_classify_suggested_selection = true;
1028
1029 flatbuffers::FlatBufferBuilder builder;
1030 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1031
1032 classifier = Annotator::FromUnownedBuffer(
1033 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1034 builder.GetSize(), unilib_.get(), calendarlib_.get());
1035 ASSERT_TRUE(classifier);
1036
1037 EXPECT_EQ(
1038 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1039 CodepointSpan(11, 14));
1040
1041 // Address selection should still work.
1042 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1043 CodepointSpan(0, 27));
1044 }
1045
TEST_F(AnnotatorTest,SuggestSelectionsAreSymmetric)1046 TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
1047 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1048 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1049 ASSERT_TRUE(classifier);
1050
1051 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
1052 CodepointSpan(0, 27));
1053 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1054 CodepointSpan(0, 27));
1055 EXPECT_EQ(
1056 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
1057 CodepointSpan(0, 27));
1058 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
1059 {16, 22}),
1060 CodepointSpan(6, 33));
1061 }
1062
TEST_F(AnnotatorTest,SuggestSelectionWithNewLine)1063 TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
1064 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1065 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1066 ASSERT_TRUE(classifier);
1067
1068 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
1069 CodepointSpan(4, 16));
1070 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
1071 CodepointSpan(0, 12));
1072
1073 SelectionOptions options;
1074 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
1075 CodepointSpan(0, 12));
1076 }
1077
TEST_F(AnnotatorTest,SuggestSelectionWithPunctuation)1078 TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
1079 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1080 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1081 ASSERT_TRUE(classifier);
1082
1083 // From the right.
1084 EXPECT_EQ(classifier->SuggestSelection(
1085 "this afternoon BarackObama, gave a speech at", {15, 26}),
1086 CodepointSpan(15, 26));
1087
1088 // From the right multiple.
1089 EXPECT_EQ(classifier->SuggestSelection(
1090 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
1091 CodepointSpan(15, 26));
1092
1093 // From the left multiple.
1094 EXPECT_EQ(classifier->SuggestSelection(
1095 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
1096 CodepointSpan(21, 32));
1097
1098 // From both sides.
1099 EXPECT_EQ(classifier->SuggestSelection(
1100 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
1101 CodepointSpan(16, 27));
1102 }
1103
TEST_F(AnnotatorTest,SuggestSelectionNoCrashWithJunk)1104 TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
1105 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1106 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1107 ASSERT_TRUE(classifier);
1108
1109 // Try passing in bunch of invalid selections.
1110 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27));
1111 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
1112 CodepointSpan(-10, 27));
1113 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
1114 CodepointSpan(0, 27));
1115 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
1116 CodepointSpan(-30, 300));
1117 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
1118 CodepointSpan(-10, -1));
1119 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
1120 CodepointSpan(100, 17));
1121
1122 // Try passing invalid utf8.
1123 EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
1124 CodepointSpan(-1, -1));
1125 }
1126
TEST_F(AnnotatorTest,SuggestSelectionSelectSpace)1127 TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
1128 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1129 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1130 ASSERT_TRUE(classifier);
1131
1132 EXPECT_EQ(
1133 classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
1134 CodepointSpan(11, 23));
1135 EXPECT_EQ(
1136 classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
1137 CodepointSpan(10, 11));
1138 EXPECT_EQ(
1139 classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
1140 CodepointSpan(23, 24));
1141 EXPECT_EQ(
1142 classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
1143 CodepointSpan(23, 24));
1144 EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
1145 {14, 17}),
1146 CodepointSpan(11, 25));
1147 EXPECT_EQ(
1148 classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
1149 CodepointSpan(11, 23));
1150 EXPECT_EQ(
1151 classifier->SuggestSelection(
1152 "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
1153 CodepointSpan(14, 40));
1154 EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
1155 CodepointSpan(4, 5));
1156 EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
1157 CodepointSpan(7, 8));
1158
1159 // With a punctuation around the selected whitespace.
1160 EXPECT_EQ(
1161 classifier->SuggestSelection(
1162 "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
1163 CodepointSpan(14, 41));
1164
1165 // When all's whitespace, should return the original indices.
1166 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
1167 CodepointSpan(0, 1));
1168 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
1169 CodepointSpan(0, 3));
1170 EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
1171 CodepointSpan(2, 3));
1172 EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
1173 CodepointSpan(5, 6));
1174 }
1175
TEST_F(AnnotatorTest,SnapLeftIfWhitespaceSelection)1176 TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
1177 UnicodeText text;
1178
1179 text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
1180 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1181 CodepointSpan(3, 4));
1182 text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
1183 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1184 CodepointSpan(3, 4));
1185
1186 // Nothing on the left.
1187 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
1188 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1189 CodepointSpan(4, 5));
1190 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
1191 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1192 CodepointSpan(0, 1));
1193
1194 // Whitespace only.
1195 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
1196 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_),
1197 CodepointSpan(2, 3));
1198 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
1199 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1200 CodepointSpan(4, 5));
1201 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
1202 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1203 CodepointSpan(0, 1));
1204 }
1205
TEST_F(AnnotatorTest,Annotate)1206 TEST_F(AnnotatorTest, Annotate) {
1207 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1208 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1209 ASSERT_TRUE(classifier);
1210
1211 const std::string test_string =
1212 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1213 "number is 853 225 3556";
1214 EXPECT_THAT(classifier->Annotate(test_string),
1215 ElementsAreArray({
1216 IsAnnotatedSpan(28, 55, "address"),
1217 IsAnnotatedSpan(79, 91, "phone"),
1218 }));
1219
1220 AnnotationOptions options;
1221 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1222 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1223 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1224 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1225 // Try passing invalid utf8.
1226 EXPECT_TRUE(
1227 classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
1228 .empty());
1229 }
1230
TEST_F(AnnotatorTest,AnnotatesWithBracketStripping)1231 TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
1232 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1233 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1234 ASSERT_TRUE(classifier);
1235
1236 EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"),
1237 ElementsAreArray({
1238 IsAnnotatedSpan(11, 26, "phone"),
1239 }));
1240
1241 // Unpaired bracket stripping.
1242 EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"),
1243 ElementsAreArray({
1244 IsAnnotatedSpan(12, 23, "phone"),
1245 }));
1246 EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"),
1247 ElementsAreArray({
1248 IsAnnotatedSpan(11, 22, "phone"),
1249 }));
1250 EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"),
1251 ElementsAreArray({
1252 IsAnnotatedSpan(12, 23, "phone"),
1253 }));
1254 }
1255
TEST_F(AnnotatorTest,AnnotatesWithBracketStrippingOptimized)1256 TEST_F(AnnotatorTest, AnnotatesWithBracketStrippingOptimized) {
1257 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1258 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1259 ASSERT_TRUE(classifier);
1260
1261 AnnotationOptions options;
1262 options.enable_optimization = true;
1263
1264 EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today", options),
1265 ElementsAreArray({
1266 IsAnnotatedSpan(11, 26, "phone"),
1267 }));
1268
1269 // Unpaired bracket stripping.
1270 EXPECT_THAT(classifier->Annotate("call me at (07038201818 today", options),
1271 ElementsAreArray({
1272 IsAnnotatedSpan(12, 23, "phone"),
1273 }));
1274 EXPECT_THAT(classifier->Annotate("call me at 07038201818) today", options),
1275 ElementsAreArray({
1276 IsAnnotatedSpan(11, 22, "phone"),
1277 }));
1278 EXPECT_THAT(classifier->Annotate("call me at )07038201818( today", options),
1279 ElementsAreArray({
1280 IsAnnotatedSpan(12, 23, "phone"),
1281 }));
1282 }
1283
TEST_F(AnnotatorTest,AnnotatesOverlappingNumbers)1284 TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
1285 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1286 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1287 ASSERT_TRUE(classifier);
1288 AnnotationOptions options;
1289 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1290
1291 // Number, float number and percentage annotator.
1292 EXPECT_THAT(
1293 classifier->Annotate("853 225 3556 and then turn it up 99%, 99 "
1294 "number, 12345.12345 float number",
1295 options),
1296 UnorderedElementsAreArray(
1297 {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"),
1298 IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"),
1299 IsAnnotatedSpan(33, 35, "number"),
1300 IsAnnotatedSpan(33, 36, "percentage"),
1301 IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"),
1302 IsAnnotatedSpan(49, 60, "phone")}));
1303 }
1304
TEST_F(AnnotatorTest,DoesNotAnnotateNumbersInSmartUsecase)1305 TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) {
1306 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1307 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1308 ASSERT_TRUE(classifier);
1309 AnnotationOptions options;
1310 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1311
1312 EXPECT_THAT(classifier->Annotate(
1313 "853 225 3556 and then turn it up 99%, 99 number", options),
1314 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
1315 IsAnnotatedSpan(33, 36, "percentage")}));
1316 }
1317
VerifyAnnotatesDurationsInRawMode(const Annotator * classifier)1318 void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) {
1319 ASSERT_TRUE(classifier);
1320 AnnotationOptions options;
1321 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1322
1323 // Duration annotator.
1324 EXPECT_THAT(classifier->Annotate(
1325 "it took 9 minutes and 7 seconds to get there", options),
1326 Contains(IsDurationSpan(
1327 /*start=*/8, /*end=*/31,
1328 /*duration_ms=*/9 * 60 * 1000 + 7 * 1000)));
1329 }
1330
TEST_F(AnnotatorTest,AnnotatesDurationsInRawMode)1331 TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) {
1332 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1333 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1334 VerifyAnnotatesDurationsInRawMode(classifier.get());
1335 }
1336
VerifyDurationAndRelativeTimeCanOverlapInRawMode(const Annotator * classifier)1337 void VerifyDurationAndRelativeTimeCanOverlapInRawMode(
1338 const Annotator* classifier) {
1339 ASSERT_TRUE(classifier);
1340 AnnotationOptions options;
1341 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1342 options.locales = "en";
1343
1344 const std::vector<AnnotatedSpan> annotations =
1345 classifier->Annotate("let's meet in 3 hours", options);
1346
1347 EXPECT_THAT(annotations,
1348 Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21,
1349 /*time_ms_utc=*/10800000L,
1350 DatetimeGranularity::GRANULARITY_HOUR)));
1351 EXPECT_THAT(annotations,
1352 Contains(IsDurationSpan(/*start=*/14, /*end=*/21,
1353 /*duration_ms=*/3 * 60 * 60 * 1000)));
1354 }
1355
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawMode)1356 TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
1357 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1358 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1359 VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1360 }
1361
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx)1362 TEST_F(AnnotatorTest,
1363 DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx) {
1364 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1365 std::unique_ptr<Annotator> classifier =
1366 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1367 unilib_.get(), calendarlib_.get());
1368 VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1369 }
1370
TEST_F(AnnotatorTest,AnnotateSplitLines)1371 TEST_F(AnnotatorTest, AnnotateSplitLines) {
1372 std::string model_buffer = ReadFile(GetTestModelPath());
1373 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1374 model->selection_feature_options->only_use_line_with_click = true;
1375 });
1376 std::unique_ptr<Annotator> classifier =
1377 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1378 unilib_.get(), calendarlib_.get());
1379
1380 ASSERT_TRUE(classifier);
1381
1382 const std::string str1 =
1383 "hey, sorry, just finished up. i didn't hear back from you in time.";
1384 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1385
1386 const int kAnnotationLength = 26;
1387 EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
1388 EXPECT_THAT(
1389 classifier->Annotate(str2),
1390 ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
1391
1392 const std::string str3 = str1 + "\n" + str2;
1393 EXPECT_THAT(
1394 classifier->Annotate(str3),
1395 ElementsAreArray({IsAnnotatedSpan(
1396 str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
1397 }
1398
TEST_F(AnnotatorTest,UsePipeAsNewLineCharacterShouldAnnotateSplitLines)1399 TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) {
1400 std::string model_buffer = ReadFile(GetTestModelPath());
1401 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1402 model->selection_feature_options->only_use_line_with_click = true;
1403 model->selection_feature_options->use_pipe_character_for_newline = true;
1404 });
1405 std::unique_ptr<Annotator> classifier =
1406 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1407 unilib_.get(), calendarlib_.get());
1408
1409 ASSERT_TRUE(classifier);
1410
1411 const std::string str1 = "hey, this is my phone number 853 225 3556";
1412 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1413 const std::string str3 = str1 + "|" + str2;
1414 const int kAnnotationLengthPhone = 12;
1415 const int kAnnotationLengthAddress = 26;
1416 // Splitting the lines on `str3` should have the same behavior (e.g. find the
1417 // phone and address spans) as if we would annotate `str1` and `str2`
1418 // individually.
1419 const std::vector<AnnotatedSpan>& annotated_spans =
1420 classifier->Annotate(str3);
1421 EXPECT_THAT(annotated_spans,
1422 ElementsAreArray(
1423 {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"),
1424 IsAnnotatedSpan(static_cast<int>(str1.size()) + 1,
1425 static_cast<int>(str1.size() + 1 +
1426 kAnnotationLengthAddress),
1427 "address")}));
1428 }
1429
TEST_F(AnnotatorTest,NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines)1430 TEST_F(AnnotatorTest,
1431 NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) {
1432 std::string model_buffer = ReadFile(GetTestModelPath());
1433 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1434 model->selection_feature_options->only_use_line_with_click = true;
1435 model->selection_feature_options->use_pipe_character_for_newline = false;
1436 });
1437 std::unique_ptr<Annotator> classifier =
1438 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1439 unilib_.get(), calendarlib_.get());
1440
1441 ASSERT_TRUE(classifier);
1442
1443 const std::string str1 = "hey, this is my phone number 853 225 3556";
1444 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1445 const std::string str3 = str1 + "|" + str2;
1446 const std::vector<AnnotatedSpan>& annotated_spans =
1447 classifier->Annotate(str3);
1448 // Note: We only check that we get a single annotated span here when the '|'
1449 // character is not used to split lines. The reason behind this is that the
1450 // model is not precise for such example and the resulted annotated span might
1451 // change when the model changes.
1452 EXPECT_THAT(annotated_spans.size(), 1);
1453 }
1454
TEST_F(AnnotatorTest,AnnotateSmallBatches)1455 TEST_F(AnnotatorTest, AnnotateSmallBatches) {
1456 const std::string test_model = ReadFile(GetTestModelPath());
1457 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1458
1459 // Set the batch size.
1460 unpacked_model->selection_options->batch_size = 4;
1461 flatbuffers::FlatBufferBuilder builder;
1462 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1463
1464 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1465 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1466 builder.GetSize(), unilib_.get(), calendarlib_.get());
1467 ASSERT_TRUE(classifier);
1468
1469 const std::string test_string =
1470 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1471 "number is 853 225 3556";
1472 EXPECT_THAT(classifier->Annotate(test_string),
1473 ElementsAreArray({
1474 IsAnnotatedSpan(28, 55, "address"),
1475 IsAnnotatedSpan(79, 91, "phone"),
1476 }));
1477
1478 AnnotationOptions options;
1479 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1480 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1481 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1482 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1483 }
1484
TEST_F(AnnotatorTest,AnnotateFilteringDiscardAll)1485 TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
1486 const std::string test_model = ReadFile(GetTestModelPath());
1487 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1488
1489 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1490 // Add test threshold.
1491 unpacked_model->triggering_options->min_annotate_confidence =
1492 2.f; // Discards all results.
1493 flatbuffers::FlatBufferBuilder builder;
1494 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1495
1496 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1497 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1498 builder.GetSize(), unilib_.get(), calendarlib_.get());
1499 ASSERT_TRUE(classifier);
1500
1501 const std::string test_string =
1502 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1503 "number is 853 225 3556";
1504
1505 EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
1506 }
1507
TEST_F(AnnotatorTest,AnnotateFilteringKeepAll)1508 TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
1509 const std::string test_model = ReadFile(GetTestModelPath());
1510 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1511
1512 // Add test thresholds.
1513 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1514 unpacked_model->triggering_options->min_annotate_confidence =
1515 0.f; // Keeps all results.
1516 unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
1517 flatbuffers::FlatBufferBuilder builder;
1518 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1519
1520 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1521 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1522 builder.GetSize(), unilib_.get(), calendarlib_.get());
1523 ASSERT_TRUE(classifier);
1524
1525 const std::string test_string =
1526 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1527 "number is 853 225 3556";
1528 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
1529 }
1530
TEST_F(AnnotatorTest,AnnotateDisabled)1531 TEST_F(AnnotatorTest, AnnotateDisabled) {
1532 const std::string test_model = ReadFile(GetTestModelPath());
1533 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1534
1535 // Disable the model for annotation.
1536 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
1537 flatbuffers::FlatBufferBuilder builder;
1538 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1539
1540 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1541 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1542 builder.GetSize(), unilib_.get(), calendarlib_.get());
1543 ASSERT_TRUE(classifier);
1544 const std::string test_string =
1545 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1546 "number is 853 225 3556";
1547 EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
1548 }
1549
TEST_F(AnnotatorTest,AnnotateFilteredCollections)1550 TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
1551 const std::string test_model = ReadFile(GetTestModelPath());
1552
1553 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1554 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1555 ASSERT_TRUE(classifier);
1556
1557 const std::string test_string =
1558 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1559 "number is 853 225 3556";
1560
1561 EXPECT_THAT(classifier->Annotate(test_string),
1562 ElementsAreArray({
1563 IsAnnotatedSpan(28, 55, "address"),
1564 IsAnnotatedSpan(79, 91, "phone"),
1565 }));
1566
1567 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1568 unpacked_model->output_options.reset(new OutputOptionsT);
1569
1570 // Disable phone annotation
1571 unpacked_model->output_options->filtered_collections_annotation.push_back(
1572 "phone");
1573
1574 flatbuffers::FlatBufferBuilder builder;
1575 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1576
1577 classifier = Annotator::FromUnownedBuffer(
1578 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1579 builder.GetSize(), unilib_.get(), calendarlib_.get());
1580 ASSERT_TRUE(classifier);
1581
1582 EXPECT_THAT(classifier->Annotate(test_string),
1583 ElementsAreArray({
1584 IsAnnotatedSpan(28, 55, "address"),
1585 }));
1586 }
1587
TEST_F(AnnotatorTest,AnnotateFilteredCollectionsSuppress)1588 TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
1589 const std::string test_model = ReadFile(GetTestModelPath());
1590
1591 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1592 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1593 ASSERT_TRUE(classifier);
1594
1595 const std::string test_string =
1596 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1597 "number is 853 225 3556";
1598
1599 EXPECT_THAT(classifier->Annotate(test_string),
1600 ElementsAreArray({
1601 IsAnnotatedSpan(28, 55, "address"),
1602 IsAnnotatedSpan(79, 91, "phone"),
1603 }));
1604
1605 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1606 unpacked_model->output_options.reset(new OutputOptionsT);
1607
1608 // We add a custom annotator that wins against the phone classification
1609 // below and that we subsequently suppress.
1610 unpacked_model->output_options->filtered_collections_annotation.push_back(
1611 "suppress");
1612
1613 unpacked_model->regex_model->patterns.push_back(MakePattern(
1614 "suppress", "(\\d{3} ?\\d{4})",
1615 /*enabled_for_classification=*/false,
1616 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
1617
1618 flatbuffers::FlatBufferBuilder builder;
1619 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1620
1621 classifier = Annotator::FromUnownedBuffer(
1622 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1623 builder.GetSize(), unilib_.get(), calendarlib_.get());
1624 ASSERT_TRUE(classifier);
1625
1626 EXPECT_THAT(classifier->Annotate(test_string),
1627 ElementsAreArray({
1628 IsAnnotatedSpan(28, 55, "address"),
1629 }));
1630 }
1631
VerifyClassifyTextDateInZurichTimezone(const Annotator * classifier)1632 void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) {
1633 EXPECT_TRUE(classifier);
1634 ClassificationOptions options;
1635 options.reference_timezone = "Europe/Zurich";
1636 options.locales = "en";
1637
1638 std::vector<ClassificationResult> result =
1639 classifier->ClassifyText("january 1, 2017", {0, 15}, options);
1640
1641 EXPECT_THAT(result,
1642 ElementsAre(IsDateResult(1483225200000,
1643 DatetimeGranularity::GRANULARITY_DAY)));
1644 }
1645
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezone)1646 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
1647 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1648 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1649 VerifyClassifyTextDateInZurichTimezone(classifier.get());
1650 }
1651
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezoneWithDatetimeRegEx)1652 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeRegEx) {
1653 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1654 std::unique_ptr<Annotator> classifier =
1655 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1656 unilib_.get(), calendarlib_.get());
1657 VerifyClassifyTextDateInZurichTimezone(classifier.get());
1658 }
1659
VerifyClassifyTextDateInLATimezone(const Annotator * classifier)1660 void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
1661 EXPECT_TRUE(classifier);
1662 ClassificationOptions options;
1663 options.reference_timezone = "America/Los_Angeles";
1664 options.locales = "en";
1665
1666 std::vector<ClassificationResult> result =
1667 classifier->ClassifyText("march 1, 2017", {0, 13}, options);
1668
1669 EXPECT_THAT(result,
1670 ElementsAre(IsDateResult(1488355200000,
1671 DatetimeGranularity::GRANULARITY_DAY)));
1672 }
1673
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezoneWithDatetimeRegEx)1674 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeRegEx) {
1675 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1676 std::unique_ptr<Annotator> classifier =
1677 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1678 unilib_.get(), calendarlib_.get());
1679 VerifyClassifyTextDateInLATimezone(classifier.get());
1680 }
1681
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezone)1682 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
1683 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1684 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1685 VerifyClassifyTextDateInLATimezone(classifier.get());
1686 }
1687
VerifyClassifyTextDateOnAotherLine(const Annotator * classifier)1688 void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
1689 EXPECT_TRUE(classifier);
1690 ClassificationOptions options;
1691 options.reference_timezone = "Europe/Zurich";
1692 options.locales = "en";
1693
1694 std::vector<ClassificationResult> result = classifier->ClassifyText(
1695 "hello world this is the first line\n"
1696 "january 1, 2017",
1697 {35, 50}, options);
1698
1699 EXPECT_THAT(result,
1700 ElementsAre(IsDateResult(1483225200000,
1701 DatetimeGranularity::GRANULARITY_DAY)));
1702 }
1703
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLineWithDatetimeRegEx)1704 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeRegEx) {
1705 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1706 std::unique_ptr<Annotator> classifier =
1707 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1708 unilib_.get(), calendarlib_.get());
1709 VerifyClassifyTextDateOnAotherLine(classifier.get());
1710 }
1711
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLine)1712 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
1713 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1714 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1715 VerifyClassifyTextDateOnAotherLine(classifier.get());
1716 }
1717
VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(const Annotator * classifier)1718 void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
1719 const Annotator* classifier) {
1720 EXPECT_TRUE(classifier);
1721 std::vector<ClassificationResult> result;
1722 ClassificationOptions options;
1723
1724 options.reference_timezone = "Europe/Zurich";
1725 options.locales = "en-US";
1726 result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
1727
1728 // In US, the date should be interpreted as <month>.<day>.
1729 EXPECT_THAT(result,
1730 ElementsAre(IsDatetimeResult(
1731 5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
1732 }
1733
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDay)1734 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
1735 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1736 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1737 VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1738 }
1739
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx)1740 TEST_F(AnnotatorTest,
1741 ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx) {
1742 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1743 std::unique_ptr<Annotator> classifier =
1744 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1745 unilib_.get(), calendarlib_.get());
1746 VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1747 }
1748
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay)1749 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
1750 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1751 std::unique_ptr<Annotator> classifier =
1752 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1753 unilib_.get(), calendarlib_.get());
1754 EXPECT_TRUE(classifier);
1755 std::vector<ClassificationResult> result;
1756 ClassificationOptions options;
1757
1758 options.reference_timezone = "Europe/Zurich";
1759 options.locales = "de";
1760 result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
1761
1762 // In Germany, the date should be interpreted as <day>.<month>.
1763 EXPECT_THAT(result,
1764 ElementsAre(IsDatetimeResult(
1765 10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
1766 }
1767
TEST_F(AnnotatorTest,ClassifyTextAmbiguousDatetime)1768 TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
1769 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1770 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1771 EXPECT_TRUE(classifier);
1772 ClassificationOptions options;
1773 options.reference_timezone = "Europe/Zurich";
1774 options.locales = "en-US";
1775 const std::vector<ClassificationResult> result =
1776 classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
1777
1778 EXPECT_THAT(
1779 result,
1780 ElementsAre(
1781 IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1782 IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1783 }
1784
TEST_F(AnnotatorTest,AnnotateAmbiguousDatetime)1785 TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
1786 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1787 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1788 EXPECT_TRUE(classifier);
1789 AnnotationOptions options;
1790 options.reference_timezone = "Europe/Zurich";
1791 options.locales = "en-US";
1792 const std::vector<AnnotatedSpan> spans =
1793 classifier->Annotate("set an alarm for 10:30", options);
1794
1795 ASSERT_EQ(spans.size(), 1);
1796 const std::vector<ClassificationResult> result = spans[0].classification;
1797 EXPECT_THAT(
1798 result,
1799 ElementsAre(
1800 IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1801 IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1802 }
1803
TEST_F(AnnotatorTest,SuggestTextDateDisabled)1804 TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
1805 std::string test_model = GetTestModelWithDatetimeRegEx();
1806 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1807
1808 // Disable the patterns for selection.
1809 for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1810 unpacked_model->datetime_model->patterns[i]->enabled_modes =
1811 ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1812 }
1813 flatbuffers::FlatBufferBuilder builder;
1814 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1815
1816 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1817 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1818 builder.GetSize(), unilib_.get(), calendarlib_.get());
1819 ASSERT_TRUE(classifier);
1820 EXPECT_EQ("date",
1821 FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1822 EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1823 CodepointSpan(0, 7));
1824 EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1825 ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1826 }
1827
TEST_F(AnnotatorTest,AnnotatesWithGrammarModel)1828 TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
1829 const std::string test_model = ReadFile(GetTestModelPath());
1830 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1831
1832 // Add test grammar model.
1833 unpacked_model->grammar_model.reset(new GrammarModelT);
1834 GrammarModelT* grammar_model = unpacked_model->grammar_model.get();
1835 grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
1836 grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU;
1837 grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false;
1838 grammar_model->tokenizer_options->tokenize_on_script_change = true;
1839
1840 // Add test rules.
1841 grammar_model->rules.reset(new grammar::RulesSetT);
1842 grammar::LocaleShardMap locale_shard_map =
1843 grammar::LocaleShardMap::CreateLocaleShardMap({""});
1844 grammar::Rules rules(locale_shard_map);
1845 rules.Add("<tv_detective>", {"jessica", "fletcher"});
1846 rules.Add("<tv_detective>", {"columbo"});
1847 rules.Add("<tv_detective>", {"magnum"});
1848 rules.Add(
1849 "<famous_person>", {"<tv_detective>"},
1850 /*callback=*/
1851 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1852 /*callback_param=*/0 /* rule classification result */);
1853
1854 // Set result.
1855 grammar_model->rule_classification_result.emplace_back(
1856 new GrammarModel_::RuleClassificationResultT);
1857 GrammarModel_::RuleClassificationResultT* result =
1858 grammar_model->rule_classification_result.back().get();
1859 result->collection_name = "famous person";
1860 result->enabled_modes = ModeFlag_ALL;
1861 rules.Finalize().Serialize(/*include_debug_information=*/false,
1862 grammar_model->rules.get());
1863 flatbuffers::FlatBufferBuilder builder;
1864 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1865
1866 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1867 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1868 builder.GetSize(), unilib_.get(), calendarlib_.get());
1869 ASSERT_TRUE(classifier);
1870
1871 const std::string test_string =
1872 "Did you see the Novel Connection episode where Jessica Fletcher helps "
1873 "Magnum solve the case? I thought that was with Columbo ...";
1874
1875 EXPECT_THAT(classifier->Annotate(test_string),
1876 ElementsAre(IsAnnotatedSpan(47, 63, "famous person"),
1877 IsAnnotatedSpan(70, 76, "famous person"),
1878 IsAnnotatedSpan(117, 124, "famous person")));
1879 EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher",
1880 CodepointSpan{0, 16})),
1881 Eq("famous person"));
1882 EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}),
1883 Eq(CodepointSpan{0, 16}));
1884 }
1885
TEST_F(AnnotatorTest,ResolveConflictsTrivial)1886 TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1887 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1888
1889 std::vector<AnnotatedSpan> candidates{
1890 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1891 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1892
1893 BaseOptions options;
1894 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1895 std::vector<int> chosen;
1896 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1897 locales, options,
1898 /*interpreter_manager=*/nullptr, &chosen);
1899 EXPECT_THAT(chosen, ElementsAreArray({0}));
1900 }
1901
TEST_F(AnnotatorTest,ResolveConflictsSequence)1902 TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1903 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1904
1905 std::vector<AnnotatedSpan> candidates{{
1906 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1907 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1908 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1909 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1910 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1911 }};
1912 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1913
1914 BaseOptions options;
1915 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1916 std::vector<int> chosen;
1917 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1918 locales, options,
1919 /*interpreter_manager=*/nullptr, &chosen);
1920 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1921 }
1922
TEST_F(AnnotatorTest,ResolveConflictsThreeSpans)1923 TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1924 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1925
1926 std::vector<AnnotatedSpan> candidates{{
1927 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1928 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
1929 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1930 }};
1931 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1932
1933 BaseOptions options;
1934 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1935 std::vector<int> chosen;
1936 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1937 locales, options,
1938 /*interpreter_manager=*/nullptr, &chosen);
1939 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1940 }
1941
TEST_F(AnnotatorTest,ResolveConflictsThreeSpansReversed)1942 TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1943 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1944
1945 std::vector<AnnotatedSpan> candidates{{
1946 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
1947 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1948 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
1949 }};
1950 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1951
1952 BaseOptions options;
1953 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1954 std::vector<int> chosen;
1955 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1956 locales, options,
1957 /*interpreter_manager=*/nullptr, &chosen);
1958 EXPECT_THAT(chosen, ElementsAreArray({1}));
1959 }
1960
TEST_F(AnnotatorTest,DoesNotPrioritizeLongerSpanWhenDoingConflictResolution)1961 TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) {
1962 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1963
1964 std::vector<AnnotatedSpan> candidates{{
1965 MakeAnnotatedSpan({3, 7}, "unit", 1),
1966 MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
1967 MakeAnnotatedSpan({5, 30}, "url", 1), // Looser!
1968 MakeAnnotatedSpan({14, 20}, "email", 1),
1969 }};
1970 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1971
1972 BaseOptions options;
1973 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1974 std::vector<int> chosen;
1975 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1976 locales, options,
1977 /*interpreter_manager=*/nullptr, &chosen);
1978 // Picks the first and the last annotations because they do not overlap.
1979 EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
1980 }
1981
TEST_F(AnnotatorTest,PrioritizeLongerSpanWhenDoingConflictResolution)1982 TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) {
1983 const std::string test_model = ReadFile(GetTestModelPath());
1984 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1985 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
1986 unpacked_model->conflict_resolution_options.reset(
1987 new Model_::ConflictResolutionOptionsT);
1988 unpacked_model->conflict_resolution_options->prioritize_longest_annotation =
1989 true;
1990
1991 flatbuffers::FlatBufferBuilder builder;
1992 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1993
1994 std::unique_ptr<TestingAnnotator> classifier =
1995 TestingAnnotator::FromUnownedBuffer(
1996 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1997 builder.GetSize(), unilib_.get(), calendarlib_.get());
1998 TC3_CHECK(classifier != nullptr);
1999
2000 std::vector<AnnotatedSpan> candidates{{
2001 MakeAnnotatedSpan({3, 7}, "unit", 1), // Looser!
2002 MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
2003 MakeAnnotatedSpan({5, 30}, "url", 1), // Pick longest match.
2004 MakeAnnotatedSpan({14, 20}, "email", 1), // Looser!
2005 }};
2006 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2007
2008 BaseOptions options;
2009 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
2010 std::vector<int> chosen;
2011 classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2012 locales, options,
2013 /*interpreter_manager=*/nullptr, &chosen);
2014 EXPECT_THAT(chosen, ElementsAreArray({2}));
2015 }
2016
TEST_F(AnnotatorTest,ResolveConflictsFiveSpans)2017 TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
2018 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2019
2020 std::vector<AnnotatedSpan> candidates{{
2021 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
2022 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
2023 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
2024 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
2025 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
2026 }};
2027 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2028
2029 BaseOptions options;
2030 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
2031 std::vector<int> chosen;
2032 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2033 locales, options,
2034 /*interpreter_manager=*/nullptr, &chosen);
2035 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
2036 }
2037
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst)2038 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
2039 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2040
2041 std::vector<AnnotatedSpan> candidates{{
2042 MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2043 AnnotatedSpan::Source::KNOWLEDGE),
2044 MakeAnnotatedSpan({5, 10}, "address", 0.6),
2045 }};
2046 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2047
2048 BaseOptions options;
2049 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2050 std::vector<int> chosen;
2051 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2052 locales, options,
2053 /*interpreter_manager=*/nullptr, &chosen);
2054 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2055 }
2056
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond)2057 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
2058 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2059
2060 std::vector<AnnotatedSpan> candidates{{
2061 MakeAnnotatedSpan({0, 15}, "address", 0.7),
2062 MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2063 AnnotatedSpan::Source::KNOWLEDGE),
2064 }};
2065 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2066
2067 BaseOptions options;
2068 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2069 std::vector<int> chosen;
2070 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2071 locales, options,
2072 /*interpreter_manager=*/nullptr, &chosen);
2073 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2074 }
2075
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedBothKnowledge)2076 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
2077 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2078
2079 std::vector<AnnotatedSpan> candidates{{
2080 MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2081 AnnotatedSpan::Source::KNOWLEDGE),
2082 MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2083 AnnotatedSpan::Source::KNOWLEDGE),
2084 }};
2085 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2086
2087 BaseOptions options;
2088 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2089 std::vector<int> chosen;
2090 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2091 locales, options,
2092 /*interpreter_manager=*/nullptr, &chosen);
2093 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2094 }
2095
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsNotAllowed)2096 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
2097 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2098
2099 std::vector<AnnotatedSpan> candidates{{
2100 MakeAnnotatedSpan({0, 15}, "address", 0.7),
2101 MakeAnnotatedSpan({5, 10}, "date", 0.6),
2102 }};
2103 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2104
2105 BaseOptions options;
2106 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2107 std::vector<int> chosen;
2108 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2109 locales, options,
2110 /*interpreter_manager=*/nullptr, &chosen);
2111 EXPECT_THAT(chosen, ElementsAreArray({0}));
2112 }
2113
TEST_F(AnnotatorTest,ResolveConflictsRawModeGeneralOverlapsAllowed)2114 TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) {
2115 TestingAnnotator classifier(
2116 unilib_.get(), calendarlib_.get(), [](ModelT* model) {
2117 model->conflict_resolution_options.reset(
2118 new Model_::ConflictResolutionOptionsT);
2119 model->conflict_resolution_options->do_conflict_resolution_in_raw_mode =
2120 false;
2121 });
2122
2123 std::vector<AnnotatedSpan> candidates{{
2124 MakeAnnotatedSpan({0, 15}, "address", 0.7),
2125 MakeAnnotatedSpan({5, 10}, "date", 0.6),
2126 }};
2127 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2128
2129 BaseOptions options;
2130 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2131 std::vector<int> chosen;
2132 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2133 locales, options,
2134 /*interpreter_manager=*/nullptr, &chosen);
2135 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2136 }
2137
VerifyLongInput(const Annotator * classifier)2138 void VerifyLongInput(const Annotator* classifier) {
2139 ASSERT_TRUE(classifier);
2140
2141 for (const auto& type_value_pair :
2142 std::vector<std::pair<std::string, std::string>>{
2143 {"address", "350 Third Street, Cambridge"},
2144 {"phone", "123 456-7890"},
2145 {"url", "www.google.com"},
2146 {"email", "someone@gmail.com"},
2147 {"flight", "LX 38"},
2148 {"date", "September 1, 2018"}}) {
2149 const std::string input_100k = std::string(50000, ' ') +
2150 type_value_pair.second +
2151 std::string(50000, ' ');
2152 const int value_length = type_value_pair.second.size();
2153
2154 AnnotationOptions annotation_options;
2155 annotation_options.locales = "en";
2156 EXPECT_THAT(classifier->Annotate(input_100k, annotation_options),
2157 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
2158 type_value_pair.first)}));
2159 SelectionOptions selection_options;
2160 selection_options.locales = "en";
2161 EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001},
2162 selection_options),
2163 CodepointSpan(50000, 50000 + value_length));
2164
2165 ClassificationOptions classification_options;
2166 classification_options.locales = "en";
2167 EXPECT_EQ(type_value_pair.first,
2168 FirstResult(classifier->ClassifyText(
2169 input_100k, {50000, 50000 + value_length},
2170 classification_options)));
2171 }
2172 }
2173
TEST_F(AnnotatorTest,LongInput)2174 TEST_F(AnnotatorTest, LongInput) {
2175 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2176 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2177 VerifyLongInput(classifier.get());
2178 }
2179
TEST_F(AnnotatorTest,LongInputWithRegExDatetime)2180 TEST_F(AnnotatorTest, LongInputWithRegExDatetime) {
2181 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2182 std::unique_ptr<Annotator> classifier =
2183 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2184 unilib_.get(), calendarlib_.get());
2185 VerifyLongInput(classifier.get());
2186 }
2187
2188 // These coarse tests are there only to make sure the execution happens in
2189 // reasonable amount of time.
TEST_F(AnnotatorTest,LongInputNoResultCheck)2190 TEST_F(AnnotatorTest, LongInputNoResultCheck) {
2191 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2192 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2193 ASSERT_TRUE(classifier);
2194
2195 for (const std::string& value :
2196 std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
2197 const std::string input_100k =
2198 std::string(50000, ' ') + value + std::string(50000, ' ');
2199 const int value_length = value.size();
2200
2201 classifier->Annotate(input_100k);
2202 classifier->SuggestSelection(input_100k, {50000, 50001});
2203 classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
2204 }
2205 }
2206
TEST_F(AnnotatorTest,MaxTokenLength)2207 TEST_F(AnnotatorTest, MaxTokenLength) {
2208 const std::string test_model = ReadFile(GetTestModelPath());
2209 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2210
2211 std::unique_ptr<Annotator> classifier;
2212
2213 // With unrestricted number of tokens should behave normally.
2214 unpacked_model->classification_options->max_num_tokens = -1;
2215
2216 flatbuffers::FlatBufferBuilder builder;
2217 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2218 classifier = Annotator::FromUnownedBuffer(
2219 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2220 builder.GetSize(), unilib_.get(), calendarlib_.get());
2221 ASSERT_TRUE(classifier);
2222
2223 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2224 "I live at 350 Third Street, Cambridge.", {10, 37})),
2225 "address");
2226
2227 // Raise the maximum number of tokens to suppress the classification.
2228 unpacked_model->classification_options->max_num_tokens = 3;
2229
2230 flatbuffers::FlatBufferBuilder builder2;
2231 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2232 classifier = Annotator::FromUnownedBuffer(
2233 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2234 builder2.GetSize(), unilib_.get(), calendarlib_.get());
2235 ASSERT_TRUE(classifier);
2236
2237 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2238 "I live at 350 Third Street, Cambridge.", {10, 37})),
2239 "other");
2240 }
2241
TEST_F(AnnotatorTest,MinAddressTokenLength)2242 TEST_F(AnnotatorTest, MinAddressTokenLength) {
2243 const std::string test_model = ReadFile(GetTestModelPath());
2244 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2245
2246 std::unique_ptr<Annotator> classifier;
2247
2248 // With unrestricted number of address tokens should behave normally.
2249 unpacked_model->classification_options->address_min_num_tokens = 0;
2250
2251 flatbuffers::FlatBufferBuilder builder;
2252 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2253 classifier = Annotator::FromUnownedBuffer(
2254 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2255 builder.GetSize(), unilib_.get(), calendarlib_.get());
2256 ASSERT_TRUE(classifier);
2257
2258 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2259 "I live at 350 Third Street, Cambridge.", {10, 37})),
2260 "address");
2261
2262 // Raise number of address tokens to suppress the address classification.
2263 unpacked_model->classification_options->address_min_num_tokens = 5;
2264
2265 flatbuffers::FlatBufferBuilder builder2;
2266 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2267 classifier = Annotator::FromUnownedBuffer(
2268 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2269 builder2.GetSize(), unilib_.get(), calendarlib_.get());
2270 ASSERT_TRUE(classifier);
2271
2272 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2273 "I live at 350 Third Street, Cambridge.", {10, 37})),
2274 "other");
2275 }
2276
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighOtherIsPreferredToFlight)2277 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) {
2278 const std::string test_model = ReadFile(GetTestModelPath());
2279 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2280 unpacked_model->triggering_options->other_collection_priority_score = 1.0;
2281
2282 flatbuffers::FlatBufferBuilder builder;
2283 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2284 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2285 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2286 builder.GetSize(), unilib_.get(), calendarlib_.get());
2287 ASSERT_TRUE(classifier);
2288
2289 EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other");
2290 }
2291
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighFlightIsPreferredToOther)2292 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) {
2293 const std::string test_model = ReadFile(GetTestModelPath());
2294 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2295 unpacked_model->triggering_options->other_collection_priority_score = -100.0;
2296
2297 flatbuffers::FlatBufferBuilder builder;
2298 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2299 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2300 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2301 builder.GetSize(), unilib_.get(), calendarlib_.get());
2302 ASSERT_TRUE(classifier);
2303
2304 EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight");
2305 }
2306
TEST_F(AnnotatorTest,VisitAnnotatorModel)2307 TEST_F(AnnotatorTest, VisitAnnotatorModel) {
2308 EXPECT_TRUE(
2309 VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
2310 if (model == nullptr) {
2311 return false;
2312 }
2313 return true;
2314 }));
2315 EXPECT_FALSE(VisitAnnotatorModel<bool>(
2316 GetModelPath() + "non_existing_model.fb", [](const Model* model) {
2317 if (model == nullptr) {
2318 return false;
2319 }
2320 return true;
2321 }));
2322 }
2323
TEST_F(AnnotatorTest,TriggersWhenNoLanguageDetected)2324 TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) {
2325 std::string model_buffer = ReadFile(GetTestModelPath());
2326 model_buffer = ModifyAnnotatorModel(
2327 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2328 std::unique_ptr<Annotator> classifier =
2329 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2330 unilib_.get(), calendarlib_.get());
2331 ASSERT_TRUE(classifier);
2332
2333 EXPECT_THAT(classifier->Annotate("(555) 225-3556"),
2334 ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2335 EXPECT_EQ("phone",
2336 FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14})));
2337 EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}),
2338 CodepointSpan(0, 14));
2339 }
2340
TEST_F(AnnotatorTest,AnnotateTriggersWhenSupportedLanguageDetected)2341 TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) {
2342 std::string model_buffer = ReadFile(GetTestModelPath());
2343 model_buffer = ModifyAnnotatorModel(
2344 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2345 std::unique_ptr<Annotator> classifier =
2346 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2347 unilib_.get(), calendarlib_.get());
2348 ASSERT_TRUE(classifier);
2349 AnnotationOptions options;
2350 options.detected_text_language_tags = "cs";
2351
2352 EXPECT_THAT(classifier->Annotate("(555) 225-3556", options),
2353 ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2354 }
2355
TEST_F(AnnotatorTest,AnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2356 TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2357 std::string model_buffer = ReadFile(GetTestModelPath());
2358 model_buffer = ModifyAnnotatorModel(
2359 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2360 std::unique_ptr<Annotator> classifier =
2361 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2362 unilib_.get(), calendarlib_.get());
2363 ASSERT_TRUE(classifier);
2364 AnnotationOptions options;
2365 options.detected_text_language_tags = "de";
2366
2367 EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty());
2368 }
2369
TEST_F(AnnotatorTest,ClassifyTextTriggersWhenSupportedLanguageDetected)2370 TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) {
2371 std::string model_buffer = ReadFile(GetTestModelPath());
2372 model_buffer = ModifyAnnotatorModel(
2373 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2374 std::unique_ptr<Annotator> classifier =
2375 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2376 unilib_.get(), calendarlib_.get());
2377 ASSERT_TRUE(classifier);
2378 ClassificationOptions options;
2379 options.detected_text_language_tags = "cs";
2380
2381 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556",
2382 {0, 14}, options)));
2383 }
2384
TEST_F(AnnotatorTest,ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2385 TEST_F(AnnotatorTest,
2386 ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2387 std::string model_buffer = ReadFile(GetTestModelPath());
2388 model_buffer = ModifyAnnotatorModel(
2389 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2390 std::unique_ptr<Annotator> classifier =
2391 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2392 unilib_.get(), calendarlib_.get());
2393 ASSERT_TRUE(classifier);
2394 ClassificationOptions options;
2395 options.detected_text_language_tags = "de";
2396
2397 EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options),
2398 IsEmpty());
2399 }
2400
TEST_F(AnnotatorTest,SuggestSelectionTriggersWhenSupportedLanguageDetected)2401 TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) {
2402 std::string model_buffer = ReadFile(GetTestModelPath());
2403 model_buffer = ModifyAnnotatorModel(
2404 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2405 std::unique_ptr<Annotator> classifier =
2406 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2407 unilib_.get(), calendarlib_.get());
2408 ASSERT_TRUE(classifier);
2409 SelectionOptions options;
2410 options.detected_text_language_tags = "cs";
2411
2412 EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2413 CodepointSpan(0, 14));
2414 }
2415
TEST_F(AnnotatorTest,SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2416 TEST_F(AnnotatorTest,
2417 SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2418 std::string model_buffer = ReadFile(GetTestModelPath());
2419 model_buffer = ModifyAnnotatorModel(
2420 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2421 std::unique_ptr<Annotator> classifier =
2422 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2423 unilib_.get(), calendarlib_.get());
2424 ASSERT_TRUE(classifier);
2425 SelectionOptions options;
2426 options.detected_text_language_tags = "de";
2427
2428 EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2429 CodepointSpan(6, 9));
2430 }
2431
TEST_F(AnnotatorTest,MlModelTriggersWhenNoLanguageDetected)2432 TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) {
2433 std::string model_buffer = ReadFile(GetTestModelPath());
2434 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2435 model->triggering_locales = "en,cs";
2436 model->triggering_options->locales = "en,cs";
2437 });
2438 std::unique_ptr<Annotator> classifier =
2439 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2440 unilib_.get(), calendarlib_.get());
2441 ASSERT_TRUE(classifier);
2442
2443 EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"),
2444 ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2445 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2446 "350 Third Street, Cambridge", {0, 27})));
2447 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
2448 CodepointSpan(0, 27));
2449 }
2450
TEST_F(AnnotatorTest,MlModelAnnotateTriggersWhenSupportedLanguageDetected)2451 TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) {
2452 std::string model_buffer = ReadFile(GetTestModelPath());
2453 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2454 model->triggering_locales = "en,cs";
2455 model->triggering_options->locales = "en,cs";
2456 });
2457 std::unique_ptr<Annotator> classifier =
2458 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2459 unilib_.get(), calendarlib_.get());
2460 ASSERT_TRUE(classifier);
2461 AnnotationOptions options;
2462 options.detected_text_language_tags = "cs";
2463
2464 EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2465 ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2466 }
2467
TEST_F(AnnotatorTest,MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2468 TEST_F(AnnotatorTest,
2469 MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2470 std::string model_buffer = ReadFile(GetTestModelPath());
2471 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2472 model->triggering_locales = "en,cs";
2473 model->triggering_options->locales = "en,cs";
2474 });
2475 std::unique_ptr<Annotator> classifier =
2476 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2477 unilib_.get(), calendarlib_.get());
2478 ASSERT_TRUE(classifier);
2479 AnnotationOptions options;
2480 options.detected_text_language_tags = "de";
2481
2482 EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2483 IsEmpty());
2484 }
2485
TEST_F(AnnotatorTest,MlModelClassifyTextTriggersWhenSupportedLanguageDetected)2486 TEST_F(AnnotatorTest,
2487 MlModelClassifyTextTriggersWhenSupportedLanguageDetected) {
2488 std::string model_buffer = ReadFile(GetTestModelPath());
2489 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2490 model->triggering_locales = "en,cs";
2491 model->triggering_options->locales = "en,cs";
2492 });
2493 std::unique_ptr<Annotator> classifier =
2494 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2495 unilib_.get(), calendarlib_.get());
2496 ASSERT_TRUE(classifier);
2497 ClassificationOptions options;
2498 options.detected_text_language_tags = "cs";
2499
2500 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2501 "350 Third Street, Cambridge", {0, 27}, options)));
2502 }
2503
TEST_F(AnnotatorTest,MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2504 TEST_F(AnnotatorTest,
2505 MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2506 std::string model_buffer = ReadFile(GetTestModelPath());
2507 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2508 model->triggering_locales = "en,cs";
2509 model->triggering_options->locales = "en,cs";
2510 });
2511 std::unique_ptr<Annotator> classifier =
2512 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2513 unilib_.get(), calendarlib_.get());
2514 ASSERT_TRUE(classifier);
2515 ClassificationOptions options;
2516 options.detected_text_language_tags = "de";
2517
2518 EXPECT_THAT(
2519 classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options),
2520 IsEmpty());
2521 }
2522
TEST_F(AnnotatorTest,MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected)2523 TEST_F(AnnotatorTest,
2524 MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) {
2525 std::string model_buffer = ReadFile(GetTestModelPath());
2526 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2527 model->triggering_locales = "en,cs";
2528 model->triggering_options->locales = "en,cs";
2529 });
2530 std::unique_ptr<Annotator> classifier =
2531 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2532 unilib_.get(), calendarlib_.get());
2533 ASSERT_TRUE(classifier);
2534 SelectionOptions options;
2535 options.detected_text_language_tags = "cs";
2536
2537 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2538 options),
2539 CodepointSpan(0, 27));
2540 }
2541
TEST_F(AnnotatorTest,MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2542 TEST_F(AnnotatorTest,
2543 MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2544 std::string model_buffer = ReadFile(GetTestModelPath());
2545 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2546 model->triggering_locales = "en,cs";
2547 model->triggering_options->locales = "en,cs";
2548 });
2549 std::unique_ptr<Annotator> classifier =
2550 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2551 unilib_.get(), calendarlib_.get());
2552 ASSERT_TRUE(classifier);
2553 SelectionOptions options;
2554 options.detected_text_language_tags = "de";
2555
2556 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2557 options),
2558 CodepointSpan(4, 9));
2559 }
2560
VerifyClassifyTextOutputsDatetimeEntityData(const Annotator * classifier)2561 void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) {
2562 EXPECT_TRUE(classifier);
2563 std::vector<ClassificationResult> result;
2564 ClassificationOptions options;
2565 options.locales = "en-US";
2566
2567 result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
2568
2569 ASSERT_GE(result.size(), 0);
2570 const EntityData* entity_data =
2571 GetEntityData(result[0].serialized_entity_data.data());
2572 ASSERT_NE(entity_data, nullptr);
2573 ASSERT_NE(entity_data->datetime(), nullptr);
2574 EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L);
2575 EXPECT_EQ(entity_data->datetime()->granularity(),
2576 EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE);
2577 EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6);
2578
2579 auto* meridiem = entity_data->datetime()->datetime_component()->Get(0);
2580 EXPECT_EQ(meridiem->component_type(),
2581 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM);
2582 EXPECT_EQ(meridiem->absolute_value(), 0);
2583 EXPECT_EQ(meridiem->relative_count(), 0);
2584 EXPECT_EQ(meridiem->relation_type(),
2585 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2586
2587 auto* minute = entity_data->datetime()->datetime_component()->Get(1);
2588 EXPECT_EQ(minute->component_type(),
2589 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE);
2590 EXPECT_EQ(minute->absolute_value(), 0);
2591 EXPECT_EQ(minute->relative_count(), 0);
2592 EXPECT_EQ(minute->relation_type(),
2593 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2594
2595 auto* hour = entity_data->datetime()->datetime_component()->Get(2);
2596 EXPECT_EQ(hour->component_type(),
2597 EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR);
2598 EXPECT_EQ(hour->absolute_value(), 0);
2599 EXPECT_EQ(hour->relative_count(), 0);
2600 EXPECT_EQ(hour->relation_type(),
2601 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2602
2603 auto* day = entity_data->datetime()->datetime_component()->Get(3);
2604 EXPECT_EQ(
2605 day->component_type(),
2606 EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2607 EXPECT_EQ(day->absolute_value(), 5);
2608 EXPECT_EQ(day->relative_count(), 0);
2609 EXPECT_EQ(day->relation_type(),
2610 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2611
2612 auto* month = entity_data->datetime()->datetime_component()->Get(4);
2613 EXPECT_EQ(month->component_type(),
2614 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2615 EXPECT_EQ(month->absolute_value(), 3);
2616 EXPECT_EQ(month->relative_count(), 0);
2617 EXPECT_EQ(month->relation_type(),
2618 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2619
2620 auto* year = entity_data->datetime()->datetime_component()->Get(5);
2621 EXPECT_EQ(year->component_type(),
2622 EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2623 EXPECT_EQ(year->absolute_value(), 1970);
2624 EXPECT_EQ(year->relative_count(), 0);
2625 EXPECT_EQ(year->relation_type(),
2626 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2627 }
2628
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityData)2629 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) {
2630 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2631 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2632 VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2633 }
2634
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx)2635 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx) {
2636 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2637 std::unique_ptr<Annotator> classifier =
2638 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2639 unilib_.get(), calendarlib_.get());
2640 VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2641 }
2642
VerifyAnnotateOutputsDatetimeEntityData(const Annotator * classifier)2643 void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
2644 EXPECT_TRUE(classifier);
2645 std::vector<AnnotatedSpan> result;
2646 AnnotationOptions options;
2647 options.is_serialized_entity_data_enabled = true;
2648 options.locales = "en";
2649
2650 result = classifier->Annotate("September 1, 2019", options);
2651
2652 ASSERT_GE(result.size(), 0);
2653 ASSERT_GE(result[0].classification.size(), 0);
2654 ASSERT_EQ(result[0].classification[0].collection, "date");
2655 const EntityData* entity_data =
2656 GetEntityData(result[0].classification[0].serialized_entity_data.data());
2657 ASSERT_NE(entity_data, nullptr);
2658 ASSERT_NE(entity_data->datetime(), nullptr);
2659 EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L);
2660 EXPECT_EQ(entity_data->datetime()->granularity(),
2661 EntityData_::Datetime_::Granularity_GRANULARITY_DAY);
2662 EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3);
2663
2664 auto* day = entity_data->datetime()->datetime_component()->Get(0);
2665 EXPECT_EQ(
2666 day->component_type(),
2667 EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2668 EXPECT_EQ(day->absolute_value(), 1);
2669 EXPECT_EQ(day->relative_count(), 0);
2670 EXPECT_EQ(day->relation_type(),
2671 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2672
2673 auto* month = entity_data->datetime()->datetime_component()->Get(1);
2674 EXPECT_EQ(month->component_type(),
2675 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2676 EXPECT_EQ(month->absolute_value(), 9);
2677 EXPECT_EQ(month->relative_count(), 0);
2678 EXPECT_EQ(month->relation_type(),
2679 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2680
2681 auto* year = entity_data->datetime()->datetime_component()->Get(2);
2682 EXPECT_EQ(year->component_type(),
2683 EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2684 EXPECT_EQ(year->absolute_value(), 2019);
2685 EXPECT_EQ(year->relative_count(), 0);
2686 EXPECT_EQ(year->relation_type(),
2687 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2688 }
2689
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityData)2690 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) {
2691 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2692 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2693 VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2694 }
2695
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx)2696 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx) {
2697 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2698 std::unique_ptr<Annotator> classifier =
2699 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2700 unilib_.get(), calendarlib_.get());
2701 VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2702 }
2703
TEST_F(AnnotatorTest,AnnotateOutputsMoneyEntityData)2704 TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
2705 // std::string model_buffer = GetTestModelWithDatetimeRegEx();
2706 // std::unique_ptr<Annotator> classifier =
2707 // Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2708 // unilib_.get(), calendarlib_.get());
2709 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2710 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2711 EXPECT_TRUE(classifier);
2712 AnnotationOptions options;
2713 options.is_serialized_entity_data_enabled = true;
2714
2715 ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF",
2716 /*amount=*/"3.5", /*whole_part=*/3,
2717 /*decimal_part=*/5, /*nanos=*/500000000);
2718 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF",
2719 /*amount=*/"3.5", /*whole_part=*/3,
2720 /*decimal_part=*/5, /*nanos=*/500000000);
2721 ExpectFirstEntityIsMoney(
2722 classifier->Annotate("For online purchase of CHF 23.00 enter", options),
2723 "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2724 /*nanos=*/0);
2725 ExpectFirstEntityIsMoney(
2726 classifier->Annotate("For online purchase of 23.00 CHF enter", options),
2727 "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2728 /*nanos=*/0);
2729 ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£",
2730 /*amount=*/"4.8198", /*whole_part=*/4,
2731 /*decimal_part=*/8198, /*nanos=*/819800000);
2732 ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£",
2733 /*amount=*/"4.8198", /*whole_part=*/4,
2734 /*decimal_part=*/8198, /*nanos=*/819800000);
2735 ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2736 /*amount=*/"0.0255", /*whole_part=*/0,
2737 /*decimal_part=*/255, /*nanos=*/25500000);
2738 ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2739 /*amount=*/"0.0255", /*whole_part=*/0,
2740 /*decimal_part=*/255, /*nanos=*/25500000);
2741 ExpectFirstEntityIsMoney(
2742 classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE "
2743 "OR on card ending 0000.",
2744 options),
2745 "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2746 /*nanos=*/0);
2747 ExpectFirstEntityIsMoney(
2748 classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE "
2749 "OR on card ending 0000.",
2750 options),
2751 "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2752 /*nanos=*/0);
2753
2754 ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF",
2755 /*amount=*/"35",
2756 /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0);
2757 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF",
2758 /*amount=*/"35", /*whole_part=*/35,
2759 /*decimal_part=*/0, /*nanos=*/0);
2760 ExpectFirstEntityIsMoney(
2761 classifier->Annotate("and win back up to CHF 150 - with digitec",
2762 options),
2763 "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2764 /*nanos=*/0);
2765 ExpectFirstEntityIsMoney(
2766 classifier->Annotate("and win back up to 150 CHF - with digitec",
2767 options),
2768 "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2769 /*nanos=*/0);
2770
2771 ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options),
2772 "CHF", /*amount=*/"3.555.333",
2773 /*whole_part=*/3555333, /*decimal_part=*/0,
2774 /*nanos=*/0);
2775 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options),
2776 "CHF", /*amount=*/"3.555.333",
2777 /*whole_part=*/3555333, /*decimal_part=*/0,
2778 /*nanos=*/0);
2779 ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF",
2780 /*amount=*/"10,000", /*whole_part=*/10000,
2781 /*decimal_part=*/0, /*nanos=*/0);
2782 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF",
2783 /*amount=*/"10,000", /*whole_part=*/10000,
2784 /*decimal_part=*/0, /*nanos=*/0);
2785
2786 ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF",
2787 /*amount=*/"3,555.33", /*whole_part=*/3555,
2788 /*decimal_part=*/33, /*nanos=*/330000000);
2789 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF",
2790 /*amount=*/"3,555.33", /*whole_part=*/3555,
2791 /*decimal_part=*/33, /*nanos=*/330000000);
2792 ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$",
2793 /*amount=*/"3,000.00", /*whole_part=*/3000,
2794 /*decimal_part=*/0, /*nanos=*/0);
2795 ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$",
2796 /*amount=*/"3,000.00", /*whole_part=*/3000,
2797 /*decimal_part=*/0, /*nanos=*/0);
2798
2799 ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF",
2800 /*amount=*/"1.2", /*whole_part=*/1,
2801 /*decimal_part=*/2, /*nanos=*/200000000);
2802 ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF",
2803 /*amount=*/"1.2", /*whole_part=*/1,
2804 /*decimal_part=*/2, /*nanos=*/200000000);
2805
2806 ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$",
2807 /*amount=*/"1.123456789", /*whole_part=*/1,
2808 /*decimal_part=*/123456789, /*nanos=*/123456789);
2809 ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF",
2810 /*amount=*/"10.01", /*whole_part=*/10,
2811 /*decimal_part=*/1, /*nanos=*/10000000);
2812
2813 ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$",
2814 /*amount=*/"59 million", /*whole_part=*/59000000,
2815 /*decimal_part=*/0, /*nanos=*/0);
2816 ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€",
2817 /*amount=*/"7.05 k", /*whole_part=*/7050,
2818 /*decimal_part=*/5, /*nanos=*/0);
2819 ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€",
2820 /*amount=*/"7.123456789 m", /*whole_part=*/7123456,
2821 /*decimal_part=*/123456789, /*nanos=*/789000000);
2822 ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€",
2823 /*amount=*/"7.000056789 k", /*whole_part=*/7000,
2824 /*decimal_part=*/56789, /*nanos=*/56789000);
2825
2826 ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$",
2827 /*amount=*/"59.3 billion", /*whole_part=*/59,
2828 /*decimal_part=*/3, /*nanos=*/300000000);
2829 ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$",
2830 /*amount=*/"1.5 billion", /*whole_part=*/1500000000,
2831 /*decimal_part=*/5, /*nanos=*/0);
2832 }
2833
TEST_F(AnnotatorTest,TranslateAction)2834 TEST_F(AnnotatorTest, TranslateAction) {
2835 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2836 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2837 std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
2838 libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() +
2839 "lang_id.smfb");
2840 classifier->SetLangId(langid_model.get());
2841
2842 ClassificationOptions options;
2843 options.user_familiar_language_tags = "de";
2844
2845 std::vector<ClassificationResult> classifications =
2846 classifier->ClassifyText("hello, how are you doing?", {11, 14}, options);
2847 EXPECT_EQ(classifications.size(), 1);
2848 EXPECT_EQ(classifications[0].collection, "translate");
2849 }
2850
TEST_F(AnnotatorTest,AnnotateStructuredInputCallsMultipleAnnotators)2851 TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) {
2852 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2853 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2854
2855 std::vector<InputFragment> string_fragments = {
2856 {.text = "He owes me 3.5 CHF."},
2857 {.text = "...was born on 13/12/1989."},
2858 };
2859
2860 AnnotationOptions annotation_options;
2861 annotation_options.locales = "en";
2862 StatusOr<Annotations> annotations_status =
2863 classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2864 ASSERT_TRUE(annotations_status.ok());
2865 Annotations annotations = annotations_status.ValueOrDie();
2866 ASSERT_EQ(annotations.annotated_spans.size(), 2);
2867 EXPECT_THAT(annotations.annotated_spans[0],
2868 ElementsAreArray({IsAnnotatedSpan(11, 18, "money")}));
2869 EXPECT_THAT(annotations.annotated_spans[1],
2870 ElementsAreArray({IsAnnotatedSpan(15, 25, "date")}));
2871 }
2872
VerifyInputFragmentTimestampOverridesAnnotationOptions(const Annotator * classifier)2873 void VerifyInputFragmentTimestampOverridesAnnotationOptions(
2874 const Annotator* classifier) {
2875 AnnotationOptions annotation_options;
2876 annotation_options.locales = "en";
2877 annotation_options.reference_time_ms_utc =
2878 1554465190000; // 04/05/2019 11:53 am
2879 int64 fragment_reference_time = 946727580000; // 01/01/2000 11:53 am
2880 std::vector<InputFragment> string_fragments = {
2881 {.text = "New event at 17:20"},
2882 {
2883 .text = "New event at 17:20",
2884 .datetime_options = Optional<DatetimeOptions>(
2885 {.reference_time_ms_utc = fragment_reference_time}),
2886 }};
2887 StatusOr<Annotations> annotations_status =
2888 classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2889 ASSERT_TRUE(annotations_status.ok());
2890 Annotations annotations = annotations_status.ValueOrDie();
2891 ASSERT_EQ(annotations.annotated_spans.size(), 2);
2892 EXPECT_THAT(annotations.annotated_spans[0],
2893 ElementsAreArray({IsDatetimeSpan(
2894 /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000,
2895 DatetimeGranularity::GRANULARITY_MINUTE)}));
2896 EXPECT_THAT(annotations.annotated_spans[1],
2897 ElementsAreArray({IsDatetimeSpan(
2898 /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000,
2899 DatetimeGranularity::GRANULARITY_MINUTE)}));
2900 }
2901
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx)2902 TEST_F(AnnotatorTest,
2903 InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx) {
2904 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2905 std::unique_ptr<Annotator> classifier =
2906 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2907 unilib_.get(), calendarlib_.get());
2908 VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2909 }
2910
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptions)2911 TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) {
2912 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2913 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2914 VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2915 }
2916
VerifyInputFragmentTimezoneOverridesAnnotationOptions(const Annotator * classifier)2917 void VerifyInputFragmentTimezoneOverridesAnnotationOptions(
2918 const Annotator* classifier) {
2919 std::vector<InputFragment> string_fragments = {
2920 {.text = "11/12/2020 17:20"},
2921 {
2922 .text = "11/12/2020 17:20",
2923 .datetime_options = Optional<DatetimeOptions>(
2924 {.reference_timezone = "Europe/Zurich"}),
2925 }};
2926 AnnotationOptions annotation_options;
2927 annotation_options.locales = "en-US";
2928 StatusOr<Annotations> annotations_status =
2929 classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2930 ASSERT_TRUE(annotations_status.ok());
2931 Annotations annotations = annotations_status.ValueOrDie();
2932 ASSERT_EQ(annotations.annotated_spans.size(), 2);
2933 EXPECT_THAT(annotations.annotated_spans[0],
2934 ElementsAreArray({IsDatetimeSpan(
2935 /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000,
2936 DatetimeGranularity::GRANULARITY_MINUTE)}));
2937 EXPECT_THAT(annotations.annotated_spans[1],
2938 ElementsAreArray({IsDatetimeSpan(
2939 /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000,
2940 DatetimeGranularity::GRANULARITY_MINUTE)}));
2941 }
2942
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptions)2943 TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) {
2944 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2945 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2946 VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2947 }
2948
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx)2949 TEST_F(AnnotatorTest,
2950 InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx) {
2951 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2952 std::unique_ptr<Annotator> classifier =
2953 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2954 unilib_.get(), calendarlib_.get());
2955 VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2956 }
2957
2958 namespace {
AddDummyRegexDatetimeModel(ModelT * unpacked_model)2959 void AddDummyRegexDatetimeModel(ModelT* unpacked_model) {
2960 unpacked_model->datetime_model.reset(new DatetimeModelT);
2961 // This needs to be false otherwise we'd have to define some extractor. When
2962 // this is false, the 0-th capturing group (whole match) from the pattern is
2963 // used to come up with the indices.
2964 unpacked_model->datetime_model->use_extractors_for_locating = false;
2965 unpacked_model->datetime_model->locales.push_back("en-US");
2966 unpacked_model->datetime_model->default_locales.push_back(0); // en-US
2967 unpacked_model->datetime_model->patterns.push_back(
2968 std::unique_ptr<DatetimeModelPatternT>(new DatetimeModelPatternT));
2969 unpacked_model->datetime_model->patterns.back()->locales.push_back(
2970 0); // en-US
2971 unpacked_model->datetime_model->patterns.back()->regexes.push_back(
2972 std::unique_ptr<DatetimeModelPattern_::RegexT>(
2973 new DatetimeModelPattern_::RegexT));
2974 unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern =
2975 "THIS_MATCHES_IN_REGEX_MODEL";
2976 unpacked_model->datetime_model->patterns.back()
2977 ->regexes.back()
2978 ->groups.push_back(DatetimeGroupType_GROUP_UNUSED);
2979 }
2980 } // namespace
2981
TEST_F(AnnotatorTest,AnnotateFiltersOutExactDuplicates)2982 TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
2983 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2984 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2985 ASSERT_TRUE(classifier);
2986
2987 // This test assumes that both ML model and Regex model trigger on the
2988 // following text and output "phone" annotation for it.
2989 const std::string test_string = "1000000000";
2990 AnnotationOptions options;
2991 options.annotation_usecase = ANNOTATION_USECASE_RAW;
2992 int num_phones = 0;
2993 for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) {
2994 if (span.classification[0].collection == "phone") {
2995 num_phones++;
2996 }
2997 }
2998
2999 EXPECT_EQ(num_phones, 1);
3000 }
3001
3002 // This test tests the optimizations in Annotator, which make some of the
3003 // annotators not run in the RAW mode when not requested. We test here that the
3004 // results indeed don't contain such annotations. However, this is a bick hacky,
3005 // since one could also add post-filtering, in which case these tests would
3006 // trivially pass.
TEST_F(AnnotatorTest,RawModeOptimizationWorks)3007 TEST_F(AnnotatorTest, RawModeOptimizationWorks) {
3008 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
3009 GetTestModelPath(), unilib_.get(), calendarlib_.get());
3010 ASSERT_TRUE(classifier);
3011
3012 AnnotationOptions options;
3013 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3014 // Requesting a non-existing type to avoid overlap with existing types.
3015 options.entity_types.insert("some_unknown_entity_type");
3016
3017 // Normally, the following command would produce the following annotations:
3018 // Span(19, 24, date, 1.000000),
3019 // Span(53, 56, number, 1.000000),
3020 // Span(53, 80, address, 1.000000),
3021 // Span(128, 142, phone, 1.000000),
3022 // Span(129, 132, number, 1.000000),
3023 // Span(192, 200, phone, 1.000000),
3024 // Span(192, 206, datetime, 1.000000),
3025 // Span(246, 253, number, 1.000000),
3026 // Span(246, 253, phone, 1.000000),
3027 // Span(292, 293, number, 1.000000),
3028 // Span(292, 301, duration, 1.000000) }
3029 // But because of the optimizations, it doesn't produce anything, since
3030 // we didn't request any of these entities.
3031 EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today
3032 350 Third Street, Cambridge
3033 my phone number is (853) 225-3556
3034 this is when we met: 1.9.2021 13:00
3035 my number: 1234567
3036 duration: 3 minutes
3037 )--",
3038 options),
3039 IsEmpty());
3040 }
3041
VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(const Annotator * classifier)3042 void VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(
3043 const Annotator* classifier) {
3044 ASSERT_TRUE(classifier);
3045 struct Example {
3046 std::string collection;
3047 std::string text;
3048 };
3049
3050 // These examples contain one example per annotator, to check that each of
3051 // the annotators can work in the RAW mode on its own.
3052 //
3053 // WARNING: This list doesn't contain yet entries for the app, contact, and
3054 // person annotators. Hopefully this won't be needed once b/155214735 is
3055 // fixed and the piping shared across annotators.
3056 std::vector<Example> examples{
3057 // ML Model.
3058 {.collection = Collections::Address(),
3059 .text = "... 350 Third Street, Cambridge ..."},
3060 // Datetime annotator.
3061 {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."},
3062 // Duration annotator.
3063 {.collection = Collections::Duration(),
3064 .text = "... 3 hours and 9 seconds ..."},
3065 // Regex annotator.
3066 {.collection = Collections::Email(),
3067 .text = "... platypus@theanimal.org ..."},
3068 // Number annotator.
3069 {.collection = Collections::Number(), .text = "... 100 ..."},
3070 };
3071
3072 for (const Example& example : examples) {
3073 AnnotationOptions options;
3074 options.locales = "en";
3075 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3076 options.entity_types.insert(example.collection);
3077
3078 EXPECT_THAT(classifier->Annotate(example.text, options),
3079 Contains(IsAnnotationWithType(example.collection)))
3080 << " text: '" << example.text
3081 << "', collection: " << example.collection;
3082 }
3083 }
3084
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawMode)3085 TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) {
3086 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
3087 GetTestModelPath(), unilib_.get(), calendarlib_.get());
3088 VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3089 }
3090
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx)3091 TEST_F(AnnotatorTest,
3092 AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx) {
3093 std::string model_buffer = GetTestModelWithDatetimeRegEx();
3094 std::unique_ptr<Annotator> classifier =
3095 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
3096 unilib_.get(), calendarlib_.get());
3097 VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3098 }
3099
TEST_F(AnnotatorTest,InitializeFromString)3100 TEST_F(AnnotatorTest, InitializeFromString) {
3101 const std::string test_model = ReadFile(GetTestModelPath());
3102 std::unique_ptr<Annotator> classifier =
3103 Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3104 ASSERT_TRUE(classifier);
3105 EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
3106 }
3107
3108 // Regression test for cl/338280366. Enabling only_use_line_with_click had
3109 // the effect, that some annotators in the previous code releases would
3110 // receive only the last line of the input text. This test has the entity on the
3111 // first line (duration).
TEST_F(AnnotatorTest,RegressionTestOnlyUseLineWithClickLastLine)3112 TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) {
3113 const std::string test_model = ReadFile(GetTestModelPath());
3114 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
3115
3116 std::unique_ptr<Annotator> classifier;
3117
3118 // With unrestricted number of tokens should behave normally.
3119 unpacked_model->selection_feature_options->only_use_line_with_click = true;
3120
3121 flatbuffers::FlatBufferBuilder builder;
3122 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
3123 classifier = Annotator::FromUnownedBuffer(
3124 reinterpret_cast<const char*>(builder.GetBufferPointer()),
3125 builder.GetSize(), unilib_.get(), calendarlib_.get());
3126 ASSERT_TRUE(classifier);
3127
3128 AnnotationOptions options;
3129 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3130
3131 const std::vector<AnnotatedSpan> annotations =
3132 classifier->Annotate("let's meet in 3 hours\nbut not now", options);
3133
3134 EXPECT_THAT(annotations, Contains(IsDurationSpan(
3135 /*start=*/14, /*end=*/21,
3136 /*duration_ms=*/3 * 60 * 60 * 1000)));
3137 }
3138
TEST_F(AnnotatorTest,DoesntProcessInvalidUtf8)3139 TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
3140 const std::string test_model = ReadFile(GetTestModelPath());
3141 const std::string invalid_utf8_text_with_phone_number =
3142 "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
3143
3144 std::unique_ptr<Annotator> classifier =
3145 Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3146 ASSERT_TRUE(classifier);
3147 EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
3148 IsEmpty());
3149 EXPECT_THAT(
3150 classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
3151 Eq(CodepointSpan{1, 4}));
3152 EXPECT_THAT(
3153 classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
3154 IsEmpty());
3155 }
3156
3157 } // namespace test_internal
3158 } // namespace libtextclassifier3
3159