• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/annotator.h"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <cstddef>
22 #include <iterator>
23 #include <limits>
24 #include <numeric>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28 
29 #include "annotator/collections.h"
30 #include "annotator/datetime/grammar-parser.h"
31 #include "annotator/datetime/regex-parser.h"
32 #include "annotator/flatbuffer-utils.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "annotator/model_generated.h"
35 #include "annotator/types.h"
36 #include "utils/base/logging.h"
37 #include "utils/base/status.h"
38 #include "utils/base/statusor.h"
39 #include "utils/calendar/calendar.h"
40 #include "utils/checksum.h"
41 #include "utils/grammar/analyzer.h"
42 #include "utils/i18n/locale-list.h"
43 #include "utils/i18n/locale.h"
44 #include "utils/math/softmax.h"
45 #include "utils/normalization.h"
46 #include "utils/optional.h"
47 #include "utils/regex-match.h"
48 #include "utils/strings/append.h"
49 #include "utils/strings/numbers.h"
50 #include "utils/strings/split.h"
51 #include "utils/utf8/unicodetext.h"
52 #include "utils/utf8/unilib-common.h"
53 #include "utils/zlib/zlib_regex.h"
54 
55 namespace libtextclassifier3 {
56 
57 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58 
59 const std::string& Annotator::kPhoneCollection =
__anona60a5fcb0102() 60     *[]() { return new std::string("phone"); }();
61 const std::string& Annotator::kAddressCollection =
__anona60a5fcb0202() 62     *[]() { return new std::string("address"); }();
63 const std::string& Annotator::kDateCollection =
__anona60a5fcb0302() 64     *[]() { return new std::string("date"); }();
65 const std::string& Annotator::kUrlCollection =
__anona60a5fcb0402() 66     *[]() { return new std::string("url"); }();
67 const std::string& Annotator::kEmailCollection =
__anona60a5fcb0502() 68     *[]() { return new std::string("email"); }();
69 
70 namespace {
LoadAndVerifyModel(const void * addr,int size)71 const Model* LoadAndVerifyModel(const void* addr, int size) {
72   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73   if (VerifyModelBuffer(verifier)) {
74     return GetModel(addr);
75   } else {
76     return nullptr;
77   }
78 }
79 
LoadAndVerifyPersonNameModel(const void * addr,int size)80 const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81                                                     int size) {
82   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83   if (VerifyPersonNameModelBuffer(verifier)) {
84     return GetPersonNameModel(addr);
85   } else {
86     return nullptr;
87   }
88 }
89 
90 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92 const UniLib* MaybeCreateUnilib(const UniLib* lib,
93                                 std::unique_ptr<UniLib>* owned_lib) {
94   if (lib) {
95     return lib;
96   } else {
97     owned_lib->reset(new UniLib);
98     return owned_lib->get();
99   }
100 }
101 
102 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103 const CalendarLib* MaybeCreateCalendarlib(
104     const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105   if (lib) {
106     return lib;
107   } else {
108     owned_lib->reset(new CalendarLib);
109     return owned_lib->get();
110   }
111 }
112 
113 // Returns whether the provided input is valid:
114 //   * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115 bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116   return (span.first >= 0 && span.first < span.second &&
117           span.second <= context.size_codepoints());
118 }
119 
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120 std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121     const flatbuffers::Vector<int32_t>* ints) {
122   if (ints == nullptr) {
123     return {};
124   }
125   std::unordered_set<char32> ints_set;
126   for (auto value : *ints) {
127     ints_set.insert(static_cast<char32>(value));
128   }
129   return ints_set;
130 }
131 
132 }  // namespace
133 
SelectionInterpreter()134 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135   if (!selection_interpreter_) {
136     TC3_CHECK(selection_executor_);
137     selection_interpreter_ = selection_executor_->CreateInterpreter();
138     if (!selection_interpreter_) {
139       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140     }
141   }
142   return selection_interpreter_.get();
143 }
144 
ClassificationInterpreter()145 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146   if (!classification_interpreter_) {
147     TC3_CHECK(classification_executor_);
148     classification_interpreter_ = classification_executor_->CreateInterpreter();
149     if (!classification_interpreter_) {
150       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151     }
152   }
153   return classification_interpreter_.get();
154 }
155 
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157     const char* buffer, int size, const UniLib* unilib,
158     const CalendarLib* calendarlib) {
159   const Model* model = LoadAndVerifyModel(buffer, size);
160   if (model == nullptr) {
161     return nullptr;
162   }
163 
164   auto classifier = std::unique_ptr<Annotator>(new Annotator());
165   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166   calendarlib =
167       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168   classifier->ValidateAndInitialize(model, unilib, calendarlib);
169   if (!classifier->IsInitialized()) {
170     return nullptr;
171   }
172 
173   return classifier;
174 }
175 
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176 std::unique_ptr<Annotator> Annotator::FromString(
177     const std::string& buffer, const UniLib* unilib,
178     const CalendarLib* calendarlib) {
179   auto classifier = std::unique_ptr<Annotator>(new Annotator());
180   classifier->owned_buffer_ = buffer;
181   const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182                                           classifier->owned_buffer_.size());
183   if (model == nullptr) {
184     return nullptr;
185   }
186   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187   calendarlib =
188       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189   classifier->ValidateAndInitialize(model, unilib, calendarlib);
190   if (!classifier->IsInitialized()) {
191     return nullptr;
192   }
193 
194   return classifier;
195 }
196 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199     const CalendarLib* calendarlib) {
200   if (!(*mmap)->handle().ok()) {
201     TC3_VLOG(1) << "Mmap failed.";
202     return nullptr;
203   }
204 
205   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206                                           (*mmap)->handle().num_bytes());
207   if (!model) {
208     TC3_LOG(ERROR) << "Model verification failed.";
209     return nullptr;
210   }
211 
212   auto classifier = std::unique_ptr<Annotator>(new Annotator());
213   classifier->mmap_ = std::move(*mmap);
214   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215   calendarlib =
216       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217   classifier->ValidateAndInitialize(model, unilib, calendarlib);
218   if (!classifier->IsInitialized()) {
219     return nullptr;
220   }
221 
222   return classifier;
223 }
224 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226     std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227     std::unique_ptr<CalendarLib> calendarlib) {
228   if (!(*mmap)->handle().ok()) {
229     TC3_VLOG(1) << "Mmap failed.";
230     return nullptr;
231   }
232 
233   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234                                           (*mmap)->handle().num_bytes());
235   if (model == nullptr) {
236     TC3_LOG(ERROR) << "Model verification failed.";
237     return nullptr;
238   }
239 
240   auto classifier = std::unique_ptr<Annotator>(new Annotator());
241   classifier->mmap_ = std::move(*mmap);
242   classifier->owned_unilib_ = std::move(unilib);
243   classifier->owned_calendarlib_ = std::move(calendarlib);
244   classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245                                     classifier->owned_calendarlib_.get());
246   if (!classifier->IsInitialized()) {
247     return nullptr;
248   }
249 
250   return classifier;
251 }
252 
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254     int fd, int offset, int size, const UniLib* unilib,
255     const CalendarLib* calendarlib) {
256   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257   return FromScopedMmap(&mmap, unilib, calendarlib);
258 }
259 
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261     int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262     std::unique_ptr<CalendarLib> calendarlib) {
263   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265 }
266 
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268     int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270   return FromScopedMmap(&mmap, unilib, calendarlib);
271 }
272 
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274     int fd, std::unique_ptr<UniLib> unilib,
275     std::unique_ptr<CalendarLib> calendarlib) {
276   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278 }
279 
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281                                                const UniLib* unilib,
282                                                const CalendarLib* calendarlib) {
283   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284   return FromScopedMmap(&mmap, unilib, calendarlib);
285 }
286 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287 std::unique_ptr<Annotator> Annotator::FromPath(
288     const std::string& path, std::unique_ptr<UniLib> unilib,
289     std::unique_ptr<CalendarLib> calendarlib) {
290   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292 }
293 
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294 void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295                                       const CalendarLib* calendarlib) {
296   model_ = model;
297   unilib_ = unilib;
298   calendarlib_ = calendarlib;
299 
300   initialized_ = false;
301 
302   if (model_ == nullptr) {
303     TC3_LOG(ERROR) << "No model specified.";
304     return;
305   }
306 
307   const bool model_enabled_for_annotation =
308       (model_->triggering_options() != nullptr &&
309        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310   const bool model_enabled_for_classification =
311       (model_->triggering_options() != nullptr &&
312        (model_->triggering_options()->enabled_modes() &
313         ModeFlag_CLASSIFICATION));
314   const bool model_enabled_for_selection =
315       (model_->triggering_options() != nullptr &&
316        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317 
318   // Annotation requires the selection model.
319   if (model_enabled_for_annotation || model_enabled_for_selection) {
320     if (!model_->selection_options()) {
321       TC3_LOG(ERROR) << "No selection options.";
322       return;
323     }
324     if (!model_->selection_feature_options()) {
325       TC3_LOG(ERROR) << "No selection feature options.";
326       return;
327     }
328     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329       TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330       return;
331     }
332     if (!model_->selection_model()) {
333       TC3_LOG(ERROR) << "No selection model.";
334       return;
335     }
336     selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337     if (!selection_executor_) {
338       TC3_LOG(ERROR) << "Could not initialize selection executor.";
339       return;
340     }
341   }
342 
343   // Even if the annotation mode is not enabled (for the neural network model),
344   // the selection feature processor is needed to tokenize the text for other
345   // models.
346   if (model_->selection_feature_options()) {
347     selection_feature_processor_.reset(
348         new FeatureProcessor(model_->selection_feature_options(), unilib_));
349   }
350 
351   // Annotation requires the classification model for conflict resolution and
352   // scoring.
353   // Selection requires the classification model for conflict resolution.
354   if (model_enabled_for_annotation || model_enabled_for_classification ||
355       model_enabled_for_selection) {
356     if (!model_->classification_options()) {
357       TC3_LOG(ERROR) << "No classification options.";
358       return;
359     }
360 
361     if (!model_->classification_feature_options()) {
362       TC3_LOG(ERROR) << "No classification feature options.";
363       return;
364     }
365 
366     if (!model_->classification_feature_options()
367              ->bounds_sensitive_features()) {
368       TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
369       return;
370     }
371     if (!model_->classification_model()) {
372       TC3_LOG(ERROR) << "No clf model.";
373       return;
374     }
375 
376     classification_executor_ =
377         ModelExecutor::FromBuffer(model_->classification_model());
378     if (!classification_executor_) {
379       TC3_LOG(ERROR) << "Could not initialize classification executor.";
380       return;
381     }
382 
383     classification_feature_processor_.reset(new FeatureProcessor(
384         model_->classification_feature_options(), unilib_));
385   }
386 
387   // The embeddings need to be specified if the model is to be used for
388   // classification or selection.
389   if (model_enabled_for_annotation || model_enabled_for_classification ||
390       model_enabled_for_selection) {
391     if (!model_->embedding_model()) {
392       TC3_LOG(ERROR) << "No embedding model.";
393       return;
394     }
395 
396     // Check that the embedding size of the selection and classification model
397     // matches, as they are using the same embeddings.
398     if (model_enabled_for_selection &&
399         (model_->selection_feature_options()->embedding_size() !=
400              model_->classification_feature_options()->embedding_size() ||
401          model_->selection_feature_options()->embedding_quantization_bits() !=
402              model_->classification_feature_options()
403                  ->embedding_quantization_bits())) {
404       TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
405       return;
406     }
407 
408     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
409         model_->embedding_model(),
410         model_->classification_feature_options()->embedding_size(),
411         model_->classification_feature_options()->embedding_quantization_bits(),
412         model_->embedding_pruning_mask());
413     if (!embedding_executor_) {
414       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
415       return;
416     }
417   }
418 
419   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
420   if (model_->regex_model()) {
421     if (!InitializeRegexModel(decompressor.get())) {
422       TC3_LOG(ERROR) << "Could not initialize regex model.";
423       return;
424     }
425   }
426 
427   if (model_->datetime_grammar_model()) {
428     if (model_->datetime_grammar_model()->rules()) {
429       analyzer_ = std::make_unique<grammar::Analyzer>(
430           unilib_, model_->datetime_grammar_model()->rules());
431       datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
432       datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
433           *analyzer_, *datetime_grounder_,
434           /*target_classification_score=*/1.0,
435           /*priority_score=*/1.0,
436           model_->datetime_grammar_model()->enabled_modes());
437     }
438   } else if (model_->datetime_model()) {
439     datetime_parser_ = RegexDatetimeParser::Instance(
440         model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
441     if (!datetime_parser_) {
442       TC3_LOG(ERROR) << "Could not initialize datetime parser.";
443       return;
444     }
445   }
446 
447   if (model_->output_options()) {
448     if (model_->output_options()->filtered_collections_annotation()) {
449       for (const auto collection :
450            *model_->output_options()->filtered_collections_annotation()) {
451         filtered_collections_annotation_.insert(collection->str());
452       }
453     }
454     if (model_->output_options()->filtered_collections_classification()) {
455       for (const auto collection :
456            *model_->output_options()->filtered_collections_classification()) {
457         filtered_collections_classification_.insert(collection->str());
458       }
459     }
460     if (model_->output_options()->filtered_collections_selection()) {
461       for (const auto collection :
462            *model_->output_options()->filtered_collections_selection()) {
463         filtered_collections_selection_.insert(collection->str());
464       }
465     }
466   }
467 
468   if (model_->number_annotator_options() &&
469       model_->number_annotator_options()->enabled()) {
470     number_annotator_.reset(
471         new NumberAnnotator(model_->number_annotator_options(), unilib_));
472   }
473 
474   if (model_->money_parsing_options()) {
475     money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
476         model_->money_parsing_options()->separators());
477   }
478 
479   if (model_->duration_annotator_options() &&
480       model_->duration_annotator_options()->enabled()) {
481     duration_annotator_.reset(
482         new DurationAnnotator(model_->duration_annotator_options(),
483                               selection_feature_processor_.get(), unilib_));
484   }
485 
486   if (model_->grammar_model()) {
487     grammar_annotator_.reset(new GrammarAnnotator(
488         unilib_, model_->grammar_model(), entity_data_builder_.get()));
489   }
490 
491   // The following #ifdef is here to aid quality evaluation of a situation, when
492   // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
493   // it.
494 #if !defined(TC3_DISABLE_POD_NER)
495   if (model_->pod_ner_model()) {
496     pod_ner_annotator_ =
497         PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
498   }
499 #endif
500 
501   if (model_->vocab_model()) {
502     vocab_annotator_ = VocabAnnotator::Create(
503         model_->vocab_model(), *selection_feature_processor_, *unilib_);
504   }
505 
506   if (model_->entity_data_schema()) {
507     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
508         model_->entity_data_schema()->Data(),
509         model_->entity_data_schema()->size());
510     if (entity_data_schema_ == nullptr) {
511       TC3_LOG(ERROR) << "Could not load entity data schema data.";
512       return;
513     }
514 
515     entity_data_builder_.reset(
516         new MutableFlatbufferBuilder(entity_data_schema_));
517   } else {
518     entity_data_schema_ = nullptr;
519     entity_data_builder_ = nullptr;
520   }
521 
522   if (model_->triggering_locales() &&
523       !ParseLocales(model_->triggering_locales()->c_str(),
524                     &model_triggering_locales_)) {
525     TC3_LOG(ERROR) << "Could not parse model supported locales.";
526     return;
527   }
528 
529   if (model_->triggering_options() != nullptr &&
530       model_->triggering_options()->locales() != nullptr &&
531       !ParseLocales(model_->triggering_options()->locales()->c_str(),
532                     &ml_model_triggering_locales_)) {
533     TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
534     return;
535   }
536 
537   if (model_->triggering_options() != nullptr &&
538       model_->triggering_options()->dictionary_locales() != nullptr &&
539       !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
540                     &dictionary_locales_)) {
541     TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
542     return;
543   }
544 
545   if (model_->conflict_resolution_options() != nullptr) {
546     prioritize_longest_annotation_ =
547         model_->conflict_resolution_options()->prioritize_longest_annotation();
548     do_conflict_resolution_in_raw_mode_ =
549         model_->conflict_resolution_options()
550             ->do_conflict_resolution_in_raw_mode();
551   }
552 
553 #ifdef TC3_EXPERIMENTAL
554   TC3_LOG(WARNING) << "Enabling experimental annotators.";
555   InitializeExperimentalAnnotators();
556 #endif
557 
558   initialized_ = true;
559 }
560 
InitializeRegexModel(ZlibDecompressor * decompressor)561 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
562   if (!model_->regex_model()->patterns()) {
563     return true;
564   }
565 
566   // Initialize pattern recognizers.
567   int regex_pattern_id = 0;
568   for (const auto regex_pattern : *model_->regex_model()->patterns()) {
569     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
570         UncompressMakeRegexPattern(
571             *unilib_, regex_pattern->pattern(),
572             regex_pattern->compressed_pattern(),
573             model_->regex_model()->lazy_regex_compilation(), decompressor);
574     if (!compiled_pattern) {
575       TC3_LOG(INFO) << "Failed to load regex pattern";
576       return false;
577     }
578 
579     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
580       annotation_regex_patterns_.push_back(regex_pattern_id);
581     }
582     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
583       classification_regex_patterns_.push_back(regex_pattern_id);
584     }
585     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
586       selection_regex_patterns_.push_back(regex_pattern_id);
587     }
588     regex_patterns_.push_back({
589         regex_pattern,
590         std::move(compiled_pattern),
591     });
592     ++regex_pattern_id;
593   }
594 
595   return true;
596 }
597 
InitializeKnowledgeEngine(const std::string & serialized_config)598 bool Annotator::InitializeKnowledgeEngine(
599     const std::string& serialized_config) {
600   std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
601   if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
602     TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
603     return false;
604   }
605   if (model_->triggering_options() != nullptr) {
606     knowledge_engine->SetPriorityScore(
607         model_->triggering_options()->knowledge_priority_score());
608     knowledge_engine->SetEnabledModes(
609         model_->triggering_options()->knowledge_enabled_modes());
610   }
611   knowledge_engine_ = std::move(knowledge_engine);
612   return true;
613 }
614 
InitializeContactEngine(const std::string & serialized_config)615 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
616   std::unique_ptr<ContactEngine> contact_engine(
617       new ContactEngine(selection_feature_processor_.get(), unilib_,
618                         model_->contact_annotator_options()));
619   if (!contact_engine->Initialize(serialized_config)) {
620     TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
621     return false;
622   }
623   contact_engine_ = std::move(contact_engine);
624   return true;
625 }
626 
CleanUpContactEngine()627 void Annotator::CleanUpContactEngine() {
628   if (contact_engine_ == nullptr) {
629     TC3_LOG(INFO)
630         << "Attempting to clean up contact engine that does not exist.";
631     return;
632   }
633   contact_engine_->CleanUp();
634 }
635 
InitializeInstalledAppEngine(const std::string & serialized_config)636 bool Annotator::InitializeInstalledAppEngine(
637     const std::string& serialized_config) {
638   std::unique_ptr<InstalledAppEngine> installed_app_engine(
639       new InstalledAppEngine(
640           selection_feature_processor_.get(), unilib_,
641           model_->triggering_options()->installed_app_enabled_modes()));
642   if (!installed_app_engine->Initialize(serialized_config)) {
643     TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
644     return false;
645   }
646   installed_app_engine_ = std::move(installed_app_engine);
647   return true;
648 }
649 
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)650 bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
651   if (lang_id == nullptr) {
652     return false;
653   }
654 
655   lang_id_ = lang_id;
656   if (lang_id_ != nullptr && model_->translate_annotator_options() &&
657       model_->translate_annotator_options()->enabled()) {
658     translate_annotator_.reset(new TranslateAnnotator(
659         model_->translate_annotator_options(), lang_id_, unilib_));
660   } else {
661     translate_annotator_.reset(nullptr);
662   }
663   return true;
664 }
665 
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)666 bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
667                                                             int size) {
668   const PersonNameModel* person_name_model =
669       LoadAndVerifyPersonNameModel(buffer, size);
670 
671   if (person_name_model == nullptr) {
672     TC3_LOG(ERROR) << "Person name model verification failed.";
673     return false;
674   }
675 
676   if (!person_name_model->enabled()) {
677     return true;
678   }
679 
680   std::unique_ptr<PersonNameEngine> person_name_engine(
681       new PersonNameEngine(selection_feature_processor_.get(), unilib_));
682   if (!person_name_engine->Initialize(person_name_model)) {
683     TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
684     return false;
685   }
686   person_name_engine_ = std::move(person_name_engine);
687   return true;
688 }
689 
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)690 bool Annotator::InitializePersonNameEngineFromScopedMmap(
691     const ScopedMmap& mmap) {
692   if (!mmap.handle().ok()) {
693     TC3_LOG(ERROR) << "Mmap for person name model failed.";
694     return false;
695   }
696 
697   return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
698                                                      mmap.handle().num_bytes());
699 }
700 
InitializePersonNameEngineFromPath(const std::string & path)701 bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
702   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
703   return InitializePersonNameEngineFromScopedMmap(*mmap);
704 }
705 
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)706 bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
707                                                              int size) {
708   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
709   return InitializePersonNameEngineFromScopedMmap(*mmap);
710 }
711 
InitializeExperimentalAnnotators()712 bool Annotator::InitializeExperimentalAnnotators() {
713   if (ExperimentalAnnotator::IsEnabled()) {
714     experimental_annotator_.reset(new ExperimentalAnnotator(
715         model_->experimental_model(), *selection_feature_processor_, *unilib_));
716     return true;
717   }
718   return false;
719 }
720 
721 namespace internal {
722 // Helper function, which if the initial 'span' contains only white-spaces,
723 // moves the selection to a single-codepoint selection on a left or right side
724 // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)725 CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
726                                             const UnicodeText& context_unicode,
727                                             const UniLib& unilib) {
728   TC3_CHECK(span.IsValid() && !span.IsEmpty());
729 
730   UnicodeText::const_iterator it;
731 
732   // Check that the current selection is all whitespaces.
733   it = context_unicode.begin();
734   std::advance(it, span.first);
735   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
736     if (!unilib.IsWhitespace(*it)) {
737       return span;
738     }
739   }
740 
741   // Try moving left.
742   CodepointSpan result = span;
743   it = context_unicode.begin();
744   std::advance(it, span.first);
745   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
746     --result.first;
747     --it;
748   }
749   result.second = result.first + 1;
750   if (!unilib.IsWhitespace(*it)) {
751     return result;
752   }
753 
754   // If moving left didn't find a non-whitespace character, just return the
755   // original span.
756   return span;
757 }
758 }  // namespace internal
759 
FilteredForAnnotation(const AnnotatedSpan & span) const760 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
761   return !span.classification.empty() &&
762          filtered_collections_annotation_.find(
763              span.classification[0].collection) !=
764              filtered_collections_annotation_.end();
765 }
766 
FilteredForClassification(const ClassificationResult & classification) const767 bool Annotator::FilteredForClassification(
768     const ClassificationResult& classification) const {
769   return filtered_collections_classification_.find(classification.collection) !=
770          filtered_collections_classification_.end();
771 }
772 
FilteredForSelection(const AnnotatedSpan & span) const773 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
774   return !span.classification.empty() &&
775          filtered_collections_selection_.find(
776              span.classification[0].collection) !=
777              filtered_collections_selection_.end();
778 }
779 
780 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)781 inline bool ClassifiedAsOther(
782     const std::vector<ClassificationResult>& classification) {
783   return !classification.empty() &&
784          classification[0].collection == Collections::Other();
785 }
786 
787 }  // namespace
788 
GetPriorityScore(const std::vector<ClassificationResult> & classification) const789 float Annotator::GetPriorityScore(
790     const std::vector<ClassificationResult>& classification) const {
791   if (!classification.empty() && !ClassifiedAsOther(classification)) {
792     return classification[0].priority_score;
793   } else {
794     if (model_->triggering_options() != nullptr) {
795       return model_->triggering_options()->other_collection_priority_score();
796     } else {
797       return -1000.0;
798     }
799   }
800 }
801 
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const802 bool Annotator::VerifyRegexMatchCandidate(
803     const std::string& context, const VerificationOptions* verification_options,
804     const std::string& match, const UniLib::RegexMatcher* matcher) const {
805   if (verification_options == nullptr) {
806     return true;
807   }
808   if (verification_options->verify_luhn_checksum() &&
809       !VerifyLuhnChecksum(match)) {
810     return false;
811   }
812   const int lua_verifier = verification_options->lua_verifier();
813   if (lua_verifier >= 0) {
814     if (model_->regex_model()->lua_verifier() == nullptr ||
815         lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
816       TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
817       return false;
818     }
819     return VerifyMatch(
820         context, matcher,
821         model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
822   }
823   return true;
824 }
825 
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const826 CodepointSpan Annotator::SuggestSelection(
827     const std::string& context, CodepointSpan click_indices,
828     const SelectionOptions& options) const {
829   if (context.size() > std::numeric_limits<int>::max()) {
830     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
831     return {};
832   }
833 
834   CodepointSpan original_click_indices = click_indices;
835   if (!initialized_) {
836     TC3_LOG(ERROR) << "Not initialized";
837     return original_click_indices;
838   }
839   if (options.annotation_usecase !=
840       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
841     TC3_LOG(WARNING)
842         << "Invoking SuggestSelection, which is not supported in RAW mode.";
843     return original_click_indices;
844   }
845   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
846     return original_click_indices;
847   }
848 
849   std::vector<Locale> detected_text_language_tags;
850   if (!ParseLocales(options.detected_text_language_tags,
851                     &detected_text_language_tags)) {
852     TC3_LOG(WARNING)
853         << "Failed to parse the detected_text_language_tags in options: "
854         << options.detected_text_language_tags;
855   }
856   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
857                                     model_triggering_locales_,
858                                     /*default_value=*/true)) {
859     return original_click_indices;
860   }
861 
862   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
863                                                         /*do_copy=*/false);
864 
865   if (!unilib_->IsValidUtf8(context_unicode)) {
866     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
867     return original_click_indices;
868   }
869 
870   if (!IsValidSpanInput(context_unicode, click_indices)) {
871     TC3_VLOG(1)
872         << "Trying to run SuggestSelection with invalid input, indices: "
873         << click_indices.first << " " << click_indices.second;
874     return original_click_indices;
875   }
876 
877   if (model_->snap_whitespace_selections()) {
878     // We want to expand a purely white-space selection to a multi-selection it
879     // would've been part of. But with this feature disabled we would do a no-
880     // op, because no token is found. Therefore, we need to modify the
881     // 'click_indices' a bit to include a part of the token, so that the click-
882     // finding logic finds the clicked token correctly. This modification is
883     // done by the following function. Note, that it's enough to check the left
884     // side of the current selection, because if the white-space is a part of a
885     // multi-selection, necessarily both tokens - on the left and the right
886     // sides need to be selected. Thus snapping only to the left is sufficient
887     // (there's a check at the bottom that makes sure that if we snap to the
888     // left token but the result does not contain the initial white-space,
889     // returns the original indices).
890     click_indices = internal::SnapLeftIfWhitespaceSelection(
891         click_indices, context_unicode, *unilib_);
892   }
893 
894   Annotations candidates;
895   // As we process a single string of context, the candidates will only
896   // contain one vector of AnnotatedSpan.
897   candidates.annotated_spans.resize(1);
898   InterpreterManager interpreter_manager(selection_executor_.get(),
899                                          classification_executor_.get());
900   std::vector<Token> tokens;
901   if (!ModelSuggestSelection(context_unicode, click_indices,
902                              detected_text_language_tags, &interpreter_manager,
903                              &tokens, &candidates.annotated_spans[0])) {
904     TC3_LOG(ERROR) << "Model suggest selection failed.";
905     return original_click_indices;
906   }
907   const std::unordered_set<std::string> set;
908   const EnabledEntityTypes is_entity_type_enabled(set);
909   if (!RegexChunk(context_unicode, selection_regex_patterns_,
910                   /*is_serialized_entity_data_enabled=*/false,
911                   is_entity_type_enabled, options.annotation_usecase,
912                   &candidates.annotated_spans[0])) {
913     TC3_LOG(ERROR) << "Regex suggest selection failed.";
914     return original_click_indices;
915   }
916   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
917                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
918                      options.locales, ModeFlag_SELECTION,
919                      options.annotation_usecase,
920                      /*is_serialized_entity_data_enabled=*/false,
921                      &candidates.annotated_spans[0])) {
922     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
923     return original_click_indices;
924   }
925   if (knowledge_engine_ != nullptr &&
926       !knowledge_engine_
927            ->Chunk(context, options.annotation_usecase,
928                    options.location_context, Permissions(),
929                    AnnotateMode::kEntityAnnotation, ModeFlag_SELECTION,
930                    &candidates)
931            .ok()) {
932     TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
933     return original_click_indices;
934   }
935   if (contact_engine_ != nullptr &&
936       !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
937                               &candidates.annotated_spans[0])) {
938     TC3_LOG(ERROR) << "Contact suggest selection failed.";
939     return original_click_indices;
940   }
941   if (installed_app_engine_ != nullptr &&
942       !installed_app_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
943                                     &candidates.annotated_spans[0])) {
944     TC3_LOG(ERROR) << "Installed app suggest selection failed.";
945     return original_click_indices;
946   }
947   if (number_annotator_ != nullptr &&
948       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
949                                   ModeFlag_SELECTION,
950                                   &candidates.annotated_spans[0])) {
951     TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
952     return original_click_indices;
953   }
954   if (duration_annotator_ != nullptr &&
955       !duration_annotator_->FindAll(
956           context_unicode, tokens, options.annotation_usecase,
957           ModeFlag_SELECTION, &candidates.annotated_spans[0])) {
958     TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
959     return original_click_indices;
960   }
961   if (person_name_engine_ != nullptr &&
962       !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_SELECTION,
963                                   &candidates.annotated_spans[0])) {
964     TC3_LOG(ERROR) << "Person name suggest selection failed.";
965     return original_click_indices;
966   }
967 
968   AnnotatedSpan grammar_suggested_span;
969   if (grammar_annotator_ != nullptr &&
970       grammar_annotator_->SuggestSelection(detected_text_language_tags,
971                                            context_unicode, click_indices,
972                                            &grammar_suggested_span)) {
973     candidates.annotated_spans[0].push_back(grammar_suggested_span);
974   }
975 
976   AnnotatedSpan pod_ner_suggested_span;
977   if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
978       pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
979                                            &pod_ner_suggested_span)) {
980     candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
981   }
982 
983   if (experimental_annotator_ != nullptr &&
984       (model_->triggering_options()->experimental_enabled_modes() &
985        ModeFlag_SELECTION)) {
986     candidates.annotated_spans[0].push_back(
987         experimental_annotator_->SuggestSelection(context_unicode,
988                                                   click_indices));
989   }
990 
991   // Sort candidates according to their position in the input, so that the next
992   // code can assume that any connected component of overlapping spans forms a
993   // contiguous block.
994   std::stable_sort(candidates.annotated_spans[0].begin(),
995                    candidates.annotated_spans[0].end(),
996                    [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
997                      return a.span.first < b.span.first;
998                    });
999 
1000   std::vector<int> candidate_indices;
1001   if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
1002                         detected_text_language_tags, options,
1003                         &interpreter_manager, &candidate_indices)) {
1004     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1005     return original_click_indices;
1006   }
1007 
1008   std::stable_sort(
1009       candidate_indices.begin(), candidate_indices.end(),
1010       [this, &candidates](int a, int b) {
1011         return GetPriorityScore(
1012                    candidates.annotated_spans[0][a].classification) >
1013                GetPriorityScore(
1014                    candidates.annotated_spans[0][b].classification);
1015       });
1016 
1017   for (const int i : candidate_indices) {
1018     if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
1019         SpansOverlap(candidates.annotated_spans[0][i].span,
1020                      original_click_indices)) {
1021       // Run model classification if not present but requested and there's a
1022       // classification collection filter specified.
1023       if (candidates.annotated_spans[0][i].classification.empty() &&
1024           model_->selection_options()->always_classify_suggested_selection() &&
1025           !filtered_collections_selection_.empty()) {
1026         if (!ModelClassifyText(context, /*cached_tokens=*/{},
1027                                detected_text_language_tags,
1028                                candidates.annotated_spans[0][i].span, options,
1029                                &interpreter_manager,
1030                                /*embedding_cache=*/nullptr,
1031                                &candidates.annotated_spans[0][i].classification,
1032                                /*tokens=*/nullptr)) {
1033           return original_click_indices;
1034         }
1035       }
1036 
1037       // Ignore if span classification is filtered.
1038       if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1039         return original_click_indices;
1040       }
1041 
1042       // We return a suggested span contains the original span.
1043       // This compensates for "select all" selection that may come from
1044       // other apps. See http://b/179890518.
1045       if (SpanContains(candidates.annotated_spans[0][i].span,
1046                        original_click_indices)) {
1047         return candidates.annotated_spans[0][i].span;
1048       }
1049     }
1050   }
1051 
1052   return original_click_indices;
1053 }
1054 
1055 namespace {
1056 // Helper function that returns the index of the first candidate that
1057 // transitively does not overlap with the candidate on 'start_index'. If the end
1058 // of 'candidates' is reached, it returns the index that points right behind the
1059 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1060 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1061                                  int start_index) {
1062   int first_non_overlapping = start_index + 1;
1063   CodepointSpan conflicting_span = candidates[start_index].span;
1064   while (
1065       first_non_overlapping < candidates.size() &&
1066       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1067     // Grow the span to include the current one.
1068     conflicting_span.second = std::max(
1069         conflicting_span.second, candidates[first_non_overlapping].span.second);
1070 
1071     ++first_non_overlapping;
1072   }
1073   return first_non_overlapping;
1074 }
1075 }  // namespace
1076 
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1077 bool Annotator::ResolveConflicts(
1078     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1079     const std::vector<Token>& cached_tokens,
1080     const std::vector<Locale>& detected_text_language_tags,
1081     const BaseOptions& options, InterpreterManager* interpreter_manager,
1082     std::vector<int>* result) const {
1083   result->clear();
1084   result->reserve(candidates.size());
1085   for (int i = 0; i < candidates.size();) {
1086     int first_non_overlapping =
1087         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1088 
1089     const bool conflict_found = first_non_overlapping != (i + 1);
1090     if (conflict_found) {
1091       std::vector<int> candidate_indices;
1092       if (!ResolveConflict(context, cached_tokens, candidates,
1093                            detected_text_language_tags, i,
1094                            first_non_overlapping, options, interpreter_manager,
1095                            &candidate_indices)) {
1096         return false;
1097       }
1098       result->insert(result->end(), candidate_indices.begin(),
1099                      candidate_indices.end());
1100     } else {
1101       result->push_back(i);
1102     }
1103 
1104     // Skip over the whole conflicting group/go to next candidate.
1105     i = first_non_overlapping;
1106   }
1107   return true;
1108 }
1109 
1110 namespace {
1111 // Returns true, if the given two sources do conflict in given annotation
1112 // usecase.
1113 //  - In SMART usecase, all sources do conflict, because there's only 1 possible
1114 //  annotation for a given span.
1115 //  - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1116 //  and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1117 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1118                        const AnnotatedSpan::Source source1,
1119                        const AnnotatedSpan::Source source2) {
1120   uint32 source_mask =
1121       (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1122 
1123   switch (annotation_usecase) {
1124     case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1125       // In the SMART mode, all annotations conflict.
1126       return true;
1127 
1128     case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1129       // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1130       // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1131       // hours" (duration).
1132       if ((source_mask &
1133            (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1134           (source_mask &
1135            (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1136         return false;
1137       }
1138 
1139       // A KNOWLEDGE entity does not conflict with anything.
1140       if ((source_mask &
1141            (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1142         return false;
1143       }
1144 
1145       // A PERSONNAME entity does not conflict with anything.
1146       if ((source_mask &
1147            (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1148         return false;
1149       }
1150 
1151       // Entities from other sources can conflict.
1152       return true;
1153   }
1154 }
1155 }  // namespace
1156 
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1157 bool Annotator::ResolveConflict(
1158     const std::string& context, const std::vector<Token>& cached_tokens,
1159     const std::vector<AnnotatedSpan>& candidates,
1160     const std::vector<Locale>& detected_text_language_tags, int start_index,
1161     int end_index, const BaseOptions& options,
1162     InterpreterManager* interpreter_manager,
1163     std::vector<int>* chosen_indices) const {
1164   std::vector<int> conflicting_indices;
1165   std::unordered_map<int, std::pair<float, int>> scores_lengths;
1166   for (int i = start_index; i < end_index; ++i) {
1167     conflicting_indices.push_back(i);
1168     if (!candidates[i].classification.empty()) {
1169       scores_lengths[i] = {
1170           GetPriorityScore(candidates[i].classification),
1171           candidates[i].span.second - candidates[i].span.first};
1172       continue;
1173     }
1174 
1175     // OPTIMIZATION: So that we don't have to classify all the ML model
1176     // spans apriori, we wait until we get here, when they conflict with
1177     // something and we need the actual classification scores. So if the
1178     // candidate conflicts and comes from the model, we need to run a
1179     // classification to determine its priority:
1180     std::vector<ClassificationResult> classification;
1181     if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1182                            candidates[i].span, options, interpreter_manager,
1183                            /*embedding_cache=*/nullptr, &classification,
1184                            /*tokens=*/nullptr)) {
1185       return false;
1186     }
1187 
1188     if (!classification.empty()) {
1189       scores_lengths[i] = {
1190           GetPriorityScore(classification),
1191           candidates[i].span.second - candidates[i].span.first};
1192     }
1193   }
1194 
1195   std::stable_sort(
1196       conflicting_indices.begin(), conflicting_indices.end(),
1197       [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1198         if (scores_lengths[i].first == scores_lengths[j].first &&
1199             prioritize_longest_annotation_) {
1200           return scores_lengths[i].second > scores_lengths[j].second;
1201         }
1202         return scores_lengths[i].first > scores_lengths[j].first;
1203       });
1204 
1205   // Here we keep a set of indices that were chosen, per-source, to enable
1206   // effective computation.
1207   std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1208       chosen_indices_for_source_map;
1209 
1210   // Greedily place the candidates if they don't conflict with the already
1211   // placed ones.
1212   for (int i = 0; i < conflicting_indices.size(); ++i) {
1213     const int considered_candidate = conflicting_indices[i];
1214 
1215     // See if there is a conflict between the candidate and all already placed
1216     // candidates.
1217     bool conflict = false;
1218     SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1219     for (auto& source_set_pair : chosen_indices_for_source_map) {
1220       if (source_set_pair.first == candidates[considered_candidate].source) {
1221         chosen_indices_for_source_ptr = &source_set_pair.second;
1222       }
1223 
1224       const bool needs_conflict_resolution =
1225           options.annotation_usecase ==
1226               AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1227           (options.annotation_usecase ==
1228                AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1229            do_conflict_resolution_in_raw_mode_);
1230       if (needs_conflict_resolution &&
1231           DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1232                             candidates[considered_candidate].source) &&
1233           DoesCandidateConflict(considered_candidate, candidates,
1234                                 source_set_pair.second)) {
1235         conflict = true;
1236         break;
1237       }
1238     }
1239 
1240     // Skip the candidate if a conflict was found.
1241     if (conflict) {
1242       continue;
1243     }
1244 
1245     // If the set of indices for the current source doesn't exist yet,
1246     // initialize it.
1247     if (chosen_indices_for_source_ptr == nullptr) {
1248       SortedIntSet new_set([&candidates](int a, int b) {
1249         return candidates[a].span.first < candidates[b].span.first;
1250       });
1251       chosen_indices_for_source_map[candidates[considered_candidate].source] =
1252           std::move(new_set);
1253       chosen_indices_for_source_ptr =
1254           &chosen_indices_for_source_map[candidates[considered_candidate]
1255                                              .source];
1256     }
1257 
1258     // Place the candidate to the output and to the per-source conflict set.
1259     chosen_indices->push_back(considered_candidate);
1260     chosen_indices_for_source_ptr->insert(considered_candidate);
1261   }
1262 
1263   std::stable_sort(chosen_indices->begin(), chosen_indices->end());
1264 
1265   return true;
1266 }
1267 
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1268 bool Annotator::ModelSuggestSelection(
1269     const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1270     const std::vector<Locale>& detected_text_language_tags,
1271     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1272     std::vector<AnnotatedSpan>* result) const {
1273   if (model_->triggering_options() == nullptr ||
1274       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1275     return true;
1276   }
1277 
1278   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1279                                     ml_model_triggering_locales_,
1280                                     /*default_value=*/true)) {
1281     return true;
1282   }
1283 
1284   int click_pos;
1285   *tokens = selection_feature_processor_->Tokenize(context_unicode);
1286   const auto [click_begin, click_end] =
1287       CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1288   selection_feature_processor_->RetokenizeAndFindClick(
1289       context_unicode, click_begin, click_end, click_indices,
1290       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1291       tokens, &click_pos);
1292   if (click_pos == kInvalidIndex) {
1293     TC3_VLOG(1) << "Could not calculate the click position.";
1294     return false;
1295   }
1296 
1297   const int symmetry_context_size =
1298       model_->selection_options()->symmetry_context_size();
1299   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1300       bounds_sensitive_features = selection_feature_processor_->GetOptions()
1301                                       ->bounds_sensitive_features();
1302 
1303   // The symmetry context span is the clicked token with symmetry_context_size
1304   // tokens on either side.
1305   const TokenSpan symmetry_context_span =
1306       IntersectTokenSpans(TokenSpan(click_pos).Expand(
1307                               /*num_tokens_left=*/symmetry_context_size,
1308                               /*num_tokens_right=*/symmetry_context_size),
1309                           AllOf(*tokens));
1310 
1311   // Compute the extraction span based on the model type.
1312   TokenSpan extraction_span;
1313   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1314     // The extraction span is the symmetry context span expanded to include
1315     // max_selection_span tokens on either side, which is how far a selection
1316     // can stretch from the click, plus a relevant number of tokens outside of
1317     // the bounds of the selection.
1318     const int max_selection_span =
1319         selection_feature_processor_->GetOptions()->max_selection_span();
1320     extraction_span = symmetry_context_span.Expand(
1321         /*num_tokens_left=*/max_selection_span +
1322             bounds_sensitive_features->num_tokens_before(),
1323         /*num_tokens_right=*/max_selection_span +
1324             bounds_sensitive_features->num_tokens_after());
1325   } else {
1326     // The extraction span is the symmetry context span expanded to include
1327     // context_size tokens on either side.
1328     const int context_size =
1329         selection_feature_processor_->GetOptions()->context_size();
1330     extraction_span = symmetry_context_span.Expand(
1331         /*num_tokens_left=*/context_size,
1332         /*num_tokens_right=*/context_size);
1333   }
1334   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1335 
1336   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1337           *tokens, extraction_span)) {
1338     return true;
1339   }
1340 
1341   std::unique_ptr<CachedFeatures> cached_features;
1342   if (!selection_feature_processor_->ExtractFeatures(
1343           *tokens, extraction_span,
1344           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1345           embedding_executor_.get(),
1346           /*embedding_cache=*/nullptr,
1347           selection_feature_processor_->EmbeddingSize() +
1348               selection_feature_processor_->DenseFeaturesCount(),
1349           &cached_features)) {
1350     TC3_LOG(ERROR) << "Could not extract features.";
1351     return false;
1352   }
1353 
1354   // Produce selection model candidates.
1355   std::vector<TokenSpan> chunks;
1356   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1357                   interpreter_manager->SelectionInterpreter(), *cached_features,
1358                   &chunks)) {
1359     TC3_LOG(ERROR) << "Could not chunk.";
1360     return false;
1361   }
1362 
1363   for (const TokenSpan& chunk : chunks) {
1364     AnnotatedSpan candidate;
1365     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1366         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1367     if (model_->selection_options()->strip_unpaired_brackets()) {
1368       candidate.span =
1369           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1370     }
1371 
1372     // Only output non-empty spans.
1373     if (candidate.span.first != candidate.span.second) {
1374       result->push_back(candidate);
1375     }
1376   }
1377   return true;
1378 }
1379 
1380 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1381 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1382                                     const CodepointSpan& selection_indices,
1383                                     TokenSpan tokens_around_selection_to_copy) {
1384   const auto first_selection_token = std::upper_bound(
1385       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1386       [](int selection_start, const Token& token) {
1387         return selection_start < token.end;
1388       });
1389   const auto last_selection_token = std::lower_bound(
1390       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1391       [](const Token& token, int selection_end) {
1392         return token.start < selection_end;
1393       });
1394 
1395   const int64 first_token = std::max(
1396       static_cast<int64>(0),
1397       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1398                          tokens_around_selection_to_copy.first));
1399   const int64 last_token = std::min(
1400       static_cast<int64>(cached_tokens.size()),
1401       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1402                          tokens_around_selection_to_copy.second));
1403 
1404   std::vector<Token> tokens;
1405   tokens.reserve(last_token - first_token);
1406   for (int i = first_token; i < last_token; ++i) {
1407     tokens.push_back(cached_tokens[i]);
1408   }
1409   return tokens;
1410 }
1411 }  // namespace internal
1412 
ClassifyTextUpperBoundNeededTokens() const1413 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1414   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1415       bounds_sensitive_features =
1416           classification_feature_processor_->GetOptions()
1417               ->bounds_sensitive_features();
1418   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1419     // The extraction span is the selection span expanded to include a relevant
1420     // number of tokens outside of the bounds of the selection.
1421     return {bounds_sensitive_features->num_tokens_before(),
1422             bounds_sensitive_features->num_tokens_after()};
1423   } else {
1424     // The extraction span is the clicked token with context_size tokens on
1425     // either side.
1426     const int context_size =
1427         selection_feature_processor_->GetOptions()->context_size();
1428     return {context_size, context_size};
1429   }
1430 }
1431 
1432 namespace {
1433 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1434 void SortClassificationResults(
1435     std::vector<ClassificationResult>* classification_results) {
1436   std::stable_sort(
1437       classification_results->begin(), classification_results->end(),
1438       [](const ClassificationResult& a, const ClassificationResult& b) {
1439         return a.score > b.score;
1440       });
1441 }
1442 }  // namespace
1443 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1444 bool Annotator::ModelClassifyText(
1445     const std::string& context, const std::vector<Token>& cached_tokens,
1446     const std::vector<Locale>& detected_text_language_tags,
1447     const CodepointSpan& selection_indices, const BaseOptions& options,
1448     InterpreterManager* interpreter_manager,
1449     FeatureProcessor::EmbeddingCache* embedding_cache,
1450     std::vector<ClassificationResult>* classification_results,
1451     std::vector<Token>* tokens) const {
1452   const UnicodeText context_unicode =
1453       UTF8ToUnicodeText(context, /*do_copy=*/false);
1454   const auto [span_begin, span_end] =
1455       CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1456   return ModelClassifyText(context_unicode, cached_tokens,
1457                            detected_text_language_tags, span_begin, span_end,
1458                            /*line=*/nullptr, selection_indices, options,
1459                            interpreter_manager, embedding_cache,
1460                            classification_results, tokens);
1461 }
1462 
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1463 bool Annotator::ModelClassifyText(
1464     const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1465     const std::vector<Locale>& detected_text_language_tags,
1466     const UnicodeText::const_iterator& span_begin,
1467     const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1468     const CodepointSpan& selection_indices, const BaseOptions& options,
1469     InterpreterManager* interpreter_manager,
1470     FeatureProcessor::EmbeddingCache* embedding_cache,
1471     std::vector<ClassificationResult>* classification_results,
1472     std::vector<Token>* tokens) const {
1473   if (model_->triggering_options() == nullptr ||
1474       !(model_->triggering_options()->enabled_modes() &
1475         ModeFlag_CLASSIFICATION)) {
1476     return true;
1477   }
1478 
1479   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1480                                     ml_model_triggering_locales_,
1481                                     /*default_value=*/true)) {
1482     return true;
1483   }
1484 
1485   std::vector<Token> local_tokens;
1486   if (tokens == nullptr) {
1487     tokens = &local_tokens;
1488   }
1489 
1490   if (cached_tokens.empty()) {
1491     *tokens = classification_feature_processor_->Tokenize(context_unicode);
1492   } else {
1493     *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1494                                          ClassifyTextUpperBoundNeededTokens());
1495   }
1496 
1497   int click_pos;
1498   classification_feature_processor_->RetokenizeAndFindClick(
1499       context_unicode, span_begin, span_end, selection_indices,
1500       classification_feature_processor_->GetOptions()
1501           ->only_use_line_with_click(),
1502       tokens, &click_pos);
1503   const TokenSpan selection_token_span =
1504       CodepointSpanToTokenSpan(*tokens, selection_indices);
1505   const int selection_num_tokens = selection_token_span.Size();
1506   if (model_->classification_options()->max_num_tokens() > 0 &&
1507       model_->classification_options()->max_num_tokens() <
1508           selection_num_tokens) {
1509     *classification_results = {{Collections::Other(), 1.0}};
1510     return true;
1511   }
1512 
1513   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1514       bounds_sensitive_features =
1515           classification_feature_processor_->GetOptions()
1516               ->bounds_sensitive_features();
1517   if (selection_token_span.first == kInvalidIndex ||
1518       selection_token_span.second == kInvalidIndex) {
1519     TC3_LOG(ERROR) << "Could not determine span.";
1520     return false;
1521   }
1522 
1523   // Compute the extraction span based on the model type.
1524   TokenSpan extraction_span;
1525   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1526     // The extraction span is the selection span expanded to include a relevant
1527     // number of tokens outside of the bounds of the selection.
1528     extraction_span = selection_token_span.Expand(
1529         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1530         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1531   } else {
1532     if (click_pos == kInvalidIndex) {
1533       TC3_LOG(ERROR) << "Couldn't choose a click position.";
1534       return false;
1535     }
1536     // The extraction span is the clicked token with context_size tokens on
1537     // either side.
1538     const int context_size =
1539         classification_feature_processor_->GetOptions()->context_size();
1540     extraction_span = TokenSpan(click_pos).Expand(
1541         /*num_tokens_left=*/context_size,
1542         /*num_tokens_right=*/context_size);
1543   }
1544   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1545 
1546   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1547           *tokens, extraction_span)) {
1548     *classification_results = {{Collections::Other(), 1.0}};
1549     return true;
1550   }
1551 
1552   std::unique_ptr<CachedFeatures> cached_features;
1553   if (!classification_feature_processor_->ExtractFeatures(
1554           *tokens, extraction_span, selection_indices,
1555           embedding_executor_.get(), embedding_cache,
1556           classification_feature_processor_->EmbeddingSize() +
1557               classification_feature_processor_->DenseFeaturesCount(),
1558           &cached_features)) {
1559     TC3_LOG(ERROR) << "Could not extract features.";
1560     return false;
1561   }
1562 
1563   std::vector<float> features;
1564   features.reserve(cached_features->OutputFeaturesSize());
1565   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1566     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1567                                                           &features);
1568   } else {
1569     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1570   }
1571 
1572   TensorView<float> logits = classification_executor_->ComputeLogits(
1573       TensorView<float>(features.data(),
1574                         {1, static_cast<int>(features.size())}),
1575       interpreter_manager->ClassificationInterpreter());
1576   if (!logits.is_valid()) {
1577     TC3_LOG(ERROR) << "Couldn't compute logits.";
1578     return false;
1579   }
1580 
1581   if (logits.dims() != 2 || logits.dim(0) != 1 ||
1582       logits.dim(1) != classification_feature_processor_->NumCollections()) {
1583     TC3_LOG(ERROR) << "Mismatching output";
1584     return false;
1585   }
1586 
1587   const std::vector<float> scores =
1588       ComputeSoftmax(logits.data(), logits.dim(1));
1589 
1590   if (scores.empty()) {
1591     *classification_results = {{Collections::Other(), 1.0}};
1592     return true;
1593   }
1594 
1595   const int best_score_index =
1596       std::max_element(scores.begin(), scores.end()) - scores.begin();
1597   const std::string top_collection =
1598       classification_feature_processor_->LabelToCollection(best_score_index);
1599 
1600   // Sanity checks.
1601   if (top_collection == Collections::Phone()) {
1602     const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1603     if (digit_count <
1604             model_->classification_options()->phone_min_num_digits() ||
1605         digit_count >
1606             model_->classification_options()->phone_max_num_digits()) {
1607       *classification_results = {{Collections::Other(), 1.0}};
1608       return true;
1609     }
1610   } else if (top_collection == Collections::Address()) {
1611     if (selection_num_tokens <
1612         model_->classification_options()->address_min_num_tokens()) {
1613       *classification_results = {{Collections::Other(), 1.0}};
1614       return true;
1615     }
1616   } else if (top_collection == Collections::Dictionary()) {
1617     if ((options.use_vocab_annotator && vocab_annotator_) ||
1618         !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1619                                       dictionary_locales_,
1620                                       /*default_value=*/false)) {
1621       *classification_results = {{Collections::Other(), 1.0}};
1622       return true;
1623     }
1624   }
1625   *classification_results = {{top_collection, /*arg_score=*/1.0,
1626                               /*arg_priority_score=*/scores[best_score_index]}};
1627 
1628   // For some entities, we might want to clamp the priority score, for better
1629   // conflict resolution between entities.
1630   if (model_->triggering_options() != nullptr &&
1631       model_->triggering_options()->collection_to_priority() != nullptr) {
1632     if (auto entry =
1633             model_->triggering_options()->collection_to_priority()->LookupByKey(
1634                 top_collection.c_str())) {
1635       (*classification_results)[0].priority_score *= entry->value();
1636     }
1637   }
1638   return true;
1639 }
1640 
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1641 bool Annotator::RegexClassifyText(
1642     const std::string& context, const CodepointSpan& selection_indices,
1643     std::vector<ClassificationResult>* classification_result) const {
1644   const std::string selection_text =
1645       UTF8ToUnicodeText(context, /*do_copy=*/false)
1646           .UTF8Substring(selection_indices.first, selection_indices.second);
1647   const UnicodeText selection_text_unicode(
1648       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1649 
1650   // Check whether any of the regular expressions match.
1651   for (const int pattern_id : classification_regex_patterns_) {
1652     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1653     const std::unique_ptr<UniLib::RegexMatcher> matcher =
1654         regex_pattern.pattern->Matcher(selection_text_unicode);
1655     int status = UniLib::RegexMatcher::kNoError;
1656     bool matches;
1657     if (regex_pattern.config->use_approximate_matching()) {
1658       matches = matcher->ApproximatelyMatches(&status);
1659     } else {
1660       matches = matcher->Matches(&status);
1661     }
1662     if (status != UniLib::RegexMatcher::kNoError) {
1663       return false;
1664     }
1665     if (matches && VerifyRegexMatchCandidate(
1666                        context, regex_pattern.config->verification_options(),
1667                        selection_text, matcher.get())) {
1668       classification_result->push_back(
1669           {regex_pattern.config->collection_name()->str(),
1670            regex_pattern.config->target_classification_score(),
1671            regex_pattern.config->priority_score()});
1672       if (!SerializedEntityDataFromRegexMatch(
1673               regex_pattern.config, matcher.get(),
1674               &classification_result->back().serialized_entity_data)) {
1675         TC3_LOG(ERROR) << "Could not get entity data.";
1676         return false;
1677       }
1678     }
1679   }
1680 
1681   return true;
1682 }
1683 
1684 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1685 std::string PickCollectionForDatetime(
1686     const DatetimeParseResult& datetime_parse_result) {
1687   switch (datetime_parse_result.granularity) {
1688     case GRANULARITY_HOUR:
1689     case GRANULARITY_MINUTE:
1690     case GRANULARITY_SECOND:
1691       return Collections::DateTime();
1692     default:
1693       return Collections::Date();
1694   }
1695 }
1696 
1697 }  // namespace
1698 
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1699 bool Annotator::DatetimeClassifyText(
1700     const std::string& context, const CodepointSpan& selection_indices,
1701     const ClassificationOptions& options,
1702     std::vector<ClassificationResult>* classification_results) const {
1703   if (!datetime_parser_) {
1704     return true;
1705   }
1706 
1707   const std::string selection_text =
1708       UTF8ToUnicodeText(context, /*do_copy=*/false)
1709           .UTF8Substring(selection_indices.first, selection_indices.second);
1710 
1711   LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1712   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1713       datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1714                               options.reference_timezone, locale_list,
1715                               ModeFlag_CLASSIFICATION,
1716                               options.annotation_usecase,
1717                               /*anchor_start_end=*/true);
1718   if (!result_status.ok()) {
1719     TC3_LOG(ERROR) << "Error during parsing datetime.";
1720     return false;
1721   }
1722 
1723   for (const DatetimeParseResultSpan& datetime_span :
1724        result_status.ValueOrDie()) {
1725     // Only consider the result valid if the selection and extracted datetime
1726     // spans exactly match.
1727     if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1728                       datetime_span.span.second + selection_indices.first) ==
1729         selection_indices) {
1730       for (const DatetimeParseResult& parse_result : datetime_span.data) {
1731         classification_results->emplace_back(
1732             PickCollectionForDatetime(parse_result),
1733             datetime_span.target_classification_score);
1734         classification_results->back().datetime_parse_result = parse_result;
1735         classification_results->back().serialized_entity_data =
1736             CreateDatetimeSerializedEntityData(parse_result);
1737         classification_results->back().priority_score =
1738             datetime_span.priority_score;
1739       }
1740       return true;
1741     }
1742   }
1743   return true;
1744 }
1745 
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1746 std::vector<ClassificationResult> Annotator::ClassifyText(
1747     const std::string& context, const CodepointSpan& selection_indices,
1748     const ClassificationOptions& options) const {
1749   if (context.size() > std::numeric_limits<int>::max()) {
1750     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1751     return {};
1752   }
1753   if (!initialized_) {
1754     TC3_LOG(ERROR) << "Not initialized";
1755     return {};
1756   }
1757   if (options.annotation_usecase !=
1758       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1759     TC3_LOG(WARNING)
1760         << "Invoking ClassifyText, which is not supported in RAW mode.";
1761     return {};
1762   }
1763   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1764     return {};
1765   }
1766 
1767   std::vector<Locale> detected_text_language_tags;
1768   if (!ParseLocales(options.detected_text_language_tags,
1769                     &detected_text_language_tags)) {
1770     TC3_LOG(WARNING)
1771         << "Failed to parse the detected_text_language_tags in options: "
1772         << options.detected_text_language_tags;
1773   }
1774   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1775                                     model_triggering_locales_,
1776                                     /*default_value=*/true)) {
1777     return {};
1778   }
1779 
1780   const UnicodeText context_unicode =
1781       UTF8ToUnicodeText(context, /*do_copy=*/false);
1782 
1783   if (!unilib_->IsValidUtf8(context_unicode)) {
1784     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1785     return {};
1786   }
1787 
1788   if (!IsValidSpanInput(context_unicode, selection_indices)) {
1789     TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1790                 << selection_indices.first << " " << selection_indices.second;
1791     return {};
1792   }
1793 
1794   // We'll accumulate a list of candidates, and pick the best candidate in the
1795   // end.
1796   std::vector<AnnotatedSpan> candidates;
1797 
1798   // Try the knowledge engine.
1799   // TODO(b/126579108): Propagate error status.
1800   ClassificationResult knowledge_result;
1801   if (knowledge_engine_ &&
1802       knowledge_engine_
1803           ->ClassifyText(context, selection_indices, options.annotation_usecase,
1804                          options.location_context, Permissions(),
1805                          &knowledge_result)
1806           .ok()) {
1807     candidates.push_back({selection_indices, {knowledge_result}});
1808     candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1809   }
1810 
1811   AddContactMetadataToKnowledgeClassificationResults(&candidates);
1812 
1813   // Try the contact engine.
1814   // TODO(b/126579108): Propagate error status.
1815   ClassificationResult contact_result;
1816   if (contact_engine_ && contact_engine_->ClassifyText(
1817                              context, selection_indices, &contact_result)) {
1818     candidates.push_back({selection_indices, {contact_result}});
1819   }
1820 
1821   // Try the person name engine.
1822   ClassificationResult person_name_result;
1823   if (person_name_engine_ &&
1824       person_name_engine_->ClassifyText(context, selection_indices,
1825                                         &person_name_result)) {
1826     candidates.push_back({selection_indices, {person_name_result}});
1827     candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1828   }
1829 
1830   // Try the installed app engine.
1831   // TODO(b/126579108): Propagate error status.
1832   ClassificationResult installed_app_result;
1833   if (installed_app_engine_ &&
1834       installed_app_engine_->ClassifyText(context, selection_indices,
1835                                           &installed_app_result)) {
1836     candidates.push_back({selection_indices, {installed_app_result}});
1837   }
1838 
1839   // Try the regular expression models.
1840   std::vector<ClassificationResult> regex_results;
1841   if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1842     return {};
1843   }
1844   for (const ClassificationResult& result : regex_results) {
1845     candidates.push_back({selection_indices, {result}});
1846   }
1847 
1848   // Try the date model.
1849   //
1850   // DatetimeClassifyText only returns the first result, which can however have
1851   // more interpretations. They are inserted in the candidates as a single
1852   // AnnotatedSpan, so that they get treated together by the conflict resolution
1853   // algorithm.
1854   std::vector<ClassificationResult> datetime_results;
1855   if (!DatetimeClassifyText(context, selection_indices, options,
1856                             &datetime_results)) {
1857     return {};
1858   }
1859   if (!datetime_results.empty()) {
1860     candidates.push_back({selection_indices, std::move(datetime_results)});
1861     candidates.back().source = AnnotatedSpan::Source::DATETIME;
1862   }
1863 
1864   // Try the number annotator.
1865   // TODO(b/126579108): Propagate error status.
1866   ClassificationResult number_annotator_result;
1867   if (number_annotator_ &&
1868       number_annotator_->ClassifyText(context_unicode, selection_indices,
1869                                       options.annotation_usecase,
1870                                       &number_annotator_result)) {
1871     candidates.push_back({selection_indices, {number_annotator_result}});
1872   }
1873 
1874   // Try the duration annotator.
1875   ClassificationResult duration_annotator_result;
1876   if (duration_annotator_ &&
1877       duration_annotator_->ClassifyText(context_unicode, selection_indices,
1878                                         options.annotation_usecase,
1879                                         &duration_annotator_result)) {
1880     candidates.push_back({selection_indices, {duration_annotator_result}});
1881     candidates.back().source = AnnotatedSpan::Source::DURATION;
1882   }
1883 
1884   // Try the translate annotator.
1885   ClassificationResult translate_annotator_result;
1886   if (translate_annotator_ &&
1887       translate_annotator_->ClassifyText(context_unicode, selection_indices,
1888                                          options.user_familiar_language_tags,
1889                                          &translate_annotator_result)) {
1890     candidates.push_back({selection_indices, {translate_annotator_result}});
1891   }
1892 
1893   // Try the grammar model.
1894   ClassificationResult grammar_annotator_result;
1895   if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1896                                 detected_text_language_tags, context_unicode,
1897                                 selection_indices, &grammar_annotator_result)) {
1898     candidates.push_back({selection_indices, {grammar_annotator_result}});
1899   }
1900 
1901   ClassificationResult pod_ner_annotator_result;
1902   if (pod_ner_annotator_ && options.use_pod_ner &&
1903       pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1904                                        &pod_ner_annotator_result)) {
1905     candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1906   }
1907 
1908   ClassificationResult vocab_annotator_result;
1909   if (vocab_annotator_ && options.use_vocab_annotator &&
1910       vocab_annotator_->ClassifyText(
1911           context_unicode, selection_indices, detected_text_language_tags,
1912           options.trigger_dictionary_on_beginner_words,
1913           &vocab_annotator_result)) {
1914     candidates.push_back({selection_indices, {vocab_annotator_result}});
1915   }
1916 
1917   if (experimental_annotator_ &&
1918       (model_->triggering_options()->experimental_enabled_modes() &
1919        ModeFlag_CLASSIFICATION)) {
1920     experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1921                                           candidates);
1922   }
1923 
1924   // Try the ML model.
1925   //
1926   // The output of the model is considered as an exclusive 1-of-N choice. That's
1927   // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1928   // span for each candidate, like e.g. the regex model.
1929   InterpreterManager interpreter_manager(selection_executor_.get(),
1930                                          classification_executor_.get());
1931   std::vector<ClassificationResult> model_results;
1932   std::vector<Token> tokens;
1933   if (!ModelClassifyText(
1934           context, /*cached_tokens=*/{}, detected_text_language_tags,
1935           selection_indices, options, &interpreter_manager,
1936           /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1937     return {};
1938   }
1939   if (!model_results.empty()) {
1940     candidates.push_back({selection_indices, std::move(model_results)});
1941   }
1942 
1943   std::vector<int> candidate_indices;
1944   if (!ResolveConflicts(candidates, context, tokens,
1945                         detected_text_language_tags, options,
1946                         &interpreter_manager, &candidate_indices)) {
1947     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1948     return {};
1949   }
1950 
1951   std::vector<ClassificationResult> results;
1952   for (const int i : candidate_indices) {
1953     for (const ClassificationResult& result : candidates[i].classification) {
1954       if (!FilteredForClassification(result)) {
1955         results.push_back(result);
1956       }
1957     }
1958   }
1959 
1960   // Sort results according to score.
1961   std::stable_sort(
1962       results.begin(), results.end(),
1963       [](const ClassificationResult& a, const ClassificationResult& b) {
1964         return a.score > b.score;
1965       });
1966 
1967   if (results.empty()) {
1968     results = {{Collections::Other(), 1.0}};
1969   }
1970   return results;
1971 }
1972 
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1973 bool Annotator::ModelAnnotate(
1974     const std::string& context,
1975     const std::vector<Locale>& detected_text_language_tags,
1976     const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1977     std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1978   bool skip_model_annotatation = false;
1979   if (model_->triggering_options() == nullptr ||
1980       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1981     skip_model_annotatation = true;
1982   }
1983   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1984                                     ml_model_triggering_locales_,
1985                                     /*default_value=*/true)) {
1986     skip_model_annotatation = true;
1987   }
1988 
1989   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1990                                                         /*do_copy=*/false);
1991   std::vector<UnicodeTextRange> lines;
1992   if (!selection_feature_processor_ ||
1993       !selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1994     lines.push_back({context_unicode.begin(), context_unicode.end()});
1995   } else {
1996     lines = selection_feature_processor_->SplitContext(
1997         context_unicode, selection_feature_processor_->GetOptions()
1998                              ->use_pipe_character_for_newline());
1999   }
2000 
2001   const float min_annotate_confidence =
2002       (model_->triggering_options() != nullptr
2003            ? model_->triggering_options()->min_annotate_confidence()
2004            : 0.f);
2005 
2006   for (const UnicodeTextRange& line : lines) {
2007     const std::string line_str =
2008         UnicodeText::UTF8Substring(line.first, line.second);
2009 
2010     std::vector<Token> line_tokens;
2011     line_tokens = selection_feature_processor_->Tokenize(line_str);
2012 
2013     selection_feature_processor_->RetokenizeAndFindClick(
2014         line_str, {0, std::distance(line.first, line.second)},
2015         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
2016         &line_tokens,
2017         /*click_pos=*/nullptr);
2018     const TokenSpan full_line_span = {
2019         0, static_cast<TokenIndex>(line_tokens.size())};
2020 
2021     tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2022 
2023     if (skip_model_annotatation) {
2024       // We do not annotate, we only output the tokens.
2025       continue;
2026     }
2027 
2028     // TODO(zilka): Add support for greater granularity of this check.
2029     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
2030             line_tokens, full_line_span)) {
2031       continue;
2032     }
2033 
2034     std::unique_ptr<CachedFeatures> cached_features;
2035     if (!selection_feature_processor_->ExtractFeatures(
2036             line_tokens, full_line_span,
2037             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2038             embedding_executor_.get(),
2039             /*embedding_cache=*/nullptr,
2040             selection_feature_processor_->EmbeddingSize() +
2041                 selection_feature_processor_->DenseFeaturesCount(),
2042             &cached_features)) {
2043       TC3_LOG(ERROR) << "Could not extract features.";
2044       return false;
2045     }
2046 
2047     std::vector<TokenSpan> local_chunks;
2048     if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2049                     interpreter_manager->SelectionInterpreter(),
2050                     *cached_features, &local_chunks)) {
2051       TC3_LOG(ERROR) << "Could not chunk.";
2052       return false;
2053     }
2054 
2055     const int offset = std::distance(context_unicode.begin(), line.first);
2056     if (local_chunks.empty()) {
2057       continue;
2058     }
2059     const UnicodeText line_unicode =
2060         UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2061     std::vector<UnicodeText::const_iterator> line_codepoints =
2062         line_unicode.Codepoints();
2063     line_codepoints.push_back(line_unicode.end());
2064 
2065     FeatureProcessor::EmbeddingCache embedding_cache;
2066     for (const TokenSpan& chunk : local_chunks) {
2067       CodepointSpan codepoint_span =
2068           TokenSpanToCodepointSpan(line_tokens, chunk);
2069       if (!codepoint_span.IsValid() ||
2070           codepoint_span.second > line_codepoints.size()) {
2071         continue;
2072       }
2073       codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2074           /*span_begin=*/line_codepoints[codepoint_span.first],
2075           /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span);
2076       if (model_->selection_options()->strip_unpaired_brackets()) {
2077         codepoint_span = StripUnpairedBrackets(
2078             /*span_begin=*/line_codepoints[codepoint_span.first],
2079             /*span_end=*/line_codepoints[codepoint_span.second], codepoint_span,
2080             *unilib_);
2081       }
2082 
2083       // Skip empty spans.
2084       if (codepoint_span.first != codepoint_span.second) {
2085         std::vector<ClassificationResult> classification;
2086         if (!ModelClassifyText(
2087                 line_unicode, line_tokens, detected_text_language_tags,
2088                 /*span_begin=*/line_codepoints[codepoint_span.first],
2089                 /*span_end=*/line_codepoints[codepoint_span.second], &line,
2090                 codepoint_span, options, interpreter_manager, &embedding_cache,
2091                 &classification, /*tokens=*/nullptr)) {
2092           TC3_LOG(ERROR) << "Could not classify text: "
2093                          << (codepoint_span.first + offset) << " "
2094                          << (codepoint_span.second + offset);
2095           return false;
2096         }
2097 
2098         // Do not include the span if it's classified as "other".
2099         if (!classification.empty() && !ClassifiedAsOther(classification) &&
2100             classification[0].score >= min_annotate_confidence) {
2101           AnnotatedSpan result_span;
2102           result_span.span = {codepoint_span.first + offset,
2103                               codepoint_span.second + offset};
2104           result_span.classification = std::move(classification);
2105           result->push_back(std::move(result_span));
2106         }
2107       }
2108     }
2109   }
2110   return true;
2111 }
2112 
SelectionFeatureProcessorForTests() const2113 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2114   return selection_feature_processor_.get();
2115 }
2116 
ClassificationFeatureProcessorForTests() const2117 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2118     const {
2119   return classification_feature_processor_.get();
2120 }
2121 
DatetimeParserForTests() const2122 const DatetimeParser* Annotator::DatetimeParserForTests() const {
2123   return datetime_parser_.get();
2124 }
2125 
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2126 void Annotator::RemoveNotEnabledEntityTypes(
2127     const EnabledEntityTypes& is_entity_type_enabled,
2128     std::vector<AnnotatedSpan>* annotated_spans) const {
2129   for (AnnotatedSpan& annotated_span : *annotated_spans) {
2130     std::vector<ClassificationResult>& classifications =
2131         annotated_span.classification;
2132     classifications.erase(
2133         std::remove_if(classifications.begin(), classifications.end(),
2134                        [&is_entity_type_enabled](
2135                            const ClassificationResult& classification_result) {
2136                          return !is_entity_type_enabled(
2137                              classification_result.collection);
2138                        }),
2139         classifications.end());
2140   }
2141   annotated_spans->erase(
2142       std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2143                      [](const AnnotatedSpan& annotated_span) {
2144                        return annotated_span.classification.empty();
2145                      }),
2146       annotated_spans->end());
2147 }
2148 
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2149 void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2150     std::vector<AnnotatedSpan>* candidates) const {
2151   if (candidates == nullptr || contact_engine_ == nullptr) {
2152     return;
2153   }
2154   for (auto& candidate : *candidates) {
2155     for (auto& classification_result : candidate.classification) {
2156       contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2157           &classification_result);
2158     }
2159   }
2160 }
2161 
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2162 Status Annotator::AnnotateSingleInput(
2163     const std::string& context, const AnnotationOptions& options,
2164     std::vector<AnnotatedSpan>* candidates) const {
2165   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2166     return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2167   }
2168 
2169   const UnicodeText context_unicode =
2170       UTF8ToUnicodeText(context, /*do_copy=*/false);
2171 
2172   std::vector<Locale> detected_text_language_tags;
2173   if (!ParseLocales(options.detected_text_language_tags,
2174                     &detected_text_language_tags)) {
2175     TC3_LOG(WARNING)
2176         << "Failed to parse the detected_text_language_tags in options: "
2177         << options.detected_text_language_tags;
2178   }
2179   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2180                                     model_triggering_locales_,
2181                                     /*default_value=*/true)) {
2182     return Status(
2183         StatusCode::UNAVAILABLE,
2184         "The detected language tags are not in the supported locales.");
2185   }
2186 
2187   InterpreterManager interpreter_manager(selection_executor_.get(),
2188                                          classification_executor_.get());
2189 
2190   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2191   const bool is_raw_usecase =
2192       options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2193 
2194   // Annotate with the selection model.
2195   const bool model_annotations_enabled =
2196       !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2197   std::vector<Token> tokens;
2198   if (model_annotations_enabled &&
2199       !ModelAnnotate(context, detected_text_language_tags, options,
2200                      &interpreter_manager, &tokens, candidates)) {
2201     return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2202   } else if (!model_annotations_enabled) {
2203     // If the ML model didn't run, we need to tokenize to support the other
2204     // annotators that depend on the tokens.
2205     // Optimization could be made to only do this when an annotator that uses
2206     // the tokens is enabled, but it's unclear if the added complexity is worth
2207     // it.
2208     if (selection_feature_processor_ != nullptr) {
2209       tokens = selection_feature_processor_->Tokenize(context_unicode);
2210     }
2211   }
2212 
2213   // Annotate with the regular expression models.
2214   const bool regex_annotations_enabled =
2215       !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2216   if (regex_annotations_enabled &&
2217       !RegexChunk(
2218           UTF8ToUnicodeText(context, /*do_copy=*/false),
2219           annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2220           is_entity_type_enabled, options.annotation_usecase, candidates)) {
2221     return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2222   }
2223 
2224   // Annotate with the datetime model.
2225   // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2226   // relatively slow for some clients.
2227   if ((is_entity_type_enabled(Collections::Date()) ||
2228        is_entity_type_enabled(Collections::DateTime())) &&
2229       !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2230                      options.reference_time_ms_utc, options.reference_timezone,
2231                      options.locales, ModeFlag_ANNOTATION,
2232                      options.annotation_usecase,
2233                      options.is_serialized_entity_data_enabled, candidates)) {
2234     return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2235   }
2236 
2237   // Annotate with the contact engine.
2238   const bool contact_annotations_enabled =
2239       !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2240   if (contact_annotations_enabled && contact_engine_ &&
2241       !contact_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
2242                               candidates)) {
2243     return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2244   }
2245 
2246   // Annotate with the installed app engine.
2247   const bool app_annotations_enabled =
2248       !is_raw_usecase || is_entity_type_enabled(Collections::App());
2249   if (app_annotations_enabled && installed_app_engine_ &&
2250       !installed_app_engine_->Chunk(context_unicode, tokens,
2251                                     ModeFlag_ANNOTATION, candidates)) {
2252     return Status(StatusCode::INTERNAL,
2253                   "Couldn't run installed app engine Chunk.");
2254   }
2255 
2256   // Annotate with the number annotator.
2257   const bool number_annotations_enabled =
2258       !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2259                           is_entity_type_enabled(Collections::Percentage()));
2260   if (number_annotations_enabled && number_annotator_ != nullptr &&
2261       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2262                                   ModeFlag_ANNOTATION, candidates)) {
2263     return Status(StatusCode::INTERNAL,
2264                   "Couldn't run number annotator FindAll.");
2265   }
2266 
2267   // Annotate with the duration annotator.
2268   const bool duration_annotations_enabled =
2269       !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2270   if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2271       !duration_annotator_->FindAll(context_unicode, tokens,
2272                                     options.annotation_usecase,
2273                                     ModeFlag_ANNOTATION, candidates)) {
2274     return Status(StatusCode::INTERNAL,
2275                   "Couldn't run duration annotator FindAll.");
2276   }
2277 
2278   // Annotate with the person name engine.
2279   const bool person_annotations_enabled =
2280       !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2281   if (person_annotations_enabled && person_name_engine_ &&
2282       !person_name_engine_->Chunk(context_unicode, tokens, ModeFlag_ANNOTATION,
2283                                   candidates)) {
2284     return Status(StatusCode::INTERNAL,
2285                   "Couldn't run person name engine Chunk.");
2286   }
2287 
2288   // Annotate with the grammar annotators.
2289   if (grammar_annotator_ != nullptr &&
2290       !grammar_annotator_->Annotate(detected_text_language_tags,
2291                                     context_unicode, candidates)) {
2292     return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2293   }
2294 
2295   // Annotate with the POD NER annotator.
2296   const bool pod_ner_annotations_enabled =
2297       !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2298   if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2299       options.use_pod_ner &&
2300       !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2301     return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2302   }
2303 
2304   // Annotate with the vocab annotator.
2305   const bool vocab_annotations_enabled =
2306       !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2307   if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2308       options.use_vocab_annotator &&
2309       !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2310                                   options.trigger_dictionary_on_beginner_words,
2311                                   candidates)) {
2312     return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2313   }
2314 
2315   // Annotate with the experimental annotator.
2316   if (experimental_annotator_ != nullptr &&
2317       (model_->triggering_options()->experimental_enabled_modes() &
2318        ModeFlag_ANNOTATION) &&
2319       !experimental_annotator_->Annotate(context_unicode, candidates)) {
2320     return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2321   }
2322 
2323   // Sort candidates according to their position in the input, so that the next
2324   // code can assume that any connected component of overlapping spans forms a
2325   // contiguous block.
2326   // Also sort them according to the end position and collection, so that the
2327   // deduplication code below can assume that same spans and classifications
2328   // form contiguous blocks.
2329   std::stable_sort(candidates->begin(), candidates->end(),
2330                    [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2331                      if (a.span.first != b.span.first) {
2332                        return a.span.first < b.span.first;
2333                      }
2334 
2335                      if (a.span.second != b.span.second) {
2336                        return a.span.second < b.span.second;
2337                      }
2338 
2339                      return a.classification[0].collection <
2340                             b.classification[0].collection;
2341                    });
2342 
2343   std::vector<int> candidate_indices;
2344   if (!ResolveConflicts(*candidates, context, tokens,
2345                         detected_text_language_tags, options,
2346                         &interpreter_manager, &candidate_indices)) {
2347     return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2348   }
2349 
2350   // Remove candidates that overlap exactly and have the same collection.
2351   // This can e.g. happen for phone coming from both ML model and regex.
2352   candidate_indices.erase(
2353       std::unique(candidate_indices.begin(), candidate_indices.end(),
2354                   [&candidates](const int a_index, const int b_index) {
2355                     const AnnotatedSpan& a = (*candidates)[a_index];
2356                     const AnnotatedSpan& b = (*candidates)[b_index];
2357                     return a.span == b.span &&
2358                            a.classification[0].collection ==
2359                                b.classification[0].collection;
2360                   }),
2361       candidate_indices.end());
2362 
2363   std::vector<AnnotatedSpan> result;
2364   result.reserve(candidate_indices.size());
2365   for (const int i : candidate_indices) {
2366     if ((*candidates)[i].classification.empty() ||
2367         ClassifiedAsOther((*candidates)[i].classification) ||
2368         FilteredForAnnotation((*candidates)[i])) {
2369       continue;
2370     }
2371     result.push_back(std::move((*candidates)[i]));
2372   }
2373 
2374   // We generate all candidates and remove them later (with the exception of
2375   // date/time/duration entities) because there are complex interdependencies
2376   // between the entity types. E.g., the TLD of an email can be interpreted as a
2377   // URL, but most likely a user of the API does not want such annotations if
2378   // "url" is enabled and "email" is not.
2379   RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2380 
2381   for (AnnotatedSpan& annotated_span : result) {
2382     SortClassificationResults(&annotated_span.classification);
2383   }
2384   *candidates = result;
2385   return Status::OK;
2386 }
2387 
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2388 StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2389     const std::vector<InputFragment>& string_fragments,
2390     const AnnotationOptions& options) const {
2391   Annotations annotation_candidates;
2392   annotation_candidates.annotated_spans.resize(string_fragments.size());
2393 
2394   std::vector<std::string> text_to_annotate;
2395   text_to_annotate.reserve(string_fragments.size());
2396   std::vector<FragmentMetadata> fragment_metadata;
2397   fragment_metadata.reserve(string_fragments.size());
2398   for (const auto& string_fragment : string_fragments) {
2399     text_to_annotate.push_back(string_fragment.text);
2400     fragment_metadata.push_back(
2401         {.relative_bounding_box_top = string_fragment.bounding_box_top,
2402          .relative_bounding_box_height = string_fragment.bounding_box_height});
2403   }
2404 
2405   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2406   const bool is_raw_usecase =
2407       options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2408 
2409   const bool knowledge_engine_annotations_enabled =
2410       !is_raw_usecase || is_entity_type_enabled(Collections::Entity());
2411   // KnowledgeEngine is special, because it supports annotation of multiple
2412   // fragments at once.
2413   if (knowledge_engine_annotations_enabled && knowledge_engine_ &&
2414       !knowledge_engine_
2415            ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2416                                 options.annotation_usecase,
2417                                 options.location_context, options.permissions,
2418                                 options.annotate_mode, ModeFlag_ANNOTATION,
2419                                 &annotation_candidates)
2420            .ok()) {
2421     return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2422   }
2423   // The annotator engines shouldn't change the number of annotation vectors.
2424   if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2425     TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2426                    << " texts to annotate but generated a different number of  "
2427                       "lists of annotations:"
2428                    << annotation_candidates.annotated_spans.size();
2429     return Status(StatusCode::INTERNAL,
2430                   "Number of annotation candidates differs from "
2431                   "number of texts to annotate.");
2432   }
2433 
2434   // As an optimization, if the only annotated type is Entity, we skip all the
2435   // other annotators than the KnowledgeEngine. This only happens in the raw
2436   // mode, to make sure it does not affect the result.
2437   if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2438       options.entity_types.size() == 1 &&
2439       *options.entity_types.begin() == Collections::Entity()) {
2440     return annotation_candidates;
2441   }
2442 
2443   // Other annotators run on each fragment independently.
2444   for (int i = 0; i < text_to_annotate.size(); ++i) {
2445     AnnotationOptions annotation_options = options;
2446     if (string_fragments[i].datetime_options.has_value()) {
2447       DatetimeOptions reference_datetime =
2448           string_fragments[i].datetime_options.value();
2449       annotation_options.reference_time_ms_utc =
2450           reference_datetime.reference_time_ms_utc;
2451       annotation_options.reference_timezone =
2452           reference_datetime.reference_timezone;
2453     }
2454 
2455     AddContactMetadataToKnowledgeClassificationResults(
2456         &annotation_candidates.annotated_spans[i]);
2457 
2458     Status annotation_status =
2459         AnnotateSingleInput(text_to_annotate[i], annotation_options,
2460                             &annotation_candidates.annotated_spans[i]);
2461     if (!annotation_status.ok()) {
2462       return annotation_status;
2463     }
2464   }
2465   return annotation_candidates;
2466 }
2467 
Annotate(const std::string & context,const AnnotationOptions & options) const2468 std::vector<AnnotatedSpan> Annotator::Annotate(
2469     const std::string& context, const AnnotationOptions& options) const {
2470   if (context.size() > std::numeric_limits<int>::max()) {
2471     TC3_LOG(ERROR) << "Rejecting too long input.";
2472     return {};
2473   }
2474 
2475   const UnicodeText context_unicode =
2476       UTF8ToUnicodeText(context, /*do_copy=*/false);
2477   if (!unilib_->IsValidUtf8(context_unicode)) {
2478     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2479     return {};
2480   }
2481 
2482   std::vector<InputFragment> string_fragments;
2483   string_fragments.push_back({.text = context});
2484   StatusOr<Annotations> annotations =
2485       AnnotateStructuredInput(string_fragments, options);
2486   if (!annotations.ok()) {
2487     TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2488                    << annotations.status().error_message();
2489     return {};
2490   }
2491   return annotations.ValueOrDie().annotated_spans[0];
2492 }
2493 
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2494 CodepointSpan Annotator::ComputeSelectionBoundaries(
2495     const UniLib::RegexMatcher* match,
2496     const RegexModel_::Pattern* config) const {
2497   if (config->capturing_group() == nullptr) {
2498     // Use first capturing group to specify the selection.
2499     int status = UniLib::RegexMatcher::kNoError;
2500     const CodepointSpan result = {match->Start(1, &status),
2501                                   match->End(1, &status)};
2502     if (status != UniLib::RegexMatcher::kNoError) {
2503       return {kInvalidIndex, kInvalidIndex};
2504     }
2505     return result;
2506   }
2507 
2508   CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2509   const int num_groups = config->capturing_group()->size();
2510   for (int i = 0; i < num_groups; i++) {
2511     if (!config->capturing_group()->Get(i)->extend_selection()) {
2512       continue;
2513     }
2514 
2515     int status = UniLib::RegexMatcher::kNoError;
2516     // Check match and adjust bounds.
2517     const int group_start = match->Start(i, &status);
2518     const int group_end = match->End(i, &status);
2519     if (status != UniLib::RegexMatcher::kNoError) {
2520       return {kInvalidIndex, kInvalidIndex};
2521     }
2522     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2523       continue;
2524     }
2525     if (result.first == kInvalidIndex) {
2526       result = {group_start, group_end};
2527     } else {
2528       result.first = std::min(result.first, group_start);
2529       result.second = std::max(result.second, group_end);
2530     }
2531   }
2532   return result;
2533 }
2534 
HasEntityData(const RegexModel_::Pattern * pattern) const2535 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2536   if (pattern->serialized_entity_data() != nullptr ||
2537       pattern->entity_data() != nullptr) {
2538     return true;
2539   }
2540   if (pattern->capturing_group() != nullptr) {
2541     for (const CapturingGroup* group : *pattern->capturing_group()) {
2542       if (group->entity_field_path() != nullptr) {
2543         return true;
2544       }
2545       if (group->serialized_entity_data() != nullptr ||
2546           group->entity_data() != nullptr) {
2547         return true;
2548       }
2549     }
2550   }
2551   return false;
2552 }
2553 
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2554 bool Annotator::SerializedEntityDataFromRegexMatch(
2555     const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2556     std::string* serialized_entity_data) const {
2557   if (!HasEntityData(pattern)) {
2558     serialized_entity_data->clear();
2559     return true;
2560   }
2561   TC3_CHECK(entity_data_builder_ != nullptr);
2562 
2563   std::unique_ptr<MutableFlatbuffer> entity_data =
2564       entity_data_builder_->NewRoot();
2565 
2566   TC3_CHECK(entity_data != nullptr);
2567 
2568   // Set fixed entity data.
2569   if (pattern->serialized_entity_data() != nullptr) {
2570     entity_data->MergeFromSerializedFlatbuffer(
2571         StringPiece(pattern->serialized_entity_data()->c_str(),
2572                     pattern->serialized_entity_data()->size()));
2573   }
2574   if (pattern->entity_data() != nullptr) {
2575     entity_data->MergeFrom(
2576         reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2577   }
2578 
2579   // Add entity data from rule capturing groups.
2580   if (pattern->capturing_group() != nullptr) {
2581     const int num_groups = pattern->capturing_group()->size();
2582     for (int i = 0; i < num_groups; i++) {
2583       const CapturingGroup* group = pattern->capturing_group()->Get(i);
2584 
2585       // Check whether the group matched.
2586       Optional<std::string> group_match_text =
2587           GetCapturingGroupText(matcher, /*group_id=*/i);
2588       if (!group_match_text.has_value()) {
2589         continue;
2590       }
2591 
2592       // Set fixed entity data from capturing group match.
2593       if (group->serialized_entity_data() != nullptr) {
2594         entity_data->MergeFromSerializedFlatbuffer(
2595             StringPiece(group->serialized_entity_data()->c_str(),
2596                         group->serialized_entity_data()->size()));
2597       }
2598       if (group->entity_data() != nullptr) {
2599         entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2600             pattern->entity_data()));
2601       }
2602 
2603       // Set entity field from capturing group text.
2604       if (group->entity_field_path() != nullptr) {
2605         UnicodeText normalized_group_match_text =
2606             UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2607 
2608         // Apply normalization if specified.
2609         if (group->normalization_options() != nullptr) {
2610           normalized_group_match_text =
2611               NormalizeText(*unilib_, group->normalization_options(),
2612                             normalized_group_match_text);
2613         }
2614 
2615         if (!entity_data->ParseAndSet(
2616                 group->entity_field_path(),
2617                 normalized_group_match_text.ToUTF8String())) {
2618           TC3_LOG(ERROR)
2619               << "Could not set entity data from rule capturing group.";
2620           return false;
2621         }
2622       }
2623     }
2624   }
2625 
2626   *serialized_entity_data = entity_data->Serialize();
2627   return true;
2628 }
2629 
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2630 UnicodeText RemoveMoneySeparators(
2631     const std::unordered_set<char32>& decimal_separators,
2632     const UnicodeText& amount,
2633     UnicodeText::const_iterator it_decimal_separator) {
2634   UnicodeText whole_amount;
2635   for (auto it = amount.begin();
2636        it != amount.end() && it != it_decimal_separator; ++it) {
2637     if (std::find(decimal_separators.begin(), decimal_separators.end(),
2638                   static_cast<char32>(*it)) == decimal_separators.end()) {
2639       whole_amount.push_back(*it);
2640     }
2641   }
2642   return whole_amount;
2643 }
2644 
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2645 void Annotator::GetMoneyQuantityFromCapturingGroup(
2646     const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2647     const UnicodeText& context_unicode, std::string* quantity,
2648     int* exponent) const {
2649   if (config->capturing_group() == nullptr) {
2650     *exponent = 0;
2651     return;
2652   }
2653 
2654   const int num_groups = config->capturing_group()->size();
2655   for (int i = 0; i < num_groups; i++) {
2656     int status = UniLib::RegexMatcher::kNoError;
2657     const int group_start = match->Start(i, &status);
2658     const int group_end = match->End(i, &status);
2659     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2660       continue;
2661     }
2662 
2663     *quantity =
2664         unilib_
2665             ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2666                                                  group_end, /*do_copy=*/false))
2667             .ToUTF8String();
2668 
2669     if (auto entry = model_->money_parsing_options()
2670                          ->quantities_name_to_exponent()
2671                          ->LookupByKey((*quantity).c_str())) {
2672       *exponent = entry->value();
2673       return;
2674     }
2675   }
2676   *exponent = 0;
2677 }
2678 
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2679 bool Annotator::ParseAndFillInMoneyAmount(
2680     std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2681     const RegexModel_::Pattern* config,
2682     const UnicodeText& context_unicode) const {
2683   std::unique_ptr<EntityDataT> data =
2684       LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2685           *serialized_entity_data);
2686   if (data == nullptr) {
2687     if (model_->version() >= 706) {
2688       // This way of parsing money entity data is enabled for models newer than
2689       // v706, consequently logging errors only for them (b/156634162).
2690       TC3_LOG(ERROR)
2691           << "Data field is null when trying to parse Money Entity Data";
2692     }
2693     return false;
2694   }
2695   if (data->money->unnormalized_amount.empty()) {
2696     if (model_->version() >= 706) {
2697       // This way of parsing money entity data is enabled for models newer than
2698       // v706, consequently logging errors only for them (b/156634162).
2699       TC3_LOG(ERROR)
2700           << "Data unnormalized_amount is empty when trying to parse "
2701              "Money Entity Data";
2702     }
2703     return false;
2704   }
2705 
2706   UnicodeText amount =
2707       UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2708   int separator_back_index = 0;
2709   auto it_decimal_separator = --amount.end();
2710   for (; it_decimal_separator != amount.begin();
2711        --it_decimal_separator, ++separator_back_index) {
2712     if (std::find(money_separators_.begin(), money_separators_.end(),
2713                   static_cast<char32>(*it_decimal_separator)) !=
2714         money_separators_.end()) {
2715       break;
2716     }
2717   }
2718 
2719   // If there are 3 digits after the last separator, we consider that a
2720   // thousands separator => the number is an int (e.g. 1.234 is considered int).
2721   // If there is no separator in number, also that number is an int.
2722   if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2723     it_decimal_separator = amount.end();
2724   }
2725 
2726   if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2727                                                  it_decimal_separator),
2728                            &data->money->amount_whole_part)) {
2729     TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2730                       "amount: "
2731                    << data->money->unnormalized_amount;
2732     return false;
2733   }
2734 
2735   if (it_decimal_separator == amount.end()) {
2736     data->money->amount_decimal_part = 0;
2737     data->money->nanos = 0;
2738   } else {
2739     const int amount_codepoints_size = amount.size_codepoints();
2740     const UnicodeText decimal_part = UnicodeText::Substring(
2741         amount, amount_codepoints_size - separator_back_index,
2742         amount_codepoints_size, /*do_copy=*/false);
2743     if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2744       TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2745                         "the amount: "
2746                      << data->money->unnormalized_amount;
2747       return false;
2748     }
2749     data->money->nanos = data->money->amount_decimal_part *
2750                          pow(10, 9 - decimal_part.size_codepoints());
2751   }
2752 
2753   if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2754       nullptr) {
2755     int quantity_exponent;
2756     std::string quantity;
2757     GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2758                                        &quantity, &quantity_exponent);
2759     if (quantity_exponent > 0 && quantity_exponent <= 9) {
2760       const double amount_whole_part =
2761           data->money->amount_whole_part * pow(10, quantity_exponent) +
2762           data->money->nanos / pow(10, 9 - quantity_exponent);
2763       // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2764       // (and `std::numeric_limits<int>::max()` to
2765       // `std::numeric_limits<int64>::max()`).
2766       if (amount_whole_part < std::numeric_limits<int>::max()) {
2767         data->money->amount_whole_part = amount_whole_part;
2768         data->money->nanos = data->money->nanos %
2769                              static_cast<int>(pow(10, 9 - quantity_exponent)) *
2770                              pow(10, quantity_exponent);
2771       }
2772     }
2773     if (quantity_exponent > 0) {
2774       data->money->unnormalized_amount = strings::JoinStrings(
2775           " ", {data->money->unnormalized_amount, quantity});
2776     }
2777   }
2778 
2779   *serialized_entity_data =
2780       PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2781   return true;
2782 }
2783 
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2784 bool Annotator::IsAnyModelEntityTypeEnabled(
2785     const EnabledEntityTypes& is_entity_type_enabled) const {
2786   if (model_->classification_feature_options() == nullptr ||
2787       model_->classification_feature_options()->collections() == nullptr) {
2788     return false;
2789   }
2790   for (int i = 0;
2791        i < model_->classification_feature_options()->collections()->size();
2792        i++) {
2793     if (is_entity_type_enabled(model_->classification_feature_options()
2794                                    ->collections()
2795                                    ->Get(i)
2796                                    ->str())) {
2797       return true;
2798     }
2799   }
2800   return false;
2801 }
2802 
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2803 bool Annotator::IsAnyRegexEntityTypeEnabled(
2804     const EnabledEntityTypes& is_entity_type_enabled) const {
2805   if (model_->regex_model() == nullptr ||
2806       model_->regex_model()->patterns() == nullptr) {
2807     return false;
2808   }
2809   for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2810     if (is_entity_type_enabled(model_->regex_model()
2811                                    ->patterns()
2812                                    ->Get(i)
2813                                    ->collection_name()
2814                                    ->str())) {
2815       return true;
2816     }
2817   }
2818   return false;
2819 }
2820 
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2821 bool Annotator::IsAnyPodNerEntityTypeEnabled(
2822     const EnabledEntityTypes& is_entity_type_enabled) const {
2823   if (pod_ner_annotator_ == nullptr) {
2824     return false;
2825   }
2826 
2827   for (const std::string& collection :
2828        pod_ner_annotator_->GetSupportedCollections()) {
2829     if (is_entity_type_enabled(collection)) {
2830       return true;
2831     }
2832   }
2833   return false;
2834 }
2835 
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2836 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2837                            const std::vector<int>& rules,
2838                            bool is_serialized_entity_data_enabled,
2839                            const EnabledEntityTypes& enabled_entity_types,
2840                            const AnnotationUsecase& annotation_usecase,
2841                            std::vector<AnnotatedSpan>* result) const {
2842   for (int pattern_id : rules) {
2843     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2844     if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2845         annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2846       // No regex annotation type has been requested, skip regex annotation.
2847       continue;
2848     }
2849     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2850     if (!matcher) {
2851       TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2852                      << pattern_id;
2853       return false;
2854     }
2855 
2856     int status = UniLib::RegexMatcher::kNoError;
2857     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2858       if (regex_pattern.config->verification_options()) {
2859         if (!VerifyRegexMatchCandidate(
2860                 context_unicode.ToUTF8String(),
2861                 regex_pattern.config->verification_options(),
2862                 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2863           continue;
2864         }
2865       }
2866 
2867       std::string serialized_entity_data;
2868       if (is_serialized_entity_data_enabled) {
2869         if (!SerializedEntityDataFromRegexMatch(
2870                 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2871           TC3_LOG(ERROR) << "Could not get entity data.";
2872           return false;
2873         }
2874 
2875         // Further parsing of money amount. Need this since regexes cannot have
2876         // empty groups that fill in entity data (amount_decimal_part and
2877         // quantity might be empty groups).
2878         if (regex_pattern.config->collection_name()->str() ==
2879             Collections::Money()) {
2880           if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2881                                          regex_pattern.config,
2882                                          context_unicode)) {
2883             if (model_->version() >= 706) {
2884               // This way of parsing money entity data is enabled for models
2885               // newer than v706 => logging errors only for them (b/156634162).
2886               TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2887             }
2888           }
2889         }
2890       }
2891 
2892       result->emplace_back();
2893 
2894       // Selection/annotation regular expressions need to specify a capturing
2895       // group specifying the selection.
2896       result->back().span =
2897           ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2898 
2899       result->back().classification = {
2900           {regex_pattern.config->collection_name()->str(),
2901            regex_pattern.config->target_classification_score(),
2902            regex_pattern.config->priority_score()}};
2903 
2904       result->back().classification[0].serialized_entity_data =
2905           serialized_entity_data;
2906     }
2907   }
2908   return true;
2909 }
2910 
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2911 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2912                            tflite::Interpreter* selection_interpreter,
2913                            const CachedFeatures& cached_features,
2914                            std::vector<TokenSpan>* chunks) const {
2915   const int max_selection_span =
2916       selection_feature_processor_->GetOptions()->max_selection_span();
2917   // The inference span is the span of interest expanded to include
2918   // max_selection_span tokens on either side, which is how far a selection can
2919   // stretch from the click.
2920   const TokenSpan inference_span =
2921       IntersectTokenSpans(span_of_interest.Expand(
2922                               /*num_tokens_left=*/max_selection_span,
2923                               /*num_tokens_right=*/max_selection_span),
2924                           {0, num_tokens});
2925 
2926   std::vector<ScoredChunk> scored_chunks;
2927   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2928       selection_feature_processor_->GetOptions()
2929           ->bounds_sensitive_features()
2930           ->enabled()) {
2931     if (!ModelBoundsSensitiveScoreChunks(
2932             num_tokens, span_of_interest, inference_span, cached_features,
2933             selection_interpreter, &scored_chunks)) {
2934       return false;
2935     }
2936   } else {
2937     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2938                                       cached_features, selection_interpreter,
2939                                       &scored_chunks)) {
2940       return false;
2941     }
2942   }
2943   std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
2944                    [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2945                      return lhs.score < rhs.score;
2946                    });
2947 
2948   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2949   // them greedily as long as they do not overlap with any previously picked
2950   // chunks.
2951   std::vector<bool> token_used(inference_span.Size());
2952   chunks->clear();
2953   for (const ScoredChunk& scored_chunk : scored_chunks) {
2954     bool feasible = true;
2955     for (int i = scored_chunk.token_span.first;
2956          i < scored_chunk.token_span.second; ++i) {
2957       if (token_used[i - inference_span.first]) {
2958         feasible = false;
2959         break;
2960       }
2961     }
2962 
2963     if (!feasible) {
2964       continue;
2965     }
2966 
2967     for (int i = scored_chunk.token_span.first;
2968          i < scored_chunk.token_span.second; ++i) {
2969       token_used[i - inference_span.first] = true;
2970     }
2971 
2972     chunks->push_back(scored_chunk.token_span);
2973   }
2974 
2975   std::stable_sort(chunks->begin(), chunks->end());
2976 
2977   return true;
2978 }
2979 
2980 namespace {
2981 // Updates the value at the given key in the map to maximum of the current value
2982 // and the given value, or simply inserts the value if the key is not yet there.
2983 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2984 void UpdateMax(Map* map, typename Map::key_type key,
2985                typename Map::mapped_type value) {
2986   const auto it = map->find(key);
2987   if (it != map->end()) {
2988     it->second = std::max(it->second, value);
2989   } else {
2990     (*map)[key] = value;
2991   }
2992 }
2993 }  // namespace
2994 
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2995 bool Annotator::ModelClickContextScoreChunks(
2996     int num_tokens, const TokenSpan& span_of_interest,
2997     const CachedFeatures& cached_features,
2998     tflite::Interpreter* selection_interpreter,
2999     std::vector<ScoredChunk>* scored_chunks) const {
3000   const int max_batch_size = model_->selection_options()->batch_size();
3001 
3002   std::vector<float> all_features;
3003   std::map<TokenSpan, float> chunk_scores;
3004   for (int batch_start = span_of_interest.first;
3005        batch_start < span_of_interest.second; batch_start += max_batch_size) {
3006     const int batch_end =
3007         std::min(batch_start + max_batch_size, span_of_interest.second);
3008 
3009     // Prepare features for the whole batch.
3010     all_features.clear();
3011     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3012     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3013       cached_features.AppendClickContextFeaturesForClick(click_pos,
3014                                                          &all_features);
3015     }
3016 
3017     // Run batched inference.
3018     const int batch_size = batch_end - batch_start;
3019     const int features_size = cached_features.OutputFeaturesSize();
3020     TensorView<float> logits = selection_executor_->ComputeLogits(
3021         TensorView<float>(all_features.data(), {batch_size, features_size}),
3022         selection_interpreter);
3023     if (!logits.is_valid()) {
3024       TC3_LOG(ERROR) << "Couldn't compute logits.";
3025       return false;
3026     }
3027     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3028         logits.dim(1) !=
3029             selection_feature_processor_->GetSelectionLabelCount()) {
3030       TC3_LOG(ERROR) << "Mismatching output.";
3031       return false;
3032     }
3033 
3034     // Save results.
3035     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3036       const std::vector<float> scores = ComputeSoftmax(
3037           logits.data() + logits.dim(1) * (click_pos - batch_start),
3038           logits.dim(1));
3039       for (int j = 0;
3040            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3041         TokenSpan relative_token_span;
3042         if (!selection_feature_processor_->LabelToTokenSpan(
3043                 j, &relative_token_span)) {
3044           TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3045           return false;
3046         }
3047         const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3048             relative_token_span.first, relative_token_span.second);
3049         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3050           UpdateMax(&chunk_scores, candidate_span, scores[j]);
3051         }
3052       }
3053     }
3054   }
3055 
3056   scored_chunks->clear();
3057   scored_chunks->reserve(chunk_scores.size());
3058   for (const auto& entry : chunk_scores) {
3059     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3060   }
3061 
3062   return true;
3063 }
3064 
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3065 bool Annotator::ModelBoundsSensitiveScoreChunks(
3066     int num_tokens, const TokenSpan& span_of_interest,
3067     const TokenSpan& inference_span, const CachedFeatures& cached_features,
3068     tflite::Interpreter* selection_interpreter,
3069     std::vector<ScoredChunk>* scored_chunks) const {
3070   const int max_selection_span =
3071       selection_feature_processor_->GetOptions()->max_selection_span();
3072   const int max_chunk_length = selection_feature_processor_->GetOptions()
3073                                        ->selection_reduced_output_space()
3074                                    ? max_selection_span + 1
3075                                    : 2 * max_selection_span + 1;
3076   const bool score_single_token_spans_as_zero =
3077       selection_feature_processor_->GetOptions()
3078           ->bounds_sensitive_features()
3079           ->score_single_token_spans_as_zero();
3080 
3081   scored_chunks->clear();
3082   if (score_single_token_spans_as_zero) {
3083     scored_chunks->reserve(span_of_interest.Size());
3084   }
3085 
3086   // Prepare all chunk candidates into one batch:
3087   //   - Are contained in the inference span
3088   //   - Have a non-empty intersection with the span of interest
3089   //   - Are at least one token long
3090   //   - Are not longer than the maximum chunk length
3091   std::vector<TokenSpan> candidate_spans;
3092   for (int start = inference_span.first; start < span_of_interest.second;
3093        ++start) {
3094     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3095     for (int end = leftmost_end_index;
3096          end <= inference_span.second && end - start <= max_chunk_length;
3097          ++end) {
3098       const TokenSpan candidate_span = {start, end};
3099       if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3100         // Do not include the single token span in the batch, add a zero score
3101         // for it directly to the output.
3102         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3103       } else {
3104         candidate_spans.push_back(candidate_span);
3105       }
3106     }
3107   }
3108 
3109   const int max_batch_size = model_->selection_options()->batch_size();
3110 
3111   std::vector<float> all_features;
3112   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3113   for (int batch_start = 0; batch_start < candidate_spans.size();
3114        batch_start += max_batch_size) {
3115     const int batch_end = std::min(batch_start + max_batch_size,
3116                                    static_cast<int>(candidate_spans.size()));
3117 
3118     // Prepare features for the whole batch.
3119     all_features.clear();
3120     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3121     for (int i = batch_start; i < batch_end; ++i) {
3122       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3123                                                            &all_features);
3124     }
3125 
3126     // Run batched inference.
3127     const int batch_size = batch_end - batch_start;
3128     const int features_size = cached_features.OutputFeaturesSize();
3129     TensorView<float> logits = selection_executor_->ComputeLogits(
3130         TensorView<float>(all_features.data(), {batch_size, features_size}),
3131         selection_interpreter);
3132     if (!logits.is_valid()) {
3133       TC3_LOG(ERROR) << "Couldn't compute logits.";
3134       return false;
3135     }
3136     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3137         logits.dim(1) != 1) {
3138       TC3_LOG(ERROR) << "Mismatching output.";
3139       return false;
3140     }
3141 
3142     // Save results.
3143     for (int i = batch_start; i < batch_end; ++i) {
3144       scored_chunks->push_back(
3145           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3146     }
3147   }
3148 
3149   return true;
3150 }
3151 
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3152 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3153                               int64 reference_time_ms_utc,
3154                               const std::string& reference_timezone,
3155                               const std::string& locales, ModeFlag mode,
3156                               AnnotationUsecase annotation_usecase,
3157                               bool is_serialized_entity_data_enabled,
3158                               std::vector<AnnotatedSpan>* result) const {
3159   if (!datetime_parser_) {
3160     return true;
3161   }
3162   LocaleList locale_list = LocaleList::ParseFrom(locales);
3163   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3164       datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3165                               reference_timezone, locale_list, mode,
3166                               annotation_usecase,
3167                               /*anchor_start_end=*/false);
3168   if (!result_status.ok()) {
3169     return false;
3170   }
3171 
3172   for (const DatetimeParseResultSpan& datetime_span :
3173        result_status.ValueOrDie()) {
3174     AnnotatedSpan annotated_span;
3175     annotated_span.span = datetime_span.span;
3176     for (const DatetimeParseResult& parse_result : datetime_span.data) {
3177       annotated_span.classification.emplace_back(
3178           PickCollectionForDatetime(parse_result),
3179           datetime_span.target_classification_score,
3180           datetime_span.priority_score);
3181       annotated_span.classification.back().datetime_parse_result = parse_result;
3182       if (is_serialized_entity_data_enabled) {
3183         annotated_span.classification.back().serialized_entity_data =
3184             CreateDatetimeSerializedEntityData(parse_result);
3185       }
3186     }
3187     annotated_span.source = AnnotatedSpan::Source::DATETIME;
3188     result->push_back(std::move(annotated_span));
3189   }
3190   return true;
3191 }
3192 
model() const3193 const Model* Annotator::model() const { return model_; }
entity_data_schema() const3194 const reflection::Schema* Annotator::entity_data_schema() const {
3195   return entity_data_schema_;
3196 }
3197 
ViewModel(const void * buffer,int size)3198 const Model* ViewModel(const void* buffer, int size) {
3199   if (!buffer) {
3200     return nullptr;
3201   }
3202 
3203   return LoadAndVerifyModel(buffer, size);
3204 }
3205 
LookUpKnowledgeEntity(const std::string & id) const3206 StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3207     const std::string& id) const {
3208   if (!knowledge_engine_) {
3209     return Status(StatusCode::FAILED_PRECONDITION,
3210                   "knowledge_engine_ is nullptr");
3211   }
3212   return knowledge_engine_->LookUpEntity(id);
3213 }
3214 
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3215 StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3216     const std::string& mid_str, const std::string& property) const {
3217   if (!knowledge_engine_) {
3218     return Status(StatusCode::FAILED_PRECONDITION,
3219                   "knowledge_engine_ is nullptr");
3220   }
3221   return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3222 }
3223 
3224 }  // namespace libtextclassifier3
3225