1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/annotator.h"
18
19 #include <algorithm>
20 #include <cmath>
21 #include <cstddef>
22 #include <iterator>
23 #include <limits>
24 #include <numeric>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28
29 #include "annotator/collections.h"
30 #include "annotator/datetime/grammar-parser.h"
31 #include "annotator/datetime/regex-parser.h"
32 #include "annotator/flatbuffer-utils.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "annotator/model_generated.h"
35 #include "annotator/types.h"
36 #include "utils/base/logging.h"
37 #include "utils/base/status.h"
38 #include "utils/base/statusor.h"
39 #include "utils/calendar/calendar.h"
40 #include "utils/checksum.h"
41 #include "utils/grammar/analyzer.h"
42 #include "utils/i18n/locale-list.h"
43 #include "utils/i18n/locale.h"
44 #include "utils/math/softmax.h"
45 #include "utils/normalization.h"
46 #include "utils/optional.h"
47 #include "utils/regex-match.h"
48 #include "utils/strings/append.h"
49 #include "utils/strings/numbers.h"
50 #include "utils/strings/split.h"
51 #include "utils/utf8/unicodetext.h"
52 #include "utils/utf8/unilib-common.h"
53 #include "utils/zlib/zlib_regex.h"
54
55 namespace libtextclassifier3 {
56
57 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58
59 const std::string& Annotator::kPhoneCollection =
__anon109669fc0102() 60 *[]() { return new std::string("phone"); }();
61 const std::string& Annotator::kAddressCollection =
__anon109669fc0202() 62 *[]() { return new std::string("address"); }();
63 const std::string& Annotator::kDateCollection =
__anon109669fc0302() 64 *[]() { return new std::string("date"); }();
65 const std::string& Annotator::kUrlCollection =
__anon109669fc0402() 66 *[]() { return new std::string("url"); }();
67 const std::string& Annotator::kEmailCollection =
__anon109669fc0502() 68 *[]() { return new std::string("email"); }();
69
70 namespace {
LoadAndVerifyModel(const void * addr,int size)71 const Model* LoadAndVerifyModel(const void* addr, int size) {
72 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73 if (VerifyModelBuffer(verifier)) {
74 return GetModel(addr);
75 } else {
76 return nullptr;
77 }
78 }
79
LoadAndVerifyPersonNameModel(const void * addr,int size)80 const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81 int size) {
82 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83 if (VerifyPersonNameModelBuffer(verifier)) {
84 return GetPersonNameModel(addr);
85 } else {
86 return nullptr;
87 }
88 }
89
90 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92 const UniLib* MaybeCreateUnilib(const UniLib* lib,
93 std::unique_ptr<UniLib>* owned_lib) {
94 if (lib) {
95 return lib;
96 } else {
97 owned_lib->reset(new UniLib);
98 return owned_lib->get();
99 }
100 }
101
102 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103 const CalendarLib* MaybeCreateCalendarlib(
104 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105 if (lib) {
106 return lib;
107 } else {
108 owned_lib->reset(new CalendarLib);
109 return owned_lib->get();
110 }
111 }
112
113 // Returns whether the provided input is valid:
114 // * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115 bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116 return (span.first >= 0 && span.first < span.second &&
117 span.second <= context.size_codepoints());
118 }
119
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120 std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121 const flatbuffers::Vector<int32_t>* ints) {
122 if (ints == nullptr) {
123 return {};
124 }
125 std::unordered_set<char32> ints_set;
126 for (auto value : *ints) {
127 ints_set.insert(static_cast<char32>(value));
128 }
129 return ints_set;
130 }
131
132 } // namespace
133
SelectionInterpreter()134 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135 if (!selection_interpreter_) {
136 TC3_CHECK(selection_executor_);
137 selection_interpreter_ = selection_executor_->CreateInterpreter();
138 if (!selection_interpreter_) {
139 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140 }
141 }
142 return selection_interpreter_.get();
143 }
144
ClassificationInterpreter()145 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146 if (!classification_interpreter_) {
147 TC3_CHECK(classification_executor_);
148 classification_interpreter_ = classification_executor_->CreateInterpreter();
149 if (!classification_interpreter_) {
150 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151 }
152 }
153 return classification_interpreter_.get();
154 }
155
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157 const char* buffer, int size, const UniLib* unilib,
158 const CalendarLib* calendarlib) {
159 const Model* model = LoadAndVerifyModel(buffer, size);
160 if (model == nullptr) {
161 return nullptr;
162 }
163
164 auto classifier = std::unique_ptr<Annotator>(new Annotator());
165 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166 calendarlib =
167 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168 classifier->ValidateAndInitialize(model, unilib, calendarlib);
169 if (!classifier->IsInitialized()) {
170 return nullptr;
171 }
172
173 return classifier;
174 }
175
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176 std::unique_ptr<Annotator> Annotator::FromString(
177 const std::string& buffer, const UniLib* unilib,
178 const CalendarLib* calendarlib) {
179 auto classifier = std::unique_ptr<Annotator>(new Annotator());
180 classifier->owned_buffer_ = buffer;
181 const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182 classifier->owned_buffer_.size());
183 if (model == nullptr) {
184 return nullptr;
185 }
186 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187 calendarlib =
188 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189 classifier->ValidateAndInitialize(model, unilib, calendarlib);
190 if (!classifier->IsInitialized()) {
191 return nullptr;
192 }
193
194 return classifier;
195 }
196
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199 const CalendarLib* calendarlib) {
200 if (!(*mmap)->handle().ok()) {
201 TC3_VLOG(1) << "Mmap failed.";
202 return nullptr;
203 }
204
205 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206 (*mmap)->handle().num_bytes());
207 if (!model) {
208 TC3_LOG(ERROR) << "Model verification failed.";
209 return nullptr;
210 }
211
212 auto classifier = std::unique_ptr<Annotator>(new Annotator());
213 classifier->mmap_ = std::move(*mmap);
214 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215 calendarlib =
216 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217 classifier->ValidateAndInitialize(model, unilib, calendarlib);
218 if (!classifier->IsInitialized()) {
219 return nullptr;
220 }
221
222 return classifier;
223 }
224
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227 std::unique_ptr<CalendarLib> calendarlib) {
228 if (!(*mmap)->handle().ok()) {
229 TC3_VLOG(1) << "Mmap failed.";
230 return nullptr;
231 }
232
233 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234 (*mmap)->handle().num_bytes());
235 if (model == nullptr) {
236 TC3_LOG(ERROR) << "Model verification failed.";
237 return nullptr;
238 }
239
240 auto classifier = std::unique_ptr<Annotator>(new Annotator());
241 classifier->mmap_ = std::move(*mmap);
242 classifier->owned_unilib_ = std::move(unilib);
243 classifier->owned_calendarlib_ = std::move(calendarlib);
244 classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245 classifier->owned_calendarlib_.get());
246 if (!classifier->IsInitialized()) {
247 return nullptr;
248 }
249
250 return classifier;
251 }
252
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254 int fd, int offset, int size, const UniLib* unilib,
255 const CalendarLib* calendarlib) {
256 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257 return FromScopedMmap(&mmap, unilib, calendarlib);
258 }
259
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262 std::unique_ptr<CalendarLib> calendarlib) {
263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265 }
266
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270 return FromScopedMmap(&mmap, unilib, calendarlib);
271 }
272
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274 int fd, std::unique_ptr<UniLib> unilib,
275 std::unique_ptr<CalendarLib> calendarlib) {
276 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278 }
279
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281 const UniLib* unilib,
282 const CalendarLib* calendarlib) {
283 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284 return FromScopedMmap(&mmap, unilib, calendarlib);
285 }
286
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287 std::unique_ptr<Annotator> Annotator::FromPath(
288 const std::string& path, std::unique_ptr<UniLib> unilib,
289 std::unique_ptr<CalendarLib> calendarlib) {
290 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292 }
293
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294 void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295 const CalendarLib* calendarlib) {
296 model_ = model;
297 unilib_ = unilib;
298 calendarlib_ = calendarlib;
299
300 initialized_ = false;
301
302 if (model_ == nullptr) {
303 TC3_LOG(ERROR) << "No model specified.";
304 return;
305 }
306
307 const bool model_enabled_for_annotation =
308 (model_->triggering_options() != nullptr &&
309 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310 const bool model_enabled_for_classification =
311 (model_->triggering_options() != nullptr &&
312 (model_->triggering_options()->enabled_modes() &
313 ModeFlag_CLASSIFICATION));
314 const bool model_enabled_for_selection =
315 (model_->triggering_options() != nullptr &&
316 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317
318 // Annotation requires the selection model.
319 if (model_enabled_for_annotation || model_enabled_for_selection) {
320 if (!model_->selection_options()) {
321 TC3_LOG(ERROR) << "No selection options.";
322 return;
323 }
324 if (!model_->selection_feature_options()) {
325 TC3_LOG(ERROR) << "No selection feature options.";
326 return;
327 }
328 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330 return;
331 }
332 if (!model_->selection_model()) {
333 TC3_LOG(ERROR) << "No selection model.";
334 return;
335 }
336 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337 if (!selection_executor_) {
338 TC3_LOG(ERROR) << "Could not initialize selection executor.";
339 return;
340 }
341 }
342
343 // Even if the annotation mode is not enabled (for the neural network model),
344 // the selection feature processor is needed to tokenize the text for other
345 // models.
346 if (model_->selection_feature_options()) {
347 selection_feature_processor_.reset(
348 new FeatureProcessor(model_->selection_feature_options(), unilib_));
349 }
350
351 // Annotation requires the classification model for conflict resolution and
352 // scoring.
353 // Selection requires the classification model for conflict resolution.
354 if (model_enabled_for_annotation || model_enabled_for_classification ||
355 model_enabled_for_selection) {
356 if (!model_->classification_options()) {
357 TC3_LOG(ERROR) << "No classification options.";
358 return;
359 }
360
361 if (!model_->classification_feature_options()) {
362 TC3_LOG(ERROR) << "No classification feature options.";
363 return;
364 }
365
366 if (!model_->classification_feature_options()
367 ->bounds_sensitive_features()) {
368 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
369 return;
370 }
371 if (!model_->classification_model()) {
372 TC3_LOG(ERROR) << "No clf model.";
373 return;
374 }
375
376 classification_executor_ =
377 ModelExecutor::FromBuffer(model_->classification_model());
378 if (!classification_executor_) {
379 TC3_LOG(ERROR) << "Could not initialize classification executor.";
380 return;
381 }
382
383 classification_feature_processor_.reset(new FeatureProcessor(
384 model_->classification_feature_options(), unilib_));
385 }
386
387 // The embeddings need to be specified if the model is to be used for
388 // classification or selection.
389 if (model_enabled_for_annotation || model_enabled_for_classification ||
390 model_enabled_for_selection) {
391 if (!model_->embedding_model()) {
392 TC3_LOG(ERROR) << "No embedding model.";
393 return;
394 }
395
396 // Check that the embedding size of the selection and classification model
397 // matches, as they are using the same embeddings.
398 if (model_enabled_for_selection &&
399 (model_->selection_feature_options()->embedding_size() !=
400 model_->classification_feature_options()->embedding_size() ||
401 model_->selection_feature_options()->embedding_quantization_bits() !=
402 model_->classification_feature_options()
403 ->embedding_quantization_bits())) {
404 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
405 return;
406 }
407
408 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
409 model_->embedding_model(),
410 model_->classification_feature_options()->embedding_size(),
411 model_->classification_feature_options()->embedding_quantization_bits(),
412 model_->embedding_pruning_mask());
413 if (!embedding_executor_) {
414 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
415 return;
416 }
417 }
418
419 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
420 if (model_->regex_model()) {
421 if (!InitializeRegexModel(decompressor.get())) {
422 TC3_LOG(ERROR) << "Could not initialize regex model.";
423 return;
424 }
425 }
426
427 if (model_->datetime_grammar_model()) {
428 if (model_->datetime_grammar_model()->rules()) {
429 analyzer_ = std::make_unique<grammar::Analyzer>(
430 unilib_, model_->datetime_grammar_model()->rules());
431 datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
432 datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
433 *analyzer_, *datetime_grounder_,
434 /*target_classification_score=*/1.0,
435 /*priority_score=*/1.0,
436 model_->datetime_grammar_model()->enabled_modes());
437 }
438 } else if (model_->datetime_model()) {
439 datetime_parser_ = RegexDatetimeParser::Instance(
440 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
441 if (!datetime_parser_) {
442 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
443 return;
444 }
445 }
446
447 if (model_->output_options()) {
448 if (model_->output_options()->filtered_collections_annotation()) {
449 for (const auto collection :
450 *model_->output_options()->filtered_collections_annotation()) {
451 filtered_collections_annotation_.insert(collection->str());
452 }
453 }
454 if (model_->output_options()->filtered_collections_classification()) {
455 for (const auto collection :
456 *model_->output_options()->filtered_collections_classification()) {
457 filtered_collections_classification_.insert(collection->str());
458 }
459 }
460 if (model_->output_options()->filtered_collections_selection()) {
461 for (const auto collection :
462 *model_->output_options()->filtered_collections_selection()) {
463 filtered_collections_selection_.insert(collection->str());
464 }
465 }
466 }
467
468 if (model_->number_annotator_options() &&
469 model_->number_annotator_options()->enabled()) {
470 number_annotator_.reset(
471 new NumberAnnotator(model_->number_annotator_options(), unilib_));
472 }
473
474 if (model_->money_parsing_options()) {
475 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
476 model_->money_parsing_options()->separators());
477 }
478
479 if (model_->duration_annotator_options() &&
480 model_->duration_annotator_options()->enabled()) {
481 duration_annotator_.reset(
482 new DurationAnnotator(model_->duration_annotator_options(),
483 selection_feature_processor_.get(), unilib_));
484 }
485
486 if (model_->grammar_model()) {
487 grammar_annotator_.reset(new GrammarAnnotator(
488 unilib_, model_->grammar_model(), entity_data_builder_.get()));
489 }
490
491 // The following #ifdef is here to aid quality evaluation of a situation, when
492 // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
493 // it.
494 #if !defined(TC3_DISABLE_POD_NER)
495 if (model_->pod_ner_model()) {
496 pod_ner_annotator_ =
497 PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
498 }
499 #endif
500
501 if (model_->vocab_model()) {
502 vocab_annotator_ = VocabAnnotator::Create(
503 model_->vocab_model(), *selection_feature_processor_, *unilib_);
504 }
505
506 if (model_->entity_data_schema()) {
507 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
508 model_->entity_data_schema()->Data(),
509 model_->entity_data_schema()->size());
510 if (entity_data_schema_ == nullptr) {
511 TC3_LOG(ERROR) << "Could not load entity data schema data.";
512 return;
513 }
514
515 entity_data_builder_.reset(
516 new MutableFlatbufferBuilder(entity_data_schema_));
517 } else {
518 entity_data_schema_ = nullptr;
519 entity_data_builder_ = nullptr;
520 }
521
522 if (model_->triggering_locales() &&
523 !ParseLocales(model_->triggering_locales()->c_str(),
524 &model_triggering_locales_)) {
525 TC3_LOG(ERROR) << "Could not parse model supported locales.";
526 return;
527 }
528
529 if (model_->triggering_options() != nullptr &&
530 model_->triggering_options()->locales() != nullptr &&
531 !ParseLocales(model_->triggering_options()->locales()->c_str(),
532 &ml_model_triggering_locales_)) {
533 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
534 return;
535 }
536
537 if (model_->triggering_options() != nullptr &&
538 model_->triggering_options()->dictionary_locales() != nullptr &&
539 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
540 &dictionary_locales_)) {
541 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
542 return;
543 }
544
545 if (model_->conflict_resolution_options() != nullptr) {
546 prioritize_longest_annotation_ =
547 model_->conflict_resolution_options()->prioritize_longest_annotation();
548 do_conflict_resolution_in_raw_mode_ =
549 model_->conflict_resolution_options()
550 ->do_conflict_resolution_in_raw_mode();
551 }
552
553 #ifdef TC3_EXPERIMENTAL
554 TC3_LOG(WARNING) << "Enabling experimental annotators.";
555 InitializeExperimentalAnnotators();
556 #endif
557
558 initialized_ = true;
559 }
560
InitializeRegexModel(ZlibDecompressor * decompressor)561 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
562 if (!model_->regex_model()->patterns()) {
563 return true;
564 }
565
566 // Initialize pattern recognizers.
567 int regex_pattern_id = 0;
568 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
569 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
570 UncompressMakeRegexPattern(
571 *unilib_, regex_pattern->pattern(),
572 regex_pattern->compressed_pattern(),
573 model_->regex_model()->lazy_regex_compilation(), decompressor);
574 if (!compiled_pattern) {
575 TC3_LOG(INFO) << "Failed to load regex pattern";
576 return false;
577 }
578
579 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
580 annotation_regex_patterns_.push_back(regex_pattern_id);
581 }
582 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
583 classification_regex_patterns_.push_back(regex_pattern_id);
584 }
585 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
586 selection_regex_patterns_.push_back(regex_pattern_id);
587 }
588 regex_patterns_.push_back({
589 regex_pattern,
590 std::move(compiled_pattern),
591 });
592 ++regex_pattern_id;
593 }
594
595 return true;
596 }
597
InitializeKnowledgeEngine(const std::string & serialized_config)598 bool Annotator::InitializeKnowledgeEngine(
599 const std::string& serialized_config) {
600 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
601 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
602 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
603 return false;
604 }
605 if (model_->triggering_options() != nullptr) {
606 knowledge_engine->SetPriorityScore(
607 model_->triggering_options()->knowledge_priority_score());
608 knowledge_engine->SetEnabledModes(
609 model_->triggering_options()->knowledge_enabled_modes());
610 }
611 knowledge_engine_ = std::move(knowledge_engine);
612 return true;
613 }
614
InitializeContactEngine(const std::string & serialized_config)615 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
616 std::unique_ptr<ContactEngine> contact_engine(
617 new ContactEngine(selection_feature_processor_.get(), unilib_,
618 model_->contact_annotator_options()));
619 if (!contact_engine->Initialize(serialized_config)) {
620 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
621 return false;
622 }
623 contact_engine_ = std::move(contact_engine);
624 return true;
625 }
626
CleanUpContactEngine()627 void Annotator::CleanUpContactEngine() {
628 if (contact_engine_ == nullptr) {
629 TC3_LOG(INFO)
630 << "Attempting to clean up contact engine that does not exist.";
631 return;
632 }
633 contact_engine_->CleanUp();
634 }
635
InitializeInstalledAppEngine(const std::string & serialized_config)636 bool Annotator::InitializeInstalledAppEngine(
637 const std::string& serialized_config) {
638 std::unique_ptr<InstalledAppEngine> installed_app_engine(
639 new InstalledAppEngine(
640 selection_feature_processor_.get(), unilib_,
641 model_->triggering_options()->installed_app_enabled_modes()));
642 if (!installed_app_engine->Initialize(serialized_config)) {
643 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
644 return false;
645 }
646 installed_app_engine_ = std::move(installed_app_engine);
647 return true;
648 }
649
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)650 bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
651 if (lang_id == nullptr) {
652 return false;
653 }
654
655 lang_id_ = lang_id;
656 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
657 model_->translate_annotator_options()->enabled()) {
658 translate_annotator_.reset(new TranslateAnnotator(
659 model_->translate_annotator_options(), lang_id_, unilib_));
660 } else {
661 translate_annotator_.reset(nullptr);
662 }
663 return true;
664 }
665
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)666 bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
667 int size) {
668 const PersonNameModel* person_name_model =
669 LoadAndVerifyPersonNameModel(buffer, size);
670
671 if (person_name_model == nullptr) {
672 TC3_LOG(ERROR) << "Person name model verification failed.";
673 return false;
674 }
675
676 if (!person_name_model->enabled()) {
677 return true;
678 }
679
680 std::unique_ptr<PersonNameEngine> person_name_engine(
681 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
682 if (!person_name_engine->Initialize(person_name_model)) {
683 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
684 return false;
685 }
686 person_name_engine_ = std::move(person_name_engine);
687 return true;
688 }
689
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)690 bool Annotator::InitializePersonNameEngineFromScopedMmap(
691 const ScopedMmap& mmap) {
692 if (!mmap.handle().ok()) {
693 TC3_LOG(ERROR) << "Mmap for person name model failed.";
694 return false;
695 }
696
697 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
698 mmap.handle().num_bytes());
699 }
700
InitializePersonNameEngineFromPath(const std::string & path)701 bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
702 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
703 return InitializePersonNameEngineFromScopedMmap(*mmap);
704 }
705
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)706 bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
707 int size) {
708 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
709 return InitializePersonNameEngineFromScopedMmap(*mmap);
710 }
711
InitializeExperimentalAnnotators()712 bool Annotator::InitializeExperimentalAnnotators() {
713 if (ExperimentalAnnotator::IsEnabled()) {
714 experimental_annotator_.reset(new ExperimentalAnnotator(
715 model_->experimental_model(), *selection_feature_processor_, *unilib_));
716 return true;
717 }
718 return false;
719 }
720
721 namespace internal {
722 // Helper function, which if the initial 'span' contains only white-spaces,
723 // moves the selection to a single-codepoint selection on a left or right side
724 // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)725 CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
726 const UnicodeText& context_unicode,
727 const UniLib& unilib) {
728 TC3_CHECK(span.IsValid() && !span.IsEmpty());
729
730 UnicodeText::const_iterator it;
731
732 // Check that the current selection is all whitespaces.
733 it = context_unicode.begin();
734 std::advance(it, span.first);
735 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
736 if (!unilib.IsWhitespace(*it)) {
737 return span;
738 }
739 }
740
741 // Try moving left.
742 CodepointSpan result = span;
743 it = context_unicode.begin();
744 std::advance(it, span.first);
745 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
746 --result.first;
747 --it;
748 }
749 result.second = result.first + 1;
750 if (!unilib.IsWhitespace(*it)) {
751 return result;
752 }
753
754 // If moving left didn't find a non-whitespace character, just return the
755 // original span.
756 return span;
757 }
758 } // namespace internal
759
FilteredForAnnotation(const AnnotatedSpan & span) const760 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
761 return !span.classification.empty() &&
762 filtered_collections_annotation_.find(
763 span.classification[0].collection) !=
764 filtered_collections_annotation_.end();
765 }
766
FilteredForClassification(const ClassificationResult & classification) const767 bool Annotator::FilteredForClassification(
768 const ClassificationResult& classification) const {
769 return filtered_collections_classification_.find(classification.collection) !=
770 filtered_collections_classification_.end();
771 }
772
FilteredForSelection(const AnnotatedSpan & span) const773 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
774 return !span.classification.empty() &&
775 filtered_collections_selection_.find(
776 span.classification[0].collection) !=
777 filtered_collections_selection_.end();
778 }
779
780 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)781 inline bool ClassifiedAsOther(
782 const std::vector<ClassificationResult>& classification) {
783 return !classification.empty() &&
784 classification[0].collection == Collections::Other();
785 }
786
787 } // namespace
788
GetPriorityScore(const std::vector<ClassificationResult> & classification) const789 float Annotator::GetPriorityScore(
790 const std::vector<ClassificationResult>& classification) const {
791 if (!classification.empty() && !ClassifiedAsOther(classification)) {
792 return classification[0].priority_score;
793 } else {
794 if (model_->triggering_options() != nullptr) {
795 return model_->triggering_options()->other_collection_priority_score();
796 } else {
797 return -1000.0;
798 }
799 }
800 }
801
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const802 bool Annotator::VerifyRegexMatchCandidate(
803 const std::string& context, const VerificationOptions* verification_options,
804 const std::string& match, const UniLib::RegexMatcher* matcher) const {
805 if (verification_options == nullptr) {
806 return true;
807 }
808 if (verification_options->verify_luhn_checksum() &&
809 !VerifyLuhnChecksum(match)) {
810 return false;
811 }
812 const int lua_verifier = verification_options->lua_verifier();
813 if (lua_verifier >= 0) {
814 if (model_->regex_model()->lua_verifier() == nullptr ||
815 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
816 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
817 return false;
818 }
819 return VerifyMatch(
820 context, matcher,
821 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
822 }
823 return true;
824 }
825
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const826 CodepointSpan Annotator::SuggestSelection(
827 const std::string& context, CodepointSpan click_indices,
828 const SelectionOptions& options) const {
829 if (context.size() > std::numeric_limits<int>::max()) {
830 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
831 return {};
832 }
833
834 CodepointSpan original_click_indices = click_indices;
835 if (!initialized_) {
836 TC3_LOG(ERROR) << "Not initialized";
837 return original_click_indices;
838 }
839 if (options.annotation_usecase !=
840 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
841 TC3_LOG(WARNING)
842 << "Invoking SuggestSelection, which is not supported in RAW mode.";
843 return original_click_indices;
844 }
845 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
846 return original_click_indices;
847 }
848
849 std::vector<Locale> detected_text_language_tags;
850 if (!ParseLocales(options.detected_text_language_tags,
851 &detected_text_language_tags)) {
852 TC3_LOG(WARNING)
853 << "Failed to parse the detected_text_language_tags in options: "
854 << options.detected_text_language_tags;
855 }
856 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
857 model_triggering_locales_,
858 /*default_value=*/true)) {
859 return original_click_indices;
860 }
861
862 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
863 /*do_copy=*/false);
864
865 if (!unilib_->IsValidUtf8(context_unicode)) {
866 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
867 return original_click_indices;
868 }
869
870 if (!IsValidSpanInput(context_unicode, click_indices)) {
871 TC3_VLOG(1)
872 << "Trying to run SuggestSelection with invalid input, indices: "
873 << click_indices.first << " " << click_indices.second;
874 return original_click_indices;
875 }
876
877 if (model_->snap_whitespace_selections()) {
878 // We want to expand a purely white-space selection to a multi-selection it
879 // would've been part of. But with this feature disabled we would do a no-
880 // op, because no token is found. Therefore, we need to modify the
881 // 'click_indices' a bit to include a part of the token, so that the click-
882 // finding logic finds the clicked token correctly. This modification is
883 // done by the following function. Note, that it's enough to check the left
884 // side of the current selection, because if the white-space is a part of a
885 // multi-selection, necessarily both tokens - on the left and the right
886 // sides need to be selected. Thus snapping only to the left is sufficient
887 // (there's a check at the bottom that makes sure that if we snap to the
888 // left token but the result does not contain the initial white-space,
889 // returns the original indices).
890 click_indices = internal::SnapLeftIfWhitespaceSelection(
891 click_indices, context_unicode, *unilib_);
892 }
893
894 Annotations candidates;
895 // As we process a single string of context, the candidates will only
896 // contain one vector of AnnotatedSpan.
897 candidates.annotated_spans.resize(1);
898 InterpreterManager interpreter_manager(selection_executor_.get(),
899 classification_executor_.get());
900 std::vector<Token> tokens;
901 if (!ModelSuggestSelection(context_unicode, click_indices,
902 detected_text_language_tags, &interpreter_manager,
903 &tokens, &candidates.annotated_spans[0])) {
904 TC3_LOG(ERROR) << "Model suggest selection failed.";
905 return original_click_indices;
906 }
907 const std::unordered_set<std::string> set;
908 const EnabledEntityTypes is_entity_type_enabled(set);
909 if (!RegexChunk(context_unicode, selection_regex_patterns_,
910 /*is_serialized_entity_data_enabled=*/false,
911 is_entity_type_enabled, options.annotation_usecase,
912 &candidates.annotated_spans[0])) {
913 TC3_LOG(ERROR) << "Regex suggest selection failed.";
914 return original_click_indices;
915 }
916 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
917 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
918 options.locales, ModeFlag_SELECTION,
919 options.annotation_usecase,
920 /*is_serialized_entity_data_enabled=*/false,
921 &candidates.annotated_spans[0])) {
922 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
923 return original_click_indices;
924 }
925 if (knowledge_engine_ != nullptr &&
926 !knowledge_engine_
927 ->Chunk(context, options.annotation_usecase,
928 options.location_context, Permissions(),
929 AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION,
930 &candidates)
931 .ok()) {
932 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
933 return original_click_indices;
934 }
935 if (contact_engine_ != nullptr &&
936 !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
937 &candidates.annotated_spans[0])) {
938 TC3_LOG(ERROR) << "Contact suggest selection failed.";
939 return original_click_indices;
940 }
941 if (installed_app_engine_ != nullptr &&
942 !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
943 &candidates.annotated_spans[0])) {
944 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
945 return original_click_indices;
946 }
947 if (number_annotator_ != nullptr &&
948 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
949 ModeFlag_SELECTION,
950 &candidates.annotated_spans[0])) {
951 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
952 return original_click_indices;
953 }
954 if (duration_annotator_ != nullptr &&
955 !duration_annotator_->FindAll(
956 context_unicode, tokens, options.annotation_usecase,
957 ModeFlag_SELECTION, &candidates.annotated_spans[0])) {
958 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
959 return original_click_indices;
960 }
961 if (person_name_engine_ != nullptr &&
962 !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
963 &candidates.annotated_spans[0])) {
964 TC3_LOG(ERROR) << "Person name suggest selection failed.";
965 return original_click_indices;
966 }
967
968 AnnotatedSpan grammar_suggested_span;
969 if (grammar_annotator_ != nullptr &&
970 grammar_annotator_->SuggestSelection(detected_text_language_tags,
971 context_unicode, click_indices,
972 &grammar_suggested_span)) {
973 candidates.annotated_spans[0].push_back(grammar_suggested_span);
974 }
975
976 AnnotatedSpan pod_ner_suggested_span;
977 if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
978 pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
979 &pod_ner_suggested_span)) {
980 candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
981 }
982
983 if (experimental_annotator_ != nullptr &&
984 (model_->triggering_options()->experimental_enabled_modes() &
985 ModeFlag_SELECTION)) {
986 candidates.annotated_spans[0].push_back(
987 experimental_annotator_->SuggestSelection(context_unicode,
988 click_indices));
989 }
990
991 // Sort candidates according to their position in the input, so that the next
992 // code can assume that any connected component of overlapping spans forms a
993 // contiguous block.
994 std::stable_sort(candidates.annotated_spans[0].begin(),
995 candidates.annotated_spans[0].end(),
996 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
997 return a.span.first < b.span.first;
998 });
999
1000 std::vector<int> candidate_indices;
1001 if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
1002 detected_text_language_tags, options,
1003 &interpreter_manager, &candidate_indices)) {
1004 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1005 return original_click_indices;
1006 }
1007
1008 std::stable_sort(
1009 candidate_indices.begin(), candidate_indices.end(),
1010 [this, &candidates](int a, int b) {
1011 return GetPriorityScore(
1012 candidates.annotated_spans[0][a].classification) >
1013 GetPriorityScore(
1014 candidates.annotated_spans[0][b].classification);
1015 });
1016
1017 for (const int i : candidate_indices) {
1018 if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
1019 SpansOverlap(candidates.annotated_spans[0][i].span,
1020 original_click_indices)) {
1021 // Run model classification if not present but requested and there's a
1022 // classification collection filter specified.
1023 if (candidates.annotated_spans[0][i].classification.empty() &&
1024 model_->selection_options()->always_classify_suggested_selection() &&
1025 !filtered_collections_selection_.empty()) {
1026 if (!ModelClassifyText(context, /*cached_tokens=*/{},
1027 detected_text_language_tags,
1028 candidates.annotated_spans[0][i].span, options,
1029 &interpreter_manager,
1030 /*embedding_cache=*/nullptr,
1031 &candidates.annotated_spans[0][i].classification,
1032 /*tokens=*/nullptr)) {
1033 return original_click_indices;
1034 }
1035 }
1036
1037 // Ignore if span classification is filtered.
1038 if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1039 return original_click_indices;
1040 }
1041
1042 // We return a suggested span contains the original span.
1043 // This compensates for "select all" selection that may come from
1044 // other apps. See http://b/179890518.
1045 if (SpanContains(candidates.annotated_spans[0][i].span,
1046 original_click_indices)) {
1047 return candidates.annotated_spans[0][i].span;
1048 }
1049 }
1050 }
1051
1052 return original_click_indices;
1053 }
1054
1055 namespace {
1056 // Helper function that returns the index of the first candidate that
1057 // transitively does not overlap with the candidate on 'start_index'. If the end
1058 // of 'candidates' is reached, it returns the index that points right behind the
1059 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1060 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1061 int start_index) {
1062 int first_non_overlapping = start_index + 1;
1063 CodepointSpan conflicting_span = candidates[start_index].span;
1064 while (
1065 first_non_overlapping < candidates.size() &&
1066 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1067 // Grow the span to include the current one.
1068 conflicting_span.second = std::max(
1069 conflicting_span.second, candidates[first_non_overlapping].span.second);
1070
1071 ++first_non_overlapping;
1072 }
1073 return first_non_overlapping;
1074 }
1075 } // namespace
1076
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1077 bool Annotator::ResolveConflicts(
1078 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1079 const std::vector<Token>& cached_tokens,
1080 const std::vector<Locale>& detected_text_language_tags,
1081 const BaseOptions& options, InterpreterManager* interpreter_manager,
1082 std::vector<int>* result) const {
1083 result->clear();
1084 result->reserve(candidates.size());
1085 for (int i = 0; i < candidates.size();) {
1086 int first_non_overlapping =
1087 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1088
1089 const bool conflict_found = first_non_overlapping != (i + 1);
1090 if (conflict_found) {
1091 std::vector<int> candidate_indices;
1092 if (!ResolveConflict(context, cached_tokens, candidates,
1093 detected_text_language_tags, i,
1094 first_non_overlapping, options, interpreter_manager,
1095 &candidate_indices)) {
1096 return false;
1097 }
1098 result->insert(result->end(), candidate_indices.begin(),
1099 candidate_indices.end());
1100 } else {
1101 result->push_back(i);
1102 }
1103
1104 // Skip over the whole conflicting group/go to next candidate.
1105 i = first_non_overlapping;
1106 }
1107 return true;
1108 }
1109
1110 namespace {
1111 // Returns true, if the given two sources do conflict in given annotation
1112 // usecase.
1113 // - In SMART usecase, all sources do conflict, because there's only 1 possible
1114 // annotation for a given span.
1115 // - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1116 // and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1117 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1118 const AnnotatedSpan::Source source1,
1119 const AnnotatedSpan::Source source2) {
1120 uint32 source_mask =
1121 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1122
1123 switch (annotation_usecase) {
1124 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1125 // In the SMART mode, all annotations conflict.
1126 return true;
1127
1128 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1129 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1130 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1131 // hours" (duration).
1132 if ((source_mask &
1133 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1134 (source_mask &
1135 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1136 return false;
1137 }
1138
1139 // A KNOWLEDGE entity does not conflict with anything.
1140 if ((source_mask &
1141 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1142 return false;
1143 }
1144
1145 // A PERSONNAME entity does not conflict with anything.
1146 if ((source_mask &
1147 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1148 return false;
1149 }
1150
1151 // Entities from other sources can conflict.
1152 return true;
1153 }
1154 }
1155 } // namespace
1156
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1157 bool Annotator::ResolveConflict(
1158 const std::string& context, const std::vector<Token>& cached_tokens,
1159 const std::vector<AnnotatedSpan>& candidates,
1160 const std::vector<Locale>& detected_text_language_tags, int start_index,
1161 int end_index, const BaseOptions& options,
1162 InterpreterManager* interpreter_manager,
1163 std::vector<int>* chosen_indices) const {
1164 std::vector<int> conflicting_indices;
1165 std::unordered_map<int, std::pair<float, int>> scores_lengths;
1166 for (int i = start_index; i < end_index; ++i) {
1167 conflicting_indices.push_back(i);
1168 if (!candidates[i].classification.empty()) {
1169 scores_lengths[i] = {
1170 GetPriorityScore(candidates[i].classification),
1171 candidates[i].span.second - candidates[i].span.first};
1172 continue;
1173 }
1174
1175 // OPTIMIZATION: So that we don't have to classify all the ML model
1176 // spans apriori, we wait until we get here, when they conflict with
1177 // something and we need the actual classification scores. So if the
1178 // candidate conflicts and comes from the model, we need to run a
1179 // classification to determine its priority:
1180 std::vector<ClassificationResult> classification;
1181 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1182 candidates[i].span, options, interpreter_manager,
1183 /*embedding_cache=*/nullptr, &classification,
1184 /*tokens=*/nullptr)) {
1185 return false;
1186 }
1187
1188 if (!classification.empty()) {
1189 scores_lengths[i] = {
1190 GetPriorityScore(classification),
1191 candidates[i].span.second - candidates[i].span.first};
1192 }
1193 }
1194
1195 std::stable_sort(
1196 conflicting_indices.begin(), conflicting_indices.end(),
1197 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1198 if (scores_lengths[i].first == scores_lengths[j].first &&
1199 prioritize_longest_annotation_) {
1200 return scores_lengths[i].second > scores_lengths[j].second;
1201 }
1202 return scores_lengths[i].first > scores_lengths[j].first;
1203 });
1204
1205 // Here we keep a set of indices that were chosen, per-source, to enable
1206 // effective computation.
1207 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1208 chosen_indices_for_source_map;
1209
1210 // Greedily place the candidates if they don't conflict with the already
1211 // placed ones.
1212 for (int i = 0; i < conflicting_indices.size(); ++i) {
1213 const int considered_candidate = conflicting_indices[i];
1214
1215 // See if there is a conflict between the candidate and all already placed
1216 // candidates.
1217 bool conflict = false;
1218 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1219 for (auto& source_set_pair : chosen_indices_for_source_map) {
1220 if (source_set_pair.first == candidates[considered_candidate].source) {
1221 chosen_indices_for_source_ptr = &source_set_pair.second;
1222 }
1223
1224 const bool needs_conflict_resolution =
1225 options.annotation_usecase ==
1226 AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1227 (options.annotation_usecase ==
1228 AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1229 do_conflict_resolution_in_raw_mode_);
1230 if (needs_conflict_resolution &&
1231 DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1232 candidates[considered_candidate].source) &&
1233 DoesCandidateConflict(considered_candidate, candidates,
1234 source_set_pair.second)) {
1235 conflict = true;
1236 break;
1237 }
1238 }
1239
1240 // Skip the candidate if a conflict was found.
1241 if (conflict) {
1242 continue;
1243 }
1244
1245 // If the set of indices for the current source doesn't exist yet,
1246 // initialize it.
1247 if (chosen_indices_for_source_ptr == nullptr) {
1248 SortedIntSet new_set([&candidates](int a, int b) {
1249 return candidates[a].span.first < candidates[b].span.first;
1250 });
1251 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1252 std::move(new_set);
1253 chosen_indices_for_source_ptr =
1254 &chosen_indices_for_source_map[candidates[considered_candidate]
1255 .source];
1256 }
1257
1258 // Place the candidate to the output and to the per-source conflict set.
1259 chosen_indices->push_back(considered_candidate);
1260 chosen_indices_for_source_ptr->insert(considered_candidate);
1261 }
1262
1263 std::stable_sort(chosen_indices->begin(), chosen_indices->end());
1264
1265 return true;
1266 }
1267
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1268 bool Annotator::ModelSuggestSelection(
1269 const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1270 const std::vector<Locale>& detected_text_language_tags,
1271 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1272 std::vector<AnnotatedSpan>* result) const {
1273 if (model_->triggering_options() == nullptr ||
1274 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1275 return true;
1276 }
1277
1278 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1279 ml_model_triggering_locales_,
1280 /*default_value=*/true)) {
1281 return true;
1282 }
1283
1284 int click_pos;
1285 *tokens = selection_feature_processor_->Tokenize(context_unicode);
1286 const auto [click_begin, click_end] =
1287 CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1288 selection_feature_processor_->RetokenizeAndFindClick(
1289 context_unicode, click_begin, click_end, click_indices,
1290 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1291 tokens, &click_pos);
1292 if (click_pos == kInvalidIndex) {
1293 TC3_VLOG(1) << "Could not calculate the click position.";
1294 return false;
1295 }
1296
1297 const int symmetry_context_size =
1298 model_->selection_options()->symmetry_context_size();
1299 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1300 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1301 ->bounds_sensitive_features();
1302
1303 // The symmetry context span is the clicked token with symmetry_context_size
1304 // tokens on either side.
1305 const TokenSpan symmetry_context_span =
1306 IntersectTokenSpans(TokenSpan(click_pos).Expand(
1307 /*num_tokens_left=*/symmetry_context_size,
1308 /*num_tokens_right=*/symmetry_context_size),
1309 AllOf(*tokens));
1310
1311 // Compute the extraction span based on the model type.
1312 TokenSpan extraction_span;
1313 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1314 // The extraction span is the symmetry context span expanded to include
1315 // max_selection_span tokens on either side, which is how far a selection
1316 // can stretch from the click, plus a relevant number of tokens outside of
1317 // the bounds of the selection.
1318 const int max_selection_span =
1319 selection_feature_processor_->GetOptions()->max_selection_span();
1320 extraction_span = symmetry_context_span.Expand(
1321 /*num_tokens_left=*/max_selection_span +
1322 bounds_sensitive_features->num_tokens_before(),
1323 /*num_tokens_right=*/max_selection_span +
1324 bounds_sensitive_features->num_tokens_after());
1325 } else {
1326 // The extraction span is the symmetry context span expanded to include
1327 // context_size tokens on either side.
1328 const int context_size =
1329 selection_feature_processor_->GetOptions()->context_size();
1330 extraction_span = symmetry_context_span.Expand(
1331 /*num_tokens_left=*/context_size,
1332 /*num_tokens_right=*/context_size);
1333 }
1334 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1335
1336 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1337 *tokens, extraction_span)) {
1338 return true;
1339 }
1340
1341 std::unique_ptr<CachedFeatures> cached_features;
1342 if (!selection_feature_processor_->ExtractFeatures(
1343 *tokens, extraction_span,
1344 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1345 embedding_executor_.get(),
1346 /*embedding_cache=*/nullptr,
1347 selection_feature_processor_->EmbeddingSize() +
1348 selection_feature_processor_->DenseFeaturesCount(),
1349 &cached_features)) {
1350 TC3_LOG(ERROR) << "Could not extract features.";
1351 return false;
1352 }
1353
1354 // Produce selection model candidates.
1355 std::vector<TokenSpan> chunks;
1356 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1357 interpreter_manager->SelectionInterpreter(), *cached_features,
1358 &chunks)) {
1359 TC3_LOG(ERROR) << "Could not chunk.";
1360 return false;
1361 }
1362
1363 for (const TokenSpan& chunk : chunks) {
1364 AnnotatedSpan candidate;
1365 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1366 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1367 if (model_->selection_options()->strip_unpaired_brackets()) {
1368 candidate.span =
1369 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1370 }
1371
1372 // Only output non-empty spans.
1373 if (candidate.span.first != candidate.span.second) {
1374 result->push_back(candidate);
1375 }
1376 }
1377 return true;
1378 }
1379
1380 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1381 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1382 const CodepointSpan& selection_indices,
1383 TokenSpan tokens_around_selection_to_copy) {
1384 const auto first_selection_token = std::upper_bound(
1385 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1386 [](int selection_start, const Token& token) {
1387 return selection_start < token.end;
1388 });
1389 const auto last_selection_token = std::lower_bound(
1390 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1391 [](const Token& token, int selection_end) {
1392 return token.start < selection_end;
1393 });
1394
1395 const int64 first_token = std::max(
1396 static_cast<int64>(0),
1397 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1398 tokens_around_selection_to_copy.first));
1399 const int64 last_token = std::min(
1400 static_cast<int64>(cached_tokens.size()),
1401 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1402 tokens_around_selection_to_copy.second));
1403
1404 std::vector<Token> tokens;
1405 tokens.reserve(last_token - first_token);
1406 for (int i = first_token; i < last_token; ++i) {
1407 tokens.push_back(cached_tokens[i]);
1408 }
1409 return tokens;
1410 }
1411 } // namespace internal
1412
ClassifyTextUpperBoundNeededTokens() const1413 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1414 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1415 bounds_sensitive_features =
1416 classification_feature_processor_->GetOptions()
1417 ->bounds_sensitive_features();
1418 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1419 // The extraction span is the selection span expanded to include a relevant
1420 // number of tokens outside of the bounds of the selection.
1421 return {bounds_sensitive_features->num_tokens_before(),
1422 bounds_sensitive_features->num_tokens_after()};
1423 } else {
1424 // The extraction span is the clicked token with context_size tokens on
1425 // either side.
1426 const int context_size =
1427 selection_feature_processor_->GetOptions()->context_size();
1428 return {context_size, context_size};
1429 }
1430 }
1431
1432 namespace {
1433 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1434 void SortClassificationResults(
1435 std::vector<ClassificationResult>* classification_results) {
1436 std::stable_sort(
1437 classification_results->begin(), classification_results->end(),
1438 [](const ClassificationResult& a, const ClassificationResult& b) {
1439 return a.score > b.score;
1440 });
1441 }
1442 } // namespace
1443
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1444 bool Annotator::ModelClassifyText(
1445 const std::string& context, const std::vector<Token>& cached_tokens,
1446 const std::vector<Locale>& detected_text_language_tags,
1447 const CodepointSpan& selection_indices, const BaseOptions& options,
1448 InterpreterManager* interpreter_manager,
1449 FeatureProcessor::EmbeddingCache* embedding_cache,
1450 std::vector<ClassificationResult>* classification_results,
1451 std::vector<Token>* tokens) const {
1452 const UnicodeText context_unicode =
1453 UTF8ToUnicodeText(context, /*do_copy=*/false);
1454 const auto [span_begin, span_end] =
1455 CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1456 return ModelClassifyText(context_unicode, cached_tokens,
1457 detected_text_language_tags, span_begin, span_end,
1458 /*line=*/nullptr, selection_indices, options,
1459 interpreter_manager, embedding_cache,
1460 classification_results, tokens);
1461 }
1462
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1463 bool Annotator::ModelClassifyText(
1464 const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1465 const std::vector<Locale>& detected_text_language_tags,
1466 const UnicodeText::const_iterator& span_begin,
1467 const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1468 const CodepointSpan& selection_indices, const BaseOptions& options,
1469 InterpreterManager* interpreter_manager,
1470 FeatureProcessor::EmbeddingCache* embedding_cache,
1471 std::vector<ClassificationResult>* classification_results,
1472 std::vector<Token>* tokens) const {
1473 if (model_->triggering_options() == nullptr ||
1474 !(model_->triggering_options()->enabled_modes() &
1475 ModeFlag_CLASSIFICATION)) {
1476 return true;
1477 }
1478
1479 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1480 ml_model_triggering_locales_,
1481 /*default_value=*/true)) {
1482 return true;
1483 }
1484
1485 std::vector<Token> local_tokens;
1486 if (tokens == nullptr) {
1487 tokens = &local_tokens;
1488 }
1489
1490 if (cached_tokens.empty()) {
1491 *tokens = classification_feature_processor_->Tokenize(context_unicode);
1492 } else {
1493 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1494 ClassifyTextUpperBoundNeededTokens());
1495 }
1496
1497 int click_pos;
1498 classification_feature_processor_->RetokenizeAndFindClick(
1499 context_unicode, span_begin, span_end, selection_indices,
1500 classification_feature_processor_->GetOptions()
1501 ->only_use_line_with_click(),
1502 tokens, &click_pos);
1503 const TokenSpan selection_token_span =
1504 CodepointSpanToTokenSpan(*tokens, selection_indices);
1505 const int selection_num_tokens = selection_token_span.Size();
1506 if (model_->classification_options()->max_num_tokens() > 0 &&
1507 model_->classification_options()->max_num_tokens() <
1508 selection_num_tokens) {
1509 *classification_results = {{Collections::Other(), 1.0}};
1510 return true;
1511 }
1512
1513 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1514 bounds_sensitive_features =
1515 classification_feature_processor_->GetOptions()
1516 ->bounds_sensitive_features();
1517 if (selection_token_span.first == kInvalidIndex ||
1518 selection_token_span.second == kInvalidIndex) {
1519 TC3_LOG(ERROR) << "Could not determine span.";
1520 return false;
1521 }
1522
1523 // Compute the extraction span based on the model type.
1524 TokenSpan extraction_span;
1525 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1526 // The extraction span is the selection span expanded to include a relevant
1527 // number of tokens outside of the bounds of the selection.
1528 extraction_span = selection_token_span.Expand(
1529 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1530 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1531 } else {
1532 if (click_pos == kInvalidIndex) {
1533 TC3_LOG(ERROR) << "Couldn't choose a click position.";
1534 return false;
1535 }
1536 // The extraction span is the clicked token with context_size tokens on
1537 // either side.
1538 const int context_size =
1539 classification_feature_processor_->GetOptions()->context_size();
1540 extraction_span = TokenSpan(click_pos).Expand(
1541 /*num_tokens_left=*/context_size,
1542 /*num_tokens_right=*/context_size);
1543 }
1544 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1545
1546 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1547 *tokens, extraction_span)) {
1548 *classification_results = {{Collections::Other(), 1.0}};
1549 return true;
1550 }
1551
1552 std::unique_ptr<CachedFeatures> cached_features;
1553 if (!classification_feature_processor_->ExtractFeatures(
1554 *tokens, extraction_span, selection_indices,
1555 embedding_executor_.get(), embedding_cache,
1556 classification_feature_processor_->EmbeddingSize() +
1557 classification_feature_processor_->DenseFeaturesCount(),
1558 &cached_features)) {
1559 TC3_LOG(ERROR) << "Could not extract features.";
1560 return false;
1561 }
1562
1563 std::vector<float> features;
1564 features.reserve(cached_features->OutputFeaturesSize());
1565 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1566 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1567 &features);
1568 } else {
1569 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1570 }
1571
1572 TensorView<float> logits = classification_executor_->ComputeLogits(
1573 TensorView<float>(features.data(),
1574 {1, static_cast<int>(features.size())}),
1575 interpreter_manager->ClassificationInterpreter());
1576 if (!logits.is_valid()) {
1577 TC3_LOG(ERROR) << "Couldn't compute logits.";
1578 return false;
1579 }
1580
1581 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1582 logits.dim(1) != classification_feature_processor_->NumCollections()) {
1583 TC3_LOG(ERROR) << "Mismatching output";
1584 return false;
1585 }
1586
1587 const std::vector<float> scores =
1588 ComputeSoftmax(logits.data(), logits.dim(1));
1589
1590 if (scores.empty()) {
1591 *classification_results = {{Collections::Other(), 1.0}};
1592 return true;
1593 }
1594
1595 const int best_score_index =
1596 std::max_element(scores.begin(), scores.end()) - scores.begin();
1597 const std::string top_collection =
1598 classification_feature_processor_->LabelToCollection(best_score_index);
1599
1600 // Sanity checks.
1601 if (top_collection == Collections::Phone()) {
1602 const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1603 if (digit_count <
1604 model_->classification_options()->phone_min_num_digits() ||
1605 digit_count >
1606 model_->classification_options()->phone_max_num_digits()) {
1607 *classification_results = {{Collections::Other(), 1.0}};
1608 return true;
1609 }
1610 } else if (top_collection == Collections::Address()) {
1611 if (selection_num_tokens <
1612 model_->classification_options()->address_min_num_tokens()) {
1613 *classification_results = {{Collections::Other(), 1.0}};
1614 return true;
1615 }
1616 } else if (top_collection == Collections::Dictionary()) {
1617 if ((options.use_vocab_annotator && vocab_annotator_) ||
1618 !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1619 dictionary_locales_,
1620 /*default_value=*/false)) {
1621 *classification_results = {{Collections::Other(), 1.0}};
1622 return true;
1623 }
1624 }
1625 *classification_results = {{top_collection, /*arg_score=*/1.0,
1626 /*arg_priority_score=*/scores[best_score_index]}};
1627
1628 // For some entities, we might want to clamp the priority score, for better
1629 // conflict resolution between entities.
1630 if (model_->triggering_options() != nullptr &&
1631 model_->triggering_options()->collection_to_priority() != nullptr) {
1632 if (auto entry =
1633 model_->triggering_options()->collection_to_priority()->LookupByKey(
1634 top_collection.c_str())) {
1635 (*classification_results)[0].priority_score *= entry->value();
1636 }
1637 }
1638 return true;
1639 }
1640
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1641 bool Annotator::RegexClassifyText(
1642 const std::string& context, const CodepointSpan& selection_indices,
1643 std::vector<ClassificationResult>* classification_result) const {
1644 const std::string selection_text =
1645 UTF8ToUnicodeText(context, /*do_copy=*/false)
1646 .UTF8Substring(selection_indices.first, selection_indices.second);
1647 const UnicodeText selection_text_unicode(
1648 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1649
1650 // Check whether any of the regular expressions match.
1651 for (const int pattern_id : classification_regex_patterns_) {
1652 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1653 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1654 regex_pattern.pattern->Matcher(selection_text_unicode);
1655 int status = UniLib::RegexMatcher::kNoError;
1656 bool matches;
1657 if (regex_pattern.config->use_approximate_matching()) {
1658 matches = matcher->ApproximatelyMatches(&status);
1659 } else {
1660 matches = matcher->Matches(&status);
1661 }
1662 if (status != UniLib::RegexMatcher::kNoError) {
1663 return false;
1664 }
1665 if (matches && VerifyRegexMatchCandidate(
1666 context, regex_pattern.config->verification_options(),
1667 selection_text, matcher.get())) {
1668 classification_result->push_back(
1669 {regex_pattern.config->collection_name()->str(),
1670 regex_pattern.config->target_classification_score(),
1671 regex_pattern.config->priority_score()});
1672 if (!SerializedEntityDataFromRegexMatch(
1673 regex_pattern.config, matcher.get(),
1674 &classification_result->back().serialized_entity_data)) {
1675 TC3_LOG(ERROR) << "Could not get entity data.";
1676 return false;
1677 }
1678 }
1679 }
1680
1681 return true;
1682 }
1683
1684 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1685 std::string PickCollectionForDatetime(
1686 const DatetimeParseResult& datetime_parse_result) {
1687 switch (datetime_parse_result.granularity) {
1688 case GRANULARITY_HOUR:
1689 case GRANULARITY_MINUTE:
1690 case GRANULARITY_SECOND:
1691 return Collections::DateTime();
1692 default:
1693 return Collections::Date();
1694 }
1695 }
1696
1697 } // namespace
1698
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1699 bool Annotator::DatetimeClassifyText(
1700 const std::string& context, const CodepointSpan& selection_indices,
1701 const ClassificationOptions& options,
1702 std::vector<ClassificationResult>* classification_results) const {
1703 if (!datetime_parser_) {
1704 return true;
1705 }
1706
1707 const std::string selection_text =
1708 UTF8ToUnicodeText(context, /*do_copy=*/false)
1709 .UTF8Substring(selection_indices.first, selection_indices.second);
1710
1711 LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1712 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1713 datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1714 options.reference_timezone, locale_list,
1715 ModeFlag_CLASSIFICATION,
1716 options.annotation_usecase,
1717 /*anchor_start_end=*/true);
1718 if (!result_status.ok()) {
1719 TC3_LOG(ERROR) << "Error during parsing datetime.";
1720 return false;
1721 }
1722
1723 for (const DatetimeParseResultSpan& datetime_span :
1724 result_status.ValueOrDie()) {
1725 // Only consider the result valid if the selection and extracted datetime
1726 // spans exactly match.
1727 if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1728 datetime_span.span.second + selection_indices.first) ==
1729 selection_indices) {
1730 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1731 classification_results->emplace_back(
1732 PickCollectionForDatetime(parse_result),
1733 datetime_span.target_classification_score);
1734 classification_results->back().datetime_parse_result = parse_result;
1735 classification_results->back().serialized_entity_data =
1736 CreateDatetimeSerializedEntityData(parse_result);
1737 classification_results->back().priority_score =
1738 datetime_span.priority_score;
1739 }
1740 return true;
1741 }
1742 }
1743 return true;
1744 }
1745
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1746 std::vector<ClassificationResult> Annotator::ClassifyText(
1747 const std::string& context, const CodepointSpan& selection_indices,
1748 const ClassificationOptions& options) const {
1749 if (context.size() > std::numeric_limits<int>::max()) {
1750 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1751 return {};
1752 }
1753 if (!initialized_) {
1754 TC3_LOG(ERROR) << "Not initialized";
1755 return {};
1756 }
1757 if (options.annotation_usecase !=
1758 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1759 TC3_LOG(WARNING)
1760 << "Invoking ClassifyText, which is not supported in RAW mode.";
1761 return {};
1762 }
1763 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1764 return {};
1765 }
1766
1767 std::vector<Locale> detected_text_language_tags;
1768 if (!ParseLocales(options.detected_text_language_tags,
1769 &detected_text_language_tags)) {
1770 TC3_LOG(WARNING)
1771 << "Failed to parse the detected_text_language_tags in options: "
1772 << options.detected_text_language_tags;
1773 }
1774 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1775 model_triggering_locales_,
1776 /*default_value=*/true)) {
1777 return {};
1778 }
1779
1780 const UnicodeText context_unicode =
1781 UTF8ToUnicodeText(context, /*do_copy=*/false);
1782
1783 if (!unilib_->IsValidUtf8(context_unicode)) {
1784 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1785 return {};
1786 }
1787
1788 if (!IsValidSpanInput(context_unicode, selection_indices)) {
1789 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1790 << selection_indices.first << " " << selection_indices.second;
1791 return {};
1792 }
1793
1794 // We'll accumulate a list of candidates, and pick the best candidate in the
1795 // end.
1796 std::vector<AnnotatedSpan> candidates;
1797
1798 // Try the knowledge engine.
1799 // TODO(b/126579108): Propagate error status.
1800 ClassificationResult knowledge_result;
1801 if (knowledge_engine_ &&
1802 knowledge_engine_
1803 ->ClassifyText(context, selection_indices, options.annotation_usecase,
1804 options.location_context, Permissions(),
1805 &knowledge_result)
1806 .ok()) {
1807 candidates.push_back({selection_indices, {knowledge_result}});
1808 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1809 }
1810
1811 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1812
1813 // Try the contact engine.
1814 // TODO(b/126579108): Propagate error status.
1815 ClassificationResult contact_result;
1816 if (contact_engine_ && contact_engine_->ClassifyText(
1817 context, selection_indices, &contact_result)) {
1818 candidates.push_back({selection_indices, {contact_result}});
1819 }
1820
1821 // Try the person name engine.
1822 ClassificationResult person_name_result;
1823 if (person_name_engine_ &&
1824 person_name_engine_->ClassifyText(context, selection_indices,
1825 &person_name_result)) {
1826 candidates.push_back({selection_indices, {person_name_result}});
1827 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1828 }
1829
1830 // Try the installed app engine.
1831 // TODO(b/126579108): Propagate error status.
1832 ClassificationResult installed_app_result;
1833 if (installed_app_engine_ &&
1834 installed_app_engine_->ClassifyText(context, selection_indices,
1835 &installed_app_result)) {
1836 candidates.push_back({selection_indices, {installed_app_result}});
1837 }
1838
1839 // Try the regular expression models.
1840 std::vector<ClassificationResult> regex_results;
1841 if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
1842 return {};
1843 }
1844 for (const ClassificationResult& result : regex_results) {
1845 candidates.push_back({selection_indices, {result}});
1846 }
1847
1848 // Try the date model.
1849 //
1850 // DatetimeClassifyText only returns the first result, which can however have
1851 // more interpretations. They are inserted in the candidates as a single
1852 // AnnotatedSpan, so that they get treated together by the conflict resolution
1853 // algorithm.
1854 std::vector<ClassificationResult> datetime_results;
1855 if (!DatetimeClassifyText(context, selection_indices, options,
1856 &datetime_results)) {
1857 return {};
1858 }
1859 if (!datetime_results.empty()) {
1860 candidates.push_back({selection_indices, std::move(datetime_results)});
1861 candidates.back().source = AnnotatedSpan::Source::DATETIME;
1862 }
1863
1864 // Try the number annotator.
1865 // TODO(b/126579108): Propagate error status.
1866 ClassificationResult number_annotator_result;
1867 if (number_annotator_ &&
1868 number_annotator_->ClassifyText(context_unicode, selection_indices,
1869 options.annotation_usecase,
1870 &number_annotator_result)) {
1871 candidates.push_back({selection_indices, {number_annotator_result}});
1872 }
1873
1874 // Try the duration annotator.
1875 ClassificationResult duration_annotator_result;
1876 if (duration_annotator_ &&
1877 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1878 options.annotation_usecase,
1879 &duration_annotator_result)) {
1880 candidates.push_back({selection_indices, {duration_annotator_result}});
1881 candidates.back().source = AnnotatedSpan::Source::DURATION;
1882 }
1883
1884 // Try the translate annotator.
1885 ClassificationResult translate_annotator_result;
1886 if (translate_annotator_ &&
1887 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1888 options.user_familiar_language_tags,
1889 &translate_annotator_result)) {
1890 candidates.push_back({selection_indices, {translate_annotator_result}});
1891 }
1892
1893 // Try the grammar model.
1894 ClassificationResult grammar_annotator_result;
1895 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1896 detected_text_language_tags, context_unicode,
1897 selection_indices, &grammar_annotator_result)) {
1898 candidates.push_back({selection_indices, {grammar_annotator_result}});
1899 }
1900
1901 ClassificationResult pod_ner_annotator_result;
1902 if (pod_ner_annotator_ && options.use_pod_ner &&
1903 pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1904 &pod_ner_annotator_result)) {
1905 candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1906 }
1907
1908 ClassificationResult vocab_annotator_result;
1909 if (vocab_annotator_ && options.use_vocab_annotator &&
1910 vocab_annotator_->ClassifyText(
1911 context_unicode, selection_indices, detected_text_language_tags,
1912 options.trigger_dictionary_on_beginner_words,
1913 &vocab_annotator_result)) {
1914 candidates.push_back({selection_indices, {vocab_annotator_result}});
1915 }
1916
1917 if (experimental_annotator_ &&
1918 (model_->triggering_options()->experimental_enabled_modes() &
1919 ModeFlag_CLASSIFICATION)) {
1920 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1921 candidates);
1922 }
1923
1924 // Try the ML model.
1925 //
1926 // The output of the model is considered as an exclusive 1-of-N choice. That's
1927 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1928 // span for each candidate, like e.g. the regex model.
1929 InterpreterManager interpreter_manager(selection_executor_.get(),
1930 classification_executor_.get());
1931 std::vector<ClassificationResult> model_results;
1932 std::vector<Token> tokens;
1933 if (!ModelClassifyText(
1934 context, /*cached_tokens=*/{}, detected_text_language_tags,
1935 selection_indices, options, &interpreter_manager,
1936 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1937 return {};
1938 }
1939 if (!model_results.empty()) {
1940 candidates.push_back({selection_indices, std::move(model_results)});
1941 }
1942
1943 std::vector<int> candidate_indices;
1944 if (!ResolveConflicts(candidates, context, tokens,
1945 detected_text_language_tags, options,
1946 &interpreter_manager, &candidate_indices)) {
1947 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1948 return {};
1949 }
1950
1951 std::vector<ClassificationResult> results;
1952 for (const int i : candidate_indices) {
1953 for (const ClassificationResult& result : candidates[i].classification) {
1954 if (!FilteredForClassification(result)) {
1955 results.push_back(result);
1956 }
1957 }
1958 }
1959
1960 // Sort results according to score.
1961 std::stable_sort(
1962 results.begin(), results.end(),
1963 [](const ClassificationResult& a, const ClassificationResult& b) {
1964 return a.score > b.score;
1965 });
1966
1967 if (results.empty()) {
1968 results = {{Collections::Other(), 1.0}};
1969 }
1970 return results;
1971 }
1972
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1973 bool Annotator::ModelAnnotate(
1974 const std::string& context,
1975 const std::vector<Locale>& detected_text_language_tags,
1976 const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1977 std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1978 bool skip_model_annotatation = false;
1979 if (model_->triggering_options() == nullptr ||
1980 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1981 skip_model_annotatation = true;
1982 }
1983 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1984 ml_model_triggering_locales_,
1985 /*default_value=*/true)) {
1986 skip_model_annotatation = true;
1987 }
1988
1989 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1990 /*do_copy=*/false);
1991 std::vector<UnicodeTextRange> lines;
1992 if (!selection_feature_processor_ ||
1993 !selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1994 lines.push_back({context_unicode.begin(), context_unicode.end()});
1995 } else {
1996 lines = selection_feature_processor_->SplitContext(
1997 context_unicode, selection_feature_processor_->GetOptions()
1998 ->use_pipe_character_for_newline());
1999 }
2000
2001 const float min_annotate_confidence =
2002 (model_->triggering_options() != nullptr
2003 ? model_->triggering_options()->min_annotate_confidence()
2004 : 0.f);
2005
2006 for (const UnicodeTextRange& line : lines) {
2007 const std::string line_str =
2008 UnicodeText::UTF8Substring(line.first, line.second);
2009
2010 std::vector<Token> line_tokens;
2011 line_tokens = selection_feature_processor_->Tokenize(line_str);
2012
2013 selection_feature_processor_->RetokenizeAndFindClick(
2014 line_str, {0, std::distance(line.first, line.second)},
2015 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
2016 &line_tokens,
2017 /*click_pos=*/nullptr);
2018 const TokenSpan full_line_span = {
2019 0, static_cast<TokenIndex>(line_tokens.size())};
2020
2021 tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2022
2023 if (skip_model_annotatation) {
2024 // We do not annotate, we only output the tokens.
2025 continue;
2026 }
2027
2028 // TODO(zilka): Add support for greater granularity of this check.
2029 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
2030 line_tokens, full_line_span)) {
2031 continue;
2032 }
2033
2034 std::unique_ptr<CachedFeatures> cached_features;
2035 if (!selection_feature_processor_->ExtractFeatures(
2036 line_tokens, full_line_span,
2037 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2038 embedding_executor_.get(),
2039 /*embedding_cache=*/nullptr,
2040 selection_feature_processor_->EmbeddingSize() +
2041 selection_feature_processor_->DenseFeaturesCount(),
2042 &cached_features)) {
2043 TC3_LOG(ERROR) << "Could not extract features.";
2044 return false;
2045 }
2046
2047 std::vector<TokenSpan> local_chunks;
2048 if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2049 interpreter_manager->SelectionInterpreter(),
2050 *cached_features, &local_chunks)) {
2051 TC3_LOG(ERROR) << "Could not chunk.";
2052 return false;
2053 }
2054
2055 const int offset = std::distance(context_unicode.begin(), line.first);
2056 if (local_chunks.empty()) {
2057 continue;
2058 }
2059 const UnicodeText line_unicode =
2060 UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2061 std::vector<UnicodeText::const_iterator> line_codepoints =
2062 line_unicode.Codepoints();
2063 line_codepoints.push_back(line_unicode.end());
2064
2065 FeatureProcessor::EmbeddingCache embedding_cache;
2066 for (const TokenSpan& chunk : local_chunks) {
2067 CodepointSpan codepoint_span =
2068 TokenSpanToCodepointSpan(line_tokens, chunk);
2069 if (!codepoint_span.IsValid() ||
2070 codepoint_span.second > line_codepoints.size()) {
2071 continue;
2072 }
2073 codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2074 /*span_begin=*/line_codepoints[codepoint_span.first],
2075 /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
2076 if (model_->selection_options()->strip_unpaired_brackets()) {
2077 codepoint_span = StripUnpairedBrackets(
2078 /*span_begin=*/line_codepoints[codepoint_span.first],
2079 /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
2080 *unilib_);
2081 }
2082
2083 // Skip empty spans.
2084 if (codepoint_span.first != codepoint_span.second) {
2085 std::vector<ClassificationResult> classification;
2086 if (!ModelClassifyText(
2087 line_unicode, line_tokens, detected_text_language_tags,
2088 /*span_begin=*/line_codepoints[codepoint_span.first],
2089 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2090 codepoint_span, options, interpreter_manager, &embedding_cache,
2091 &classification, /*tokens=*/nullptr)) {
2092 TC3_LOG(ERROR) << "Could not classify text: "
2093 << (codepoint_span.first + offset) << " "
2094 << (codepoint_span.second + offset);
2095 return false;
2096 }
2097
2098 // Do not include the span if it's classified as "other".
2099 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2100 classification[0].score >= min_annotate_confidence) {
2101 AnnotatedSpan result_span;
2102 result_span.span = {codepoint_span.first + offset,
2103 codepoint_span.second + offset};
2104 result_span.classification = std::move(classification);
2105 result->push_back(std::move(result_span));
2106 }
2107 }
2108 }
2109 }
2110 return true;
2111 }
2112
SelectionFeatureProcessorForTests() const2113 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2114 return selection_feature_processor_.get();
2115 }
2116
ClassificationFeatureProcessorForTests() const2117 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2118 const {
2119 return classification_feature_processor_.get();
2120 }
2121
DatetimeParserForTests() const2122 const DatetimeParser* Annotator::DatetimeParserForTests() const {
2123 return datetime_parser_.get();
2124 }
2125
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2126 void Annotator::RemoveNotEnabledEntityTypes(
2127 const EnabledEntityTypes& is_entity_type_enabled,
2128 std::vector<AnnotatedSpan>* annotated_spans) const {
2129 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2130 std::vector<ClassificationResult>& classifications =
2131 annotated_span.classification;
2132 classifications.erase(
2133 std::remove_if(classifications.begin(), classifications.end(),
2134 [&is_entity_type_enabled](
2135 const ClassificationResult& classification_result) {
2136 return !is_entity_type_enabled(
2137 classification_result.collection);
2138 }),
2139 classifications.end());
2140 }
2141 annotated_spans->erase(
2142 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2143 [](const AnnotatedSpan& annotated_span) {
2144 return annotated_span.classification.empty();
2145 }),
2146 annotated_spans->end());
2147 }
2148
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2149 void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2150 std::vector<AnnotatedSpan>* candidates) const {
2151 if (candidates == nullptr || contact_engine_ == nullptr) {
2152 return;
2153 }
2154 for (auto& candidate : *candidates) {
2155 for (auto& classification_result : candidate.classification) {
2156 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2157 &classification_result);
2158 }
2159 }
2160 }
2161
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2162 Status Annotator::AnnotateSingleInput(
2163 const std::string& context, const AnnotationOptions& options,
2164 std::vector<AnnotatedSpan>* candidates) const {
2165 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2166 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2167 }
2168
2169 const UnicodeText context_unicode =
2170 UTF8ToUnicodeText(context, /*do_copy=*/false);
2171
2172 std::vector<Locale> detected_text_language_tags;
2173 if (!ParseLocales(options.detected_text_language_tags,
2174 &detected_text_language_tags)) {
2175 TC3_LOG(WARNING)
2176 << "Failed to parse the detected_text_language_tags in options: "
2177 << options.detected_text_language_tags;
2178 }
2179 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2180 model_triggering_locales_,
2181 /*default_value=*/true)) {
2182 return Status(
2183 StatusCode::UNAVAILABLE,
2184 "The detected language tags are not in the supported locales.");
2185 }
2186
2187 InterpreterManager interpreter_manager(selection_executor_.get(),
2188 classification_executor_.get());
2189
2190 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2191 const bool is_raw_usecase =
2192 options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2193
2194 // Annotate with the selection model.
2195 const bool model_annotations_enabled =
2196 !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2197 std::vector<Token> tokens;
2198 if (model_annotations_enabled &&
2199 !ModelAnnotate(context, detected_text_language_tags, options,
2200 &interpreter_manager, &tokens, candidates)) {
2201 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2202 } else if (!model_annotations_enabled) {
2203 // If the ML model didn't run, we need to tokenize to support the other
2204 // annotators that depend on the tokens.
2205 // Optimization could be made to only do this when an annotator that uses
2206 // the tokens is enabled, but it's unclear if the added complexity is worth
2207 // it.
2208 if (selection_feature_processor_ != nullptr) {
2209 tokens = selection_feature_processor_->Tokenize(context_unicode);
2210 }
2211 }
2212
2213 // Annotate with the regular expression models.
2214 const bool regex_annotations_enabled =
2215 !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2216 if (regex_annotations_enabled &&
2217 !RegexChunk(
2218 UTF8ToUnicodeText(context, /*do_copy=*/false),
2219 annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2220 is_entity_type_enabled, options.annotation_usecase, candidates)) {
2221 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2222 }
2223
2224 // Annotate with the datetime model.
2225 // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2226 // relatively slow for some clients.
2227 if ((is_entity_type_enabled(Collections::Date()) ||
2228 is_entity_type_enabled(Collections::DateTime())) &&
2229 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2230 options.reference_time_ms_utc, options.reference_timezone,
2231 options.locales, ModeFlag_ANNOTATION,
2232 options.annotation_usecase,
2233 options.is_serialized_entity_data_enabled, candidates)) {
2234 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2235 }
2236
2237 // Annotate with the contact engine.
2238 const bool contact_annotations_enabled =
2239 !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2240 if (contact_annotations_enabled && contact_engine_ &&
2241 !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
2242 candidates)) {
2243 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2244 }
2245
2246 // Annotate with the installed app engine.
2247 const bool app_annotations_enabled =
2248 !is_raw_usecase || is_entity_type_enabled(Collections::App());
2249 if (app_annotations_enabled && installed_app_engine_ &&
2250 !installed_app_engine_->Chunk(context_unicode, tokens,
2251 ModeFlag_ANNOTATION, candidates)) {
2252 return Status(StatusCode::INTERNAL,
2253 "Couldn't run installed app engine Chunk.");
2254 }
2255
2256 // Annotate with the number annotator.
2257 const bool number_annotations_enabled =
2258 !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2259 is_entity_type_enabled(Collections::Percentage()));
2260 if (number_annotations_enabled && number_annotator_ != nullptr &&
2261 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2262 ModeFlag_ANNOTATION, candidates)) {
2263 return Status(StatusCode::INTERNAL,
2264 "Couldn't run number annotator FindAll.");
2265 }
2266
2267 // Annotate with the duration annotator.
2268 const bool duration_annotations_enabled =
2269 !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2270 if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2271 !duration_annotator_->FindAll(context_unicode, tokens,
2272 options.annotation_usecase,
2273 ModeFlag_ANNOTATION, candidates)) {
2274 return Status(StatusCode::INTERNAL,
2275 "Couldn't run duration annotator FindAll.");
2276 }
2277
2278 // Annotate with the person name engine.
2279 const bool person_annotations_enabled =
2280 !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2281 if (person_annotations_enabled && person_name_engine_ &&
2282 !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
2283 candidates)) {
2284 return Status(StatusCode::INTERNAL,
2285 "Couldn't run person name engine Chunk.");
2286 }
2287
2288 // Annotate with the grammar annotators.
2289 if (grammar_annotator_ != nullptr &&
2290 !grammar_annotator_->Annotate(detected_text_language_tags,
2291 context_unicode, candidates)) {
2292 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2293 }
2294
2295 // Annotate with the POD NER annotator.
2296 const bool pod_ner_annotations_enabled =
2297 !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2298 if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2299 options.use_pod_ner &&
2300 !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2301 return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2302 }
2303
2304 // Annotate with the vocab annotator.
2305 const bool vocab_annotations_enabled =
2306 !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2307 if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2308 options.use_vocab_annotator &&
2309 !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2310 options.trigger_dictionary_on_beginner_words,
2311 candidates)) {
2312 return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2313 }
2314
2315 // Annotate with the experimental annotator.
2316 if (experimental_annotator_ != nullptr &&
2317 (model_->triggering_options()->experimental_enabled_modes() &
2318 ModeFlag_ANNOTATION) &&
2319 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2320 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2321 }
2322
2323 // Sort candidates according to their position in the input, so that the next
2324 // code can assume that any connected component of overlapping spans forms a
2325 // contiguous block.
2326 // Also sort them according to the end position and collection, so that the
2327 // deduplication code below can assume that same spans and classifications
2328 // form contiguous blocks.
2329 std::stable_sort(candidates->begin(), candidates->end(),
2330 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2331 if (a.span.first != b.span.first) {
2332 return a.span.first < b.span.first;
2333 }
2334
2335 if (a.span.second != b.span.second) {
2336 return a.span.second < b.span.second;
2337 }
2338
2339 return a.classification[0].collection <
2340 b.classification[0].collection;
2341 });
2342
2343 std::vector<int> candidate_indices;
2344 if (!ResolveConflicts(*candidates, context, tokens,
2345 detected_text_language_tags, options,
2346 &interpreter_manager, &candidate_indices)) {
2347 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2348 }
2349
2350 // Remove candidates that overlap exactly and have the same collection.
2351 // This can e.g. happen for phone coming from both ML model and regex.
2352 candidate_indices.erase(
2353 std::unique(candidate_indices.begin(), candidate_indices.end(),
2354 [&candidates](const int a_index, const int b_index) {
2355 const AnnotatedSpan& a = (*candidates)[a_index];
2356 const AnnotatedSpan& b = (*candidates)[b_index];
2357 return a.span == b.span &&
2358 a.classification[0].collection ==
2359 b.classification[0].collection;
2360 }),
2361 candidate_indices.end());
2362
2363 std::vector<AnnotatedSpan> result;
2364 result.reserve(candidate_indices.size());
2365 for (const int i : candidate_indices) {
2366 if ((*candidates)[i].classification.empty() ||
2367 ClassifiedAsOther((*candidates)[i].classification) ||
2368 FilteredForAnnotation((*candidates)[i])) {
2369 continue;
2370 }
2371 result.push_back(std::move((*candidates)[i]));
2372 }
2373
2374 // We generate all candidates and remove them later (with the exception of
2375 // date/time/duration entities) because there are complex interdependencies
2376 // between the entity types. E.g., the TLD of an email can be interpreted as a
2377 // URL, but most likely a user of the API does not want such annotations if
2378 // "url" is enabled and "email" is not.
2379 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2380
2381 for (AnnotatedSpan& annotated_span : result) {
2382 SortClassificationResults(&annotated_span.classification);
2383 }
2384 *candidates = result;
2385 return Status::OK;
2386 }
2387
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2388 StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2389 const std::vector<InputFragment>& string_fragments,
2390 const AnnotationOptions& options) const {
2391 Annotations annotation_candidates;
2392 annotation_candidates.annotated_spans.resize(string_fragments.size());
2393
2394 std::vector<std::string> text_to_annotate;
2395 text_to_annotate.reserve(string_fragments.size());
2396 std::vector<FragmentMetadata> fragment_metadata;
2397 fragment_metadata.reserve(string_fragments.size());
2398 for (const auto& string_fragment : string_fragments) {
2399 text_to_annotate.push_back(string_fragment.text);
2400 fragment_metadata.push_back(
2401 {.relative_bounding_box_top = string_fragment.bounding_box_top,
2402 .relative_bounding_box_height = string_fragment.bounding_box_height});
2403 }
2404
2405 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2406 const bool is_raw_usecase =
2407 options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2408
2409 const bool knowledge_engine_annotations_enabled =
2410 !is_raw_usecase || is_entity_type_enabled(Collections::Entity());
2411 // KnowledgeEngine is special, because it supports annotation of multiple
2412 // fragments at once.
2413 if (knowledge_engine_annotations_enabled && knowledge_engine_ &&
2414 !knowledge_engine_
2415 ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2416 options.annotation_usecase,
2417 options.location_context, options.permissions,
2418 options.annotate_mode, ModeFlag_ANNOTATION,
2419 &annotation_candidates)
2420 .ok()) {
2421 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2422 }
2423 // The annotator engines shouldn't change the number of annotation vectors.
2424 if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2425 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2426 << " texts to annotate but generated a different number of "
2427 "lists of annotations:"
2428 << annotation_candidates.annotated_spans.size();
2429 return Status(StatusCode::INTERNAL,
2430 "Number of annotation candidates differs from "
2431 "number of texts to annotate.");
2432 }
2433
2434 // As an optimization, if the only annotated type is Entity, we skip all the
2435 // other annotators than the KnowledgeEngine. This only happens in the raw
2436 // mode, to make sure it does not affect the result.
2437 if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2438 options.entity_types.size() == 1 &&
2439 *options.entity_types.begin() == Collections::Entity()) {
2440 return annotation_candidates;
2441 }
2442
2443 // Other annotators run on each fragment independently.
2444 for (int i = 0; i < text_to_annotate.size(); ++i) {
2445 AnnotationOptions annotation_options = options;
2446 if (string_fragments[i].datetime_options.has_value()) {
2447 DatetimeOptions reference_datetime =
2448 string_fragments[i].datetime_options.value();
2449 annotation_options.reference_time_ms_utc =
2450 reference_datetime.reference_time_ms_utc;
2451 annotation_options.reference_timezone =
2452 reference_datetime.reference_timezone;
2453 }
2454
2455 AddContactMetadataToKnowledgeClassificationResults(
2456 &annotation_candidates.annotated_spans[i]);
2457
2458 Status annotation_status =
2459 AnnotateSingleInput(text_to_annotate[i], annotation_options,
2460 &annotation_candidates.annotated_spans[i]);
2461 if (!annotation_status.ok()) {
2462 return annotation_status;
2463 }
2464 }
2465 return annotation_candidates;
2466 }
2467
Annotate(const std::string & context,const AnnotationOptions & options) const2468 std::vector<AnnotatedSpan> Annotator::Annotate(
2469 const std::string& context, const AnnotationOptions& options) const {
2470 if (context.size() > std::numeric_limits<int>::max()) {
2471 TC3_LOG(ERROR) << "Rejecting too long input.";
2472 return {};
2473 }
2474
2475 const UnicodeText context_unicode =
2476 UTF8ToUnicodeText(context, /*do_copy=*/false);
2477 if (!unilib_->IsValidUtf8(context_unicode)) {
2478 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2479 return {};
2480 }
2481
2482 std::vector<InputFragment> string_fragments;
2483 string_fragments.push_back({.text = context});
2484 StatusOr<Annotations> annotations =
2485 AnnotateStructuredInput(string_fragments, options);
2486 if (!annotations.ok()) {
2487 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2488 << annotations.status().error_message();
2489 return {};
2490 }
2491 return annotations.ValueOrDie().annotated_spans[0];
2492 }
2493
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2494 CodepointSpan Annotator::ComputeSelectionBoundaries(
2495 const UniLib::RegexMatcher* match,
2496 const RegexModel_::Pattern* config) const {
2497 if (config->capturing_group() == nullptr) {
2498 // Use first capturing group to specify the selection.
2499 int status = UniLib::RegexMatcher::kNoError;
2500 const CodepointSpan result = {match->Start(1, &status),
2501 match->End(1, &status)};
2502 if (status != UniLib::RegexMatcher::kNoError) {
2503 return {kInvalidIndex, kInvalidIndex};
2504 }
2505 return result;
2506 }
2507
2508 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2509 const int num_groups = config->capturing_group()->size();
2510 for (int i = 0; i < num_groups; i++) {
2511 if (!config->capturing_group()->Get(i)->extend_selection()) {
2512 continue;
2513 }
2514
2515 int status = UniLib::RegexMatcher::kNoError;
2516 // Check match and adjust bounds.
2517 const int group_start = match->Start(i, &status);
2518 const int group_end = match->End(i, &status);
2519 if (status != UniLib::RegexMatcher::kNoError) {
2520 return {kInvalidIndex, kInvalidIndex};
2521 }
2522 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2523 continue;
2524 }
2525 if (result.first == kInvalidIndex) {
2526 result = {group_start, group_end};
2527 } else {
2528 result.first = std::min(result.first, group_start);
2529 result.second = std::max(result.second, group_end);
2530 }
2531 }
2532 return result;
2533 }
2534
HasEntityData(const RegexModel_::Pattern * pattern) const2535 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2536 if (pattern->serialized_entity_data() != nullptr ||
2537 pattern->entity_data() != nullptr) {
2538 return true;
2539 }
2540 if (pattern->capturing_group() != nullptr) {
2541 for (const CapturingGroup* group : *pattern->capturing_group()) {
2542 if (group->entity_field_path() != nullptr) {
2543 return true;
2544 }
2545 if (group->serialized_entity_data() != nullptr ||
2546 group->entity_data() != nullptr) {
2547 return true;
2548 }
2549 }
2550 }
2551 return false;
2552 }
2553
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2554 bool Annotator::SerializedEntityDataFromRegexMatch(
2555 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2556 std::string* serialized_entity_data) const {
2557 if (!HasEntityData(pattern)) {
2558 serialized_entity_data->clear();
2559 return true;
2560 }
2561 TC3_CHECK(entity_data_builder_ != nullptr);
2562
2563 std::unique_ptr<MutableFlatbuffer> entity_data =
2564 entity_data_builder_->NewRoot();
2565
2566 TC3_CHECK(entity_data != nullptr);
2567
2568 // Set fixed entity data.
2569 if (pattern->serialized_entity_data() != nullptr) {
2570 entity_data->MergeFromSerializedFlatbuffer(
2571 StringPiece(pattern->serialized_entity_data()->c_str(),
2572 pattern->serialized_entity_data()->size()));
2573 }
2574 if (pattern->entity_data() != nullptr) {
2575 entity_data->MergeFrom(
2576 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2577 }
2578
2579 // Add entity data from rule capturing groups.
2580 if (pattern->capturing_group() != nullptr) {
2581 const int num_groups = pattern->capturing_group()->size();
2582 for (int i = 0; i < num_groups; i++) {
2583 const CapturingGroup* group = pattern->capturing_group()->Get(i);
2584
2585 // Check whether the group matched.
2586 Optional<std::string> group_match_text =
2587 GetCapturingGroupText(matcher, /*group_id=*/i);
2588 if (!group_match_text.has_value()) {
2589 continue;
2590 }
2591
2592 // Set fixed entity data from capturing group match.
2593 if (group->serialized_entity_data() != nullptr) {
2594 entity_data->MergeFromSerializedFlatbuffer(
2595 StringPiece(group->serialized_entity_data()->c_str(),
2596 group->serialized_entity_data()->size()));
2597 }
2598 if (group->entity_data() != nullptr) {
2599 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2600 pattern->entity_data()));
2601 }
2602
2603 // Set entity field from capturing group text.
2604 if (group->entity_field_path() != nullptr) {
2605 UnicodeText normalized_group_match_text =
2606 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2607
2608 // Apply normalization if specified.
2609 if (group->normalization_options() != nullptr) {
2610 normalized_group_match_text =
2611 NormalizeText(*unilib_, group->normalization_options(),
2612 normalized_group_match_text);
2613 }
2614
2615 if (!entity_data->ParseAndSet(
2616 group->entity_field_path(),
2617 normalized_group_match_text.ToUTF8String())) {
2618 TC3_LOG(ERROR)
2619 << "Could not set entity data from rule capturing group.";
2620 return false;
2621 }
2622 }
2623 }
2624 }
2625
2626 *serialized_entity_data = entity_data->Serialize();
2627 return true;
2628 }
2629
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2630 UnicodeText RemoveMoneySeparators(
2631 const std::unordered_set<char32>& decimal_separators,
2632 const UnicodeText& amount,
2633 UnicodeText::const_iterator it_decimal_separator) {
2634 UnicodeText whole_amount;
2635 for (auto it = amount.begin();
2636 it != amount.end() && it != it_decimal_separator; ++it) {
2637 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2638 static_cast<char32>(*it)) == decimal_separators.end()) {
2639 whole_amount.push_back(*it);
2640 }
2641 }
2642 return whole_amount;
2643 }
2644
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2645 void Annotator::GetMoneyQuantityFromCapturingGroup(
2646 const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2647 const UnicodeText& context_unicode, std::string* quantity,
2648 int* exponent) const {
2649 if (config->capturing_group() == nullptr) {
2650 *exponent = 0;
2651 return;
2652 }
2653
2654 const int num_groups = config->capturing_group()->size();
2655 for (int i = 0; i < num_groups; i++) {
2656 int status = UniLib::RegexMatcher::kNoError;
2657 const int group_start = match->Start(i, &status);
2658 const int group_end = match->End(i, &status);
2659 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2660 continue;
2661 }
2662
2663 *quantity =
2664 unilib_
2665 ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2666 group_end, /*do_copy=*/false))
2667 .ToUTF8String();
2668
2669 if (auto entry = model_->money_parsing_options()
2670 ->quantities_name_to_exponent()
2671 ->LookupByKey((*quantity).c_str())) {
2672 *exponent = entry->value();
2673 return;
2674 }
2675 }
2676 *exponent = 0;
2677 }
2678
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2679 bool Annotator::ParseAndFillInMoneyAmount(
2680 std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2681 const RegexModel_::Pattern* config,
2682 const UnicodeText& context_unicode) const {
2683 std::unique_ptr<EntityDataT> data =
2684 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2685 *serialized_entity_data);
2686 if (data == nullptr) {
2687 if (model_->version() >= 706) {
2688 // This way of parsing money entity data is enabled for models newer than
2689 // v706, consequently logging errors only for them (b/156634162).
2690 TC3_LOG(ERROR)
2691 << "Data field is null when trying to parse Money Entity Data";
2692 }
2693 return false;
2694 }
2695 if (data->money->unnormalized_amount.empty()) {
2696 if (model_->version() >= 706) {
2697 // This way of parsing money entity data is enabled for models newer than
2698 // v706, consequently logging errors only for them (b/156634162).
2699 TC3_LOG(ERROR)
2700 << "Data unnormalized_amount is empty when trying to parse "
2701 "Money Entity Data";
2702 }
2703 return false;
2704 }
2705
2706 UnicodeText amount =
2707 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2708 int separator_back_index = 0;
2709 auto it_decimal_separator = --amount.end();
2710 for (; it_decimal_separator != amount.begin();
2711 --it_decimal_separator, ++separator_back_index) {
2712 if (std::find(money_separators_.begin(), money_separators_.end(),
2713 static_cast<char32>(*it_decimal_separator)) !=
2714 money_separators_.end()) {
2715 break;
2716 }
2717 }
2718
2719 // If there are 3 digits after the last separator, we consider that a
2720 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2721 // If there is no separator in number, also that number is an int.
2722 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2723 it_decimal_separator = amount.end();
2724 }
2725
2726 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2727 it_decimal_separator),
2728 &data->money->amount_whole_part)) {
2729 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2730 "amount: "
2731 << data->money->unnormalized_amount;
2732 return false;
2733 }
2734
2735 if (it_decimal_separator == amount.end()) {
2736 data->money->amount_decimal_part = 0;
2737 data->money->nanos = 0;
2738 } else {
2739 const int amount_codepoints_size = amount.size_codepoints();
2740 const UnicodeText decimal_part = UnicodeText::Substring(
2741 amount, amount_codepoints_size - separator_back_index,
2742 amount_codepoints_size, /*do_copy=*/false);
2743 if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2744 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2745 "the amount: "
2746 << data->money->unnormalized_amount;
2747 return false;
2748 }
2749 data->money->nanos = data->money->amount_decimal_part *
2750 pow(10, 9 - decimal_part.size_codepoints());
2751 }
2752
2753 if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2754 nullptr) {
2755 int quantity_exponent;
2756 std::string quantity;
2757 GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2758 &quantity, &quantity_exponent);
2759 if (quantity_exponent > 0 && quantity_exponent <= 9) {
2760 const double amount_whole_part =
2761 data->money->amount_whole_part * pow(10, quantity_exponent) +
2762 data->money->nanos / pow(10, 9 - quantity_exponent);
2763 // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2764 // (and `std::numeric_limits<int>::max()` to
2765 // `std::numeric_limits<int64>::max()`).
2766 if (amount_whole_part < std::numeric_limits<int>::max()) {
2767 data->money->amount_whole_part = amount_whole_part;
2768 data->money->nanos = data->money->nanos %
2769 static_cast<int>(pow(10, 9 - quantity_exponent)) *
2770 pow(10, quantity_exponent);
2771 }
2772 }
2773 if (quantity_exponent > 0) {
2774 data->money->unnormalized_amount = strings::JoinStrings(
2775 " ", {data->money->unnormalized_amount, quantity});
2776 }
2777 }
2778
2779 *serialized_entity_data =
2780 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2781 return true;
2782 }
2783
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2784 bool Annotator::IsAnyModelEntityTypeEnabled(
2785 const EnabledEntityTypes& is_entity_type_enabled) const {
2786 if (model_->classification_feature_options() == nullptr ||
2787 model_->classification_feature_options()->collections() == nullptr) {
2788 return false;
2789 }
2790 for (int i = 0;
2791 i < model_->classification_feature_options()->collections()->size();
2792 i++) {
2793 if (is_entity_type_enabled(model_->classification_feature_options()
2794 ->collections()
2795 ->Get(i)
2796 ->str())) {
2797 return true;
2798 }
2799 }
2800 return false;
2801 }
2802
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2803 bool Annotator::IsAnyRegexEntityTypeEnabled(
2804 const EnabledEntityTypes& is_entity_type_enabled) const {
2805 if (model_->regex_model() == nullptr ||
2806 model_->regex_model()->patterns() == nullptr) {
2807 return false;
2808 }
2809 for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2810 if (is_entity_type_enabled(model_->regex_model()
2811 ->patterns()
2812 ->Get(i)
2813 ->collection_name()
2814 ->str())) {
2815 return true;
2816 }
2817 }
2818 return false;
2819 }
2820
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2821 bool Annotator::IsAnyPodNerEntityTypeEnabled(
2822 const EnabledEntityTypes& is_entity_type_enabled) const {
2823 if (pod_ner_annotator_ == nullptr) {
2824 return false;
2825 }
2826
2827 for (const std::string& collection :
2828 pod_ner_annotator_->GetSupportedCollections()) {
2829 if (is_entity_type_enabled(collection)) {
2830 return true;
2831 }
2832 }
2833 return false;
2834 }
2835
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2836 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2837 const std::vector<int>& rules,
2838 bool is_serialized_entity_data_enabled,
2839 const EnabledEntityTypes& enabled_entity_types,
2840 const AnnotationUsecase& annotation_usecase,
2841 std::vector<AnnotatedSpan>* result) const {
2842 for (int pattern_id : rules) {
2843 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2844 if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2845 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2846 // No regex annotation type has been requested, skip regex annotation.
2847 continue;
2848 }
2849 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2850 if (!matcher) {
2851 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2852 << pattern_id;
2853 return false;
2854 }
2855
2856 int status = UniLib::RegexMatcher::kNoError;
2857 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2858 if (regex_pattern.config->verification_options()) {
2859 if (!VerifyRegexMatchCandidate(
2860 context_unicode.ToUTF8String(),
2861 regex_pattern.config->verification_options(),
2862 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2863 continue;
2864 }
2865 }
2866
2867 std::string serialized_entity_data;
2868 if (is_serialized_entity_data_enabled) {
2869 if (!SerializedEntityDataFromRegexMatch(
2870 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2871 TC3_LOG(ERROR) << "Could not get entity data.";
2872 return false;
2873 }
2874
2875 // Further parsing of money amount. Need this since regexes cannot have
2876 // empty groups that fill in entity data (amount_decimal_part and
2877 // quantity might be empty groups).
2878 if (regex_pattern.config->collection_name()->str() ==
2879 Collections::Money()) {
2880 if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2881 regex_pattern.config,
2882 context_unicode)) {
2883 if (model_->version() >= 706) {
2884 // This way of parsing money entity data is enabled for models
2885 // newer than v706 => logging errors only for them (b/156634162).
2886 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2887 }
2888 }
2889 }
2890 }
2891
2892 result->emplace_back();
2893
2894 // Selection/annotation regular expressions need to specify a capturing
2895 // group specifying the selection.
2896 result->back().span =
2897 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2898
2899 result->back().classification = {
2900 {regex_pattern.config->collection_name()->str(),
2901 regex_pattern.config->target_classification_score(),
2902 regex_pattern.config->priority_score()}};
2903
2904 result->back().classification[0].serialized_entity_data =
2905 serialized_entity_data;
2906 }
2907 }
2908 return true;
2909 }
2910
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2911 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2912 tflite::Interpreter* selection_interpreter,
2913 const CachedFeatures& cached_features,
2914 std::vector<TokenSpan>* chunks) const {
2915 const int max_selection_span =
2916 selection_feature_processor_->GetOptions()->max_selection_span();
2917 // The inference span is the span of interest expanded to include
2918 // max_selection_span tokens on either side, which is how far a selection can
2919 // stretch from the click.
2920 const TokenSpan inference_span =
2921 IntersectTokenSpans(span_of_interest.Expand(
2922 /*num_tokens_left=*/max_selection_span,
2923 /*num_tokens_right=*/max_selection_span),
2924 {0, num_tokens});
2925
2926 std::vector<ScoredChunk> scored_chunks;
2927 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2928 selection_feature_processor_->GetOptions()
2929 ->bounds_sensitive_features()
2930 ->enabled()) {
2931 if (!ModelBoundsSensitiveScoreChunks(
2932 num_tokens, span_of_interest, inference_span, cached_features,
2933 selection_interpreter, &scored_chunks)) {
2934 return false;
2935 }
2936 } else {
2937 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2938 cached_features, selection_interpreter,
2939 &scored_chunks)) {
2940 return false;
2941 }
2942 }
2943 std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
2944 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2945 return lhs.score < rhs.score;
2946 });
2947
2948 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2949 // them greedily as long as they do not overlap with any previously picked
2950 // chunks.
2951 std::vector<bool> token_used(inference_span.Size());
2952 chunks->clear();
2953 for (const ScoredChunk& scored_chunk : scored_chunks) {
2954 bool feasible = true;
2955 for (int i = scored_chunk.token_span.first;
2956 i < scored_chunk.token_span.second; ++i) {
2957 if (token_used[i - inference_span.first]) {
2958 feasible = false;
2959 break;
2960 }
2961 }
2962
2963 if (!feasible) {
2964 continue;
2965 }
2966
2967 for (int i = scored_chunk.token_span.first;
2968 i < scored_chunk.token_span.second; ++i) {
2969 token_used[i - inference_span.first] = true;
2970 }
2971
2972 chunks->push_back(scored_chunk.token_span);
2973 }
2974
2975 std::stable_sort(chunks->begin(), chunks->end());
2976
2977 return true;
2978 }
2979
2980 namespace {
2981 // Updates the value at the given key in the map to maximum of the current value
2982 // and the given value, or simply inserts the value if the key is not yet there.
2983 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2984 void UpdateMax(Map* map, typename Map::key_type key,
2985 typename Map::mapped_type value) {
2986 const auto it = map->find(key);
2987 if (it != map->end()) {
2988 it->second = std::max(it->second, value);
2989 } else {
2990 (*map)[key] = value;
2991 }
2992 }
2993 } // namespace
2994
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2995 bool Annotator::ModelClickContextScoreChunks(
2996 int num_tokens, const TokenSpan& span_of_interest,
2997 const CachedFeatures& cached_features,
2998 tflite::Interpreter* selection_interpreter,
2999 std::vector<ScoredChunk>* scored_chunks) const {
3000 const int max_batch_size = model_->selection_options()->batch_size();
3001
3002 std::vector<float> all_features;
3003 std::map<TokenSpan, float> chunk_scores;
3004 for (int batch_start = span_of_interest.first;
3005 batch_start < span_of_interest.second; batch_start += max_batch_size) {
3006 const int batch_end =
3007 std::min(batch_start + max_batch_size, span_of_interest.second);
3008
3009 // Prepare features for the whole batch.
3010 all_features.clear();
3011 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3012 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3013 cached_features.AppendClickContextFeaturesForClick(click_pos,
3014 &all_features);
3015 }
3016
3017 // Run batched inference.
3018 const int batch_size = batch_end - batch_start;
3019 const int features_size = cached_features.OutputFeaturesSize();
3020 TensorView<float> logits = selection_executor_->ComputeLogits(
3021 TensorView<float>(all_features.data(), {batch_size, features_size}),
3022 selection_interpreter);
3023 if (!logits.is_valid()) {
3024 TC3_LOG(ERROR) << "Couldn't compute logits.";
3025 return false;
3026 }
3027 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3028 logits.dim(1) !=
3029 selection_feature_processor_->GetSelectionLabelCount()) {
3030 TC3_LOG(ERROR) << "Mismatching output.";
3031 return false;
3032 }
3033
3034 // Save results.
3035 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3036 const std::vector<float> scores = ComputeSoftmax(
3037 logits.data() + logits.dim(1) * (click_pos - batch_start),
3038 logits.dim(1));
3039 for (int j = 0;
3040 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3041 TokenSpan relative_token_span;
3042 if (!selection_feature_processor_->LabelToTokenSpan(
3043 j, &relative_token_span)) {
3044 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3045 return false;
3046 }
3047 const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3048 relative_token_span.first, relative_token_span.second);
3049 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3050 UpdateMax(&chunk_scores, candidate_span, scores[j]);
3051 }
3052 }
3053 }
3054 }
3055
3056 scored_chunks->clear();
3057 scored_chunks->reserve(chunk_scores.size());
3058 for (const auto& entry : chunk_scores) {
3059 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3060 }
3061
3062 return true;
3063 }
3064
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3065 bool Annotator::ModelBoundsSensitiveScoreChunks(
3066 int num_tokens, const TokenSpan& span_of_interest,
3067 const TokenSpan& inference_span, const CachedFeatures& cached_features,
3068 tflite::Interpreter* selection_interpreter,
3069 std::vector<ScoredChunk>* scored_chunks) const {
3070 const int max_selection_span =
3071 selection_feature_processor_->GetOptions()->max_selection_span();
3072 const int max_chunk_length = selection_feature_processor_->GetOptions()
3073 ->selection_reduced_output_space()
3074 ? max_selection_span + 1
3075 : 2 * max_selection_span + 1;
3076 const bool score_single_token_spans_as_zero =
3077 selection_feature_processor_->GetOptions()
3078 ->bounds_sensitive_features()
3079 ->score_single_token_spans_as_zero();
3080
3081 scored_chunks->clear();
3082 if (score_single_token_spans_as_zero) {
3083 scored_chunks->reserve(span_of_interest.Size());
3084 }
3085
3086 // Prepare all chunk candidates into one batch:
3087 // - Are contained in the inference span
3088 // - Have a non-empty intersection with the span of interest
3089 // - Are at least one token long
3090 // - Are not longer than the maximum chunk length
3091 std::vector<TokenSpan> candidate_spans;
3092 for (int start = inference_span.first; start < span_of_interest.second;
3093 ++start) {
3094 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3095 for (int end = leftmost_end_index;
3096 end <= inference_span.second && end - start <= max_chunk_length;
3097 ++end) {
3098 const TokenSpan candidate_span = {start, end};
3099 if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3100 // Do not include the single token span in the batch, add a zero score
3101 // for it directly to the output.
3102 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3103 } else {
3104 candidate_spans.push_back(candidate_span);
3105 }
3106 }
3107 }
3108
3109 const int max_batch_size = model_->selection_options()->batch_size();
3110
3111 std::vector<float> all_features;
3112 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3113 for (int batch_start = 0; batch_start < candidate_spans.size();
3114 batch_start += max_batch_size) {
3115 const int batch_end = std::min(batch_start + max_batch_size,
3116 static_cast<int>(candidate_spans.size()));
3117
3118 // Prepare features for the whole batch.
3119 all_features.clear();
3120 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3121 for (int i = batch_start; i < batch_end; ++i) {
3122 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3123 &all_features);
3124 }
3125
3126 // Run batched inference.
3127 const int batch_size = batch_end - batch_start;
3128 const int features_size = cached_features.OutputFeaturesSize();
3129 TensorView<float> logits = selection_executor_->ComputeLogits(
3130 TensorView<float>(all_features.data(), {batch_size, features_size}),
3131 selection_interpreter);
3132 if (!logits.is_valid()) {
3133 TC3_LOG(ERROR) << "Couldn't compute logits.";
3134 return false;
3135 }
3136 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3137 logits.dim(1) != 1) {
3138 TC3_LOG(ERROR) << "Mismatching output.";
3139 return false;
3140 }
3141
3142 // Save results.
3143 for (int i = batch_start; i < batch_end; ++i) {
3144 scored_chunks->push_back(
3145 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3146 }
3147 }
3148
3149 return true;
3150 }
3151
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3152 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3153 int64 reference_time_ms_utc,
3154 const std::string& reference_timezone,
3155 const std::string& locales, ModeFlag mode,
3156 AnnotationUsecase annotation_usecase,
3157 bool is_serialized_entity_data_enabled,
3158 std::vector<AnnotatedSpan>* result) const {
3159 if (!datetime_parser_) {
3160 return true;
3161 }
3162 LocaleList locale_list = LocaleList::ParseFrom(locales);
3163 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3164 datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3165 reference_timezone, locale_list, mode,
3166 annotation_usecase,
3167 /*anchor_start_end=*/false);
3168 if (!result_status.ok()) {
3169 return false;
3170 }
3171
3172 for (const DatetimeParseResultSpan& datetime_span :
3173 result_status.ValueOrDie()) {
3174 AnnotatedSpan annotated_span;
3175 annotated_span.span = datetime_span.span;
3176 for (const DatetimeParseResult& parse_result : datetime_span.data) {
3177 annotated_span.classification.emplace_back(
3178 PickCollectionForDatetime(parse_result),
3179 datetime_span.target_classification_score,
3180 datetime_span.priority_score);
3181 annotated_span.classification.back().datetime_parse_result = parse_result;
3182 if (is_serialized_entity_data_enabled) {
3183 annotated_span.classification.back().serialized_entity_data =
3184 CreateDatetimeSerializedEntityData(parse_result);
3185 }
3186 }
3187 annotated_span.source = AnnotatedSpan::Source::DATETIME;
3188 result->push_back(std::move(annotated_span));
3189 }
3190 return true;
3191 }
3192
model() const3193 const Model* Annotator::model() const { return model_; }
entity_data_schema() const3194 const reflection::Schema* Annotator::entity_data_schema() const {
3195 return entity_data_schema_;
3196 }
3197
ViewModel(const void * buffer,int size)3198 const Model* ViewModel(const void* buffer, int size) {
3199 if (!buffer) {
3200 return nullptr;
3201 }
3202
3203 return LoadAndVerifyModel(buffer, size);
3204 }
3205
LookUpKnowledgeEntity(const std::string & id) const3206 StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3207 const std::string& id) const {
3208 if (!knowledge_engine_) {
3209 return Status(StatusCode::FAILED_PRECONDITION,
3210 "knowledge_engine_ is nullptr");
3211 }
3212 return knowledge_engine_->LookUpEntity(id);
3213 }
3214
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3215 StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3216 const std::string& mid_str, const std::string& property) const {
3217 if (!knowledge_engine_) {
3218 return Status(StatusCode::FAILED_PRECONDITION,
3219 "knowledge_engine_ is nullptr");
3220 }
3221 return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3222 }
3223
3224 } // namespace libtextclassifier3
3225