1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/annotator_test-include.h"
18
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <type_traits>
23
24 #include "annotator/annotator.h"
25 #include "annotator/collections.h"
26 #include "annotator/model_generated.h"
27 #include "annotator/test-utils.h"
28 #include "annotator/types-test-util.h"
29 #include "annotator/types.h"
30 #include "utils/grammar/utils/locale-shard-map.h"
31 #include "utils/grammar/utils/rules.h"
32 #include "utils/testing/annotator.h"
33 #include "lang_id/fb_model/lang-id-from-fb.h"
34 #include "lang_id/lang-id.h"
35
36 namespace libtextclassifier3 {
37 namespace test_internal {
38
39 using ::testing::Contains;
40 using ::testing::ElementsAre;
41 using ::testing::ElementsAreArray;
42 using ::testing::Eq;
43 using ::testing::IsEmpty;
44 using ::testing::UnorderedElementsAreArray;
45
GetTestModelPath()46 std::string GetTestModelPath() { return GetModelPath() + "test_model.fb"; }
47
GetModelWithVocabPath()48 std::string GetModelWithVocabPath() {
49 return GetModelPath() + "test_vocab_model.fb";
50 }
51
GetTestModelWithDatetimeRegEx()52 std::string GetTestModelWithDatetimeRegEx() {
53 std::string model_buffer = ReadFile(GetTestModelPath());
54 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
55 model->datetime_grammar_model.reset(nullptr);
56 });
57 return model_buffer;
58 }
59
ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan> & result,const std::string & currency,const std::string & amount,const int whole_part,const int decimal_part,const int nanos)60 void ExpectFirstEntityIsMoney(const std::vector<AnnotatedSpan>& result,
61 const std::string& currency,
62 const std::string& amount, const int whole_part,
63 const int decimal_part, const int nanos) {
64 ASSERT_GT(result.size(), 0);
65 ASSERT_GT(result[0].classification.size(), 0);
66 ASSERT_EQ(result[0].classification[0].collection, "money");
67
68 const EntityData* entity_data =
69 GetEntityData(result[0].classification[0].serialized_entity_data.data());
70 ASSERT_NE(entity_data, nullptr);
71 ASSERT_NE(entity_data->money(), nullptr);
72 EXPECT_EQ(entity_data->money()->unnormalized_currency()->str(), currency);
73 EXPECT_EQ(entity_data->money()->unnormalized_amount()->str(), amount);
74 EXPECT_EQ(entity_data->money()->amount_whole_part(), whole_part);
75 EXPECT_EQ(entity_data->money()->amount_decimal_part(), decimal_part);
76 EXPECT_EQ(entity_data->money()->nanos(), nanos);
77 }
78
TEST_F(AnnotatorTest,EmbeddingExecutorLoadingFails)79 TEST_F(AnnotatorTest, EmbeddingExecutorLoadingFails) {
80 std::unique_ptr<Annotator> classifier =
81 Annotator::FromPath(GetModelPath() + "wrong_embeddings.fb", unilib_.get(),
82 calendarlib_.get());
83 EXPECT_FALSE(classifier);
84 }
85
VerifyClassifyText(const Annotator * classifier)86 void VerifyClassifyText(const Annotator* classifier) {
87 ASSERT_TRUE(classifier);
88
89 EXPECT_EQ("other",
90 FirstResult(classifier->ClassifyText(
91 "this afternoon Barack Obama gave a speech at", {15, 27})));
92 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
93 "Call me at (800) 123-456 today", {11, 24})));
94
95 // More lines.
96 EXPECT_EQ("other",
97 FirstResult(classifier->ClassifyText(
98 "this afternoon Barack Obama gave a speech at|Visit "
99 "www.google.com every today!|Call me at (800) 123-456 today.",
100 {15, 27})));
101 EXPECT_EQ("phone",
102 FirstResult(classifier->ClassifyText(
103 "this afternoon Barack Obama gave a speech at|Visit "
104 "www.google.com every today!|Call me at (800) 123-456 today.",
105 {90, 103})));
106
107 // Single word.
108 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
109 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
110
111 // Junk. These should not crash the test.
112 classifier->ClassifyText("", {0, 0});
113 classifier->ClassifyText("asdf", {0, 0});
114 classifier->ClassifyText("asdf", {0, 27});
115 classifier->ClassifyText("asdf", {-30, 300});
116 classifier->ClassifyText("asdf", {-10, -1});
117 classifier->ClassifyText("asdf", {100, 17});
118 classifier->ClassifyText("a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5});
119
120 // Test invalid utf8 input.
121 EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
122 "\xf0\x9f\x98\x8b\x8b", {0, 0})));
123 }
124
TEST_F(AnnotatorTest,ClassifyText)125 TEST_F(AnnotatorTest, ClassifyText) {
126 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
127 GetTestModelPath(), unilib_.get(), calendarlib_.get());
128 VerifyClassifyText(classifier.get());
129 }
130
TEST_F(AnnotatorTest,ClassifyTextLocalesAndDictionary)131 TEST_F(AnnotatorTest, ClassifyTextLocalesAndDictionary) {
132 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
133 GetTestModelPath(), unilib_.get(), calendarlib_.get());
134 ASSERT_TRUE(classifier);
135
136 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("isotope", {0, 7})));
137
138 ClassificationOptions classification_options;
139 classification_options.detected_text_language_tags = "en";
140 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
141 "isotope", {0, 7}, classification_options)));
142
143 classification_options.detected_text_language_tags = "uz";
144 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
145 "isotope", {0, 7}, classification_options)));
146 }
147
TEST_F(AnnotatorTest,ClassifyTextUseVocabAnnotatorWithoutVocabModel)148 TEST_F(AnnotatorTest, ClassifyTextUseVocabAnnotatorWithoutVocabModel) {
149 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
150 GetTestModelPath(), unilib_.get(), calendarlib_.get());
151 ASSERT_TRUE(classifier);
152
153 ClassificationOptions classification_options;
154 classification_options.detected_text_language_tags = "en";
155 classification_options.use_vocab_annotator = true;
156
157 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
158 "isotope", {0, 7}, classification_options)));
159 }
160
161 #ifdef TC3_VOCAB_ANNOTATOR_IMPL
TEST_F(AnnotatorTest,ClassifyTextWithVocabModel)162 TEST_F(AnnotatorTest, ClassifyTextWithVocabModel) {
163 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
164 GetModelWithVocabPath(), unilib_.get(), calendarlib_.get());
165 ASSERT_TRUE(classifier);
166
167 ClassificationOptions classification_options;
168 classification_options.detected_text_language_tags = "en";
169
170 // The FFModel model does not annotate "integrity" as "dictionary", but the
171 // vocab annotator does. So we can use that to check if the vocab annotator is
172 // in use.
173 classification_options.use_vocab_annotator = true;
174 EXPECT_EQ("dictionary", FirstResult(classifier->ClassifyText(
175 "integrity", {0, 9}, classification_options)));
176 classification_options.use_vocab_annotator = false;
177 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
178 "integrity", {0, 9}, classification_options)));
179 }
180 #endif // TC3_VOCAB_ANNOTATOR_IMPL
181
TEST_F(AnnotatorTest,ClassifyTextDisabledFail)182 TEST_F(AnnotatorTest, ClassifyTextDisabledFail) {
183 const std::string test_model = ReadFile(GetTestModelPath());
184 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
185 TC3_CHECK(unpacked_model != nullptr);
186
187 unpacked_model->classification_model.clear();
188 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
189 unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
190
191 flatbuffers::FlatBufferBuilder builder;
192 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
193
194 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
195 reinterpret_cast<const char*>(builder.GetBufferPointer()),
196 builder.GetSize(), unilib_.get(), calendarlib_.get());
197
198 // The classification model is still needed for selection scores.
199 ASSERT_FALSE(classifier);
200 }
201
TEST_F(AnnotatorTest,ClassifyTextDisabled)202 TEST_F(AnnotatorTest, ClassifyTextDisabled) {
203 const std::string test_model = ReadFile(GetTestModelPath());
204 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
205
206 unpacked_model->enabled_modes = ModeFlag_ANNOTATION_AND_SELECTION;
207
208 flatbuffers::FlatBufferBuilder builder;
209 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
210
211 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
212 reinterpret_cast<const char*>(builder.GetBufferPointer()),
213 builder.GetSize(), unilib_.get(), calendarlib_.get());
214 ASSERT_TRUE(classifier);
215
216 EXPECT_THAT(
217 classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
218 IsEmpty());
219 }
220
TEST_F(AnnotatorTest,ClassifyTextFilteredCollections)221 TEST_F(AnnotatorTest, ClassifyTextFilteredCollections) {
222 const std::string test_model = ReadFile(GetTestModelPath());
223
224 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
225 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
226 ASSERT_TRUE(classifier);
227
228 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
229 "Call me at (800) 123-456 today", {11, 24})));
230
231 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
232 unpacked_model->output_options.reset(new OutputOptionsT);
233
234 // Disable phone classification
235 unpacked_model->output_options->filtered_collections_classification.push_back(
236 "phone");
237
238 flatbuffers::FlatBufferBuilder builder;
239 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
240
241 classifier = Annotator::FromUnownedBuffer(
242 reinterpret_cast<const char*>(builder.GetBufferPointer()),
243 builder.GetSize(), unilib_.get(), calendarlib_.get());
244 ASSERT_TRUE(classifier);
245
246 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
247 "Call me at (800) 123-456 today", {11, 24})));
248
249 // Check that the address classification still passes.
250 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
251 "350 Third Street, Cambridge", {0, 27})));
252 }
253
TEST_F(AnnotatorTest,ClassifyTextRegularExpression)254 TEST_F(AnnotatorTest, ClassifyTextRegularExpression) {
255 const std::string test_model = ReadFile(GetTestModelPath());
256 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
257
258 // Add test regex models.
259 unpacked_model->regex_model->patterns.push_back(MakePattern(
260 "person", "Barack Obama", /*enabled_for_classification=*/true,
261 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
262 unpacked_model->regex_model->patterns.push_back(MakePattern(
263 "flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
264 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
265 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
266 MakePattern("payment_card", "\\d{4}(?: \\d{4}){3}",
267 /*enabled_for_classification=*/true,
268 /*enabled_for_selection=*/false,
269 /*enabled_for_annotation=*/false, 1.0);
270 verified_pattern->verification_options.reset(new VerificationOptionsT);
271 verified_pattern->verification_options->verify_luhn_checksum = true;
272 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
273
274 flatbuffers::FlatBufferBuilder builder;
275 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
276
277 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
278 reinterpret_cast<const char*>(builder.GetBufferPointer()),
279 builder.GetSize(), unilib_.get(), calendarlib_.get());
280 ASSERT_TRUE(classifier);
281
282 EXPECT_EQ("flight",
283 FirstResult(classifier->ClassifyText(
284 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
285 EXPECT_EQ("person",
286 FirstResult(classifier->ClassifyText(
287 "this afternoon Barack Obama gave a speech at", {15, 27})));
288 EXPECT_EQ("email",
289 FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
290 EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
291 "Contact me at you@android.com", {14, 29})));
292
293 EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
294 "Visit www.google.com every today!", {6, 20})));
295
296 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
297 EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
298 {7, 12})));
299 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
300 "cc: 4012 8888 8888 1881", {4, 23})));
301 EXPECT_EQ("payment_card", FirstResult(classifier->ClassifyText(
302 "2221 0067 4735 6281", {0, 19})));
303 // Luhn check fails.
304 EXPECT_EQ("other", FirstResult(classifier->ClassifyText("2221 0067 4735 6282",
305 {0, 19})));
306
307 // More lines.
308 EXPECT_EQ("url",
309 FirstResult(classifier->ClassifyText(
310 "this afternoon Barack Obama gave a speech at|Visit "
311 "www.google.com every today!|Call me at (800) 123-456 today.",
312 {51, 65})));
313 }
314
315 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionLuaVerification)316 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionLuaVerification) {
317 const std::string test_model = ReadFile(GetTestModelPath());
318 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
319
320 // Add test regex models.
321 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
322 MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
323 /*enabled_for_classification=*/true,
324 /*enabled_for_selection=*/false,
325 /*enabled_for_annotation=*/false, 1.0);
326 verified_pattern->verification_options.reset(new VerificationOptionsT);
327 verified_pattern->verification_options->lua_verifier = 0;
328 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
329 unpacked_model->regex_model->lua_verifier.push_back(
330 "return match[2].text==\"99\"");
331
332 flatbuffers::FlatBufferBuilder builder;
333 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
334
335 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
336 reinterpret_cast<const char*>(builder.GetBufferPointer()),
337 builder.GetSize(), unilib_.get(), calendarlib_.get());
338 ASSERT_TRUE(classifier);
339
340 // Custom rule triggers and is correctly verified.
341 EXPECT_EQ("parcel_tracking", FirstResult(classifier->ClassifyText(
342 "99-00-123456-12345678", {0, 21})));
343
344 // Custom verification fails.
345 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
346 "90-00-123456-12345678", {0, 21})));
347 }
348 #endif // TC3_DISABLE_LUA
349
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityData)350 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityData) {
351 const std::string test_model = ReadFile(GetTestModelPath());
352 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
353
354 // Add fake entity schema metadata.
355 AddTestEntitySchemaData(unpacked_model.get());
356
357 AddTestRegexModel(unpacked_model.get());
358
359 flatbuffers::FlatBufferBuilder builder;
360 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
361
362 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
363 reinterpret_cast<const char*>(builder.GetBufferPointer()),
364 builder.GetSize(), unilib_.get(), calendarlib_.get());
365 ASSERT_TRUE(classifier);
366
367 // Check with full name.
368 {
369 auto classifications =
370 classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
371 EXPECT_EQ(1, classifications.size());
372 EXPECT_EQ("person_with_age", classifications[0].collection);
373
374 // Check entity data.
375 const flatbuffers::Table* entity =
376 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
377 classifications[0].serialized_entity_data.data()));
378 EXPECT_EQ(
379 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
380 "Barack");
381 EXPECT_EQ(
382 entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
383 "Obama");
384 // Check `age`.
385 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
386
387 // Check `is_alive`.
388 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
389
390 // Check `former_us_president`.
391 EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
392 }
393
394 // Check only with first name.
395 {
396 auto classifications =
397 classifier->ClassifyText("Barack is 57 years old", {0, 22});
398 EXPECT_EQ(1, classifications.size());
399 EXPECT_EQ("person_with_age", classifications[0].collection);
400
401 // Check entity data.
402 const flatbuffers::Table* entity =
403 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
404 classifications[0].serialized_entity_data.data()));
405 EXPECT_EQ(
406 entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
407 "Barack");
408
409 // Check `age`.
410 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
411
412 // Check `is_alive`.
413 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
414
415 // Check `former_us_president`.
416 EXPECT_FALSE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
417 }
418 }
419
TEST_F(AnnotatorTest,ClassifyTextRegularExpressionEntityDataNormalization)420 TEST_F(AnnotatorTest, ClassifyTextRegularExpressionEntityDataNormalization) {
421 const std::string test_model = ReadFile(GetTestModelPath());
422 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
423
424 // Add fake entity schema metadata.
425 AddTestEntitySchemaData(unpacked_model.get());
426
427 AddTestRegexModel(unpacked_model.get());
428
429 // Upper case last name as post-processing.
430 RegexModel_::PatternT* pattern =
431 unpacked_model->regex_model->patterns.back().get();
432 pattern->capturing_group[2]->normalization_options.reset(
433 new NormalizationOptionsT);
434 pattern->capturing_group[2]
435 ->normalization_options->codepointwise_normalization =
436 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
437
438 flatbuffers::FlatBufferBuilder builder;
439 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
440
441 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
442 reinterpret_cast<const char*>(builder.GetBufferPointer()),
443 builder.GetSize(), unilib_.get(), calendarlib_.get());
444 ASSERT_TRUE(classifier);
445
446 auto classifications =
447 classifier->ClassifyText("Barack Obama is 57 years old", {0, 28});
448 EXPECT_EQ(1, classifications.size());
449 EXPECT_EQ("person_with_age", classifications[0].collection);
450
451 // Check entity data normalization.
452 const flatbuffers::Table* entity =
453 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
454 classifications[0].serialized_entity_data.data()));
455 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
456 "OBAMA");
457 }
458
TEST_F(AnnotatorTest,ClassifyTextPriorityResolution)459 TEST_F(AnnotatorTest, ClassifyTextPriorityResolution) {
460 const std::string test_model = ReadFile(GetTestModelPath());
461 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
462 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
463 // Add test regex models.
464 unpacked_model->regex_model->patterns.clear();
465 unpacked_model->regex_model->patterns.push_back(MakePattern(
466 "flight1", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
467 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
468 /*score=*/1.0, /*priority_score=*/1.0));
469 unpacked_model->regex_model->patterns.push_back(MakePattern(
470 "flight2", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
471 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/false,
472 /*score=*/1.0, /*priority_score=*/0.0));
473
474 {
475 flatbuffers::FlatBufferBuilder builder;
476 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
477 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
478 reinterpret_cast<const char*>(builder.GetBufferPointer()),
479 builder.GetSize(), unilib_.get(), calendarlib_.get());
480 ASSERT_TRUE(classifier);
481
482 EXPECT_EQ("flight1",
483 FirstResult(classifier->ClassifyText(
484 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
485 }
486
487 unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
488 {
489 flatbuffers::FlatBufferBuilder builder;
490 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
491 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
492 reinterpret_cast<const char*>(builder.GetBufferPointer()),
493 builder.GetSize(), unilib_.get(), calendarlib_.get());
494 ASSERT_TRUE(classifier);
495
496 EXPECT_EQ("flight2",
497 FirstResult(classifier->ClassifyText(
498 "Your flight LX373 is delayed by 3 hours.", {12, 17})));
499 }
500 }
501
TEST_F(AnnotatorTest,AnnotatePriorityResolution)502 TEST_F(AnnotatorTest, AnnotatePriorityResolution) {
503 const std::string test_model = ReadFile(GetTestModelPath());
504 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
505 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
506 // Add test regex models. One of them has higher priority score than
507 // the other. We'll test that always the one with higher priority score
508 // ends up winning.
509 unpacked_model->regex_model->patterns.clear();
510 const std::string flight_regex = "([a-zA-Z]{2}\\d{2,4})";
511 unpacked_model->regex_model->patterns.push_back(MakePattern(
512 "flight", flight_regex, /*enabled_for_classification=*/true,
513 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
514 /*score=*/1.0, /*priority_score=*/1.0));
515 unpacked_model->regex_model->patterns.push_back(MakePattern(
516 "flight", flight_regex, /*enabled_for_classification=*/true,
517 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/true,
518 /*score=*/1.0, /*priority_score=*/0.0));
519
520 // "flight" that wins should have a priority score of 1.0.
521 {
522 flatbuffers::FlatBufferBuilder builder;
523 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
524 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
525 reinterpret_cast<const char*>(builder.GetBufferPointer()),
526 builder.GetSize(), unilib_.get(), calendarlib_.get());
527 ASSERT_TRUE(classifier);
528
529 const std::vector<AnnotatedSpan> results =
530 classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
531 ASSERT_THAT(results, Not(IsEmpty()));
532 EXPECT_THAT(results[0].classification, Not(IsEmpty()));
533 EXPECT_GE(results[0].classification[0].priority_score, 0.9);
534 }
535
536 // When we increase the priority score, the "flight" that wins should have a
537 // priority score of 3.0.
538 unpacked_model->regex_model->patterns.back()->priority_score = 3.0;
539 {
540 flatbuffers::FlatBufferBuilder builder;
541 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
542 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
543 reinterpret_cast<const char*>(builder.GetBufferPointer()),
544 builder.GetSize(), unilib_.get(), calendarlib_.get());
545 ASSERT_TRUE(classifier);
546
547 const std::vector<AnnotatedSpan> results =
548 classifier->Annotate("Your flight LX373 is delayed by 3 hours.");
549 ASSERT_THAT(results, Not(IsEmpty()));
550 EXPECT_THAT(results[0].classification, Not(IsEmpty()));
551 EXPECT_GE(results[0].classification[0].priority_score, 2.9);
552 }
553 }
554
TEST_F(AnnotatorTest,SuggestSelectionRegularExpression)555 TEST_F(AnnotatorTest, SuggestSelectionRegularExpression) {
556 const std::string test_model = ReadFile(GetTestModelPath());
557 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
558
559 // Add test regex models.
560 unpacked_model->regex_model->patterns.push_back(MakePattern(
561 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
562 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
563 unpacked_model->regex_model->patterns.push_back(MakePattern(
564 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
565 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
566 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
567 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
568 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
569 /*enabled_for_classification=*/false,
570 /*enabled_for_selection=*/true,
571 /*enabled_for_annotation=*/false, 1.0);
572 verified_pattern->verification_options.reset(new VerificationOptionsT);
573 verified_pattern->verification_options->verify_luhn_checksum = true;
574 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
575
576 flatbuffers::FlatBufferBuilder builder;
577 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
578
579 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
580 reinterpret_cast<const char*>(builder.GetBufferPointer()),
581 builder.GetSize(), unilib_.get(), calendarlib_.get());
582 ASSERT_TRUE(classifier);
583
584 // Check regular expression selection.
585 EXPECT_EQ(classifier->SuggestSelection(
586 "Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
587 CodepointSpan(12, 19));
588 EXPECT_EQ(classifier->SuggestSelection(
589 "this afternoon Barack Obama gave a speech at", {15, 21}),
590 CodepointSpan(15, 27));
591 EXPECT_EQ(classifier->SuggestSelection("cc: 4012 8888 8888 1881", {9, 14}),
592 CodepointSpan(4, 23));
593 }
594
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionCustomSelectionBounds)595 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionCustomSelectionBounds) {
596 const std::string test_model = ReadFile(GetTestModelPath());
597 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
598
599 // Add test regex models.
600 std::unique_ptr<RegexModel_::PatternT> custom_selection_bounds_pattern =
601 MakePattern("date_range",
602 "(?:(?:from )?(\\d{2}\\/\\d{2}\\/\\d{4}) to "
603 "(\\d{2}\\/\\d{2}\\/\\d{4}))|(for ever)",
604 /*enabled_for_classification=*/false,
605 /*enabled_for_selection=*/true,
606 /*enabled_for_annotation=*/false, 1.0);
607 custom_selection_bounds_pattern->capturing_group.emplace_back(
608 new CapturingGroupT);
609 custom_selection_bounds_pattern->capturing_group.emplace_back(
610 new CapturingGroupT);
611 custom_selection_bounds_pattern->capturing_group.emplace_back(
612 new CapturingGroupT);
613 custom_selection_bounds_pattern->capturing_group.emplace_back(
614 new CapturingGroupT);
615 custom_selection_bounds_pattern->capturing_group[0]->extend_selection = false;
616 custom_selection_bounds_pattern->capturing_group[1]->extend_selection = true;
617 custom_selection_bounds_pattern->capturing_group[2]->extend_selection = true;
618 custom_selection_bounds_pattern->capturing_group[3]->extend_selection = true;
619 unpacked_model->regex_model->patterns.push_back(
620 std::move(custom_selection_bounds_pattern));
621
622 flatbuffers::FlatBufferBuilder builder;
623 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
624
625 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
626 reinterpret_cast<const char*>(builder.GetBufferPointer()),
627 builder.GetSize(), unilib_.get(), calendarlib_.get());
628 ASSERT_TRUE(classifier);
629
630 // Check regular expression selection.
631 EXPECT_EQ(classifier->SuggestSelection("it's from 04/30/1789 to 03/04/1797",
632 {21, 23}),
633 CodepointSpan(10, 34));
634 EXPECT_EQ(classifier->SuggestSelection("it takes for ever", {9, 12}),
635 CodepointSpan(9, 17));
636 }
637
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsModelWins)638 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsModelWins) {
639 const std::string test_model = ReadFile(GetTestModelPath());
640 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
641
642 // Add test regex models.
643 unpacked_model->regex_model->patterns.push_back(MakePattern(
644 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
645 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
646 unpacked_model->regex_model->patterns.push_back(MakePattern(
647 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
648 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
649 unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
650
651 flatbuffers::FlatBufferBuilder builder;
652 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
653
654 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
655 reinterpret_cast<const char*>(builder.GetBufferPointer()),
656 builder.GetSize(), unilib_.get(), calendarlib_.get());
657 ASSERT_TRUE(classifier);
658
659 // Check conflict resolution.
660 EXPECT_EQ(
661 classifier->SuggestSelection(
662 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
663 {55, 57}),
664 CodepointSpan(26, 62));
665 }
666
TEST_F(AnnotatorTest,SuggestSelectionRegularExpressionConflictsRegexWins)667 TEST_F(AnnotatorTest, SuggestSelectionRegularExpressionConflictsRegexWins) {
668 const std::string test_model = ReadFile(GetTestModelPath());
669 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
670
671 // Add test regex models.
672 unpacked_model->regex_model->patterns.push_back(MakePattern(
673 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
674 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
675 unpacked_model->regex_model->patterns.push_back(MakePattern(
676 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
677 /*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
678 unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
679
680 flatbuffers::FlatBufferBuilder builder;
681 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
682
683 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
684 reinterpret_cast<const char*>(builder.GetBufferPointer()),
685 builder.GetSize(), unilib_.get(), calendarlib_.get());
686 ASSERT_TRUE(classifier);
687
688 // Check conflict resolution.
689 EXPECT_EQ(
690 classifier->SuggestSelection(
691 "saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
692 {55, 57}),
693 CodepointSpan(55, 62));
694 }
695
TEST_F(AnnotatorTest,AnnotateRegex)696 TEST_F(AnnotatorTest, AnnotateRegex) {
697 const std::string test_model = ReadFile(GetTestModelPath());
698 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
699
700 // Add test regex models.
701 unpacked_model->regex_model->patterns.push_back(MakePattern(
702 "person", " (Barack Obama) ", /*enabled_for_classification=*/false,
703 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
704 unpacked_model->regex_model->patterns.push_back(MakePattern(
705 "flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
706 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
707 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
708 MakePattern("payment_card", "(\\d{4}(?: \\d{4}){3})",
709 /*enabled_for_classification=*/false,
710 /*enabled_for_selection=*/false,
711 /*enabled_for_annotation=*/true, 1.0);
712 verified_pattern->verification_options.reset(new VerificationOptionsT);
713 verified_pattern->verification_options->verify_luhn_checksum = true;
714 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
715 flatbuffers::FlatBufferBuilder builder;
716 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
717
718 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
719 reinterpret_cast<const char*>(builder.GetBufferPointer()),
720 builder.GetSize(), unilib_.get(), calendarlib_.get());
721 ASSERT_TRUE(classifier);
722
723 const std::string test_string =
724 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
725 "number is 853 225 3556\nand my card is 4012 8888 8888 1881.\n";
726 EXPECT_THAT(classifier->Annotate(test_string),
727 ElementsAreArray({IsAnnotatedSpan(6, 18, "person"),
728 IsAnnotatedSpan(28, 55, "address"),
729 IsAnnotatedSpan(79, 91, "phone"),
730 IsAnnotatedSpan(107, 126, "payment_card")}));
731 }
732
TEST_F(AnnotatorTest,AnnotatesFlightNumbers)733 TEST_F(AnnotatorTest, AnnotatesFlightNumbers) {
734 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
735 GetTestModelPath(), unilib_.get(), calendarlib_.get());
736 ASSERT_TRUE(classifier);
737
738 // ICAO is only used for selected airlines.
739 // Expected: LX373, EZY1234 and U21234.
740 const std::string test_string = "flights LX373, SWR373, EZY1234, U21234";
741 EXPECT_THAT(classifier->Annotate(test_string),
742 ElementsAreArray({IsAnnotatedSpan(8, 13, "flight"),
743 IsAnnotatedSpan(23, 30, "flight"),
744 IsAnnotatedSpan(32, 38, "flight")}));
745 }
746
747 #ifndef TC3_DISABLE_LUA
TEST_F(AnnotatorTest,AnnotateRegexLuaVerification)748 TEST_F(AnnotatorTest, AnnotateRegexLuaVerification) {
749 const std::string test_model = ReadFile(GetTestModelPath());
750 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
751
752 // Add test regex models.
753 std::unique_ptr<RegexModel_::PatternT> verified_pattern =
754 MakePattern("parcel_tracking", "((\\d{2})-00-\\d{6}-\\d{8})",
755 /*enabled_for_classification=*/true,
756 /*enabled_for_selection=*/true,
757 /*enabled_for_annotation=*/true, 1.0);
758 verified_pattern->verification_options.reset(new VerificationOptionsT);
759 verified_pattern->verification_options->lua_verifier = 0;
760 unpacked_model->regex_model->patterns.push_back(std::move(verified_pattern));
761 unpacked_model->regex_model->lua_verifier.push_back(
762 "return match[2].text==\"99\"");
763
764 flatbuffers::FlatBufferBuilder builder;
765 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
766
767 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
768 reinterpret_cast<const char*>(builder.GetBufferPointer()),
769 builder.GetSize(), unilib_.get(), calendarlib_.get());
770 ASSERT_TRUE(classifier);
771
772 const std::string test_string =
773 "your parcel is on the way: 99-00-123456-12345678";
774 EXPECT_THAT(classifier->Annotate(test_string),
775 ElementsAreArray({IsAnnotatedSpan(27, 48, "parcel_tracking")}));
776 }
777 #endif // TC3_DISABLE_LUA
778
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityData)779 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityData) {
780 const std::string test_model = ReadFile(GetTestModelPath());
781 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
782
783 // Add fake entity schema metadata.
784 AddTestEntitySchemaData(unpacked_model.get());
785
786 AddTestRegexModel(unpacked_model.get());
787
788 flatbuffers::FlatBufferBuilder builder;
789 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
790
791 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
792 reinterpret_cast<const char*>(builder.GetBufferPointer()),
793 builder.GetSize(), unilib_.get(), calendarlib_.get());
794 ASSERT_TRUE(classifier);
795
796 AnnotationOptions options;
797 options.is_serialized_entity_data_enabled = true;
798 auto annotations =
799 classifier->Annotate("Barack Obama is 57 years old", options);
800 EXPECT_EQ(1, annotations.size());
801 EXPECT_EQ(1, annotations[0].classification.size());
802 EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
803
804 // Check entity data.
805 const flatbuffers::Table* entity =
806 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
807 annotations[0].classification[0].serialized_entity_data.data()));
808 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
809 "Barack");
810 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
811 "Obama");
812 // Check `age`.
813 EXPECT_EQ(entity->GetField<int>(/*field=*/10, /*defaultval=*/0), 57);
814
815 // Check `is_alive`.
816 EXPECT_TRUE(entity->GetField<bool>(/*field=*/6, /*defaultval=*/false));
817
818 // Check `former_us_president`.
819 EXPECT_TRUE(entity->GetField<bool>(/*field=*/12, /*defaultval=*/false));
820 }
821
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataNormalization)822 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataNormalization) {
823 const std::string test_model = ReadFile(GetTestModelPath());
824 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
825
826 // Add fake entity schema metadata.
827 AddTestEntitySchemaData(unpacked_model.get());
828
829 AddTestRegexModel(unpacked_model.get());
830
831 // Upper case last name as post-processing.
832 RegexModel_::PatternT* pattern =
833 unpacked_model->regex_model->patterns.back().get();
834 pattern->capturing_group[2]->normalization_options.reset(
835 new NormalizationOptionsT);
836 pattern->capturing_group[2]
837 ->normalization_options->codepointwise_normalization =
838 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
839
840 flatbuffers::FlatBufferBuilder builder;
841 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
842
843 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
844 reinterpret_cast<const char*>(builder.GetBufferPointer()),
845 builder.GetSize(), unilib_.get(), calendarlib_.get());
846 ASSERT_TRUE(classifier);
847
848 AnnotationOptions options;
849 options.is_serialized_entity_data_enabled = true;
850 auto annotations =
851 classifier->Annotate("Barack Obama is 57 years old", options);
852 EXPECT_EQ(1, annotations.size());
853 EXPECT_EQ(1, annotations[0].classification.size());
854 EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
855
856 // Check normalization.
857 const flatbuffers::Table* entity =
858 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
859 annotations[0].classification[0].serialized_entity_data.data()));
860 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
861 "OBAMA");
862 }
863
TEST_F(AnnotatorTest,AnnotateTextRegularExpressionEntityDataDisabled)864 TEST_F(AnnotatorTest, AnnotateTextRegularExpressionEntityDataDisabled) {
865 const std::string test_model = ReadFile(GetTestModelPath());
866 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
867
868 // Add fake entity schema metadata.
869 AddTestEntitySchemaData(unpacked_model.get());
870
871 AddTestRegexModel(unpacked_model.get());
872
873 flatbuffers::FlatBufferBuilder builder;
874 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
875
876 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
877 reinterpret_cast<const char*>(builder.GetBufferPointer()),
878 builder.GetSize(), unilib_.get(), calendarlib_.get());
879 ASSERT_TRUE(classifier);
880
881 AnnotationOptions options;
882 options.is_serialized_entity_data_enabled = false;
883 auto annotations =
884 classifier->Annotate("Barack Obama is 57 years old", options);
885 EXPECT_EQ(1, annotations.size());
886 EXPECT_EQ(1, annotations[0].classification.size());
887 EXPECT_EQ("person_with_age", annotations[0].classification[0].collection);
888
889 // Check entity data.
890 EXPECT_EQ("", annotations[0].classification[0].serialized_entity_data);
891 }
892
TEST_F(AnnotatorTest,PhoneFiltering)893 TEST_F(AnnotatorTest, PhoneFiltering) {
894 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
895 GetTestModelPath(), unilib_.get(), calendarlib_.get());
896 ASSERT_TRUE(classifier);
897
898 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
899 "phone: (123) 456 789", {7, 20})));
900 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
901 "phone: (123) 456 789,0001112", {7, 25})));
902 EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
903 "phone: (123) 456 789,0001112", {7, 28})));
904 }
905
TEST_F(AnnotatorTest,SuggestSelection)906 TEST_F(AnnotatorTest, SuggestSelection) {
907 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
908 GetTestModelPath(), unilib_.get(), calendarlib_.get());
909 ASSERT_TRUE(classifier);
910
911 EXPECT_EQ(classifier->SuggestSelection(
912 "this afternoon Barack Obama gave a speech at", {15, 21}),
913 CodepointSpan(15, 21));
914
915 // Try passing whole string.
916 // If more than 1 token is specified, we should return back what entered.
917 EXPECT_EQ(
918 classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
919 CodepointSpan(0, 27));
920
921 // Single letter.
922 EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), CodepointSpan(0, 1));
923
924 // Single word.
925 EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), CodepointSpan(0, 4));
926
927 EXPECT_EQ(
928 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
929 CodepointSpan(11, 23));
930
931 // Unpaired bracket stripping.
932 EXPECT_EQ(
933 classifier->SuggestSelection("call me at (857) 225 3556 today", {12, 14}),
934 CodepointSpan(11, 25));
935 EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {12, 14}),
936 CodepointSpan(12, 15));
937 EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {12, 14}),
938 CodepointSpan(11, 15));
939 EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {12, 14}),
940 CodepointSpan(12, 15));
941
942 // If the resulting selection would be empty, the original span is returned.
943 EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
944 CodepointSpan(11, 13));
945 EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
946 CodepointSpan(11, 12));
947 EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
948 CodepointSpan(11, 12));
949
950 // If the original span is larger than the found selection, the original span
951 // is returned.
952 EXPECT_EQ(
953 classifier->SuggestSelection("call me at 857 225 3556 today", {5, 24}),
954 CodepointSpan(5, 24));
955 }
956
TEST_F(AnnotatorTest,SuggestSelectionDisabledFail)957 TEST_F(AnnotatorTest, SuggestSelectionDisabledFail) {
958 const std::string test_model = ReadFile(GetTestModelPath());
959 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
960
961 // Disable the selection model.
962 unpacked_model->selection_model.clear();
963 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
964 unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
965
966 flatbuffers::FlatBufferBuilder builder;
967 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
968
969 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
970 reinterpret_cast<const char*>(builder.GetBufferPointer()),
971 builder.GetSize(), unilib_.get(), calendarlib_.get());
972 // Selection model needs to be present for annotation.
973 ASSERT_FALSE(classifier);
974 }
975
TEST_F(AnnotatorTest,SuggestSelectionDisabled)976 TEST_F(AnnotatorTest, SuggestSelectionDisabled) {
977 const std::string test_model = ReadFile(GetTestModelPath());
978 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
979
980 // Disable the selection model.
981 unpacked_model->selection_model.clear();
982 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
983 unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
984 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
985
986 // Disable the number annotator. With the selection model disabled, there is
987 // no feature processor, which is required for the number annotator.
988 unpacked_model->number_annotator_options->enabled = false;
989
990 flatbuffers::FlatBufferBuilder builder;
991 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
992
993 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
994 reinterpret_cast<const char*>(builder.GetBufferPointer()),
995 builder.GetSize(), unilib_.get(), calendarlib_.get());
996 ASSERT_TRUE(classifier);
997
998 EXPECT_EQ(
999 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1000 CodepointSpan(11, 14));
1001
1002 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
1003 "call me at (800) 123-456 today", {11, 24})));
1004
1005 EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
1006 IsEmpty());
1007 }
1008
TEST_F(AnnotatorTest,SuggestSelectionFilteredCollections)1009 TEST_F(AnnotatorTest, SuggestSelectionFilteredCollections) {
1010 const std::string test_model = ReadFile(GetTestModelPath());
1011
1012 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1013 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1014 ASSERT_TRUE(classifier);
1015
1016 EXPECT_EQ(
1017 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1018 CodepointSpan(11, 23));
1019
1020 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1021 unpacked_model->output_options.reset(new OutputOptionsT);
1022
1023 // Disable phone selection
1024 unpacked_model->output_options->filtered_collections_selection.push_back(
1025 "phone");
1026 // We need to force this for filtering.
1027 unpacked_model->selection_options->always_classify_suggested_selection = true;
1028
1029 flatbuffers::FlatBufferBuilder builder;
1030 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1031
1032 classifier = Annotator::FromUnownedBuffer(
1033 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1034 builder.GetSize(), unilib_.get(), calendarlib_.get());
1035 ASSERT_TRUE(classifier);
1036
1037 EXPECT_EQ(
1038 classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
1039 CodepointSpan(11, 14));
1040
1041 // Address selection should still work.
1042 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1043 CodepointSpan(0, 27));
1044 }
1045
TEST_F(AnnotatorTest,SuggestSelectionsAreSymmetric)1046 TEST_F(AnnotatorTest, SuggestSelectionsAreSymmetric) {
1047 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1048 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1049 ASSERT_TRUE(classifier);
1050
1051 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
1052 CodepointSpan(0, 27));
1053 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
1054 CodepointSpan(0, 27));
1055 EXPECT_EQ(
1056 classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
1057 CodepointSpan(0, 27));
1058 EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
1059 {16, 22}),
1060 CodepointSpan(6, 33));
1061 }
1062
TEST_F(AnnotatorTest,SuggestSelectionWithNewLine)1063 TEST_F(AnnotatorTest, SuggestSelectionWithNewLine) {
1064 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1065 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1066 ASSERT_TRUE(classifier);
1067
1068 EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
1069 CodepointSpan(4, 16));
1070 EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
1071 CodepointSpan(0, 12));
1072
1073 SelectionOptions options;
1074 EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
1075 CodepointSpan(0, 12));
1076 }
1077
TEST_F(AnnotatorTest,SuggestSelectionWithPunctuation)1078 TEST_F(AnnotatorTest, SuggestSelectionWithPunctuation) {
1079 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1080 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1081 ASSERT_TRUE(classifier);
1082
1083 // From the right.
1084 EXPECT_EQ(classifier->SuggestSelection(
1085 "this afternoon BarackObama, gave a speech at", {15, 26}),
1086 CodepointSpan(15, 26));
1087
1088 // From the right multiple.
1089 EXPECT_EQ(classifier->SuggestSelection(
1090 "this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
1091 CodepointSpan(15, 26));
1092
1093 // From the left multiple.
1094 EXPECT_EQ(classifier->SuggestSelection(
1095 "this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
1096 CodepointSpan(21, 32));
1097
1098 // From both sides.
1099 EXPECT_EQ(classifier->SuggestSelection(
1100 "this afternoon !BarackObama,- gave a speech at", {16, 27}),
1101 CodepointSpan(16, 27));
1102 }
1103
TEST_F(AnnotatorTest,SuggestSelectionNoCrashWithJunk)1104 TEST_F(AnnotatorTest, SuggestSelectionNoCrashWithJunk) {
1105 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1106 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1107 ASSERT_TRUE(classifier);
1108
1109 // Try passing in bunch of invalid selections.
1110 EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), CodepointSpan(0, 27));
1111 EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
1112 CodepointSpan(-10, 27));
1113 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
1114 CodepointSpan(0, 27));
1115 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
1116 CodepointSpan(-30, 300));
1117 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
1118 CodepointSpan(-10, -1));
1119 EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
1120 CodepointSpan(100, 17));
1121
1122 // Try passing invalid utf8.
1123 EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
1124 CodepointSpan(-1, -1));
1125 }
1126
TEST_F(AnnotatorTest,SuggestSelectionSelectSpace)1127 TEST_F(AnnotatorTest, SuggestSelectionSelectSpace) {
1128 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1129 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1130 ASSERT_TRUE(classifier);
1131
1132 EXPECT_EQ(
1133 classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
1134 CodepointSpan(11, 23));
1135 EXPECT_EQ(
1136 classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
1137 CodepointSpan(10, 11));
1138 EXPECT_EQ(
1139 classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
1140 CodepointSpan(23, 24));
1141 EXPECT_EQ(
1142 classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
1143 CodepointSpan(23, 24));
1144 EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
1145 {14, 17}),
1146 CodepointSpan(11, 25));
1147 EXPECT_EQ(
1148 classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
1149 CodepointSpan(11, 23));
1150 EXPECT_EQ(
1151 classifier->SuggestSelection(
1152 "let's meet at 350 Third Street Cambridge and go there", {30, 31}),
1153 CodepointSpan(14, 40));
1154 EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
1155 CodepointSpan(4, 5));
1156 EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
1157 CodepointSpan(7, 8));
1158
1159 // With a punctuation around the selected whitespace.
1160 EXPECT_EQ(
1161 classifier->SuggestSelection(
1162 "let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
1163 CodepointSpan(14, 41));
1164
1165 // When all's whitespace, should return the original indices.
1166 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
1167 CodepointSpan(0, 1));
1168 EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
1169 CodepointSpan(0, 3));
1170 EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
1171 CodepointSpan(2, 3));
1172 EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
1173 CodepointSpan(5, 6));
1174 }
1175
TEST_F(AnnotatorTest,SnapLeftIfWhitespaceSelection)1176 TEST_F(AnnotatorTest, SnapLeftIfWhitespaceSelection) {
1177 UnicodeText text;
1178
1179 text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
1180 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1181 CodepointSpan(3, 4));
1182 text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
1183 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1184 CodepointSpan(3, 4));
1185
1186 // Nothing on the left.
1187 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
1188 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1189 CodepointSpan(4, 5));
1190 text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
1191 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1192 CodepointSpan(0, 1));
1193
1194 // Whitespace only.
1195 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
1196 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, *unilib_),
1197 CodepointSpan(2, 3));
1198 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
1199 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, *unilib_),
1200 CodepointSpan(4, 5));
1201 text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
1202 EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, *unilib_),
1203 CodepointSpan(0, 1));
1204 }
1205
TEST_F(AnnotatorTest,Annotate)1206 TEST_F(AnnotatorTest, Annotate) {
1207 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1208 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1209 ASSERT_TRUE(classifier);
1210
1211 const std::string test_string =
1212 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1213 "number is 853 225 3556";
1214 EXPECT_THAT(classifier->Annotate(test_string),
1215 ElementsAreArray({
1216 IsAnnotatedSpan(28, 55, "address"),
1217 IsAnnotatedSpan(79, 91, "phone"),
1218 }));
1219
1220 AnnotationOptions options;
1221 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1222 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1223 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1224 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1225 // Try passing invalid utf8.
1226 EXPECT_TRUE(
1227 classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
1228 .empty());
1229 }
1230
TEST_F(AnnotatorTest,AnnotatesWithBracketStripping)1231 TEST_F(AnnotatorTest, AnnotatesWithBracketStripping) {
1232 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1233 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1234 ASSERT_TRUE(classifier);
1235
1236 EXPECT_THAT(classifier->Annotate("call me at (0845) 100 1000 today"),
1237 ElementsAreArray({
1238 IsAnnotatedSpan(11, 26, "phone"),
1239 }));
1240
1241 // Unpaired bracket stripping.
1242 EXPECT_THAT(classifier->Annotate("call me at (07038201818 today"),
1243 ElementsAreArray({
1244 IsAnnotatedSpan(12, 23, "phone"),
1245 }));
1246 EXPECT_THAT(classifier->Annotate("call me at 07038201818) today"),
1247 ElementsAreArray({
1248 IsAnnotatedSpan(11, 22, "phone"),
1249 }));
1250 EXPECT_THAT(classifier->Annotate("call me at )07038201818( today"),
1251 ElementsAreArray({
1252 IsAnnotatedSpan(12, 23, "phone"),
1253 }));
1254 }
1255
TEST_F(AnnotatorTest,AnnotatesOverlappingNumbers)1256 TEST_F(AnnotatorTest, AnnotatesOverlappingNumbers) {
1257 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1258 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1259 ASSERT_TRUE(classifier);
1260 AnnotationOptions options;
1261 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1262
1263 // Number, float number and percentage annotator.
1264 EXPECT_THAT(
1265 classifier->Annotate("853 225 3556 and then turn it up 99%, 99 "
1266 "number, 12345.12345 float number",
1267 options),
1268 UnorderedElementsAreArray(
1269 {IsAnnotatedSpan(0, 12, "phone"), IsAnnotatedSpan(0, 3, "number"),
1270 IsAnnotatedSpan(4, 7, "number"), IsAnnotatedSpan(8, 12, "number"),
1271 IsAnnotatedSpan(33, 35, "number"),
1272 IsAnnotatedSpan(33, 36, "percentage"),
1273 IsAnnotatedSpan(38, 40, "number"), IsAnnotatedSpan(49, 60, "number"),
1274 IsAnnotatedSpan(49, 60, "phone")}));
1275 }
1276
TEST_F(AnnotatorTest,DoesNotAnnotateNumbersInSmartUsecase)1277 TEST_F(AnnotatorTest, DoesNotAnnotateNumbersInSmartUsecase) {
1278 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1279 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1280 ASSERT_TRUE(classifier);
1281 AnnotationOptions options;
1282 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1283
1284 EXPECT_THAT(classifier->Annotate(
1285 "853 225 3556 and then turn it up 99%, 99 number", options),
1286 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone"),
1287 IsAnnotatedSpan(33, 36, "percentage")}));
1288 }
1289
VerifyAnnotatesDurationsInRawMode(const Annotator * classifier)1290 void VerifyAnnotatesDurationsInRawMode(const Annotator* classifier) {
1291 ASSERT_TRUE(classifier);
1292 AnnotationOptions options;
1293 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1294
1295 // Duration annotator.
1296 EXPECT_THAT(classifier->Annotate(
1297 "it took 9 minutes and 7 seconds to get there", options),
1298 Contains(IsDurationSpan(
1299 /*start=*/8, /*end=*/31,
1300 /*duration_ms=*/9 * 60 * 1000 + 7 * 1000)));
1301 }
1302
TEST_F(AnnotatorTest,AnnotatesDurationsInRawMode)1303 TEST_F(AnnotatorTest, AnnotatesDurationsInRawMode) {
1304 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1305 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1306 VerifyAnnotatesDurationsInRawMode(classifier.get());
1307 }
1308
VerifyDurationAndRelativeTimeCanOverlapInRawMode(const Annotator * classifier)1309 void VerifyDurationAndRelativeTimeCanOverlapInRawMode(
1310 const Annotator* classifier) {
1311 ASSERT_TRUE(classifier);
1312 AnnotationOptions options;
1313 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
1314 options.locales = "en";
1315
1316 const std::vector<AnnotatedSpan> annotations =
1317 classifier->Annotate("let's meet in 3 hours", options);
1318
1319 EXPECT_THAT(annotations,
1320 Contains(IsDatetimeSpan(/*start=*/11, /*end=*/21,
1321 /*time_ms_utc=*/10800000L,
1322 DatetimeGranularity::GRANULARITY_HOUR)));
1323 EXPECT_THAT(annotations,
1324 Contains(IsDurationSpan(/*start=*/14, /*end=*/21,
1325 /*duration_ms=*/3 * 60 * 60 * 1000)));
1326 }
1327
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawMode)1328 TEST_F(AnnotatorTest, DurationAndRelativeTimeCanOverlapInRawMode) {
1329 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1330 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1331 VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1332 }
1333
TEST_F(AnnotatorTest,DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx)1334 TEST_F(AnnotatorTest,
1335 DurationAndRelativeTimeCanOverlapInRawModeWithDatetimeRegEx) {
1336 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1337 std::unique_ptr<Annotator> classifier =
1338 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1339 unilib_.get(), calendarlib_.get());
1340 VerifyDurationAndRelativeTimeCanOverlapInRawMode(classifier.get());
1341 }
1342
TEST_F(AnnotatorTest,AnnotateSplitLines)1343 TEST_F(AnnotatorTest, AnnotateSplitLines) {
1344 std::string model_buffer = ReadFile(GetTestModelPath());
1345 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1346 model->selection_feature_options->only_use_line_with_click = true;
1347 });
1348 std::unique_ptr<Annotator> classifier =
1349 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1350 unilib_.get(), calendarlib_.get());
1351
1352 ASSERT_TRUE(classifier);
1353
1354 const std::string str1 =
1355 "hey, sorry, just finished up. i didn't hear back from you in time.";
1356 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1357
1358 const int kAnnotationLength = 26;
1359 EXPECT_THAT(classifier->Annotate(str1), IsEmpty());
1360 EXPECT_THAT(
1361 classifier->Annotate(str2),
1362 ElementsAreArray({IsAnnotatedSpan(0, kAnnotationLength, "address")}));
1363
1364 const std::string str3 = str1 + "\n" + str2;
1365 EXPECT_THAT(
1366 classifier->Annotate(str3),
1367 ElementsAreArray({IsAnnotatedSpan(
1368 str1.size() + 1, str1.size() + 1 + kAnnotationLength, "address")}));
1369 }
1370
TEST_F(AnnotatorTest,UsePipeAsNewLineCharacterShouldAnnotateSplitLines)1371 TEST_F(AnnotatorTest, UsePipeAsNewLineCharacterShouldAnnotateSplitLines) {
1372 std::string model_buffer = ReadFile(GetTestModelPath());
1373 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1374 model->selection_feature_options->only_use_line_with_click = true;
1375 model->selection_feature_options->use_pipe_character_for_newline = true;
1376 });
1377 std::unique_ptr<Annotator> classifier =
1378 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1379 unilib_.get(), calendarlib_.get());
1380
1381 ASSERT_TRUE(classifier);
1382
1383 const std::string str1 = "hey, this is my phone number 853 225 3556";
1384 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1385 const std::string str3 = str1 + "|" + str2;
1386 const int kAnnotationLengthPhone = 12;
1387 const int kAnnotationLengthAddress = 26;
1388 // Splitting the lines on `str3` should have the same behavior (e.g. find the
1389 // phone and address spans) as if we would annotate `str1` and `str2`
1390 // individually.
1391 const std::vector<AnnotatedSpan>& annotated_spans =
1392 classifier->Annotate(str3);
1393 EXPECT_THAT(annotated_spans,
1394 ElementsAreArray(
1395 {IsAnnotatedSpan(29, 29 + kAnnotationLengthPhone, "phone"),
1396 IsAnnotatedSpan(static_cast<int>(str1.size()) + 1,
1397 static_cast<int>(str1.size() + 1 +
1398 kAnnotationLengthAddress),
1399 "address")}));
1400 }
1401
TEST_F(AnnotatorTest,NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines)1402 TEST_F(AnnotatorTest,
1403 NotUsingPipeAsNewLineCharacterShouldNotAnnotateSplitLines) {
1404 std::string model_buffer = ReadFile(GetTestModelPath());
1405 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
1406 model->selection_feature_options->only_use_line_with_click = true;
1407 model->selection_feature_options->use_pipe_character_for_newline = false;
1408 });
1409 std::unique_ptr<Annotator> classifier =
1410 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1411 unilib_.get(), calendarlib_.get());
1412
1413 ASSERT_TRUE(classifier);
1414
1415 const std::string str1 = "hey, this is my phone number 853 225 3556";
1416 const std::string str2 = "2000 Main Avenue, Apt #201, San Mateo";
1417 const std::string str3 = str1 + "|" + str2;
1418 const std::vector<AnnotatedSpan>& annotated_spans =
1419 classifier->Annotate(str3);
1420 // Note: We only check that we get a single annotated span here when the '|'
1421 // character is not used to split lines. The reason behind this is that the
1422 // model is not precise for such example and the resulted annotated span might
1423 // change when the model changes.
1424 EXPECT_THAT(annotated_spans.size(), 1);
1425 }
1426
TEST_F(AnnotatorTest,AnnotateSmallBatches)1427 TEST_F(AnnotatorTest, AnnotateSmallBatches) {
1428 const std::string test_model = ReadFile(GetTestModelPath());
1429 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1430
1431 // Set the batch size.
1432 unpacked_model->selection_options->batch_size = 4;
1433 flatbuffers::FlatBufferBuilder builder;
1434 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1435
1436 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1437 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1438 builder.GetSize(), unilib_.get(), calendarlib_.get());
1439 ASSERT_TRUE(classifier);
1440
1441 const std::string test_string =
1442 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1443 "number is 853 225 3556";
1444 EXPECT_THAT(classifier->Annotate(test_string),
1445 ElementsAreArray({
1446 IsAnnotatedSpan(28, 55, "address"),
1447 IsAnnotatedSpan(79, 91, "phone"),
1448 }));
1449
1450 AnnotationOptions options;
1451 EXPECT_THAT(classifier->Annotate("853 225 3556", options),
1452 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1453 EXPECT_THAT(classifier->Annotate("853 225\n3556", options),
1454 ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
1455 }
1456
TEST_F(AnnotatorTest,AnnotateFilteringDiscardAll)1457 TEST_F(AnnotatorTest, AnnotateFilteringDiscardAll) {
1458 const std::string test_model = ReadFile(GetTestModelPath());
1459 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1460
1461 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1462 // Add test threshold.
1463 unpacked_model->triggering_options->min_annotate_confidence =
1464 2.f; // Discards all results.
1465 flatbuffers::FlatBufferBuilder builder;
1466 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1467
1468 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1469 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1470 builder.GetSize(), unilib_.get(), calendarlib_.get());
1471 ASSERT_TRUE(classifier);
1472
1473 const std::string test_string =
1474 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1475 "number is 853 225 3556";
1476
1477 EXPECT_EQ(classifier->Annotate(test_string).size(), 0);
1478 }
1479
TEST_F(AnnotatorTest,AnnotateFilteringKeepAll)1480 TEST_F(AnnotatorTest, AnnotateFilteringKeepAll) {
1481 const std::string test_model = ReadFile(GetTestModelPath());
1482 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1483
1484 // Add test thresholds.
1485 unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
1486 unpacked_model->triggering_options->min_annotate_confidence =
1487 0.f; // Keeps all results.
1488 unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
1489 flatbuffers::FlatBufferBuilder builder;
1490 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1491
1492 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1493 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1494 builder.GetSize(), unilib_.get(), calendarlib_.get());
1495 ASSERT_TRUE(classifier);
1496
1497 const std::string test_string =
1498 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1499 "number is 853 225 3556";
1500 EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
1501 }
1502
TEST_F(AnnotatorTest,AnnotateDisabled)1503 TEST_F(AnnotatorTest, AnnotateDisabled) {
1504 const std::string test_model = ReadFile(GetTestModelPath());
1505 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1506
1507 // Disable the model for annotation.
1508 unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
1509 flatbuffers::FlatBufferBuilder builder;
1510 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1511
1512 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1513 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1514 builder.GetSize(), unilib_.get(), calendarlib_.get());
1515 ASSERT_TRUE(classifier);
1516 const std::string test_string =
1517 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1518 "number is 853 225 3556";
1519 EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
1520 }
1521
TEST_F(AnnotatorTest,AnnotateFilteredCollections)1522 TEST_F(AnnotatorTest, AnnotateFilteredCollections) {
1523 const std::string test_model = ReadFile(GetTestModelPath());
1524
1525 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1526 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1527 ASSERT_TRUE(classifier);
1528
1529 const std::string test_string =
1530 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1531 "number is 853 225 3556";
1532
1533 EXPECT_THAT(classifier->Annotate(test_string),
1534 ElementsAreArray({
1535 IsAnnotatedSpan(28, 55, "address"),
1536 IsAnnotatedSpan(79, 91, "phone"),
1537 }));
1538
1539 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1540 unpacked_model->output_options.reset(new OutputOptionsT);
1541
1542 // Disable phone annotation
1543 unpacked_model->output_options->filtered_collections_annotation.push_back(
1544 "phone");
1545
1546 flatbuffers::FlatBufferBuilder builder;
1547 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1548
1549 classifier = Annotator::FromUnownedBuffer(
1550 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1551 builder.GetSize(), unilib_.get(), calendarlib_.get());
1552 ASSERT_TRUE(classifier);
1553
1554 EXPECT_THAT(classifier->Annotate(test_string),
1555 ElementsAreArray({
1556 IsAnnotatedSpan(28, 55, "address"),
1557 }));
1558 }
1559
TEST_F(AnnotatorTest,AnnotateFilteredCollectionsSuppress)1560 TEST_F(AnnotatorTest, AnnotateFilteredCollectionsSuppress) {
1561 const std::string test_model = ReadFile(GetTestModelPath());
1562
1563 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1564 test_model.c_str(), test_model.size(), unilib_.get(), calendarlib_.get());
1565 ASSERT_TRUE(classifier);
1566
1567 const std::string test_string =
1568 "& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
1569 "number is 853 225 3556";
1570
1571 EXPECT_THAT(classifier->Annotate(test_string),
1572 ElementsAreArray({
1573 IsAnnotatedSpan(28, 55, "address"),
1574 IsAnnotatedSpan(79, 91, "phone"),
1575 }));
1576
1577 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1578 unpacked_model->output_options.reset(new OutputOptionsT);
1579
1580 // We add a custom annotator that wins against the phone classification
1581 // below and that we subsequently suppress.
1582 unpacked_model->output_options->filtered_collections_annotation.push_back(
1583 "suppress");
1584
1585 unpacked_model->regex_model->patterns.push_back(MakePattern(
1586 "suppress", "(\\d{3} ?\\d{4})",
1587 /*enabled_for_classification=*/false,
1588 /*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
1589
1590 flatbuffers::FlatBufferBuilder builder;
1591 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1592
1593 classifier = Annotator::FromUnownedBuffer(
1594 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1595 builder.GetSize(), unilib_.get(), calendarlib_.get());
1596 ASSERT_TRUE(classifier);
1597
1598 EXPECT_THAT(classifier->Annotate(test_string),
1599 ElementsAreArray({
1600 IsAnnotatedSpan(28, 55, "address"),
1601 }));
1602 }
1603
VerifyClassifyTextDateInZurichTimezone(const Annotator * classifier)1604 void VerifyClassifyTextDateInZurichTimezone(const Annotator* classifier) {
1605 EXPECT_TRUE(classifier);
1606 ClassificationOptions options;
1607 options.reference_timezone = "Europe/Zurich";
1608 options.locales = "en";
1609
1610 std::vector<ClassificationResult> result =
1611 classifier->ClassifyText("january 1, 2017", {0, 15}, options);
1612
1613 EXPECT_THAT(result,
1614 ElementsAre(IsDateResult(1483225200000,
1615 DatetimeGranularity::GRANULARITY_DAY)));
1616 }
1617
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezone)1618 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezone) {
1619 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1620 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1621 VerifyClassifyTextDateInZurichTimezone(classifier.get());
1622 }
1623
TEST_F(AnnotatorTest,ClassifyTextDateInZurichTimezoneWithDatetimeRegEx)1624 TEST_F(AnnotatorTest, ClassifyTextDateInZurichTimezoneWithDatetimeRegEx) {
1625 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1626 std::unique_ptr<Annotator> classifier =
1627 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1628 unilib_.get(), calendarlib_.get());
1629 VerifyClassifyTextDateInZurichTimezone(classifier.get());
1630 }
1631
VerifyClassifyTextDateInLATimezone(const Annotator * classifier)1632 void VerifyClassifyTextDateInLATimezone(const Annotator* classifier) {
1633 EXPECT_TRUE(classifier);
1634 ClassificationOptions options;
1635 options.reference_timezone = "America/Los_Angeles";
1636 options.locales = "en";
1637
1638 std::vector<ClassificationResult> result =
1639 classifier->ClassifyText("march 1, 2017", {0, 13}, options);
1640
1641 EXPECT_THAT(result,
1642 ElementsAre(IsDateResult(1488355200000,
1643 DatetimeGranularity::GRANULARITY_DAY)));
1644 }
1645
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezoneWithDatetimeRegEx)1646 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezoneWithDatetimeRegEx) {
1647 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1648 std::unique_ptr<Annotator> classifier =
1649 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1650 unilib_.get(), calendarlib_.get());
1651 VerifyClassifyTextDateInLATimezone(classifier.get());
1652 }
1653
TEST_F(AnnotatorTest,ClassifyTextDateInLATimezone)1654 TEST_F(AnnotatorTest, ClassifyTextDateInLATimezone) {
1655 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1656 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1657 VerifyClassifyTextDateInLATimezone(classifier.get());
1658 }
1659
VerifyClassifyTextDateOnAotherLine(const Annotator * classifier)1660 void VerifyClassifyTextDateOnAotherLine(const Annotator* classifier) {
1661 EXPECT_TRUE(classifier);
1662 ClassificationOptions options;
1663 options.reference_timezone = "Europe/Zurich";
1664 options.locales = "en";
1665
1666 std::vector<ClassificationResult> result = classifier->ClassifyText(
1667 "hello world this is the first line\n"
1668 "january 1, 2017",
1669 {35, 50}, options);
1670
1671 EXPECT_THAT(result,
1672 ElementsAre(IsDateResult(1483225200000,
1673 DatetimeGranularity::GRANULARITY_DAY)));
1674 }
1675
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLineWithDatetimeRegEx)1676 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLineWithDatetimeRegEx) {
1677 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1678 std::unique_ptr<Annotator> classifier =
1679 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1680 unilib_.get(), calendarlib_.get());
1681 VerifyClassifyTextDateOnAotherLine(classifier.get());
1682 }
1683
TEST_F(AnnotatorTest,ClassifyTextDateOnAotherLine)1684 TEST_F(AnnotatorTest, ClassifyTextDateOnAotherLine) {
1685 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1686 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1687 VerifyClassifyTextDateOnAotherLine(classifier.get());
1688 }
1689
VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(const Annotator * classifier)1690 void VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(
1691 const Annotator* classifier) {
1692 EXPECT_TRUE(classifier);
1693 std::vector<ClassificationResult> result;
1694 ClassificationOptions options;
1695
1696 options.reference_timezone = "Europe/Zurich";
1697 options.locales = "en-US";
1698 result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
1699
1700 // In US, the date should be interpreted as <month>.<day>.
1701 EXPECT_THAT(result,
1702 ElementsAre(IsDatetimeResult(
1703 5439600000, DatetimeGranularity::GRANULARITY_MINUTE)));
1704 }
1705
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDay)1706 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleUSParsesDateAsMonthDay) {
1707 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1708 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1709 VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1710 }
1711
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx)1712 TEST_F(AnnotatorTest,
1713 ClassifyTextWhenLocaleUSParsesDateAsMonthDayWithDatetimeRegEx) {
1714 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1715 std::unique_ptr<Annotator> classifier =
1716 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1717 unilib_.get(), calendarlib_.get());
1718 VerifyClassifyTextWhenLocaleUSParsesDateAsMonthDay(classifier.get());
1719 }
1720
TEST_F(AnnotatorTest,ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay)1721 TEST_F(AnnotatorTest, ClassifyTextWhenLocaleGermanyParsesDateAsMonthDay) {
1722 std::string model_buffer = GetTestModelWithDatetimeRegEx();
1723 std::unique_ptr<Annotator> classifier =
1724 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
1725 unilib_.get(), calendarlib_.get());
1726 EXPECT_TRUE(classifier);
1727 std::vector<ClassificationResult> result;
1728 ClassificationOptions options;
1729
1730 options.reference_timezone = "Europe/Zurich";
1731 options.locales = "de";
1732 result = classifier->ClassifyText("03.05.1970 00:00vorm", {0, 20}, options);
1733
1734 // In Germany, the date should be interpreted as <day>.<month>.
1735 EXPECT_THAT(result,
1736 ElementsAre(IsDatetimeResult(
1737 10537200000, DatetimeGranularity::GRANULARITY_MINUTE)));
1738 }
1739
TEST_F(AnnotatorTest,ClassifyTextAmbiguousDatetime)1740 TEST_F(AnnotatorTest, ClassifyTextAmbiguousDatetime) {
1741 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1742 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1743 EXPECT_TRUE(classifier);
1744 ClassificationOptions options;
1745 options.reference_timezone = "Europe/Zurich";
1746 options.locales = "en-US";
1747 const std::vector<ClassificationResult> result =
1748 classifier->ClassifyText("set an alarm for 10:30", {17, 22}, options);
1749
1750 EXPECT_THAT(
1751 result,
1752 ElementsAre(
1753 IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1754 IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1755 }
1756
TEST_F(AnnotatorTest,AnnotateAmbiguousDatetime)1757 TEST_F(AnnotatorTest, AnnotateAmbiguousDatetime) {
1758 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
1759 GetTestModelPath(), unilib_.get(), calendarlib_.get());
1760 EXPECT_TRUE(classifier);
1761 AnnotationOptions options;
1762 options.reference_timezone = "Europe/Zurich";
1763 options.locales = "en-US";
1764 const std::vector<AnnotatedSpan> spans =
1765 classifier->Annotate("set an alarm for 10:30", options);
1766
1767 ASSERT_EQ(spans.size(), 1);
1768 const std::vector<ClassificationResult> result = spans[0].classification;
1769 EXPECT_THAT(
1770 result,
1771 ElementsAre(
1772 IsDatetimeResult(34200000, DatetimeGranularity::GRANULARITY_MINUTE),
1773 IsDatetimeResult(77400000, DatetimeGranularity::GRANULARITY_MINUTE)));
1774 }
1775
TEST_F(AnnotatorTest,SuggestTextDateDisabled)1776 TEST_F(AnnotatorTest, SuggestTextDateDisabled) {
1777 std::string test_model = GetTestModelWithDatetimeRegEx();
1778 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1779
1780 // Disable the patterns for selection.
1781 for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
1782 unpacked_model->datetime_model->patterns[i]->enabled_modes =
1783 ModeFlag_ANNOTATION_AND_CLASSIFICATION;
1784 }
1785 flatbuffers::FlatBufferBuilder builder;
1786 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1787
1788 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1789 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1790 builder.GetSize(), unilib_.get(), calendarlib_.get());
1791 ASSERT_TRUE(classifier);
1792 EXPECT_EQ("date",
1793 FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
1794 EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
1795 CodepointSpan(0, 7));
1796 EXPECT_THAT(classifier->Annotate("january 1, 2017"),
1797 ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
1798 }
1799
TEST_F(AnnotatorTest,AnnotatesWithGrammarModel)1800 TEST_F(AnnotatorTest, AnnotatesWithGrammarModel) {
1801 const std::string test_model = ReadFile(GetTestModelPath());
1802 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1803
1804 // Add test grammar model.
1805 unpacked_model->grammar_model.reset(new GrammarModelT);
1806 GrammarModelT* grammar_model = unpacked_model->grammar_model.get();
1807 grammar_model->tokenizer_options.reset(new GrammarTokenizerOptionsT);
1808 grammar_model->tokenizer_options->tokenization_type = TokenizationType_ICU;
1809 grammar_model->tokenizer_options->icu_preserve_whitespace_tokens = false;
1810 grammar_model->tokenizer_options->tokenize_on_script_change = true;
1811
1812 // Add test rules.
1813 grammar_model->rules.reset(new grammar::RulesSetT);
1814 grammar::LocaleShardMap locale_shard_map =
1815 grammar::LocaleShardMap::CreateLocaleShardMap({""});
1816 grammar::Rules rules(locale_shard_map);
1817 rules.Add("<tv_detective>", {"jessica", "fletcher"});
1818 rules.Add("<tv_detective>", {"columbo"});
1819 rules.Add("<tv_detective>", {"magnum"});
1820 rules.Add(
1821 "<famous_person>", {"<tv_detective>"},
1822 /*callback=*/
1823 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
1824 /*callback_param=*/0 /* rule classification result */);
1825
1826 // Set result.
1827 grammar_model->rule_classification_result.emplace_back(
1828 new GrammarModel_::RuleClassificationResultT);
1829 GrammarModel_::RuleClassificationResultT* result =
1830 grammar_model->rule_classification_result.back().get();
1831 result->collection_name = "famous person";
1832 result->enabled_modes = ModeFlag_ALL;
1833 rules.Finalize().Serialize(/*include_debug_information=*/false,
1834 grammar_model->rules.get());
1835 flatbuffers::FlatBufferBuilder builder;
1836 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1837
1838 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
1839 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1840 builder.GetSize(), unilib_.get(), calendarlib_.get());
1841 ASSERT_TRUE(classifier);
1842
1843 const std::string test_string =
1844 "Did you see the Novel Connection episode where Jessica Fletcher helps "
1845 "Magnum solve the case? I thought that was with Columbo ...";
1846
1847 EXPECT_THAT(classifier->Annotate(test_string),
1848 ElementsAre(IsAnnotatedSpan(47, 63, "famous person"),
1849 IsAnnotatedSpan(70, 76, "famous person"),
1850 IsAnnotatedSpan(117, 124, "famous person")));
1851 EXPECT_THAT(FirstResult(classifier->ClassifyText("Jessica Fletcher",
1852 CodepointSpan{0, 16})),
1853 Eq("famous person"));
1854 EXPECT_THAT(classifier->SuggestSelection("Jessica Fletcher", {0, 7}),
1855 Eq(CodepointSpan{0, 16}));
1856 }
1857
TEST_F(AnnotatorTest,ResolveConflictsTrivial)1858 TEST_F(AnnotatorTest, ResolveConflictsTrivial) {
1859 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1860
1861 std::vector<AnnotatedSpan> candidates{
1862 {MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
1863 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1864
1865 BaseOptions options;
1866 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1867 std::vector<int> chosen;
1868 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1869 locales, options,
1870 /*interpreter_manager=*/nullptr, &chosen);
1871 EXPECT_THAT(chosen, ElementsAreArray({0}));
1872 }
1873
TEST_F(AnnotatorTest,ResolveConflictsSequence)1874 TEST_F(AnnotatorTest, ResolveConflictsSequence) {
1875 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1876
1877 std::vector<AnnotatedSpan> candidates{{
1878 MakeAnnotatedSpan({0, 1}, "phone", 1.0),
1879 MakeAnnotatedSpan({1, 2}, "phone", 1.0),
1880 MakeAnnotatedSpan({2, 3}, "phone", 1.0),
1881 MakeAnnotatedSpan({3, 4}, "phone", 1.0),
1882 MakeAnnotatedSpan({4, 5}, "phone", 1.0),
1883 }};
1884 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1885
1886 BaseOptions options;
1887 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1888 std::vector<int> chosen;
1889 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1890 locales, options,
1891 /*interpreter_manager=*/nullptr, &chosen);
1892 EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
1893 }
1894
TEST_F(AnnotatorTest,ResolveConflictsThreeSpans)1895 TEST_F(AnnotatorTest, ResolveConflictsThreeSpans) {
1896 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1897
1898 std::vector<AnnotatedSpan> candidates{{
1899 MakeAnnotatedSpan({0, 3}, "phone", 1.0),
1900 MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
1901 MakeAnnotatedSpan({3, 7}, "phone", 1.0),
1902 }};
1903 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1904
1905 BaseOptions options;
1906 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1907 std::vector<int> chosen;
1908 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1909 locales, options,
1910 /*interpreter_manager=*/nullptr, &chosen);
1911 EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
1912 }
1913
TEST_F(AnnotatorTest,ResolveConflictsThreeSpansReversed)1914 TEST_F(AnnotatorTest, ResolveConflictsThreeSpansReversed) {
1915 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1916
1917 std::vector<AnnotatedSpan> candidates{{
1918 MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
1919 MakeAnnotatedSpan({1, 5}, "phone", 1.0),
1920 MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
1921 }};
1922 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1923
1924 BaseOptions options;
1925 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1926 std::vector<int> chosen;
1927 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1928 locales, options,
1929 /*interpreter_manager=*/nullptr, &chosen);
1930 EXPECT_THAT(chosen, ElementsAreArray({1}));
1931 }
1932
TEST_F(AnnotatorTest,DoesNotPrioritizeLongerSpanWhenDoingConflictResolution)1933 TEST_F(AnnotatorTest, DoesNotPrioritizeLongerSpanWhenDoingConflictResolution) {
1934 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1935
1936 std::vector<AnnotatedSpan> candidates{{
1937 MakeAnnotatedSpan({3, 7}, "unit", 1),
1938 MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
1939 MakeAnnotatedSpan({5, 30}, "url", 1), // Looser!
1940 MakeAnnotatedSpan({14, 20}, "email", 1),
1941 }};
1942 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1943
1944 BaseOptions options;
1945 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1946 std::vector<int> chosen;
1947 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1948 locales, options,
1949 /*interpreter_manager=*/nullptr, &chosen);
1950 // Picks the first and the last annotations because they do not overlap.
1951 EXPECT_THAT(chosen, ElementsAreArray({0, 3}));
1952 }
1953
TEST_F(AnnotatorTest,PrioritizeLongerSpanWhenDoingConflictResolution)1954 TEST_F(AnnotatorTest, PrioritizeLongerSpanWhenDoingConflictResolution) {
1955 const std::string test_model = ReadFile(GetTestModelPath());
1956 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
1957 TC3_CHECK(libtextclassifier3::DecompressModel(unpacked_model.get()));
1958 unpacked_model->conflict_resolution_options.reset(
1959 new Model_::ConflictResolutionOptionsT);
1960 unpacked_model->conflict_resolution_options->prioritize_longest_annotation =
1961 true;
1962
1963 flatbuffers::FlatBufferBuilder builder;
1964 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
1965
1966 std::unique_ptr<TestingAnnotator> classifier =
1967 TestingAnnotator::FromUnownedBuffer(
1968 reinterpret_cast<const char*>(builder.GetBufferPointer()),
1969 builder.GetSize(), unilib_.get(), calendarlib_.get());
1970 TC3_CHECK(classifier != nullptr);
1971
1972 std::vector<AnnotatedSpan> candidates{{
1973 MakeAnnotatedSpan({3, 7}, "unit", 1), // Looser!
1974 MakeAnnotatedSpan({5, 13}, "unit", 1), // Looser!
1975 MakeAnnotatedSpan({5, 30}, "url", 1), // Pick longest match.
1976 MakeAnnotatedSpan({14, 20}, "email", 1), // Looser!
1977 }};
1978 std::vector<Locale> locales = {Locale::FromBCP47("en")};
1979
1980 BaseOptions options;
1981 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
1982 std::vector<int> chosen;
1983 classifier->ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
1984 locales, options,
1985 /*interpreter_manager=*/nullptr, &chosen);
1986 EXPECT_THAT(chosen, ElementsAreArray({2}));
1987 }
1988
TEST_F(AnnotatorTest,ResolveConflictsFiveSpans)1989 TEST_F(AnnotatorTest, ResolveConflictsFiveSpans) {
1990 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
1991
1992 std::vector<AnnotatedSpan> candidates{{
1993 MakeAnnotatedSpan({0, 3}, "phone", 0.5),
1994 MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
1995 MakeAnnotatedSpan({3, 7}, "phone", 0.6),
1996 MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
1997 MakeAnnotatedSpan({11, 15}, "phone", 0.9),
1998 }};
1999 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2000
2001 BaseOptions options;
2002 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_SMART;
2003 std::vector<int> chosen;
2004 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2005 locales, options,
2006 /*interpreter_manager=*/nullptr, &chosen);
2007 EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
2008 }
2009
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst)2010 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeFirst) {
2011 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2012
2013 std::vector<AnnotatedSpan> candidates{{
2014 MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2015 AnnotatedSpan::Source::KNOWLEDGE),
2016 MakeAnnotatedSpan({5, 10}, "address", 0.6),
2017 }};
2018 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2019
2020 BaseOptions options;
2021 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2022 std::vector<int> chosen;
2023 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2024 locales, options,
2025 /*interpreter_manager=*/nullptr, &chosen);
2026 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2027 }
2028
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond)2029 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedKnowledgeSecond) {
2030 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2031
2032 std::vector<AnnotatedSpan> candidates{{
2033 MakeAnnotatedSpan({0, 15}, "address", 0.7),
2034 MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2035 AnnotatedSpan::Source::KNOWLEDGE),
2036 }};
2037 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2038
2039 BaseOptions options;
2040 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2041 std::vector<int> chosen;
2042 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2043 locales, options,
2044 /*interpreter_manager=*/nullptr, &chosen);
2045 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2046 }
2047
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsAllowedBothKnowledge)2048 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsAllowedBothKnowledge) {
2049 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2050
2051 std::vector<AnnotatedSpan> candidates{{
2052 MakeAnnotatedSpan({0, 15}, "entity", 0.7,
2053 AnnotatedSpan::Source::KNOWLEDGE),
2054 MakeAnnotatedSpan({5, 10}, "entity", 0.6,
2055 AnnotatedSpan::Source::KNOWLEDGE),
2056 }};
2057 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2058
2059 BaseOptions options;
2060 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2061 std::vector<int> chosen;
2062 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2063 locales, options,
2064 /*interpreter_manager=*/nullptr, &chosen);
2065 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2066 }
2067
TEST_F(AnnotatorTest,ResolveConflictsRawModeOverlapsNotAllowed)2068 TEST_F(AnnotatorTest, ResolveConflictsRawModeOverlapsNotAllowed) {
2069 TestingAnnotator classifier(unilib_.get(), calendarlib_.get());
2070
2071 std::vector<AnnotatedSpan> candidates{{
2072 MakeAnnotatedSpan({0, 15}, "address", 0.7),
2073 MakeAnnotatedSpan({5, 10}, "date", 0.6),
2074 }};
2075 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2076
2077 BaseOptions options;
2078 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2079 std::vector<int> chosen;
2080 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2081 locales, options,
2082 /*interpreter_manager=*/nullptr, &chosen);
2083 EXPECT_THAT(chosen, ElementsAreArray({0}));
2084 }
2085
TEST_F(AnnotatorTest,ResolveConflictsRawModeGeneralOverlapsAllowed)2086 TEST_F(AnnotatorTest, ResolveConflictsRawModeGeneralOverlapsAllowed) {
2087 TestingAnnotator classifier(
2088 unilib_.get(), calendarlib_.get(), [](ModelT* model) {
2089 model->conflict_resolution_options.reset(
2090 new Model_::ConflictResolutionOptionsT);
2091 model->conflict_resolution_options->do_conflict_resolution_in_raw_mode =
2092 false;
2093 });
2094
2095 std::vector<AnnotatedSpan> candidates{{
2096 MakeAnnotatedSpan({0, 15}, "address", 0.7),
2097 MakeAnnotatedSpan({5, 10}, "date", 0.6),
2098 }};
2099 std::vector<Locale> locales = {Locale::FromBCP47("en")};
2100
2101 BaseOptions options;
2102 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2103 std::vector<int> chosen;
2104 classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
2105 locales, options,
2106 /*interpreter_manager=*/nullptr, &chosen);
2107 EXPECT_THAT(chosen, ElementsAreArray({0, 1}));
2108 }
2109
VerifyLongInput(const Annotator * classifier)2110 void VerifyLongInput(const Annotator* classifier) {
2111 ASSERT_TRUE(classifier);
2112
2113 for (const auto& type_value_pair :
2114 std::vector<std::pair<std::string, std::string>>{
2115 {"address", "350 Third Street, Cambridge"},
2116 {"phone", "123 456-7890"},
2117 {"url", "www.google.com"},
2118 {"email", "someone@gmail.com"},
2119 {"flight", "LX 38"},
2120 {"date", "September 1, 2018"}}) {
2121 const std::string input_100k = std::string(50000, ' ') +
2122 type_value_pair.second +
2123 std::string(50000, ' ');
2124 const int value_length = type_value_pair.second.size();
2125
2126 AnnotationOptions annotation_options;
2127 annotation_options.locales = "en";
2128 EXPECT_THAT(classifier->Annotate(input_100k, annotation_options),
2129 ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
2130 type_value_pair.first)}));
2131 SelectionOptions selection_options;
2132 selection_options.locales = "en";
2133 EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001},
2134 selection_options),
2135 CodepointSpan(50000, 50000 + value_length));
2136
2137 ClassificationOptions classification_options;
2138 classification_options.locales = "en";
2139 EXPECT_EQ(type_value_pair.first,
2140 FirstResult(classifier->ClassifyText(
2141 input_100k, {50000, 50000 + value_length},
2142 classification_options)));
2143 }
2144 }
2145
TEST_F(AnnotatorTest,LongInput)2146 TEST_F(AnnotatorTest, LongInput) {
2147 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2148 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2149 VerifyLongInput(classifier.get());
2150 }
2151
TEST_F(AnnotatorTest,LongInputWithRegExDatetime)2152 TEST_F(AnnotatorTest, LongInputWithRegExDatetime) {
2153 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2154 std::unique_ptr<Annotator> classifier =
2155 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2156 unilib_.get(), calendarlib_.get());
2157 VerifyLongInput(classifier.get());
2158 }
2159
2160 // These coarse tests are there only to make sure the execution happens in
2161 // reasonable amount of time.
TEST_F(AnnotatorTest,LongInputNoResultCheck)2162 TEST_F(AnnotatorTest, LongInputNoResultCheck) {
2163 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2164 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2165 ASSERT_TRUE(classifier);
2166
2167 for (const std::string& value :
2168 std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
2169 const std::string input_100k =
2170 std::string(50000, ' ') + value + std::string(50000, ' ');
2171 const int value_length = value.size();
2172
2173 classifier->Annotate(input_100k);
2174 classifier->SuggestSelection(input_100k, {50000, 50001});
2175 classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
2176 }
2177 }
2178
TEST_F(AnnotatorTest,MaxTokenLength)2179 TEST_F(AnnotatorTest, MaxTokenLength) {
2180 const std::string test_model = ReadFile(GetTestModelPath());
2181 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2182
2183 std::unique_ptr<Annotator> classifier;
2184
2185 // With unrestricted number of tokens should behave normally.
2186 unpacked_model->classification_options->max_num_tokens = -1;
2187
2188 flatbuffers::FlatBufferBuilder builder;
2189 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2190 classifier = Annotator::FromUnownedBuffer(
2191 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2192 builder.GetSize(), unilib_.get(), calendarlib_.get());
2193 ASSERT_TRUE(classifier);
2194
2195 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2196 "I live at 350 Third Street, Cambridge.", {10, 37})),
2197 "address");
2198
2199 // Raise the maximum number of tokens to suppress the classification.
2200 unpacked_model->classification_options->max_num_tokens = 3;
2201
2202 flatbuffers::FlatBufferBuilder builder2;
2203 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2204 classifier = Annotator::FromUnownedBuffer(
2205 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2206 builder2.GetSize(), unilib_.get(), calendarlib_.get());
2207 ASSERT_TRUE(classifier);
2208
2209 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2210 "I live at 350 Third Street, Cambridge.", {10, 37})),
2211 "other");
2212 }
2213
TEST_F(AnnotatorTest,MinAddressTokenLength)2214 TEST_F(AnnotatorTest, MinAddressTokenLength) {
2215 const std::string test_model = ReadFile(GetTestModelPath());
2216 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2217
2218 std::unique_ptr<Annotator> classifier;
2219
2220 // With unrestricted number of address tokens should behave normally.
2221 unpacked_model->classification_options->address_min_num_tokens = 0;
2222
2223 flatbuffers::FlatBufferBuilder builder;
2224 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2225 classifier = Annotator::FromUnownedBuffer(
2226 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2227 builder.GetSize(), unilib_.get(), calendarlib_.get());
2228 ASSERT_TRUE(classifier);
2229
2230 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2231 "I live at 350 Third Street, Cambridge.", {10, 37})),
2232 "address");
2233
2234 // Raise number of address tokens to suppress the address classification.
2235 unpacked_model->classification_options->address_min_num_tokens = 5;
2236
2237 flatbuffers::FlatBufferBuilder builder2;
2238 FinishModelBuffer(builder2, Model::Pack(builder2, unpacked_model.get()));
2239 classifier = Annotator::FromUnownedBuffer(
2240 reinterpret_cast<const char*>(builder2.GetBufferPointer()),
2241 builder2.GetSize(), unilib_.get(), calendarlib_.get());
2242 ASSERT_TRUE(classifier);
2243
2244 EXPECT_EQ(FirstResult(classifier->ClassifyText(
2245 "I live at 350 Third Street, Cambridge.", {10, 37})),
2246 "other");
2247 }
2248
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighOtherIsPreferredToFlight)2249 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighOtherIsPreferredToFlight) {
2250 const std::string test_model = ReadFile(GetTestModelPath());
2251 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2252 unpacked_model->triggering_options->other_collection_priority_score = 1.0;
2253
2254 flatbuffers::FlatBufferBuilder builder;
2255 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2256 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2257 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2258 builder.GetSize(), unilib_.get(), calendarlib_.get());
2259 ASSERT_TRUE(classifier);
2260
2261 EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "other");
2262 }
2263
TEST_F(AnnotatorTest,WhenOtherCollectionPriorityHighFlightIsPreferredToOther)2264 TEST_F(AnnotatorTest, WhenOtherCollectionPriorityHighFlightIsPreferredToOther) {
2265 const std::string test_model = ReadFile(GetTestModelPath());
2266 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
2267 unpacked_model->triggering_options->other_collection_priority_score = -100.0;
2268
2269 flatbuffers::FlatBufferBuilder builder;
2270 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
2271 std::unique_ptr<Annotator> classifier = Annotator::FromUnownedBuffer(
2272 reinterpret_cast<const char*>(builder.GetBufferPointer()),
2273 builder.GetSize(), unilib_.get(), calendarlib_.get());
2274 ASSERT_TRUE(classifier);
2275
2276 EXPECT_EQ(FirstResult(classifier->ClassifyText("LX37", {0, 4})), "flight");
2277 }
2278
TEST_F(AnnotatorTest,VisitAnnotatorModel)2279 TEST_F(AnnotatorTest, VisitAnnotatorModel) {
2280 EXPECT_TRUE(
2281 VisitAnnotatorModel<bool>(GetTestModelPath(), [](const Model* model) {
2282 if (model == nullptr) {
2283 return false;
2284 }
2285 return true;
2286 }));
2287 EXPECT_FALSE(VisitAnnotatorModel<bool>(
2288 GetModelPath() + "non_existing_model.fb", [](const Model* model) {
2289 if (model == nullptr) {
2290 return false;
2291 }
2292 return true;
2293 }));
2294 }
2295
TEST_F(AnnotatorTest,TriggersWhenNoLanguageDetected)2296 TEST_F(AnnotatorTest, TriggersWhenNoLanguageDetected) {
2297 std::string model_buffer = ReadFile(GetTestModelPath());
2298 model_buffer = ModifyAnnotatorModel(
2299 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2300 std::unique_ptr<Annotator> classifier =
2301 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2302 unilib_.get(), calendarlib_.get());
2303 ASSERT_TRUE(classifier);
2304
2305 EXPECT_THAT(classifier->Annotate("(555) 225-3556"),
2306 ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2307 EXPECT_EQ("phone",
2308 FirstResult(classifier->ClassifyText("(555) 225-3556", {0, 14})));
2309 EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}),
2310 CodepointSpan(0, 14));
2311 }
2312
TEST_F(AnnotatorTest,AnnotateTriggersWhenSupportedLanguageDetected)2313 TEST_F(AnnotatorTest, AnnotateTriggersWhenSupportedLanguageDetected) {
2314 std::string model_buffer = ReadFile(GetTestModelPath());
2315 model_buffer = ModifyAnnotatorModel(
2316 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2317 std::unique_ptr<Annotator> classifier =
2318 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2319 unilib_.get(), calendarlib_.get());
2320 ASSERT_TRUE(classifier);
2321 AnnotationOptions options;
2322 options.detected_text_language_tags = "cs";
2323
2324 EXPECT_THAT(classifier->Annotate("(555) 225-3556", options),
2325 ElementsAreArray({IsAnnotatedSpan(0, 14, "phone")}));
2326 }
2327
TEST_F(AnnotatorTest,AnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2328 TEST_F(AnnotatorTest, AnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2329 std::string model_buffer = ReadFile(GetTestModelPath());
2330 model_buffer = ModifyAnnotatorModel(
2331 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2332 std::unique_ptr<Annotator> classifier =
2333 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2334 unilib_.get(), calendarlib_.get());
2335 ASSERT_TRUE(classifier);
2336 AnnotationOptions options;
2337 options.detected_text_language_tags = "de";
2338
2339 EXPECT_THAT(classifier->Annotate("(555) 225-3556", options), IsEmpty());
2340 }
2341
TEST_F(AnnotatorTest,ClassifyTextTriggersWhenSupportedLanguageDetected)2342 TEST_F(AnnotatorTest, ClassifyTextTriggersWhenSupportedLanguageDetected) {
2343 std::string model_buffer = ReadFile(GetTestModelPath());
2344 model_buffer = ModifyAnnotatorModel(
2345 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2346 std::unique_ptr<Annotator> classifier =
2347 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2348 unilib_.get(), calendarlib_.get());
2349 ASSERT_TRUE(classifier);
2350 ClassificationOptions options;
2351 options.detected_text_language_tags = "cs";
2352
2353 EXPECT_EQ("phone", FirstResult(classifier->ClassifyText("(555) 225-3556",
2354 {0, 14}, options)));
2355 }
2356
TEST_F(AnnotatorTest,ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2357 TEST_F(AnnotatorTest,
2358 ClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2359 std::string model_buffer = ReadFile(GetTestModelPath());
2360 model_buffer = ModifyAnnotatorModel(
2361 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2362 std::unique_ptr<Annotator> classifier =
2363 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2364 unilib_.get(), calendarlib_.get());
2365 ASSERT_TRUE(classifier);
2366 ClassificationOptions options;
2367 options.detected_text_language_tags = "de";
2368
2369 EXPECT_THAT(classifier->ClassifyText("(555) 225-3556", {0, 14}, options),
2370 IsEmpty());
2371 }
2372
TEST_F(AnnotatorTest,SuggestSelectionTriggersWhenSupportedLanguageDetected)2373 TEST_F(AnnotatorTest, SuggestSelectionTriggersWhenSupportedLanguageDetected) {
2374 std::string model_buffer = ReadFile(GetTestModelPath());
2375 model_buffer = ModifyAnnotatorModel(
2376 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2377 std::unique_ptr<Annotator> classifier =
2378 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2379 unilib_.get(), calendarlib_.get());
2380 ASSERT_TRUE(classifier);
2381 SelectionOptions options;
2382 options.detected_text_language_tags = "cs";
2383
2384 EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2385 CodepointSpan(0, 14));
2386 }
2387
TEST_F(AnnotatorTest,SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2388 TEST_F(AnnotatorTest,
2389 SuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2390 std::string model_buffer = ReadFile(GetTestModelPath());
2391 model_buffer = ModifyAnnotatorModel(
2392 model_buffer, [](ModelT* model) { model->triggering_locales = "en,cs"; });
2393 std::unique_ptr<Annotator> classifier =
2394 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2395 unilib_.get(), calendarlib_.get());
2396 ASSERT_TRUE(classifier);
2397 SelectionOptions options;
2398 options.detected_text_language_tags = "de";
2399
2400 EXPECT_EQ(classifier->SuggestSelection("(555) 225-3556", {6, 9}, options),
2401 CodepointSpan(6, 9));
2402 }
2403
TEST_F(AnnotatorTest,MlModelTriggersWhenNoLanguageDetected)2404 TEST_F(AnnotatorTest, MlModelTriggersWhenNoLanguageDetected) {
2405 std::string model_buffer = ReadFile(GetTestModelPath());
2406 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2407 model->triggering_locales = "en,cs";
2408 model->triggering_options->locales = "en,cs";
2409 });
2410 std::unique_ptr<Annotator> classifier =
2411 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2412 unilib_.get(), calendarlib_.get());
2413 ASSERT_TRUE(classifier);
2414
2415 EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge"),
2416 ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2417 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2418 "350 Third Street, Cambridge", {0, 27})));
2419 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
2420 CodepointSpan(0, 27));
2421 }
2422
TEST_F(AnnotatorTest,MlModelAnnotateTriggersWhenSupportedLanguageDetected)2423 TEST_F(AnnotatorTest, MlModelAnnotateTriggersWhenSupportedLanguageDetected) {
2424 std::string model_buffer = ReadFile(GetTestModelPath());
2425 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2426 model->triggering_locales = "en,cs";
2427 model->triggering_options->locales = "en,cs";
2428 });
2429 std::unique_ptr<Annotator> classifier =
2430 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2431 unilib_.get(), calendarlib_.get());
2432 ASSERT_TRUE(classifier);
2433 AnnotationOptions options;
2434 options.detected_text_language_tags = "cs";
2435
2436 EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2437 ElementsAreArray({IsAnnotatedSpan(0, 27, "address")}));
2438 }
2439
TEST_F(AnnotatorTest,MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected)2440 TEST_F(AnnotatorTest,
2441 MlModelAnnotateDoesntTriggerWhenUnsupportedLanguageDetected) {
2442 std::string model_buffer = ReadFile(GetTestModelPath());
2443 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2444 model->triggering_locales = "en,cs";
2445 model->triggering_options->locales = "en,cs";
2446 });
2447 std::unique_ptr<Annotator> classifier =
2448 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2449 unilib_.get(), calendarlib_.get());
2450 ASSERT_TRUE(classifier);
2451 AnnotationOptions options;
2452 options.detected_text_language_tags = "de";
2453
2454 EXPECT_THAT(classifier->Annotate("350 Third Street, Cambridge", options),
2455 IsEmpty());
2456 }
2457
TEST_F(AnnotatorTest,MlModelClassifyTextTriggersWhenSupportedLanguageDetected)2458 TEST_F(AnnotatorTest,
2459 MlModelClassifyTextTriggersWhenSupportedLanguageDetected) {
2460 std::string model_buffer = ReadFile(GetTestModelPath());
2461 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2462 model->triggering_locales = "en,cs";
2463 model->triggering_options->locales = "en,cs";
2464 });
2465 std::unique_ptr<Annotator> classifier =
2466 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2467 unilib_.get(), calendarlib_.get());
2468 ASSERT_TRUE(classifier);
2469 ClassificationOptions options;
2470 options.detected_text_language_tags = "cs";
2471
2472 EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
2473 "350 Third Street, Cambridge", {0, 27}, options)));
2474 }
2475
TEST_F(AnnotatorTest,MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected)2476 TEST_F(AnnotatorTest,
2477 MlModelClassifyTextDoesntTriggerWhenUnsupportedLanguageDetected) {
2478 std::string model_buffer = ReadFile(GetTestModelPath());
2479 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2480 model->triggering_locales = "en,cs";
2481 model->triggering_options->locales = "en,cs";
2482 });
2483 std::unique_ptr<Annotator> classifier =
2484 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2485 unilib_.get(), calendarlib_.get());
2486 ASSERT_TRUE(classifier);
2487 ClassificationOptions options;
2488 options.detected_text_language_tags = "de";
2489
2490 EXPECT_THAT(
2491 classifier->ClassifyText("350 Third Street, Cambridge", {0, 27}, options),
2492 IsEmpty());
2493 }
2494
TEST_F(AnnotatorTest,MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected)2495 TEST_F(AnnotatorTest,
2496 MlModelSuggestSelectionTriggersWhenSupportedLanguageDetected) {
2497 std::string model_buffer = ReadFile(GetTestModelPath());
2498 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2499 model->triggering_locales = "en,cs";
2500 model->triggering_options->locales = "en,cs";
2501 });
2502 std::unique_ptr<Annotator> classifier =
2503 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2504 unilib_.get(), calendarlib_.get());
2505 ASSERT_TRUE(classifier);
2506 SelectionOptions options;
2507 options.detected_text_language_tags = "cs";
2508
2509 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2510 options),
2511 CodepointSpan(0, 27));
2512 }
2513
TEST_F(AnnotatorTest,MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected)2514 TEST_F(AnnotatorTest,
2515 MlModelSuggestSelectionDoesntTriggerWhenUnsupportedLanguageDetected) {
2516 std::string model_buffer = ReadFile(GetTestModelPath());
2517 model_buffer = ModifyAnnotatorModel(model_buffer, [](ModelT* model) {
2518 model->triggering_locales = "en,cs";
2519 model->triggering_options->locales = "en,cs";
2520 });
2521 std::unique_ptr<Annotator> classifier =
2522 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2523 unilib_.get(), calendarlib_.get());
2524 ASSERT_TRUE(classifier);
2525 SelectionOptions options;
2526 options.detected_text_language_tags = "de";
2527
2528 EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9},
2529 options),
2530 CodepointSpan(4, 9));
2531 }
2532
VerifyClassifyTextOutputsDatetimeEntityData(const Annotator * classifier)2533 void VerifyClassifyTextOutputsDatetimeEntityData(const Annotator* classifier) {
2534 EXPECT_TRUE(classifier);
2535 std::vector<ClassificationResult> result;
2536 ClassificationOptions options;
2537 options.locales = "en-US";
2538
2539 result = classifier->ClassifyText("03.05.1970 00:00am", {0, 18}, options);
2540
2541 ASSERT_GE(result.size(), 0);
2542 const EntityData* entity_data =
2543 GetEntityData(result[0].serialized_entity_data.data());
2544 ASSERT_NE(entity_data, nullptr);
2545 ASSERT_NE(entity_data->datetime(), nullptr);
2546 EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 5443200000L);
2547 EXPECT_EQ(entity_data->datetime()->granularity(),
2548 EntityData_::Datetime_::Granularity_GRANULARITY_MINUTE);
2549 EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 6);
2550
2551 auto* meridiem = entity_data->datetime()->datetime_component()->Get(0);
2552 EXPECT_EQ(meridiem->component_type(),
2553 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MERIDIEM);
2554 EXPECT_EQ(meridiem->absolute_value(), 0);
2555 EXPECT_EQ(meridiem->relative_count(), 0);
2556 EXPECT_EQ(meridiem->relation_type(),
2557 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2558
2559 auto* minute = entity_data->datetime()->datetime_component()->Get(1);
2560 EXPECT_EQ(minute->component_type(),
2561 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MINUTE);
2562 EXPECT_EQ(minute->absolute_value(), 0);
2563 EXPECT_EQ(minute->relative_count(), 0);
2564 EXPECT_EQ(minute->relation_type(),
2565 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2566
2567 auto* hour = entity_data->datetime()->datetime_component()->Get(2);
2568 EXPECT_EQ(hour->component_type(),
2569 EntityData_::Datetime_::DatetimeComponent_::ComponentType_HOUR);
2570 EXPECT_EQ(hour->absolute_value(), 0);
2571 EXPECT_EQ(hour->relative_count(), 0);
2572 EXPECT_EQ(hour->relation_type(),
2573 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2574
2575 auto* day = entity_data->datetime()->datetime_component()->Get(3);
2576 EXPECT_EQ(
2577 day->component_type(),
2578 EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2579 EXPECT_EQ(day->absolute_value(), 5);
2580 EXPECT_EQ(day->relative_count(), 0);
2581 EXPECT_EQ(day->relation_type(),
2582 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2583
2584 auto* month = entity_data->datetime()->datetime_component()->Get(4);
2585 EXPECT_EQ(month->component_type(),
2586 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2587 EXPECT_EQ(month->absolute_value(), 3);
2588 EXPECT_EQ(month->relative_count(), 0);
2589 EXPECT_EQ(month->relation_type(),
2590 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2591
2592 auto* year = entity_data->datetime()->datetime_component()->Get(5);
2593 EXPECT_EQ(year->component_type(),
2594 EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2595 EXPECT_EQ(year->absolute_value(), 1970);
2596 EXPECT_EQ(year->relative_count(), 0);
2597 EXPECT_EQ(year->relation_type(),
2598 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2599 }
2600
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityData)2601 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityData) {
2602 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2603 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2604 VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2605 }
2606
TEST_F(AnnotatorTest,ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx)2607 TEST_F(AnnotatorTest, ClassifyTextOutputsDatetimeEntityDataWithDatetimeRegEx) {
2608 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2609 std::unique_ptr<Annotator> classifier =
2610 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2611 unilib_.get(), calendarlib_.get());
2612 VerifyClassifyTextOutputsDatetimeEntityData(classifier.get());
2613 }
2614
VerifyAnnotateOutputsDatetimeEntityData(const Annotator * classifier)2615 void VerifyAnnotateOutputsDatetimeEntityData(const Annotator* classifier) {
2616 EXPECT_TRUE(classifier);
2617 std::vector<AnnotatedSpan> result;
2618 AnnotationOptions options;
2619 options.is_serialized_entity_data_enabled = true;
2620 options.locales = "en";
2621
2622 result = classifier->Annotate("September 1, 2019", options);
2623
2624 ASSERT_GE(result.size(), 0);
2625 ASSERT_GE(result[0].classification.size(), 0);
2626 ASSERT_EQ(result[0].classification[0].collection, "date");
2627 const EntityData* entity_data =
2628 GetEntityData(result[0].classification[0].serialized_entity_data.data());
2629 ASSERT_NE(entity_data, nullptr);
2630 ASSERT_NE(entity_data->datetime(), nullptr);
2631 EXPECT_EQ(entity_data->datetime()->time_ms_utc(), 1567296000000L);
2632 EXPECT_EQ(entity_data->datetime()->granularity(),
2633 EntityData_::Datetime_::Granularity_GRANULARITY_DAY);
2634 EXPECT_EQ(entity_data->datetime()->datetime_component()->size(), 3);
2635
2636 auto* day = entity_data->datetime()->datetime_component()->Get(0);
2637 EXPECT_EQ(
2638 day->component_type(),
2639 EntityData_::Datetime_::DatetimeComponent_::ComponentType_DAY_OF_MONTH);
2640 EXPECT_EQ(day->absolute_value(), 1);
2641 EXPECT_EQ(day->relative_count(), 0);
2642 EXPECT_EQ(day->relation_type(),
2643 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2644
2645 auto* month = entity_data->datetime()->datetime_component()->Get(1);
2646 EXPECT_EQ(month->component_type(),
2647 EntityData_::Datetime_::DatetimeComponent_::ComponentType_MONTH);
2648 EXPECT_EQ(month->absolute_value(), 9);
2649 EXPECT_EQ(month->relative_count(), 0);
2650 EXPECT_EQ(month->relation_type(),
2651 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2652
2653 auto* year = entity_data->datetime()->datetime_component()->Get(2);
2654 EXPECT_EQ(year->component_type(),
2655 EntityData_::Datetime_::DatetimeComponent_::ComponentType_YEAR);
2656 EXPECT_EQ(year->absolute_value(), 2019);
2657 EXPECT_EQ(year->relative_count(), 0);
2658 EXPECT_EQ(year->relation_type(),
2659 EntityData_::Datetime_::DatetimeComponent_::RelationType_ABSOLUTE);
2660 }
2661
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityData)2662 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityData) {
2663 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2664 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2665 VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2666 }
2667
TEST_F(AnnotatorTest,AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx)2668 TEST_F(AnnotatorTest, AnnotateOutputsDatetimeEntityDataWithDatatimeRegEx) {
2669 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2670 std::unique_ptr<Annotator> classifier =
2671 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2672 unilib_.get(), calendarlib_.get());
2673 VerifyAnnotateOutputsDatetimeEntityData(classifier.get());
2674 }
2675
TEST_F(AnnotatorTest,AnnotateOutputsMoneyEntityData)2676 TEST_F(AnnotatorTest, AnnotateOutputsMoneyEntityData) {
2677 // std::string model_buffer = GetTestModelWithDatetimeRegEx();
2678 // std::unique_ptr<Annotator> classifier =
2679 // Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2680 // unilib_.get(), calendarlib_.get());
2681 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2682 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2683 EXPECT_TRUE(classifier);
2684 AnnotationOptions options;
2685 options.is_serialized_entity_data_enabled = true;
2686
2687 ExpectFirstEntityIsMoney(classifier->Annotate("3.5 CHF", options), "CHF",
2688 /*amount=*/"3.5", /*whole_part=*/3,
2689 /*decimal_part=*/5, /*nanos=*/500000000);
2690 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.5", options), "CHF",
2691 /*amount=*/"3.5", /*whole_part=*/3,
2692 /*decimal_part=*/5, /*nanos=*/500000000);
2693 ExpectFirstEntityIsMoney(
2694 classifier->Annotate("For online purchase of CHF 23.00 enter", options),
2695 "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2696 /*nanos=*/0);
2697 ExpectFirstEntityIsMoney(
2698 classifier->Annotate("For online purchase of 23.00 CHF enter", options),
2699 "CHF", /*amount=*/"23.00", /*whole_part=*/23, /*decimal_part=*/0,
2700 /*nanos=*/0);
2701 ExpectFirstEntityIsMoney(classifier->Annotate("4.8198£", options), "£",
2702 /*amount=*/"4.8198", /*whole_part=*/4,
2703 /*decimal_part=*/8198, /*nanos=*/819800000);
2704 ExpectFirstEntityIsMoney(classifier->Annotate("£4.8198", options), "£",
2705 /*amount=*/"4.8198", /*whole_part=*/4,
2706 /*decimal_part=*/8198, /*nanos=*/819800000);
2707 ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2708 /*amount=*/"0.0255", /*whole_part=*/0,
2709 /*decimal_part=*/255, /*nanos=*/25500000);
2710 ExpectFirstEntityIsMoney(classifier->Annotate("$0.0255", options), "$",
2711 /*amount=*/"0.0255", /*whole_part=*/0,
2712 /*decimal_part=*/255, /*nanos=*/25500000);
2713 ExpectFirstEntityIsMoney(
2714 classifier->Annotate("for txn of INR 000.00 at RAZOR-PAY ZOMATO ONLINE "
2715 "OR on card ending 0000.",
2716 options),
2717 "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2718 /*nanos=*/0);
2719 ExpectFirstEntityIsMoney(
2720 classifier->Annotate("for txn of 000.00 INR at RAZOR-PAY ZOMATO ONLINE "
2721 "OR on card ending 0000.",
2722 options),
2723 "INR", /*amount=*/"000.00", /*whole_part=*/0, /*decimal_part=*/0,
2724 /*nanos=*/0);
2725
2726 ExpectFirstEntityIsMoney(classifier->Annotate("35 CHF", options), "CHF",
2727 /*amount=*/"35",
2728 /*whole_part=*/35, /*decimal_part=*/0, /*nanos=*/0);
2729 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 35", options), "CHF",
2730 /*amount=*/"35", /*whole_part=*/35,
2731 /*decimal_part=*/0, /*nanos=*/0);
2732 ExpectFirstEntityIsMoney(
2733 classifier->Annotate("and win back up to CHF 150 - with digitec",
2734 options),
2735 "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2736 /*nanos=*/0);
2737 ExpectFirstEntityIsMoney(
2738 classifier->Annotate("and win back up to 150 CHF - with digitec",
2739 options),
2740 "CHF", /*amount=*/"150", /*whole_part=*/150, /*decimal_part=*/0,
2741 /*nanos=*/0);
2742
2743 ExpectFirstEntityIsMoney(classifier->Annotate("3.555.333 CHF", options),
2744 "CHF", /*amount=*/"3.555.333",
2745 /*whole_part=*/3555333, /*decimal_part=*/0,
2746 /*nanos=*/0);
2747 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3.555.333", options),
2748 "CHF", /*amount=*/"3.555.333",
2749 /*whole_part=*/3555333, /*decimal_part=*/0,
2750 /*nanos=*/0);
2751 ExpectFirstEntityIsMoney(classifier->Annotate("10,000 CHF", options), "CHF",
2752 /*amount=*/"10,000", /*whole_part=*/10000,
2753 /*decimal_part=*/0, /*nanos=*/0);
2754 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 10,000", options), "CHF",
2755 /*amount=*/"10,000", /*whole_part=*/10000,
2756 /*decimal_part=*/0, /*nanos=*/0);
2757
2758 ExpectFirstEntityIsMoney(classifier->Annotate("3,555.33 CHF", options), "CHF",
2759 /*amount=*/"3,555.33", /*whole_part=*/3555,
2760 /*decimal_part=*/33, /*nanos=*/330000000);
2761 ExpectFirstEntityIsMoney(classifier->Annotate("CHF 3,555.33", options), "CHF",
2762 /*amount=*/"3,555.33", /*whole_part=*/3555,
2763 /*decimal_part=*/33, /*nanos=*/330000000);
2764 ExpectFirstEntityIsMoney(classifier->Annotate("$3,000.00", options), "$",
2765 /*amount=*/"3,000.00", /*whole_part=*/3000,
2766 /*decimal_part=*/0, /*nanos=*/0);
2767 ExpectFirstEntityIsMoney(classifier->Annotate("3,000.00$", options), "$",
2768 /*amount=*/"3,000.00", /*whole_part=*/3000,
2769 /*decimal_part=*/0, /*nanos=*/0);
2770
2771 ExpectFirstEntityIsMoney(classifier->Annotate("1.2 CHF", options), "CHF",
2772 /*amount=*/"1.2", /*whole_part=*/1,
2773 /*decimal_part=*/2, /*nanos=*/200000000);
2774 ExpectFirstEntityIsMoney(classifier->Annotate("CHF1.2", options), "CHF",
2775 /*amount=*/"1.2", /*whole_part=*/1,
2776 /*decimal_part=*/2, /*nanos=*/200000000);
2777
2778 ExpectFirstEntityIsMoney(classifier->Annotate("$1.123456789", options), "$",
2779 /*amount=*/"1.123456789", /*whole_part=*/1,
2780 /*decimal_part=*/123456789, /*nanos=*/123456789);
2781 ExpectFirstEntityIsMoney(classifier->Annotate("10.01 CHF", options), "CHF",
2782 /*amount=*/"10.01", /*whole_part=*/10,
2783 /*decimal_part=*/1, /*nanos=*/10000000);
2784
2785 ExpectFirstEntityIsMoney(classifier->Annotate("$59 Million", options), "$",
2786 /*amount=*/"59 million", /*whole_part=*/59000000,
2787 /*decimal_part=*/0, /*nanos=*/0);
2788 ExpectFirstEntityIsMoney(classifier->Annotate("7.05k €", options), "€",
2789 /*amount=*/"7.05 k", /*whole_part=*/7050,
2790 /*decimal_part=*/5, /*nanos=*/0);
2791 ExpectFirstEntityIsMoney(classifier->Annotate("7.123456789m €", options), "€",
2792 /*amount=*/"7.123456789 m", /*whole_part=*/7123456,
2793 /*decimal_part=*/123456789, /*nanos=*/789000000);
2794 ExpectFirstEntityIsMoney(classifier->Annotate("7.000056789k €", options), "€",
2795 /*amount=*/"7.000056789 k", /*whole_part=*/7000,
2796 /*decimal_part=*/56789, /*nanos=*/56789000);
2797
2798 ExpectFirstEntityIsMoney(classifier->Annotate("$59.3 Billion", options), "$",
2799 /*amount=*/"59.3 billion", /*whole_part=*/59,
2800 /*decimal_part=*/3, /*nanos=*/300000000);
2801 ExpectFirstEntityIsMoney(classifier->Annotate("$1.5 Billion", options), "$",
2802 /*amount=*/"1.5 billion", /*whole_part=*/1500000000,
2803 /*decimal_part=*/5, /*nanos=*/0);
2804 }
2805
TEST_F(AnnotatorTest,TranslateAction)2806 TEST_F(AnnotatorTest, TranslateAction) {
2807 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2808 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2809 std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
2810 libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(GetModelPath() +
2811 "lang_id.smfb");
2812 classifier->SetLangId(langid_model.get());
2813
2814 ClassificationOptions options;
2815 options.user_familiar_language_tags = "de";
2816
2817 std::vector<ClassificationResult> classifications =
2818 classifier->ClassifyText("hello, how are you doing?", {11, 14}, options);
2819 EXPECT_EQ(classifications.size(), 1);
2820 EXPECT_EQ(classifications[0].collection, "translate");
2821 }
2822
TEST_F(AnnotatorTest,AnnotateStructuredInputCallsMultipleAnnotators)2823 TEST_F(AnnotatorTest, AnnotateStructuredInputCallsMultipleAnnotators) {
2824 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2825 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2826
2827 std::vector<InputFragment> string_fragments = {
2828 {.text = "He owes me 3.5 CHF."},
2829 {.text = "...was born on 13/12/1989."},
2830 };
2831
2832 AnnotationOptions annotation_options;
2833 annotation_options.locales = "en";
2834 StatusOr<Annotations> annotations_status =
2835 classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2836 ASSERT_TRUE(annotations_status.ok());
2837 Annotations annotations = annotations_status.ValueOrDie();
2838 ASSERT_EQ(annotations.annotated_spans.size(), 2);
2839 EXPECT_THAT(annotations.annotated_spans[0],
2840 ElementsAreArray({IsAnnotatedSpan(11, 18, "money")}));
2841 EXPECT_THAT(annotations.annotated_spans[1],
2842 ElementsAreArray({IsAnnotatedSpan(15, 25, "date")}));
2843 }
2844
VerifyInputFragmentTimestampOverridesAnnotationOptions(const Annotator * classifier)2845 void VerifyInputFragmentTimestampOverridesAnnotationOptions(
2846 const Annotator* classifier) {
2847 AnnotationOptions annotation_options;
2848 annotation_options.locales = "en";
2849 annotation_options.reference_time_ms_utc =
2850 1554465190000; // 04/05/2019 11:53 am
2851 int64 fragment_reference_time = 946727580000; // 01/01/2000 11:53 am
2852 std::vector<InputFragment> string_fragments = {
2853 {.text = "New event at 17:20"},
2854 {
2855 .text = "New event at 17:20",
2856 .datetime_options = Optional<DatetimeOptions>(
2857 {.reference_time_ms_utc = fragment_reference_time}),
2858 }};
2859 StatusOr<Annotations> annotations_status =
2860 classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2861 ASSERT_TRUE(annotations_status.ok());
2862 Annotations annotations = annotations_status.ValueOrDie();
2863 ASSERT_EQ(annotations.annotated_spans.size(), 2);
2864 EXPECT_THAT(annotations.annotated_spans[0],
2865 ElementsAreArray({IsDatetimeSpan(
2866 /*start=*/13, /*end=*/18, /*time_ms_utc=*/1554484800000,
2867 DatetimeGranularity::GRANULARITY_MINUTE)}));
2868 EXPECT_THAT(annotations.annotated_spans[1],
2869 ElementsAreArray({IsDatetimeSpan(
2870 /*start=*/13, /*end=*/18, /*time_ms_utc=*/946747200000,
2871 DatetimeGranularity::GRANULARITY_MINUTE)}));
2872 }
2873
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx)2874 TEST_F(AnnotatorTest,
2875 InputFragmentTimestampOverridesAnnotationOptionsWithDatetimeRegEx) {
2876 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2877 std::unique_ptr<Annotator> classifier =
2878 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2879 unilib_.get(), calendarlib_.get());
2880 VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2881 }
2882
TEST_F(AnnotatorTest,InputFragmentTimestampOverridesAnnotationOptions)2883 TEST_F(AnnotatorTest, InputFragmentTimestampOverridesAnnotationOptions) {
2884 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2885 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2886 VerifyInputFragmentTimestampOverridesAnnotationOptions(classifier.get());
2887 }
2888
VerifyInputFragmentTimezoneOverridesAnnotationOptions(const Annotator * classifier)2889 void VerifyInputFragmentTimezoneOverridesAnnotationOptions(
2890 const Annotator* classifier) {
2891 std::vector<InputFragment> string_fragments = {
2892 {.text = "11/12/2020 17:20"},
2893 {
2894 .text = "11/12/2020 17:20",
2895 .datetime_options = Optional<DatetimeOptions>(
2896 {.reference_timezone = "Europe/Zurich"}),
2897 }};
2898 AnnotationOptions annotation_options;
2899 annotation_options.locales = "en-US";
2900 StatusOr<Annotations> annotations_status =
2901 classifier->AnnotateStructuredInput(string_fragments, annotation_options);
2902 ASSERT_TRUE(annotations_status.ok());
2903 Annotations annotations = annotations_status.ValueOrDie();
2904 ASSERT_EQ(annotations.annotated_spans.size(), 2);
2905 EXPECT_THAT(annotations.annotated_spans[0],
2906 ElementsAreArray({IsDatetimeSpan(
2907 /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605201600000,
2908 DatetimeGranularity::GRANULARITY_MINUTE)}));
2909 EXPECT_THAT(annotations.annotated_spans[1],
2910 ElementsAreArray({IsDatetimeSpan(
2911 /*start=*/0, /*end=*/16, /*time_ms_utc=*/1605198000000,
2912 DatetimeGranularity::GRANULARITY_MINUTE)}));
2913 }
2914
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptions)2915 TEST_F(AnnotatorTest, InputFragmentTimezoneOverridesAnnotationOptions) {
2916 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2917 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2918 VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2919 }
2920
TEST_F(AnnotatorTest,InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx)2921 TEST_F(AnnotatorTest,
2922 InputFragmentTimezoneOverridesAnnotationOptionsWithDatetimeRegEx) {
2923 std::string model_buffer = GetTestModelWithDatetimeRegEx();
2924 std::unique_ptr<Annotator> classifier =
2925 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
2926 unilib_.get(), calendarlib_.get());
2927 VerifyInputFragmentTimezoneOverridesAnnotationOptions(classifier.get());
2928 }
2929
2930 namespace {
AddDummyRegexDatetimeModel(ModelT * unpacked_model)2931 void AddDummyRegexDatetimeModel(ModelT* unpacked_model) {
2932 unpacked_model->datetime_model.reset(new DatetimeModelT);
2933 // This needs to be false otherwise we'd have to define some extractor. When
2934 // this is false, the 0-th capturing group (whole match) from the pattern is
2935 // used to come up with the indices.
2936 unpacked_model->datetime_model->use_extractors_for_locating = false;
2937 unpacked_model->datetime_model->locales.push_back("en-US");
2938 unpacked_model->datetime_model->default_locales.push_back(0); // en-US
2939 unpacked_model->datetime_model->patterns.push_back(
2940 std::unique_ptr<DatetimeModelPatternT>(new DatetimeModelPatternT));
2941 unpacked_model->datetime_model->patterns.back()->locales.push_back(
2942 0); // en-US
2943 unpacked_model->datetime_model->patterns.back()->regexes.push_back(
2944 std::unique_ptr<DatetimeModelPattern_::RegexT>(
2945 new DatetimeModelPattern_::RegexT));
2946 unpacked_model->datetime_model->patterns.back()->regexes.back()->pattern =
2947 "THIS_MATCHES_IN_REGEX_MODEL";
2948 unpacked_model->datetime_model->patterns.back()
2949 ->regexes.back()
2950 ->groups.push_back(DatetimeGroupType_GROUP_UNUSED);
2951 }
2952 } // namespace
2953
TEST_F(AnnotatorTest,AnnotateFiltersOutExactDuplicates)2954 TEST_F(AnnotatorTest, AnnotateFiltersOutExactDuplicates) {
2955 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2956 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2957 ASSERT_TRUE(classifier);
2958
2959 // This test assumes that both ML model and Regex model trigger on the
2960 // following text and output "phone" annotation for it.
2961 const std::string test_string = "1000000000";
2962 AnnotationOptions options;
2963 options.annotation_usecase = ANNOTATION_USECASE_RAW;
2964 int num_phones = 0;
2965 for (const AnnotatedSpan& span : classifier->Annotate(test_string, options)) {
2966 if (span.classification[0].collection == "phone") {
2967 num_phones++;
2968 }
2969 }
2970
2971 EXPECT_EQ(num_phones, 1);
2972 }
2973
2974 // This test tests the optimizations in Annotator, which make some of the
2975 // annotators not run in the RAW mode when not requested. We test here that the
2976 // results indeed don't contain such annotations. However, this is a bick hacky,
2977 // since one could also add post-filtering, in which case these tests would
2978 // trivially pass.
TEST_F(AnnotatorTest,RawModeOptimizationWorks)2979 TEST_F(AnnotatorTest, RawModeOptimizationWorks) {
2980 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
2981 GetTestModelPath(), unilib_.get(), calendarlib_.get());
2982 ASSERT_TRUE(classifier);
2983
2984 AnnotationOptions options;
2985 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
2986 // Requesting a non-existing type to avoid overlap with existing types.
2987 options.entity_types.insert("some_unknown_entity_type");
2988
2989 // Normally, the following command would produce the following annotations:
2990 // Span(19, 24, date, 1.000000),
2991 // Span(53, 56, number, 1.000000),
2992 // Span(53, 80, address, 1.000000),
2993 // Span(128, 142, phone, 1.000000),
2994 // Span(129, 132, number, 1.000000),
2995 // Span(192, 200, phone, 1.000000),
2996 // Span(192, 206, datetime, 1.000000),
2997 // Span(246, 253, number, 1.000000),
2998 // Span(246, 253, phone, 1.000000),
2999 // Span(292, 293, number, 1.000000),
3000 // Span(292, 301, duration, 1.000000) }
3001 // But because of the optimizations, it doesn't produce anything, since
3002 // we didn't request any of these entities.
3003 EXPECT_THAT(classifier->Annotate(R"--(I saw Barack Obama today
3004 350 Third Street, Cambridge
3005 my phone number is (853) 225-3556
3006 this is when we met: 1.9.2021 13:00
3007 my number: 1234567
3008 duration: 3 minutes
3009 )--",
3010 options),
3011 IsEmpty());
3012 }
3013
VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(const Annotator * classifier)3014 void VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(
3015 const Annotator* classifier) {
3016 ASSERT_TRUE(classifier);
3017 struct Example {
3018 std::string collection;
3019 std::string text;
3020 };
3021
3022 // These examples contain one example per annotator, to check that each of
3023 // the annotators can work in the RAW mode on its own.
3024 //
3025 // WARNING: This list doesn't contain yet entries for the app, contact, and
3026 // person annotators. Hopefully this won't be needed once b/155214735 is
3027 // fixed and the piping shared across annotators.
3028 std::vector<Example> examples{
3029 // ML Model.
3030 {.collection = Collections::Address(),
3031 .text = "... 350 Third Street, Cambridge ..."},
3032 // Datetime annotator.
3033 {.collection = Collections::DateTime(), .text = "... 1.9.2020 10:00 ..."},
3034 // Duration annotator.
3035 {.collection = Collections::Duration(),
3036 .text = "... 3 hours and 9 seconds ..."},
3037 // Regex annotator.
3038 {.collection = Collections::Email(),
3039 .text = "... platypus@theanimal.org ..."},
3040 // Number annotator.
3041 {.collection = Collections::Number(), .text = "... 100 ..."},
3042 };
3043
3044 for (const Example& example : examples) {
3045 AnnotationOptions options;
3046 options.locales = "en";
3047 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3048 options.entity_types.insert(example.collection);
3049
3050 EXPECT_THAT(classifier->Annotate(example.text, options),
3051 Contains(IsAnnotationWithType(example.collection)))
3052 << " text: '" << example.text
3053 << "', collection: " << example.collection;
3054 }
3055 }
3056
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawMode)3057 TEST_F(AnnotatorTest, AnnotateSupportsPointwiseCollectionFilteringInRawMode) {
3058 std::unique_ptr<Annotator> classifier = Annotator::FromPath(
3059 GetTestModelPath(), unilib_.get(), calendarlib_.get());
3060 VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3061 }
3062
TEST_F(AnnotatorTest,AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx)3063 TEST_F(AnnotatorTest,
3064 AnnotateSupportsPointwiseCollectionFilteringInRawModeWithDatetimeRegEx) {
3065 std::string model_buffer = GetTestModelWithDatetimeRegEx();
3066 std::unique_ptr<Annotator> classifier =
3067 Annotator::FromUnownedBuffer(model_buffer.data(), model_buffer.size(),
3068 unilib_.get(), calendarlib_.get());
3069 VerifyAnnotateSupportsPointwiseCollectionFilteringInRawMode(classifier.get());
3070 }
3071
TEST_F(AnnotatorTest,InitializeFromString)3072 TEST_F(AnnotatorTest, InitializeFromString) {
3073 const std::string test_model = ReadFile(GetTestModelPath());
3074 std::unique_ptr<Annotator> classifier =
3075 Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3076 ASSERT_TRUE(classifier);
3077 EXPECT_THAT(classifier->Annotate("(857) 225-3556"), Not(IsEmpty()));
3078 }
3079
3080 // Regression test for cl/338280366. Enabling only_use_line_with_click had
3081 // the effect, that some annotators in the previous code releases would
3082 // receive only the last line of the input text. This test has the entity on the
3083 // first line (duration).
TEST_F(AnnotatorTest,RegressionTestOnlyUseLineWithClickLastLine)3084 TEST_F(AnnotatorTest, RegressionTestOnlyUseLineWithClickLastLine) {
3085 const std::string test_model = ReadFile(GetTestModelPath());
3086 std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
3087
3088 std::unique_ptr<Annotator> classifier;
3089
3090 // With unrestricted number of tokens should behave normally.
3091 unpacked_model->selection_feature_options->only_use_line_with_click = true;
3092
3093 flatbuffers::FlatBufferBuilder builder;
3094 FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
3095 classifier = Annotator::FromUnownedBuffer(
3096 reinterpret_cast<const char*>(builder.GetBufferPointer()),
3097 builder.GetSize(), unilib_.get(), calendarlib_.get());
3098 ASSERT_TRUE(classifier);
3099
3100 AnnotationOptions options;
3101 options.annotation_usecase = AnnotationUsecase_ANNOTATION_USECASE_RAW;
3102
3103 const std::vector<AnnotatedSpan> annotations =
3104 classifier->Annotate("let's meet in 3 hours\nbut not now", options);
3105
3106 EXPECT_THAT(annotations, Contains(IsDurationSpan(
3107 /*start=*/14, /*end=*/21,
3108 /*duration_ms=*/3 * 60 * 60 * 1000)));
3109 }
3110
TEST_F(AnnotatorTest,DoesntProcessInvalidUtf8)3111 TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
3112 const std::string test_model = ReadFile(GetTestModelPath());
3113 const std::string invalid_utf8_text_with_phone_number =
3114 "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
3115
3116 std::unique_ptr<Annotator> classifier =
3117 Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
3118 ASSERT_TRUE(classifier);
3119 EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
3120 IsEmpty());
3121 EXPECT_THAT(
3122 classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
3123 Eq(CodepointSpan{1, 4}));
3124 EXPECT_THAT(
3125 classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
3126 IsEmpty());
3127 }
3128
3129 } // namespace test_internal
3130 } // namespace libtextclassifier3
3131