1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/annotator.h"
18
19 #include <algorithm>
20 #include <cmath>
21 #include <cstddef>
22 #include <iterator>
23 #include <limits>
24 #include <numeric>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28
29 #include "annotator/collections.h"
30 #include "annotator/datetime/grammar-parser.h"
31 #include "annotator/datetime/regex-parser.h"
32 #include "annotator/flatbuffer-utils.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "annotator/model_generated.h"
35 #include "annotator/types.h"
36 #include "utils/base/logging.h"
37 #include "utils/base/status.h"
38 #include "utils/base/statusor.h"
39 #include "utils/calendar/calendar.h"
40 #include "utils/checksum.h"
41 #include "utils/grammar/analyzer.h"
42 #include "utils/i18n/locale-list.h"
43 #include "utils/i18n/locale.h"
44 #include "utils/math/softmax.h"
45 #include "utils/normalization.h"
46 #include "utils/optional.h"
47 #include "utils/regex-match.h"
48 #include "utils/strings/append.h"
49 #include "utils/strings/numbers.h"
50 #include "utils/strings/split.h"
51 #include "utils/utf8/unicodetext.h"
52 #include "utils/utf8/unilib-common.h"
53 #include "utils/zlib/zlib_regex.h"
54
55 namespace libtextclassifier3 {
56
57 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58
59 const std::string& Annotator::kPhoneCollection =
__anon0684102d0102() 60 *[]() { return new std::string("phone"); }();
61 const std::string& Annotator::kAddressCollection =
__anon0684102d0202() 62 *[]() { return new std::string("address"); }();
63 const std::string& Annotator::kDateCollection =
__anon0684102d0302() 64 *[]() { return new std::string("date"); }();
65 const std::string& Annotator::kUrlCollection =
__anon0684102d0402() 66 *[]() { return new std::string("url"); }();
67 const std::string& Annotator::kEmailCollection =
__anon0684102d0502() 68 *[]() { return new std::string("email"); }();
69
70 namespace {
LoadAndVerifyModel(const void * addr,int size)71 const Model* LoadAndVerifyModel(const void* addr, int size) {
72 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73 if (VerifyModelBuffer(verifier)) {
74 return GetModel(addr);
75 } else {
76 return nullptr;
77 }
78 }
79
LoadAndVerifyPersonNameModel(const void * addr,int size)80 const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81 int size) {
82 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83 if (VerifyPersonNameModelBuffer(verifier)) {
84 return GetPersonNameModel(addr);
85 } else {
86 return nullptr;
87 }
88 }
89
90 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92 const UniLib* MaybeCreateUnilib(const UniLib* lib,
93 std::unique_ptr<UniLib>* owned_lib) {
94 if (lib) {
95 return lib;
96 } else {
97 owned_lib->reset(new UniLib);
98 return owned_lib->get();
99 }
100 }
101
102 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103 const CalendarLib* MaybeCreateCalendarlib(
104 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105 if (lib) {
106 return lib;
107 } else {
108 owned_lib->reset(new CalendarLib);
109 return owned_lib->get();
110 }
111 }
112
113 // Returns whether the provided input is valid:
114 // * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115 bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116 return (span.first >= 0 && span.first < span.second &&
117 span.second <= context.size_codepoints());
118 }
119
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120 std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121 const flatbuffers::Vector<int32_t>* ints) {
122 if (ints == nullptr) {
123 return {};
124 }
125 std::unordered_set<char32> ints_set;
126 for (auto value : *ints) {
127 ints_set.insert(static_cast<char32>(value));
128 }
129 return ints_set;
130 }
131
132 } // namespace
133
SelectionInterpreter()134 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135 if (!selection_interpreter_) {
136 TC3_CHECK(selection_executor_);
137 selection_interpreter_ = selection_executor_->CreateInterpreter();
138 if (!selection_interpreter_) {
139 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140 }
141 }
142 return selection_interpreter_.get();
143 }
144
ClassificationInterpreter()145 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146 if (!classification_interpreter_) {
147 TC3_CHECK(classification_executor_);
148 classification_interpreter_ = classification_executor_->CreateInterpreter();
149 if (!classification_interpreter_) {
150 TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151 }
152 }
153 return classification_interpreter_.get();
154 }
155
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157 const char* buffer, int size, const UniLib* unilib,
158 const CalendarLib* calendarlib) {
159 const Model* model = LoadAndVerifyModel(buffer, size);
160 if (model == nullptr) {
161 return nullptr;
162 }
163
164 auto classifier = std::unique_ptr<Annotator>(new Annotator());
165 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166 calendarlib =
167 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168 classifier->ValidateAndInitialize(model, unilib, calendarlib);
169 if (!classifier->IsInitialized()) {
170 return nullptr;
171 }
172
173 return classifier;
174 }
175
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176 std::unique_ptr<Annotator> Annotator::FromString(
177 const std::string& buffer, const UniLib* unilib,
178 const CalendarLib* calendarlib) {
179 auto classifier = std::unique_ptr<Annotator>(new Annotator());
180 classifier->owned_buffer_ = buffer;
181 const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182 classifier->owned_buffer_.size());
183 if (model == nullptr) {
184 return nullptr;
185 }
186 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187 calendarlib =
188 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189 classifier->ValidateAndInitialize(model, unilib, calendarlib);
190 if (!classifier->IsInitialized()) {
191 return nullptr;
192 }
193
194 return classifier;
195 }
196
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199 const CalendarLib* calendarlib) {
200 if (!(*mmap)->handle().ok()) {
201 TC3_VLOG(1) << "Mmap failed.";
202 return nullptr;
203 }
204
205 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206 (*mmap)->handle().num_bytes());
207 if (!model) {
208 TC3_LOG(ERROR) << "Model verification failed.";
209 return nullptr;
210 }
211
212 auto classifier = std::unique_ptr<Annotator>(new Annotator());
213 classifier->mmap_ = std::move(*mmap);
214 unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215 calendarlib =
216 MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217 classifier->ValidateAndInitialize(model, unilib, calendarlib);
218 if (!classifier->IsInitialized()) {
219 return nullptr;
220 }
221
222 return classifier;
223 }
224
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227 std::unique_ptr<CalendarLib> calendarlib) {
228 if (!(*mmap)->handle().ok()) {
229 TC3_VLOG(1) << "Mmap failed.";
230 return nullptr;
231 }
232
233 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234 (*mmap)->handle().num_bytes());
235 if (model == nullptr) {
236 TC3_LOG(ERROR) << "Model verification failed.";
237 return nullptr;
238 }
239
240 auto classifier = std::unique_ptr<Annotator>(new Annotator());
241 classifier->mmap_ = std::move(*mmap);
242 classifier->owned_unilib_ = std::move(unilib);
243 classifier->owned_calendarlib_ = std::move(calendarlib);
244 classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245 classifier->owned_calendarlib_.get());
246 if (!classifier->IsInitialized()) {
247 return nullptr;
248 }
249
250 return classifier;
251 }
252
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254 int fd, int offset, int size, const UniLib* unilib,
255 const CalendarLib* calendarlib) {
256 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257 return FromScopedMmap(&mmap, unilib, calendarlib);
258 }
259
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261 int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262 std::unique_ptr<CalendarLib> calendarlib) {
263 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265 }
266
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268 int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270 return FromScopedMmap(&mmap, unilib, calendarlib);
271 }
272
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274 int fd, std::unique_ptr<UniLib> unilib,
275 std::unique_ptr<CalendarLib> calendarlib) {
276 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278 }
279
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281 const UniLib* unilib,
282 const CalendarLib* calendarlib) {
283 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284 return FromScopedMmap(&mmap, unilib, calendarlib);
285 }
286
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287 std::unique_ptr<Annotator> Annotator::FromPath(
288 const std::string& path, std::unique_ptr<UniLib> unilib,
289 std::unique_ptr<CalendarLib> calendarlib) {
290 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292 }
293
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294 void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295 const CalendarLib* calendarlib) {
296 model_ = model;
297 unilib_ = unilib;
298 calendarlib_ = calendarlib;
299
300 initialized_ = false;
301
302 if (model_ == nullptr) {
303 TC3_LOG(ERROR) << "No model specified.";
304 return;
305 }
306
307 const bool model_enabled_for_annotation =
308 (model_->triggering_options() != nullptr &&
309 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310 const bool model_enabled_for_classification =
311 (model_->triggering_options() != nullptr &&
312 (model_->triggering_options()->enabled_modes() &
313 ModeFlag_CLASSIFICATION));
314 const bool model_enabled_for_selection =
315 (model_->triggering_options() != nullptr &&
316 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317
318 // Annotation requires the selection model.
319 if (model_enabled_for_annotation || model_enabled_for_selection) {
320 if (!model_->selection_options()) {
321 TC3_LOG(ERROR) << "No selection options.";
322 return;
323 }
324 if (!model_->selection_feature_options()) {
325 TC3_LOG(ERROR) << "No selection feature options.";
326 return;
327 }
328 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329 TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330 return;
331 }
332 if (!model_->selection_model()) {
333 TC3_LOG(ERROR) << "No selection model.";
334 return;
335 }
336 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337 if (!selection_executor_) {
338 TC3_LOG(ERROR) << "Could not initialize selection executor.";
339 return;
340 }
341 selection_feature_processor_.reset(
342 new FeatureProcessor(model_->selection_feature_options(), unilib_));
343 }
344
345 // Annotation requires the classification model for conflict resolution and
346 // scoring.
347 // Selection requires the classification model for conflict resolution.
348 if (model_enabled_for_annotation || model_enabled_for_classification ||
349 model_enabled_for_selection) {
350 if (!model_->classification_options()) {
351 TC3_LOG(ERROR) << "No classification options.";
352 return;
353 }
354
355 if (!model_->classification_feature_options()) {
356 TC3_LOG(ERROR) << "No classification feature options.";
357 return;
358 }
359
360 if (!model_->classification_feature_options()
361 ->bounds_sensitive_features()) {
362 TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
363 return;
364 }
365 if (!model_->classification_model()) {
366 TC3_LOG(ERROR) << "No clf model.";
367 return;
368 }
369
370 classification_executor_ =
371 ModelExecutor::FromBuffer(model_->classification_model());
372 if (!classification_executor_) {
373 TC3_LOG(ERROR) << "Could not initialize classification executor.";
374 return;
375 }
376
377 classification_feature_processor_.reset(new FeatureProcessor(
378 model_->classification_feature_options(), unilib_));
379 }
380
381 // The embeddings need to be specified if the model is to be used for
382 // classification or selection.
383 if (model_enabled_for_annotation || model_enabled_for_classification ||
384 model_enabled_for_selection) {
385 if (!model_->embedding_model()) {
386 TC3_LOG(ERROR) << "No embedding model.";
387 return;
388 }
389
390 // Check that the embedding size of the selection and classification model
391 // matches, as they are using the same embeddings.
392 if (model_enabled_for_selection &&
393 (model_->selection_feature_options()->embedding_size() !=
394 model_->classification_feature_options()->embedding_size() ||
395 model_->selection_feature_options()->embedding_quantization_bits() !=
396 model_->classification_feature_options()
397 ->embedding_quantization_bits())) {
398 TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
399 return;
400 }
401
402 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
403 model_->embedding_model(),
404 model_->classification_feature_options()->embedding_size(),
405 model_->classification_feature_options()->embedding_quantization_bits(),
406 model_->embedding_pruning_mask());
407 if (!embedding_executor_) {
408 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
409 return;
410 }
411 }
412
413 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
414 if (model_->regex_model()) {
415 if (!InitializeRegexModel(decompressor.get())) {
416 TC3_LOG(ERROR) << "Could not initialize regex model.";
417 return;
418 }
419 }
420
421 if (model_->datetime_grammar_model()) {
422 if (model_->datetime_grammar_model()->rules()) {
423 analyzer_ = std::make_unique<grammar::Analyzer>(
424 unilib_, model_->datetime_grammar_model()->rules());
425 datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
426 datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
427 *analyzer_, *datetime_grounder_,
428 /*target_classification_score=*/1.0,
429 /*priority_score=*/1.0);
430 }
431 } else if (model_->datetime_model()) {
432 datetime_parser_ = RegexDatetimeParser::Instance(
433 model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
434 if (!datetime_parser_) {
435 TC3_LOG(ERROR) << "Could not initialize datetime parser.";
436 return;
437 }
438 }
439
440 if (model_->output_options()) {
441 if (model_->output_options()->filtered_collections_annotation()) {
442 for (const auto collection :
443 *model_->output_options()->filtered_collections_annotation()) {
444 filtered_collections_annotation_.insert(collection->str());
445 }
446 }
447 if (model_->output_options()->filtered_collections_classification()) {
448 for (const auto collection :
449 *model_->output_options()->filtered_collections_classification()) {
450 filtered_collections_classification_.insert(collection->str());
451 }
452 }
453 if (model_->output_options()->filtered_collections_selection()) {
454 for (const auto collection :
455 *model_->output_options()->filtered_collections_selection()) {
456 filtered_collections_selection_.insert(collection->str());
457 }
458 }
459 }
460
461 if (model_->number_annotator_options() &&
462 model_->number_annotator_options()->enabled()) {
463 number_annotator_.reset(
464 new NumberAnnotator(model_->number_annotator_options(), unilib_));
465 }
466
467 if (model_->money_parsing_options()) {
468 money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
469 model_->money_parsing_options()->separators());
470 }
471
472 if (model_->duration_annotator_options() &&
473 model_->duration_annotator_options()->enabled()) {
474 duration_annotator_.reset(
475 new DurationAnnotator(model_->duration_annotator_options(),
476 selection_feature_processor_.get(), unilib_));
477 }
478
479 if (model_->grammar_model()) {
480 grammar_annotator_.reset(new GrammarAnnotator(
481 unilib_, model_->grammar_model(), entity_data_builder_.get()));
482 }
483
484 // The following #ifdef is here to aid quality evaluation of a situation, when
485 // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
486 // it.
487 #if !defined(TC3_DISABLE_POD_NER)
488 if (model_->pod_ner_model()) {
489 pod_ner_annotator_ =
490 PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
491 }
492 #endif
493
494 if (model_->vocab_model()) {
495 vocab_annotator_ = VocabAnnotator::Create(
496 model_->vocab_model(), *selection_feature_processor_, *unilib_);
497 }
498
499 if (model_->entity_data_schema()) {
500 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
501 model_->entity_data_schema()->Data(),
502 model_->entity_data_schema()->size());
503 if (entity_data_schema_ == nullptr) {
504 TC3_LOG(ERROR) << "Could not load entity data schema data.";
505 return;
506 }
507
508 entity_data_builder_.reset(
509 new MutableFlatbufferBuilder(entity_data_schema_));
510 } else {
511 entity_data_schema_ = nullptr;
512 entity_data_builder_ = nullptr;
513 }
514
515 if (model_->triggering_locales() &&
516 !ParseLocales(model_->triggering_locales()->c_str(),
517 &model_triggering_locales_)) {
518 TC3_LOG(ERROR) << "Could not parse model supported locales.";
519 return;
520 }
521
522 if (model_->triggering_options() != nullptr &&
523 model_->triggering_options()->locales() != nullptr &&
524 !ParseLocales(model_->triggering_options()->locales()->c_str(),
525 &ml_model_triggering_locales_)) {
526 TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
527 return;
528 }
529
530 if (model_->triggering_options() != nullptr &&
531 model_->triggering_options()->dictionary_locales() != nullptr &&
532 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
533 &dictionary_locales_)) {
534 TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
535 return;
536 }
537
538 if (model_->conflict_resolution_options() != nullptr) {
539 prioritize_longest_annotation_ =
540 model_->conflict_resolution_options()->prioritize_longest_annotation();
541 do_conflict_resolution_in_raw_mode_ =
542 model_->conflict_resolution_options()
543 ->do_conflict_resolution_in_raw_mode();
544 }
545
546 #ifdef TC3_EXPERIMENTAL
547 TC3_LOG(WARNING) << "Enabling experimental annotators.";
548 InitializeExperimentalAnnotators();
549 #endif
550
551 initialized_ = true;
552 }
553
InitializeRegexModel(ZlibDecompressor * decompressor)554 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
555 if (!model_->regex_model()->patterns()) {
556 return true;
557 }
558
559 // Initialize pattern recognizers.
560 int regex_pattern_id = 0;
561 for (const auto regex_pattern : *model_->regex_model()->patterns()) {
562 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
563 UncompressMakeRegexPattern(
564 *unilib_, regex_pattern->pattern(),
565 regex_pattern->compressed_pattern(),
566 model_->regex_model()->lazy_regex_compilation(), decompressor);
567 if (!compiled_pattern) {
568 TC3_LOG(INFO) << "Failed to load regex pattern";
569 return false;
570 }
571
572 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
573 annotation_regex_patterns_.push_back(regex_pattern_id);
574 }
575 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
576 classification_regex_patterns_.push_back(regex_pattern_id);
577 }
578 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
579 selection_regex_patterns_.push_back(regex_pattern_id);
580 }
581 regex_patterns_.push_back({
582 regex_pattern,
583 std::move(compiled_pattern),
584 });
585 ++regex_pattern_id;
586 }
587
588 return true;
589 }
590
InitializeKnowledgeEngine(const std::string & serialized_config)591 bool Annotator::InitializeKnowledgeEngine(
592 const std::string& serialized_config) {
593 std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
594 if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
595 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
596 return false;
597 }
598 if (model_->triggering_options() != nullptr) {
599 knowledge_engine->SetPriorityScore(
600 model_->triggering_options()->knowledge_priority_score());
601 }
602 knowledge_engine_ = std::move(knowledge_engine);
603 return true;
604 }
605
InitializeContactEngine(const std::string & serialized_config)606 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
607 std::unique_ptr<ContactEngine> contact_engine(
608 new ContactEngine(selection_feature_processor_.get(), unilib_,
609 model_->contact_annotator_options()));
610 if (!contact_engine->Initialize(serialized_config)) {
611 TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
612 return false;
613 }
614 contact_engine_ = std::move(contact_engine);
615 return true;
616 }
617
InitializeInstalledAppEngine(const std::string & serialized_config)618 bool Annotator::InitializeInstalledAppEngine(
619 const std::string& serialized_config) {
620 std::unique_ptr<InstalledAppEngine> installed_app_engine(
621 new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
622 if (!installed_app_engine->Initialize(serialized_config)) {
623 TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
624 return false;
625 }
626 installed_app_engine_ = std::move(installed_app_engine);
627 return true;
628 }
629
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)630 bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
631 if (lang_id == nullptr) {
632 return false;
633 }
634
635 lang_id_ = lang_id;
636 if (lang_id_ != nullptr && model_->translate_annotator_options() &&
637 model_->translate_annotator_options()->enabled()) {
638 translate_annotator_.reset(new TranslateAnnotator(
639 model_->translate_annotator_options(), lang_id_, unilib_));
640 } else {
641 translate_annotator_.reset(nullptr);
642 }
643 return true;
644 }
645
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)646 bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
647 int size) {
648 const PersonNameModel* person_name_model =
649 LoadAndVerifyPersonNameModel(buffer, size);
650
651 if (person_name_model == nullptr) {
652 TC3_LOG(ERROR) << "Person name model verification failed.";
653 return false;
654 }
655
656 if (!person_name_model->enabled()) {
657 return true;
658 }
659
660 std::unique_ptr<PersonNameEngine> person_name_engine(
661 new PersonNameEngine(selection_feature_processor_.get(), unilib_));
662 if (!person_name_engine->Initialize(person_name_model)) {
663 TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
664 return false;
665 }
666 person_name_engine_ = std::move(person_name_engine);
667 return true;
668 }
669
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)670 bool Annotator::InitializePersonNameEngineFromScopedMmap(
671 const ScopedMmap& mmap) {
672 if (!mmap.handle().ok()) {
673 TC3_LOG(ERROR) << "Mmap for person name model failed.";
674 return false;
675 }
676
677 return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
678 mmap.handle().num_bytes());
679 }
680
InitializePersonNameEngineFromPath(const std::string & path)681 bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
682 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
683 return InitializePersonNameEngineFromScopedMmap(*mmap);
684 }
685
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)686 bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
687 int size) {
688 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
689 return InitializePersonNameEngineFromScopedMmap(*mmap);
690 }
691
InitializeExperimentalAnnotators()692 bool Annotator::InitializeExperimentalAnnotators() {
693 if (ExperimentalAnnotator::IsEnabled()) {
694 experimental_annotator_.reset(new ExperimentalAnnotator(
695 model_->experimental_model(), *selection_feature_processor_, *unilib_));
696 return true;
697 }
698 return false;
699 }
700
701 namespace internal {
702 // Helper function, which if the initial 'span' contains only white-spaces,
703 // moves the selection to a single-codepoint selection on a left or right side
704 // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)705 CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
706 const UnicodeText& context_unicode,
707 const UniLib& unilib) {
708 TC3_CHECK(span.IsValid() && !span.IsEmpty());
709
710 UnicodeText::const_iterator it;
711
712 // Check that the current selection is all whitespaces.
713 it = context_unicode.begin();
714 std::advance(it, span.first);
715 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
716 if (!unilib.IsWhitespace(*it)) {
717 return span;
718 }
719 }
720
721 // Try moving left.
722 CodepointSpan result = span;
723 it = context_unicode.begin();
724 std::advance(it, span.first);
725 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
726 --result.first;
727 --it;
728 }
729 result.second = result.first + 1;
730 if (!unilib.IsWhitespace(*it)) {
731 return result;
732 }
733
734 // If moving left didn't find a non-whitespace character, just return the
735 // original span.
736 return span;
737 }
738 } // namespace internal
739
FilteredForAnnotation(const AnnotatedSpan & span) const740 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
741 return !span.classification.empty() &&
742 filtered_collections_annotation_.find(
743 span.classification[0].collection) !=
744 filtered_collections_annotation_.end();
745 }
746
FilteredForClassification(const ClassificationResult & classification) const747 bool Annotator::FilteredForClassification(
748 const ClassificationResult& classification) const {
749 return filtered_collections_classification_.find(classification.collection) !=
750 filtered_collections_classification_.end();
751 }
752
FilteredForSelection(const AnnotatedSpan & span) const753 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
754 return !span.classification.empty() &&
755 filtered_collections_selection_.find(
756 span.classification[0].collection) !=
757 filtered_collections_selection_.end();
758 }
759
760 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)761 inline bool ClassifiedAsOther(
762 const std::vector<ClassificationResult>& classification) {
763 return !classification.empty() &&
764 classification[0].collection == Collections::Other();
765 }
766
767 } // namespace
768
GetPriorityScore(const std::vector<ClassificationResult> & classification) const769 float Annotator::GetPriorityScore(
770 const std::vector<ClassificationResult>& classification) const {
771 if (!classification.empty() && !ClassifiedAsOther(classification)) {
772 return classification[0].priority_score;
773 } else {
774 if (model_->triggering_options() != nullptr) {
775 return model_->triggering_options()->other_collection_priority_score();
776 } else {
777 return -1000.0;
778 }
779 }
780 }
781
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const782 bool Annotator::VerifyRegexMatchCandidate(
783 const std::string& context, const VerificationOptions* verification_options,
784 const std::string& match, const UniLib::RegexMatcher* matcher) const {
785 if (verification_options == nullptr) {
786 return true;
787 }
788 if (verification_options->verify_luhn_checksum() &&
789 !VerifyLuhnChecksum(match)) {
790 return false;
791 }
792 const int lua_verifier = verification_options->lua_verifier();
793 if (lua_verifier >= 0) {
794 if (model_->regex_model()->lua_verifier() == nullptr ||
795 lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
796 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
797 return false;
798 }
799 return VerifyMatch(
800 context, matcher,
801 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
802 }
803 return true;
804 }
805
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const806 CodepointSpan Annotator::SuggestSelection(
807 const std::string& context, CodepointSpan click_indices,
808 const SelectionOptions& options) const {
809 if (context.size() > std::numeric_limits<int>::max()) {
810 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
811 return {};
812 }
813
814 CodepointSpan original_click_indices = click_indices;
815 if (!initialized_) {
816 TC3_LOG(ERROR) << "Not initialized";
817 return original_click_indices;
818 }
819 if (options.annotation_usecase !=
820 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
821 TC3_LOG(WARNING)
822 << "Invoking SuggestSelection, which is not supported in RAW mode.";
823 return original_click_indices;
824 }
825 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
826 return original_click_indices;
827 }
828
829 std::vector<Locale> detected_text_language_tags;
830 if (!ParseLocales(options.detected_text_language_tags,
831 &detected_text_language_tags)) {
832 TC3_LOG(WARNING)
833 << "Failed to parse the detected_text_language_tags in options: "
834 << options.detected_text_language_tags;
835 }
836 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
837 model_triggering_locales_,
838 /*default_value=*/true)) {
839 return original_click_indices;
840 }
841
842 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
843 /*do_copy=*/false);
844
845 if (!unilib_->IsValidUtf8(context_unicode)) {
846 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
847 return original_click_indices;
848 }
849
850 if (!IsValidSpanInput(context_unicode, click_indices)) {
851 TC3_VLOG(1)
852 << "Trying to run SuggestSelection with invalid input, indices: "
853 << click_indices.first << " " << click_indices.second;
854 return original_click_indices;
855 }
856
857 if (model_->snap_whitespace_selections()) {
858 // We want to expand a purely white-space selection to a multi-selection it
859 // would've been part of. But with this feature disabled we would do a no-
860 // op, because no token is found. Therefore, we need to modify the
861 // 'click_indices' a bit to include a part of the token, so that the click-
862 // finding logic finds the clicked token correctly. This modification is
863 // done by the following function. Note, that it's enough to check the left
864 // side of the current selection, because if the white-space is a part of a
865 // multi-selection, necessarily both tokens - on the left and the right
866 // sides need to be selected. Thus snapping only to the left is sufficient
867 // (there's a check at the bottom that makes sure that if we snap to the
868 // left token but the result does not contain the initial white-space,
869 // returns the original indices).
870 click_indices = internal::SnapLeftIfWhitespaceSelection(
871 click_indices, context_unicode, *unilib_);
872 }
873
874 Annotations candidates;
875 // As we process a single string of context, the candidates will only
876 // contain one vector of AnnotatedSpan.
877 candidates.annotated_spans.resize(1);
878 InterpreterManager interpreter_manager(selection_executor_.get(),
879 classification_executor_.get());
880 std::vector<Token> tokens;
881 if (!ModelSuggestSelection(context_unicode, click_indices,
882 detected_text_language_tags, &interpreter_manager,
883 &tokens, &candidates.annotated_spans[0])) {
884 TC3_LOG(ERROR) << "Model suggest selection failed.";
885 return original_click_indices;
886 }
887 const std::unordered_set<std::string> set;
888 const EnabledEntityTypes is_entity_type_enabled(set);
889 if (!RegexChunk(context_unicode, selection_regex_patterns_,
890 /*is_serialized_entity_data_enabled=*/false,
891 is_entity_type_enabled, options.annotation_usecase,
892 &candidates.annotated_spans[0])) {
893 TC3_LOG(ERROR) << "Regex suggest selection failed.";
894 return original_click_indices;
895 }
896 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
897 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
898 options.locales, ModeFlag_SELECTION,
899 options.annotation_usecase,
900 /*is_serialized_entity_data_enabled=*/false,
901 &candidates.annotated_spans[0])) {
902 TC3_LOG(ERROR) << "Datetime suggest selection failed.";
903 return original_click_indices;
904 }
905 if (knowledge_engine_ != nullptr &&
906 !knowledge_engine_
907 ->Chunk(context, options.annotation_usecase,
908 options.location_context, Permissions(),
909 AnnotateMode::kEntityAnnotation, &candidates)
910 .ok()) {
911 TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
912 return original_click_indices;
913 }
914 if (contact_engine_ != nullptr &&
915 !contact_engine_->Chunk(context_unicode, tokens,
916 &candidates.annotated_spans[0])) {
917 TC3_LOG(ERROR) << "Contact suggest selection failed.";
918 return original_click_indices;
919 }
920 if (installed_app_engine_ != nullptr &&
921 !installed_app_engine_->Chunk(context_unicode, tokens,
922 &candidates.annotated_spans[0])) {
923 TC3_LOG(ERROR) << "Installed app suggest selection failed.";
924 return original_click_indices;
925 }
926 if (number_annotator_ != nullptr &&
927 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
928 &candidates.annotated_spans[0])) {
929 TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
930 return original_click_indices;
931 }
932 if (duration_annotator_ != nullptr &&
933 !duration_annotator_->FindAll(context_unicode, tokens,
934 options.annotation_usecase,
935 &candidates.annotated_spans[0])) {
936 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
937 return original_click_indices;
938 }
939 if (person_name_engine_ != nullptr &&
940 !person_name_engine_->Chunk(context_unicode, tokens,
941 &candidates.annotated_spans[0])) {
942 TC3_LOG(ERROR) << "Person name suggest selection failed.";
943 return original_click_indices;
944 }
945
946 AnnotatedSpan grammar_suggested_span;
947 if (grammar_annotator_ != nullptr &&
948 grammar_annotator_->SuggestSelection(detected_text_language_tags,
949 context_unicode, click_indices,
950 &grammar_suggested_span)) {
951 candidates.annotated_spans[0].push_back(grammar_suggested_span);
952 }
953
954 AnnotatedSpan pod_ner_suggested_span;
955 if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
956 pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
957 &pod_ner_suggested_span)) {
958 candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
959 }
960
961 if (experimental_annotator_ != nullptr) {
962 candidates.annotated_spans[0].push_back(
963 experimental_annotator_->SuggestSelection(context_unicode,
964 click_indices));
965 }
966
967 // Sort candidates according to their position in the input, so that the next
968 // code can assume that any connected component of overlapping spans forms a
969 // contiguous block.
970 std::sort(candidates.annotated_spans[0].begin(),
971 candidates.annotated_spans[0].end(),
972 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
973 return a.span.first < b.span.first;
974 });
975
976 std::vector<int> candidate_indices;
977 if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
978 detected_text_language_tags, options,
979 &interpreter_manager, &candidate_indices)) {
980 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
981 return original_click_indices;
982 }
983
984 std::sort(candidate_indices.begin(), candidate_indices.end(),
985 [this, &candidates](int a, int b) {
986 return GetPriorityScore(
987 candidates.annotated_spans[0][a].classification) >
988 GetPriorityScore(
989 candidates.annotated_spans[0][b].classification);
990 });
991
992 for (const int i : candidate_indices) {
993 if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
994 SpansOverlap(candidates.annotated_spans[0][i].span,
995 original_click_indices)) {
996 // Run model classification if not present but requested and there's a
997 // classification collection filter specified.
998 if (candidates.annotated_spans[0][i].classification.empty() &&
999 model_->selection_options()->always_classify_suggested_selection() &&
1000 !filtered_collections_selection_.empty()) {
1001 if (!ModelClassifyText(context, /*cached_tokens=*/{},
1002 detected_text_language_tags,
1003 candidates.annotated_spans[0][i].span, options,
1004 &interpreter_manager,
1005 /*embedding_cache=*/nullptr,
1006 &candidates.annotated_spans[0][i].classification,
1007 /*tokens=*/nullptr)) {
1008 return original_click_indices;
1009 }
1010 }
1011
1012 // Ignore if span classification is filtered.
1013 if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1014 return original_click_indices;
1015 }
1016
1017 // We return a suggested span contains the original span.
1018 // This compensates for "select all" selection that may come from
1019 // other apps. See http://b/179890518.
1020 if (SpanContains(candidates.annotated_spans[0][i].span,
1021 original_click_indices)) {
1022 return candidates.annotated_spans[0][i].span;
1023 }
1024 }
1025 }
1026
1027 return original_click_indices;
1028 }
1029
1030 namespace {
1031 // Helper function that returns the index of the first candidate that
1032 // transitively does not overlap with the candidate on 'start_index'. If the end
1033 // of 'candidates' is reached, it returns the index that points right behind the
1034 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1035 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1036 int start_index) {
1037 int first_non_overlapping = start_index + 1;
1038 CodepointSpan conflicting_span = candidates[start_index].span;
1039 while (
1040 first_non_overlapping < candidates.size() &&
1041 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1042 // Grow the span to include the current one.
1043 conflicting_span.second = std::max(
1044 conflicting_span.second, candidates[first_non_overlapping].span.second);
1045
1046 ++first_non_overlapping;
1047 }
1048 return first_non_overlapping;
1049 }
1050 } // namespace
1051
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1052 bool Annotator::ResolveConflicts(
1053 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1054 const std::vector<Token>& cached_tokens,
1055 const std::vector<Locale>& detected_text_language_tags,
1056 const BaseOptions& options, InterpreterManager* interpreter_manager,
1057 std::vector<int>* result) const {
1058 result->clear();
1059 result->reserve(candidates.size());
1060 for (int i = 0; i < candidates.size();) {
1061 int first_non_overlapping =
1062 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1063
1064 const bool conflict_found = first_non_overlapping != (i + 1);
1065 if (conflict_found) {
1066 std::vector<int> candidate_indices;
1067 if (!ResolveConflict(context, cached_tokens, candidates,
1068 detected_text_language_tags, i,
1069 first_non_overlapping, options, interpreter_manager,
1070 &candidate_indices)) {
1071 return false;
1072 }
1073 result->insert(result->end(), candidate_indices.begin(),
1074 candidate_indices.end());
1075 } else {
1076 result->push_back(i);
1077 }
1078
1079 // Skip over the whole conflicting group/go to next candidate.
1080 i = first_non_overlapping;
1081 }
1082 return true;
1083 }
1084
1085 namespace {
1086 // Returns true, if the given two sources do conflict in given annotation
1087 // usecase.
1088 // - In SMART usecase, all sources do conflict, because there's only 1 possible
1089 // annotation for a given span.
1090 // - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1091 // and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1092 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1093 const AnnotatedSpan::Source source1,
1094 const AnnotatedSpan::Source source2) {
1095 uint32 source_mask =
1096 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1097
1098 switch (annotation_usecase) {
1099 case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1100 // In the SMART mode, all annotations conflict.
1101 return true;
1102
1103 case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1104 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1105 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1106 // hours" (duration).
1107 if ((source_mask &
1108 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1109 (source_mask &
1110 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1111 return false;
1112 }
1113
1114 // A KNOWLEDGE entity does not conflict with anything.
1115 if ((source_mask &
1116 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1117 return false;
1118 }
1119
1120 // A PERSONNAME entity does not conflict with anything.
1121 if ((source_mask &
1122 (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1123 return false;
1124 }
1125
1126 // Entities from other sources can conflict.
1127 return true;
1128 }
1129 }
1130 } // namespace
1131
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1132 bool Annotator::ResolveConflict(
1133 const std::string& context, const std::vector<Token>& cached_tokens,
1134 const std::vector<AnnotatedSpan>& candidates,
1135 const std::vector<Locale>& detected_text_language_tags, int start_index,
1136 int end_index, const BaseOptions& options,
1137 InterpreterManager* interpreter_manager,
1138 std::vector<int>* chosen_indices) const {
1139 std::vector<int> conflicting_indices;
1140 std::unordered_map<int, std::pair<float, int>> scores_lengths;
1141 for (int i = start_index; i < end_index; ++i) {
1142 conflicting_indices.push_back(i);
1143 if (!candidates[i].classification.empty()) {
1144 scores_lengths[i] = {
1145 GetPriorityScore(candidates[i].classification),
1146 candidates[i].span.second - candidates[i].span.first};
1147 continue;
1148 }
1149
1150 // OPTIMIZATION: So that we don't have to classify all the ML model
1151 // spans apriori, we wait until we get here, when they conflict with
1152 // something and we need the actual classification scores. So if the
1153 // candidate conflicts and comes from the model, we need to run a
1154 // classification to determine its priority:
1155 std::vector<ClassificationResult> classification;
1156 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1157 candidates[i].span, options, interpreter_manager,
1158 /*embedding_cache=*/nullptr, &classification,
1159 /*tokens=*/nullptr)) {
1160 return false;
1161 }
1162
1163 if (!classification.empty()) {
1164 scores_lengths[i] = {
1165 GetPriorityScore(classification),
1166 candidates[i].span.second - candidates[i].span.first};
1167 }
1168 }
1169
1170 std::sort(
1171 conflicting_indices.begin(), conflicting_indices.end(),
1172 [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1173 if (scores_lengths[i].first == scores_lengths[j].first &&
1174 prioritize_longest_annotation_) {
1175 return scores_lengths[i].second > scores_lengths[j].second;
1176 }
1177 return scores_lengths[i].first > scores_lengths[j].first;
1178 });
1179
1180 // Here we keep a set of indices that were chosen, per-source, to enable
1181 // effective computation.
1182 std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1183 chosen_indices_for_source_map;
1184
1185 // Greedily place the candidates if they don't conflict with the already
1186 // placed ones.
1187 for (int i = 0; i < conflicting_indices.size(); ++i) {
1188 const int considered_candidate = conflicting_indices[i];
1189
1190 // See if there is a conflict between the candidate and all already placed
1191 // candidates.
1192 bool conflict = false;
1193 SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1194 for (auto& source_set_pair : chosen_indices_for_source_map) {
1195 if (source_set_pair.first == candidates[considered_candidate].source) {
1196 chosen_indices_for_source_ptr = &source_set_pair.second;
1197 }
1198
1199 const bool needs_conflict_resolution =
1200 options.annotation_usecase ==
1201 AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1202 (options.annotation_usecase ==
1203 AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1204 do_conflict_resolution_in_raw_mode_);
1205 if (needs_conflict_resolution &&
1206 DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1207 candidates[considered_candidate].source) &&
1208 DoesCandidateConflict(considered_candidate, candidates,
1209 source_set_pair.second)) {
1210 conflict = true;
1211 break;
1212 }
1213 }
1214
1215 // Skip the candidate if a conflict was found.
1216 if (conflict) {
1217 continue;
1218 }
1219
1220 // If the set of indices for the current source doesn't exist yet,
1221 // initialize it.
1222 if (chosen_indices_for_source_ptr == nullptr) {
1223 SortedIntSet new_set([&candidates](int a, int b) {
1224 return candidates[a].span.first < candidates[b].span.first;
1225 });
1226 chosen_indices_for_source_map[candidates[considered_candidate].source] =
1227 std::move(new_set);
1228 chosen_indices_for_source_ptr =
1229 &chosen_indices_for_source_map[candidates[considered_candidate]
1230 .source];
1231 }
1232
1233 // Place the candidate to the output and to the per-source conflict set.
1234 chosen_indices->push_back(considered_candidate);
1235 chosen_indices_for_source_ptr->insert(considered_candidate);
1236 }
1237
1238 std::sort(chosen_indices->begin(), chosen_indices->end());
1239
1240 return true;
1241 }
1242
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1243 bool Annotator::ModelSuggestSelection(
1244 const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1245 const std::vector<Locale>& detected_text_language_tags,
1246 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1247 std::vector<AnnotatedSpan>* result) const {
1248 if (model_->triggering_options() == nullptr ||
1249 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1250 return true;
1251 }
1252
1253 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1254 ml_model_triggering_locales_,
1255 /*default_value=*/true)) {
1256 return true;
1257 }
1258
1259 int click_pos;
1260 *tokens = selection_feature_processor_->Tokenize(context_unicode);
1261 const auto [click_begin, click_end] =
1262 CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1263 selection_feature_processor_->RetokenizeAndFindClick(
1264 context_unicode, click_begin, click_end, click_indices,
1265 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1266 tokens, &click_pos);
1267 if (click_pos == kInvalidIndex) {
1268 TC3_VLOG(1) << "Could not calculate the click position.";
1269 return false;
1270 }
1271
1272 const int symmetry_context_size =
1273 model_->selection_options()->symmetry_context_size();
1274 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1275 bounds_sensitive_features = selection_feature_processor_->GetOptions()
1276 ->bounds_sensitive_features();
1277
1278 // The symmetry context span is the clicked token with symmetry_context_size
1279 // tokens on either side.
1280 const TokenSpan symmetry_context_span =
1281 IntersectTokenSpans(TokenSpan(click_pos).Expand(
1282 /*num_tokens_left=*/symmetry_context_size,
1283 /*num_tokens_right=*/symmetry_context_size),
1284 AllOf(*tokens));
1285
1286 // Compute the extraction span based on the model type.
1287 TokenSpan extraction_span;
1288 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1289 // The extraction span is the symmetry context span expanded to include
1290 // max_selection_span tokens on either side, which is how far a selection
1291 // can stretch from the click, plus a relevant number of tokens outside of
1292 // the bounds of the selection.
1293 const int max_selection_span =
1294 selection_feature_processor_->GetOptions()->max_selection_span();
1295 extraction_span = symmetry_context_span.Expand(
1296 /*num_tokens_left=*/max_selection_span +
1297 bounds_sensitive_features->num_tokens_before(),
1298 /*num_tokens_right=*/max_selection_span +
1299 bounds_sensitive_features->num_tokens_after());
1300 } else {
1301 // The extraction span is the symmetry context span expanded to include
1302 // context_size tokens on either side.
1303 const int context_size =
1304 selection_feature_processor_->GetOptions()->context_size();
1305 extraction_span = symmetry_context_span.Expand(
1306 /*num_tokens_left=*/context_size,
1307 /*num_tokens_right=*/context_size);
1308 }
1309 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1310
1311 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1312 *tokens, extraction_span)) {
1313 return true;
1314 }
1315
1316 std::unique_ptr<CachedFeatures> cached_features;
1317 if (!selection_feature_processor_->ExtractFeatures(
1318 *tokens, extraction_span,
1319 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1320 embedding_executor_.get(),
1321 /*embedding_cache=*/nullptr,
1322 selection_feature_processor_->EmbeddingSize() +
1323 selection_feature_processor_->DenseFeaturesCount(),
1324 &cached_features)) {
1325 TC3_LOG(ERROR) << "Could not extract features.";
1326 return false;
1327 }
1328
1329 // Produce selection model candidates.
1330 std::vector<TokenSpan> chunks;
1331 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1332 interpreter_manager->SelectionInterpreter(), *cached_features,
1333 &chunks)) {
1334 TC3_LOG(ERROR) << "Could not chunk.";
1335 return false;
1336 }
1337
1338 for (const TokenSpan& chunk : chunks) {
1339 AnnotatedSpan candidate;
1340 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1341 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1342 if (model_->selection_options()->strip_unpaired_brackets()) {
1343 candidate.span =
1344 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1345 }
1346
1347 // Only output non-empty spans.
1348 if (candidate.span.first != candidate.span.second) {
1349 result->push_back(candidate);
1350 }
1351 }
1352 return true;
1353 }
1354
1355 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1356 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1357 const CodepointSpan& selection_indices,
1358 TokenSpan tokens_around_selection_to_copy) {
1359 const auto first_selection_token = std::upper_bound(
1360 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1361 [](int selection_start, const Token& token) {
1362 return selection_start < token.end;
1363 });
1364 const auto last_selection_token = std::lower_bound(
1365 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1366 [](const Token& token, int selection_end) {
1367 return token.start < selection_end;
1368 });
1369
1370 const int64 first_token = std::max(
1371 static_cast<int64>(0),
1372 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1373 tokens_around_selection_to_copy.first));
1374 const int64 last_token = std::min(
1375 static_cast<int64>(cached_tokens.size()),
1376 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1377 tokens_around_selection_to_copy.second));
1378
1379 std::vector<Token> tokens;
1380 tokens.reserve(last_token - first_token);
1381 for (int i = first_token; i < last_token; ++i) {
1382 tokens.push_back(cached_tokens[i]);
1383 }
1384 return tokens;
1385 }
1386 } // namespace internal
1387
ClassifyTextUpperBoundNeededTokens() const1388 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1389 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1390 bounds_sensitive_features =
1391 classification_feature_processor_->GetOptions()
1392 ->bounds_sensitive_features();
1393 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1394 // The extraction span is the selection span expanded to include a relevant
1395 // number of tokens outside of the bounds of the selection.
1396 return {bounds_sensitive_features->num_tokens_before(),
1397 bounds_sensitive_features->num_tokens_after()};
1398 } else {
1399 // The extraction span is the clicked token with context_size tokens on
1400 // either side.
1401 const int context_size =
1402 selection_feature_processor_->GetOptions()->context_size();
1403 return {context_size, context_size};
1404 }
1405 }
1406
1407 namespace {
1408 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1409 void SortClassificationResults(
1410 std::vector<ClassificationResult>* classification_results) {
1411 std::sort(classification_results->begin(), classification_results->end(),
1412 [](const ClassificationResult& a, const ClassificationResult& b) {
1413 return a.score > b.score;
1414 });
1415 }
1416 } // namespace
1417
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1418 bool Annotator::ModelClassifyText(
1419 const std::string& context, const std::vector<Token>& cached_tokens,
1420 const std::vector<Locale>& detected_text_language_tags,
1421 const CodepointSpan& selection_indices, const BaseOptions& options,
1422 InterpreterManager* interpreter_manager,
1423 FeatureProcessor::EmbeddingCache* embedding_cache,
1424 std::vector<ClassificationResult>* classification_results,
1425 std::vector<Token>* tokens) const {
1426 const UnicodeText context_unicode =
1427 UTF8ToUnicodeText(context, /*do_copy=*/false);
1428 const auto [span_begin, span_end] =
1429 CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1430 return ModelClassifyText(context_unicode, cached_tokens,
1431 detected_text_language_tags, span_begin, span_end,
1432 /*line=*/nullptr, selection_indices, options,
1433 interpreter_manager, embedding_cache,
1434 classification_results, tokens);
1435 }
1436
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1437 bool Annotator::ModelClassifyText(
1438 const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1439 const std::vector<Locale>& detected_text_language_tags,
1440 const UnicodeText::const_iterator& span_begin,
1441 const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1442 const CodepointSpan& selection_indices, const BaseOptions& options,
1443 InterpreterManager* interpreter_manager,
1444 FeatureProcessor::EmbeddingCache* embedding_cache,
1445 std::vector<ClassificationResult>* classification_results,
1446 std::vector<Token>* tokens) const {
1447 if (model_->triggering_options() == nullptr ||
1448 !(model_->triggering_options()->enabled_modes() &
1449 ModeFlag_CLASSIFICATION)) {
1450 return true;
1451 }
1452
1453 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1454 ml_model_triggering_locales_,
1455 /*default_value=*/true)) {
1456 return true;
1457 }
1458
1459 std::vector<Token> local_tokens;
1460 if (tokens == nullptr) {
1461 tokens = &local_tokens;
1462 }
1463
1464 if (cached_tokens.empty()) {
1465 *tokens = classification_feature_processor_->Tokenize(context_unicode);
1466 } else {
1467 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1468 ClassifyTextUpperBoundNeededTokens());
1469 }
1470
1471 int click_pos;
1472 classification_feature_processor_->RetokenizeAndFindClick(
1473 context_unicode, span_begin, span_end, selection_indices,
1474 classification_feature_processor_->GetOptions()
1475 ->only_use_line_with_click(),
1476 tokens, &click_pos);
1477 const TokenSpan selection_token_span =
1478 CodepointSpanToTokenSpan(*tokens, selection_indices);
1479 const int selection_num_tokens = selection_token_span.Size();
1480 if (model_->classification_options()->max_num_tokens() > 0 &&
1481 model_->classification_options()->max_num_tokens() <
1482 selection_num_tokens) {
1483 *classification_results = {{Collections::Other(), 1.0}};
1484 return true;
1485 }
1486
1487 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1488 bounds_sensitive_features =
1489 classification_feature_processor_->GetOptions()
1490 ->bounds_sensitive_features();
1491 if (selection_token_span.first == kInvalidIndex ||
1492 selection_token_span.second == kInvalidIndex) {
1493 TC3_LOG(ERROR) << "Could not determine span.";
1494 return false;
1495 }
1496
1497 // Compute the extraction span based on the model type.
1498 TokenSpan extraction_span;
1499 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1500 // The extraction span is the selection span expanded to include a relevant
1501 // number of tokens outside of the bounds of the selection.
1502 extraction_span = selection_token_span.Expand(
1503 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1504 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1505 } else {
1506 if (click_pos == kInvalidIndex) {
1507 TC3_LOG(ERROR) << "Couldn't choose a click position.";
1508 return false;
1509 }
1510 // The extraction span is the clicked token with context_size tokens on
1511 // either side.
1512 const int context_size =
1513 classification_feature_processor_->GetOptions()->context_size();
1514 extraction_span = TokenSpan(click_pos).Expand(
1515 /*num_tokens_left=*/context_size,
1516 /*num_tokens_right=*/context_size);
1517 }
1518 extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1519
1520 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1521 *tokens, extraction_span)) {
1522 *classification_results = {{Collections::Other(), 1.0}};
1523 return true;
1524 }
1525
1526 std::unique_ptr<CachedFeatures> cached_features;
1527 if (!classification_feature_processor_->ExtractFeatures(
1528 *tokens, extraction_span, selection_indices,
1529 embedding_executor_.get(), embedding_cache,
1530 classification_feature_processor_->EmbeddingSize() +
1531 classification_feature_processor_->DenseFeaturesCount(),
1532 &cached_features)) {
1533 TC3_LOG(ERROR) << "Could not extract features.";
1534 return false;
1535 }
1536
1537 std::vector<float> features;
1538 features.reserve(cached_features->OutputFeaturesSize());
1539 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1540 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1541 &features);
1542 } else {
1543 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1544 }
1545
1546 TensorView<float> logits = classification_executor_->ComputeLogits(
1547 TensorView<float>(features.data(),
1548 {1, static_cast<int>(features.size())}),
1549 interpreter_manager->ClassificationInterpreter());
1550 if (!logits.is_valid()) {
1551 TC3_LOG(ERROR) << "Couldn't compute logits.";
1552 return false;
1553 }
1554
1555 if (logits.dims() != 2 || logits.dim(0) != 1 ||
1556 logits.dim(1) != classification_feature_processor_->NumCollections()) {
1557 TC3_LOG(ERROR) << "Mismatching output";
1558 return false;
1559 }
1560
1561 const std::vector<float> scores =
1562 ComputeSoftmax(logits.data(), logits.dim(1));
1563
1564 if (scores.empty()) {
1565 *classification_results = {{Collections::Other(), 1.0}};
1566 return true;
1567 }
1568
1569 const int best_score_index =
1570 std::max_element(scores.begin(), scores.end()) - scores.begin();
1571 const std::string top_collection =
1572 classification_feature_processor_->LabelToCollection(best_score_index);
1573
1574 // Sanity checks.
1575 if (top_collection == Collections::Phone()) {
1576 const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1577 if (digit_count <
1578 model_->classification_options()->phone_min_num_digits() ||
1579 digit_count >
1580 model_->classification_options()->phone_max_num_digits()) {
1581 *classification_results = {{Collections::Other(), 1.0}};
1582 return true;
1583 }
1584 } else if (top_collection == Collections::Address()) {
1585 if (selection_num_tokens <
1586 model_->classification_options()->address_min_num_tokens()) {
1587 *classification_results = {{Collections::Other(), 1.0}};
1588 return true;
1589 }
1590 } else if (top_collection == Collections::Dictionary()) {
1591 if ((options.use_vocab_annotator && vocab_annotator_) ||
1592 !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1593 dictionary_locales_,
1594 /*default_value=*/false)) {
1595 *classification_results = {{Collections::Other(), 1.0}};
1596 return true;
1597 }
1598 }
1599 *classification_results = {{top_collection, /*arg_score=*/1.0,
1600 /*arg_priority_score=*/scores[best_score_index]}};
1601
1602 // For some entities, we might want to clamp the priority score, for better
1603 // conflict resolution between entities.
1604 if (model_->triggering_options() != nullptr &&
1605 model_->triggering_options()->collection_to_priority() != nullptr) {
1606 if (auto entry =
1607 model_->triggering_options()->collection_to_priority()->LookupByKey(
1608 top_collection.c_str())) {
1609 (*classification_results)[0].priority_score *= entry->value();
1610 }
1611 }
1612 return true;
1613 }
1614
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1615 bool Annotator::RegexClassifyText(
1616 const std::string& context, const CodepointSpan& selection_indices,
1617 std::vector<ClassificationResult>* classification_result) const {
1618 const std::string selection_text =
1619 UTF8ToUnicodeText(context, /*do_copy=*/false)
1620 .UTF8Substring(selection_indices.first, selection_indices.second);
1621 const UnicodeText selection_text_unicode(
1622 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1623
1624 // Check whether any of the regular expressions match.
1625 for (const int pattern_id : classification_regex_patterns_) {
1626 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1627 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1628 regex_pattern.pattern->Matcher(selection_text_unicode);
1629 int status = UniLib::RegexMatcher::kNoError;
1630 bool matches;
1631 if (regex_pattern.config->use_approximate_matching()) {
1632 matches = matcher->ApproximatelyMatches(&status);
1633 } else {
1634 matches = matcher->Matches(&status);
1635 }
1636 if (status != UniLib::RegexMatcher::kNoError) {
1637 return false;
1638 }
1639 if (matches && VerifyRegexMatchCandidate(
1640 context, regex_pattern.config->verification_options(),
1641 selection_text, matcher.get())) {
1642 classification_result->push_back(
1643 {regex_pattern.config->collection_name()->str(),
1644 regex_pattern.config->target_classification_score(),
1645 regex_pattern.config->priority_score()});
1646 if (!SerializedEntityDataFromRegexMatch(
1647 regex_pattern.config, matcher.get(),
1648 &classification_result->back().serialized_entity_data)) {
1649 TC3_LOG(ERROR) << "Could not get entity data.";
1650 return false;
1651 }
1652 }
1653 }
1654
1655 return true;
1656 }
1657
1658 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1659 std::string PickCollectionForDatetime(
1660 const DatetimeParseResult& datetime_parse_result) {
1661 switch (datetime_parse_result.granularity) {
1662 case GRANULARITY_HOUR:
1663 case GRANULARITY_MINUTE:
1664 case GRANULARITY_SECOND:
1665 return Collections::DateTime();
1666 default:
1667 return Collections::Date();
1668 }
1669 }
1670
1671 } // namespace
1672
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1673 bool Annotator::DatetimeClassifyText(
1674 const std::string& context, const CodepointSpan& selection_indices,
1675 const ClassificationOptions& options,
1676 std::vector<ClassificationResult>* classification_results) const {
1677 if (!datetime_parser_) {
1678 return true;
1679 }
1680
1681 const std::string selection_text =
1682 UTF8ToUnicodeText(context, /*do_copy=*/false)
1683 .UTF8Substring(selection_indices.first, selection_indices.second);
1684
1685 LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1686 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1687 datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1688 options.reference_timezone, locale_list,
1689 ModeFlag_CLASSIFICATION,
1690 options.annotation_usecase,
1691 /*anchor_start_end=*/true);
1692 if (!result_status.ok()) {
1693 TC3_LOG(ERROR) << "Error during parsing datetime.";
1694 return false;
1695 }
1696
1697 for (const DatetimeParseResultSpan& datetime_span :
1698 result_status.ValueOrDie()) {
1699 // Only consider the result valid if the selection and extracted datetime
1700 // spans exactly match.
1701 if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1702 datetime_span.span.second + selection_indices.first) ==
1703 selection_indices) {
1704 for (const DatetimeParseResult& parse_result : datetime_span.data) {
1705 classification_results->emplace_back(
1706 PickCollectionForDatetime(parse_result),
1707 datetime_span.target_classification_score);
1708 classification_results->back().datetime_parse_result = parse_result;
1709 classification_results->back().serialized_entity_data =
1710 CreateDatetimeSerializedEntityData(parse_result);
1711 classification_results->back().priority_score =
1712 datetime_span.priority_score;
1713 }
1714 return true;
1715 }
1716 }
1717 return true;
1718 }
1719
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1720 std::vector<ClassificationResult> Annotator::ClassifyText(
1721 const std::string& context, const CodepointSpan& selection_indices,
1722 const ClassificationOptions& options) const {
1723 if (context.size() > std::numeric_limits<int>::max()) {
1724 TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1725 return {};
1726 }
1727 if (!initialized_) {
1728 TC3_LOG(ERROR) << "Not initialized";
1729 return {};
1730 }
1731 if (options.annotation_usecase !=
1732 AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1733 TC3_LOG(WARNING)
1734 << "Invoking ClassifyText, which is not supported in RAW mode.";
1735 return {};
1736 }
1737 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1738 return {};
1739 }
1740
1741 std::vector<Locale> detected_text_language_tags;
1742 if (!ParseLocales(options.detected_text_language_tags,
1743 &detected_text_language_tags)) {
1744 TC3_LOG(WARNING)
1745 << "Failed to parse the detected_text_language_tags in options: "
1746 << options.detected_text_language_tags;
1747 }
1748 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1749 model_triggering_locales_,
1750 /*default_value=*/true)) {
1751 return {};
1752 }
1753
1754 const UnicodeText context_unicode =
1755 UTF8ToUnicodeText(context, /*do_copy=*/false);
1756
1757 if (!unilib_->IsValidUtf8(context_unicode)) {
1758 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1759 return {};
1760 }
1761
1762 if (!IsValidSpanInput(context_unicode, selection_indices)) {
1763 TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1764 << selection_indices.first << " " << selection_indices.second;
1765 return {};
1766 }
1767
1768 // We'll accumulate a list of candidates, and pick the best candidate in the
1769 // end.
1770 std::vector<AnnotatedSpan> candidates;
1771
1772 // Try the knowledge engine.
1773 // TODO(b/126579108): Propagate error status.
1774 ClassificationResult knowledge_result;
1775 if (knowledge_engine_ &&
1776 knowledge_engine_
1777 ->ClassifyText(context, selection_indices, options.annotation_usecase,
1778 options.location_context, Permissions(),
1779 &knowledge_result)
1780 .ok()) {
1781 candidates.push_back({selection_indices, {knowledge_result}});
1782 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1783 }
1784
1785 AddContactMetadataToKnowledgeClassificationResults(&candidates);
1786
1787 // Try the contact engine.
1788 // TODO(b/126579108): Propagate error status.
1789 ClassificationResult contact_result;
1790 if (contact_engine_ && contact_engine_->ClassifyText(
1791 context, selection_indices, &contact_result)) {
1792 candidates.push_back({selection_indices, {contact_result}});
1793 }
1794
1795 // Try the person name engine.
1796 ClassificationResult person_name_result;
1797 if (person_name_engine_ &&
1798 person_name_engine_->ClassifyText(context, selection_indices,
1799 &person_name_result)) {
1800 candidates.push_back({selection_indices, {person_name_result}});
1801 candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1802 }
1803
1804 // Try the installed app engine.
1805 // TODO(b/126579108): Propagate error status.
1806 ClassificationResult installed_app_result;
1807 if (installed_app_engine_ &&
1808 installed_app_engine_->ClassifyText(context, selection_indices,
1809 &installed_app_result)) {
1810 candidates.push_back({selection_indices, {installed_app_result}});
1811 }
1812
1813 // Try the regular expression models.
1814 std::vector<ClassificationResult> regex_results;
1815 if (!RegexClassifyText(context, selection_indices, ®ex_results)) {
1816 return {};
1817 }
1818 for (const ClassificationResult& result : regex_results) {
1819 candidates.push_back({selection_indices, {result}});
1820 }
1821
1822 // Try the date model.
1823 //
1824 // DatetimeClassifyText only returns the first result, which can however have
1825 // more interpretations. They are inserted in the candidates as a single
1826 // AnnotatedSpan, so that they get treated together by the conflict resolution
1827 // algorithm.
1828 std::vector<ClassificationResult> datetime_results;
1829 if (!DatetimeClassifyText(context, selection_indices, options,
1830 &datetime_results)) {
1831 return {};
1832 }
1833 if (!datetime_results.empty()) {
1834 candidates.push_back({selection_indices, std::move(datetime_results)});
1835 candidates.back().source = AnnotatedSpan::Source::DATETIME;
1836 }
1837
1838 // Try the number annotator.
1839 // TODO(b/126579108): Propagate error status.
1840 ClassificationResult number_annotator_result;
1841 if (number_annotator_ &&
1842 number_annotator_->ClassifyText(context_unicode, selection_indices,
1843 options.annotation_usecase,
1844 &number_annotator_result)) {
1845 candidates.push_back({selection_indices, {number_annotator_result}});
1846 }
1847
1848 // Try the duration annotator.
1849 ClassificationResult duration_annotator_result;
1850 if (duration_annotator_ &&
1851 duration_annotator_->ClassifyText(context_unicode, selection_indices,
1852 options.annotation_usecase,
1853 &duration_annotator_result)) {
1854 candidates.push_back({selection_indices, {duration_annotator_result}});
1855 candidates.back().source = AnnotatedSpan::Source::DURATION;
1856 }
1857
1858 // Try the translate annotator.
1859 ClassificationResult translate_annotator_result;
1860 if (translate_annotator_ &&
1861 translate_annotator_->ClassifyText(context_unicode, selection_indices,
1862 options.user_familiar_language_tags,
1863 &translate_annotator_result)) {
1864 candidates.push_back({selection_indices, {translate_annotator_result}});
1865 }
1866
1867 // Try the grammar model.
1868 ClassificationResult grammar_annotator_result;
1869 if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1870 detected_text_language_tags, context_unicode,
1871 selection_indices, &grammar_annotator_result)) {
1872 candidates.push_back({selection_indices, {grammar_annotator_result}});
1873 }
1874
1875 ClassificationResult pod_ner_annotator_result;
1876 if (pod_ner_annotator_ && options.use_pod_ner &&
1877 pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1878 &pod_ner_annotator_result)) {
1879 candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1880 }
1881
1882 ClassificationResult vocab_annotator_result;
1883 if (vocab_annotator_ && options.use_vocab_annotator &&
1884 vocab_annotator_->ClassifyText(
1885 context_unicode, selection_indices, detected_text_language_tags,
1886 options.trigger_dictionary_on_beginner_words,
1887 &vocab_annotator_result)) {
1888 candidates.push_back({selection_indices, {vocab_annotator_result}});
1889 }
1890
1891 if (experimental_annotator_) {
1892 experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1893 candidates);
1894 }
1895
1896 // Try the ML model.
1897 //
1898 // The output of the model is considered as an exclusive 1-of-N choice. That's
1899 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1900 // span for each candidate, like e.g. the regex model.
1901 InterpreterManager interpreter_manager(selection_executor_.get(),
1902 classification_executor_.get());
1903 std::vector<ClassificationResult> model_results;
1904 std::vector<Token> tokens;
1905 if (!ModelClassifyText(
1906 context, /*cached_tokens=*/{}, detected_text_language_tags,
1907 selection_indices, options, &interpreter_manager,
1908 /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1909 return {};
1910 }
1911 if (!model_results.empty()) {
1912 candidates.push_back({selection_indices, std::move(model_results)});
1913 }
1914
1915 std::vector<int> candidate_indices;
1916 if (!ResolveConflicts(candidates, context, tokens,
1917 detected_text_language_tags, options,
1918 &interpreter_manager, &candidate_indices)) {
1919 TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1920 return {};
1921 }
1922
1923 std::vector<ClassificationResult> results;
1924 for (const int i : candidate_indices) {
1925 for (const ClassificationResult& result : candidates[i].classification) {
1926 if (!FilteredForClassification(result)) {
1927 results.push_back(result);
1928 }
1929 }
1930 }
1931
1932 // Sort results according to score.
1933 std::sort(results.begin(), results.end(),
1934 [](const ClassificationResult& a, const ClassificationResult& b) {
1935 return a.score > b.score;
1936 });
1937
1938 if (results.empty()) {
1939 results = {{Collections::Other(), 1.0}};
1940 }
1941 return results;
1942 }
1943
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1944 bool Annotator::ModelAnnotate(
1945 const std::string& context,
1946 const std::vector<Locale>& detected_text_language_tags,
1947 const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1948 std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1949 if (model_->triggering_options() == nullptr ||
1950 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1951 return true;
1952 }
1953
1954 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1955 ml_model_triggering_locales_,
1956 /*default_value=*/true)) {
1957 return true;
1958 }
1959
1960 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1961 /*do_copy=*/false);
1962 std::vector<UnicodeTextRange> lines;
1963 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1964 lines.push_back({context_unicode.begin(), context_unicode.end()});
1965 } else {
1966 lines = selection_feature_processor_->SplitContext(
1967 context_unicode, selection_feature_processor_->GetOptions()
1968 ->use_pipe_character_for_newline());
1969 }
1970
1971 const float min_annotate_confidence =
1972 (model_->triggering_options() != nullptr
1973 ? model_->triggering_options()->min_annotate_confidence()
1974 : 0.f);
1975
1976 for (const UnicodeTextRange& line : lines) {
1977 FeatureProcessor::EmbeddingCache embedding_cache;
1978 const std::string line_str =
1979 UnicodeText::UTF8Substring(line.first, line.second);
1980
1981 std::vector<Token> line_tokens;
1982 line_tokens = selection_feature_processor_->Tokenize(line_str);
1983
1984 selection_feature_processor_->RetokenizeAndFindClick(
1985 line_str, {0, std::distance(line.first, line.second)},
1986 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1987 &line_tokens,
1988 /*click_pos=*/nullptr);
1989 const TokenSpan full_line_span = {
1990 0, static_cast<TokenIndex>(line_tokens.size())};
1991
1992 // TODO(zilka): Add support for greater granularity of this check.
1993 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1994 line_tokens, full_line_span)) {
1995 continue;
1996 }
1997
1998 std::unique_ptr<CachedFeatures> cached_features;
1999 if (!selection_feature_processor_->ExtractFeatures(
2000 line_tokens, full_line_span,
2001 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2002 embedding_executor_.get(),
2003 /*embedding_cache=*/nullptr,
2004 selection_feature_processor_->EmbeddingSize() +
2005 selection_feature_processor_->DenseFeaturesCount(),
2006 &cached_features)) {
2007 TC3_LOG(ERROR) << "Could not extract features.";
2008 return false;
2009 }
2010
2011 std::vector<TokenSpan> local_chunks;
2012 if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2013 interpreter_manager->SelectionInterpreter(),
2014 *cached_features, &local_chunks)) {
2015 TC3_LOG(ERROR) << "Could not chunk.";
2016 return false;
2017 }
2018
2019 const int offset = std::distance(context_unicode.begin(), line.first);
2020 UnicodeText line_unicode;
2021 std::vector<UnicodeText::const_iterator> line_codepoints;
2022 if (options.enable_optimization) {
2023 if (local_chunks.empty()) {
2024 continue;
2025 }
2026 line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2027 line_codepoints = line_unicode.Codepoints();
2028 line_codepoints.push_back(line_unicode.end());
2029 }
2030 for (const TokenSpan& chunk : local_chunks) {
2031 CodepointSpan codepoint_span =
2032 TokenSpanToCodepointSpan(line_tokens, chunk);
2033 if (options.enable_optimization) {
2034 if (!codepoint_span.IsValid() ||
2035 codepoint_span.second > line_codepoints.size()) {
2036 continue;
2037 }
2038 codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2039 /*span_begin=*/line_codepoints[codepoint_span.first],
2040 /*span_end=*/line_codepoints[codepoint_span.second],
2041 codepoint_span);
2042 if (model_->selection_options()->strip_unpaired_brackets()) {
2043 codepoint_span = StripUnpairedBrackets(
2044 /*span_begin=*/line_codepoints[codepoint_span.first],
2045 /*span_end=*/line_codepoints[codepoint_span.second],
2046 codepoint_span, *unilib_);
2047 }
2048 } else {
2049 codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2050 line_str, codepoint_span);
2051 if (model_->selection_options()->strip_unpaired_brackets()) {
2052 codepoint_span =
2053 StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
2054 }
2055 }
2056
2057 // Skip empty spans.
2058 if (codepoint_span.first != codepoint_span.second) {
2059 std::vector<ClassificationResult> classification;
2060 if (options.enable_optimization) {
2061 if (!ModelClassifyText(
2062 line_unicode, line_tokens, detected_text_language_tags,
2063 /*span_begin=*/line_codepoints[codepoint_span.first],
2064 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2065 codepoint_span, options, interpreter_manager,
2066 &embedding_cache, &classification, /*tokens=*/nullptr)) {
2067 TC3_LOG(ERROR) << "Could not classify text: "
2068 << (codepoint_span.first + offset) << " "
2069 << (codepoint_span.second + offset);
2070 return false;
2071 }
2072 } else {
2073 if (!ModelClassifyText(line_str, line_tokens,
2074 detected_text_language_tags, codepoint_span,
2075 options, interpreter_manager, &embedding_cache,
2076 &classification, /*tokens=*/nullptr)) {
2077 TC3_LOG(ERROR) << "Could not classify text: "
2078 << (codepoint_span.first + offset) << " "
2079 << (codepoint_span.second + offset);
2080 return false;
2081 }
2082 }
2083
2084 // Do not include the span if it's classified as "other".
2085 if (!classification.empty() && !ClassifiedAsOther(classification) &&
2086 classification[0].score >= min_annotate_confidence) {
2087 AnnotatedSpan result_span;
2088 result_span.span = {codepoint_span.first + offset,
2089 codepoint_span.second + offset};
2090 result_span.classification = std::move(classification);
2091 result->push_back(std::move(result_span));
2092 }
2093 }
2094 }
2095
2096 // If we are going line-by-line, we need to insert the tokens for each line.
2097 // But if not, we can optimize and just std::move the current line vector to
2098 // the output.
2099 if (selection_feature_processor_->GetOptions()
2100 ->only_use_line_with_click()) {
2101 tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2102 } else {
2103 *tokens = std::move(line_tokens);
2104 }
2105 }
2106 return true;
2107 }
2108
SelectionFeatureProcessorForTests() const2109 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2110 return selection_feature_processor_.get();
2111 }
2112
ClassificationFeatureProcessorForTests() const2113 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2114 const {
2115 return classification_feature_processor_.get();
2116 }
2117
DatetimeParserForTests() const2118 const DatetimeParser* Annotator::DatetimeParserForTests() const {
2119 return datetime_parser_.get();
2120 }
2121
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2122 void Annotator::RemoveNotEnabledEntityTypes(
2123 const EnabledEntityTypes& is_entity_type_enabled,
2124 std::vector<AnnotatedSpan>* annotated_spans) const {
2125 for (AnnotatedSpan& annotated_span : *annotated_spans) {
2126 std::vector<ClassificationResult>& classifications =
2127 annotated_span.classification;
2128 classifications.erase(
2129 std::remove_if(classifications.begin(), classifications.end(),
2130 [&is_entity_type_enabled](
2131 const ClassificationResult& classification_result) {
2132 return !is_entity_type_enabled(
2133 classification_result.collection);
2134 }),
2135 classifications.end());
2136 }
2137 annotated_spans->erase(
2138 std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2139 [](const AnnotatedSpan& annotated_span) {
2140 return annotated_span.classification.empty();
2141 }),
2142 annotated_spans->end());
2143 }
2144
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2145 void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2146 std::vector<AnnotatedSpan>* candidates) const {
2147 if (candidates == nullptr || contact_engine_ == nullptr) {
2148 return;
2149 }
2150 for (auto& candidate : *candidates) {
2151 for (auto& classification_result : candidate.classification) {
2152 contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2153 &classification_result);
2154 }
2155 }
2156 }
2157
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2158 Status Annotator::AnnotateSingleInput(
2159 const std::string& context, const AnnotationOptions& options,
2160 std::vector<AnnotatedSpan>* candidates) const {
2161 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2162 return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2163 }
2164
2165 const UnicodeText context_unicode =
2166 UTF8ToUnicodeText(context, /*do_copy=*/false);
2167
2168 std::vector<Locale> detected_text_language_tags;
2169 if (!ParseLocales(options.detected_text_language_tags,
2170 &detected_text_language_tags)) {
2171 TC3_LOG(WARNING)
2172 << "Failed to parse the detected_text_language_tags in options: "
2173 << options.detected_text_language_tags;
2174 }
2175 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2176 model_triggering_locales_,
2177 /*default_value=*/true)) {
2178 return Status(
2179 StatusCode::UNAVAILABLE,
2180 "The detected language tags are not in the supported locales.");
2181 }
2182
2183 InterpreterManager interpreter_manager(selection_executor_.get(),
2184 classification_executor_.get());
2185
2186 const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2187 const bool is_raw_usecase =
2188 options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2189
2190 // Annotate with the selection model.
2191 const bool model_annotations_enabled =
2192 !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2193 std::vector<Token> tokens;
2194 if (model_annotations_enabled &&
2195 !ModelAnnotate(context, detected_text_language_tags, options,
2196 &interpreter_manager, &tokens, candidates)) {
2197 return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2198 } else if (!model_annotations_enabled) {
2199 // If the ML model didn't run, we need to tokenize to support the other
2200 // annotators that depend on the tokens.
2201 // Optimization could be made to only do this when an annotator that uses
2202 // the tokens is enabled, but it's unclear if the added complexity is worth
2203 // it.
2204 if (selection_feature_processor_ != nullptr) {
2205 tokens = selection_feature_processor_->Tokenize(context_unicode);
2206 }
2207 }
2208
2209 // Annotate with the regular expression models.
2210 const bool regex_annotations_enabled =
2211 !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2212 if (regex_annotations_enabled &&
2213 !RegexChunk(
2214 UTF8ToUnicodeText(context, /*do_copy=*/false),
2215 annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2216 is_entity_type_enabled, options.annotation_usecase, candidates)) {
2217 return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2218 }
2219
2220 // Annotate with the datetime model.
2221 // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2222 // relatively slow for some clients.
2223 if ((is_entity_type_enabled(Collections::Date()) ||
2224 is_entity_type_enabled(Collections::DateTime())) &&
2225 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2226 options.reference_time_ms_utc, options.reference_timezone,
2227 options.locales, ModeFlag_ANNOTATION,
2228 options.annotation_usecase,
2229 options.is_serialized_entity_data_enabled, candidates)) {
2230 return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2231 }
2232
2233 // Annotate with the contact engine.
2234 const bool contact_annotations_enabled =
2235 !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2236 if (contact_annotations_enabled && contact_engine_ &&
2237 !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2238 return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2239 }
2240
2241 // Annotate with the installed app engine.
2242 const bool app_annotations_enabled =
2243 !is_raw_usecase || is_entity_type_enabled(Collections::App());
2244 if (app_annotations_enabled && installed_app_engine_ &&
2245 !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2246 return Status(StatusCode::INTERNAL,
2247 "Couldn't run installed app engine Chunk.");
2248 }
2249
2250 // Annotate with the number annotator.
2251 const bool number_annotations_enabled =
2252 !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2253 is_entity_type_enabled(Collections::Percentage()));
2254 if (number_annotations_enabled && number_annotator_ != nullptr &&
2255 !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2256 candidates)) {
2257 return Status(StatusCode::INTERNAL,
2258 "Couldn't run number annotator FindAll.");
2259 }
2260
2261 // Annotate with the duration annotator.
2262 const bool duration_annotations_enabled =
2263 !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2264 if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2265 !duration_annotator_->FindAll(context_unicode, tokens,
2266 options.annotation_usecase, candidates)) {
2267 return Status(StatusCode::INTERNAL,
2268 "Couldn't run duration annotator FindAll.");
2269 }
2270
2271 // Annotate with the person name engine.
2272 const bool person_annotations_enabled =
2273 !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2274 if (person_annotations_enabled && person_name_engine_ &&
2275 !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2276 return Status(StatusCode::INTERNAL,
2277 "Couldn't run person name engine Chunk.");
2278 }
2279
2280 // Annotate with the grammar annotators.
2281 if (grammar_annotator_ != nullptr &&
2282 !grammar_annotator_->Annotate(detected_text_language_tags,
2283 context_unicode, candidates)) {
2284 return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2285 }
2286
2287 // Annotate with the POD NER annotator.
2288 const bool pod_ner_annotations_enabled =
2289 !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2290 if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2291 options.use_pod_ner &&
2292 !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2293 return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2294 }
2295
2296 // Annotate with the vocab annotator.
2297 const bool vocab_annotations_enabled =
2298 !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2299 if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2300 options.use_vocab_annotator &&
2301 !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2302 options.trigger_dictionary_on_beginner_words,
2303 candidates)) {
2304 return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2305 }
2306
2307 // Annotate with the experimental annotator.
2308 if (experimental_annotator_ != nullptr &&
2309 !experimental_annotator_->Annotate(context_unicode, candidates)) {
2310 return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2311 }
2312
2313 // Sort candidates according to their position in the input, so that the next
2314 // code can assume that any connected component of overlapping spans forms a
2315 // contiguous block.
2316 // Also sort them according to the end position and collection, so that the
2317 // deduplication code below can assume that same spans and classifications
2318 // form contiguous blocks.
2319 std::sort(candidates->begin(), candidates->end(),
2320 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2321 if (a.span.first != b.span.first) {
2322 return a.span.first < b.span.first;
2323 }
2324
2325 if (a.span.second != b.span.second) {
2326 return a.span.second < b.span.second;
2327 }
2328
2329 return a.classification[0].collection <
2330 b.classification[0].collection;
2331 });
2332
2333 std::vector<int> candidate_indices;
2334 if (!ResolveConflicts(*candidates, context, tokens,
2335 detected_text_language_tags, options,
2336 &interpreter_manager, &candidate_indices)) {
2337 return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2338 }
2339
2340 // Remove candidates that overlap exactly and have the same collection.
2341 // This can e.g. happen for phone coming from both ML model and regex.
2342 candidate_indices.erase(
2343 std::unique(candidate_indices.begin(), candidate_indices.end(),
2344 [&candidates](const int a_index, const int b_index) {
2345 const AnnotatedSpan& a = (*candidates)[a_index];
2346 const AnnotatedSpan& b = (*candidates)[b_index];
2347 return a.span == b.span &&
2348 a.classification[0].collection ==
2349 b.classification[0].collection;
2350 }),
2351 candidate_indices.end());
2352
2353 std::vector<AnnotatedSpan> result;
2354 result.reserve(candidate_indices.size());
2355 for (const int i : candidate_indices) {
2356 if ((*candidates)[i].classification.empty() ||
2357 ClassifiedAsOther((*candidates)[i].classification) ||
2358 FilteredForAnnotation((*candidates)[i])) {
2359 continue;
2360 }
2361 result.push_back(std::move((*candidates)[i]));
2362 }
2363
2364 // We generate all candidates and remove them later (with the exception of
2365 // date/time/duration entities) because there are complex interdependencies
2366 // between the entity types. E.g., the TLD of an email can be interpreted as a
2367 // URL, but most likely a user of the API does not want such annotations if
2368 // "url" is enabled and "email" is not.
2369 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2370
2371 for (AnnotatedSpan& annotated_span : result) {
2372 SortClassificationResults(&annotated_span.classification);
2373 }
2374 *candidates = result;
2375 return Status::OK;
2376 }
2377
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2378 StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2379 const std::vector<InputFragment>& string_fragments,
2380 const AnnotationOptions& options) const {
2381 Annotations annotation_candidates;
2382 annotation_candidates.annotated_spans.resize(string_fragments.size());
2383
2384 std::vector<std::string> text_to_annotate;
2385 text_to_annotate.reserve(string_fragments.size());
2386 std::vector<FragmentMetadata> fragment_metadata;
2387 fragment_metadata.reserve(string_fragments.size());
2388 for (const auto& string_fragment : string_fragments) {
2389 text_to_annotate.push_back(string_fragment.text);
2390 fragment_metadata.push_back(
2391 {.relative_bounding_box_top = string_fragment.bounding_box_top,
2392 .relative_bounding_box_height = string_fragment.bounding_box_height});
2393 }
2394
2395 // KnowledgeEngine is special, because it supports annotation of multiple
2396 // fragments at once.
2397 if (knowledge_engine_ &&
2398 !knowledge_engine_
2399 ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2400 options.annotation_usecase,
2401 options.location_context, options.permissions,
2402 options.annotate_mode, &annotation_candidates)
2403 .ok()) {
2404 return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2405 }
2406 // The annotator engines shouldn't change the number of annotation vectors.
2407 if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2408 TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2409 << " texts to annotate but generated a different number of "
2410 "lists of annotations:"
2411 << annotation_candidates.annotated_spans.size();
2412 return Status(StatusCode::INTERNAL,
2413 "Number of annotation candidates differs from "
2414 "number of texts to annotate.");
2415 }
2416
2417 // As an optimization, if the only annotated type is Entity, we skip all the
2418 // other annotators than the KnowledgeEngine. This only happens in the raw
2419 // mode, to make sure it does not affect the result.
2420 if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2421 options.entity_types.size() == 1 &&
2422 *options.entity_types.begin() == Collections::Entity()) {
2423 return annotation_candidates;
2424 }
2425
2426 // Other annotators run on each fragment independently.
2427 for (int i = 0; i < text_to_annotate.size(); ++i) {
2428 AnnotationOptions annotation_options = options;
2429 if (string_fragments[i].datetime_options.has_value()) {
2430 DatetimeOptions reference_datetime =
2431 string_fragments[i].datetime_options.value();
2432 annotation_options.reference_time_ms_utc =
2433 reference_datetime.reference_time_ms_utc;
2434 annotation_options.reference_timezone =
2435 reference_datetime.reference_timezone;
2436 }
2437
2438 AddContactMetadataToKnowledgeClassificationResults(
2439 &annotation_candidates.annotated_spans[i]);
2440
2441 Status annotation_status =
2442 AnnotateSingleInput(text_to_annotate[i], annotation_options,
2443 &annotation_candidates.annotated_spans[i]);
2444 if (!annotation_status.ok()) {
2445 return annotation_status;
2446 }
2447 }
2448 return annotation_candidates;
2449 }
2450
Annotate(const std::string & context,const AnnotationOptions & options) const2451 std::vector<AnnotatedSpan> Annotator::Annotate(
2452 const std::string& context, const AnnotationOptions& options) const {
2453 if (context.size() > std::numeric_limits<int>::max()) {
2454 TC3_LOG(ERROR) << "Rejecting too long input.";
2455 return {};
2456 }
2457
2458 const UnicodeText context_unicode =
2459 UTF8ToUnicodeText(context, /*do_copy=*/false);
2460 if (!unilib_->IsValidUtf8(context_unicode)) {
2461 TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2462 return {};
2463 }
2464
2465 std::vector<InputFragment> string_fragments;
2466 string_fragments.push_back({.text = context});
2467 StatusOr<Annotations> annotations =
2468 AnnotateStructuredInput(string_fragments, options);
2469 if (!annotations.ok()) {
2470 TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2471 << annotations.status().error_message();
2472 return {};
2473 }
2474 return annotations.ValueOrDie().annotated_spans[0];
2475 }
2476
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2477 CodepointSpan Annotator::ComputeSelectionBoundaries(
2478 const UniLib::RegexMatcher* match,
2479 const RegexModel_::Pattern* config) const {
2480 if (config->capturing_group() == nullptr) {
2481 // Use first capturing group to specify the selection.
2482 int status = UniLib::RegexMatcher::kNoError;
2483 const CodepointSpan result = {match->Start(1, &status),
2484 match->End(1, &status)};
2485 if (status != UniLib::RegexMatcher::kNoError) {
2486 return {kInvalidIndex, kInvalidIndex};
2487 }
2488 return result;
2489 }
2490
2491 CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2492 const int num_groups = config->capturing_group()->size();
2493 for (int i = 0; i < num_groups; i++) {
2494 if (!config->capturing_group()->Get(i)->extend_selection()) {
2495 continue;
2496 }
2497
2498 int status = UniLib::RegexMatcher::kNoError;
2499 // Check match and adjust bounds.
2500 const int group_start = match->Start(i, &status);
2501 const int group_end = match->End(i, &status);
2502 if (status != UniLib::RegexMatcher::kNoError) {
2503 return {kInvalidIndex, kInvalidIndex};
2504 }
2505 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2506 continue;
2507 }
2508 if (result.first == kInvalidIndex) {
2509 result = {group_start, group_end};
2510 } else {
2511 result.first = std::min(result.first, group_start);
2512 result.second = std::max(result.second, group_end);
2513 }
2514 }
2515 return result;
2516 }
2517
HasEntityData(const RegexModel_::Pattern * pattern) const2518 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2519 if (pattern->serialized_entity_data() != nullptr ||
2520 pattern->entity_data() != nullptr) {
2521 return true;
2522 }
2523 if (pattern->capturing_group() != nullptr) {
2524 for (const CapturingGroup* group : *pattern->capturing_group()) {
2525 if (group->entity_field_path() != nullptr) {
2526 return true;
2527 }
2528 if (group->serialized_entity_data() != nullptr ||
2529 group->entity_data() != nullptr) {
2530 return true;
2531 }
2532 }
2533 }
2534 return false;
2535 }
2536
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2537 bool Annotator::SerializedEntityDataFromRegexMatch(
2538 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2539 std::string* serialized_entity_data) const {
2540 if (!HasEntityData(pattern)) {
2541 serialized_entity_data->clear();
2542 return true;
2543 }
2544 TC3_CHECK(entity_data_builder_ != nullptr);
2545
2546 std::unique_ptr<MutableFlatbuffer> entity_data =
2547 entity_data_builder_->NewRoot();
2548
2549 TC3_CHECK(entity_data != nullptr);
2550
2551 // Set fixed entity data.
2552 if (pattern->serialized_entity_data() != nullptr) {
2553 entity_data->MergeFromSerializedFlatbuffer(
2554 StringPiece(pattern->serialized_entity_data()->c_str(),
2555 pattern->serialized_entity_data()->size()));
2556 }
2557 if (pattern->entity_data() != nullptr) {
2558 entity_data->MergeFrom(
2559 reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2560 }
2561
2562 // Add entity data from rule capturing groups.
2563 if (pattern->capturing_group() != nullptr) {
2564 const int num_groups = pattern->capturing_group()->size();
2565 for (int i = 0; i < num_groups; i++) {
2566 const CapturingGroup* group = pattern->capturing_group()->Get(i);
2567
2568 // Check whether the group matched.
2569 Optional<std::string> group_match_text =
2570 GetCapturingGroupText(matcher, /*group_id=*/i);
2571 if (!group_match_text.has_value()) {
2572 continue;
2573 }
2574
2575 // Set fixed entity data from capturing group match.
2576 if (group->serialized_entity_data() != nullptr) {
2577 entity_data->MergeFromSerializedFlatbuffer(
2578 StringPiece(group->serialized_entity_data()->c_str(),
2579 group->serialized_entity_data()->size()));
2580 }
2581 if (group->entity_data() != nullptr) {
2582 entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2583 pattern->entity_data()));
2584 }
2585
2586 // Set entity field from capturing group text.
2587 if (group->entity_field_path() != nullptr) {
2588 UnicodeText normalized_group_match_text =
2589 UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2590
2591 // Apply normalization if specified.
2592 if (group->normalization_options() != nullptr) {
2593 normalized_group_match_text =
2594 NormalizeText(*unilib_, group->normalization_options(),
2595 normalized_group_match_text);
2596 }
2597
2598 if (!entity_data->ParseAndSet(
2599 group->entity_field_path(),
2600 normalized_group_match_text.ToUTF8String())) {
2601 TC3_LOG(ERROR)
2602 << "Could not set entity data from rule capturing group.";
2603 return false;
2604 }
2605 }
2606 }
2607 }
2608
2609 *serialized_entity_data = entity_data->Serialize();
2610 return true;
2611 }
2612
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2613 UnicodeText RemoveMoneySeparators(
2614 const std::unordered_set<char32>& decimal_separators,
2615 const UnicodeText& amount,
2616 UnicodeText::const_iterator it_decimal_separator) {
2617 UnicodeText whole_amount;
2618 for (auto it = amount.begin();
2619 it != amount.end() && it != it_decimal_separator; ++it) {
2620 if (std::find(decimal_separators.begin(), decimal_separators.end(),
2621 static_cast<char32>(*it)) == decimal_separators.end()) {
2622 whole_amount.push_back(*it);
2623 }
2624 }
2625 return whole_amount;
2626 }
2627
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2628 void Annotator::GetMoneyQuantityFromCapturingGroup(
2629 const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2630 const UnicodeText& context_unicode, std::string* quantity,
2631 int* exponent) const {
2632 if (config->capturing_group() == nullptr) {
2633 *exponent = 0;
2634 return;
2635 }
2636
2637 const int num_groups = config->capturing_group()->size();
2638 for (int i = 0; i < num_groups; i++) {
2639 int status = UniLib::RegexMatcher::kNoError;
2640 const int group_start = match->Start(i, &status);
2641 const int group_end = match->End(i, &status);
2642 if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2643 continue;
2644 }
2645
2646 *quantity =
2647 unilib_
2648 ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2649 group_end, /*do_copy=*/false))
2650 .ToUTF8String();
2651
2652 if (auto entry = model_->money_parsing_options()
2653 ->quantities_name_to_exponent()
2654 ->LookupByKey((*quantity).c_str())) {
2655 *exponent = entry->value();
2656 return;
2657 }
2658 }
2659 *exponent = 0;
2660 }
2661
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2662 bool Annotator::ParseAndFillInMoneyAmount(
2663 std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2664 const RegexModel_::Pattern* config,
2665 const UnicodeText& context_unicode) const {
2666 std::unique_ptr<EntityDataT> data =
2667 LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2668 *serialized_entity_data);
2669 if (data == nullptr) {
2670 if (model_->version() >= 706) {
2671 // This way of parsing money entity data is enabled for models newer than
2672 // v706, consequently logging errors only for them (b/156634162).
2673 TC3_LOG(ERROR)
2674 << "Data field is null when trying to parse Money Entity Data";
2675 }
2676 return false;
2677 }
2678 if (data->money->unnormalized_amount.empty()) {
2679 if (model_->version() >= 706) {
2680 // This way of parsing money entity data is enabled for models newer than
2681 // v706, consequently logging errors only for them (b/156634162).
2682 TC3_LOG(ERROR)
2683 << "Data unnormalized_amount is empty when trying to parse "
2684 "Money Entity Data";
2685 }
2686 return false;
2687 }
2688
2689 UnicodeText amount =
2690 UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2691 int separator_back_index = 0;
2692 auto it_decimal_separator = --amount.end();
2693 for (; it_decimal_separator != amount.begin();
2694 --it_decimal_separator, ++separator_back_index) {
2695 if (std::find(money_separators_.begin(), money_separators_.end(),
2696 static_cast<char32>(*it_decimal_separator)) !=
2697 money_separators_.end()) {
2698 break;
2699 }
2700 }
2701
2702 // If there are 3 digits after the last separator, we consider that a
2703 // thousands separator => the number is an int (e.g. 1.234 is considered int).
2704 // If there is no separator in number, also that number is an int.
2705 if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2706 it_decimal_separator = amount.end();
2707 }
2708
2709 if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2710 it_decimal_separator),
2711 &data->money->amount_whole_part)) {
2712 TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2713 "amount: "
2714 << data->money->unnormalized_amount;
2715 return false;
2716 }
2717
2718 if (it_decimal_separator == amount.end()) {
2719 data->money->amount_decimal_part = 0;
2720 data->money->nanos = 0;
2721 } else {
2722 const int amount_codepoints_size = amount.size_codepoints();
2723 const UnicodeText decimal_part = UnicodeText::Substring(
2724 amount, amount_codepoints_size - separator_back_index,
2725 amount_codepoints_size, /*do_copy=*/false);
2726 if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2727 TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2728 "the amount: "
2729 << data->money->unnormalized_amount;
2730 return false;
2731 }
2732 data->money->nanos = data->money->amount_decimal_part *
2733 pow(10, 9 - decimal_part.size_codepoints());
2734 }
2735
2736 if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2737 nullptr) {
2738 int quantity_exponent;
2739 std::string quantity;
2740 GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2741 &quantity, &quantity_exponent);
2742 if (quantity_exponent > 0 && quantity_exponent <= 9) {
2743 const double amount_whole_part =
2744 data->money->amount_whole_part * pow(10, quantity_exponent) +
2745 data->money->nanos / pow(10, 9 - quantity_exponent);
2746 // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2747 // (and `std::numeric_limits<int>::max()` to
2748 // `std::numeric_limits<int64>::max()`).
2749 if (amount_whole_part < std::numeric_limits<int>::max()) {
2750 data->money->amount_whole_part = amount_whole_part;
2751 data->money->nanos = data->money->nanos %
2752 static_cast<int>(pow(10, 9 - quantity_exponent)) *
2753 pow(10, quantity_exponent);
2754 }
2755 }
2756 if (quantity_exponent > 0) {
2757 data->money->unnormalized_amount = strings::JoinStrings(
2758 " ", {data->money->unnormalized_amount, quantity});
2759 }
2760 }
2761
2762 *serialized_entity_data =
2763 PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2764 return true;
2765 }
2766
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2767 bool Annotator::IsAnyModelEntityTypeEnabled(
2768 const EnabledEntityTypes& is_entity_type_enabled) const {
2769 if (model_->classification_feature_options() == nullptr ||
2770 model_->classification_feature_options()->collections() == nullptr) {
2771 return false;
2772 }
2773 for (int i = 0;
2774 i < model_->classification_feature_options()->collections()->size();
2775 i++) {
2776 if (is_entity_type_enabled(model_->classification_feature_options()
2777 ->collections()
2778 ->Get(i)
2779 ->str())) {
2780 return true;
2781 }
2782 }
2783 return false;
2784 }
2785
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2786 bool Annotator::IsAnyRegexEntityTypeEnabled(
2787 const EnabledEntityTypes& is_entity_type_enabled) const {
2788 if (model_->regex_model() == nullptr ||
2789 model_->regex_model()->patterns() == nullptr) {
2790 return false;
2791 }
2792 for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2793 if (is_entity_type_enabled(model_->regex_model()
2794 ->patterns()
2795 ->Get(i)
2796 ->collection_name()
2797 ->str())) {
2798 return true;
2799 }
2800 }
2801 return false;
2802 }
2803
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2804 bool Annotator::IsAnyPodNerEntityTypeEnabled(
2805 const EnabledEntityTypes& is_entity_type_enabled) const {
2806 if (pod_ner_annotator_ == nullptr) {
2807 return false;
2808 }
2809
2810 for (const std::string& collection :
2811 pod_ner_annotator_->GetSupportedCollections()) {
2812 if (is_entity_type_enabled(collection)) {
2813 return true;
2814 }
2815 }
2816 return false;
2817 }
2818
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2819 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2820 const std::vector<int>& rules,
2821 bool is_serialized_entity_data_enabled,
2822 const EnabledEntityTypes& enabled_entity_types,
2823 const AnnotationUsecase& annotation_usecase,
2824 std::vector<AnnotatedSpan>* result) const {
2825 for (int pattern_id : rules) {
2826 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2827 if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2828 annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2829 // No regex annotation type has been requested, skip regex annotation.
2830 continue;
2831 }
2832 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2833 if (!matcher) {
2834 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2835 << pattern_id;
2836 return false;
2837 }
2838
2839 int status = UniLib::RegexMatcher::kNoError;
2840 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2841 if (regex_pattern.config->verification_options()) {
2842 if (!VerifyRegexMatchCandidate(
2843 context_unicode.ToUTF8String(),
2844 regex_pattern.config->verification_options(),
2845 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2846 continue;
2847 }
2848 }
2849
2850 std::string serialized_entity_data;
2851 if (is_serialized_entity_data_enabled) {
2852 if (!SerializedEntityDataFromRegexMatch(
2853 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2854 TC3_LOG(ERROR) << "Could not get entity data.";
2855 return false;
2856 }
2857
2858 // Further parsing of money amount. Need this since regexes cannot have
2859 // empty groups that fill in entity data (amount_decimal_part and
2860 // quantity might be empty groups).
2861 if (regex_pattern.config->collection_name()->str() ==
2862 Collections::Money()) {
2863 if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2864 regex_pattern.config,
2865 context_unicode)) {
2866 if (model_->version() >= 706) {
2867 // This way of parsing money entity data is enabled for models
2868 // newer than v706 => logging errors only for them (b/156634162).
2869 TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2870 }
2871 }
2872 }
2873 }
2874
2875 result->emplace_back();
2876
2877 // Selection/annotation regular expressions need to specify a capturing
2878 // group specifying the selection.
2879 result->back().span =
2880 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2881
2882 result->back().classification = {
2883 {regex_pattern.config->collection_name()->str(),
2884 regex_pattern.config->target_classification_score(),
2885 regex_pattern.config->priority_score()}};
2886
2887 result->back().classification[0].serialized_entity_data =
2888 serialized_entity_data;
2889 }
2890 }
2891 return true;
2892 }
2893
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2894 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2895 tflite::Interpreter* selection_interpreter,
2896 const CachedFeatures& cached_features,
2897 std::vector<TokenSpan>* chunks) const {
2898 const int max_selection_span =
2899 selection_feature_processor_->GetOptions()->max_selection_span();
2900 // The inference span is the span of interest expanded to include
2901 // max_selection_span tokens on either side, which is how far a selection can
2902 // stretch from the click.
2903 const TokenSpan inference_span =
2904 IntersectTokenSpans(span_of_interest.Expand(
2905 /*num_tokens_left=*/max_selection_span,
2906 /*num_tokens_right=*/max_selection_span),
2907 {0, num_tokens});
2908
2909 std::vector<ScoredChunk> scored_chunks;
2910 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2911 selection_feature_processor_->GetOptions()
2912 ->bounds_sensitive_features()
2913 ->enabled()) {
2914 if (!ModelBoundsSensitiveScoreChunks(
2915 num_tokens, span_of_interest, inference_span, cached_features,
2916 selection_interpreter, &scored_chunks)) {
2917 return false;
2918 }
2919 } else {
2920 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2921 cached_features, selection_interpreter,
2922 &scored_chunks)) {
2923 return false;
2924 }
2925 }
2926 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2927 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2928 return lhs.score < rhs.score;
2929 });
2930
2931 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2932 // them greedily as long as they do not overlap with any previously picked
2933 // chunks.
2934 std::vector<bool> token_used(inference_span.Size());
2935 chunks->clear();
2936 for (const ScoredChunk& scored_chunk : scored_chunks) {
2937 bool feasible = true;
2938 for (int i = scored_chunk.token_span.first;
2939 i < scored_chunk.token_span.second; ++i) {
2940 if (token_used[i - inference_span.first]) {
2941 feasible = false;
2942 break;
2943 }
2944 }
2945
2946 if (!feasible) {
2947 continue;
2948 }
2949
2950 for (int i = scored_chunk.token_span.first;
2951 i < scored_chunk.token_span.second; ++i) {
2952 token_used[i - inference_span.first] = true;
2953 }
2954
2955 chunks->push_back(scored_chunk.token_span);
2956 }
2957
2958 std::sort(chunks->begin(), chunks->end());
2959
2960 return true;
2961 }
2962
2963 namespace {
2964 // Updates the value at the given key in the map to maximum of the current value
2965 // and the given value, or simply inserts the value if the key is not yet there.
2966 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2967 void UpdateMax(Map* map, typename Map::key_type key,
2968 typename Map::mapped_type value) {
2969 const auto it = map->find(key);
2970 if (it != map->end()) {
2971 it->second = std::max(it->second, value);
2972 } else {
2973 (*map)[key] = value;
2974 }
2975 }
2976 } // namespace
2977
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2978 bool Annotator::ModelClickContextScoreChunks(
2979 int num_tokens, const TokenSpan& span_of_interest,
2980 const CachedFeatures& cached_features,
2981 tflite::Interpreter* selection_interpreter,
2982 std::vector<ScoredChunk>* scored_chunks) const {
2983 const int max_batch_size = model_->selection_options()->batch_size();
2984
2985 std::vector<float> all_features;
2986 std::map<TokenSpan, float> chunk_scores;
2987 for (int batch_start = span_of_interest.first;
2988 batch_start < span_of_interest.second; batch_start += max_batch_size) {
2989 const int batch_end =
2990 std::min(batch_start + max_batch_size, span_of_interest.second);
2991
2992 // Prepare features for the whole batch.
2993 all_features.clear();
2994 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2995 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2996 cached_features.AppendClickContextFeaturesForClick(click_pos,
2997 &all_features);
2998 }
2999
3000 // Run batched inference.
3001 const int batch_size = batch_end - batch_start;
3002 const int features_size = cached_features.OutputFeaturesSize();
3003 TensorView<float> logits = selection_executor_->ComputeLogits(
3004 TensorView<float>(all_features.data(), {batch_size, features_size}),
3005 selection_interpreter);
3006 if (!logits.is_valid()) {
3007 TC3_LOG(ERROR) << "Couldn't compute logits.";
3008 return false;
3009 }
3010 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3011 logits.dim(1) !=
3012 selection_feature_processor_->GetSelectionLabelCount()) {
3013 TC3_LOG(ERROR) << "Mismatching output.";
3014 return false;
3015 }
3016
3017 // Save results.
3018 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3019 const std::vector<float> scores = ComputeSoftmax(
3020 logits.data() + logits.dim(1) * (click_pos - batch_start),
3021 logits.dim(1));
3022 for (int j = 0;
3023 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3024 TokenSpan relative_token_span;
3025 if (!selection_feature_processor_->LabelToTokenSpan(
3026 j, &relative_token_span)) {
3027 TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3028 return false;
3029 }
3030 const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3031 relative_token_span.first, relative_token_span.second);
3032 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3033 UpdateMax(&chunk_scores, candidate_span, scores[j]);
3034 }
3035 }
3036 }
3037 }
3038
3039 scored_chunks->clear();
3040 scored_chunks->reserve(chunk_scores.size());
3041 for (const auto& entry : chunk_scores) {
3042 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3043 }
3044
3045 return true;
3046 }
3047
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3048 bool Annotator::ModelBoundsSensitiveScoreChunks(
3049 int num_tokens, const TokenSpan& span_of_interest,
3050 const TokenSpan& inference_span, const CachedFeatures& cached_features,
3051 tflite::Interpreter* selection_interpreter,
3052 std::vector<ScoredChunk>* scored_chunks) const {
3053 const int max_selection_span =
3054 selection_feature_processor_->GetOptions()->max_selection_span();
3055 const int max_chunk_length = selection_feature_processor_->GetOptions()
3056 ->selection_reduced_output_space()
3057 ? max_selection_span + 1
3058 : 2 * max_selection_span + 1;
3059 const bool score_single_token_spans_as_zero =
3060 selection_feature_processor_->GetOptions()
3061 ->bounds_sensitive_features()
3062 ->score_single_token_spans_as_zero();
3063
3064 scored_chunks->clear();
3065 if (score_single_token_spans_as_zero) {
3066 scored_chunks->reserve(span_of_interest.Size());
3067 }
3068
3069 // Prepare all chunk candidates into one batch:
3070 // - Are contained in the inference span
3071 // - Have a non-empty intersection with the span of interest
3072 // - Are at least one token long
3073 // - Are not longer than the maximum chunk length
3074 std::vector<TokenSpan> candidate_spans;
3075 for (int start = inference_span.first; start < span_of_interest.second;
3076 ++start) {
3077 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3078 for (int end = leftmost_end_index;
3079 end <= inference_span.second && end - start <= max_chunk_length;
3080 ++end) {
3081 const TokenSpan candidate_span = {start, end};
3082 if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3083 // Do not include the single token span in the batch, add a zero score
3084 // for it directly to the output.
3085 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3086 } else {
3087 candidate_spans.push_back(candidate_span);
3088 }
3089 }
3090 }
3091
3092 const int max_batch_size = model_->selection_options()->batch_size();
3093
3094 std::vector<float> all_features;
3095 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3096 for (int batch_start = 0; batch_start < candidate_spans.size();
3097 batch_start += max_batch_size) {
3098 const int batch_end = std::min(batch_start + max_batch_size,
3099 static_cast<int>(candidate_spans.size()));
3100
3101 // Prepare features for the whole batch.
3102 all_features.clear();
3103 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3104 for (int i = batch_start; i < batch_end; ++i) {
3105 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3106 &all_features);
3107 }
3108
3109 // Run batched inference.
3110 const int batch_size = batch_end - batch_start;
3111 const int features_size = cached_features.OutputFeaturesSize();
3112 TensorView<float> logits = selection_executor_->ComputeLogits(
3113 TensorView<float>(all_features.data(), {batch_size, features_size}),
3114 selection_interpreter);
3115 if (!logits.is_valid()) {
3116 TC3_LOG(ERROR) << "Couldn't compute logits.";
3117 return false;
3118 }
3119 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3120 logits.dim(1) != 1) {
3121 TC3_LOG(ERROR) << "Mismatching output.";
3122 return false;
3123 }
3124
3125 // Save results.
3126 for (int i = batch_start; i < batch_end; ++i) {
3127 scored_chunks->push_back(
3128 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3129 }
3130 }
3131
3132 return true;
3133 }
3134
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3135 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3136 int64 reference_time_ms_utc,
3137 const std::string& reference_timezone,
3138 const std::string& locales, ModeFlag mode,
3139 AnnotationUsecase annotation_usecase,
3140 bool is_serialized_entity_data_enabled,
3141 std::vector<AnnotatedSpan>* result) const {
3142 if (!datetime_parser_) {
3143 return true;
3144 }
3145 LocaleList locale_list = LocaleList::ParseFrom(locales);
3146 StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3147 datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3148 reference_timezone, locale_list, mode,
3149 annotation_usecase,
3150 /*anchor_start_end=*/false);
3151 if (!result_status.ok()) {
3152 return false;
3153 }
3154
3155 for (const DatetimeParseResultSpan& datetime_span :
3156 result_status.ValueOrDie()) {
3157 AnnotatedSpan annotated_span;
3158 annotated_span.span = datetime_span.span;
3159 for (const DatetimeParseResult& parse_result : datetime_span.data) {
3160 annotated_span.classification.emplace_back(
3161 PickCollectionForDatetime(parse_result),
3162 datetime_span.target_classification_score,
3163 datetime_span.priority_score);
3164 annotated_span.classification.back().datetime_parse_result = parse_result;
3165 if (is_serialized_entity_data_enabled) {
3166 annotated_span.classification.back().serialized_entity_data =
3167 CreateDatetimeSerializedEntityData(parse_result);
3168 }
3169 }
3170 annotated_span.source = AnnotatedSpan::Source::DATETIME;
3171 result->push_back(std::move(annotated_span));
3172 }
3173 return true;
3174 }
3175
model() const3176 const Model* Annotator::model() const { return model_; }
entity_data_schema() const3177 const reflection::Schema* Annotator::entity_data_schema() const {
3178 return entity_data_schema_;
3179 }
3180
ViewModel(const void * buffer,int size)3181 const Model* ViewModel(const void* buffer, int size) {
3182 if (!buffer) {
3183 return nullptr;
3184 }
3185
3186 return LoadAndVerifyModel(buffer, size);
3187 }
3188
LookUpKnowledgeEntity(const std::string & id) const3189 StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3190 const std::string& id) const {
3191 if (!knowledge_engine_) {
3192 return Status(StatusCode::FAILED_PRECONDITION,
3193 "knowledge_engine_ is nullptr");
3194 }
3195 return knowledge_engine_->LookUpEntity(id);
3196 }
3197
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3198 StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3199 const std::string& mid_str, const std::string& property) const {
3200 if (!knowledge_engine_) {
3201 return Status(StatusCode::FAILED_PRECONDITION,
3202 "knowledge_engine_ is nullptr");
3203 }
3204 return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3205 }
3206
3207 } // namespace libtextclassifier3
3208