1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/annotator.h"
18
19 #include <algorithm>
20 #include <cmath>
21 #include <cstddef>
22 #include <iterator>
23 #include <limits>
24 #include <numeric>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28
29 #include "annotator/collections.h"
30 #include "annotator/datetime/grammar-parser.h"
31 #include "annotator/datetime/regex-parser.h"
32 #include "annotator/flatbuffer-utils.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "annotator/model_generated.h"
35 #include "annotator/types.h"
36 #include "utils/base/logging.h"
37 #include "utils/base/status.h"
38 #include "utils/base/statusor.h"
39 #include "utils/calendar/calendar.h"
40 #include "utils/checksum.h"
41 #include "utils/grammar/analyzer.h"
42 #include "utils/i18n/locale-list.h"
43 #include "utils/i18n/locale.h"
44 #include "utils/math/softmax.h"
45 #include "utils/normalization.h"
46 #include "utils/optional.h"
47 #include "utils/regex-match.h"
48 #include "utils/strings/append.h"
49 #include "utils/strings/numbers.h"
50 #include "utils/strings/split.h"
51 #include "utils/utf8/unicodetext.h"
52 #include "utils/utf8/unilib-common.h"
53 #include "utils/zlib/zlib_regex.h"
54
55 namespace libtextclassifier3 {
56
57 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58
59 const std::string& Annotator::kPhoneCollection =
__anon17cae9720102() 60 *[]() { return new std::string("phone"); }();
61 const std::string& Annotator::kAddressCollection =
__anon17cae9720202() 62 *[]() { return new std::string("address"); }();
63 const std::string& Annotator::kDateCollection =
__anon17cae9720302() 64 *[]() { return new std::string("date"); }();
65 const std::string& Annotator::kUrlCollection =
__anon17cae9720402() 66 *[]() { return new std::string("url"); }();
67 const std::string& Annotator::kEmailCollection =
__anon17cae9720502() 68 *[]() { return new std::string("email"); }();
69
70 namespace {
LoadAndVerifyModel(const void * addr,int size)71 const Model* LoadAndVerifyModel(const void* addr, int size) {
72 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73 if (VerifyModelBuffer(verifier)) {
74 return GetModel(addr);
75 } else {
76 return nullptr;
77 }
78 }
79
LoadAndVerifyPersonNameModel(const void * addr,int size)80 const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81 int size) {
82 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83 if (VerifyPersonNameModelBuffer(verifier)) {
84 return GetPersonNameModel(addr);
85 } else {
86 return nullptr;
87 }
88 }
89
90 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92 const UniLib* MaybeCreateUnilib(const UniLib* lib,
93 std::unique_ptr<UniLib>* owned_lib) {
94 if (lib) {
95 return lib;
96 } else {
97 owned_lib->reset(new UniLib);
98 return owned_lib->get();
99 }
100 }
101
102 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103 const CalendarLib* MaybeCreateCalendarlib(
104 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105 if (lib) {
106 return lib;
107 } else {
108 owned_lib->reset(new CalendarLib);
109 return owned_lib->get();
110 }
111 }
112
113 // Returns whether the provided input is valid:
114 // * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115 bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116 return (span.first >= 0 && span.first < span.second &&
117 span.second <= context.size_codepoints());
118 }
119
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120 std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121 const flatbuffers::Vector<int32_t>* ints) {
122 if (ints == nullptr) {
123 return {};
124 }
125 std::unordered_set<char32> ints_set;
126 for (auto value : *ints) {
127 ints_set.insert(static_cast<char32>(value));
128 }
129 return ints_set;
130 }
131
132 } // namespace
133
SelectionInterpreter()134 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135 if (!selection_interpreter_) {
136 TC3_CHECK(selection_executor_);
137 selection_interpreter_ = selection_executor_->CreateInterpreter();
138 if (!selection_interpreter_) {
139 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140 }
141 }
142 return selection_interpreter_.get();
143 }
144
ClassificationInterpreter()145 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146 if (!classification_interpreter_) {
147 TC3_CHECK(classification_executor_);
148 classification_interpreter_ = classification_executor_->CreateInterpreter();
149 if (!classification_interpreter_) {
150 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151 }
152 }
153 return classification_interpreter_.get();
154 }
155
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157 const char* buffer, int size, const UniLib* unilib,
158 const CalendarLib* calendarlib) {
159 const Model* model = LoadAndVerifyModel(buffer, size);
160 if (model == nullptr) {
161 return nullptr;
162 }
163
164 auto classifier = std::unique_ptr<Annotator>(new Annotator());
165 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166 calendarlib =
167 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168 classifier->ValidateAndInitialize(model, unilib, calendarlib);
169 if (!classifier->IsInitialized()) {
170 return nullptr;
171 }
172
173 return classifier;
174 }
175
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176 std::unique_ptr<Annotator> Annotator::FromString(
177 const std::string& buffer, const UniLib* unilib,
178 const CalendarLib* calendarlib) {
179 auto classifier = std::unique_ptr<Annotator>(new Annotator());
180 classifier->owned_buffer_ = buffer;
181 const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182 classifier->owned_buffer_.size());
183 if (model == nullptr) {
184 return nullptr;
185 }
186 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187 calendarlib =
188 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189 classifier->ValidateAndInitialize(model, unilib, calendarlib);
190 if (!classifier->IsInitialized()) {
191 return nullptr;
192 }
193
194 return classifier;
195 }
196
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199 const CalendarLib* calendarlib) {
200 if (!(*mmap)->handle().ok()) {
201 TC3_VLOG(1) << "Mmap failed.";
202 return nullptr;
203 }
204
205 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206 (*mmap)->handle().num_bytes());
207 if (!model) {
208 TC3_LOG(ERROR) << "Model verification failed.";
209 return nullptr;
210 }
211
212 auto classifier = std::unique_ptr<Annotator>(new Annotator());
213 classifier->mmap_ = std::move(*mmap);
214 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215 calendarlib =
216 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217 classifier->ValidateAndInitialize(model, unilib, calendarlib);
218 if (!classifier->IsInitialized()) {
219 return nullptr;
220 }
221
222 return classifier;
223 }
224
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227 std::unique_ptr<CalendarLib> calendarlib) {
228 if (!(*mmap)->handle().ok()) {
229 TC3_VLOG(1) << "Mmap failed.";
230 return nullptr;
231 }
232
233 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234 (*mmap)->handle().num_bytes());
235 if (model == nullptr) {
236 TC3_LOG(ERROR) << "Model verification failed.";
237 return nullptr;
238 }
239
240 auto classifier = std::unique_ptr<Annotator>(new Annotator());
241 classifier->mmap_ = std::move(*mmap);
242 classifier->owned_unilib_ = std::move(unilib);
243 classifier->owned_calendarlib_ = std::move(calendarlib);
244 classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245 classifier->owned_calendarlib_.get());
246 if (!classifier->IsInitialized()) {
247 return nullptr;
248 }
249
250 return classifier;
251 }
252
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254 int fd, int offset, int size, const UniLib* unilib,
255 const CalendarLib* calendarlib) {
256 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257 return FromScopedMmap(&mmap, unilib, calendarlib);
258 }
259
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262 std::unique_ptr<CalendarLib> calendarlib) {
263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265 }
266
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270 return FromScopedMmap(&mmap, unilib, calendarlib);
271 }
272
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274 int fd, std::unique_ptr<UniLib> unilib,
275 std::unique_ptr<CalendarLib> calendarlib) {
276 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278 }
279
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281 const UniLib* unilib,
282 const CalendarLib* calendarlib) {
283 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284 return FromScopedMmap(&mmap, unilib, calendarlib);
285 }
286
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287 std::unique_ptr<Annotator> Annotator::FromPath(
288 const std::string& path, std::unique_ptr<UniLib> unilib,
289 std::unique_ptr<CalendarLib> calendarlib) {
290 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292 }
293
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294 void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295 const CalendarLib* calendarlib) {
296 model_ = model;
297 unilib_ = unilib;
298 calendarlib_ = calendarlib;
299
300 initialized_ = false;
301
302 if (model_ == nullptr) {
303 TC3_LOG(ERROR) << "No model specified.";
304 return;
305 }
306
307 const bool model_enabled_for_annotation =
308 (model_->triggering_options() != nullptr &&
309 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310 const bool model_enabled_for_classification =
311 (model_->triggering_options() != nullptr &&
312 (model_->triggering_options()->enabled_modes() &
313 ModeFlag_CLASSIFICATION));
314 const bool model_enabled_for_selection =
315 (model_->triggering_options() != nullptr &&
316 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317
318 // Annotation requires the selection model.
319 if (model_enabled_for_annotation || model_enabled_for_selection) {
320 if (!model_->selection_options()) {
321 TC3_LOG(ERROR) << "No selection options.";
322 return;
323 }
324 if (!model_->selection_feature_options()) {
325 TC3_LOG(ERROR) << "No selection feature options.";
326 return;
327 }
328 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330 return;
331 }
332 if (!model_->selection_model()) {
333 TC3_LOG(ERROR) << "No selection model.";
334 return;
335 }
336 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337 if (!selection_executor_) {
338 TC3_LOG(ERROR) << "Could not initialize selection executor.";
339 return;
340 }
341 }
342
343 // Even if the annotation mode is not enabled (for the neural network model),
344 // the selection feature processor is needed to tokenize the text for other
345 // models.
346 if (model_->selection_feature_options()) {
347 selection_feature_processor_.reset(
348 new FeatureProcessor(model_->selection_feature_options(), unilib_));
349 }
350
351 // Annotation requires the classification model for conflict resolution and
352 // scoring.
353 // Selection requires the classification model for conflict resolution.
354 if (model_enabled_for_annotation || model_enabled_for_classification ||
355 model_enabled_for_selection) {
356 if (!model_->classification_options()) {
357 TC3_LOG(ERROR) << "No classification options.";
358 return;
359 }
360
361 if (!model_->classification_feature_options()) {
362 TC3_LOG(ERROR) << "No classification feature options.";
363 return;
364 }
365
366 if (!model_->classification_feature_options()
367 ->bounds_sensitive_features()) {
368 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
369 return;
370 }
371 if (!model_->classification_model()) {
372 TC3_LOG(ERROR) << "No clf model.";
373 return;
374 }
375
376 classification_executor_ =
377 ModelExecutor::FromBuffer(model_->classification_model());
378 if (!classification_executor_) {
379 TC3_LOG(ERROR) << "Could not initialize classification executor.";
380 return;
381 }
382
383 classification_feature_processor_.reset(new FeatureProcessor(
384 model_->classification_feature_options(), unilib_));
385 }
386
387 // The embeddings need to be specified if the model is to be used for
388 // classification or selection.
389 if (model_enabled_for_annotation || model_enabled_for_classification ||
390 model_enabled_for_selection) {
391 if (!model_->embedding_model()) {
392 TC3_LOG(ERROR) << "No embedding model.";
393 return;
394 }
395
396 // Check that the embedding size of the selection and classification model
397 // matches, as they are using the same embeddings.
398 if (model_enabled_for_selection &&
399 (model_->selection_feature_options()->embedding_size() !=
400 model_->classification_feature_options()->embedding_size() ||
401 model_->selection_feature_options()->embedding_quantization_bits() !=
402 model_->classification_feature_options()
403 ->embedding_quantization_bits())) {
404 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
405 return;
406 }
407
408 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
409 model_->embedding_model(),
410 model_->classification_feature_options()->embedding_size(),
411 model_->classification_feature_options()->embedding_quantization_bits(),
412 model_->embedding_pruning_mask());
413 if (!embedding_executor_) {
414 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
415 return;
416 }
417 }
418
419 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
420 if (model_->regex_model()) {
421 if (!InitializeRegexModel(decompressor.get())) {
422 TC3_LOG(ERROR) << "Could not initialize regex model.";
423 return;
424 }
425 }
426
427 if (model_->datetime_grammar_model()) {
428 if (model_->datetime_grammar_model()->rules()) {
429 analyzer_ = std::make_unique<grammar::Analyzer>(
430 unilib_, model_->datetime_grammar_model()->rules());
431 datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
432 datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
433 *analyzer_, *datetime_grounder_,
434 /*target_classification_score=*/1.0,
435 /*priority_score=*/1.0);
436 }
437 } else if (model_->datetime_model()) {
438 datetime_parser_ = RegexDatetimeParser::Instance(
439 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
440 if (!datetime_parser_) {
441 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
442 return;
443 }
444 }
445
446 if (model_->output_options()) {
447 if (model_->output_options()->filtered_collections_annotation()) {
448 for (const auto collection :
449 *model_->output_options()->filtered_collections_annotation()) {
450 filtered_collections_annotation_.insert(collection->str());
451 }
452 }
453 if (model_->output_options()->filtered_collections_classification()) {
454 for (const auto collection :
455 *model_->output_options()->filtered_collections_classification()) {
456 filtered_collections_classification_.insert(collection->str());
457 }
458 }
459 if (model_->output_options()->filtered_collections_selection()) {
460 for (const auto collection :
461 *model_->output_options()->filtered_collections_selection()) {
462 filtered_collections_selection_.insert(collection->str());
463 }
464 }
465 }
466
467 if (model_->number_annotator_options() &&
468 model_->number_annotator_options()->enabled()) {
469 number_annotator_.reset(
470 new NumberAnnotator(model_->number_annotator_options(), unilib_));
471 }
472
473 if (model_->money_parsing_options()) {
474 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
475 model_->money_parsing_options()->separators());
476 }
477
478 if (model_->duration_annotator_options() &&
479 model_->duration_annotator_options()->enabled()) {
480 duration_annotator_.reset(
481 new DurationAnnotator(model_->duration_annotator_options(),
482 selection_feature_processor_.get(), unilib_));
483 }
484
485 if (model_->grammar_model()) {
486 grammar_annotator_.reset(new GrammarAnnotator(
487 unilib_, model_->grammar_model(), entity_data_builder_.get()));
488 }
489
490 // The following #ifdef is here to aid quality evaluation of a situation, when
491 // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
492 // it.
493 #if !defined(TC3_DISABLE_POD_NER)
494 if (model_->pod_ner_model()) {
495 pod_ner_annotator_ =
496 PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
497 }
498 #endif
499
500 if (model_->vocab_model()) {
501 vocab_annotator_ = VocabAnnotator::Create(
502 model_->vocab_model(), *selection_feature_processor_, *unilib_);
503 }
504
505 if (model_->entity_data_schema()) {
506 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
507 model_->entity_data_schema()->Data(),
508 model_->entity_data_schema()->size());
509 if (entity_data_schema_ == nullptr) {
510 TC3_LOG(ERROR) << "Could not load entity data schema data.";
511 return;
512 }
513
514 entity_data_builder_.reset(
515 new MutableFlatbufferBuilder(entity_data_schema_));
516 } else {
517 entity_data_schema_ = nullptr;
518 entity_data_builder_ = nullptr;
519 }
520
521 if (model_->triggering_locales() &&
522 !ParseLocales(model_->triggering_locales()->c_str(),
523 &model_triggering_locales_)) {
524 TC3_LOG(ERROR) << "Could not parse model supported locales.";
525 return;
526 }
527
528 if (model_->triggering_options() != nullptr &&
529 model_->triggering_options()->locales() != nullptr &&
530 !ParseLocales(model_->triggering_options()->locales()->c_str(),
531 &ml_model_triggering_locales_)) {
532 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
533 return;
534 }
535
536 if (model_->triggering_options() != nullptr &&
537 model_->triggering_options()->dictionary_locales() != nullptr &&
538 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
539 &dictionary_locales_)) {
540 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
541 return;
542 }
543
544 if (model_->conflict_resolution_options() != nullptr) {
545 prioritize_longest_annotation_ =
546 model_->conflict_resolution_options()->prioritize_longest_annotation();
547 do_conflict_resolution_in_raw_mode_ =
548 model_->conflict_resolution_options()
549 ->do_conflict_resolution_in_raw_mode();
550 }
551
552 #ifdef TC3_EXPERIMENTAL
553 TC3_LOG(WARNING) << "Enabling experimental annotators.";
554 InitializeExperimentalAnnotators();
555 #endif
556
557 initialized_ = true;
558 }
559
InitializeRegexModel(ZlibDecompressor * decompressor)560 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
561 if (!model_->regex_model()->patterns()) {
562 return true;
563 }
564
565 // Initialize pattern recognizers.
566 int regex_pattern_id = 0;
567 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
568 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
569 UncompressMakeRegexPattern(
570 *unilib_, regex_pattern->pattern(),
571 regex_pattern->compressed_pattern(),
572 model_->regex_model()->lazy_regex_compilation(), decompressor);
573 if (!compiled_pattern) {
574 TC3_LOG(INFO) << "Failed to load regex pattern";
575 return false;
576 }
577
578 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
579 annotation_regex_patterns_.push_back(regex_pattern_id);
580 }
581 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
582 classification_regex_patterns_.push_back(regex_pattern_id);
583 }
584 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
585 selection_regex_patterns_.push_back(regex_pattern_id);
586 }
587 regex_patterns_.push_back({
588 regex_pattern,
589 std::move(compiled_pattern),
590 });
591 ++regex_pattern_id;
592 }
593
594 return true;
595 }
596
InitializeKnowledgeEngine(const std::string & serialized_config)597 bool Annotator::InitializeKnowledgeEngine(
598 const std::string& serialized_config) {
599 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
600 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
601 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
602 return false;
603 }
604 if (model_->triggering_options() != nullptr) {
605 knowledge_engine->SetPriorityScore(
606 model_->triggering_options()->knowledge_priority_score());
607 }
608 knowledge_engine_ = std::move(knowledge_engine);
609 return true;
610 }
611
InitializeContactEngine(const std::string & serialized_config)612 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
613 std::unique_ptr<ContactEngine> contact_engine(
614 new ContactEngine(selection_feature_processor_.get(), unilib_,
615 model_->contact_annotator_options()));
616 if (!contact_engine->Initialize(serialized_config)) {
617 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
618 return false;
619 }
620 contact_engine_ = std::move(contact_engine);
621 return true;
622 }
623
InitializeInstalledAppEngine(const std::string & serialized_config)624 bool Annotator::InitializeInstalledAppEngine(
625 const std::string& serialized_config) {
626 std::unique_ptr<InstalledAppEngine> installed_app_engine(
627 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
628 if (!installed_app_engine->Initialize(serialized_config)) {
629 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
630 return false;
631 }
632 installed_app_engine_ = std::move(installed_app_engine);
633 return true;
634 }
635
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)636 bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
637 if (lang_id == nullptr) {
638 return false;
639 }
640
641 lang_id_ = lang_id;
642 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
643 model_->translate_annotator_options()->enabled()) {
644 translate_annotator_.reset(new TranslateAnnotator(
645 model_->translate_annotator_options(), lang_id_, unilib_));
646 } else {
647 translate_annotator_.reset(nullptr);
648 }
649 return true;
650 }
651
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)652 bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
653 int size) {
654 const PersonNameModel* person_name_model =
655 LoadAndVerifyPersonNameModel(buffer, size);
656
657 if (person_name_model == nullptr) {
658 TC3_LOG(ERROR) << "Person name model verification failed.";
659 return false;
660 }
661
662 if (!person_name_model->enabled()) {
663 return true;
664 }
665
666 std::unique_ptr<PersonNameEngine> person_name_engine(
667 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
668 if (!person_name_engine->Initialize(person_name_model)) {
669 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
670 return false;
671 }
672 person_name_engine_ = std::move(person_name_engine);
673 return true;
674 }
675
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)676 bool Annotator::InitializePersonNameEngineFromScopedMmap(
677 const ScopedMmap& mmap) {
678 if (!mmap.handle().ok()) {
679 TC3_LOG(ERROR) << "Mmap for person name model failed.";
680 return false;
681 }
682
683 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
684 mmap.handle().num_bytes());
685 }
686
InitializePersonNameEngineFromPath(const std::string & path)687 bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
688 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
689 return InitializePersonNameEngineFromScopedMmap(*mmap);
690 }
691
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)692 bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
693 int size) {
694 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
695 return InitializePersonNameEngineFromScopedMmap(*mmap);
696 }
697
InitializeExperimentalAnnotators()698 bool Annotator::InitializeExperimentalAnnotators() {
699 if (ExperimentalAnnotator::IsEnabled()) {
700 experimental_annotator_.reset(new ExperimentalAnnotator(
701 model_->experimental_model(), *selection_feature_processor_, *unilib_));
702 return true;
703 }
704 return false;
705 }
706
707 namespace internal {
708 // Helper function, which if the initial 'span' contains only white-spaces,
709 // moves the selection to a single-codepoint selection on a left or right side
710 // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)711 CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
712 const UnicodeText& context_unicode,
713 const UniLib& unilib) {
714 TC3_CHECK(span.IsValid() && !span.IsEmpty());
715
716 UnicodeText::const_iterator it;
717
718 // Check that the current selection is all whitespaces.
719 it = context_unicode.begin();
720 std::advance(it, span.first);
721 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
722 if (!unilib.IsWhitespace(*it)) {
723 return span;
724 }
725 }
726
727 // Try moving left.
728 CodepointSpan result = span;
729 it = context_unicode.begin();
730 std::advance(it, span.first);
731 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
732 --result.first;
733 --it;
734 }
735 result.second = result.first + 1;
736 if (!unilib.IsWhitespace(*it)) {
737 return result;
738 }
739
740 // If moving left didn't find a non-whitespace character, just return the
741 // original span.
742 return span;
743 }
744 } // namespace internal
745
FilteredForAnnotation(const AnnotatedSpan & span) const746 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
747 return !span.classification.empty() &&
748 filtered_collections_annotation_.find(
749 span.classification[0].collection) !=
750 filtered_collections_annotation_.end();
751 }
752
FilteredForClassification(const ClassificationResult & classification) const753 bool Annotator::FilteredForClassification(
754 const ClassificationResult& classification) const {
755 return filtered_collections_classification_.find(classification.collection) !=
756 filtered_collections_classification_.end();
757 }
758
FilteredForSelection(const AnnotatedSpan & span) const759 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
760 return !span.classification.empty() &&
761 filtered_collections_selection_.find(
762 span.classification[0].collection) !=
763 filtered_collections_selection_.end();
764 }
765
766 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)767 inline bool ClassifiedAsOther(
768 const std::vector<ClassificationResult>& classification) {
769 return !classification.empty() &&
770 classification[0].collection == Collections::Other();
771 }
772
773 } // namespace
774
GetPriorityScore(const std::vector<ClassificationResult> & classification) const775 float Annotator::GetPriorityScore(
776 const std::vector<ClassificationResult>& classification) const {
777 if (!classification.empty() && !ClassifiedAsOther(classification)) {
778 return classification[0].priority_score;
779 } else {
780 if (model_->triggering_options() != nullptr) {
781 return model_->triggering_options()->other_collection_priority_score();
782 } else {
783 return -1000.0;
784 }
785 }
786 }
787
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const788 bool Annotator::VerifyRegexMatchCandidate(
789 const std::string& context, const VerificationOptions* verification_options,
790 const std::string& match, const UniLib::RegexMatcher* matcher) const {
791 if (verification_options == nullptr) {
792 return true;
793 }
794 if (verification_options->verify_luhn_checksum() &&
795 !VerifyLuhnChecksum(match)) {
796 return false;
797 }
798 const int lua_verifier = verification_options->lua_verifier();
799 if (lua_verifier >= 0) {
800 if (model_->regex_model()->lua_verifier() == nullptr ||
801 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
802 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
803 return false;
804 }
805 return VerifyMatch(
806 context, matcher,
807 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
808 }
809 return true;
810 }
811
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const812 CodepointSpan Annotator::SuggestSelection(
813 const std::string& context, CodepointSpan click_indices,
814 const SelectionOptions& options) const {
815 if (context.size() > std::numeric_limits<int>::max()) {
816 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
817 return {};
818 }
819
820 CodepointSpan original_click_indices = click_indices;
821 if (!initialized_) {
822 TC3_LOG(ERROR) << "Not initialized";
823 return original_click_indices;
824 }
825 if (options.annotation_usecase !=
826 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
827 TC3_LOG(WARNING)
828 << "Invoking SuggestSelection, which is not supported in RAW mode.";
829 return original_click_indices;
830 }
831 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
832 return original_click_indices;
833 }
834
835 std::vector<Locale> detected_text_language_tags;
836 if (!ParseLocales(options.detected_text_language_tags,
837 &detected_text_language_tags)) {
838 TC3_LOG(WARNING)
839 << "Failed to parse the detected_text_language_tags in options: "
840 << options.detected_text_language_tags;
841 }
842 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
843 model_triggering_locales_,
844 /*default_value=*/true)) {
845 return original_click_indices;
846 }
847
848 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
849 /*do_copy=*/false);
850
851 if (!unilib_->IsValidUtf8(context_unicode)) {
852 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
853 return original_click_indices;
854 }
855
856 if (!IsValidSpanInput(context_unicode, click_indices)) {
857 TC3_VLOG(1)
858 << "Trying to run SuggestSelection with invalid input, indices: "
859 << click_indices.first << " " << click_indices.second;
860 return original_click_indices;
861 }
862
863 if (model_->snap_whitespace_selections()) {
864 // We want to expand a purely white-space selection to a multi-selection it
865 // would've been part of. But with this feature disabled we would do a no-
866 // op, because no token is found. Therefore, we need to modify the
867 // 'click_indices' a bit to include a part of the token, so that the click-
868 // finding logic finds the clicked token correctly. This modification is
869 // done by the following function. Note, that it's enough to check the left
870 // side of the current selection, because if the white-space is a part of a
871 // multi-selection, necessarily both tokens - on the left and the right
872 // sides need to be selected. Thus snapping only to the left is sufficient
873 // (there's a check at the bottom that makes sure that if we snap to the
874 // left token but the result does not contain the initial white-space,
875 // returns the original indices).
876 click_indices = internal::SnapLeftIfWhitespaceSelection(
877 click_indices, context_unicode, *unilib_);
878 }
879
880 Annotations candidates;
881 // As we process a single string of context, the candidates will only
882 // contain one vector of AnnotatedSpan.
883 candidates.annotated_spans.resize(1);
884 InterpreterManager interpreter_manager(selection_executor_.get(),
885 classification_executor_.get());
886 std::vector<Token> tokens;
887 if (!ModelSuggestSelection(context_unicode, click_indices,
888 detected_text_language_tags, &interpreter_manager,
889 &tokens, &candidates.annotated_spans[0])) {
890 TC3_LOG(ERROR) << "Model suggest selection failed.";
891 return original_click_indices;
892 }
893 const std::unordered_set<std::string> set;
894 const EnabledEntityTypes is_entity_type_enabled(set);
895 if (!RegexChunk(context_unicode, selection_regex_patterns_,
896 /*is_serialized_entity_data_enabled=*/false,
897 is_entity_type_enabled, options.annotation_usecase,
898 &candidates.annotated_spans[0])) {
899 TC3_LOG(ERROR) << "Regex suggest selection failed.";
900 return original_click_indices;
901 }
902 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
903 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
904 options.locales, ModeFlag_SELECTION,
905 options.annotation_usecase,
906 /*is_serialized_entity_data_enabled=*/false,
907 &candidates.annotated_spans[0])) {
908 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
909 return original_click_indices;
910 }
911 if (knowledge_engine_ != nullptr &&
912 !knowledge_engine_
913 ->Chunk(context, options.annotation_usecase,
914 options.location_context, Permissions(),
915 AnnotateMode::kEntityAnnotation, &candidates)
916 .ok()) {
917 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
918 return original_click_indices;
919 }
920 if (contact_engine_ != nullptr &&
921 !contact_engine_->Chunk(context_unicode, tokens,
922 &candidates.annotated_spans[0])) {
923 TC3_LOG(ERROR) << "Contact suggest selection failed.";
924 return original_click_indices;
925 }
926 if (installed_app_engine_ != nullptr &&
927 !installed_app_engine_->Chunk(context_unicode, tokens,
928 &candidates.annotated_spans[0])) {
929 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
930 return original_click_indices;
931 }
932 if (number_annotator_ != nullptr &&
933 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
934 &candidates.annotated_spans[0])) {
935 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
936 return original_click_indices;
937 }
938 if (duration_annotator_ != nullptr &&
939 !duration_annotator_->FindAll(context_unicode, tokens,
940 options.annotation_usecase,
941 &candidates.annotated_spans[0])) {
942 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
943 return original_click_indices;
944 }
945 if (person_name_engine_ != nullptr &&
946 !person_name_engine_->Chunk(context_unicode, tokens,
947 &candidates.annotated_spans[0])) {
948 TC3_LOG(ERROR) << "Person name suggest selection failed.";
949 return original_click_indices;
950 }
951
952 AnnotatedSpan grammar_suggested_span;
953 if (grammar_annotator_ != nullptr &&
954 grammar_annotator_->SuggestSelection(detected_text_language_tags,
955 context_unicode, click_indices,
956 &grammar_suggested_span)) {
957 candidates.annotated_spans[0].push_back(grammar_suggested_span);
958 }
959
960 AnnotatedSpan pod_ner_suggested_span;
961 if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
962 pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
963 &pod_ner_suggested_span)) {
964 candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
965 }
966
967 if (experimental_annotator_ != nullptr) {
968 candidates.annotated_spans[0].push_back(
969 experimental_annotator_->SuggestSelection(context_unicode,
970 click_indices));
971 }
972
973 // Sort candidates according to their position in the input, so that the next
974 // code can assume that any connected component of overlapping spans forms a
975 // contiguous block.
976 std::stable_sort(candidates.annotated_spans[0].begin(),
977 candidates.annotated_spans[0].end(),
978 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
979 return a.span.first < b.span.first;
980 });
981
982 std::vector<int> candidate_indices;
983 if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
984 detected_text_language_tags, options,
985 &interpreter_manager, &candidate_indices)) {
986 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
987 return original_click_indices;
988 }
989
990 std::stable_sort(
991 candidate_indices.begin(), candidate_indices.end(),
992 [this, &candidates](int a, int b) {
993 return GetPriorityScore(
994 candidates.annotated_spans[0][a].classification) >
995 GetPriorityScore(
996 candidates.annotated_spans[0][b].classification);
997 });
998
999 for (const int i : candidate_indices) {
1000 if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
1001 SpansOverlap(candidates.annotated_spans[0][i].span,
1002 original_click_indices)) {
1003 // Run model classification if not present but requested and there's a
1004 // classification collection filter specified.
1005 if (candidates.annotated_spans[0][i].classification.empty() &&
1006 model_->selection_options()->always_classify_suggested_selection() &&
1007 !filtered_collections_selection_.empty()) {
1008 if (!ModelClassifyText(context, /*cached_tokens=*/{},
1009 detected_text_language_tags,
1010 candidates.annotated_spans[0][i].span, options,
1011 &interpreter_manager,
1012 /*embedding_cache=*/nullptr,
1013 &candidates.annotated_spans[0][i].classification,
1014 /*tokens=*/nullptr)) {
1015 return original_click_indices;
1016 }
1017 }
1018
1019 // Ignore if span classification is filtered.
1020 if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1021 return original_click_indices;
1022 }
1023
1024 // We return a suggested span contains the original span.
1025 // This compensates for "select all" selection that may come from
1026 // other apps. See http://b/179890518.
1027 if (SpanContains(candidates.annotated_spans[0][i].span,
1028 original_click_indices)) {
1029 return candidates.annotated_spans[0][i].span;
1030 }
1031 }
1032 }
1033
1034 return original_click_indices;
1035 }
1036
1037 namespace {
1038 // Helper function that returns the index of the first candidate that
1039 // transitively does not overlap with the candidate on 'start_index'. If the end
1040 // of 'candidates' is reached, it returns the index that points right behind the
1041 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1042 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1043 int start_index) {
1044 int first_non_overlapping = start_index + 1;
1045 CodepointSpan conflicting_span = candidates[start_index].span;
1046 while (
1047 first_non_overlapping < candidates.size() &&
1048 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1049 // Grow the span to include the current one.
1050 conflicting_span.second = std::max(
1051 conflicting_span.second, candidates[first_non_overlapping].span.second);
1052
1053 ++first_non_overlapping;
1054 }
1055 return first_non_overlapping;
1056 }
1057 } // namespace
1058
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1059 bool Annotator::ResolveConflicts(
1060 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1061 const std::vector<Token>& cached_tokens,
1062 const std::vector<Locale>& detected_text_language_tags,
1063 const BaseOptions& options, InterpreterManager* interpreter_manager,
1064 std::vector<int>* result) const {
1065 result->clear();
1066 result->reserve(candidates.size());
1067 for (int i = 0; i < candidates.size();) {
1068 int first_non_overlapping =
1069 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1070
1071 const bool conflict_found = first_non_overlapping != (i + 1);
1072 if (conflict_found) {
1073 std::vector<int> candidate_indices;
1074 if (!ResolveConflict(context, cached_tokens, candidates,
1075 detected_text_language_tags, i,
1076 first_non_overlapping, options, interpreter_manager,
1077 &candidate_indices)) {
1078 return false;
1079 }
1080 result->insert(result->end(), candidate_indices.begin(),
1081 candidate_indices.end());
1082 } else {
1083 result->push_back(i);
1084 }
1085
1086 // Skip over the whole conflicting group/go to next candidate.
1087 i = first_non_overlapping;
1088 }
1089 return true;
1090 }
1091
1092 namespace {
1093 // Returns true, if the given two sources do conflict in given annotation
1094 // usecase.
1095 // - In SMART usecase, all sources do conflict, because there's only 1 possible
1096 // annotation for a given span.
1097 // - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1098 // and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1099 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1100 const AnnotatedSpan::Source source1,
1101 const AnnotatedSpan::Source source2) {
1102 uint32 source_mask =
1103 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1104
1105 switch (annotation_usecase) {
1106 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1107 // In the SMART mode, all annotations conflict.
1108 return true;
1109
1110 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1111 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1112 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1113 // hours" (duration).
1114 if ((source_mask &
1115 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1116 (source_mask &
1117 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1118 return false;
1119 }
1120
1121 // A KNOWLEDGE entity does not conflict with anything.
1122 if ((source_mask &
1123 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1124 return false;
1125 }
1126
1127 // A PERSONNAME entity does not conflict with anything.
1128 if ((source_mask &
1129 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1130 return false;
1131 }
1132
1133 // Entities from other sources can conflict.
1134 return true;
1135 }
1136 }
1137 } // namespace
1138
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1139 bool Annotator::ResolveConflict(
1140 const std::string& context, const std::vector<Token>& cached_tokens,
1141 const std::vector<AnnotatedSpan>& candidates,
1142 const std::vector<Locale>& detected_text_language_tags, int start_index,
1143 int end_index, const BaseOptions& options,
1144 InterpreterManager* interpreter_manager,
1145 std::vector<int>* chosen_indices) const {
1146 std::vector<int> conflicting_indices;
1147 std::unordered_map<int, std::pair<float, int>> scores_lengths;
1148 for (int i = start_index; i < end_index; ++i) {
1149 conflicting_indices.push_back(i);
1150 if (!candidates[i].classification.empty()) {
1151 scores_lengths[i] = {
1152 GetPriorityScore(candidates[i].classification),
1153 candidates[i].span.second - candidates[i].span.first};
1154 continue;
1155 }
1156
1157 // OPTIMIZATION: So that we don't have to classify all the ML model
1158 // spans apriori, we wait until we get here, when they conflict with
1159 // something and we need the actual classification scores. So if the
1160 // candidate conflicts and comes from the model, we need to run a
1161 // classification to determine its priority:
1162 std::vector<ClassificationResult> classification;
1163 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1164 candidates[i].span, options, interpreter_manager,
1165 /*embedding_cache=*/nullptr, &classification,
1166 /*tokens=*/nullptr)) {
1167 return false;
1168 }
1169
1170 if (!classification.empty()) {
1171 scores_lengths[i] = {
1172 GetPriorityScore(classification),
1173 candidates[i].span.second - candidates[i].span.first};
1174 }
1175 }
1176
1177 std::stable_sort(
1178 conflicting_indices.begin(), conflicting_indices.end(),
1179 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1180 if (scores_lengths[i].first == scores_lengths[j].first &&
1181 prioritize_longest_annotation_) {
1182 return scores_lengths[i].second > scores_lengths[j].second;
1183 }
1184 return scores_lengths[i].first > scores_lengths[j].first;
1185 });
1186
1187 // Here we keep a set of indices that were chosen, per-source, to enable
1188 // effective computation.
1189 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1190 chosen_indices_for_source_map;
1191
1192 // Greedily place the candidates if they don't conflict with the already
1193 // placed ones.
1194 for (int i = 0; i < conflicting_indices.size(); ++i) {
1195 const int considered_candidate = conflicting_indices[i];
1196
1197 // See if there is a conflict between the candidate and all already placed
1198 // candidates.
1199 bool conflict = false;
1200 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1201 for (auto& source_set_pair : chosen_indices_for_source_map) {
1202 if (source_set_pair.first == candidates[considered_candidate].source) {
1203 chosen_indices_for_source_ptr = &source_set_pair.second;
1204 }
1205
1206 const bool needs_conflict_resolution =
1207 options.annotation_usecase ==
1208 AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1209 (options.annotation_usecase ==
1210 AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1211 do_conflict_resolution_in_raw_mode_);
1212 if (needs_conflict_resolution &&
1213 DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1214 candidates[considered_candidate].source) &&
1215 DoesCandidateConflict(considered_candidate, candidates,
1216 source_set_pair.second)) {
1217 conflict = true;
1218 break;
1219 }
1220 }
1221
1222 // Skip the candidate if a conflict was found.
1223 if (conflict) {
1224 continue;
1225 }
1226
1227 // If the set of indices for the current source doesn't exist yet,
1228 // initialize it.
1229 if (chosen_indices_for_source_ptr == nullptr) {
1230 SortedIntSet new_set([&candidates](int a, int b) {
1231 return candidates[a].span.first < candidates[b].span.first;
1232 });
1233 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1234 std::move(new_set);
1235 chosen_indices_for_source_ptr =
1236 &chosen_indices_for_source_map[candidates[considered_candidate]
1237 .source];
1238 }
1239
1240 // Place the candidate to the output and to the per-source conflict set.
1241 chosen_indices->push_back(considered_candidate);
1242 chosen_indices_for_source_ptr->insert(considered_candidate);
1243 }
1244
1245 std::stable_sort(chosen_indices->begin(), chosen_indices->end());
1246
1247 return true;
1248 }
1249
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1250 bool Annotator::ModelSuggestSelection(
1251 const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1252 const std::vector<Locale>& detected_text_language_tags,
1253 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1254 std::vector<AnnotatedSpan>* result) const {
1255 if (model_->triggering_options() == nullptr ||
1256 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1257 return true;
1258 }
1259
1260 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1261 ml_model_triggering_locales_,
1262 /*default_value=*/true)) {
1263 return true;
1264 }
1265
1266 int click_pos;
1267 *tokens = selection_feature_processor_->Tokenize(context_unicode);
1268 const auto [click_begin, click_end] =
1269 CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1270 selection_feature_processor_->RetokenizeAndFindClick(
1271 context_unicode, click_begin, click_end, click_indices,
1272 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1273 tokens, &click_pos);
1274 if (click_pos == kInvalidIndex) {
1275 TC3_VLOG(1) << "Could not calculate the click position.";
1276 return false;
1277 }
1278
1279 const int symmetry_context_size =
1280 model_->selection_options()->symmetry_context_size();
1281 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1282 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1283 ->bounds_sensitive_features();
1284
1285 // The symmetry context span is the clicked token with symmetry_context_size
1286 // tokens on either side.
1287 const TokenSpan symmetry_context_span =
1288 IntersectTokenSpans(TokenSpan(click_pos).Expand(
1289 /*num_tokens_left=*/symmetry_context_size,
1290 /*num_tokens_right=*/symmetry_context_size),
1291 AllOf(*tokens));
1292
1293 // Compute the extraction span based on the model type.
1294 TokenSpan extraction_span;
1295 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1296 // The extraction span is the symmetry context span expanded to include
1297 // max_selection_span tokens on either side, which is how far a selection
1298 // can stretch from the click, plus a relevant number of tokens outside of
1299 // the bounds of the selection.
1300 const int max_selection_span =
1301 selection_feature_processor_->GetOptions()->max_selection_span();
1302 extraction_span = symmetry_context_span.Expand(
1303 /*num_tokens_left=*/max_selection_span +
1304 bounds_sensitive_features->num_tokens_before(),
1305 /*num_tokens_right=*/max_selection_span +
1306 bounds_sensitive_features->num_tokens_after());
1307 } else {
1308 // The extraction span is the symmetry context span expanded to include
1309 // context_size tokens on either side.
1310 const int context_size =
1311 selection_feature_processor_->GetOptions()->context_size();
1312 extraction_span = symmetry_context_span.Expand(
1313 /*num_tokens_left=*/context_size,
1314 /*num_tokens_right=*/context_size);
1315 }
1316 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1317
1318 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1319 *tokens, extraction_span)) {
1320 return true;
1321 }
1322
1323 std::unique_ptr<CachedFeatures> cached_features;
1324 if (!selection_feature_processor_->ExtractFeatures(
1325 *tokens, extraction_span,
1326 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1327 embedding_executor_.get(),
1328 /*embedding_cache=*/nullptr,
1329 selection_feature_processor_->EmbeddingSize() +
1330 selection_feature_processor_->DenseFeaturesCount(),
1331 &cached_features)) {
1332 TC3_LOG(ERROR) << "Could not extract features.";
1333 return false;
1334 }
1335
1336 // Produce selection model candidates.
1337 std::vector<TokenSpan> chunks;
1338 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1339 interpreter_manager->SelectionInterpreter(), *cached_features,
1340 &chunks)) {
1341 TC3_LOG(ERROR) << "Could not chunk.";
1342 return false;
1343 }
1344
1345 for (const TokenSpan& chunk : chunks) {
1346 AnnotatedSpan candidate;
1347 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1348 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1349 if (model_->selection_options()->strip_unpaired_brackets()) {
1350 candidate.span =
1351 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1352 }
1353
1354 // Only output non-empty spans.
1355 if (candidate.span.first != candidate.span.second) {
1356 result->push_back(candidate);
1357 }
1358 }
1359 return true;
1360 }
1361
1362 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1363 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1364 const CodepointSpan& selection_indices,
1365 TokenSpan tokens_around_selection_to_copy) {
1366 const auto first_selection_token = std::upper_bound(
1367 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1368 [](int selection_start, const Token& token) {
1369 return selection_start < token.end;
1370 });
1371 const auto last_selection_token = std::lower_bound(
1372 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1373 [](const Token& token, int selection_end) {
1374 return token.start < selection_end;
1375 });
1376
1377 const int64 first_token = std::max(
1378 static_cast<int64>(0),
1379 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1380 tokens_around_selection_to_copy.first));
1381 const int64 last_token = std::min(
1382 static_cast<int64>(cached_tokens.size()),
1383 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1384 tokens_around_selection_to_copy.second));
1385
1386 std::vector<Token> tokens;
1387 tokens.reserve(last_token - first_token);
1388 for (int i = first_token; i < last_token; ++i) {
1389 tokens.push_back(cached_tokens[i]);
1390 }
1391 return tokens;
1392 }
1393 } // namespace internal
1394
ClassifyTextUpperBoundNeededTokens() const1395 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1396 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1397 bounds_sensitive_features =
1398 classification_feature_processor_->GetOptions()
1399 ->bounds_sensitive_features();
1400 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1401 // The extraction span is the selection span expanded to include a relevant
1402 // number of tokens outside of the bounds of the selection.
1403 return {bounds_sensitive_features->num_tokens_before(),
1404 bounds_sensitive_features->num_tokens_after()};
1405 } else {
1406 // The extraction span is the clicked token with context_size tokens on
1407 // either side.
1408 const int context_size =
1409 selection_feature_processor_->GetOptions()->context_size();
1410 return {context_size, context_size};
1411 }
1412 }
1413
1414 namespace {
1415 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1416 void SortClassificationResults(
1417 std::vector<ClassificationResult>* classification_results) {
1418 std::stable_sort(
1419 classification_results->begin(), classification_results->end(),
1420 [](const ClassificationResult& a, const ClassificationResult& b) {
1421 return a.score > b.score;
1422 });
1423 }
1424 } // namespace
1425
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1426 bool Annotator::ModelClassifyText(
1427 const std::string& context, const std::vector<Token>& cached_tokens,
1428 const std::vector<Locale>& detected_text_language_tags,
1429 const CodepointSpan& selection_indices, const BaseOptions& options,
1430 InterpreterManager* interpreter_manager,
1431 FeatureProcessor::EmbeddingCache* embedding_cache,
1432 std::vector<ClassificationResult>* classification_results,
1433 std::vector<Token>* tokens) const {
1434 const UnicodeText context_unicode =
1435 UTF8ToUnicodeText(context, /*do_copy=*/false);
1436 const auto [span_begin, span_end] =
1437 CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1438 return ModelClassifyText(context_unicode, cached_tokens,
1439 detected_text_language_tags, span_begin, span_end,
1440 /*line=*/nullptr, selection_indices, options,
1441 interpreter_manager, embedding_cache,
1442 classification_results, tokens);
1443 }
1444
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1445 bool Annotator::ModelClassifyText(
1446 const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1447 const std::vector<Locale>& detected_text_language_tags,
1448 const UnicodeText::const_iterator& span_begin,
1449 const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1450 const CodepointSpan& selection_indices, const BaseOptions& options,
1451 InterpreterManager* interpreter_manager,
1452 FeatureProcessor::EmbeddingCache* embedding_cache,
1453 std::vector<ClassificationResult>* classification_results,
1454 std::vector<Token>* tokens) const {
1455 if (model_->triggering_options() == nullptr ||
1456 !(model_->triggering_options()->enabled_modes() &
1457 ModeFlag_CLASSIFICATION)) {
1458 return true;
1459 }
1460
1461 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1462 ml_model_triggering_locales_,
1463 /*default_value=*/true)) {
1464 return true;
1465 }
1466
1467 std::vector<Token> local_tokens;
1468 if (tokens == nullptr) {
1469 tokens = &local_tokens;
1470 }
1471
1472 if (cached_tokens.empty()) {
1473 *tokens = classification_feature_processor_->Tokenize(context_unicode);
1474 } else {
1475 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1476 ClassifyTextUpperBoundNeededTokens());
1477 }
1478
1479 int click_pos;
1480 classification_feature_processor_->RetokenizeAndFindClick(
1481 context_unicode, span_begin, span_end, selection_indices,
1482 classification_feature_processor_->GetOptions()
1483 ->only_use_line_with_click(),
1484 tokens, &click_pos);
1485 const TokenSpan selection_token_span =
1486 CodepointSpanToTokenSpan(*tokens, selection_indices);
1487 const int selection_num_tokens = selection_token_span.Size();
1488 if (model_->classification_options()->max_num_tokens() > 0 &&
1489 model_->classification_options()->max_num_tokens() <
1490 selection_num_tokens) {
1491 *classification_results = {{Collections::Other(), 1.0}};
1492 return true;
1493 }
1494
1495 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1496 bounds_sensitive_features =
1497 classification_feature_processor_->GetOptions()
1498 ->bounds_sensitive_features();
1499 if (selection_token_span.first == kInvalidIndex ||
1500 selection_token_span.second == kInvalidIndex) {
1501 TC3_LOG(ERROR) << "Could not determine span.";
1502 return false;
1503 }
1504
1505 // Compute the extraction span based on the model type.
1506 TokenSpan extraction_span;
1507 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1508 // The extraction span is the selection span expanded to include a relevant
1509 // number of tokens outside of the bounds of the selection.
1510 extraction_span = selection_token_span.Expand(
1511 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1512 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1513 } else {
1514 if (click_pos == kInvalidIndex) {
1515 TC3_LOG(ERROR) << "Couldn't choose a click position.";
1516 return false;
1517 }
1518 // The extraction span is the clicked token with context_size tokens on
1519 // either side.
1520 const int context_size =
1521 classification_feature_processor_->GetOptions()->context_size();
1522 extraction_span = TokenSpan(click_pos).Expand(
1523 /*num_tokens_left=*/context_size,
1524 /*num_tokens_right=*/context_size);
1525 }
1526 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1527
1528 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1529 *tokens, extraction_span)) {
1530 *classification_results = {{Collections::Other(), 1.0}};
1531 return true;
1532 }
1533
1534 std::unique_ptr<CachedFeatures> cached_features;
1535 if (!classification_feature_processor_->ExtractFeatures(
1536 *tokens, extraction_span, selection_indices,
1537 embedding_executor_.get(), embedding_cache,
1538 classification_feature_processor_->EmbeddingSize() +
1539 classification_feature_processor_->DenseFeaturesCount(),
1540 &cached_features)) {
1541 TC3_LOG(ERROR) << "Could not extract features.";
1542 return false;
1543 }
1544
1545 std::vector<float> features;
1546 features.reserve(cached_features->OutputFeaturesSize());
1547 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1548 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1549 &features);
1550 } else {
1551 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1552 }
1553
1554 TensorView<float> logits = classification_executor_->ComputeLogits(
1555 TensorView<float>(features.data(),
1556 {1, static_cast<int>(features.size())}),
1557 interpreter_manager->ClassificationInterpreter());
1558 if (!logits.is_valid()) {
1559 TC3_LOG(ERROR) << "Couldn't compute logits.";
1560 return false;
1561 }
1562
1563 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1564 logits.dim(1) != classification_feature_processor_->NumCollections()) {
1565 TC3_LOG(ERROR) << "Mismatching output";
1566 return false;
1567 }
1568
1569 const std::vector<float> scores =
1570 ComputeSoftmax(logits.data(), logits.dim(1));
1571
1572 if (scores.empty()) {
1573 *classification_results = {{Collections::Other(), 1.0}};
1574 return true;
1575 }
1576
1577 const int best_score_index =
1578 std::max_element(scores.begin(), scores.end()) - scores.begin();
1579 const std::string top_collection =
1580 classification_feature_processor_->LabelToCollection(best_score_index);
1581
1582 // Sanity checks.
1583 if (top_collection == Collections::Phone()) {
1584 const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1585 if (digit_count <
1586 model_->classification_options()->phone_min_num_digits() ||
1587 digit_count >
1588 model_->classification_options()->phone_max_num_digits()) {
1589 *classification_results = {{Collections::Other(), 1.0}};
1590 return true;
1591 }
1592 } else if (top_collection == Collections::Address()) {
1593 if (selection_num_tokens <
1594 model_->classification_options()->address_min_num_tokens()) {
1595 *classification_results = {{Collections::Other(), 1.0}};
1596 return true;
1597 }
1598 } else if (top_collection == Collections::Dictionary()) {
1599 if ((options.use_vocab_annotator && vocab_annotator_) ||
1600 !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1601 dictionary_locales_,
1602 /*default_value=*/false)) {
1603 *classification_results = {{Collections::Other(), 1.0}};
1604 return true;
1605 }
1606 }
1607 *classification_results = {{top_collection, /*arg_score=*/1.0,
1608 /*arg_priority_score=*/scores[best_score_index]}};
1609
1610 // For some entities, we might want to clamp the priority score, for better
1611 // conflict resolution between entities.
1612 if (model_->triggering_options() != nullptr &&
1613 model_->triggering_options()->collection_to_priority() != nullptr) {
1614 if (auto entry =
1615 model_->triggering_options()->collection_to_priority()->LookupByKey(
1616 top_collection.c_str())) {
1617 (*classification_results)[0].priority_score *= entry->value();
1618 }
1619 }
1620 return true;
1621 }
1622
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1623 bool Annotator::RegexClassifyText(
1624 const std::string& context, const CodepointSpan& selection_indices,
1625 std::vector<ClassificationResult>* classification_result) const {
1626 const std::string selection_text =
1627 UTF8ToUnicodeText(context, /*do_copy=*/false)
1628 .UTF8Substring(selection_indices.first, selection_indices.second);
1629 const UnicodeText selection_text_unicode(
1630 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1631
1632 // Check whether any of the regular expressions match.
1633 for (const int pattern_id : classification_regex_patterns_) {
1634 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1635 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1636 regex_pattern.pattern->Matcher(selection_text_unicode);
1637 int status = UniLib::RegexMatcher::kNoError;
1638 bool matches;
1639 if (regex_pattern.config->use_approximate_matching()) {
1640 matches = matcher->ApproximatelyMatches(&status);
1641 } else {
1642 matches = matcher->Matches(&status);
1643 }
1644 if (status != UniLib::RegexMatcher::kNoError) {
1645 return false;
1646 }
1647 if (matches && VerifyRegexMatchCandidate(
1648 context, regex_pattern.config->verification_options(),
1649 selection_text, matcher.get())) {
1650 classification_result->push_back(
1651 {regex_pattern.config->collection_name()->str(),
1652 regex_pattern.config->target_classification_score(),
1653 regex_pattern.config->priority_score()});
1654 if (!SerializedEntityDataFromRegexMatch(
1655 regex_pattern.config, matcher.get(),
1656 &classification_result->back().serialized_entity_data)) {
1657 TC3_LOG(ERROR) << "Could not get entity data.";
1658 return false;
1659 }
1660 }
1661 }
1662
1663 return true;
1664 }
1665
1666 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1667 std::string PickCollectionForDatetime(
1668 const DatetimeParseResult& datetime_parse_result) {
1669 switch (datetime_parse_result.granularity) {
1670 case GRANULARITY_HOUR:
1671 case GRANULARITY_MINUTE:
1672 case GRANULARITY_SECOND:
1673 return Collections::DateTime();
1674 default:
1675 return Collections::Date();
1676 }
1677 }
1678
1679 } // namespace
1680
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1681 bool Annotator::DatetimeClassifyText(
1682 const std::string& context, const CodepointSpan& selection_indices,
1683 const ClassificationOptions& options,
1684 std::vector<ClassificationResult>* classification_results) const {
1685 if (!datetime_parser_) {
1686 return true;
1687 }
1688
1689 const std::string selection_text =
1690 UTF8ToUnicodeText(context, /*do_copy=*/false)
1691 .UTF8Substring(selection_indices.first, selection_indices.second);
1692
1693 LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1694 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1695 datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1696 options.reference_timezone, locale_list,
1697 ModeFlag_CLASSIFICATION,
1698 options.annotation_usecase,
1699 /*anchor_start_end=*/true);
1700 if (!result_status.ok()) {
1701 TC3_LOG(ERROR) << "Error during parsing datetime.";
1702 return false;
1703 }
1704
1705 for (const DatetimeParseResultSpan& datetime_span :
1706 result_status.ValueOrDie()) {
1707 // Only consider the result valid if the selection and extracted datetime
1708 // spans exactly match.
1709 if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1710 datetime_span.span.second + selection_indices.first) ==
1711 selection_indices) {
1712 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1713 classification_results->emplace_back(
1714 PickCollectionForDatetime(parse_result),
1715 datetime_span.target_classification_score);
1716 classification_results->back().datetime_parse_result = parse_result;
1717 classification_results->back().serialized_entity_data =
1718 CreateDatetimeSerializedEntityData(parse_result);
1719 classification_results->back().priority_score =
1720 datetime_span.priority_score;
1721 }
1722 return true;
1723 }
1724 }
1725 return true;
1726 }
1727
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1728 std::vector<ClassificationResult> Annotator::ClassifyText(
1729 const std::string& context, const CodepointSpan& selection_indices,
1730 const ClassificationOptions& options) const {
1731 if (context.size() > std::numeric_limits<int>::max()) {
1732 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1733 return {};
1734 }
1735 if (!initialized_) {
1736 TC3_LOG(ERROR) << "Not initialized";
1737 return {};
1738 }
1739 if (options.annotation_usecase !=
1740 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1741 TC3_LOG(WARNING)
1742 << "Invoking ClassifyText, which is not supported in RAW mode.";
1743 return {};
1744 }
1745 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1746 return {};
1747 }
1748
1749 std::vector<Locale> detected_text_language_tags;
1750 if (!ParseLocales(options.detected_text_language_tags,
1751 &detected_text_language_tags)) {
1752 TC3_LOG(WARNING)
1753 << "Failed to parse the detected_text_language_tags in options: "
1754 << options.detected_text_language_tags;
1755 }
1756 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1757 model_triggering_locales_,
1758 /*default_value=*/true)) {
1759 return {};
1760 }
1761
1762 const UnicodeText context_unicode =
1763 UTF8ToUnicodeText(context, /*do_copy=*/false);
1764
1765 if (!unilib_->IsValidUtf8(context_unicode)) {
1766 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1767 return {};
1768 }
1769
1770 if (!IsValidSpanInput(context_unicode, selection_indices)) {
1771 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1772 << selection_indices.first << " " << selection_indices.second;
1773 return {};
1774 }
1775
1776 // We'll accumulate a list of candidates, and pick the best candidate in the
1777 // end.
1778 std::vector<AnnotatedSpan> candidates;
1779
1780 // Try the knowledge engine.
1781 // TODO(b/126579108): Propagate error status.
1782 ClassificationResult knowledge_result;
1783 if (knowledge_engine_ &&
1784 knowledge_engine_
1785 ->ClassifyText(context, selection_indices, options.annotation_usecase,
1786 options.location_context, Permissions(),
1787 &knowledge_result)
1788 .ok()) {
1789 candidates.push_back({selection_indices, {knowledge_result}});
1790 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1791 }
1792
1793 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1794
1795 // Try the contact engine.
1796 // TODO(b/126579108): Propagate error status.
1797 ClassificationResult contact_result;
1798 if (contact_engine_ && contact_engine_->ClassifyText(
1799 context, selection_indices, &contact_result)) {
1800 candidates.push_back({selection_indices, {contact_result}});
1801 }
1802
1803 // Try the person name engine.
1804 ClassificationResult person_name_result;
1805 if (person_name_engine_ &&
1806 person_name_engine_->ClassifyText(context, selection_indices,
1807 &person_name_result)) {
1808 candidates.push_back({selection_indices, {person_name_result}});
1809 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1810 }
1811
1812 // Try the installed app engine.
1813 // TODO(b/126579108): Propagate error status.
1814 ClassificationResult installed_app_result;
1815 if (installed_app_engine_ &&
1816 installed_app_engine_->ClassifyText(context, selection_indices,
1817 &installed_app_result)) {
1818 candidates.push_back({selection_indices, {installed_app_result}});
1819 }
1820
1821 // Try the regular expression models.
1822 std::vector<ClassificationResult> regex_results;
1823 if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
1824 return {};
1825 }
1826 for (const ClassificationResult& result : regex_results) {
1827 candidates.push_back({selection_indices, {result}});
1828 }
1829
1830 // Try the date model.
1831 //
1832 // DatetimeClassifyText only returns the first result, which can however have
1833 // more interpretations. They are inserted in the candidates as a single
1834 // AnnotatedSpan, so that they get treated together by the conflict resolution
1835 // algorithm.
1836 std::vector<ClassificationResult> datetime_results;
1837 if (!DatetimeClassifyText(context, selection_indices, options,
1838 &datetime_results)) {
1839 return {};
1840 }
1841 if (!datetime_results.empty()) {
1842 candidates.push_back({selection_indices, std::move(datetime_results)});
1843 candidates.back().source = AnnotatedSpan::Source::DATETIME;
1844 }
1845
1846 // Try the number annotator.
1847 // TODO(b/126579108): Propagate error status.
1848 ClassificationResult number_annotator_result;
1849 if (number_annotator_ &&
1850 number_annotator_->ClassifyText(context_unicode, selection_indices,
1851 options.annotation_usecase,
1852 &number_annotator_result)) {
1853 candidates.push_back({selection_indices, {number_annotator_result}});
1854 }
1855
1856 // Try the duration annotator.
1857 ClassificationResult duration_annotator_result;
1858 if (duration_annotator_ &&
1859 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1860 options.annotation_usecase,
1861 &duration_annotator_result)) {
1862 candidates.push_back({selection_indices, {duration_annotator_result}});
1863 candidates.back().source = AnnotatedSpan::Source::DURATION;
1864 }
1865
1866 // Try the translate annotator.
1867 ClassificationResult translate_annotator_result;
1868 if (translate_annotator_ &&
1869 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1870 options.user_familiar_language_tags,
1871 &translate_annotator_result)) {
1872 candidates.push_back({selection_indices, {translate_annotator_result}});
1873 }
1874
1875 // Try the grammar model.
1876 ClassificationResult grammar_annotator_result;
1877 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1878 detected_text_language_tags, context_unicode,
1879 selection_indices, &grammar_annotator_result)) {
1880 candidates.push_back({selection_indices, {grammar_annotator_result}});
1881 }
1882
1883 ClassificationResult pod_ner_annotator_result;
1884 if (pod_ner_annotator_ && options.use_pod_ner &&
1885 pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1886 &pod_ner_annotator_result)) {
1887 candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1888 }
1889
1890 ClassificationResult vocab_annotator_result;
1891 if (vocab_annotator_ && options.use_vocab_annotator &&
1892 vocab_annotator_->ClassifyText(
1893 context_unicode, selection_indices, detected_text_language_tags,
1894 options.trigger_dictionary_on_beginner_words,
1895 &vocab_annotator_result)) {
1896 candidates.push_back({selection_indices, {vocab_annotator_result}});
1897 }
1898
1899 if (experimental_annotator_) {
1900 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1901 candidates);
1902 }
1903
1904 // Try the ML model.
1905 //
1906 // The output of the model is considered as an exclusive 1-of-N choice. That's
1907 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1908 // span for each candidate, like e.g. the regex model.
1909 InterpreterManager interpreter_manager(selection_executor_.get(),
1910 classification_executor_.get());
1911 std::vector<ClassificationResult> model_results;
1912 std::vector<Token> tokens;
1913 if (!ModelClassifyText(
1914 context, /*cached_tokens=*/{}, detected_text_language_tags,
1915 selection_indices, options, &interpreter_manager,
1916 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1917 return {};
1918 }
1919 if (!model_results.empty()) {
1920 candidates.push_back({selection_indices, std::move(model_results)});
1921 }
1922
1923 std::vector<int> candidate_indices;
1924 if (!ResolveConflicts(candidates, context, tokens,
1925 detected_text_language_tags, options,
1926 &interpreter_manager, &candidate_indices)) {
1927 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1928 return {};
1929 }
1930
1931 std::vector<ClassificationResult> results;
1932 for (const int i : candidate_indices) {
1933 for (const ClassificationResult& result : candidates[i].classification) {
1934 if (!FilteredForClassification(result)) {
1935 results.push_back(result);
1936 }
1937 }
1938 }
1939
1940 // Sort results according to score.
1941 std::stable_sort(
1942 results.begin(), results.end(),
1943 [](const ClassificationResult& a, const ClassificationResult& b) {
1944 return a.score > b.score;
1945 });
1946
1947 if (results.empty()) {
1948 results = {{Collections::Other(), 1.0}};
1949 }
1950 return results;
1951 }
1952
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1953 bool Annotator::ModelAnnotate(
1954 const std::string& context,
1955 const std::vector<Locale>& detected_text_language_tags,
1956 const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1957 std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1958 bool skip_model_annotatation = false;
1959 if (model_->triggering_options() == nullptr ||
1960 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1961 skip_model_annotatation = true;
1962 }
1963 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1964 ml_model_triggering_locales_,
1965 /*default_value=*/true)) {
1966 skip_model_annotatation = true;
1967 }
1968
1969 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1970 /*do_copy=*/false);
1971 std::vector<UnicodeTextRange> lines;
1972 if (!selection_feature_processor_ ||
1973 !selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1974 lines.push_back({context_unicode.begin(), context_unicode.end()});
1975 } else {
1976 lines = selection_feature_processor_->SplitContext(
1977 context_unicode, selection_feature_processor_->GetOptions()
1978 ->use_pipe_character_for_newline());
1979 }
1980
1981 const float min_annotate_confidence =
1982 (model_->triggering_options() != nullptr
1983 ? model_->triggering_options()->min_annotate_confidence()
1984 : 0.f);
1985
1986 for (const UnicodeTextRange& line : lines) {
1987 const std::string line_str =
1988 UnicodeText::UTF8Substring(line.first, line.second);
1989
1990 std::vector<Token> line_tokens;
1991 line_tokens = selection_feature_processor_->Tokenize(line_str);
1992
1993 selection_feature_processor_->RetokenizeAndFindClick(
1994 line_str, {0, std::distance(line.first, line.second)},
1995 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1996 &line_tokens,
1997 /*click_pos=*/nullptr);
1998 const TokenSpan full_line_span = {
1999 0, static_cast<TokenIndex>(line_tokens.size())};
2000
2001 tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2002
2003 if (skip_model_annotatation) {
2004 // We do not annotate, we only output the tokens.
2005 continue;
2006 }
2007
2008 // TODO(zilka): Add support for greater granularity of this check.
2009 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
2010 line_tokens, full_line_span)) {
2011 continue;
2012 }
2013
2014 std::unique_ptr<CachedFeatures> cached_features;
2015 if (!selection_feature_processor_->ExtractFeatures(
2016 line_tokens, full_line_span,
2017 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2018 embedding_executor_.get(),
2019 /*embedding_cache=*/nullptr,
2020 selection_feature_processor_->EmbeddingSize() +
2021 selection_feature_processor_->DenseFeaturesCount(),
2022 &cached_features)) {
2023 TC3_LOG(ERROR) << "Could not extract features.";
2024 return false;
2025 }
2026
2027 std::vector<TokenSpan> local_chunks;
2028 if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2029 interpreter_manager->SelectionInterpreter(),
2030 *cached_features, &local_chunks)) {
2031 TC3_LOG(ERROR) << "Could not chunk.";
2032 return false;
2033 }
2034
2035 const int offset = std::distance(context_unicode.begin(), line.first);
2036 if (local_chunks.empty()) {
2037 continue;
2038 }
2039 const UnicodeText line_unicode =
2040 UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2041 std::vector<UnicodeText::const_iterator> line_codepoints =
2042 line_unicode.Codepoints();
2043 line_codepoints.push_back(line_unicode.end());
2044
2045 FeatureProcessor::EmbeddingCache embedding_cache;
2046 for (const TokenSpan& chunk : local_chunks) {
2047 CodepointSpan codepoint_span =
2048 TokenSpanToCodepointSpan(line_tokens, chunk);
2049 if (!codepoint_span.IsValid() ||
2050 codepoint_span.second > line_codepoints.size()) {
2051 continue;
2052 }
2053 codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2054 /*span_begin=*/line_codepoints[codepoint_span.first],
2055 /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
2056 if (model_->selection_options()->strip_unpaired_brackets()) {
2057 codepoint_span = StripUnpairedBrackets(
2058 /*span_begin=*/line_codepoints[codepoint_span.first],
2059 /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
2060 *unilib_);
2061 }
2062
2063 // Skip empty spans.
2064 if (codepoint_span.first != codepoint_span.second) {
2065 std::vector<ClassificationResult> classification;
2066 if (!ModelClassifyText(
2067 line_unicode, line_tokens, detected_text_language_tags,
2068 /*span_begin=*/line_codepoints[codepoint_span.first],
2069 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2070 codepoint_span, options, interpreter_manager, &embedding_cache,
2071 &classification, /*tokens=*/nullptr)) {
2072 TC3_LOG(ERROR) << "Could not classify text: "
2073 << (codepoint_span.first + offset) << " "
2074 << (codepoint_span.second + offset);
2075 return false;
2076 }
2077
2078 // Do not include the span if it's classified as "other".
2079 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2080 classification[0].score >= min_annotate_confidence) {
2081 AnnotatedSpan result_span;
2082 result_span.span = {codepoint_span.first + offset,
2083 codepoint_span.second + offset};
2084 result_span.classification = std::move(classification);
2085 result->push_back(std::move(result_span));
2086 }
2087 }
2088 }
2089 }
2090 return true;
2091 }
2092
SelectionFeatureProcessorForTests() const2093 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2094 return selection_feature_processor_.get();
2095 }
2096
ClassificationFeatureProcessorForTests() const2097 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2098 const {
2099 return classification_feature_processor_.get();
2100 }
2101
DatetimeParserForTests() const2102 const DatetimeParser* Annotator::DatetimeParserForTests() const {
2103 return datetime_parser_.get();
2104 }
2105
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2106 void Annotator::RemoveNotEnabledEntityTypes(
2107 const EnabledEntityTypes& is_entity_type_enabled,
2108 std::vector<AnnotatedSpan>* annotated_spans) const {
2109 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2110 std::vector<ClassificationResult>& classifications =
2111 annotated_span.classification;
2112 classifications.erase(
2113 std::remove_if(classifications.begin(), classifications.end(),
2114 [&is_entity_type_enabled](
2115 const ClassificationResult& classification_result) {
2116 return !is_entity_type_enabled(
2117 classification_result.collection);
2118 }),
2119 classifications.end());
2120 }
2121 annotated_spans->erase(
2122 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2123 [](const AnnotatedSpan& annotated_span) {
2124 return annotated_span.classification.empty();
2125 }),
2126 annotated_spans->end());
2127 }
2128
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2129 void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2130 std::vector<AnnotatedSpan>* candidates) const {
2131 if (candidates == nullptr || contact_engine_ == nullptr) {
2132 return;
2133 }
2134 for (auto& candidate : *candidates) {
2135 for (auto& classification_result : candidate.classification) {
2136 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2137 &classification_result);
2138 }
2139 }
2140 }
2141
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2142 Status Annotator::AnnotateSingleInput(
2143 const std::string& context, const AnnotationOptions& options,
2144 std::vector<AnnotatedSpan>* candidates) const {
2145 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2146 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2147 }
2148
2149 const UnicodeText context_unicode =
2150 UTF8ToUnicodeText(context, /*do_copy=*/false);
2151
2152 std::vector<Locale> detected_text_language_tags;
2153 if (!ParseLocales(options.detected_text_language_tags,
2154 &detected_text_language_tags)) {
2155 TC3_LOG(WARNING)
2156 << "Failed to parse the detected_text_language_tags in options: "
2157 << options.detected_text_language_tags;
2158 }
2159 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2160 model_triggering_locales_,
2161 /*default_value=*/true)) {
2162 return Status(
2163 StatusCode::UNAVAILABLE,
2164 "The detected language tags are not in the supported locales.");
2165 }
2166
2167 InterpreterManager interpreter_manager(selection_executor_.get(),
2168 classification_executor_.get());
2169
2170 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2171 const bool is_raw_usecase =
2172 options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2173
2174 // Annotate with the selection model.
2175 const bool model_annotations_enabled =
2176 !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2177 std::vector<Token> tokens;
2178 if (model_annotations_enabled &&
2179 !ModelAnnotate(context, detected_text_language_tags, options,
2180 &interpreter_manager, &tokens, candidates)) {
2181 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2182 } else if (!model_annotations_enabled) {
2183 // If the ML model didn't run, we need to tokenize to support the other
2184 // annotators that depend on the tokens.
2185 // Optimization could be made to only do this when an annotator that uses
2186 // the tokens is enabled, but it's unclear if the added complexity is worth
2187 // it.
2188 if (selection_feature_processor_ != nullptr) {
2189 tokens = selection_feature_processor_->Tokenize(context_unicode);
2190 }
2191 }
2192
2193 // Annotate with the regular expression models.
2194 const bool regex_annotations_enabled =
2195 !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2196 if (regex_annotations_enabled &&
2197 !RegexChunk(
2198 UTF8ToUnicodeText(context, /*do_copy=*/false),
2199 annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2200 is_entity_type_enabled, options.annotation_usecase, candidates)) {
2201 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2202 }
2203
2204 // Annotate with the datetime model.
2205 // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2206 // relatively slow for some clients.
2207 if ((is_entity_type_enabled(Collections::Date()) ||
2208 is_entity_type_enabled(Collections::DateTime())) &&
2209 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2210 options.reference_time_ms_utc, options.reference_timezone,
2211 options.locales, ModeFlag_ANNOTATION,
2212 options.annotation_usecase,
2213 options.is_serialized_entity_data_enabled, candidates)) {
2214 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2215 }
2216
2217 // Annotate with the contact engine.
2218 const bool contact_annotations_enabled =
2219 !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2220 if (contact_annotations_enabled && contact_engine_ &&
2221 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2222 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2223 }
2224
2225 // Annotate with the installed app engine.
2226 const bool app_annotations_enabled =
2227 !is_raw_usecase || is_entity_type_enabled(Collections::App());
2228 if (app_annotations_enabled && installed_app_engine_ &&
2229 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2230 return Status(StatusCode::INTERNAL,
2231 "Couldn't run installed app engine Chunk.");
2232 }
2233
2234 // Annotate with the number annotator.
2235 const bool number_annotations_enabled =
2236 !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2237 is_entity_type_enabled(Collections::Percentage()));
2238 if (number_annotations_enabled && number_annotator_ != nullptr &&
2239 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2240 candidates)) {
2241 return Status(StatusCode::INTERNAL,
2242 "Couldn't run number annotator FindAll.");
2243 }
2244
2245 // Annotate with the duration annotator.
2246 const bool duration_annotations_enabled =
2247 !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2248 if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2249 !duration_annotator_->FindAll(context_unicode, tokens,
2250 options.annotation_usecase, candidates)) {
2251 return Status(StatusCode::INTERNAL,
2252 "Couldn't run duration annotator FindAll.");
2253 }
2254
2255 // Annotate with the person name engine.
2256 const bool person_annotations_enabled =
2257 !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2258 if (person_annotations_enabled && person_name_engine_ &&
2259 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2260 return Status(StatusCode::INTERNAL,
2261 "Couldn't run person name engine Chunk.");
2262 }
2263
2264 // Annotate with the grammar annotators.
2265 if (grammar_annotator_ != nullptr &&
2266 !grammar_annotator_->Annotate(detected_text_language_tags,
2267 context_unicode, candidates)) {
2268 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2269 }
2270
2271 // Annotate with the POD NER annotator.
2272 const bool pod_ner_annotations_enabled =
2273 !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2274 if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2275 options.use_pod_ner &&
2276 !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2277 return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2278 }
2279
2280 // Annotate with the vocab annotator.
2281 const bool vocab_annotations_enabled =
2282 !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2283 if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2284 options.use_vocab_annotator &&
2285 !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2286 options.trigger_dictionary_on_beginner_words,
2287 candidates)) {
2288 return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2289 }
2290
2291 // Annotate with the experimental annotator.
2292 if (experimental_annotator_ != nullptr &&
2293 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2294 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2295 }
2296
2297 // Sort candidates according to their position in the input, so that the next
2298 // code can assume that any connected component of overlapping spans forms a
2299 // contiguous block.
2300 // Also sort them according to the end position and collection, so that the
2301 // deduplication code below can assume that same spans and classifications
2302 // form contiguous blocks.
2303 std::stable_sort(candidates->begin(), candidates->end(),
2304 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2305 if (a.span.first != b.span.first) {
2306 return a.span.first < b.span.first;
2307 }
2308
2309 if (a.span.second != b.span.second) {
2310 return a.span.second < b.span.second;
2311 }
2312
2313 return a.classification[0].collection <
2314 b.classification[0].collection;
2315 });
2316
2317 std::vector<int> candidate_indices;
2318 if (!ResolveConflicts(*candidates, context, tokens,
2319 detected_text_language_tags, options,
2320 &interpreter_manager, &candidate_indices)) {
2321 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2322 }
2323
2324 // Remove candidates that overlap exactly and have the same collection.
2325 // This can e.g. happen for phone coming from both ML model and regex.
2326 candidate_indices.erase(
2327 std::unique(candidate_indices.begin(), candidate_indices.end(),
2328 [&candidates](const int a_index, const int b_index) {
2329 const AnnotatedSpan& a = (*candidates)[a_index];
2330 const AnnotatedSpan& b = (*candidates)[b_index];
2331 return a.span == b.span &&
2332 a.classification[0].collection ==
2333 b.classification[0].collection;
2334 }),
2335 candidate_indices.end());
2336
2337 std::vector<AnnotatedSpan> result;
2338 result.reserve(candidate_indices.size());
2339 for (const int i : candidate_indices) {
2340 if ((*candidates)[i].classification.empty() ||
2341 ClassifiedAsOther((*candidates)[i].classification) ||
2342 FilteredForAnnotation((*candidates)[i])) {
2343 continue;
2344 }
2345 result.push_back(std::move((*candidates)[i]));
2346 }
2347
2348 // We generate all candidates and remove them later (with the exception of
2349 // date/time/duration entities) because there are complex interdependencies
2350 // between the entity types. E.g., the TLD of an email can be interpreted as a
2351 // URL, but most likely a user of the API does not want such annotations if
2352 // "url" is enabled and "email" is not.
2353 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2354
2355 for (AnnotatedSpan& annotated_span : result) {
2356 SortClassificationResults(&annotated_span.classification);
2357 }
2358 *candidates = result;
2359 return Status::OK;
2360 }
2361
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2362 StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2363 const std::vector<InputFragment>& string_fragments,
2364 const AnnotationOptions& options) const {
2365 Annotations annotation_candidates;
2366 annotation_candidates.annotated_spans.resize(string_fragments.size());
2367
2368 std::vector<std::string> text_to_annotate;
2369 text_to_annotate.reserve(string_fragments.size());
2370 std::vector<FragmentMetadata> fragment_metadata;
2371 fragment_metadata.reserve(string_fragments.size());
2372 for (const auto& string_fragment : string_fragments) {
2373 text_to_annotate.push_back(string_fragment.text);
2374 fragment_metadata.push_back(
2375 {.relative_bounding_box_top = string_fragment.bounding_box_top,
2376 .relative_bounding_box_height = string_fragment.bounding_box_height});
2377 }
2378
2379 // KnowledgeEngine is special, because it supports annotation of multiple
2380 // fragments at once.
2381 if (knowledge_engine_ &&
2382 !knowledge_engine_
2383 ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2384 options.annotation_usecase,
2385 options.location_context, options.permissions,
2386 options.annotate_mode, &annotation_candidates)
2387 .ok()) {
2388 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2389 }
2390 // The annotator engines shouldn't change the number of annotation vectors.
2391 if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2392 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2393 << " texts to annotate but generated a different number of "
2394 "lists of annotations:"
2395 << annotation_candidates.annotated_spans.size();
2396 return Status(StatusCode::INTERNAL,
2397 "Number of annotation candidates differs from "
2398 "number of texts to annotate.");
2399 }
2400
2401 // As an optimization, if the only annotated type is Entity, we skip all the
2402 // other annotators than the KnowledgeEngine. This only happens in the raw
2403 // mode, to make sure it does not affect the result.
2404 if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2405 options.entity_types.size() == 1 &&
2406 *options.entity_types.begin() == Collections::Entity()) {
2407 return annotation_candidates;
2408 }
2409
2410 // Other annotators run on each fragment independently.
2411 for (int i = 0; i < text_to_annotate.size(); ++i) {
2412 AnnotationOptions annotation_options = options;
2413 if (string_fragments[i].datetime_options.has_value()) {
2414 DatetimeOptions reference_datetime =
2415 string_fragments[i].datetime_options.value();
2416 annotation_options.reference_time_ms_utc =
2417 reference_datetime.reference_time_ms_utc;
2418 annotation_options.reference_timezone =
2419 reference_datetime.reference_timezone;
2420 }
2421
2422 AddContactMetadataToKnowledgeClassificationResults(
2423 &annotation_candidates.annotated_spans[i]);
2424
2425 Status annotation_status =
2426 AnnotateSingleInput(text_to_annotate[i], annotation_options,
2427 &annotation_candidates.annotated_spans[i]);
2428 if (!annotation_status.ok()) {
2429 return annotation_status;
2430 }
2431 }
2432 return annotation_candidates;
2433 }
2434
Annotate(const std::string & context,const AnnotationOptions & options) const2435 std::vector<AnnotatedSpan> Annotator::Annotate(
2436 const std::string& context, const AnnotationOptions& options) const {
2437 if (context.size() > std::numeric_limits<int>::max()) {
2438 TC3_LOG(ERROR) << "Rejecting too long input.";
2439 return {};
2440 }
2441
2442 const UnicodeText context_unicode =
2443 UTF8ToUnicodeText(context, /*do_copy=*/false);
2444 if (!unilib_->IsValidUtf8(context_unicode)) {
2445 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2446 return {};
2447 }
2448
2449 std::vector<InputFragment> string_fragments;
2450 string_fragments.push_back({.text = context});
2451 StatusOr<Annotations> annotations =
2452 AnnotateStructuredInput(string_fragments, options);
2453 if (!annotations.ok()) {
2454 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2455 << annotations.status().error_message();
2456 return {};
2457 }
2458 return annotations.ValueOrDie().annotated_spans[0];
2459 }
2460
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2461 CodepointSpan Annotator::ComputeSelectionBoundaries(
2462 const UniLib::RegexMatcher* match,
2463 const RegexModel_::Pattern* config) const {
2464 if (config->capturing_group() == nullptr) {
2465 // Use first capturing group to specify the selection.
2466 int status = UniLib::RegexMatcher::kNoError;
2467 const CodepointSpan result = {match->Start(1, &status),
2468 match->End(1, &status)};
2469 if (status != UniLib::RegexMatcher::kNoError) {
2470 return {kInvalidIndex, kInvalidIndex};
2471 }
2472 return result;
2473 }
2474
2475 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2476 const int num_groups = config->capturing_group()->size();
2477 for (int i = 0; i < num_groups; i++) {
2478 if (!config->capturing_group()->Get(i)->extend_selection()) {
2479 continue;
2480 }
2481
2482 int status = UniLib::RegexMatcher::kNoError;
2483 // Check match and adjust bounds.
2484 const int group_start = match->Start(i, &status);
2485 const int group_end = match->End(i, &status);
2486 if (status != UniLib::RegexMatcher::kNoError) {
2487 return {kInvalidIndex, kInvalidIndex};
2488 }
2489 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2490 continue;
2491 }
2492 if (result.first == kInvalidIndex) {
2493 result = {group_start, group_end};
2494 } else {
2495 result.first = std::min(result.first, group_start);
2496 result.second = std::max(result.second, group_end);
2497 }
2498 }
2499 return result;
2500 }
2501
HasEntityData(const RegexModel_::Pattern * pattern) const2502 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2503 if (pattern->serialized_entity_data() != nullptr ||
2504 pattern->entity_data() != nullptr) {
2505 return true;
2506 }
2507 if (pattern->capturing_group() != nullptr) {
2508 for (const CapturingGroup* group : *pattern->capturing_group()) {
2509 if (group->entity_field_path() != nullptr) {
2510 return true;
2511 }
2512 if (group->serialized_entity_data() != nullptr ||
2513 group->entity_data() != nullptr) {
2514 return true;
2515 }
2516 }
2517 }
2518 return false;
2519 }
2520
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2521 bool Annotator::SerializedEntityDataFromRegexMatch(
2522 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2523 std::string* serialized_entity_data) const {
2524 if (!HasEntityData(pattern)) {
2525 serialized_entity_data->clear();
2526 return true;
2527 }
2528 TC3_CHECK(entity_data_builder_ != nullptr);
2529
2530 std::unique_ptr<MutableFlatbuffer> entity_data =
2531 entity_data_builder_->NewRoot();
2532
2533 TC3_CHECK(entity_data != nullptr);
2534
2535 // Set fixed entity data.
2536 if (pattern->serialized_entity_data() != nullptr) {
2537 entity_data->MergeFromSerializedFlatbuffer(
2538 StringPiece(pattern->serialized_entity_data()->c_str(),
2539 pattern->serialized_entity_data()->size()));
2540 }
2541 if (pattern->entity_data() != nullptr) {
2542 entity_data->MergeFrom(
2543 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2544 }
2545
2546 // Add entity data from rule capturing groups.
2547 if (pattern->capturing_group() != nullptr) {
2548 const int num_groups = pattern->capturing_group()->size();
2549 for (int i = 0; i < num_groups; i++) {
2550 const CapturingGroup* group = pattern->capturing_group()->Get(i);
2551
2552 // Check whether the group matched.
2553 Optional<std::string> group_match_text =
2554 GetCapturingGroupText(matcher, /*group_id=*/i);
2555 if (!group_match_text.has_value()) {
2556 continue;
2557 }
2558
2559 // Set fixed entity data from capturing group match.
2560 if (group->serialized_entity_data() != nullptr) {
2561 entity_data->MergeFromSerializedFlatbuffer(
2562 StringPiece(group->serialized_entity_data()->c_str(),
2563 group->serialized_entity_data()->size()));
2564 }
2565 if (group->entity_data() != nullptr) {
2566 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2567 pattern->entity_data()));
2568 }
2569
2570 // Set entity field from capturing group text.
2571 if (group->entity_field_path() != nullptr) {
2572 UnicodeText normalized_group_match_text =
2573 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2574
2575 // Apply normalization if specified.
2576 if (group->normalization_options() != nullptr) {
2577 normalized_group_match_text =
2578 NormalizeText(*unilib_, group->normalization_options(),
2579 normalized_group_match_text);
2580 }
2581
2582 if (!entity_data->ParseAndSet(
2583 group->entity_field_path(),
2584 normalized_group_match_text.ToUTF8String())) {
2585 TC3_LOG(ERROR)
2586 << "Could not set entity data from rule capturing group.";
2587 return false;
2588 }
2589 }
2590 }
2591 }
2592
2593 *serialized_entity_data = entity_data->Serialize();
2594 return true;
2595 }
2596
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2597 UnicodeText RemoveMoneySeparators(
2598 const std::unordered_set<char32>& decimal_separators,
2599 const UnicodeText& amount,
2600 UnicodeText::const_iterator it_decimal_separator) {
2601 UnicodeText whole_amount;
2602 for (auto it = amount.begin();
2603 it != amount.end() && it != it_decimal_separator; ++it) {
2604 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2605 static_cast<char32>(*it)) == decimal_separators.end()) {
2606 whole_amount.push_back(*it);
2607 }
2608 }
2609 return whole_amount;
2610 }
2611
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2612 void Annotator::GetMoneyQuantityFromCapturingGroup(
2613 const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2614 const UnicodeText& context_unicode, std::string* quantity,
2615 int* exponent) const {
2616 if (config->capturing_group() == nullptr) {
2617 *exponent = 0;
2618 return;
2619 }
2620
2621 const int num_groups = config->capturing_group()->size();
2622 for (int i = 0; i < num_groups; i++) {
2623 int status = UniLib::RegexMatcher::kNoError;
2624 const int group_start = match->Start(i, &status);
2625 const int group_end = match->End(i, &status);
2626 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2627 continue;
2628 }
2629
2630 *quantity =
2631 unilib_
2632 ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2633 group_end, /*do_copy=*/false))
2634 .ToUTF8String();
2635
2636 if (auto entry = model_->money_parsing_options()
2637 ->quantities_name_to_exponent()
2638 ->LookupByKey((*quantity).c_str())) {
2639 *exponent = entry->value();
2640 return;
2641 }
2642 }
2643 *exponent = 0;
2644 }
2645
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2646 bool Annotator::ParseAndFillInMoneyAmount(
2647 std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2648 const RegexModel_::Pattern* config,
2649 const UnicodeText& context_unicode) const {
2650 std::unique_ptr<EntityDataT> data =
2651 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2652 *serialized_entity_data);
2653 if (data == nullptr) {
2654 if (model_->version() >= 706) {
2655 // This way of parsing money entity data is enabled for models newer than
2656 // v706, consequently logging errors only for them (b/156634162).
2657 TC3_LOG(ERROR)
2658 << "Data field is null when trying to parse Money Entity Data";
2659 }
2660 return false;
2661 }
2662 if (data->money->unnormalized_amount.empty()) {
2663 if (model_->version() >= 706) {
2664 // This way of parsing money entity data is enabled for models newer than
2665 // v706, consequently logging errors only for them (b/156634162).
2666 TC3_LOG(ERROR)
2667 << "Data unnormalized_amount is empty when trying to parse "
2668 "Money Entity Data";
2669 }
2670 return false;
2671 }
2672
2673 UnicodeText amount =
2674 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2675 int separator_back_index = 0;
2676 auto it_decimal_separator = --amount.end();
2677 for (; it_decimal_separator != amount.begin();
2678 --it_decimal_separator, ++separator_back_index) {
2679 if (std::find(money_separators_.begin(), money_separators_.end(),
2680 static_cast<char32>(*it_decimal_separator)) !=
2681 money_separators_.end()) {
2682 break;
2683 }
2684 }
2685
2686 // If there are 3 digits after the last separator, we consider that a
2687 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2688 // If there is no separator in number, also that number is an int.
2689 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2690 it_decimal_separator = amount.end();
2691 }
2692
2693 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2694 it_decimal_separator),
2695 &data->money->amount_whole_part)) {
2696 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2697 "amount: "
2698 << data->money->unnormalized_amount;
2699 return false;
2700 }
2701
2702 if (it_decimal_separator == amount.end()) {
2703 data->money->amount_decimal_part = 0;
2704 data->money->nanos = 0;
2705 } else {
2706 const int amount_codepoints_size = amount.size_codepoints();
2707 const UnicodeText decimal_part = UnicodeText::Substring(
2708 amount, amount_codepoints_size - separator_back_index,
2709 amount_codepoints_size, /*do_copy=*/false);
2710 if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2711 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2712 "the amount: "
2713 << data->money->unnormalized_amount;
2714 return false;
2715 }
2716 data->money->nanos = data->money->amount_decimal_part *
2717 pow(10, 9 - decimal_part.size_codepoints());
2718 }
2719
2720 if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2721 nullptr) {
2722 int quantity_exponent;
2723 std::string quantity;
2724 GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2725 &quantity, &quantity_exponent);
2726 if (quantity_exponent > 0 && quantity_exponent <= 9) {
2727 const double amount_whole_part =
2728 data->money->amount_whole_part * pow(10, quantity_exponent) +
2729 data->money->nanos / pow(10, 9 - quantity_exponent);
2730 // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2731 // (and `std::numeric_limits<int>::max()` to
2732 // `std::numeric_limits<int64>::max()`).
2733 if (amount_whole_part < std::numeric_limits<int>::max()) {
2734 data->money->amount_whole_part = amount_whole_part;
2735 data->money->nanos = data->money->nanos %
2736 static_cast<int>(pow(10, 9 - quantity_exponent)) *
2737 pow(10, quantity_exponent);
2738 }
2739 }
2740 if (quantity_exponent > 0) {
2741 data->money->unnormalized_amount = strings::JoinStrings(
2742 " ", {data->money->unnormalized_amount, quantity});
2743 }
2744 }
2745
2746 *serialized_entity_data =
2747 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2748 return true;
2749 }
2750
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2751 bool Annotator::IsAnyModelEntityTypeEnabled(
2752 const EnabledEntityTypes& is_entity_type_enabled) const {
2753 if (model_->classification_feature_options() == nullptr ||
2754 model_->classification_feature_options()->collections() == nullptr) {
2755 return false;
2756 }
2757 for (int i = 0;
2758 i < model_->classification_feature_options()->collections()->size();
2759 i++) {
2760 if (is_entity_type_enabled(model_->classification_feature_options()
2761 ->collections()
2762 ->Get(i)
2763 ->str())) {
2764 return true;
2765 }
2766 }
2767 return false;
2768 }
2769
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2770 bool Annotator::IsAnyRegexEntityTypeEnabled(
2771 const EnabledEntityTypes& is_entity_type_enabled) const {
2772 if (model_->regex_model() == nullptr ||
2773 model_->regex_model()->patterns() == nullptr) {
2774 return false;
2775 }
2776 for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2777 if (is_entity_type_enabled(model_->regex_model()
2778 ->patterns()
2779 ->Get(i)
2780 ->collection_name()
2781 ->str())) {
2782 return true;
2783 }
2784 }
2785 return false;
2786 }
2787
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2788 bool Annotator::IsAnyPodNerEntityTypeEnabled(
2789 const EnabledEntityTypes& is_entity_type_enabled) const {
2790 if (pod_ner_annotator_ == nullptr) {
2791 return false;
2792 }
2793
2794 for (const std::string& collection :
2795 pod_ner_annotator_->GetSupportedCollections()) {
2796 if (is_entity_type_enabled(collection)) {
2797 return true;
2798 }
2799 }
2800 return false;
2801 }
2802
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2803 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2804 const std::vector<int>& rules,
2805 bool is_serialized_entity_data_enabled,
2806 const EnabledEntityTypes& enabled_entity_types,
2807 const AnnotationUsecase& annotation_usecase,
2808 std::vector<AnnotatedSpan>* result) const {
2809 for (int pattern_id : rules) {
2810 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2811 if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2812 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2813 // No regex annotation type has been requested, skip regex annotation.
2814 continue;
2815 }
2816 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2817 if (!matcher) {
2818 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2819 << pattern_id;
2820 return false;
2821 }
2822
2823 int status = UniLib::RegexMatcher::kNoError;
2824 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2825 if (regex_pattern.config->verification_options()) {
2826 if (!VerifyRegexMatchCandidate(
2827 context_unicode.ToUTF8String(),
2828 regex_pattern.config->verification_options(),
2829 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2830 continue;
2831 }
2832 }
2833
2834 std::string serialized_entity_data;
2835 if (is_serialized_entity_data_enabled) {
2836 if (!SerializedEntityDataFromRegexMatch(
2837 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2838 TC3_LOG(ERROR) << "Could not get entity data.";
2839 return false;
2840 }
2841
2842 // Further parsing of money amount. Need this since regexes cannot have
2843 // empty groups that fill in entity data (amount_decimal_part and
2844 // quantity might be empty groups).
2845 if (regex_pattern.config->collection_name()->str() ==
2846 Collections::Money()) {
2847 if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2848 regex_pattern.config,
2849 context_unicode)) {
2850 if (model_->version() >= 706) {
2851 // This way of parsing money entity data is enabled for models
2852 // newer than v706 => logging errors only for them (b/156634162).
2853 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2854 }
2855 }
2856 }
2857 }
2858
2859 result->emplace_back();
2860
2861 // Selection/annotation regular expressions need to specify a capturing
2862 // group specifying the selection.
2863 result->back().span =
2864 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2865
2866 result->back().classification = {
2867 {regex_pattern.config->collection_name()->str(),
2868 regex_pattern.config->target_classification_score(),
2869 regex_pattern.config->priority_score()}};
2870
2871 result->back().classification[0].serialized_entity_data =
2872 serialized_entity_data;
2873 }
2874 }
2875 return true;
2876 }
2877
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2878 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2879 tflite::Interpreter* selection_interpreter,
2880 const CachedFeatures& cached_features,
2881 std::vector<TokenSpan>* chunks) const {
2882 const int max_selection_span =
2883 selection_feature_processor_->GetOptions()->max_selection_span();
2884 // The inference span is the span of interest expanded to include
2885 // max_selection_span tokens on either side, which is how far a selection can
2886 // stretch from the click.
2887 const TokenSpan inference_span =
2888 IntersectTokenSpans(span_of_interest.Expand(
2889 /*num_tokens_left=*/max_selection_span,
2890 /*num_tokens_right=*/max_selection_span),
2891 {0, num_tokens});
2892
2893 std::vector<ScoredChunk> scored_chunks;
2894 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2895 selection_feature_processor_->GetOptions()
2896 ->bounds_sensitive_features()
2897 ->enabled()) {
2898 if (!ModelBoundsSensitiveScoreChunks(
2899 num_tokens, span_of_interest, inference_span, cached_features,
2900 selection_interpreter, &scored_chunks)) {
2901 return false;
2902 }
2903 } else {
2904 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2905 cached_features, selection_interpreter,
2906 &scored_chunks)) {
2907 return false;
2908 }
2909 }
2910 std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
2911 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2912 return lhs.score < rhs.score;
2913 });
2914
2915 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2916 // them greedily as long as they do not overlap with any previously picked
2917 // chunks.
2918 std::vector<bool> token_used(inference_span.Size());
2919 chunks->clear();
2920 for (const ScoredChunk& scored_chunk : scored_chunks) {
2921 bool feasible = true;
2922 for (int i = scored_chunk.token_span.first;
2923 i < scored_chunk.token_span.second; ++i) {
2924 if (token_used[i - inference_span.first]) {
2925 feasible = false;
2926 break;
2927 }
2928 }
2929
2930 if (!feasible) {
2931 continue;
2932 }
2933
2934 for (int i = scored_chunk.token_span.first;
2935 i < scored_chunk.token_span.second; ++i) {
2936 token_used[i - inference_span.first] = true;
2937 }
2938
2939 chunks->push_back(scored_chunk.token_span);
2940 }
2941
2942 std::stable_sort(chunks->begin(), chunks->end());
2943
2944 return true;
2945 }
2946
2947 namespace {
2948 // Updates the value at the given key in the map to maximum of the current value
2949 // and the given value, or simply inserts the value if the key is not yet there.
2950 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2951 void UpdateMax(Map* map, typename Map::key_type key,
2952 typename Map::mapped_type value) {
2953 const auto it = map->find(key);
2954 if (it != map->end()) {
2955 it->second = std::max(it->second, value);
2956 } else {
2957 (*map)[key] = value;
2958 }
2959 }
2960 } // namespace
2961
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2962 bool Annotator::ModelClickContextScoreChunks(
2963 int num_tokens, const TokenSpan& span_of_interest,
2964 const CachedFeatures& cached_features,
2965 tflite::Interpreter* selection_interpreter,
2966 std::vector<ScoredChunk>* scored_chunks) const {
2967 const int max_batch_size = model_->selection_options()->batch_size();
2968
2969 std::vector<float> all_features;
2970 std::map<TokenSpan, float> chunk_scores;
2971 for (int batch_start = span_of_interest.first;
2972 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2973 const int batch_end =
2974 std::min(batch_start + max_batch_size, span_of_interest.second);
2975
2976 // Prepare features for the whole batch.
2977 all_features.clear();
2978 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2979 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2980 cached_features.AppendClickContextFeaturesForClick(click_pos,
2981 &all_features);
2982 }
2983
2984 // Run batched inference.
2985 const int batch_size = batch_end - batch_start;
2986 const int features_size = cached_features.OutputFeaturesSize();
2987 TensorView<float> logits = selection_executor_->ComputeLogits(
2988 TensorView<float>(all_features.data(), {batch_size, features_size}),
2989 selection_interpreter);
2990 if (!logits.is_valid()) {
2991 TC3_LOG(ERROR) << "Couldn't compute logits.";
2992 return false;
2993 }
2994 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2995 logits.dim(1) !=
2996 selection_feature_processor_->GetSelectionLabelCount()) {
2997 TC3_LOG(ERROR) << "Mismatching output.";
2998 return false;
2999 }
3000
3001 // Save results.
3002 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3003 const std::vector<float> scores = ComputeSoftmax(
3004 logits.data() + logits.dim(1) * (click_pos - batch_start),
3005 logits.dim(1));
3006 for (int j = 0;
3007 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3008 TokenSpan relative_token_span;
3009 if (!selection_feature_processor_->LabelToTokenSpan(
3010 j, &relative_token_span)) {
3011 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3012 return false;
3013 }
3014 const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3015 relative_token_span.first, relative_token_span.second);
3016 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3017 UpdateMax(&chunk_scores, candidate_span, scores[j]);
3018 }
3019 }
3020 }
3021 }
3022
3023 scored_chunks->clear();
3024 scored_chunks->reserve(chunk_scores.size());
3025 for (const auto& entry : chunk_scores) {
3026 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3027 }
3028
3029 return true;
3030 }
3031
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3032 bool Annotator::ModelBoundsSensitiveScoreChunks(
3033 int num_tokens, const TokenSpan& span_of_interest,
3034 const TokenSpan& inference_span, const CachedFeatures& cached_features,
3035 tflite::Interpreter* selection_interpreter,
3036 std::vector<ScoredChunk>* scored_chunks) const {
3037 const int max_selection_span =
3038 selection_feature_processor_->GetOptions()->max_selection_span();
3039 const int max_chunk_length = selection_feature_processor_->GetOptions()
3040 ->selection_reduced_output_space()
3041 ? max_selection_span + 1
3042 : 2 * max_selection_span + 1;
3043 const bool score_single_token_spans_as_zero =
3044 selection_feature_processor_->GetOptions()
3045 ->bounds_sensitive_features()
3046 ->score_single_token_spans_as_zero();
3047
3048 scored_chunks->clear();
3049 if (score_single_token_spans_as_zero) {
3050 scored_chunks->reserve(span_of_interest.Size());
3051 }
3052
3053 // Prepare all chunk candidates into one batch:
3054 // - Are contained in the inference span
3055 // - Have a non-empty intersection with the span of interest
3056 // - Are at least one token long
3057 // - Are not longer than the maximum chunk length
3058 std::vector<TokenSpan> candidate_spans;
3059 for (int start = inference_span.first; start < span_of_interest.second;
3060 ++start) {
3061 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3062 for (int end = leftmost_end_index;
3063 end <= inference_span.second && end - start <= max_chunk_length;
3064 ++end) {
3065 const TokenSpan candidate_span = {start, end};
3066 if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3067 // Do not include the single token span in the batch, add a zero score
3068 // for it directly to the output.
3069 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3070 } else {
3071 candidate_spans.push_back(candidate_span);
3072 }
3073 }
3074 }
3075
3076 const int max_batch_size = model_->selection_options()->batch_size();
3077
3078 std::vector<float> all_features;
3079 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3080 for (int batch_start = 0; batch_start < candidate_spans.size();
3081 batch_start += max_batch_size) {
3082 const int batch_end = std::min(batch_start + max_batch_size,
3083 static_cast<int>(candidate_spans.size()));
3084
3085 // Prepare features for the whole batch.
3086 all_features.clear();
3087 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3088 for (int i = batch_start; i < batch_end; ++i) {
3089 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3090 &all_features);
3091 }
3092
3093 // Run batched inference.
3094 const int batch_size = batch_end - batch_start;
3095 const int features_size = cached_features.OutputFeaturesSize();
3096 TensorView<float> logits = selection_executor_->ComputeLogits(
3097 TensorView<float>(all_features.data(), {batch_size, features_size}),
3098 selection_interpreter);
3099 if (!logits.is_valid()) {
3100 TC3_LOG(ERROR) << "Couldn't compute logits.";
3101 return false;
3102 }
3103 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3104 logits.dim(1) != 1) {
3105 TC3_LOG(ERROR) << "Mismatching output.";
3106 return false;
3107 }
3108
3109 // Save results.
3110 for (int i = batch_start; i < batch_end; ++i) {
3111 scored_chunks->push_back(
3112 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3113 }
3114 }
3115
3116 return true;
3117 }
3118
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3119 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3120 int64 reference_time_ms_utc,
3121 const std::string& reference_timezone,
3122 const std::string& locales, ModeFlag mode,
3123 AnnotationUsecase annotation_usecase,
3124 bool is_serialized_entity_data_enabled,
3125 std::vector<AnnotatedSpan>* result) const {
3126 if (!datetime_parser_) {
3127 return true;
3128 }
3129 LocaleList locale_list = LocaleList::ParseFrom(locales);
3130 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3131 datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3132 reference_timezone, locale_list, mode,
3133 annotation_usecase,
3134 /*anchor_start_end=*/false);
3135 if (!result_status.ok()) {
3136 return false;
3137 }
3138
3139 for (const DatetimeParseResultSpan& datetime_span :
3140 result_status.ValueOrDie()) {
3141 AnnotatedSpan annotated_span;
3142 annotated_span.span = datetime_span.span;
3143 for (const DatetimeParseResult& parse_result : datetime_span.data) {
3144 annotated_span.classification.emplace_back(
3145 PickCollectionForDatetime(parse_result),
3146 datetime_span.target_classification_score,
3147 datetime_span.priority_score);
3148 annotated_span.classification.back().datetime_parse_result = parse_result;
3149 if (is_serialized_entity_data_enabled) {
3150 annotated_span.classification.back().serialized_entity_data =
3151 CreateDatetimeSerializedEntityData(parse_result);
3152 }
3153 }
3154 annotated_span.source = AnnotatedSpan::Source::DATETIME;
3155 result->push_back(std::move(annotated_span));
3156 }
3157 return true;
3158 }
3159
model() const3160 const Model* Annotator::model() const { return model_; }
entity_data_schema() const3161 const reflection::Schema* Annotator::entity_data_schema() const {
3162 return entity_data_schema_;
3163 }
3164
ViewModel(const void * buffer,int size)3165 const Model* ViewModel(const void* buffer, int size) {
3166 if (!buffer) {
3167 return nullptr;
3168 }
3169
3170 return LoadAndVerifyModel(buffer, size);
3171 }
3172
LookUpKnowledgeEntity(const std::string & id) const3173 StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3174 const std::string& id) const {
3175 if (!knowledge_engine_) {
3176 return Status(StatusCode::FAILED_PRECONDITION,
3177 "knowledge_engine_ is nullptr");
3178 }
3179 return knowledge_engine_->LookUpEntity(id);
3180 }
3181
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3182 StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3183 const std::string& mid_str, const std::string& property) const {
3184 if (!knowledge_engine_) {
3185 return Status(StatusCode::FAILED_PRECONDITION,
3186 "knowledge_engine_ is nullptr");
3187 }
3188 return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3189 }
3190
3191 } // namespace libtextclassifier3
3192