1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/grammar/grammar-annotator.h"
18
19 #include <memory>
20
21 #include "annotator/grammar/test-utils.h"
22 #include "annotator/grammar/utils.h"
23 #include "annotator/model_generated.h"
24 #include "utils/flatbuffers/flatbuffers.h"
25 #include "utils/flatbuffers/mutable.h"
26 #include "utils/grammar/utils/locale-shard-map.h"
27 #include "utils/grammar/utils/rules.h"
28 #include "utils/tokenizer.h"
29 #include "utils/utf8/unicodetext.h"
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32
33 namespace libtextclassifier3 {
34 namespace {
35
36 using testing::ElementsAre;
37
PackModel(const GrammarModelT & model)38 flatbuffers::DetachedBuffer PackModel(const GrammarModelT& model) {
39 flatbuffers::FlatBufferBuilder builder;
40 builder.Finish(GrammarModel::Pack(builder, &model));
41 return builder.Release();
42 }
43
TEST_F(GrammarAnnotatorTest,AnnotesWithGrammarRules)44 TEST_F(GrammarAnnotatorTest, AnnotesWithGrammarRules) {
45 // Create test rules.
46 GrammarModelT grammar_model;
47 SetTestTokenizerOptions(&grammar_model);
48 grammar_model.rules.reset(new grammar::RulesSetT);
49 grammar::LocaleShardMap locale_shard_map =
50 grammar::LocaleShardMap::CreateLocaleShardMap({""});
51 grammar::Rules rules(locale_shard_map);
52 rules.Add("<carrier>", {"lx"});
53 rules.Add("<carrier>", {"aa"});
54 rules.Add("<flight_code>", {"<2_digits>"});
55 rules.Add("<flight_code>", {"<3_digits>"});
56 rules.Add("<flight_code>", {"<4_digits>"});
57 rules.Add(
58 "<flight>", {"<carrier>", "<flight_code>"},
59 /*callback=*/
60 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
61 /*callback_param=*/
62 AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
63 rules.Finalize().Serialize(/*include_debug_information=*/false,
64 grammar_model.rules.get());
65 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
66 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
67
68 std::vector<AnnotatedSpan> result;
69 EXPECT_TRUE(annotator.Annotate(
70 {Locale::FromBCP47("en")},
71 UTF8ToUnicodeText(
72 "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
73 /*do_copy=*/false),
74 &result));
75
76 EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
77 IsAnnotatedSpan(51, 57, "flight")));
78 }
79
TEST_F(GrammarAnnotatorTest,HandlesAssertions)80 TEST_F(GrammarAnnotatorTest, HandlesAssertions) {
81 // Create test rules.
82 GrammarModelT grammar_model;
83 SetTestTokenizerOptions(&grammar_model);
84 grammar_model.rules.reset(new grammar::RulesSetT);
85 grammar::LocaleShardMap locale_shard_map =
86 grammar::LocaleShardMap::CreateLocaleShardMap({""});
87 grammar::Rules rules(locale_shard_map);
88 rules.Add("<carrier>", {"lx"});
89 rules.Add("<carrier>", {"aa"});
90 rules.Add("<flight_code>", {"<2_digits>"});
91 rules.Add("<flight_code>", {"<3_digits>"});
92 rules.Add("<flight_code>", {"<4_digits>"});
93
94 // Flight: carrier + flight code and check right context.
95 rules.Add(
96 "<flight>", {"<carrier>", "<flight_code>", "<context_assertion>?"},
97 /*callback=*/
98 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
99 /*callback_param=*/
100 AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
101
102 // Exclude matches like: LX 38.00 etc.
103 rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
104 /*negative=*/true);
105
106 rules.Finalize().Serialize(/*include_debug_information=*/false,
107 grammar_model.rules.get());
108 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
109 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
110
111 std::vector<AnnotatedSpan> result;
112 EXPECT_TRUE(annotator.Annotate(
113 {Locale::FromBCP47("en")},
114 UTF8ToUnicodeText("My flight: LX 38 arriving at 4pm, I'll fly back on "
115 "AA2014 on LX 38.00",
116 /*do_copy=*/false),
117 &result));
118
119 EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight"),
120 IsAnnotatedSpan(51, 57, "flight")));
121 }
122
TEST_F(GrammarAnnotatorTest,HandlesCapturingGroups)123 TEST_F(GrammarAnnotatorTest, HandlesCapturingGroups) {
124 // Create test rules.
125 GrammarModelT grammar_model;
126 SetTestTokenizerOptions(&grammar_model);
127 grammar_model.rules.reset(new grammar::RulesSetT);
128 grammar::LocaleShardMap locale_shard_map =
129 grammar::LocaleShardMap::CreateLocaleShardMap({""});
130 grammar::Rules rules(locale_shard_map);
131 rules.AddValueMapping("<low_confidence_phone>", {"<digits>"},
132 /*value=*/0);
133
134 // Create rule result.
135 const int classification_result_id =
136 AddRuleClassificationResult("phone", ModeFlag_ALL, 1.0, &grammar_model);
137 grammar_model.rule_classification_result[classification_result_id]
138 ->capturing_group.emplace_back(new CapturingGroupT);
139 grammar_model.rule_classification_result[classification_result_id]
140 ->capturing_group.back()
141 ->extend_selection = true;
142
143 rules.Add(
144 "<phone>", {"please", "call", "<low_confidence_phone>"},
145 /*callback=*/
146 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
147 /*callback_param=*/classification_result_id);
148
149 rules.Finalize().Serialize(/*include_debug_information=*/false,
150 grammar_model.rules.get());
151 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
152 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
153
154 std::vector<AnnotatedSpan> result;
155 EXPECT_TRUE(annotator.Annotate(
156 {Locale::FromBCP47("en")},
157 UTF8ToUnicodeText("Please call 911 before 10 am!", /*do_copy=*/false),
158 &result));
159 EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(12, 15, "phone")));
160 }
161
TEST_F(GrammarAnnotatorTest,ClassifiesTextWithGrammarRules)162 TEST_F(GrammarAnnotatorTest, ClassifiesTextWithGrammarRules) {
163 // Create test rules.
164 GrammarModelT grammar_model;
165 SetTestTokenizerOptions(&grammar_model);
166 grammar_model.rules.reset(new grammar::RulesSetT);
167 grammar::LocaleShardMap locale_shard_map =
168 grammar::LocaleShardMap::CreateLocaleShardMap({""});
169 grammar::Rules rules(locale_shard_map);
170 rules.Add("<carrier>", {"lx"});
171 rules.Add("<carrier>", {"aa"});
172 rules.Add("<flight_code>", {"<2_digits>"});
173 rules.Add("<flight_code>", {"<3_digits>"});
174 rules.Add("<flight_code>", {"<4_digits>"});
175 rules.Add(
176 "<flight>", {"<carrier>", "<flight_code>"},
177 /*callback=*/
178 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
179 /*callback_param=*/
180 AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
181 rules.Finalize().Serialize(/*include_debug_information=*/false,
182 grammar_model.rules.get());
183 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
184 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
185
186 ClassificationResult result;
187 EXPECT_TRUE(annotator.ClassifyText(
188 {Locale::FromBCP47("en")},
189 UTF8ToUnicodeText(
190 "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
191 /*do_copy=*/false),
192 CodepointSpan{11, 16}, &result));
193 EXPECT_THAT(result, IsClassificationResult("flight"));
194 }
195
TEST_F(GrammarAnnotatorTest,ClassifiesTextWithAssertions)196 TEST_F(GrammarAnnotatorTest, ClassifiesTextWithAssertions) {
197 // Create test rules.
198 GrammarModelT grammar_model;
199 SetTestTokenizerOptions(&grammar_model);
200 grammar_model.rules.reset(new grammar::RulesSetT);
201
202 // Use unbounded context.
203 grammar_model.context_left_num_tokens = -1;
204 grammar_model.context_right_num_tokens = -1;
205
206 grammar::LocaleShardMap locale_shard_map =
207 grammar::LocaleShardMap::CreateLocaleShardMap({""});
208 grammar::Rules rules(locale_shard_map);
209 rules.Add("<carrier>", {"lx"});
210 rules.Add("<carrier>", {"aa"});
211 rules.Add("<flight_code>", {"<2_digits>"});
212 rules.Add("<flight_code>", {"<3_digits>"});
213 rules.Add("<flight_code>", {"<4_digits>"});
214 rules.AddValueMapping("<flight_selection>", {"<carrier>", "<flight_code>"},
215 /*value=*/0);
216
217 // Flight: carrier + flight code and check right context.
218 const int classification_result_id =
219 AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model);
220 rules.Add(
221 "<flight>", {"<flight_selection>", "<context_assertion>?"},
222 /*callback=*/
223 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
224 /*callback_param=*/
225 classification_result_id);
226
227 grammar_model.rule_classification_result[classification_result_id]
228 ->capturing_group.emplace_back(new CapturingGroupT);
229 grammar_model.rule_classification_result[classification_result_id]
230 ->capturing_group.back()
231 ->extend_selection = true;
232
233 // Exclude matches like: LX 38.00 etc.
234 rules.AddAssertion("<context_assertion>", {".?", "<digits>"},
235 /*negative=*/true);
236
237 rules.Finalize().Serialize(/*include_debug_information=*/false,
238 grammar_model.rules.get());
239 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
240 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
241
242 EXPECT_FALSE(annotator.ClassifyText(
243 {Locale::FromBCP47("en")},
244 UTF8ToUnicodeText("See LX 38.00", /*do_copy=*/false), CodepointSpan{4, 9},
245 nullptr));
246 EXPECT_FALSE(annotator.ClassifyText(
247 {Locale::FromBCP47("en")},
248 UTF8ToUnicodeText("See LX 38 00", /*do_copy=*/false), CodepointSpan{4, 9},
249 nullptr));
250 ClassificationResult result;
251 EXPECT_TRUE(annotator.ClassifyText(
252 {Locale::FromBCP47("en")},
253 UTF8ToUnicodeText("See LX 38, seat 5", /*do_copy=*/false),
254 CodepointSpan{4, 9}, &result));
255 EXPECT_THAT(result, IsClassificationResult("flight"));
256 }
257
TEST_F(GrammarAnnotatorTest,ClassifiesTextWithContext)258 TEST_F(GrammarAnnotatorTest, ClassifiesTextWithContext) {
259 // Create test rules.
260 GrammarModelT grammar_model;
261 SetTestTokenizerOptions(&grammar_model);
262 grammar_model.rules.reset(new grammar::RulesSetT);
263
264 // Max three tokens to the left ("tracking number: ...").
265 grammar_model.context_left_num_tokens = 3;
266 grammar_model.context_right_num_tokens = 0;
267
268 grammar::LocaleShardMap locale_shard_map =
269 grammar::LocaleShardMap::CreateLocaleShardMap({""});
270 grammar::Rules rules(locale_shard_map);
271 rules.Add("<tracking_number>", {"<5_digits>"});
272 rules.Add("<tracking_number>", {"<6_digits>"});
273 rules.Add("<tracking_number>", {"<7_digits>"});
274 rules.Add("<tracking_number>", {"<8_digits>"});
275 rules.Add("<tracking_number>", {"<9_digits>"});
276 rules.Add("<tracking_number>", {"<10_digits>"});
277 rules.AddValueMapping("<captured_tracking_number>", {"<tracking_number>"},
278 /*value=*/0);
279 rules.Add("<parcel_tracking_trigger>", {"tracking", "number?", ":?"});
280
281 const int classification_result_id = AddRuleClassificationResult(
282 "parcel_tracking", ModeFlag_ALL, 1.0, &grammar_model);
283 rules.Add(
284 "<parcel_tracking>",
285 {"<parcel_tracking_trigger>", "<captured_tracking_number>"},
286 /*callback=*/
287 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
288 /*callback_param=*/
289 classification_result_id);
290
291 grammar_model.rule_classification_result[classification_result_id]
292 ->capturing_group.emplace_back(new CapturingGroupT);
293 grammar_model.rule_classification_result[classification_result_id]
294 ->capturing_group.back()
295 ->extend_selection = true;
296
297 rules.Finalize().Serialize(/*include_debug_information=*/false,
298 grammar_model.rules.get());
299 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
300 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
301
302 ClassificationResult result;
303 EXPECT_TRUE(annotator.ClassifyText(
304 {Locale::FromBCP47("en")},
305 UTF8ToUnicodeText("Use tracking number 012345 for live parcel tracking.",
306 /*do_copy=*/false),
307 CodepointSpan{20, 26}, &result));
308 EXPECT_THAT(result, IsClassificationResult("parcel_tracking"));
309
310 EXPECT_FALSE(annotator.ClassifyText(
311 {Locale::FromBCP47("en")},
312 UTF8ToUnicodeText("Call phone 012345 for live parcel tracking.",
313 /*do_copy=*/false),
314 CodepointSpan{11, 17}, &result));
315 }
316
TEST_F(GrammarAnnotatorTest,SuggestsTextSelection)317 TEST_F(GrammarAnnotatorTest, SuggestsTextSelection) {
318 // Create test rules.
319 GrammarModelT grammar_model;
320 SetTestTokenizerOptions(&grammar_model);
321 grammar_model.rules.reset(new grammar::RulesSetT);
322 grammar::LocaleShardMap locale_shard_map =
323 grammar::LocaleShardMap::CreateLocaleShardMap({""});
324 grammar::Rules rules(locale_shard_map);
325 rules.Add("<carrier>", {"lx"});
326 rules.Add("<carrier>", {"aa"});
327 rules.Add("<flight_code>", {"<2_digits>"});
328 rules.Add("<flight_code>", {"<3_digits>"});
329 rules.Add("<flight_code>", {"<4_digits>"});
330 rules.Add(
331 "<flight>", {"<carrier>", "<flight_code>"},
332 /*callback=*/
333 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
334 /*callback_param=*/
335 AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
336 rules.Finalize().Serialize(/*include_debug_information=*/false,
337 grammar_model.rules.get());
338 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
339 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
340
341 AnnotatedSpan selection;
342 EXPECT_TRUE(annotator.SuggestSelection(
343 {Locale::FromBCP47("en")},
344 UTF8ToUnicodeText(
345 "My flight: LX 38 arriving at 4pm, I'll fly back on AA2014",
346 /*do_copy=*/false),
347 /*selection=*/CodepointSpan{14, 15}, &selection));
348 EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
349 }
350
TEST_F(GrammarAnnotatorTest,SetsFixedEntityData)351 TEST_F(GrammarAnnotatorTest, SetsFixedEntityData) {
352 // Create test rules.
353 GrammarModelT grammar_model;
354 SetTestTokenizerOptions(&grammar_model);
355 grammar_model.rules.reset(new grammar::RulesSetT);
356 grammar::LocaleShardMap locale_shard_map =
357 grammar::LocaleShardMap::CreateLocaleShardMap({""});
358 grammar::Rules rules(locale_shard_map);
359 const int person_result =
360 AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
361 rules.Add(
362 "<person>", {"barack", "obama"},
363 /*callback=*/
364 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
365 /*callback_param=*/person_result);
366
367 // Add test entity data.
368 std::unique_ptr<MutableFlatbuffer> entity_data =
369 entity_data_builder_->NewRoot();
370 entity_data->Set("person", "Former President Barack Obama");
371 grammar_model.rule_classification_result[person_result]
372 ->serialized_entity_data = entity_data->Serialize();
373
374 rules.Finalize().Serialize(/*include_debug_information=*/false,
375 grammar_model.rules.get());
376 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
377 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
378
379 std::vector<AnnotatedSpan> result;
380 EXPECT_TRUE(annotator.Annotate(
381 {Locale::FromBCP47("en")},
382 UTF8ToUnicodeText("I saw Barack Obama today", /*do_copy=*/false),
383 &result));
384 EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 18, "person")));
385
386 // Check entity data.
387 // As we don't have generated code for the ad-hoc generated entity data
388 // schema, we have to check manually using field offsets.
389 const flatbuffers::Table* entity =
390 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
391 result.front().classification.front().serialized_entity_data.data()));
392 EXPECT_THAT(
393 entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
394 "Former President Barack Obama");
395 }
396
TEST_F(GrammarAnnotatorTest,SetsEntityDataFromCapturingMatches)397 TEST_F(GrammarAnnotatorTest, SetsEntityDataFromCapturingMatches) {
398 // Create test rules.
399 GrammarModelT grammar_model;
400 SetTestTokenizerOptions(&grammar_model);
401 grammar_model.rules.reset(new grammar::RulesSetT);
402 grammar::LocaleShardMap locale_shard_map =
403 grammar::LocaleShardMap::CreateLocaleShardMap({""});
404 grammar::Rules rules(locale_shard_map);
405 const int person_result =
406 AddRuleClassificationResult("person", ModeFlag_ALL, 1.0, &grammar_model);
407
408 rules.Add("<person>", {"barack?", "obama"});
409 rules.Add("<person>", {"zapp?", "brannigan"});
410 rules.AddValueMapping("<captured_person>", {"<person>"},
411 /*value=*/0);
412 rules.Add(
413 "<test>", {"<captured_person>"},
414 /*callback=*/
415 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
416 /*callback_param=*/person_result);
417
418 // Set capturing group entity data information.
419 grammar_model.rule_classification_result[person_result]
420 ->capturing_group.emplace_back(new CapturingGroupT);
421 CapturingGroupT* group =
422 grammar_model.rule_classification_result[person_result]
423 ->capturing_group.back()
424 .get();
425 group->entity_field_path.reset(new FlatbufferFieldPathT);
426 group->entity_field_path->field.emplace_back(new FlatbufferFieldT);
427 group->entity_field_path->field.back()->field_name = "person";
428 group->normalization_options.reset(new NormalizationOptionsT);
429 group->normalization_options->codepointwise_normalization =
430 NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
431
432 rules.Finalize().Serialize(/*include_debug_information=*/false,
433 grammar_model.rules.get());
434 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
435 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
436
437 std::vector<AnnotatedSpan> result;
438 EXPECT_TRUE(annotator.Annotate(
439 {Locale::FromBCP47("en")},
440 UTF8ToUnicodeText("I saw Zapp Brannigan today", /*do_copy=*/false),
441 &result));
442 EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(6, 20, "person")));
443
444 // Check entity data.
445 // As we don't have generated code for the ad-hoc generated entity data
446 // schema, we have to check manually using field offsets.
447 const flatbuffers::Table* entity =
448 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
449 result.front().classification.front().serialized_entity_data.data()));
450 EXPECT_THAT(
451 entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
452 "ZAPP BRANNIGAN");
453 }
454
TEST_F(GrammarAnnotatorTest,RespectsRuleModes)455 TEST_F(GrammarAnnotatorTest, RespectsRuleModes) {
456 // Create test rules.
457 GrammarModelT grammar_model;
458 SetTestTokenizerOptions(&grammar_model);
459 grammar_model.rules.reset(new grammar::RulesSetT);
460 grammar::LocaleShardMap locale_shard_map =
461 grammar::LocaleShardMap::CreateLocaleShardMap({""});
462 grammar::Rules rules(locale_shard_map);
463 rules.Add("<classification_carrier>", {"ei"});
464 rules.Add("<classification_carrier>", {"en"});
465 rules.Add("<selection_carrier>", {"ai"});
466 rules.Add("<selection_carrier>", {"bx"});
467 rules.Add("<annotation_carrier>", {"aa"});
468 rules.Add("<annotation_carrier>", {"lx"});
469 rules.Add("<flight_code>", {"<2_digits>"});
470 rules.Add("<flight_code>", {"<3_digits>"});
471 rules.Add("<flight_code>", {"<4_digits>"});
472 rules.Add(
473 "<flight>", {"<annotation_carrier>", "<flight_code>"},
474 /*callback=*/
475 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
476 /*callback_param=*/
477 AddRuleClassificationResult("flight", ModeFlag_ALL, 1.0, &grammar_model));
478 rules.Add(
479 "<flight>", {"<selection_carrier>", "<flight_code>"},
480 /*callback=*/
481 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
482 /*callback_param=*/
483 AddRuleClassificationResult("flight",
484 ModeFlag_CLASSIFICATION_AND_SELECTION, 1.0,
485 &grammar_model));
486 rules.Add(
487 "<flight>", {"<classification_carrier>", "<flight_code>"},
488 /*callback=*/
489 static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
490 /*callback_param=*/
491 AddRuleClassificationResult("flight", ModeFlag_CLASSIFICATION, 1.0,
492 &grammar_model));
493 rules.Finalize().Serialize(/*include_debug_information=*/false,
494 grammar_model.rules.get());
495 flatbuffers::DetachedBuffer serialized_model = PackModel(grammar_model);
496 GrammarAnnotator annotator(CreateGrammarAnnotator(serialized_model));
497
498 const UnicodeText text = UTF8ToUnicodeText(
499 "My flight: LX 38 arriving at 4pm, I'll fly back on EI2014 but maybe "
500 "also on bx 222");
501 const std::vector<Locale> locales = {Locale::FromBCP47("en")};
502
503 // Annotation, only high confidence pattern.
504 {
505 std::vector<AnnotatedSpan> result;
506 EXPECT_TRUE(annotator.Annotate(locales, text, &result));
507 EXPECT_THAT(result, ElementsAre(IsAnnotatedSpan(11, 16, "flight")));
508 }
509
510 // Selection, annotation patterns + selection.
511 {
512 AnnotatedSpan selection;
513
514 // Selects 'LX 38'.
515 EXPECT_TRUE(annotator.SuggestSelection(locales, text,
516 /*selection=*/CodepointSpan{14, 15},
517 &selection));
518 EXPECT_THAT(selection, IsAnnotatedSpan(11, 16, "flight"));
519
520 // Selects 'bx 222'.
521 EXPECT_TRUE(annotator.SuggestSelection(locales, text,
522 /*selection=*/CodepointSpan{76, 77},
523 &selection));
524 EXPECT_THAT(selection, IsAnnotatedSpan(76, 82, "flight"));
525
526 // Doesn't select 'EI2014'.
527 EXPECT_FALSE(annotator.SuggestSelection(locales, text,
528 /*selection=*/CodepointSpan{51, 51},
529 &selection));
530 }
531
532 // Classification, all patterns.
533 {
534 ClassificationResult result;
535
536 // Classifies 'LX 38'.
537 EXPECT_TRUE(
538 annotator.ClassifyText(locales, text, CodepointSpan{11, 16}, &result));
539 EXPECT_THAT(result, IsClassificationResult("flight"));
540
541 // Classifies 'EI2014'.
542 EXPECT_TRUE(
543 annotator.ClassifyText(locales, text, CodepointSpan{51, 57}, &result));
544 EXPECT_THAT(result, IsClassificationResult("flight"));
545
546 // Classifies 'bx 222'.
547 EXPECT_TRUE(
548 annotator.ClassifyText(locales, text, CodepointSpan{76, 82}, &result));
549 EXPECT_THAT(result, IsClassificationResult("flight"));
550 }
551 }
552
553 } // namespace
554 } // namespace libtextclassifier3
555