• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/annotator.h"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <cstddef>
22 #include <iterator>
23 #include <limits>
24 #include <numeric>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28 
29 #include "annotator/collections.h"
30 #include "annotator/datetime/grammar-parser.h"
31 #include "annotator/datetime/regex-parser.h"
32 #include "annotator/flatbuffer-utils.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "annotator/model_generated.h"
35 #include "annotator/types.h"
36 #include "utils/base/logging.h"
37 #include "utils/base/status.h"
38 #include "utils/base/statusor.h"
39 #include "utils/calendar/calendar.h"
40 #include "utils/checksum.h"
41 #include "utils/grammar/analyzer.h"
42 #include "utils/i18n/locale-list.h"
43 #include "utils/i18n/locale.h"
44 #include "utils/math/softmax.h"
45 #include "utils/normalization.h"
46 #include "utils/optional.h"
47 #include "utils/regex-match.h"
48 #include "utils/strings/append.h"
49 #include "utils/strings/numbers.h"
50 #include "utils/strings/split.h"
51 #include "utils/utf8/unicodetext.h"
52 #include "utils/utf8/unilib-common.h"
53 #include "utils/zlib/zlib_regex.h"
54 
55 namespace libtextclassifier3 {
56 
57 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58 
59 const std::string& Annotator::kPhoneCollection =
__anon17cae9720102() 60     *[]() { return new std::string("phone"); }();
61 const std::string& Annotator::kAddressCollection =
__anon17cae9720202() 62     *[]() { return new std::string("address"); }();
63 const std::string& Annotator::kDateCollection =
__anon17cae9720302() 64     *[]() { return new std::string("date"); }();
65 const std::string& Annotator::kUrlCollection =
__anon17cae9720402() 66     *[]() { return new std::string("url"); }();
67 const std::string& Annotator::kEmailCollection =
__anon17cae9720502() 68     *[]() { return new std::string("email"); }();
69 
70 namespace {
LoadAndVerifyModel(const void * addr,int size)71 const Model* LoadAndVerifyModel(const void* addr, int size) {
72   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73   if (VerifyModelBuffer(verifier)) {
74     return GetModel(addr);
75   } else {
76     return nullptr;
77   }
78 }
79 
LoadAndVerifyPersonNameModel(const void * addr,int size)80 const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81                                                     int size) {
82   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83   if (VerifyPersonNameModelBuffer(verifier)) {
84     return GetPersonNameModel(addr);
85   } else {
86     return nullptr;
87   }
88 }
89 
90 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92 const UniLib* MaybeCreateUnilib(const UniLib* lib,
93                                 std::unique_ptr<UniLib>* owned_lib) {
94   if (lib) {
95     return lib;
96   } else {
97     owned_lib->reset(new UniLib);
98     return owned_lib->get();
99   }
100 }
101 
102 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103 const CalendarLib* MaybeCreateCalendarlib(
104     const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105   if (lib) {
106     return lib;
107   } else {
108     owned_lib->reset(new CalendarLib);
109     return owned_lib->get();
110   }
111 }
112 
113 // Returns whether the provided input is valid:
114 //   * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115 bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116   return (span.first >= 0 && span.first < span.second &&
117           span.second <= context.size_codepoints());
118 }
119 
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120 std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121     const flatbuffers::Vector<int32_t>* ints) {
122   if (ints == nullptr) {
123     return {};
124   }
125   std::unordered_set<char32> ints_set;
126   for (auto value : *ints) {
127     ints_set.insert(static_cast<char32>(value));
128   }
129   return ints_set;
130 }
131 
132 }  // namespace
133 
SelectionInterpreter()134 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135   if (!selection_interpreter_) {
136     TC3_CHECK(selection_executor_);
137     selection_interpreter_ = selection_executor_->CreateInterpreter();
138     if (!selection_interpreter_) {
139       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140     }
141   }
142   return selection_interpreter_.get();
143 }
144 
ClassificationInterpreter()145 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146   if (!classification_interpreter_) {
147     TC3_CHECK(classification_executor_);
148     classification_interpreter_ = classification_executor_->CreateInterpreter();
149     if (!classification_interpreter_) {
150       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151     }
152   }
153   return classification_interpreter_.get();
154 }
155 
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157     const char* buffer, int size, const UniLib* unilib,
158     const CalendarLib* calendarlib) {
159   const Model* model = LoadAndVerifyModel(buffer, size);
160   if (model == nullptr) {
161     return nullptr;
162   }
163 
164   auto classifier = std::unique_ptr<Annotator>(new Annotator());
165   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166   calendarlib =
167       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168   classifier->ValidateAndInitialize(model, unilib, calendarlib);
169   if (!classifier->IsInitialized()) {
170     return nullptr;
171   }
172 
173   return classifier;
174 }
175 
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176 std::unique_ptr<Annotator> Annotator::FromString(
177     const std::string& buffer, const UniLib* unilib,
178     const CalendarLib* calendarlib) {
179   auto classifier = std::unique_ptr<Annotator>(new Annotator());
180   classifier->owned_buffer_ = buffer;
181   const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182                                           classifier->owned_buffer_.size());
183   if (model == nullptr) {
184     return nullptr;
185   }
186   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187   calendarlib =
188       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189   classifier->ValidateAndInitialize(model, unilib, calendarlib);
190   if (!classifier->IsInitialized()) {
191     return nullptr;
192   }
193 
194   return classifier;
195 }
196 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199     const CalendarLib* calendarlib) {
200   if (!(*mmap)->handle().ok()) {
201     TC3_VLOG(1) << "Mmap failed.";
202     return nullptr;
203   }
204 
205   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206                                           (*mmap)->handle().num_bytes());
207   if (!model) {
208     TC3_LOG(ERROR) << "Model verification failed.";
209     return nullptr;
210   }
211 
212   auto classifier = std::unique_ptr<Annotator>(new Annotator());
213   classifier->mmap_ = std::move(*mmap);
214   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215   calendarlib =
216       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217   classifier->ValidateAndInitialize(model, unilib, calendarlib);
218   if (!classifier->IsInitialized()) {
219     return nullptr;
220   }
221 
222   return classifier;
223 }
224 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226     std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227     std::unique_ptr<CalendarLib> calendarlib) {
228   if (!(*mmap)->handle().ok()) {
229     TC3_VLOG(1) << "Mmap failed.";
230     return nullptr;
231   }
232 
233   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234                                           (*mmap)->handle().num_bytes());
235   if (model == nullptr) {
236     TC3_LOG(ERROR) << "Model verification failed.";
237     return nullptr;
238   }
239 
240   auto classifier = std::unique_ptr<Annotator>(new Annotator());
241   classifier->mmap_ = std::move(*mmap);
242   classifier->owned_unilib_ = std::move(unilib);
243   classifier->owned_calendarlib_ = std::move(calendarlib);
244   classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245                                     classifier->owned_calendarlib_.get());
246   if (!classifier->IsInitialized()) {
247     return nullptr;
248   }
249 
250   return classifier;
251 }
252 
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254     int fd, int offset, int size, const UniLib* unilib,
255     const CalendarLib* calendarlib) {
256   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257   return FromScopedMmap(&mmap, unilib, calendarlib);
258 }
259 
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261     int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262     std::unique_ptr<CalendarLib> calendarlib) {
263   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265 }
266 
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268     int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270   return FromScopedMmap(&mmap, unilib, calendarlib);
271 }
272 
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274     int fd, std::unique_ptr<UniLib> unilib,
275     std::unique_ptr<CalendarLib> calendarlib) {
276   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278 }
279 
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281                                                const UniLib* unilib,
282                                                const CalendarLib* calendarlib) {
283   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284   return FromScopedMmap(&mmap, unilib, calendarlib);
285 }
286 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287 std::unique_ptr<Annotator> Annotator::FromPath(
288     const std::string& path, std::unique_ptr<UniLib> unilib,
289     std::unique_ptr<CalendarLib> calendarlib) {
290   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292 }
293 
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294 void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295                                       const CalendarLib* calendarlib) {
296   model_ = model;
297   unilib_ = unilib;
298   calendarlib_ = calendarlib;
299 
300   initialized_ = false;
301 
302   if (model_ == nullptr) {
303     TC3_LOG(ERROR) << "No model specified.";
304     return;
305   }
306 
307   const bool model_enabled_for_annotation =
308       (model_->triggering_options() != nullptr &&
309        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310   const bool model_enabled_for_classification =
311       (model_->triggering_options() != nullptr &&
312        (model_->triggering_options()->enabled_modes() &
313         ModeFlag_CLASSIFICATION));
314   const bool model_enabled_for_selection =
315       (model_->triggering_options() != nullptr &&
316        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317 
318   // Annotation requires the selection model.
319   if (model_enabled_for_annotation || model_enabled_for_selection) {
320     if (!model_->selection_options()) {
321       TC3_LOG(ERROR) << "No selection options.";
322       return;
323     }
324     if (!model_->selection_feature_options()) {
325       TC3_LOG(ERROR) << "No selection feature options.";
326       return;
327     }
328     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329       TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330       return;
331     }
332     if (!model_->selection_model()) {
333       TC3_LOG(ERROR) << "No selection model.";
334       return;
335     }
336     selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337     if (!selection_executor_) {
338       TC3_LOG(ERROR) << "Could not initialize selection executor.";
339       return;
340     }
341   }
342 
343   // Even if the annotation mode is not enabled (for the neural network model),
344   // the selection feature processor is needed to tokenize the text for other
345   // models.
346   if (model_->selection_feature_options()) {
347     selection_feature_processor_.reset(
348         new FeatureProcessor(model_->selection_feature_options(), unilib_));
349   }
350 
351   // Annotation requires the classification model for conflict resolution and
352   // scoring.
353   // Selection requires the classification model for conflict resolution.
354   if (model_enabled_for_annotation || model_enabled_for_classification ||
355       model_enabled_for_selection) {
356     if (!model_->classification_options()) {
357       TC3_LOG(ERROR) << "No classification options.";
358       return;
359     }
360 
361     if (!model_->classification_feature_options()) {
362       TC3_LOG(ERROR) << "No classification feature options.";
363       return;
364     }
365 
366     if (!model_->classification_feature_options()
367              ->bounds_sensitive_features()) {
368       TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
369       return;
370     }
371     if (!model_->classification_model()) {
372       TC3_LOG(ERROR) << "No clf model.";
373       return;
374     }
375 
376     classification_executor_ =
377         ModelExecutor::FromBuffer(model_->classification_model());
378     if (!classification_executor_) {
379       TC3_LOG(ERROR) << "Could not initialize classification executor.";
380       return;
381     }
382 
383     classification_feature_processor_.reset(new FeatureProcessor(
384         model_->classification_feature_options(), unilib_));
385   }
386 
387   // The embeddings need to be specified if the model is to be used for
388   // classification or selection.
389   if (model_enabled_for_annotation || model_enabled_for_classification ||
390       model_enabled_for_selection) {
391     if (!model_->embedding_model()) {
392       TC3_LOG(ERROR) << "No embedding model.";
393       return;
394     }
395 
396     // Check that the embedding size of the selection and classification model
397     // matches, as they are using the same embeddings.
398     if (model_enabled_for_selection &&
399         (model_->selection_feature_options()->embedding_size() !=
400              model_->classification_feature_options()->embedding_size() ||
401          model_->selection_feature_options()->embedding_quantization_bits() !=
402              model_->classification_feature_options()
403                  ->embedding_quantization_bits())) {
404       TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
405       return;
406     }
407 
408     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
409         model_->embedding_model(),
410         model_->classification_feature_options()->embedding_size(),
411         model_->classification_feature_options()->embedding_quantization_bits(),
412         model_->embedding_pruning_mask());
413     if (!embedding_executor_) {
414       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
415       return;
416     }
417   }
418 
419   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
420   if (model_->regex_model()) {
421     if (!InitializeRegexModel(decompressor.get())) {
422       TC3_LOG(ERROR) << "Could not initialize regex model.";
423       return;
424     }
425   }
426 
427   if (model_->datetime_grammar_model()) {
428     if (model_->datetime_grammar_model()->rules()) {
429       analyzer_ = std::make_unique<grammar::Analyzer>(
430           unilib_, model_->datetime_grammar_model()->rules());
431       datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
432       datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
433           *analyzer_, *datetime_grounder_,
434           /*target_classification_score=*/1.0,
435           /*priority_score=*/1.0);
436     }
437   } else if (model_->datetime_model()) {
438     datetime_parser_ = RegexDatetimeParser::Instance(
439         model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
440     if (!datetime_parser_) {
441       TC3_LOG(ERROR) << "Could not initialize datetime parser.";
442       return;
443     }
444   }
445 
446   if (model_->output_options()) {
447     if (model_->output_options()->filtered_collections_annotation()) {
448       for (const auto collection :
449            *model_->output_options()->filtered_collections_annotation()) {
450         filtered_collections_annotation_.insert(collection->str());
451       }
452     }
453     if (model_->output_options()->filtered_collections_classification()) {
454       for (const auto collection :
455            *model_->output_options()->filtered_collections_classification()) {
456         filtered_collections_classification_.insert(collection->str());
457       }
458     }
459     if (model_->output_options()->filtered_collections_selection()) {
460       for (const auto collection :
461            *model_->output_options()->filtered_collections_selection()) {
462         filtered_collections_selection_.insert(collection->str());
463       }
464     }
465   }
466 
467   if (model_->number_annotator_options() &&
468       model_->number_annotator_options()->enabled()) {
469     number_annotator_.reset(
470         new NumberAnnotator(model_->number_annotator_options(), unilib_));
471   }
472 
473   if (model_->money_parsing_options()) {
474     money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
475         model_->money_parsing_options()->separators());
476   }
477 
478   if (model_->duration_annotator_options() &&
479       model_->duration_annotator_options()->enabled()) {
480     duration_annotator_.reset(
481         new DurationAnnotator(model_->duration_annotator_options(),
482                               selection_feature_processor_.get(), unilib_));
483   }
484 
485   if (model_->grammar_model()) {
486     grammar_annotator_.reset(new GrammarAnnotator(
487         unilib_, model_->grammar_model(), entity_data_builder_.get()));
488   }
489 
490   // The following #ifdef is here to aid quality evaluation of a situation, when
491   // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
492   // it.
493 #if !defined(TC3_DISABLE_POD_NER)
494   if (model_->pod_ner_model()) {
495     pod_ner_annotator_ =
496         PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
497   }
498 #endif
499 
500   if (model_->vocab_model()) {
501     vocab_annotator_ = VocabAnnotator::Create(
502         model_->vocab_model(), *selection_feature_processor_, *unilib_);
503   }
504 
505   if (model_->entity_data_schema()) {
506     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
507         model_->entity_data_schema()->Data(),
508         model_->entity_data_schema()->size());
509     if (entity_data_schema_ == nullptr) {
510       TC3_LOG(ERROR) << "Could not load entity data schema data.";
511       return;
512     }
513 
514     entity_data_builder_.reset(
515         new MutableFlatbufferBuilder(entity_data_schema_));
516   } else {
517     entity_data_schema_ = nullptr;
518     entity_data_builder_ = nullptr;
519   }
520 
521   if (model_->triggering_locales() &&
522       !ParseLocales(model_->triggering_locales()->c_str(),
523                     &model_triggering_locales_)) {
524     TC3_LOG(ERROR) << "Could not parse model supported locales.";
525     return;
526   }
527 
528   if (model_->triggering_options() != nullptr &&
529       model_->triggering_options()->locales() != nullptr &&
530       !ParseLocales(model_->triggering_options()->locales()->c_str(),
531                     &ml_model_triggering_locales_)) {
532     TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
533     return;
534   }
535 
536   if (model_->triggering_options() != nullptr &&
537       model_->triggering_options()->dictionary_locales() != nullptr &&
538       !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
539                     &dictionary_locales_)) {
540     TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
541     return;
542   }
543 
544   if (model_->conflict_resolution_options() != nullptr) {
545     prioritize_longest_annotation_ =
546         model_->conflict_resolution_options()->prioritize_longest_annotation();
547     do_conflict_resolution_in_raw_mode_ =
548         model_->conflict_resolution_options()
549             ->do_conflict_resolution_in_raw_mode();
550   }
551 
552 #ifdef TC3_EXPERIMENTAL
553   TC3_LOG(WARNING) << "Enabling experimental annotators.";
554   InitializeExperimentalAnnotators();
555 #endif
556 
557   initialized_ = true;
558 }
559 
InitializeRegexModel(ZlibDecompressor * decompressor)560 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
561   if (!model_->regex_model()->patterns()) {
562     return true;
563   }
564 
565   // Initialize pattern recognizers.
566   int regex_pattern_id = 0;
567   for (const auto regex_pattern : *model_->regex_model()->patterns()) {
568     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
569         UncompressMakeRegexPattern(
570             *unilib_, regex_pattern->pattern(),
571             regex_pattern->compressed_pattern(),
572             model_->regex_model()->lazy_regex_compilation(), decompressor);
573     if (!compiled_pattern) {
574       TC3_LOG(INFO) << "Failed to load regex pattern";
575       return false;
576     }
577 
578     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
579       annotation_regex_patterns_.push_back(regex_pattern_id);
580     }
581     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
582       classification_regex_patterns_.push_back(regex_pattern_id);
583     }
584     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
585       selection_regex_patterns_.push_back(regex_pattern_id);
586     }
587     regex_patterns_.push_back({
588         regex_pattern,
589         std::move(compiled_pattern),
590     });
591     ++regex_pattern_id;
592   }
593 
594   return true;
595 }
596 
InitializeKnowledgeEngine(const std::string & serialized_config)597 bool Annotator::InitializeKnowledgeEngine(
598     const std::string& serialized_config) {
599   std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
600   if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
601     TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
602     return false;
603   }
604   if (model_->triggering_options() != nullptr) {
605     knowledge_engine->SetPriorityScore(
606         model_->triggering_options()->knowledge_priority_score());
607   }
608   knowledge_engine_ = std::move(knowledge_engine);
609   return true;
610 }
611 
InitializeContactEngine(const std::string & serialized_config)612 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
613   std::unique_ptr<ContactEngine> contact_engine(
614       new ContactEngine(selection_feature_processor_.get(), unilib_,
615                         model_->contact_annotator_options()));
616   if (!contact_engine->Initialize(serialized_config)) {
617     TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
618     return false;
619   }
620   contact_engine_ = std::move(contact_engine);
621   return true;
622 }
623 
InitializeInstalledAppEngine(const std::string & serialized_config)624 bool Annotator::InitializeInstalledAppEngine(
625     const std::string& serialized_config) {
626   std::unique_ptr<InstalledAppEngine> installed_app_engine(
627       new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
628   if (!installed_app_engine->Initialize(serialized_config)) {
629     TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
630     return false;
631   }
632   installed_app_engine_ = std::move(installed_app_engine);
633   return true;
634 }
635 
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)636 bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
637   if (lang_id == nullptr) {
638     return false;
639   }
640 
641   lang_id_ = lang_id;
642   if (lang_id_ != nullptr && model_->translate_annotator_options() &&
643       model_->translate_annotator_options()->enabled()) {
644     translate_annotator_.reset(new TranslateAnnotator(
645         model_->translate_annotator_options(), lang_id_, unilib_));
646   } else {
647     translate_annotator_.reset(nullptr);
648   }
649   return true;
650 }
651 
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)652 bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
653                                                             int size) {
654   const PersonNameModel* person_name_model =
655       LoadAndVerifyPersonNameModel(buffer, size);
656 
657   if (person_name_model == nullptr) {
658     TC3_LOG(ERROR) << "Person name model verification failed.";
659     return false;
660   }
661 
662   if (!person_name_model->enabled()) {
663     return true;
664   }
665 
666   std::unique_ptr<PersonNameEngine> person_name_engine(
667       new PersonNameEngine(selection_feature_processor_.get(), unilib_));
668   if (!person_name_engine->Initialize(person_name_model)) {
669     TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
670     return false;
671   }
672   person_name_engine_ = std::move(person_name_engine);
673   return true;
674 }
675 
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)676 bool Annotator::InitializePersonNameEngineFromScopedMmap(
677     const ScopedMmap& mmap) {
678   if (!mmap.handle().ok()) {
679     TC3_LOG(ERROR) << "Mmap for person name model failed.";
680     return false;
681   }
682 
683   return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
684                                                      mmap.handle().num_bytes());
685 }
686 
InitializePersonNameEngineFromPath(const std::string & path)687 bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
688   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
689   return InitializePersonNameEngineFromScopedMmap(*mmap);
690 }
691 
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)692 bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
693                                                              int size) {
694   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
695   return InitializePersonNameEngineFromScopedMmap(*mmap);
696 }
697 
InitializeExperimentalAnnotators()698 bool Annotator::InitializeExperimentalAnnotators() {
699   if (ExperimentalAnnotator::IsEnabled()) {
700     experimental_annotator_.reset(new ExperimentalAnnotator(
701         model_->experimental_model(), *selection_feature_processor_, *unilib_));
702     return true;
703   }
704   return false;
705 }
706 
707 namespace internal {
708 // Helper function, which if the initial 'span' contains only white-spaces,
709 // moves the selection to a single-codepoint selection on a left or right side
710 // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)711 CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
712                                             const UnicodeText& context_unicode,
713                                             const UniLib& unilib) {
714   TC3_CHECK(span.IsValid() && !span.IsEmpty());
715 
716   UnicodeText::const_iterator it;
717 
718   // Check that the current selection is all whitespaces.
719   it = context_unicode.begin();
720   std::advance(it, span.first);
721   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
722     if (!unilib.IsWhitespace(*it)) {
723       return span;
724     }
725   }
726 
727   // Try moving left.
728   CodepointSpan result = span;
729   it = context_unicode.begin();
730   std::advance(it, span.first);
731   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
732     --result.first;
733     --it;
734   }
735   result.second = result.first + 1;
736   if (!unilib.IsWhitespace(*it)) {
737     return result;
738   }
739 
740   // If moving left didn't find a non-whitespace character, just return the
741   // original span.
742   return span;
743 }
744 }  // namespace internal
745 
FilteredForAnnotation(const AnnotatedSpan & span) const746 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
747   return !span.classification.empty() &&
748          filtered_collections_annotation_.find(
749              span.classification[0].collection) !=
750              filtered_collections_annotation_.end();
751 }
752 
FilteredForClassification(const ClassificationResult & classification) const753 bool Annotator::FilteredForClassification(
754     const ClassificationResult& classification) const {
755   return filtered_collections_classification_.find(classification.collection) !=
756          filtered_collections_classification_.end();
757 }
758 
FilteredForSelection(const AnnotatedSpan & span) const759 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
760   return !span.classification.empty() &&
761          filtered_collections_selection_.find(
762              span.classification[0].collection) !=
763              filtered_collections_selection_.end();
764 }
765 
766 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)767 inline bool ClassifiedAsOther(
768     const std::vector<ClassificationResult>& classification) {
769   return !classification.empty() &&
770          classification[0].collection == Collections::Other();
771 }
772 
773 }  // namespace
774 
GetPriorityScore(const std::vector<ClassificationResult> & classification) const775 float Annotator::GetPriorityScore(
776     const std::vector<ClassificationResult>& classification) const {
777   if (!classification.empty() && !ClassifiedAsOther(classification)) {
778     return classification[0].priority_score;
779   } else {
780     if (model_->triggering_options() != nullptr) {
781       return model_->triggering_options()->other_collection_priority_score();
782     } else {
783       return -1000.0;
784     }
785   }
786 }
787 
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const788 bool Annotator::VerifyRegexMatchCandidate(
789     const std::string& context, const VerificationOptions* verification_options,
790     const std::string& match, const UniLib::RegexMatcher* matcher) const {
791   if (verification_options == nullptr) {
792     return true;
793   }
794   if (verification_options->verify_luhn_checksum() &&
795       !VerifyLuhnChecksum(match)) {
796     return false;
797   }
798   const int lua_verifier = verification_options->lua_verifier();
799   if (lua_verifier >= 0) {
800     if (model_->regex_model()->lua_verifier() == nullptr ||
801         lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
802       TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
803       return false;
804     }
805     return VerifyMatch(
806         context, matcher,
807         model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
808   }
809   return true;
810 }
811 
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const812 CodepointSpan Annotator::SuggestSelection(
813     const std::string& context, CodepointSpan click_indices,
814     const SelectionOptions& options) const {
815   if (context.size() > std::numeric_limits<int>::max()) {
816     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
817     return {};
818   }
819 
820   CodepointSpan original_click_indices = click_indices;
821   if (!initialized_) {
822     TC3_LOG(ERROR) << "Not initialized";
823     return original_click_indices;
824   }
825   if (options.annotation_usecase !=
826       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
827     TC3_LOG(WARNING)
828         << "Invoking SuggestSelection, which is not supported in RAW mode.";
829     return original_click_indices;
830   }
831   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
832     return original_click_indices;
833   }
834 
835   std::vector<Locale> detected_text_language_tags;
836   if (!ParseLocales(options.detected_text_language_tags,
837                     &detected_text_language_tags)) {
838     TC3_LOG(WARNING)
839         << "Failed to parse the detected_text_language_tags in options: "
840         << options.detected_text_language_tags;
841   }
842   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
843                                     model_triggering_locales_,
844                                     /*default_value=*/true)) {
845     return original_click_indices;
846   }
847 
848   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
849                                                         /*do_copy=*/false);
850 
851   if (!unilib_->IsValidUtf8(context_unicode)) {
852     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
853     return original_click_indices;
854   }
855 
856   if (!IsValidSpanInput(context_unicode, click_indices)) {
857     TC3_VLOG(1)
858         << "Trying to run SuggestSelection with invalid input, indices: "
859         << click_indices.first << " " << click_indices.second;
860     return original_click_indices;
861   }
862 
863   if (model_->snap_whitespace_selections()) {
864     // We want to expand a purely white-space selection to a multi-selection it
865     // would've been part of. But with this feature disabled we would do a no-
866     // op, because no token is found. Therefore, we need to modify the
867     // 'click_indices' a bit to include a part of the token, so that the click-
868     // finding logic finds the clicked token correctly. This modification is
869     // done by the following function. Note, that it's enough to check the left
870     // side of the current selection, because if the white-space is a part of a
871     // multi-selection, necessarily both tokens - on the left and the right
872     // sides need to be selected. Thus snapping only to the left is sufficient
873     // (there's a check at the bottom that makes sure that if we snap to the
874     // left token but the result does not contain the initial white-space,
875     // returns the original indices).
876     click_indices = internal::SnapLeftIfWhitespaceSelection(
877         click_indices, context_unicode, *unilib_);
878   }
879 
880   Annotations candidates;
881   // As we process a single string of context, the candidates will only
882   // contain one vector of AnnotatedSpan.
883   candidates.annotated_spans.resize(1);
884   InterpreterManager interpreter_manager(selection_executor_.get(),
885                                          classification_executor_.get());
886   std::vector<Token> tokens;
887   if (!ModelSuggestSelection(context_unicode, click_indices,
888                              detected_text_language_tags, &interpreter_manager,
889                              &tokens, &candidates.annotated_spans[0])) {
890     TC3_LOG(ERROR) << "Model suggest selection failed.";
891     return original_click_indices;
892   }
893   const std::unordered_set<std::string> set;
894   const EnabledEntityTypes is_entity_type_enabled(set);
895   if (!RegexChunk(context_unicode, selection_regex_patterns_,
896                   /*is_serialized_entity_data_enabled=*/false,
897                   is_entity_type_enabled, options.annotation_usecase,
898                   &candidates.annotated_spans[0])) {
899     TC3_LOG(ERROR) << "Regex suggest selection failed.";
900     return original_click_indices;
901   }
902   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
903                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
904                      options.locales, ModeFlag_SELECTION,
905                      options.annotation_usecase,
906                      /*is_serialized_entity_data_enabled=*/false,
907                      &candidates.annotated_spans[0])) {
908     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
909     return original_click_indices;
910   }
911   if (knowledge_engine_ != nullptr &&
912       !knowledge_engine_
913            ->Chunk(context, options.annotation_usecase,
914                    options.location_context, Permissions(),
915                    AnnotateMode::kEntityAnnotation, &candidates)
916            .ok()) {
917     TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
918     return original_click_indices;
919   }
920   if (contact_engine_ != nullptr &&
921       !contact_engine_->Chunk(context_unicode, tokens,
922                               &candidates.annotated_spans[0])) {
923     TC3_LOG(ERROR) << "Contact suggest selection failed.";
924     return original_click_indices;
925   }
926   if (installed_app_engine_ != nullptr &&
927       !installed_app_engine_->Chunk(context_unicode, tokens,
928                                     &candidates.annotated_spans[0])) {
929     TC3_LOG(ERROR) << "Installed app suggest selection failed.";
930     return original_click_indices;
931   }
932   if (number_annotator_ != nullptr &&
933       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
934                                   &candidates.annotated_spans[0])) {
935     TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
936     return original_click_indices;
937   }
938   if (duration_annotator_ != nullptr &&
939       !duration_annotator_->FindAll(context_unicode, tokens,
940                                     options.annotation_usecase,
941                                     &candidates.annotated_spans[0])) {
942     TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
943     return original_click_indices;
944   }
945   if (person_name_engine_ != nullptr &&
946       !person_name_engine_->Chunk(context_unicode, tokens,
947                                   &candidates.annotated_spans[0])) {
948     TC3_LOG(ERROR) << "Person name suggest selection failed.";
949     return original_click_indices;
950   }
951 
952   AnnotatedSpan grammar_suggested_span;
953   if (grammar_annotator_ != nullptr &&
954       grammar_annotator_->SuggestSelection(detected_text_language_tags,
955                                            context_unicode, click_indices,
956                                            &grammar_suggested_span)) {
957     candidates.annotated_spans[0].push_back(grammar_suggested_span);
958   }
959 
960   AnnotatedSpan pod_ner_suggested_span;
961   if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
962       pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
963                                            &pod_ner_suggested_span)) {
964     candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
965   }
966 
967   if (experimental_annotator_ != nullptr) {
968     candidates.annotated_spans[0].push_back(
969         experimental_annotator_->SuggestSelection(context_unicode,
970                                                   click_indices));
971   }
972 
973   // Sort candidates according to their position in the input, so that the next
974   // code can assume that any connected component of overlapping spans forms a
975   // contiguous block.
976   std::stable_sort(candidates.annotated_spans[0].begin(),
977                    candidates.annotated_spans[0].end(),
978                    [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
979                      return a.span.first < b.span.first;
980                    });
981 
982   std::vector<int> candidate_indices;
983   if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
984                         detected_text_language_tags, options,
985                         &interpreter_manager, &candidate_indices)) {
986     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
987     return original_click_indices;
988   }
989 
990   std::stable_sort(
991       candidate_indices.begin(), candidate_indices.end(),
992       [this, &candidates](int a, int b) {
993         return GetPriorityScore(
994                    candidates.annotated_spans[0][a].classification) >
995                GetPriorityScore(
996                    candidates.annotated_spans[0][b].classification);
997       });
998 
999   for (const int i : candidate_indices) {
1000     if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
1001         SpansOverlap(candidates.annotated_spans[0][i].span,
1002                      original_click_indices)) {
1003       // Run model classification if not present but requested and there's a
1004       // classification collection filter specified.
1005       if (candidates.annotated_spans[0][i].classification.empty() &&
1006           model_->selection_options()->always_classify_suggested_selection() &&
1007           !filtered_collections_selection_.empty()) {
1008         if (!ModelClassifyText(context, /*cached_tokens=*/{},
1009                                detected_text_language_tags,
1010                                candidates.annotated_spans[0][i].span, options,
1011                                &interpreter_manager,
1012                                /*embedding_cache=*/nullptr,
1013                                &candidates.annotated_spans[0][i].classification,
1014                                /*tokens=*/nullptr)) {
1015           return original_click_indices;
1016         }
1017       }
1018 
1019       // Ignore if span classification is filtered.
1020       if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1021         return original_click_indices;
1022       }
1023 
1024       // We return a suggested span contains the original span.
1025       // This compensates for "select all" selection that may come from
1026       // other apps. See http://b/179890518.
1027       if (SpanContains(candidates.annotated_spans[0][i].span,
1028                        original_click_indices)) {
1029         return candidates.annotated_spans[0][i].span;
1030       }
1031     }
1032   }
1033 
1034   return original_click_indices;
1035 }
1036 
1037 namespace {
1038 // Helper function that returns the index of the first candidate that
1039 // transitively does not overlap with the candidate on 'start_index'. If the end
1040 // of 'candidates' is reached, it returns the index that points right behind the
1041 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1042 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1043                                  int start_index) {
1044   int first_non_overlapping = start_index + 1;
1045   CodepointSpan conflicting_span = candidates[start_index].span;
1046   while (
1047       first_non_overlapping < candidates.size() &&
1048       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1049     // Grow the span to include the current one.
1050     conflicting_span.second = std::max(
1051         conflicting_span.second, candidates[first_non_overlapping].span.second);
1052 
1053     ++first_non_overlapping;
1054   }
1055   return first_non_overlapping;
1056 }
1057 }  // namespace
1058 
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1059 bool Annotator::ResolveConflicts(
1060     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1061     const std::vector<Token>& cached_tokens,
1062     const std::vector<Locale>& detected_text_language_tags,
1063     const BaseOptions& options, InterpreterManager* interpreter_manager,
1064     std::vector<int>* result) const {
1065   result->clear();
1066   result->reserve(candidates.size());
1067   for (int i = 0; i < candidates.size();) {
1068     int first_non_overlapping =
1069         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1070 
1071     const bool conflict_found = first_non_overlapping != (i + 1);
1072     if (conflict_found) {
1073       std::vector<int> candidate_indices;
1074       if (!ResolveConflict(context, cached_tokens, candidates,
1075                            detected_text_language_tags, i,
1076                            first_non_overlapping, options, interpreter_manager,
1077                            &candidate_indices)) {
1078         return false;
1079       }
1080       result->insert(result->end(), candidate_indices.begin(),
1081                      candidate_indices.end());
1082     } else {
1083       result->push_back(i);
1084     }
1085 
1086     // Skip over the whole conflicting group/go to next candidate.
1087     i = first_non_overlapping;
1088   }
1089   return true;
1090 }
1091 
1092 namespace {
1093 // Returns true, if the given two sources do conflict in given annotation
1094 // usecase.
1095 //  - In SMART usecase, all sources do conflict, because there's only 1 possible
1096 //  annotation for a given span.
1097 //  - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1098 //  and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1099 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1100                        const AnnotatedSpan::Source source1,
1101                        const AnnotatedSpan::Source source2) {
1102   uint32 source_mask =
1103       (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1104 
1105   switch (annotation_usecase) {
1106     case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1107       // In the SMART mode, all annotations conflict.
1108       return true;
1109 
1110     case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1111       // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1112       // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1113       // hours" (duration).
1114       if ((source_mask &
1115            (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1116           (source_mask &
1117            (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1118         return false;
1119       }
1120 
1121       // A KNOWLEDGE entity does not conflict with anything.
1122       if ((source_mask &
1123            (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1124         return false;
1125       }
1126 
1127       // A PERSONNAME entity does not conflict with anything.
1128       if ((source_mask &
1129            (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1130         return false;
1131       }
1132 
1133       // Entities from other sources can conflict.
1134       return true;
1135   }
1136 }
1137 }  // namespace
1138 
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1139 bool Annotator::ResolveConflict(
1140     const std::string& context, const std::vector<Token>& cached_tokens,
1141     const std::vector<AnnotatedSpan>& candidates,
1142     const std::vector<Locale>& detected_text_language_tags, int start_index,
1143     int end_index, const BaseOptions& options,
1144     InterpreterManager* interpreter_manager,
1145     std::vector<int>* chosen_indices) const {
1146   std::vector<int> conflicting_indices;
1147   std::unordered_map<int, std::pair<float, int>> scores_lengths;
1148   for (int i = start_index; i < end_index; ++i) {
1149     conflicting_indices.push_back(i);
1150     if (!candidates[i].classification.empty()) {
1151       scores_lengths[i] = {
1152           GetPriorityScore(candidates[i].classification),
1153           candidates[i].span.second - candidates[i].span.first};
1154       continue;
1155     }
1156 
1157     // OPTIMIZATION: So that we don't have to classify all the ML model
1158     // spans apriori, we wait until we get here, when they conflict with
1159     // something and we need the actual classification scores. So if the
1160     // candidate conflicts and comes from the model, we need to run a
1161     // classification to determine its priority:
1162     std::vector<ClassificationResult> classification;
1163     if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1164                            candidates[i].span, options, interpreter_manager,
1165                            /*embedding_cache=*/nullptr, &classification,
1166                            /*tokens=*/nullptr)) {
1167       return false;
1168     }
1169 
1170     if (!classification.empty()) {
1171       scores_lengths[i] = {
1172           GetPriorityScore(classification),
1173           candidates[i].span.second - candidates[i].span.first};
1174     }
1175   }
1176 
1177   std::stable_sort(
1178       conflicting_indices.begin(), conflicting_indices.end(),
1179       [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1180         if (scores_lengths[i].first == scores_lengths[j].first &&
1181             prioritize_longest_annotation_) {
1182           return scores_lengths[i].second > scores_lengths[j].second;
1183         }
1184         return scores_lengths[i].first > scores_lengths[j].first;
1185       });
1186 
1187   // Here we keep a set of indices that were chosen, per-source, to enable
1188   // effective computation.
1189   std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1190       chosen_indices_for_source_map;
1191 
1192   // Greedily place the candidates if they don't conflict with the already
1193   // placed ones.
1194   for (int i = 0; i < conflicting_indices.size(); ++i) {
1195     const int considered_candidate = conflicting_indices[i];
1196 
1197     // See if there is a conflict between the candidate and all already placed
1198     // candidates.
1199     bool conflict = false;
1200     SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1201     for (auto& source_set_pair : chosen_indices_for_source_map) {
1202       if (source_set_pair.first == candidates[considered_candidate].source) {
1203         chosen_indices_for_source_ptr = &source_set_pair.second;
1204       }
1205 
1206       const bool needs_conflict_resolution =
1207           options.annotation_usecase ==
1208               AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1209           (options.annotation_usecase ==
1210                AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1211            do_conflict_resolution_in_raw_mode_);
1212       if (needs_conflict_resolution &&
1213           DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1214                             candidates[considered_candidate].source) &&
1215           DoesCandidateConflict(considered_candidate, candidates,
1216                                 source_set_pair.second)) {
1217         conflict = true;
1218         break;
1219       }
1220     }
1221 
1222     // Skip the candidate if a conflict was found.
1223     if (conflict) {
1224       continue;
1225     }
1226 
1227     // If the set of indices for the current source doesn't exist yet,
1228     // initialize it.
1229     if (chosen_indices_for_source_ptr == nullptr) {
1230       SortedIntSet new_set([&candidates](int a, int b) {
1231         return candidates[a].span.first < candidates[b].span.first;
1232       });
1233       chosen_indices_for_source_map[candidates[considered_candidate].source] =
1234           std::move(new_set);
1235       chosen_indices_for_source_ptr =
1236           &chosen_indices_for_source_map[candidates[considered_candidate]
1237                                              .source];
1238     }
1239 
1240     // Place the candidate to the output and to the per-source conflict set.
1241     chosen_indices->push_back(considered_candidate);
1242     chosen_indices_for_source_ptr->insert(considered_candidate);
1243   }
1244 
1245   std::stable_sort(chosen_indices->begin(), chosen_indices->end());
1246 
1247   return true;
1248 }
1249 
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1250 bool Annotator::ModelSuggestSelection(
1251     const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1252     const std::vector<Locale>& detected_text_language_tags,
1253     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1254     std::vector<AnnotatedSpan>* result) const {
1255   if (model_->triggering_options() == nullptr ||
1256       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1257     return true;
1258   }
1259 
1260   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1261                                     ml_model_triggering_locales_,
1262                                     /*default_value=*/true)) {
1263     return true;
1264   }
1265 
1266   int click_pos;
1267   *tokens = selection_feature_processor_->Tokenize(context_unicode);
1268   const auto [click_begin, click_end] =
1269       CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1270   selection_feature_processor_->RetokenizeAndFindClick(
1271       context_unicode, click_begin, click_end, click_indices,
1272       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1273       tokens, &click_pos);
1274   if (click_pos == kInvalidIndex) {
1275     TC3_VLOG(1) << "Could not calculate the click position.";
1276     return false;
1277   }
1278 
1279   const int symmetry_context_size =
1280       model_->selection_options()->symmetry_context_size();
1281   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1282       bounds_sensitive_features = selection_feature_processor_->GetOptions()
1283                                       ->bounds_sensitive_features();
1284 
1285   // The symmetry context span is the clicked token with symmetry_context_size
1286   // tokens on either side.
1287   const TokenSpan symmetry_context_span =
1288       IntersectTokenSpans(TokenSpan(click_pos).Expand(
1289                               /*num_tokens_left=*/symmetry_context_size,
1290                               /*num_tokens_right=*/symmetry_context_size),
1291                           AllOf(*tokens));
1292 
1293   // Compute the extraction span based on the model type.
1294   TokenSpan extraction_span;
1295   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1296     // The extraction span is the symmetry context span expanded to include
1297     // max_selection_span tokens on either side, which is how far a selection
1298     // can stretch from the click, plus a relevant number of tokens outside of
1299     // the bounds of the selection.
1300     const int max_selection_span =
1301         selection_feature_processor_->GetOptions()->max_selection_span();
1302     extraction_span = symmetry_context_span.Expand(
1303         /*num_tokens_left=*/max_selection_span +
1304             bounds_sensitive_features->num_tokens_before(),
1305         /*num_tokens_right=*/max_selection_span +
1306             bounds_sensitive_features->num_tokens_after());
1307   } else {
1308     // The extraction span is the symmetry context span expanded to include
1309     // context_size tokens on either side.
1310     const int context_size =
1311         selection_feature_processor_->GetOptions()->context_size();
1312     extraction_span = symmetry_context_span.Expand(
1313         /*num_tokens_left=*/context_size,
1314         /*num_tokens_right=*/context_size);
1315   }
1316   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1317 
1318   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1319           *tokens, extraction_span)) {
1320     return true;
1321   }
1322 
1323   std::unique_ptr<CachedFeatures> cached_features;
1324   if (!selection_feature_processor_->ExtractFeatures(
1325           *tokens, extraction_span,
1326           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1327           embedding_executor_.get(),
1328           /*embedding_cache=*/nullptr,
1329           selection_feature_processor_->EmbeddingSize() +
1330               selection_feature_processor_->DenseFeaturesCount(),
1331           &cached_features)) {
1332     TC3_LOG(ERROR) << "Could not extract features.";
1333     return false;
1334   }
1335 
1336   // Produce selection model candidates.
1337   std::vector<TokenSpan> chunks;
1338   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1339                   interpreter_manager->SelectionInterpreter(), *cached_features,
1340                   &chunks)) {
1341     TC3_LOG(ERROR) << "Could not chunk.";
1342     return false;
1343   }
1344 
1345   for (const TokenSpan& chunk : chunks) {
1346     AnnotatedSpan candidate;
1347     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1348         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1349     if (model_->selection_options()->strip_unpaired_brackets()) {
1350       candidate.span =
1351           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1352     }
1353 
1354     // Only output non-empty spans.
1355     if (candidate.span.first != candidate.span.second) {
1356       result->push_back(candidate);
1357     }
1358   }
1359   return true;
1360 }
1361 
1362 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1363 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1364                                     const CodepointSpan& selection_indices,
1365                                     TokenSpan tokens_around_selection_to_copy) {
1366   const auto first_selection_token = std::upper_bound(
1367       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1368       [](int selection_start, const Token& token) {
1369         return selection_start < token.end;
1370       });
1371   const auto last_selection_token = std::lower_bound(
1372       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1373       [](const Token& token, int selection_end) {
1374         return token.start < selection_end;
1375       });
1376 
1377   const int64 first_token = std::max(
1378       static_cast<int64>(0),
1379       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1380                          tokens_around_selection_to_copy.first));
1381   const int64 last_token = std::min(
1382       static_cast<int64>(cached_tokens.size()),
1383       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1384                          tokens_around_selection_to_copy.second));
1385 
1386   std::vector<Token> tokens;
1387   tokens.reserve(last_token - first_token);
1388   for (int i = first_token; i < last_token; ++i) {
1389     tokens.push_back(cached_tokens[i]);
1390   }
1391   return tokens;
1392 }
1393 }  // namespace internal
1394 
ClassifyTextUpperBoundNeededTokens() const1395 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1396   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1397       bounds_sensitive_features =
1398           classification_feature_processor_->GetOptions()
1399               ->bounds_sensitive_features();
1400   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1401     // The extraction span is the selection span expanded to include a relevant
1402     // number of tokens outside of the bounds of the selection.
1403     return {bounds_sensitive_features->num_tokens_before(),
1404             bounds_sensitive_features->num_tokens_after()};
1405   } else {
1406     // The extraction span is the clicked token with context_size tokens on
1407     // either side.
1408     const int context_size =
1409         selection_feature_processor_->GetOptions()->context_size();
1410     return {context_size, context_size};
1411   }
1412 }
1413 
1414 namespace {
1415 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1416 void SortClassificationResults(
1417     std::vector<ClassificationResult>* classification_results) {
1418   std::stable_sort(
1419       classification_results->begin(), classification_results->end(),
1420       [](const ClassificationResult& a, const ClassificationResult& b) {
1421         return a.score > b.score;
1422       });
1423 }
1424 }  // namespace
1425 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1426 bool Annotator::ModelClassifyText(
1427     const std::string& context, const std::vector<Token>& cached_tokens,
1428     const std::vector<Locale>& detected_text_language_tags,
1429     const CodepointSpan& selection_indices, const BaseOptions& options,
1430     InterpreterManager* interpreter_manager,
1431     FeatureProcessor::EmbeddingCache* embedding_cache,
1432     std::vector<ClassificationResult>* classification_results,
1433     std::vector<Token>* tokens) const {
1434   const UnicodeText context_unicode =
1435       UTF8ToUnicodeText(context, /*do_copy=*/false);
1436   const auto [span_begin, span_end] =
1437       CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1438   return ModelClassifyText(context_unicode, cached_tokens,
1439                            detected_text_language_tags, span_begin, span_end,
1440                            /*line=*/nullptr, selection_indices, options,
1441                            interpreter_manager, embedding_cache,
1442                            classification_results, tokens);
1443 }
1444 
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1445 bool Annotator::ModelClassifyText(
1446     const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1447     const std::vector<Locale>& detected_text_language_tags,
1448     const UnicodeText::const_iterator& span_begin,
1449     const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1450     const CodepointSpan& selection_indices, const BaseOptions& options,
1451     InterpreterManager* interpreter_manager,
1452     FeatureProcessor::EmbeddingCache* embedding_cache,
1453     std::vector<ClassificationResult>* classification_results,
1454     std::vector<Token>* tokens) const {
1455   if (model_->triggering_options() == nullptr ||
1456       !(model_->triggering_options()->enabled_modes() &
1457         ModeFlag_CLASSIFICATION)) {
1458     return true;
1459   }
1460 
1461   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1462                                     ml_model_triggering_locales_,
1463                                     /*default_value=*/true)) {
1464     return true;
1465   }
1466 
1467   std::vector<Token> local_tokens;
1468   if (tokens == nullptr) {
1469     tokens = &local_tokens;
1470   }
1471 
1472   if (cached_tokens.empty()) {
1473     *tokens = classification_feature_processor_->Tokenize(context_unicode);
1474   } else {
1475     *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1476                                          ClassifyTextUpperBoundNeededTokens());
1477   }
1478 
1479   int click_pos;
1480   classification_feature_processor_->RetokenizeAndFindClick(
1481       context_unicode, span_begin, span_end, selection_indices,
1482       classification_feature_processor_->GetOptions()
1483           ->only_use_line_with_click(),
1484       tokens, &click_pos);
1485   const TokenSpan selection_token_span =
1486       CodepointSpanToTokenSpan(*tokens, selection_indices);
1487   const int selection_num_tokens = selection_token_span.Size();
1488   if (model_->classification_options()->max_num_tokens() > 0 &&
1489       model_->classification_options()->max_num_tokens() <
1490           selection_num_tokens) {
1491     *classification_results = {{Collections::Other(), 1.0}};
1492     return true;
1493   }
1494 
1495   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1496       bounds_sensitive_features =
1497           classification_feature_processor_->GetOptions()
1498               ->bounds_sensitive_features();
1499   if (selection_token_span.first == kInvalidIndex ||
1500       selection_token_span.second == kInvalidIndex) {
1501     TC3_LOG(ERROR) << "Could not determine span.";
1502     return false;
1503   }
1504 
1505   // Compute the extraction span based on the model type.
1506   TokenSpan extraction_span;
1507   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1508     // The extraction span is the selection span expanded to include a relevant
1509     // number of tokens outside of the bounds of the selection.
1510     extraction_span = selection_token_span.Expand(
1511         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1512         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1513   } else {
1514     if (click_pos == kInvalidIndex) {
1515       TC3_LOG(ERROR) << "Couldn't choose a click position.";
1516       return false;
1517     }
1518     // The extraction span is the clicked token with context_size tokens on
1519     // either side.
1520     const int context_size =
1521         classification_feature_processor_->GetOptions()->context_size();
1522     extraction_span = TokenSpan(click_pos).Expand(
1523         /*num_tokens_left=*/context_size,
1524         /*num_tokens_right=*/context_size);
1525   }
1526   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1527 
1528   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1529           *tokens, extraction_span)) {
1530     *classification_results = {{Collections::Other(), 1.0}};
1531     return true;
1532   }
1533 
1534   std::unique_ptr<CachedFeatures> cached_features;
1535   if (!classification_feature_processor_->ExtractFeatures(
1536           *tokens, extraction_span, selection_indices,
1537           embedding_executor_.get(), embedding_cache,
1538           classification_feature_processor_->EmbeddingSize() +
1539               classification_feature_processor_->DenseFeaturesCount(),
1540           &cached_features)) {
1541     TC3_LOG(ERROR) << "Could not extract features.";
1542     return false;
1543   }
1544 
1545   std::vector<float> features;
1546   features.reserve(cached_features->OutputFeaturesSize());
1547   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1548     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1549                                                           &features);
1550   } else {
1551     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1552   }
1553 
1554   TensorView<float> logits = classification_executor_->ComputeLogits(
1555       TensorView<float>(features.data(),
1556                         {1, static_cast<int>(features.size())}),
1557       interpreter_manager->ClassificationInterpreter());
1558   if (!logits.is_valid()) {
1559     TC3_LOG(ERROR) << "Couldn't compute logits.";
1560     return false;
1561   }
1562 
1563   if (logits.dims() != 2 || logits.dim(0) != 1 ||
1564       logits.dim(1) != classification_feature_processor_->NumCollections()) {
1565     TC3_LOG(ERROR) << "Mismatching output";
1566     return false;
1567   }
1568 
1569   const std::vector<float> scores =
1570       ComputeSoftmax(logits.data(), logits.dim(1));
1571 
1572   if (scores.empty()) {
1573     *classification_results = {{Collections::Other(), 1.0}};
1574     return true;
1575   }
1576 
1577   const int best_score_index =
1578       std::max_element(scores.begin(), scores.end()) - scores.begin();
1579   const std::string top_collection =
1580       classification_feature_processor_->LabelToCollection(best_score_index);
1581 
1582   // Sanity checks.
1583   if (top_collection == Collections::Phone()) {
1584     const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1585     if (digit_count <
1586             model_->classification_options()->phone_min_num_digits() ||
1587         digit_count >
1588             model_->classification_options()->phone_max_num_digits()) {
1589       *classification_results = {{Collections::Other(), 1.0}};
1590       return true;
1591     }
1592   } else if (top_collection == Collections::Address()) {
1593     if (selection_num_tokens <
1594         model_->classification_options()->address_min_num_tokens()) {
1595       *classification_results = {{Collections::Other(), 1.0}};
1596       return true;
1597     }
1598   } else if (top_collection == Collections::Dictionary()) {
1599     if ((options.use_vocab_annotator && vocab_annotator_) ||
1600         !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1601                                       dictionary_locales_,
1602                                       /*default_value=*/false)) {
1603       *classification_results = {{Collections::Other(), 1.0}};
1604       return true;
1605     }
1606   }
1607   *classification_results = {{top_collection, /*arg_score=*/1.0,
1608                               /*arg_priority_score=*/scores[best_score_index]}};
1609 
1610   // For some entities, we might want to clamp the priority score, for better
1611   // conflict resolution between entities.
1612   if (model_->triggering_options() != nullptr &&
1613       model_->triggering_options()->collection_to_priority() != nullptr) {
1614     if (auto entry =
1615             model_->triggering_options()->collection_to_priority()->LookupByKey(
1616                 top_collection.c_str())) {
1617       (*classification_results)[0].priority_score *= entry->value();
1618     }
1619   }
1620   return true;
1621 }
1622 
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1623 bool Annotator::RegexClassifyText(
1624     const std::string& context, const CodepointSpan& selection_indices,
1625     std::vector<ClassificationResult>* classification_result) const {
1626   const std::string selection_text =
1627       UTF8ToUnicodeText(context, /*do_copy=*/false)
1628           .UTF8Substring(selection_indices.first, selection_indices.second);
1629   const UnicodeText selection_text_unicode(
1630       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1631 
1632   // Check whether any of the regular expressions match.
1633   for (const int pattern_id : classification_regex_patterns_) {
1634     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1635     const std::unique_ptr<UniLib::RegexMatcher> matcher =
1636         regex_pattern.pattern->Matcher(selection_text_unicode);
1637     int status = UniLib::RegexMatcher::kNoError;
1638     bool matches;
1639     if (regex_pattern.config->use_approximate_matching()) {
1640       matches = matcher->ApproximatelyMatches(&status);
1641     } else {
1642       matches = matcher->Matches(&status);
1643     }
1644     if (status != UniLib::RegexMatcher::kNoError) {
1645       return false;
1646     }
1647     if (matches && VerifyRegexMatchCandidate(
1648                        context, regex_pattern.config->verification_options(),
1649                        selection_text, matcher.get())) {
1650       classification_result->push_back(
1651           {regex_pattern.config->collection_name()->str(),
1652            regex_pattern.config->target_classification_score(),
1653            regex_pattern.config->priority_score()});
1654       if (!SerializedEntityDataFromRegexMatch(
1655               regex_pattern.config, matcher.get(),
1656               &classification_result->back().serialized_entity_data)) {
1657         TC3_LOG(ERROR) << "Could not get entity data.";
1658         return false;
1659       }
1660     }
1661   }
1662 
1663   return true;
1664 }
1665 
1666 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1667 std::string PickCollectionForDatetime(
1668     const DatetimeParseResult& datetime_parse_result) {
1669   switch (datetime_parse_result.granularity) {
1670     case GRANULARITY_HOUR:
1671     case GRANULARITY_MINUTE:
1672     case GRANULARITY_SECOND:
1673       return Collections::DateTime();
1674     default:
1675       return Collections::Date();
1676   }
1677 }
1678 
1679 }  // namespace
1680 
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1681 bool Annotator::DatetimeClassifyText(
1682     const std::string& context, const CodepointSpan& selection_indices,
1683     const ClassificationOptions& options,
1684     std::vector<ClassificationResult>* classification_results) const {
1685   if (!datetime_parser_) {
1686     return true;
1687   }
1688 
1689   const std::string selection_text =
1690       UTF8ToUnicodeText(context, /*do_copy=*/false)
1691           .UTF8Substring(selection_indices.first, selection_indices.second);
1692 
1693   LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1694   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1695       datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1696                               options.reference_timezone, locale_list,
1697                               ModeFlag_CLASSIFICATION,
1698                               options.annotation_usecase,
1699                               /*anchor_start_end=*/true);
1700   if (!result_status.ok()) {
1701     TC3_LOG(ERROR) << "Error during parsing datetime.";
1702     return false;
1703   }
1704 
1705   for (const DatetimeParseResultSpan& datetime_span :
1706        result_status.ValueOrDie()) {
1707     // Only consider the result valid if the selection and extracted datetime
1708     // spans exactly match.
1709     if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1710                       datetime_span.span.second + selection_indices.first) ==
1711         selection_indices) {
1712       for (const DatetimeParseResult& parse_result : datetime_span.data) {
1713         classification_results->emplace_back(
1714             PickCollectionForDatetime(parse_result),
1715             datetime_span.target_classification_score);
1716         classification_results->back().datetime_parse_result = parse_result;
1717         classification_results->back().serialized_entity_data =
1718             CreateDatetimeSerializedEntityData(parse_result);
1719         classification_results->back().priority_score =
1720             datetime_span.priority_score;
1721       }
1722       return true;
1723     }
1724   }
1725   return true;
1726 }
1727 
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1728 std::vector<ClassificationResult> Annotator::ClassifyText(
1729     const std::string& context, const CodepointSpan& selection_indices,
1730     const ClassificationOptions& options) const {
1731   if (context.size() > std::numeric_limits<int>::max()) {
1732     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1733     return {};
1734   }
1735   if (!initialized_) {
1736     TC3_LOG(ERROR) << "Not initialized";
1737     return {};
1738   }
1739   if (options.annotation_usecase !=
1740       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1741     TC3_LOG(WARNING)
1742         << "Invoking ClassifyText, which is not supported in RAW mode.";
1743     return {};
1744   }
1745   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1746     return {};
1747   }
1748 
1749   std::vector<Locale> detected_text_language_tags;
1750   if (!ParseLocales(options.detected_text_language_tags,
1751                     &detected_text_language_tags)) {
1752     TC3_LOG(WARNING)
1753         << "Failed to parse the detected_text_language_tags in options: "
1754         << options.detected_text_language_tags;
1755   }
1756   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1757                                     model_triggering_locales_,
1758                                     /*default_value=*/true)) {
1759     return {};
1760   }
1761 
1762   const UnicodeText context_unicode =
1763       UTF8ToUnicodeText(context, /*do_copy=*/false);
1764 
1765   if (!unilib_->IsValidUtf8(context_unicode)) {
1766     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1767     return {};
1768   }
1769 
1770   if (!IsValidSpanInput(context_unicode, selection_indices)) {
1771     TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1772                 << selection_indices.first << " " << selection_indices.second;
1773     return {};
1774   }
1775 
1776   // We'll accumulate a list of candidates, and pick the best candidate in the
1777   // end.
1778   std::vector<AnnotatedSpan> candidates;
1779 
1780   // Try the knowledge engine.
1781   // TODO(b/126579108): Propagate error status.
1782   ClassificationResult knowledge_result;
1783   if (knowledge_engine_ &&
1784       knowledge_engine_
1785           ->ClassifyText(context, selection_indices, options.annotation_usecase,
1786                          options.location_context, Permissions(),
1787                          &knowledge_result)
1788           .ok()) {
1789     candidates.push_back({selection_indices, {knowledge_result}});
1790     candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1791   }
1792 
1793   AddContactMetadataToKnowledgeClassificationResults(&candidates);
1794 
1795   // Try the contact engine.
1796   // TODO(b/126579108): Propagate error status.
1797   ClassificationResult contact_result;
1798   if (contact_engine_ && contact_engine_->ClassifyText(
1799                              context, selection_indices, &contact_result)) {
1800     candidates.push_back({selection_indices, {contact_result}});
1801   }
1802 
1803   // Try the person name engine.
1804   ClassificationResult person_name_result;
1805   if (person_name_engine_ &&
1806       person_name_engine_->ClassifyText(context, selection_indices,
1807                                         &person_name_result)) {
1808     candidates.push_back({selection_indices, {person_name_result}});
1809     candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1810   }
1811 
1812   // Try the installed app engine.
1813   // TODO(b/126579108): Propagate error status.
1814   ClassificationResult installed_app_result;
1815   if (installed_app_engine_ &&
1816       installed_app_engine_->ClassifyText(context, selection_indices,
1817                                           &installed_app_result)) {
1818     candidates.push_back({selection_indices, {installed_app_result}});
1819   }
1820 
1821   // Try the regular expression models.
1822   std::vector<ClassificationResult> regex_results;
1823   if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1824     return {};
1825   }
1826   for (const ClassificationResult& result : regex_results) {
1827     candidates.push_back({selection_indices, {result}});
1828   }
1829 
1830   // Try the date model.
1831   //
1832   // DatetimeClassifyText only returns the first result, which can however have
1833   // more interpretations. They are inserted in the candidates as a single
1834   // AnnotatedSpan, so that they get treated together by the conflict resolution
1835   // algorithm.
1836   std::vector<ClassificationResult> datetime_results;
1837   if (!DatetimeClassifyText(context, selection_indices, options,
1838                             &datetime_results)) {
1839     return {};
1840   }
1841   if (!datetime_results.empty()) {
1842     candidates.push_back({selection_indices, std::move(datetime_results)});
1843     candidates.back().source = AnnotatedSpan::Source::DATETIME;
1844   }
1845 
1846   // Try the number annotator.
1847   // TODO(b/126579108): Propagate error status.
1848   ClassificationResult number_annotator_result;
1849   if (number_annotator_ &&
1850       number_annotator_->ClassifyText(context_unicode, selection_indices,
1851                                       options.annotation_usecase,
1852                                       &number_annotator_result)) {
1853     candidates.push_back({selection_indices, {number_annotator_result}});
1854   }
1855 
1856   // Try the duration annotator.
1857   ClassificationResult duration_annotator_result;
1858   if (duration_annotator_ &&
1859       duration_annotator_->ClassifyText(context_unicode, selection_indices,
1860                                         options.annotation_usecase,
1861                                         &duration_annotator_result)) {
1862     candidates.push_back({selection_indices, {duration_annotator_result}});
1863     candidates.back().source = AnnotatedSpan::Source::DURATION;
1864   }
1865 
1866   // Try the translate annotator.
1867   ClassificationResult translate_annotator_result;
1868   if (translate_annotator_ &&
1869       translate_annotator_->ClassifyText(context_unicode, selection_indices,
1870                                          options.user_familiar_language_tags,
1871                                          &translate_annotator_result)) {
1872     candidates.push_back({selection_indices, {translate_annotator_result}});
1873   }
1874 
1875   // Try the grammar model.
1876   ClassificationResult grammar_annotator_result;
1877   if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1878                                 detected_text_language_tags, context_unicode,
1879                                 selection_indices, &grammar_annotator_result)) {
1880     candidates.push_back({selection_indices, {grammar_annotator_result}});
1881   }
1882 
1883   ClassificationResult pod_ner_annotator_result;
1884   if (pod_ner_annotator_ && options.use_pod_ner &&
1885       pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1886                                        &pod_ner_annotator_result)) {
1887     candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1888   }
1889 
1890   ClassificationResult vocab_annotator_result;
1891   if (vocab_annotator_ && options.use_vocab_annotator &&
1892       vocab_annotator_->ClassifyText(
1893           context_unicode, selection_indices, detected_text_language_tags,
1894           options.trigger_dictionary_on_beginner_words,
1895           &vocab_annotator_result)) {
1896     candidates.push_back({selection_indices, {vocab_annotator_result}});
1897   }
1898 
1899   if (experimental_annotator_) {
1900     experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1901                                           candidates);
1902   }
1903 
1904   // Try the ML model.
1905   //
1906   // The output of the model is considered as an exclusive 1-of-N choice. That's
1907   // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1908   // span for each candidate, like e.g. the regex model.
1909   InterpreterManager interpreter_manager(selection_executor_.get(),
1910                                          classification_executor_.get());
1911   std::vector<ClassificationResult> model_results;
1912   std::vector<Token> tokens;
1913   if (!ModelClassifyText(
1914           context, /*cached_tokens=*/{}, detected_text_language_tags,
1915           selection_indices, options, &interpreter_manager,
1916           /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1917     return {};
1918   }
1919   if (!model_results.empty()) {
1920     candidates.push_back({selection_indices, std::move(model_results)});
1921   }
1922 
1923   std::vector<int> candidate_indices;
1924   if (!ResolveConflicts(candidates, context, tokens,
1925                         detected_text_language_tags, options,
1926                         &interpreter_manager, &candidate_indices)) {
1927     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1928     return {};
1929   }
1930 
1931   std::vector<ClassificationResult> results;
1932   for (const int i : candidate_indices) {
1933     for (const ClassificationResult& result : candidates[i].classification) {
1934       if (!FilteredForClassification(result)) {
1935         results.push_back(result);
1936       }
1937     }
1938   }
1939 
1940   // Sort results according to score.
1941   std::stable_sort(
1942       results.begin(), results.end(),
1943       [](const ClassificationResult& a, const ClassificationResult& b) {
1944         return a.score > b.score;
1945       });
1946 
1947   if (results.empty()) {
1948     results = {{Collections::Other(), 1.0}};
1949   }
1950   return results;
1951 }
1952 
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1953 bool Annotator::ModelAnnotate(
1954     const std::string& context,
1955     const std::vector<Locale>& detected_text_language_tags,
1956     const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1957     std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1958   bool skip_model_annotatation = false;
1959   if (model_->triggering_options() == nullptr ||
1960       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1961     skip_model_annotatation = true;
1962   }
1963   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1964                                     ml_model_triggering_locales_,
1965                                     /*default_value=*/true)) {
1966     skip_model_annotatation = true;
1967   }
1968 
1969   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1970                                                         /*do_copy=*/false);
1971   std::vector<UnicodeTextRange> lines;
1972   if (!selection_feature_processor_ ||
1973       !selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1974     lines.push_back({context_unicode.begin(), context_unicode.end()});
1975   } else {
1976     lines = selection_feature_processor_->SplitContext(
1977         context_unicode, selection_feature_processor_->GetOptions()
1978                              ->use_pipe_character_for_newline());
1979   }
1980 
1981   const float min_annotate_confidence =
1982       (model_->triggering_options() != nullptr
1983            ? model_->triggering_options()->min_annotate_confidence()
1984            : 0.f);
1985 
1986   for (const UnicodeTextRange& line : lines) {
1987     const std::string line_str =
1988         UnicodeText::UTF8Substring(line.first, line.second);
1989 
1990     std::vector<Token> line_tokens;
1991     line_tokens = selection_feature_processor_->Tokenize(line_str);
1992 
1993     selection_feature_processor_->RetokenizeAndFindClick(
1994         line_str, {0, std::distance(line.first, line.second)},
1995         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1996         &line_tokens,
1997         /*click_pos=*/nullptr);
1998     const TokenSpan full_line_span = {
1999         0, static_cast<TokenIndex>(line_tokens.size())};
2000 
2001     tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2002 
2003     if (skip_model_annotatation) {
2004       // We do not annotate, we only output the tokens.
2005       continue;
2006     }
2007 
2008     // TODO(zilka): Add support for greater granularity of this check.
2009     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
2010             line_tokens, full_line_span)) {
2011       continue;
2012     }
2013 
2014     std::unique_ptr<CachedFeatures> cached_features;
2015     if (!selection_feature_processor_->ExtractFeatures(
2016             line_tokens, full_line_span,
2017             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2018             embedding_executor_.get(),
2019             /*embedding_cache=*/nullptr,
2020             selection_feature_processor_->EmbeddingSize() +
2021                 selection_feature_processor_->DenseFeaturesCount(),
2022             &cached_features)) {
2023       TC3_LOG(ERROR) << "Could not extract features.";
2024       return false;
2025     }
2026 
2027     std::vector<TokenSpan> local_chunks;
2028     if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2029                     interpreter_manager->SelectionInterpreter(),
2030                     *cached_features, &local_chunks)) {
2031       TC3_LOG(ERROR) << "Could not chunk.";
2032       return false;
2033     }
2034 
2035     const int offset = std::distance(context_unicode.begin(), line.first);
2036     if (local_chunks.empty()) {
2037       continue;
2038     }
2039     const UnicodeText line_unicode =
2040         UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2041     std::vector<UnicodeText::const_iterator> line_codepoints =
2042         line_unicode.Codepoints();
2043     line_codepoints.push_back(line_unicode.end());
2044 
2045     FeatureProcessor::EmbeddingCache embedding_cache;
2046     for (const TokenSpan& chunk : local_chunks) {
2047       CodepointSpan codepoint_span =
2048           TokenSpanToCodepointSpan(line_tokens, chunk);
2049       if (!codepoint_span.IsValid() ||
2050           codepoint_span.second > line_codepoints.size()) {
2051         continue;
2052       }
2053       codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2054           /*span_begin=*/line_codepoints[codepoint_span.first],
2055           /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
2056       if (model_->selection_options()->strip_unpaired_brackets()) {
2057         codepoint_span = StripUnpairedBrackets(
2058             /*span_begin=*/line_codepoints[codepoint_span.first],
2059             /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
2060             *unilib_);
2061       }
2062 
2063       // Skip empty spans.
2064       if (codepoint_span.first != codepoint_span.second) {
2065         std::vector<ClassificationResult> classification;
2066         if (!ModelClassifyText(
2067                 line_unicode, line_tokens, detected_text_language_tags,
2068                 /*span_begin=*/line_codepoints[codepoint_span.first],
2069                 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2070                 codepoint_span, options, interpreter_manager, &embedding_cache,
2071                 &classification, /*tokens=*/nullptr)) {
2072           TC3_LOG(ERROR) << "Could not classify text: "
2073                          << (codepoint_span.first + offset) << " "
2074                          << (codepoint_span.second + offset);
2075           return false;
2076         }
2077 
2078         // Do not include the span if it's classified as "other".
2079         if (!classification.empty() && !ClassifiedAsOther(classification) &&
2080             classification[0].score >= min_annotate_confidence) {
2081           AnnotatedSpan result_span;
2082           result_span.span = {codepoint_span.first + offset,
2083                               codepoint_span.second + offset};
2084           result_span.classification = std::move(classification);
2085           result->push_back(std::move(result_span));
2086         }
2087       }
2088     }
2089   }
2090   return true;
2091 }
2092 
SelectionFeatureProcessorForTests() const2093 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2094   return selection_feature_processor_.get();
2095 }
2096 
ClassificationFeatureProcessorForTests() const2097 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2098     const {
2099   return classification_feature_processor_.get();
2100 }
2101 
DatetimeParserForTests() const2102 const DatetimeParser* Annotator::DatetimeParserForTests() const {
2103   return datetime_parser_.get();
2104 }
2105 
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2106 void Annotator::RemoveNotEnabledEntityTypes(
2107     const EnabledEntityTypes& is_entity_type_enabled,
2108     std::vector<AnnotatedSpan>* annotated_spans) const {
2109   for (AnnotatedSpan& annotated_span : *annotated_spans) {
2110     std::vector<ClassificationResult>& classifications =
2111         annotated_span.classification;
2112     classifications.erase(
2113         std::remove_if(classifications.begin(), classifications.end(),
2114                        [&is_entity_type_enabled](
2115                            const ClassificationResult& classification_result) {
2116                          return !is_entity_type_enabled(
2117                              classification_result.collection);
2118                        }),
2119         classifications.end());
2120   }
2121   annotated_spans->erase(
2122       std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2123                      [](const AnnotatedSpan& annotated_span) {
2124                        return annotated_span.classification.empty();
2125                      }),
2126       annotated_spans->end());
2127 }
2128 
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2129 void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2130     std::vector<AnnotatedSpan>* candidates) const {
2131   if (candidates == nullptr || contact_engine_ == nullptr) {
2132     return;
2133   }
2134   for (auto& candidate : *candidates) {
2135     for (auto& classification_result : candidate.classification) {
2136       contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2137           &classification_result);
2138     }
2139   }
2140 }
2141 
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2142 Status Annotator::AnnotateSingleInput(
2143     const std::string& context, const AnnotationOptions& options,
2144     std::vector<AnnotatedSpan>* candidates) const {
2145   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2146     return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2147   }
2148 
2149   const UnicodeText context_unicode =
2150       UTF8ToUnicodeText(context, /*do_copy=*/false);
2151 
2152   std::vector<Locale> detected_text_language_tags;
2153   if (!ParseLocales(options.detected_text_language_tags,
2154                     &detected_text_language_tags)) {
2155     TC3_LOG(WARNING)
2156         << "Failed to parse the detected_text_language_tags in options: "
2157         << options.detected_text_language_tags;
2158   }
2159   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2160                                     model_triggering_locales_,
2161                                     /*default_value=*/true)) {
2162     return Status(
2163         StatusCode::UNAVAILABLE,
2164         "The detected language tags are not in the supported locales.");
2165   }
2166 
2167   InterpreterManager interpreter_manager(selection_executor_.get(),
2168                                          classification_executor_.get());
2169 
2170   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2171   const bool is_raw_usecase =
2172       options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2173 
2174   // Annotate with the selection model.
2175   const bool model_annotations_enabled =
2176       !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2177   std::vector<Token> tokens;
2178   if (model_annotations_enabled &&
2179       !ModelAnnotate(context, detected_text_language_tags, options,
2180                      &interpreter_manager, &tokens, candidates)) {
2181     return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2182   } else if (!model_annotations_enabled) {
2183     // If the ML model didn't run, we need to tokenize to support the other
2184     // annotators that depend on the tokens.
2185     // Optimization could be made to only do this when an annotator that uses
2186     // the tokens is enabled, but it's unclear if the added complexity is worth
2187     // it.
2188     if (selection_feature_processor_ != nullptr) {
2189       tokens = selection_feature_processor_->Tokenize(context_unicode);
2190     }
2191   }
2192 
2193   // Annotate with the regular expression models.
2194   const bool regex_annotations_enabled =
2195       !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2196   if (regex_annotations_enabled &&
2197       !RegexChunk(
2198           UTF8ToUnicodeText(context, /*do_copy=*/false),
2199           annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2200           is_entity_type_enabled, options.annotation_usecase, candidates)) {
2201     return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2202   }
2203 
2204   // Annotate with the datetime model.
2205   // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2206   // relatively slow for some clients.
2207   if ((is_entity_type_enabled(Collections::Date()) ||
2208        is_entity_type_enabled(Collections::DateTime())) &&
2209       !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2210                      options.reference_time_ms_utc, options.reference_timezone,
2211                      options.locales, ModeFlag_ANNOTATION,
2212                      options.annotation_usecase,
2213                      options.is_serialized_entity_data_enabled, candidates)) {
2214     return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2215   }
2216 
2217   // Annotate with the contact engine.
2218   const bool contact_annotations_enabled =
2219       !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2220   if (contact_annotations_enabled && contact_engine_ &&
2221       !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2222     return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2223   }
2224 
2225   // Annotate with the installed app engine.
2226   const bool app_annotations_enabled =
2227       !is_raw_usecase || is_entity_type_enabled(Collections::App());
2228   if (app_annotations_enabled && installed_app_engine_ &&
2229       !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2230     return Status(StatusCode::INTERNAL,
2231                   "Couldn't run installed app engine Chunk.");
2232   }
2233 
2234   // Annotate with the number annotator.
2235   const bool number_annotations_enabled =
2236       !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2237                           is_entity_type_enabled(Collections::Percentage()));
2238   if (number_annotations_enabled && number_annotator_ != nullptr &&
2239       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2240                                   candidates)) {
2241     return Status(StatusCode::INTERNAL,
2242                   "Couldn't run number annotator FindAll.");
2243   }
2244 
2245   // Annotate with the duration annotator.
2246   const bool duration_annotations_enabled =
2247       !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2248   if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2249       !duration_annotator_->FindAll(context_unicode, tokens,
2250                                     options.annotation_usecase, candidates)) {
2251     return Status(StatusCode::INTERNAL,
2252                   "Couldn't run duration annotator FindAll.");
2253   }
2254 
2255   // Annotate with the person name engine.
2256   const bool person_annotations_enabled =
2257       !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2258   if (person_annotations_enabled && person_name_engine_ &&
2259       !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2260     return Status(StatusCode::INTERNAL,
2261                   "Couldn't run person name engine Chunk.");
2262   }
2263 
2264   // Annotate with the grammar annotators.
2265   if (grammar_annotator_ != nullptr &&
2266       !grammar_annotator_->Annotate(detected_text_language_tags,
2267                                     context_unicode, candidates)) {
2268     return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2269   }
2270 
2271   // Annotate with the POD NER annotator.
2272   const bool pod_ner_annotations_enabled =
2273       !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2274   if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2275       options.use_pod_ner &&
2276       !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2277     return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2278   }
2279 
2280   // Annotate with the vocab annotator.
2281   const bool vocab_annotations_enabled =
2282       !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2283   if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2284       options.use_vocab_annotator &&
2285       !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2286                                   options.trigger_dictionary_on_beginner_words,
2287                                   candidates)) {
2288     return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2289   }
2290 
2291   // Annotate with the experimental annotator.
2292   if (experimental_annotator_ != nullptr &&
2293       !experimental_annotator_->Annotate(context_unicode, candidates)) {
2294     return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2295   }
2296 
2297   // Sort candidates according to their position in the input, so that the next
2298   // code can assume that any connected component of overlapping spans forms a
2299   // contiguous block.
2300   // Also sort them according to the end position and collection, so that the
2301   // deduplication code below can assume that same spans and classifications
2302   // form contiguous blocks.
2303   std::stable_sort(candidates->begin(), candidates->end(),
2304                    [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2305                      if (a.span.first != b.span.first) {
2306                        return a.span.first < b.span.first;
2307                      }
2308 
2309                      if (a.span.second != b.span.second) {
2310                        return a.span.second < b.span.second;
2311                      }
2312 
2313                      return a.classification[0].collection <
2314                             b.classification[0].collection;
2315                    });
2316 
2317   std::vector<int> candidate_indices;
2318   if (!ResolveConflicts(*candidates, context, tokens,
2319                         detected_text_language_tags, options,
2320                         &interpreter_manager, &candidate_indices)) {
2321     return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2322   }
2323 
2324   // Remove candidates that overlap exactly and have the same collection.
2325   // This can e.g. happen for phone coming from both ML model and regex.
2326   candidate_indices.erase(
2327       std::unique(candidate_indices.begin(), candidate_indices.end(),
2328                   [&candidates](const int a_index, const int b_index) {
2329                     const AnnotatedSpan& a = (*candidates)[a_index];
2330                     const AnnotatedSpan& b = (*candidates)[b_index];
2331                     return a.span == b.span &&
2332                            a.classification[0].collection ==
2333                                b.classification[0].collection;
2334                   }),
2335       candidate_indices.end());
2336 
2337   std::vector<AnnotatedSpan> result;
2338   result.reserve(candidate_indices.size());
2339   for (const int i : candidate_indices) {
2340     if ((*candidates)[i].classification.empty() ||
2341         ClassifiedAsOther((*candidates)[i].classification) ||
2342         FilteredForAnnotation((*candidates)[i])) {
2343       continue;
2344     }
2345     result.push_back(std::move((*candidates)[i]));
2346   }
2347 
2348   // We generate all candidates and remove them later (with the exception of
2349   // date/time/duration entities) because there are complex interdependencies
2350   // between the entity types. E.g., the TLD of an email can be interpreted as a
2351   // URL, but most likely a user of the API does not want such annotations if
2352   // "url" is enabled and "email" is not.
2353   RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2354 
2355   for (AnnotatedSpan& annotated_span : result) {
2356     SortClassificationResults(&annotated_span.classification);
2357   }
2358   *candidates = result;
2359   return Status::OK;
2360 }
2361 
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2362 StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2363     const std::vector<InputFragment>& string_fragments,
2364     const AnnotationOptions& options) const {
2365   Annotations annotation_candidates;
2366   annotation_candidates.annotated_spans.resize(string_fragments.size());
2367 
2368   std::vector<std::string> text_to_annotate;
2369   text_to_annotate.reserve(string_fragments.size());
2370   std::vector<FragmentMetadata> fragment_metadata;
2371   fragment_metadata.reserve(string_fragments.size());
2372   for (const auto& string_fragment : string_fragments) {
2373     text_to_annotate.push_back(string_fragment.text);
2374     fragment_metadata.push_back(
2375         {.relative_bounding_box_top = string_fragment.bounding_box_top,
2376          .relative_bounding_box_height = string_fragment.bounding_box_height});
2377   }
2378 
2379   // KnowledgeEngine is special, because it supports annotation of multiple
2380   // fragments at once.
2381   if (knowledge_engine_ &&
2382       !knowledge_engine_
2383            ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2384                                 options.annotation_usecase,
2385                                 options.location_context, options.permissions,
2386                                 options.annotate_mode, &annotation_candidates)
2387            .ok()) {
2388     return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2389   }
2390   // The annotator engines shouldn't change the number of annotation vectors.
2391   if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2392     TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2393                    << " texts to annotate but generated a different number of  "
2394                       "lists of annotations:"
2395                    << annotation_candidates.annotated_spans.size();
2396     return Status(StatusCode::INTERNAL,
2397                   "Number of annotation candidates differs from "
2398                   "number of texts to annotate.");
2399   }
2400 
2401   // As an optimization, if the only annotated type is Entity, we skip all the
2402   // other annotators than the KnowledgeEngine. This only happens in the raw
2403   // mode, to make sure it does not affect the result.
2404   if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2405       options.entity_types.size() == 1 &&
2406       *options.entity_types.begin() == Collections::Entity()) {
2407     return annotation_candidates;
2408   }
2409 
2410   // Other annotators run on each fragment independently.
2411   for (int i = 0; i < text_to_annotate.size(); ++i) {
2412     AnnotationOptions annotation_options = options;
2413     if (string_fragments[i].datetime_options.has_value()) {
2414       DatetimeOptions reference_datetime =
2415           string_fragments[i].datetime_options.value();
2416       annotation_options.reference_time_ms_utc =
2417           reference_datetime.reference_time_ms_utc;
2418       annotation_options.reference_timezone =
2419           reference_datetime.reference_timezone;
2420     }
2421 
2422     AddContactMetadataToKnowledgeClassificationResults(
2423         &annotation_candidates.annotated_spans[i]);
2424 
2425     Status annotation_status =
2426         AnnotateSingleInput(text_to_annotate[i], annotation_options,
2427                             &annotation_candidates.annotated_spans[i]);
2428     if (!annotation_status.ok()) {
2429       return annotation_status;
2430     }
2431   }
2432   return annotation_candidates;
2433 }
2434 
Annotate(const std::string & context,const AnnotationOptions & options) const2435 std::vector<AnnotatedSpan> Annotator::Annotate(
2436     const std::string& context, const AnnotationOptions& options) const {
2437   if (context.size() > std::numeric_limits<int>::max()) {
2438     TC3_LOG(ERROR) << "Rejecting too long input.";
2439     return {};
2440   }
2441 
2442   const UnicodeText context_unicode =
2443       UTF8ToUnicodeText(context, /*do_copy=*/false);
2444   if (!unilib_->IsValidUtf8(context_unicode)) {
2445     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2446     return {};
2447   }
2448 
2449   std::vector<InputFragment> string_fragments;
2450   string_fragments.push_back({.text = context});
2451   StatusOr<Annotations> annotations =
2452       AnnotateStructuredInput(string_fragments, options);
2453   if (!annotations.ok()) {
2454     TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2455                    << annotations.status().error_message();
2456     return {};
2457   }
2458   return annotations.ValueOrDie().annotated_spans[0];
2459 }
2460 
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2461 CodepointSpan Annotator::ComputeSelectionBoundaries(
2462     const UniLib::RegexMatcher* match,
2463     const RegexModel_::Pattern* config) const {
2464   if (config->capturing_group() == nullptr) {
2465     // Use first capturing group to specify the selection.
2466     int status = UniLib::RegexMatcher::kNoError;
2467     const CodepointSpan result = {match->Start(1, &status),
2468                                   match->End(1, &status)};
2469     if (status != UniLib::RegexMatcher::kNoError) {
2470       return {kInvalidIndex, kInvalidIndex};
2471     }
2472     return result;
2473   }
2474 
2475   CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2476   const int num_groups = config->capturing_group()->size();
2477   for (int i = 0; i < num_groups; i++) {
2478     if (!config->capturing_group()->Get(i)->extend_selection()) {
2479       continue;
2480     }
2481 
2482     int status = UniLib::RegexMatcher::kNoError;
2483     // Check match and adjust bounds.
2484     const int group_start = match->Start(i, &status);
2485     const int group_end = match->End(i, &status);
2486     if (status != UniLib::RegexMatcher::kNoError) {
2487       return {kInvalidIndex, kInvalidIndex};
2488     }
2489     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2490       continue;
2491     }
2492     if (result.first == kInvalidIndex) {
2493       result = {group_start, group_end};
2494     } else {
2495       result.first = std::min(result.first, group_start);
2496       result.second = std::max(result.second, group_end);
2497     }
2498   }
2499   return result;
2500 }
2501 
HasEntityData(const RegexModel_::Pattern * pattern) const2502 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2503   if (pattern->serialized_entity_data() != nullptr ||
2504       pattern->entity_data() != nullptr) {
2505     return true;
2506   }
2507   if (pattern->capturing_group() != nullptr) {
2508     for (const CapturingGroup* group : *pattern->capturing_group()) {
2509       if (group->entity_field_path() != nullptr) {
2510         return true;
2511       }
2512       if (group->serialized_entity_data() != nullptr ||
2513           group->entity_data() != nullptr) {
2514         return true;
2515       }
2516     }
2517   }
2518   return false;
2519 }
2520 
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2521 bool Annotator::SerializedEntityDataFromRegexMatch(
2522     const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2523     std::string* serialized_entity_data) const {
2524   if (!HasEntityData(pattern)) {
2525     serialized_entity_data->clear();
2526     return true;
2527   }
2528   TC3_CHECK(entity_data_builder_ != nullptr);
2529 
2530   std::unique_ptr<MutableFlatbuffer> entity_data =
2531       entity_data_builder_->NewRoot();
2532 
2533   TC3_CHECK(entity_data != nullptr);
2534 
2535   // Set fixed entity data.
2536   if (pattern->serialized_entity_data() != nullptr) {
2537     entity_data->MergeFromSerializedFlatbuffer(
2538         StringPiece(pattern->serialized_entity_data()->c_str(),
2539                     pattern->serialized_entity_data()->size()));
2540   }
2541   if (pattern->entity_data() != nullptr) {
2542     entity_data->MergeFrom(
2543         reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2544   }
2545 
2546   // Add entity data from rule capturing groups.
2547   if (pattern->capturing_group() != nullptr) {
2548     const int num_groups = pattern->capturing_group()->size();
2549     for (int i = 0; i < num_groups; i++) {
2550       const CapturingGroup* group = pattern->capturing_group()->Get(i);
2551 
2552       // Check whether the group matched.
2553       Optional<std::string> group_match_text =
2554           GetCapturingGroupText(matcher, /*group_id=*/i);
2555       if (!group_match_text.has_value()) {
2556         continue;
2557       }
2558 
2559       // Set fixed entity data from capturing group match.
2560       if (group->serialized_entity_data() != nullptr) {
2561         entity_data->MergeFromSerializedFlatbuffer(
2562             StringPiece(group->serialized_entity_data()->c_str(),
2563                         group->serialized_entity_data()->size()));
2564       }
2565       if (group->entity_data() != nullptr) {
2566         entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2567             pattern->entity_data()));
2568       }
2569 
2570       // Set entity field from capturing group text.
2571       if (group->entity_field_path() != nullptr) {
2572         UnicodeText normalized_group_match_text =
2573             UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2574 
2575         // Apply normalization if specified.
2576         if (group->normalization_options() != nullptr) {
2577           normalized_group_match_text =
2578               NormalizeText(*unilib_, group->normalization_options(),
2579                             normalized_group_match_text);
2580         }
2581 
2582         if (!entity_data->ParseAndSet(
2583                 group->entity_field_path(),
2584                 normalized_group_match_text.ToUTF8String())) {
2585           TC3_LOG(ERROR)
2586               << "Could not set entity data from rule capturing group.";
2587           return false;
2588         }
2589       }
2590     }
2591   }
2592 
2593   *serialized_entity_data = entity_data->Serialize();
2594   return true;
2595 }
2596 
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2597 UnicodeText RemoveMoneySeparators(
2598     const std::unordered_set<char32>& decimal_separators,
2599     const UnicodeText& amount,
2600     UnicodeText::const_iterator it_decimal_separator) {
2601   UnicodeText whole_amount;
2602   for (auto it = amount.begin();
2603        it != amount.end() && it != it_decimal_separator; ++it) {
2604     if (std::find(decimal_separators.begin(), decimal_separators.end(),
2605                   static_cast<char32>(*it)) == decimal_separators.end()) {
2606       whole_amount.push_back(*it);
2607     }
2608   }
2609   return whole_amount;
2610 }
2611 
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2612 void Annotator::GetMoneyQuantityFromCapturingGroup(
2613     const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2614     const UnicodeText& context_unicode, std::string* quantity,
2615     int* exponent) const {
2616   if (config->capturing_group() == nullptr) {
2617     *exponent = 0;
2618     return;
2619   }
2620 
2621   const int num_groups = config->capturing_group()->size();
2622   for (int i = 0; i < num_groups; i++) {
2623     int status = UniLib::RegexMatcher::kNoError;
2624     const int group_start = match->Start(i, &status);
2625     const int group_end = match->End(i, &status);
2626     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2627       continue;
2628     }
2629 
2630     *quantity =
2631         unilib_
2632             ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2633                                                  group_end, /*do_copy=*/false))
2634             .ToUTF8String();
2635 
2636     if (auto entry = model_->money_parsing_options()
2637                          ->quantities_name_to_exponent()
2638                          ->LookupByKey((*quantity).c_str())) {
2639       *exponent = entry->value();
2640       return;
2641     }
2642   }
2643   *exponent = 0;
2644 }
2645 
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2646 bool Annotator::ParseAndFillInMoneyAmount(
2647     std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2648     const RegexModel_::Pattern* config,
2649     const UnicodeText& context_unicode) const {
2650   std::unique_ptr<EntityDataT> data =
2651       LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2652           *serialized_entity_data);
2653   if (data == nullptr) {
2654     if (model_->version() >= 706) {
2655       // This way of parsing money entity data is enabled for models newer than
2656       // v706, consequently logging errors only for them (b/156634162).
2657       TC3_LOG(ERROR)
2658           << "Data field is null when trying to parse Money Entity Data";
2659     }
2660     return false;
2661   }
2662   if (data->money->unnormalized_amount.empty()) {
2663     if (model_->version() >= 706) {
2664       // This way of parsing money entity data is enabled for models newer than
2665       // v706, consequently logging errors only for them (b/156634162).
2666       TC3_LOG(ERROR)
2667           << "Data unnormalized_amount is empty when trying to parse "
2668              "Money Entity Data";
2669     }
2670     return false;
2671   }
2672 
2673   UnicodeText amount =
2674       UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2675   int separator_back_index = 0;
2676   auto it_decimal_separator = --amount.end();
2677   for (; it_decimal_separator != amount.begin();
2678        --it_decimal_separator, ++separator_back_index) {
2679     if (std::find(money_separators_.begin(), money_separators_.end(),
2680                   static_cast<char32>(*it_decimal_separator)) !=
2681         money_separators_.end()) {
2682       break;
2683     }
2684   }
2685 
2686   // If there are 3 digits after the last separator, we consider that a
2687   // thousands separator => the number is an int (e.g. 1.234 is considered int).
2688   // If there is no separator in number, also that number is an int.
2689   if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2690     it_decimal_separator = amount.end();
2691   }
2692 
2693   if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2694                                                  it_decimal_separator),
2695                            &data->money->amount_whole_part)) {
2696     TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2697                       "amount: "
2698                    << data->money->unnormalized_amount;
2699     return false;
2700   }
2701 
2702   if (it_decimal_separator == amount.end()) {
2703     data->money->amount_decimal_part = 0;
2704     data->money->nanos = 0;
2705   } else {
2706     const int amount_codepoints_size = amount.size_codepoints();
2707     const UnicodeText decimal_part = UnicodeText::Substring(
2708         amount, amount_codepoints_size - separator_back_index,
2709         amount_codepoints_size, /*do_copy=*/false);
2710     if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2711       TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2712                         "the amount: "
2713                      << data->money->unnormalized_amount;
2714       return false;
2715     }
2716     data->money->nanos = data->money->amount_decimal_part *
2717                          pow(10, 9 - decimal_part.size_codepoints());
2718   }
2719 
2720   if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2721       nullptr) {
2722     int quantity_exponent;
2723     std::string quantity;
2724     GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2725                                        &quantity, &quantity_exponent);
2726     if (quantity_exponent > 0 && quantity_exponent <= 9) {
2727       const double amount_whole_part =
2728           data->money->amount_whole_part * pow(10, quantity_exponent) +
2729           data->money->nanos / pow(10, 9 - quantity_exponent);
2730       // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2731       // (and `std::numeric_limits<int>::max()` to
2732       // `std::numeric_limits<int64>::max()`).
2733       if (amount_whole_part < std::numeric_limits<int>::max()) {
2734         data->money->amount_whole_part = amount_whole_part;
2735         data->money->nanos = data->money->nanos %
2736                              static_cast<int>(pow(10, 9 - quantity_exponent)) *
2737                              pow(10, quantity_exponent);
2738       }
2739     }
2740     if (quantity_exponent > 0) {
2741       data->money->unnormalized_amount = strings::JoinStrings(
2742           " ", {data->money->unnormalized_amount, quantity});
2743     }
2744   }
2745 
2746   *serialized_entity_data =
2747       PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2748   return true;
2749 }
2750 
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2751 bool Annotator::IsAnyModelEntityTypeEnabled(
2752     const EnabledEntityTypes& is_entity_type_enabled) const {
2753   if (model_->classification_feature_options() == nullptr ||
2754       model_->classification_feature_options()->collections() == nullptr) {
2755     return false;
2756   }
2757   for (int i = 0;
2758        i < model_->classification_feature_options()->collections()->size();
2759        i++) {
2760     if (is_entity_type_enabled(model_->classification_feature_options()
2761                                    ->collections()
2762                                    ->Get(i)
2763                                    ->str())) {
2764       return true;
2765     }
2766   }
2767   return false;
2768 }
2769 
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2770 bool Annotator::IsAnyRegexEntityTypeEnabled(
2771     const EnabledEntityTypes& is_entity_type_enabled) const {
2772   if (model_->regex_model() == nullptr ||
2773       model_->regex_model()->patterns() == nullptr) {
2774     return false;
2775   }
2776   for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2777     if (is_entity_type_enabled(model_->regex_model()
2778                                    ->patterns()
2779                                    ->Get(i)
2780                                    ->collection_name()
2781                                    ->str())) {
2782       return true;
2783     }
2784   }
2785   return false;
2786 }
2787 
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2788 bool Annotator::IsAnyPodNerEntityTypeEnabled(
2789     const EnabledEntityTypes& is_entity_type_enabled) const {
2790   if (pod_ner_annotator_ == nullptr) {
2791     return false;
2792   }
2793 
2794   for (const std::string& collection :
2795        pod_ner_annotator_->GetSupportedCollections()) {
2796     if (is_entity_type_enabled(collection)) {
2797       return true;
2798     }
2799   }
2800   return false;
2801 }
2802 
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2803 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2804                            const std::vector<int>& rules,
2805                            bool is_serialized_entity_data_enabled,
2806                            const EnabledEntityTypes& enabled_entity_types,
2807                            const AnnotationUsecase& annotation_usecase,
2808                            std::vector<AnnotatedSpan>* result) const {
2809   for (int pattern_id : rules) {
2810     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2811     if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2812         annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2813       // No regex annotation type has been requested, skip regex annotation.
2814       continue;
2815     }
2816     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2817     if (!matcher) {
2818       TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2819                      << pattern_id;
2820       return false;
2821     }
2822 
2823     int status = UniLib::RegexMatcher::kNoError;
2824     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2825       if (regex_pattern.config->verification_options()) {
2826         if (!VerifyRegexMatchCandidate(
2827                 context_unicode.ToUTF8String(),
2828                 regex_pattern.config->verification_options(),
2829                 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2830           continue;
2831         }
2832       }
2833 
2834       std::string serialized_entity_data;
2835       if (is_serialized_entity_data_enabled) {
2836         if (!SerializedEntityDataFromRegexMatch(
2837                 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2838           TC3_LOG(ERROR) << "Could not get entity data.";
2839           return false;
2840         }
2841 
2842         // Further parsing of money amount. Need this since regexes cannot have
2843         // empty groups that fill in entity data (amount_decimal_part and
2844         // quantity might be empty groups).
2845         if (regex_pattern.config->collection_name()->str() ==
2846             Collections::Money()) {
2847           if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2848                                          regex_pattern.config,
2849                                          context_unicode)) {
2850             if (model_->version() >= 706) {
2851               // This way of parsing money entity data is enabled for models
2852               // newer than v706 => logging errors only for them (b/156634162).
2853               TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2854             }
2855           }
2856         }
2857       }
2858 
2859       result->emplace_back();
2860 
2861       // Selection/annotation regular expressions need to specify a capturing
2862       // group specifying the selection.
2863       result->back().span =
2864           ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2865 
2866       result->back().classification = {
2867           {regex_pattern.config->collection_name()->str(),
2868            regex_pattern.config->target_classification_score(),
2869            regex_pattern.config->priority_score()}};
2870 
2871       result->back().classification[0].serialized_entity_data =
2872           serialized_entity_data;
2873     }
2874   }
2875   return true;
2876 }
2877 
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2878 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2879                            tflite::Interpreter* selection_interpreter,
2880                            const CachedFeatures& cached_features,
2881                            std::vector<TokenSpan>* chunks) const {
2882   const int max_selection_span =
2883       selection_feature_processor_->GetOptions()->max_selection_span();
2884   // The inference span is the span of interest expanded to include
2885   // max_selection_span tokens on either side, which is how far a selection can
2886   // stretch from the click.
2887   const TokenSpan inference_span =
2888       IntersectTokenSpans(span_of_interest.Expand(
2889                               /*num_tokens_left=*/max_selection_span,
2890                               /*num_tokens_right=*/max_selection_span),
2891                           {0, num_tokens});
2892 
2893   std::vector<ScoredChunk> scored_chunks;
2894   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2895       selection_feature_processor_->GetOptions()
2896           ->bounds_sensitive_features()
2897           ->enabled()) {
2898     if (!ModelBoundsSensitiveScoreChunks(
2899             num_tokens, span_of_interest, inference_span, cached_features,
2900             selection_interpreter, &scored_chunks)) {
2901       return false;
2902     }
2903   } else {
2904     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2905                                       cached_features, selection_interpreter,
2906                                       &scored_chunks)) {
2907       return false;
2908     }
2909   }
2910   std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
2911                    [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2912                      return lhs.score < rhs.score;
2913                    });
2914 
2915   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2916   // them greedily as long as they do not overlap with any previously picked
2917   // chunks.
2918   std::vector<bool> token_used(inference_span.Size());
2919   chunks->clear();
2920   for (const ScoredChunk& scored_chunk : scored_chunks) {
2921     bool feasible = true;
2922     for (int i = scored_chunk.token_span.first;
2923          i < scored_chunk.token_span.second; ++i) {
2924       if (token_used[i - inference_span.first]) {
2925         feasible = false;
2926         break;
2927       }
2928     }
2929 
2930     if (!feasible) {
2931       continue;
2932     }
2933 
2934     for (int i = scored_chunk.token_span.first;
2935          i < scored_chunk.token_span.second; ++i) {
2936       token_used[i - inference_span.first] = true;
2937     }
2938 
2939     chunks->push_back(scored_chunk.token_span);
2940   }
2941 
2942   std::stable_sort(chunks->begin(), chunks->end());
2943 
2944   return true;
2945 }
2946 
2947 namespace {
2948 // Updates the value at the given key in the map to maximum of the current value
2949 // and the given value, or simply inserts the value if the key is not yet there.
2950 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2951 void UpdateMax(Map* map, typename Map::key_type key,
2952                typename Map::mapped_type value) {
2953   const auto it = map->find(key);
2954   if (it != map->end()) {
2955     it->second = std::max(it->second, value);
2956   } else {
2957     (*map)[key] = value;
2958   }
2959 }
2960 }  // namespace
2961 
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2962 bool Annotator::ModelClickContextScoreChunks(
2963     int num_tokens, const TokenSpan& span_of_interest,
2964     const CachedFeatures& cached_features,
2965     tflite::Interpreter* selection_interpreter,
2966     std::vector<ScoredChunk>* scored_chunks) const {
2967   const int max_batch_size = model_->selection_options()->batch_size();
2968 
2969   std::vector<float> all_features;
2970   std::map<TokenSpan, float> chunk_scores;
2971   for (int batch_start = span_of_interest.first;
2972        batch_start < span_of_interest.second; batch_start += max_batch_size) {
2973     const int batch_end =
2974         std::min(batch_start + max_batch_size, span_of_interest.second);
2975 
2976     // Prepare features for the whole batch.
2977     all_features.clear();
2978     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2979     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2980       cached_features.AppendClickContextFeaturesForClick(click_pos,
2981                                                          &all_features);
2982     }
2983 
2984     // Run batched inference.
2985     const int batch_size = batch_end - batch_start;
2986     const int features_size = cached_features.OutputFeaturesSize();
2987     TensorView<float> logits = selection_executor_->ComputeLogits(
2988         TensorView<float>(all_features.data(), {batch_size, features_size}),
2989         selection_interpreter);
2990     if (!logits.is_valid()) {
2991       TC3_LOG(ERROR) << "Couldn't compute logits.";
2992       return false;
2993     }
2994     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2995         logits.dim(1) !=
2996             selection_feature_processor_->GetSelectionLabelCount()) {
2997       TC3_LOG(ERROR) << "Mismatching output.";
2998       return false;
2999     }
3000 
3001     // Save results.
3002     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3003       const std::vector<float> scores = ComputeSoftmax(
3004           logits.data() + logits.dim(1) * (click_pos - batch_start),
3005           logits.dim(1));
3006       for (int j = 0;
3007            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3008         TokenSpan relative_token_span;
3009         if (!selection_feature_processor_->LabelToTokenSpan(
3010                 j, &relative_token_span)) {
3011           TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3012           return false;
3013         }
3014         const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3015             relative_token_span.first, relative_token_span.second);
3016         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3017           UpdateMax(&chunk_scores, candidate_span, scores[j]);
3018         }
3019       }
3020     }
3021   }
3022 
3023   scored_chunks->clear();
3024   scored_chunks->reserve(chunk_scores.size());
3025   for (const auto& entry : chunk_scores) {
3026     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3027   }
3028 
3029   return true;
3030 }
3031 
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3032 bool Annotator::ModelBoundsSensitiveScoreChunks(
3033     int num_tokens, const TokenSpan& span_of_interest,
3034     const TokenSpan& inference_span, const CachedFeatures& cached_features,
3035     tflite::Interpreter* selection_interpreter,
3036     std::vector<ScoredChunk>* scored_chunks) const {
3037   const int max_selection_span =
3038       selection_feature_processor_->GetOptions()->max_selection_span();
3039   const int max_chunk_length = selection_feature_processor_->GetOptions()
3040                                        ->selection_reduced_output_space()
3041                                    ? max_selection_span + 1
3042                                    : 2 * max_selection_span + 1;
3043   const bool score_single_token_spans_as_zero =
3044       selection_feature_processor_->GetOptions()
3045           ->bounds_sensitive_features()
3046           ->score_single_token_spans_as_zero();
3047 
3048   scored_chunks->clear();
3049   if (score_single_token_spans_as_zero) {
3050     scored_chunks->reserve(span_of_interest.Size());
3051   }
3052 
3053   // Prepare all chunk candidates into one batch:
3054   //   - Are contained in the inference span
3055   //   - Have a non-empty intersection with the span of interest
3056   //   - Are at least one token long
3057   //   - Are not longer than the maximum chunk length
3058   std::vector<TokenSpan> candidate_spans;
3059   for (int start = inference_span.first; start < span_of_interest.second;
3060        ++start) {
3061     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3062     for (int end = leftmost_end_index;
3063          end <= inference_span.second && end - start <= max_chunk_length;
3064          ++end) {
3065       const TokenSpan candidate_span = {start, end};
3066       if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3067         // Do not include the single token span in the batch, add a zero score
3068         // for it directly to the output.
3069         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3070       } else {
3071         candidate_spans.push_back(candidate_span);
3072       }
3073     }
3074   }
3075 
3076   const int max_batch_size = model_->selection_options()->batch_size();
3077 
3078   std::vector<float> all_features;
3079   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3080   for (int batch_start = 0; batch_start < candidate_spans.size();
3081        batch_start += max_batch_size) {
3082     const int batch_end = std::min(batch_start + max_batch_size,
3083                                    static_cast<int>(candidate_spans.size()));
3084 
3085     // Prepare features for the whole batch.
3086     all_features.clear();
3087     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3088     for (int i = batch_start; i < batch_end; ++i) {
3089       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3090                                                            &all_features);
3091     }
3092 
3093     // Run batched inference.
3094     const int batch_size = batch_end - batch_start;
3095     const int features_size = cached_features.OutputFeaturesSize();
3096     TensorView<float> logits = selection_executor_->ComputeLogits(
3097         TensorView<float>(all_features.data(), {batch_size, features_size}),
3098         selection_interpreter);
3099     if (!logits.is_valid()) {
3100       TC3_LOG(ERROR) << "Couldn't compute logits.";
3101       return false;
3102     }
3103     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3104         logits.dim(1) != 1) {
3105       TC3_LOG(ERROR) << "Mismatching output.";
3106       return false;
3107     }
3108 
3109     // Save results.
3110     for (int i = batch_start; i < batch_end; ++i) {
3111       scored_chunks->push_back(
3112           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3113     }
3114   }
3115 
3116   return true;
3117 }
3118 
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3119 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3120                               int64 reference_time_ms_utc,
3121                               const std::string& reference_timezone,
3122                               const std::string& locales, ModeFlag mode,
3123                               AnnotationUsecase annotation_usecase,
3124                               bool is_serialized_entity_data_enabled,
3125                               std::vector<AnnotatedSpan>* result) const {
3126   if (!datetime_parser_) {
3127     return true;
3128   }
3129   LocaleList locale_list = LocaleList::ParseFrom(locales);
3130   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3131       datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3132                               reference_timezone, locale_list, mode,
3133                               annotation_usecase,
3134                               /*anchor_start_end=*/false);
3135   if (!result_status.ok()) {
3136     return false;
3137   }
3138 
3139   for (const DatetimeParseResultSpan& datetime_span :
3140        result_status.ValueOrDie()) {
3141     AnnotatedSpan annotated_span;
3142     annotated_span.span = datetime_span.span;
3143     for (const DatetimeParseResult& parse_result : datetime_span.data) {
3144       annotated_span.classification.emplace_back(
3145           PickCollectionForDatetime(parse_result),
3146           datetime_span.target_classification_score,
3147           datetime_span.priority_score);
3148       annotated_span.classification.back().datetime_parse_result = parse_result;
3149       if (is_serialized_entity_data_enabled) {
3150         annotated_span.classification.back().serialized_entity_data =
3151             CreateDatetimeSerializedEntityData(parse_result);
3152       }
3153     }
3154     annotated_span.source = AnnotatedSpan::Source::DATETIME;
3155     result->push_back(std::move(annotated_span));
3156   }
3157   return true;
3158 }
3159 
model() const3160 const Model* Annotator::model() const { return model_; }
entity_data_schema() const3161 const reflection::Schema* Annotator::entity_data_schema() const {
3162   return entity_data_schema_;
3163 }
3164 
ViewModel(const void * buffer,int size)3165 const Model* ViewModel(const void* buffer, int size) {
3166   if (!buffer) {
3167     return nullptr;
3168   }
3169 
3170   return LoadAndVerifyModel(buffer, size);
3171 }
3172 
LookUpKnowledgeEntity(const std::string & id) const3173 StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3174     const std::string& id) const {
3175   if (!knowledge_engine_) {
3176     return Status(StatusCode::FAILED_PRECONDITION,
3177                   "knowledge_engine_ is nullptr");
3178   }
3179   return knowledge_engine_->LookUpEntity(id);
3180 }
3181 
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3182 StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3183     const std::string& mid_str, const std::string& property) const {
3184   if (!knowledge_engine_) {
3185     return Status(StatusCode::FAILED_PRECONDITION,
3186                   "knowledge_engine_ is nullptr");
3187   }
3188   return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3189 }
3190 
3191 }  // namespace libtextclassifier3
3192