• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/annotator.h"
18 
19 #include <algorithm>
20 #include <cmath>
21 #include <cstddef>
22 #include <iterator>
23 #include <limits>
24 #include <numeric>
25 #include <string>
26 #include <unordered_map>
27 #include <vector>
28 
29 #include "annotator/collections.h"
30 #include "annotator/datetime/grammar-parser.h"
31 #include "annotator/datetime/regex-parser.h"
32 #include "annotator/flatbuffer-utils.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "annotator/model_generated.h"
35 #include "annotator/types.h"
36 #include "utils/base/logging.h"
37 #include "utils/base/status.h"
38 #include "utils/base/statusor.h"
39 #include "utils/calendar/calendar.h"
40 #include "utils/checksum.h"
41 #include "utils/grammar/analyzer.h"
42 #include "utils/i18n/locale-list.h"
43 #include "utils/i18n/locale.h"
44 #include "utils/math/softmax.h"
45 #include "utils/normalization.h"
46 #include "utils/optional.h"
47 #include "utils/regex-match.h"
48 #include "utils/strings/append.h"
49 #include "utils/strings/numbers.h"
50 #include "utils/strings/split.h"
51 #include "utils/utf8/unicodetext.h"
52 #include "utils/utf8/unilib-common.h"
53 #include "utils/zlib/zlib_regex.h"
54 
55 namespace libtextclassifier3 {
56 
57 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
58 
59 const std::string& Annotator::kPhoneCollection =
__anon0684102d0102() 60     *[]() { return new std::string("phone"); }();
61 const std::string& Annotator::kAddressCollection =
__anon0684102d0202() 62     *[]() { return new std::string("address"); }();
63 const std::string& Annotator::kDateCollection =
__anon0684102d0302() 64     *[]() { return new std::string("date"); }();
65 const std::string& Annotator::kUrlCollection =
__anon0684102d0402() 66     *[]() { return new std::string("url"); }();
67 const std::string& Annotator::kEmailCollection =
__anon0684102d0502() 68     *[]() { return new std::string("email"); }();
69 
70 namespace {
LoadAndVerifyModel(const void * addr,int size)71 const Model* LoadAndVerifyModel(const void* addr, int size) {
72   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
73   if (VerifyModelBuffer(verifier)) {
74     return GetModel(addr);
75   } else {
76     return nullptr;
77   }
78 }
79 
LoadAndVerifyPersonNameModel(const void * addr,int size)80 const PersonNameModel* LoadAndVerifyPersonNameModel(const void* addr,
81                                                     int size) {
82   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
83   if (VerifyPersonNameModelBuffer(verifier)) {
84     return GetPersonNameModel(addr);
85   } else {
86     return nullptr;
87   }
88 }
89 
90 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
91 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)92 const UniLib* MaybeCreateUnilib(const UniLib* lib,
93                                 std::unique_ptr<UniLib>* owned_lib) {
94   if (lib) {
95     return lib;
96   } else {
97     owned_lib->reset(new UniLib);
98     return owned_lib->get();
99   }
100 }
101 
102 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)103 const CalendarLib* MaybeCreateCalendarlib(
104     const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
105   if (lib) {
106     return lib;
107   } else {
108     owned_lib->reset(new CalendarLib);
109     return owned_lib->get();
110   }
111 }
112 
113 // Returns whether the provided input is valid:
114 //   * Sane span indices.
IsValidSpanInput(const UnicodeText & context,const CodepointSpan & span)115 bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
116   return (span.first >= 0 && span.first < span.second &&
117           span.second <= context.size_codepoints());
118 }
119 
FlatbuffersIntVectorToChar32UnorderedSet(const flatbuffers::Vector<int32_t> * ints)120 std::unordered_set<char32> FlatbuffersIntVectorToChar32UnorderedSet(
121     const flatbuffers::Vector<int32_t>* ints) {
122   if (ints == nullptr) {
123     return {};
124   }
125   std::unordered_set<char32> ints_set;
126   for (auto value : *ints) {
127     ints_set.insert(static_cast<char32>(value));
128   }
129   return ints_set;
130 }
131 
132 }  // namespace
133 
SelectionInterpreter()134 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
135   if (!selection_interpreter_) {
136     TC3_CHECK(selection_executor_);
137     selection_interpreter_ = selection_executor_->CreateInterpreter();
138     if (!selection_interpreter_) {
139       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
140     }
141   }
142   return selection_interpreter_.get();
143 }
144 
ClassificationInterpreter()145 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
146   if (!classification_interpreter_) {
147     TC3_CHECK(classification_executor_);
148     classification_interpreter_ = classification_executor_->CreateInterpreter();
149     if (!classification_interpreter_) {
150       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
151     }
152   }
153   return classification_interpreter_.get();
154 }
155 
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)156 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
157     const char* buffer, int size, const UniLib* unilib,
158     const CalendarLib* calendarlib) {
159   const Model* model = LoadAndVerifyModel(buffer, size);
160   if (model == nullptr) {
161     return nullptr;
162   }
163 
164   auto classifier = std::unique_ptr<Annotator>(new Annotator());
165   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
166   calendarlib =
167       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
168   classifier->ValidateAndInitialize(model, unilib, calendarlib);
169   if (!classifier->IsInitialized()) {
170     return nullptr;
171   }
172 
173   return classifier;
174 }
175 
FromString(const std::string & buffer,const UniLib * unilib,const CalendarLib * calendarlib)176 std::unique_ptr<Annotator> Annotator::FromString(
177     const std::string& buffer, const UniLib* unilib,
178     const CalendarLib* calendarlib) {
179   auto classifier = std::unique_ptr<Annotator>(new Annotator());
180   classifier->owned_buffer_ = buffer;
181   const Model* model = LoadAndVerifyModel(classifier->owned_buffer_.data(),
182                                           classifier->owned_buffer_.size());
183   if (model == nullptr) {
184     return nullptr;
185   }
186   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
187   calendarlib =
188       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
189   classifier->ValidateAndInitialize(model, unilib, calendarlib);
190   if (!classifier->IsInitialized()) {
191     return nullptr;
192   }
193 
194   return classifier;
195 }
196 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)197 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
198     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
199     const CalendarLib* calendarlib) {
200   if (!(*mmap)->handle().ok()) {
201     TC3_VLOG(1) << "Mmap failed.";
202     return nullptr;
203   }
204 
205   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
206                                           (*mmap)->handle().num_bytes());
207   if (!model) {
208     TC3_LOG(ERROR) << "Model verification failed.";
209     return nullptr;
210   }
211 
212   auto classifier = std::unique_ptr<Annotator>(new Annotator());
213   classifier->mmap_ = std::move(*mmap);
214   unilib = MaybeCreateUnilib(unilib, &classifier->owned_unilib_);
215   calendarlib =
216       MaybeCreateCalendarlib(calendarlib, &classifier->owned_calendarlib_);
217   classifier->ValidateAndInitialize(model, unilib, calendarlib);
218   if (!classifier->IsInitialized()) {
219     return nullptr;
220   }
221 
222   return classifier;
223 }
224 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)225 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
226     std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
227     std::unique_ptr<CalendarLib> calendarlib) {
228   if (!(*mmap)->handle().ok()) {
229     TC3_VLOG(1) << "Mmap failed.";
230     return nullptr;
231   }
232 
233   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
234                                           (*mmap)->handle().num_bytes());
235   if (model == nullptr) {
236     TC3_LOG(ERROR) << "Model verification failed.";
237     return nullptr;
238   }
239 
240   auto classifier = std::unique_ptr<Annotator>(new Annotator());
241   classifier->mmap_ = std::move(*mmap);
242   classifier->owned_unilib_ = std::move(unilib);
243   classifier->owned_calendarlib_ = std::move(calendarlib);
244   classifier->ValidateAndInitialize(model, classifier->owned_unilib_.get(),
245                                     classifier->owned_calendarlib_.get());
246   if (!classifier->IsInitialized()) {
247     return nullptr;
248   }
249 
250   return classifier;
251 }
252 
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)253 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
254     int fd, int offset, int size, const UniLib* unilib,
255     const CalendarLib* calendarlib) {
256   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
257   return FromScopedMmap(&mmap, unilib, calendarlib);
258 }
259 
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)260 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
261     int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
262     std::unique_ptr<CalendarLib> calendarlib) {
263   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
264   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
265 }
266 
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)267 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
268     int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
269   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
270   return FromScopedMmap(&mmap, unilib, calendarlib);
271 }
272 
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)273 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
274     int fd, std::unique_ptr<UniLib> unilib,
275     std::unique_ptr<CalendarLib> calendarlib) {
276   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
277   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
278 }
279 
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)280 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
281                                                const UniLib* unilib,
282                                                const CalendarLib* calendarlib) {
283   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
284   return FromScopedMmap(&mmap, unilib, calendarlib);
285 }
286 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)287 std::unique_ptr<Annotator> Annotator::FromPath(
288     const std::string& path, std::unique_ptr<UniLib> unilib,
289     std::unique_ptr<CalendarLib> calendarlib) {
290   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
291   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
292 }
293 
ValidateAndInitialize(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)294 void Annotator::ValidateAndInitialize(const Model* model, const UniLib* unilib,
295                                       const CalendarLib* calendarlib) {
296   model_ = model;
297   unilib_ = unilib;
298   calendarlib_ = calendarlib;
299 
300   initialized_ = false;
301 
302   if (model_ == nullptr) {
303     TC3_LOG(ERROR) << "No model specified.";
304     return;
305   }
306 
307   const bool model_enabled_for_annotation =
308       (model_->triggering_options() != nullptr &&
309        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
310   const bool model_enabled_for_classification =
311       (model_->triggering_options() != nullptr &&
312        (model_->triggering_options()->enabled_modes() &
313         ModeFlag_CLASSIFICATION));
314   const bool model_enabled_for_selection =
315       (model_->triggering_options() != nullptr &&
316        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
317 
318   // Annotation requires the selection model.
319   if (model_enabled_for_annotation || model_enabled_for_selection) {
320     if (!model_->selection_options()) {
321       TC3_LOG(ERROR) << "No selection options.";
322       return;
323     }
324     if (!model_->selection_feature_options()) {
325       TC3_LOG(ERROR) << "No selection feature options.";
326       return;
327     }
328     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
329       TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
330       return;
331     }
332     if (!model_->selection_model()) {
333       TC3_LOG(ERROR) << "No selection model.";
334       return;
335     }
336     selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
337     if (!selection_executor_) {
338       TC3_LOG(ERROR) << "Could not initialize selection executor.";
339       return;
340     }
341     selection_feature_processor_.reset(
342         new FeatureProcessor(model_->selection_feature_options(), unilib_));
343   }
344 
345   // Annotation requires the classification model for conflict resolution and
346   // scoring.
347   // Selection requires the classification model for conflict resolution.
348   if (model_enabled_for_annotation || model_enabled_for_classification ||
349       model_enabled_for_selection) {
350     if (!model_->classification_options()) {
351       TC3_LOG(ERROR) << "No classification options.";
352       return;
353     }
354 
355     if (!model_->classification_feature_options()) {
356       TC3_LOG(ERROR) << "No classification feature options.";
357       return;
358     }
359 
360     if (!model_->classification_feature_options()
361              ->bounds_sensitive_features()) {
362       TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
363       return;
364     }
365     if (!model_->classification_model()) {
366       TC3_LOG(ERROR) << "No clf model.";
367       return;
368     }
369 
370     classification_executor_ =
371         ModelExecutor::FromBuffer(model_->classification_model());
372     if (!classification_executor_) {
373       TC3_LOG(ERROR) << "Could not initialize classification executor.";
374       return;
375     }
376 
377     classification_feature_processor_.reset(new FeatureProcessor(
378         model_->classification_feature_options(), unilib_));
379   }
380 
381   // The embeddings need to be specified if the model is to be used for
382   // classification or selection.
383   if (model_enabled_for_annotation || model_enabled_for_classification ||
384       model_enabled_for_selection) {
385     if (!model_->embedding_model()) {
386       TC3_LOG(ERROR) << "No embedding model.";
387       return;
388     }
389 
390     // Check that the embedding size of the selection and classification model
391     // matches, as they are using the same embeddings.
392     if (model_enabled_for_selection &&
393         (model_->selection_feature_options()->embedding_size() !=
394              model_->classification_feature_options()->embedding_size() ||
395          model_->selection_feature_options()->embedding_quantization_bits() !=
396              model_->classification_feature_options()
397                  ->embedding_quantization_bits())) {
398       TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
399       return;
400     }
401 
402     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
403         model_->embedding_model(),
404         model_->classification_feature_options()->embedding_size(),
405         model_->classification_feature_options()->embedding_quantization_bits(),
406         model_->embedding_pruning_mask());
407     if (!embedding_executor_) {
408       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
409       return;
410     }
411   }
412 
413   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
414   if (model_->regex_model()) {
415     if (!InitializeRegexModel(decompressor.get())) {
416       TC3_LOG(ERROR) << "Could not initialize regex model.";
417       return;
418     }
419   }
420 
421   if (model_->datetime_grammar_model()) {
422     if (model_->datetime_grammar_model()->rules()) {
423       analyzer_ = std::make_unique<grammar::Analyzer>(
424           unilib_, model_->datetime_grammar_model()->rules());
425       datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_);
426       datetime_parser_ = std::make_unique<GrammarDatetimeParser>(
427           *analyzer_, *datetime_grounder_,
428           /*target_classification_score=*/1.0,
429           /*priority_score=*/1.0);
430     }
431   } else if (model_->datetime_model()) {
432     datetime_parser_ = RegexDatetimeParser::Instance(
433         model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
434     if (!datetime_parser_) {
435       TC3_LOG(ERROR) << "Could not initialize datetime parser.";
436       return;
437     }
438   }
439 
440   if (model_->output_options()) {
441     if (model_->output_options()->filtered_collections_annotation()) {
442       for (const auto collection :
443            *model_->output_options()->filtered_collections_annotation()) {
444         filtered_collections_annotation_.insert(collection->str());
445       }
446     }
447     if (model_->output_options()->filtered_collections_classification()) {
448       for (const auto collection :
449            *model_->output_options()->filtered_collections_classification()) {
450         filtered_collections_classification_.insert(collection->str());
451       }
452     }
453     if (model_->output_options()->filtered_collections_selection()) {
454       for (const auto collection :
455            *model_->output_options()->filtered_collections_selection()) {
456         filtered_collections_selection_.insert(collection->str());
457       }
458     }
459   }
460 
461   if (model_->number_annotator_options() &&
462       model_->number_annotator_options()->enabled()) {
463     number_annotator_.reset(
464         new NumberAnnotator(model_->number_annotator_options(), unilib_));
465   }
466 
467   if (model_->money_parsing_options()) {
468     money_separators_ = FlatbuffersIntVectorToChar32UnorderedSet(
469         model_->money_parsing_options()->separators());
470   }
471 
472   if (model_->duration_annotator_options() &&
473       model_->duration_annotator_options()->enabled()) {
474     duration_annotator_.reset(
475         new DurationAnnotator(model_->duration_annotator_options(),
476                               selection_feature_processor_.get(), unilib_));
477   }
478 
479   if (model_->grammar_model()) {
480     grammar_annotator_.reset(new GrammarAnnotator(
481         unilib_, model_->grammar_model(), entity_data_builder_.get()));
482   }
483 
484   // The following #ifdef is here to aid quality evaluation of a situation, when
485   // a POD NER kill switch in AiAi is invoked, when a model that has POD NER in
486   // it.
487 #if !defined(TC3_DISABLE_POD_NER)
488   if (model_->pod_ner_model()) {
489     pod_ner_annotator_ =
490         PodNerAnnotator::Create(model_->pod_ner_model(), *unilib_);
491   }
492 #endif
493 
494   if (model_->vocab_model()) {
495     vocab_annotator_ = VocabAnnotator::Create(
496         model_->vocab_model(), *selection_feature_processor_, *unilib_);
497   }
498 
499   if (model_->entity_data_schema()) {
500     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
501         model_->entity_data_schema()->Data(),
502         model_->entity_data_schema()->size());
503     if (entity_data_schema_ == nullptr) {
504       TC3_LOG(ERROR) << "Could not load entity data schema data.";
505       return;
506     }
507 
508     entity_data_builder_.reset(
509         new MutableFlatbufferBuilder(entity_data_schema_));
510   } else {
511     entity_data_schema_ = nullptr;
512     entity_data_builder_ = nullptr;
513   }
514 
515   if (model_->triggering_locales() &&
516       !ParseLocales(model_->triggering_locales()->c_str(),
517                     &model_triggering_locales_)) {
518     TC3_LOG(ERROR) << "Could not parse model supported locales.";
519     return;
520   }
521 
522   if (model_->triggering_options() != nullptr &&
523       model_->triggering_options()->locales() != nullptr &&
524       !ParseLocales(model_->triggering_options()->locales()->c_str(),
525                     &ml_model_triggering_locales_)) {
526     TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
527     return;
528   }
529 
530   if (model_->triggering_options() != nullptr &&
531       model_->triggering_options()->dictionary_locales() != nullptr &&
532       !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
533                     &dictionary_locales_)) {
534     TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
535     return;
536   }
537 
538   if (model_->conflict_resolution_options() != nullptr) {
539     prioritize_longest_annotation_ =
540         model_->conflict_resolution_options()->prioritize_longest_annotation();
541     do_conflict_resolution_in_raw_mode_ =
542         model_->conflict_resolution_options()
543             ->do_conflict_resolution_in_raw_mode();
544   }
545 
546 #ifdef TC3_EXPERIMENTAL
547   TC3_LOG(WARNING) << "Enabling experimental annotators.";
548   InitializeExperimentalAnnotators();
549 #endif
550 
551   initialized_ = true;
552 }
553 
InitializeRegexModel(ZlibDecompressor * decompressor)554 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
555   if (!model_->regex_model()->patterns()) {
556     return true;
557   }
558 
559   // Initialize pattern recognizers.
560   int regex_pattern_id = 0;
561   for (const auto regex_pattern : *model_->regex_model()->patterns()) {
562     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
563         UncompressMakeRegexPattern(
564             *unilib_, regex_pattern->pattern(),
565             regex_pattern->compressed_pattern(),
566             model_->regex_model()->lazy_regex_compilation(), decompressor);
567     if (!compiled_pattern) {
568       TC3_LOG(INFO) << "Failed to load regex pattern";
569       return false;
570     }
571 
572     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
573       annotation_regex_patterns_.push_back(regex_pattern_id);
574     }
575     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
576       classification_regex_patterns_.push_back(regex_pattern_id);
577     }
578     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
579       selection_regex_patterns_.push_back(regex_pattern_id);
580     }
581     regex_patterns_.push_back({
582         regex_pattern,
583         std::move(compiled_pattern),
584     });
585     ++regex_pattern_id;
586   }
587 
588   return true;
589 }
590 
InitializeKnowledgeEngine(const std::string & serialized_config)591 bool Annotator::InitializeKnowledgeEngine(
592     const std::string& serialized_config) {
593   std::unique_ptr<KnowledgeEngine> knowledge_engine(new KnowledgeEngine());
594   if (!knowledge_engine->Initialize(serialized_config, unilib_)) {
595     TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
596     return false;
597   }
598   if (model_->triggering_options() != nullptr) {
599     knowledge_engine->SetPriorityScore(
600         model_->triggering_options()->knowledge_priority_score());
601   }
602   knowledge_engine_ = std::move(knowledge_engine);
603   return true;
604 }
605 
InitializeContactEngine(const std::string & serialized_config)606 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
607   std::unique_ptr<ContactEngine> contact_engine(
608       new ContactEngine(selection_feature_processor_.get(), unilib_,
609                         model_->contact_annotator_options()));
610   if (!contact_engine->Initialize(serialized_config)) {
611     TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
612     return false;
613   }
614   contact_engine_ = std::move(contact_engine);
615   return true;
616 }
617 
InitializeInstalledAppEngine(const std::string & serialized_config)618 bool Annotator::InitializeInstalledAppEngine(
619     const std::string& serialized_config) {
620   std::unique_ptr<InstalledAppEngine> installed_app_engine(
621       new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
622   if (!installed_app_engine->Initialize(serialized_config)) {
623     TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
624     return false;
625   }
626   installed_app_engine_ = std::move(installed_app_engine);
627   return true;
628 }
629 
SetLangId(const libtextclassifier3::mobile::lang_id::LangId * lang_id)630 bool Annotator::SetLangId(const libtextclassifier3::mobile::lang_id::LangId* lang_id) {
631   if (lang_id == nullptr) {
632     return false;
633   }
634 
635   lang_id_ = lang_id;
636   if (lang_id_ != nullptr && model_->translate_annotator_options() &&
637       model_->translate_annotator_options()->enabled()) {
638     translate_annotator_.reset(new TranslateAnnotator(
639         model_->translate_annotator_options(), lang_id_, unilib_));
640   } else {
641     translate_annotator_.reset(nullptr);
642   }
643   return true;
644 }
645 
InitializePersonNameEngineFromUnownedBuffer(const void * buffer,int size)646 bool Annotator::InitializePersonNameEngineFromUnownedBuffer(const void* buffer,
647                                                             int size) {
648   const PersonNameModel* person_name_model =
649       LoadAndVerifyPersonNameModel(buffer, size);
650 
651   if (person_name_model == nullptr) {
652     TC3_LOG(ERROR) << "Person name model verification failed.";
653     return false;
654   }
655 
656   if (!person_name_model->enabled()) {
657     return true;
658   }
659 
660   std::unique_ptr<PersonNameEngine> person_name_engine(
661       new PersonNameEngine(selection_feature_processor_.get(), unilib_));
662   if (!person_name_engine->Initialize(person_name_model)) {
663     TC3_LOG(ERROR) << "Failed to initialize the person name engine.";
664     return false;
665   }
666   person_name_engine_ = std::move(person_name_engine);
667   return true;
668 }
669 
InitializePersonNameEngineFromScopedMmap(const ScopedMmap & mmap)670 bool Annotator::InitializePersonNameEngineFromScopedMmap(
671     const ScopedMmap& mmap) {
672   if (!mmap.handle().ok()) {
673     TC3_LOG(ERROR) << "Mmap for person name model failed.";
674     return false;
675   }
676 
677   return InitializePersonNameEngineFromUnownedBuffer(mmap.handle().start(),
678                                                      mmap.handle().num_bytes());
679 }
680 
InitializePersonNameEngineFromPath(const std::string & path)681 bool Annotator::InitializePersonNameEngineFromPath(const std::string& path) {
682   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
683   return InitializePersonNameEngineFromScopedMmap(*mmap);
684 }
685 
InitializePersonNameEngineFromFileDescriptor(int fd,int offset,int size)686 bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
687                                                              int size) {
688   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
689   return InitializePersonNameEngineFromScopedMmap(*mmap);
690 }
691 
InitializeExperimentalAnnotators()692 bool Annotator::InitializeExperimentalAnnotators() {
693   if (ExperimentalAnnotator::IsEnabled()) {
694     experimental_annotator_.reset(new ExperimentalAnnotator(
695         model_->experimental_model(), *selection_feature_processor_, *unilib_));
696     return true;
697   }
698   return false;
699 }
700 
701 namespace internal {
702 // Helper function, which if the initial 'span' contains only white-spaces,
703 // moves the selection to a single-codepoint selection on a left or right side
704 // of this space.
SnapLeftIfWhitespaceSelection(const CodepointSpan & span,const UnicodeText & context_unicode,const UniLib & unilib)705 CodepointSpan SnapLeftIfWhitespaceSelection(const CodepointSpan& span,
706                                             const UnicodeText& context_unicode,
707                                             const UniLib& unilib) {
708   TC3_CHECK(span.IsValid() && !span.IsEmpty());
709 
710   UnicodeText::const_iterator it;
711 
712   // Check that the current selection is all whitespaces.
713   it = context_unicode.begin();
714   std::advance(it, span.first);
715   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
716     if (!unilib.IsWhitespace(*it)) {
717       return span;
718     }
719   }
720 
721   // Try moving left.
722   CodepointSpan result = span;
723   it = context_unicode.begin();
724   std::advance(it, span.first);
725   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
726     --result.first;
727     --it;
728   }
729   result.second = result.first + 1;
730   if (!unilib.IsWhitespace(*it)) {
731     return result;
732   }
733 
734   // If moving left didn't find a non-whitespace character, just return the
735   // original span.
736   return span;
737 }
738 }  // namespace internal
739 
FilteredForAnnotation(const AnnotatedSpan & span) const740 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
741   return !span.classification.empty() &&
742          filtered_collections_annotation_.find(
743              span.classification[0].collection) !=
744              filtered_collections_annotation_.end();
745 }
746 
FilteredForClassification(const ClassificationResult & classification) const747 bool Annotator::FilteredForClassification(
748     const ClassificationResult& classification) const {
749   return filtered_collections_classification_.find(classification.collection) !=
750          filtered_collections_classification_.end();
751 }
752 
FilteredForSelection(const AnnotatedSpan & span) const753 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
754   return !span.classification.empty() &&
755          filtered_collections_selection_.find(
756              span.classification[0].collection) !=
757              filtered_collections_selection_.end();
758 }
759 
760 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)761 inline bool ClassifiedAsOther(
762     const std::vector<ClassificationResult>& classification) {
763   return !classification.empty() &&
764          classification[0].collection == Collections::Other();
765 }
766 
767 }  // namespace
768 
GetPriorityScore(const std::vector<ClassificationResult> & classification) const769 float Annotator::GetPriorityScore(
770     const std::vector<ClassificationResult>& classification) const {
771   if (!classification.empty() && !ClassifiedAsOther(classification)) {
772     return classification[0].priority_score;
773   } else {
774     if (model_->triggering_options() != nullptr) {
775       return model_->triggering_options()->other_collection_priority_score();
776     } else {
777       return -1000.0;
778     }
779   }
780 }
781 
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const782 bool Annotator::VerifyRegexMatchCandidate(
783     const std::string& context, const VerificationOptions* verification_options,
784     const std::string& match, const UniLib::RegexMatcher* matcher) const {
785   if (verification_options == nullptr) {
786     return true;
787   }
788   if (verification_options->verify_luhn_checksum() &&
789       !VerifyLuhnChecksum(match)) {
790     return false;
791   }
792   const int lua_verifier = verification_options->lua_verifier();
793   if (lua_verifier >= 0) {
794     if (model_->regex_model()->lua_verifier() == nullptr ||
795         lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
796       TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
797       return false;
798     }
799     return VerifyMatch(
800         context, matcher,
801         model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
802   }
803   return true;
804 }
805 
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const806 CodepointSpan Annotator::SuggestSelection(
807     const std::string& context, CodepointSpan click_indices,
808     const SelectionOptions& options) const {
809   if (context.size() > std::numeric_limits<int>::max()) {
810     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
811     return {};
812   }
813 
814   CodepointSpan original_click_indices = click_indices;
815   if (!initialized_) {
816     TC3_LOG(ERROR) << "Not initialized";
817     return original_click_indices;
818   }
819   if (options.annotation_usecase !=
820       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
821     TC3_LOG(WARNING)
822         << "Invoking SuggestSelection, which is not supported in RAW mode.";
823     return original_click_indices;
824   }
825   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
826     return original_click_indices;
827   }
828 
829   std::vector<Locale> detected_text_language_tags;
830   if (!ParseLocales(options.detected_text_language_tags,
831                     &detected_text_language_tags)) {
832     TC3_LOG(WARNING)
833         << "Failed to parse the detected_text_language_tags in options: "
834         << options.detected_text_language_tags;
835   }
836   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
837                                     model_triggering_locales_,
838                                     /*default_value=*/true)) {
839     return original_click_indices;
840   }
841 
842   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
843                                                         /*do_copy=*/false);
844 
845   if (!unilib_->IsValidUtf8(context_unicode)) {
846     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
847     return original_click_indices;
848   }
849 
850   if (!IsValidSpanInput(context_unicode, click_indices)) {
851     TC3_VLOG(1)
852         << "Trying to run SuggestSelection with invalid input, indices: "
853         << click_indices.first << " " << click_indices.second;
854     return original_click_indices;
855   }
856 
857   if (model_->snap_whitespace_selections()) {
858     // We want to expand a purely white-space selection to a multi-selection it
859     // would've been part of. But with this feature disabled we would do a no-
860     // op, because no token is found. Therefore, we need to modify the
861     // 'click_indices' a bit to include a part of the token, so that the click-
862     // finding logic finds the clicked token correctly. This modification is
863     // done by the following function. Note, that it's enough to check the left
864     // side of the current selection, because if the white-space is a part of a
865     // multi-selection, necessarily both tokens - on the left and the right
866     // sides need to be selected. Thus snapping only to the left is sufficient
867     // (there's a check at the bottom that makes sure that if we snap to the
868     // left token but the result does not contain the initial white-space,
869     // returns the original indices).
870     click_indices = internal::SnapLeftIfWhitespaceSelection(
871         click_indices, context_unicode, *unilib_);
872   }
873 
874   Annotations candidates;
875   // As we process a single string of context, the candidates will only
876   // contain one vector of AnnotatedSpan.
877   candidates.annotated_spans.resize(1);
878   InterpreterManager interpreter_manager(selection_executor_.get(),
879                                          classification_executor_.get());
880   std::vector<Token> tokens;
881   if (!ModelSuggestSelection(context_unicode, click_indices,
882                              detected_text_language_tags, &interpreter_manager,
883                              &tokens, &candidates.annotated_spans[0])) {
884     TC3_LOG(ERROR) << "Model suggest selection failed.";
885     return original_click_indices;
886   }
887   const std::unordered_set<std::string> set;
888   const EnabledEntityTypes is_entity_type_enabled(set);
889   if (!RegexChunk(context_unicode, selection_regex_patterns_,
890                   /*is_serialized_entity_data_enabled=*/false,
891                   is_entity_type_enabled, options.annotation_usecase,
892                   &candidates.annotated_spans[0])) {
893     TC3_LOG(ERROR) << "Regex suggest selection failed.";
894     return original_click_indices;
895   }
896   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
897                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
898                      options.locales, ModeFlag_SELECTION,
899                      options.annotation_usecase,
900                      /*is_serialized_entity_data_enabled=*/false,
901                      &candidates.annotated_spans[0])) {
902     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
903     return original_click_indices;
904   }
905   if (knowledge_engine_ != nullptr &&
906       !knowledge_engine_
907            ->Chunk(context, options.annotation_usecase,
908                    options.location_context, Permissions(),
909                    AnnotateMode::kEntityAnnotation, &candidates)
910            .ok()) {
911     TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
912     return original_click_indices;
913   }
914   if (contact_engine_ != nullptr &&
915       !contact_engine_->Chunk(context_unicode, tokens,
916                               &candidates.annotated_spans[0])) {
917     TC3_LOG(ERROR) << "Contact suggest selection failed.";
918     return original_click_indices;
919   }
920   if (installed_app_engine_ != nullptr &&
921       !installed_app_engine_->Chunk(context_unicode, tokens,
922                                     &candidates.annotated_spans[0])) {
923     TC3_LOG(ERROR) << "Installed app suggest selection failed.";
924     return original_click_indices;
925   }
926   if (number_annotator_ != nullptr &&
927       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
928                                   &candidates.annotated_spans[0])) {
929     TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
930     return original_click_indices;
931   }
932   if (duration_annotator_ != nullptr &&
933       !duration_annotator_->FindAll(context_unicode, tokens,
934                                     options.annotation_usecase,
935                                     &candidates.annotated_spans[0])) {
936     TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
937     return original_click_indices;
938   }
939   if (person_name_engine_ != nullptr &&
940       !person_name_engine_->Chunk(context_unicode, tokens,
941                                   &candidates.annotated_spans[0])) {
942     TC3_LOG(ERROR) << "Person name suggest selection failed.";
943     return original_click_indices;
944   }
945 
946   AnnotatedSpan grammar_suggested_span;
947   if (grammar_annotator_ != nullptr &&
948       grammar_annotator_->SuggestSelection(detected_text_language_tags,
949                                            context_unicode, click_indices,
950                                            &grammar_suggested_span)) {
951     candidates.annotated_spans[0].push_back(grammar_suggested_span);
952   }
953 
954   AnnotatedSpan pod_ner_suggested_span;
955   if (pod_ner_annotator_ != nullptr && options.use_pod_ner &&
956       pod_ner_annotator_->SuggestSelection(context_unicode, click_indices,
957                                            &pod_ner_suggested_span)) {
958     candidates.annotated_spans[0].push_back(pod_ner_suggested_span);
959   }
960 
961   if (experimental_annotator_ != nullptr) {
962     candidates.annotated_spans[0].push_back(
963         experimental_annotator_->SuggestSelection(context_unicode,
964                                                   click_indices));
965   }
966 
967   // Sort candidates according to their position in the input, so that the next
968   // code can assume that any connected component of overlapping spans forms a
969   // contiguous block.
970   std::sort(candidates.annotated_spans[0].begin(),
971             candidates.annotated_spans[0].end(),
972             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
973               return a.span.first < b.span.first;
974             });
975 
976   std::vector<int> candidate_indices;
977   if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
978                         detected_text_language_tags, options,
979                         &interpreter_manager, &candidate_indices)) {
980     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
981     return original_click_indices;
982   }
983 
984   std::sort(candidate_indices.begin(), candidate_indices.end(),
985             [this, &candidates](int a, int b) {
986               return GetPriorityScore(
987                          candidates.annotated_spans[0][a].classification) >
988                      GetPriorityScore(
989                          candidates.annotated_spans[0][b].classification);
990             });
991 
992   for (const int i : candidate_indices) {
993     if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
994         SpansOverlap(candidates.annotated_spans[0][i].span,
995                      original_click_indices)) {
996       // Run model classification if not present but requested and there's a
997       // classification collection filter specified.
998       if (candidates.annotated_spans[0][i].classification.empty() &&
999           model_->selection_options()->always_classify_suggested_selection() &&
1000           !filtered_collections_selection_.empty()) {
1001         if (!ModelClassifyText(context, /*cached_tokens=*/{},
1002                                detected_text_language_tags,
1003                                candidates.annotated_spans[0][i].span, options,
1004                                &interpreter_manager,
1005                                /*embedding_cache=*/nullptr,
1006                                &candidates.annotated_spans[0][i].classification,
1007                                /*tokens=*/nullptr)) {
1008           return original_click_indices;
1009         }
1010       }
1011 
1012       // Ignore if span classification is filtered.
1013       if (FilteredForSelection(candidates.annotated_spans[0][i])) {
1014         return original_click_indices;
1015       }
1016 
1017       // We return a suggested span contains the original span.
1018       // This compensates for "select all" selection that may come from
1019       // other apps. See http://b/179890518.
1020       if (SpanContains(candidates.annotated_spans[0][i].span,
1021                        original_click_indices)) {
1022         return candidates.annotated_spans[0][i].span;
1023       }
1024     }
1025   }
1026 
1027   return original_click_indices;
1028 }
1029 
1030 namespace {
1031 // Helper function that returns the index of the first candidate that
1032 // transitively does not overlap with the candidate on 'start_index'. If the end
1033 // of 'candidates' is reached, it returns the index that points right behind the
1034 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)1035 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
1036                                  int start_index) {
1037   int first_non_overlapping = start_index + 1;
1038   CodepointSpan conflicting_span = candidates[start_index].span;
1039   while (
1040       first_non_overlapping < candidates.size() &&
1041       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
1042     // Grow the span to include the current one.
1043     conflicting_span.second = std::max(
1044         conflicting_span.second, candidates[first_non_overlapping].span.second);
1045 
1046     ++first_non_overlapping;
1047   }
1048   return first_non_overlapping;
1049 }
1050 }  // namespace
1051 
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * result) const1052 bool Annotator::ResolveConflicts(
1053     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
1054     const std::vector<Token>& cached_tokens,
1055     const std::vector<Locale>& detected_text_language_tags,
1056     const BaseOptions& options, InterpreterManager* interpreter_manager,
1057     std::vector<int>* result) const {
1058   result->clear();
1059   result->reserve(candidates.size());
1060   for (int i = 0; i < candidates.size();) {
1061     int first_non_overlapping =
1062         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
1063 
1064     const bool conflict_found = first_non_overlapping != (i + 1);
1065     if (conflict_found) {
1066       std::vector<int> candidate_indices;
1067       if (!ResolveConflict(context, cached_tokens, candidates,
1068                            detected_text_language_tags, i,
1069                            first_non_overlapping, options, interpreter_manager,
1070                            &candidate_indices)) {
1071         return false;
1072       }
1073       result->insert(result->end(), candidate_indices.begin(),
1074                      candidate_indices.end());
1075     } else {
1076       result->push_back(i);
1077     }
1078 
1079     // Skip over the whole conflicting group/go to next candidate.
1080     i = first_non_overlapping;
1081   }
1082   return true;
1083 }
1084 
1085 namespace {
1086 // Returns true, if the given two sources do conflict in given annotation
1087 // usecase.
1088 //  - In SMART usecase, all sources do conflict, because there's only 1 possible
1089 //  annotation for a given span.
1090 //  - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
1091 //  and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)1092 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
1093                        const AnnotatedSpan::Source source1,
1094                        const AnnotatedSpan::Source source2) {
1095   uint32 source_mask =
1096       (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
1097 
1098   switch (annotation_usecase) {
1099     case AnnotationUsecase_ANNOTATION_USECASE_SMART:
1100       // In the SMART mode, all annotations conflict.
1101       return true;
1102 
1103     case AnnotationUsecase_ANNOTATION_USECASE_RAW:
1104       // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
1105       // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
1106       // hours" (duration).
1107       if ((source_mask &
1108            (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
1109           (source_mask &
1110            (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
1111         return false;
1112       }
1113 
1114       // A KNOWLEDGE entity does not conflict with anything.
1115       if ((source_mask &
1116            (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
1117         return false;
1118       }
1119 
1120       // A PERSONNAME entity does not conflict with anything.
1121       if ((source_mask &
1122            (1 << static_cast<int>(AnnotatedSpan::Source::PERSON_NAME)))) {
1123         return false;
1124       }
1125 
1126       // Entities from other sources can conflict.
1127       return true;
1128   }
1129 }
1130 }  // namespace
1131 
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,const BaseOptions & options,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const1132 bool Annotator::ResolveConflict(
1133     const std::string& context, const std::vector<Token>& cached_tokens,
1134     const std::vector<AnnotatedSpan>& candidates,
1135     const std::vector<Locale>& detected_text_language_tags, int start_index,
1136     int end_index, const BaseOptions& options,
1137     InterpreterManager* interpreter_manager,
1138     std::vector<int>* chosen_indices) const {
1139   std::vector<int> conflicting_indices;
1140   std::unordered_map<int, std::pair<float, int>> scores_lengths;
1141   for (int i = start_index; i < end_index; ++i) {
1142     conflicting_indices.push_back(i);
1143     if (!candidates[i].classification.empty()) {
1144       scores_lengths[i] = {
1145           GetPriorityScore(candidates[i].classification),
1146           candidates[i].span.second - candidates[i].span.first};
1147       continue;
1148     }
1149 
1150     // OPTIMIZATION: So that we don't have to classify all the ML model
1151     // spans apriori, we wait until we get here, when they conflict with
1152     // something and we need the actual classification scores. So if the
1153     // candidate conflicts and comes from the model, we need to run a
1154     // classification to determine its priority:
1155     std::vector<ClassificationResult> classification;
1156     if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1157                            candidates[i].span, options, interpreter_manager,
1158                            /*embedding_cache=*/nullptr, &classification,
1159                            /*tokens=*/nullptr)) {
1160       return false;
1161     }
1162 
1163     if (!classification.empty()) {
1164       scores_lengths[i] = {
1165           GetPriorityScore(classification),
1166           candidates[i].span.second - candidates[i].span.first};
1167     }
1168   }
1169 
1170   std::sort(
1171       conflicting_indices.begin(), conflicting_indices.end(),
1172       [this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
1173         if (scores_lengths[i].first == scores_lengths[j].first &&
1174             prioritize_longest_annotation_) {
1175           return scores_lengths[i].second > scores_lengths[j].second;
1176         }
1177         return scores_lengths[i].first > scores_lengths[j].first;
1178       });
1179 
1180   // Here we keep a set of indices that were chosen, per-source, to enable
1181   // effective computation.
1182   std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
1183       chosen_indices_for_source_map;
1184 
1185   // Greedily place the candidates if they don't conflict with the already
1186   // placed ones.
1187   for (int i = 0; i < conflicting_indices.size(); ++i) {
1188     const int considered_candidate = conflicting_indices[i];
1189 
1190     // See if there is a conflict between the candidate and all already placed
1191     // candidates.
1192     bool conflict = false;
1193     SortedIntSet* chosen_indices_for_source_ptr = nullptr;
1194     for (auto& source_set_pair : chosen_indices_for_source_map) {
1195       if (source_set_pair.first == candidates[considered_candidate].source) {
1196         chosen_indices_for_source_ptr = &source_set_pair.second;
1197       }
1198 
1199       const bool needs_conflict_resolution =
1200           options.annotation_usecase ==
1201               AnnotationUsecase_ANNOTATION_USECASE_SMART ||
1202           (options.annotation_usecase ==
1203                AnnotationUsecase_ANNOTATION_USECASE_RAW &&
1204            do_conflict_resolution_in_raw_mode_);
1205       if (needs_conflict_resolution &&
1206           DoSourcesConflict(options.annotation_usecase, source_set_pair.first,
1207                             candidates[considered_candidate].source) &&
1208           DoesCandidateConflict(considered_candidate, candidates,
1209                                 source_set_pair.second)) {
1210         conflict = true;
1211         break;
1212       }
1213     }
1214 
1215     // Skip the candidate if a conflict was found.
1216     if (conflict) {
1217       continue;
1218     }
1219 
1220     // If the set of indices for the current source doesn't exist yet,
1221     // initialize it.
1222     if (chosen_indices_for_source_ptr == nullptr) {
1223       SortedIntSet new_set([&candidates](int a, int b) {
1224         return candidates[a].span.first < candidates[b].span.first;
1225       });
1226       chosen_indices_for_source_map[candidates[considered_candidate].source] =
1227           std::move(new_set);
1228       chosen_indices_for_source_ptr =
1229           &chosen_indices_for_source_map[candidates[considered_candidate]
1230                                              .source];
1231     }
1232 
1233     // Place the candidate to the output and to the per-source conflict set.
1234     chosen_indices->push_back(considered_candidate);
1235     chosen_indices_for_source_ptr->insert(considered_candidate);
1236   }
1237 
1238   std::sort(chosen_indices->begin(), chosen_indices->end());
1239 
1240   return true;
1241 }
1242 
ModelSuggestSelection(const UnicodeText & context_unicode,const CodepointSpan & click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1243 bool Annotator::ModelSuggestSelection(
1244     const UnicodeText& context_unicode, const CodepointSpan& click_indices,
1245     const std::vector<Locale>& detected_text_language_tags,
1246     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1247     std::vector<AnnotatedSpan>* result) const {
1248   if (model_->triggering_options() == nullptr ||
1249       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1250     return true;
1251   }
1252 
1253   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1254                                     ml_model_triggering_locales_,
1255                                     /*default_value=*/true)) {
1256     return true;
1257   }
1258 
1259   int click_pos;
1260   *tokens = selection_feature_processor_->Tokenize(context_unicode);
1261   const auto [click_begin, click_end] =
1262       CodepointSpanToUnicodeTextRange(context_unicode, click_indices);
1263   selection_feature_processor_->RetokenizeAndFindClick(
1264       context_unicode, click_begin, click_end, click_indices,
1265       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1266       tokens, &click_pos);
1267   if (click_pos == kInvalidIndex) {
1268     TC3_VLOG(1) << "Could not calculate the click position.";
1269     return false;
1270   }
1271 
1272   const int symmetry_context_size =
1273       model_->selection_options()->symmetry_context_size();
1274   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1275       bounds_sensitive_features = selection_feature_processor_->GetOptions()
1276                                       ->bounds_sensitive_features();
1277 
1278   // The symmetry context span is the clicked token with symmetry_context_size
1279   // tokens on either side.
1280   const TokenSpan symmetry_context_span =
1281       IntersectTokenSpans(TokenSpan(click_pos).Expand(
1282                               /*num_tokens_left=*/symmetry_context_size,
1283                               /*num_tokens_right=*/symmetry_context_size),
1284                           AllOf(*tokens));
1285 
1286   // Compute the extraction span based on the model type.
1287   TokenSpan extraction_span;
1288   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1289     // The extraction span is the symmetry context span expanded to include
1290     // max_selection_span tokens on either side, which is how far a selection
1291     // can stretch from the click, plus a relevant number of tokens outside of
1292     // the bounds of the selection.
1293     const int max_selection_span =
1294         selection_feature_processor_->GetOptions()->max_selection_span();
1295     extraction_span = symmetry_context_span.Expand(
1296         /*num_tokens_left=*/max_selection_span +
1297             bounds_sensitive_features->num_tokens_before(),
1298         /*num_tokens_right=*/max_selection_span +
1299             bounds_sensitive_features->num_tokens_after());
1300   } else {
1301     // The extraction span is the symmetry context span expanded to include
1302     // context_size tokens on either side.
1303     const int context_size =
1304         selection_feature_processor_->GetOptions()->context_size();
1305     extraction_span = symmetry_context_span.Expand(
1306         /*num_tokens_left=*/context_size,
1307         /*num_tokens_right=*/context_size);
1308   }
1309   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1310 
1311   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1312           *tokens, extraction_span)) {
1313     return true;
1314   }
1315 
1316   std::unique_ptr<CachedFeatures> cached_features;
1317   if (!selection_feature_processor_->ExtractFeatures(
1318           *tokens, extraction_span,
1319           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1320           embedding_executor_.get(),
1321           /*embedding_cache=*/nullptr,
1322           selection_feature_processor_->EmbeddingSize() +
1323               selection_feature_processor_->DenseFeaturesCount(),
1324           &cached_features)) {
1325     TC3_LOG(ERROR) << "Could not extract features.";
1326     return false;
1327   }
1328 
1329   // Produce selection model candidates.
1330   std::vector<TokenSpan> chunks;
1331   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1332                   interpreter_manager->SelectionInterpreter(), *cached_features,
1333                   &chunks)) {
1334     TC3_LOG(ERROR) << "Could not chunk.";
1335     return false;
1336   }
1337 
1338   for (const TokenSpan& chunk : chunks) {
1339     AnnotatedSpan candidate;
1340     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1341         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1342     if (model_->selection_options()->strip_unpaired_brackets()) {
1343       candidate.span =
1344           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1345     }
1346 
1347     // Only output non-empty spans.
1348     if (candidate.span.first != candidate.span.second) {
1349       result->push_back(candidate);
1350     }
1351   }
1352   return true;
1353 }
1354 
1355 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,const CodepointSpan & selection_indices,TokenSpan tokens_around_selection_to_copy)1356 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1357                                     const CodepointSpan& selection_indices,
1358                                     TokenSpan tokens_around_selection_to_copy) {
1359   const auto first_selection_token = std::upper_bound(
1360       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1361       [](int selection_start, const Token& token) {
1362         return selection_start < token.end;
1363       });
1364   const auto last_selection_token = std::lower_bound(
1365       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1366       [](const Token& token, int selection_end) {
1367         return token.start < selection_end;
1368       });
1369 
1370   const int64 first_token = std::max(
1371       static_cast<int64>(0),
1372       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1373                          tokens_around_selection_to_copy.first));
1374   const int64 last_token = std::min(
1375       static_cast<int64>(cached_tokens.size()),
1376       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1377                          tokens_around_selection_to_copy.second));
1378 
1379   std::vector<Token> tokens;
1380   tokens.reserve(last_token - first_token);
1381   for (int i = first_token; i < last_token; ++i) {
1382     tokens.push_back(cached_tokens[i]);
1383   }
1384   return tokens;
1385 }
1386 }  // namespace internal
1387 
ClassifyTextUpperBoundNeededTokens() const1388 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1389   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1390       bounds_sensitive_features =
1391           classification_feature_processor_->GetOptions()
1392               ->bounds_sensitive_features();
1393   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1394     // The extraction span is the selection span expanded to include a relevant
1395     // number of tokens outside of the bounds of the selection.
1396     return {bounds_sensitive_features->num_tokens_before(),
1397             bounds_sensitive_features->num_tokens_after()};
1398   } else {
1399     // The extraction span is the clicked token with context_size tokens on
1400     // either side.
1401     const int context_size =
1402         selection_feature_processor_->GetOptions()->context_size();
1403     return {context_size, context_size};
1404   }
1405 }
1406 
1407 namespace {
1408 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1409 void SortClassificationResults(
1410     std::vector<ClassificationResult>* classification_results) {
1411   std::sort(classification_results->begin(), classification_results->end(),
1412             [](const ClassificationResult& a, const ClassificationResult& b) {
1413               return a.score > b.score;
1414             });
1415 }
1416 }  // namespace
1417 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1418 bool Annotator::ModelClassifyText(
1419     const std::string& context, const std::vector<Token>& cached_tokens,
1420     const std::vector<Locale>& detected_text_language_tags,
1421     const CodepointSpan& selection_indices, const BaseOptions& options,
1422     InterpreterManager* interpreter_manager,
1423     FeatureProcessor::EmbeddingCache* embedding_cache,
1424     std::vector<ClassificationResult>* classification_results,
1425     std::vector<Token>* tokens) const {
1426   const UnicodeText context_unicode =
1427       UTF8ToUnicodeText(context, /*do_copy=*/false);
1428   const auto [span_begin, span_end] =
1429       CodepointSpanToUnicodeTextRange(context_unicode, selection_indices);
1430   return ModelClassifyText(context_unicode, cached_tokens,
1431                            detected_text_language_tags, span_begin, span_end,
1432                            /*line=*/nullptr, selection_indices, options,
1433                            interpreter_manager, embedding_cache,
1434                            classification_results, tokens);
1435 }
1436 
ModelClassifyText(const UnicodeText & context_unicode,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const UnicodeTextRange * line,const CodepointSpan & selection_indices,const BaseOptions & options,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1437 bool Annotator::ModelClassifyText(
1438     const UnicodeText& context_unicode, const std::vector<Token>& cached_tokens,
1439     const std::vector<Locale>& detected_text_language_tags,
1440     const UnicodeText::const_iterator& span_begin,
1441     const UnicodeText::const_iterator& span_end, const UnicodeTextRange* line,
1442     const CodepointSpan& selection_indices, const BaseOptions& options,
1443     InterpreterManager* interpreter_manager,
1444     FeatureProcessor::EmbeddingCache* embedding_cache,
1445     std::vector<ClassificationResult>* classification_results,
1446     std::vector<Token>* tokens) const {
1447   if (model_->triggering_options() == nullptr ||
1448       !(model_->triggering_options()->enabled_modes() &
1449         ModeFlag_CLASSIFICATION)) {
1450     return true;
1451   }
1452 
1453   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1454                                     ml_model_triggering_locales_,
1455                                     /*default_value=*/true)) {
1456     return true;
1457   }
1458 
1459   std::vector<Token> local_tokens;
1460   if (tokens == nullptr) {
1461     tokens = &local_tokens;
1462   }
1463 
1464   if (cached_tokens.empty()) {
1465     *tokens = classification_feature_processor_->Tokenize(context_unicode);
1466   } else {
1467     *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1468                                          ClassifyTextUpperBoundNeededTokens());
1469   }
1470 
1471   int click_pos;
1472   classification_feature_processor_->RetokenizeAndFindClick(
1473       context_unicode, span_begin, span_end, selection_indices,
1474       classification_feature_processor_->GetOptions()
1475           ->only_use_line_with_click(),
1476       tokens, &click_pos);
1477   const TokenSpan selection_token_span =
1478       CodepointSpanToTokenSpan(*tokens, selection_indices);
1479   const int selection_num_tokens = selection_token_span.Size();
1480   if (model_->classification_options()->max_num_tokens() > 0 &&
1481       model_->classification_options()->max_num_tokens() <
1482           selection_num_tokens) {
1483     *classification_results = {{Collections::Other(), 1.0}};
1484     return true;
1485   }
1486 
1487   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1488       bounds_sensitive_features =
1489           classification_feature_processor_->GetOptions()
1490               ->bounds_sensitive_features();
1491   if (selection_token_span.first == kInvalidIndex ||
1492       selection_token_span.second == kInvalidIndex) {
1493     TC3_LOG(ERROR) << "Could not determine span.";
1494     return false;
1495   }
1496 
1497   // Compute the extraction span based on the model type.
1498   TokenSpan extraction_span;
1499   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1500     // The extraction span is the selection span expanded to include a relevant
1501     // number of tokens outside of the bounds of the selection.
1502     extraction_span = selection_token_span.Expand(
1503         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1504         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1505   } else {
1506     if (click_pos == kInvalidIndex) {
1507       TC3_LOG(ERROR) << "Couldn't choose a click position.";
1508       return false;
1509     }
1510     // The extraction span is the clicked token with context_size tokens on
1511     // either side.
1512     const int context_size =
1513         classification_feature_processor_->GetOptions()->context_size();
1514     extraction_span = TokenSpan(click_pos).Expand(
1515         /*num_tokens_left=*/context_size,
1516         /*num_tokens_right=*/context_size);
1517   }
1518   extraction_span = IntersectTokenSpans(extraction_span, AllOf(*tokens));
1519 
1520   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1521           *tokens, extraction_span)) {
1522     *classification_results = {{Collections::Other(), 1.0}};
1523     return true;
1524   }
1525 
1526   std::unique_ptr<CachedFeatures> cached_features;
1527   if (!classification_feature_processor_->ExtractFeatures(
1528           *tokens, extraction_span, selection_indices,
1529           embedding_executor_.get(), embedding_cache,
1530           classification_feature_processor_->EmbeddingSize() +
1531               classification_feature_processor_->DenseFeaturesCount(),
1532           &cached_features)) {
1533     TC3_LOG(ERROR) << "Could not extract features.";
1534     return false;
1535   }
1536 
1537   std::vector<float> features;
1538   features.reserve(cached_features->OutputFeaturesSize());
1539   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1540     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1541                                                           &features);
1542   } else {
1543     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1544   }
1545 
1546   TensorView<float> logits = classification_executor_->ComputeLogits(
1547       TensorView<float>(features.data(),
1548                         {1, static_cast<int>(features.size())}),
1549       interpreter_manager->ClassificationInterpreter());
1550   if (!logits.is_valid()) {
1551     TC3_LOG(ERROR) << "Couldn't compute logits.";
1552     return false;
1553   }
1554 
1555   if (logits.dims() != 2 || logits.dim(0) != 1 ||
1556       logits.dim(1) != classification_feature_processor_->NumCollections()) {
1557     TC3_LOG(ERROR) << "Mismatching output";
1558     return false;
1559   }
1560 
1561   const std::vector<float> scores =
1562       ComputeSoftmax(logits.data(), logits.dim(1));
1563 
1564   if (scores.empty()) {
1565     *classification_results = {{Collections::Other(), 1.0}};
1566     return true;
1567   }
1568 
1569   const int best_score_index =
1570       std::max_element(scores.begin(), scores.end()) - scores.begin();
1571   const std::string top_collection =
1572       classification_feature_processor_->LabelToCollection(best_score_index);
1573 
1574   // Sanity checks.
1575   if (top_collection == Collections::Phone()) {
1576     const int digit_count = std::count_if(span_begin, span_end, IsDigit);
1577     if (digit_count <
1578             model_->classification_options()->phone_min_num_digits() ||
1579         digit_count >
1580             model_->classification_options()->phone_max_num_digits()) {
1581       *classification_results = {{Collections::Other(), 1.0}};
1582       return true;
1583     }
1584   } else if (top_collection == Collections::Address()) {
1585     if (selection_num_tokens <
1586         model_->classification_options()->address_min_num_tokens()) {
1587       *classification_results = {{Collections::Other(), 1.0}};
1588       return true;
1589     }
1590   } else if (top_collection == Collections::Dictionary()) {
1591     if ((options.use_vocab_annotator && vocab_annotator_) ||
1592         !Locale::IsAnyLocaleSupported(detected_text_language_tags,
1593                                       dictionary_locales_,
1594                                       /*default_value=*/false)) {
1595       *classification_results = {{Collections::Other(), 1.0}};
1596       return true;
1597     }
1598   }
1599   *classification_results = {{top_collection, /*arg_score=*/1.0,
1600                               /*arg_priority_score=*/scores[best_score_index]}};
1601 
1602   // For some entities, we might want to clamp the priority score, for better
1603   // conflict resolution between entities.
1604   if (model_->triggering_options() != nullptr &&
1605       model_->triggering_options()->collection_to_priority() != nullptr) {
1606     if (auto entry =
1607             model_->triggering_options()->collection_to_priority()->LookupByKey(
1608                 top_collection.c_str())) {
1609       (*classification_results)[0].priority_score *= entry->value();
1610     }
1611   }
1612   return true;
1613 }
1614 
RegexClassifyText(const std::string & context,const CodepointSpan & selection_indices,std::vector<ClassificationResult> * classification_result) const1615 bool Annotator::RegexClassifyText(
1616     const std::string& context, const CodepointSpan& selection_indices,
1617     std::vector<ClassificationResult>* classification_result) const {
1618   const std::string selection_text =
1619       UTF8ToUnicodeText(context, /*do_copy=*/false)
1620           .UTF8Substring(selection_indices.first, selection_indices.second);
1621   const UnicodeText selection_text_unicode(
1622       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1623 
1624   // Check whether any of the regular expressions match.
1625   for (const int pattern_id : classification_regex_patterns_) {
1626     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1627     const std::unique_ptr<UniLib::RegexMatcher> matcher =
1628         regex_pattern.pattern->Matcher(selection_text_unicode);
1629     int status = UniLib::RegexMatcher::kNoError;
1630     bool matches;
1631     if (regex_pattern.config->use_approximate_matching()) {
1632       matches = matcher->ApproximatelyMatches(&status);
1633     } else {
1634       matches = matcher->Matches(&status);
1635     }
1636     if (status != UniLib::RegexMatcher::kNoError) {
1637       return false;
1638     }
1639     if (matches && VerifyRegexMatchCandidate(
1640                        context, regex_pattern.config->verification_options(),
1641                        selection_text, matcher.get())) {
1642       classification_result->push_back(
1643           {regex_pattern.config->collection_name()->str(),
1644            regex_pattern.config->target_classification_score(),
1645            regex_pattern.config->priority_score()});
1646       if (!SerializedEntityDataFromRegexMatch(
1647               regex_pattern.config, matcher.get(),
1648               &classification_result->back().serialized_entity_data)) {
1649         TC3_LOG(ERROR) << "Could not get entity data.";
1650         return false;
1651       }
1652     }
1653   }
1654 
1655   return true;
1656 }
1657 
1658 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1659 std::string PickCollectionForDatetime(
1660     const DatetimeParseResult& datetime_parse_result) {
1661   switch (datetime_parse_result.granularity) {
1662     case GRANULARITY_HOUR:
1663     case GRANULARITY_MINUTE:
1664     case GRANULARITY_SECOND:
1665       return Collections::DateTime();
1666     default:
1667       return Collections::Date();
1668   }
1669 }
1670 
1671 }  // namespace
1672 
DatetimeClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1673 bool Annotator::DatetimeClassifyText(
1674     const std::string& context, const CodepointSpan& selection_indices,
1675     const ClassificationOptions& options,
1676     std::vector<ClassificationResult>* classification_results) const {
1677   if (!datetime_parser_) {
1678     return true;
1679   }
1680 
1681   const std::string selection_text =
1682       UTF8ToUnicodeText(context, /*do_copy=*/false)
1683           .UTF8Substring(selection_indices.first, selection_indices.second);
1684 
1685   LocaleList locale_list = LocaleList::ParseFrom(options.locales);
1686   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
1687       datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1688                               options.reference_timezone, locale_list,
1689                               ModeFlag_CLASSIFICATION,
1690                               options.annotation_usecase,
1691                               /*anchor_start_end=*/true);
1692   if (!result_status.ok()) {
1693     TC3_LOG(ERROR) << "Error during parsing datetime.";
1694     return false;
1695   }
1696 
1697   for (const DatetimeParseResultSpan& datetime_span :
1698        result_status.ValueOrDie()) {
1699     // Only consider the result valid if the selection and extracted datetime
1700     // spans exactly match.
1701     if (CodepointSpan(datetime_span.span.first + selection_indices.first,
1702                       datetime_span.span.second + selection_indices.first) ==
1703         selection_indices) {
1704       for (const DatetimeParseResult& parse_result : datetime_span.data) {
1705         classification_results->emplace_back(
1706             PickCollectionForDatetime(parse_result),
1707             datetime_span.target_classification_score);
1708         classification_results->back().datetime_parse_result = parse_result;
1709         classification_results->back().serialized_entity_data =
1710             CreateDatetimeSerializedEntityData(parse_result);
1711         classification_results->back().priority_score =
1712             datetime_span.priority_score;
1713       }
1714       return true;
1715     }
1716   }
1717   return true;
1718 }
1719 
ClassifyText(const std::string & context,const CodepointSpan & selection_indices,const ClassificationOptions & options) const1720 std::vector<ClassificationResult> Annotator::ClassifyText(
1721     const std::string& context, const CodepointSpan& selection_indices,
1722     const ClassificationOptions& options) const {
1723   if (context.size() > std::numeric_limits<int>::max()) {
1724     TC3_LOG(ERROR) << "Rejecting too long input: " << context.size();
1725     return {};
1726   }
1727   if (!initialized_) {
1728     TC3_LOG(ERROR) << "Not initialized";
1729     return {};
1730   }
1731   if (options.annotation_usecase !=
1732       AnnotationUsecase_ANNOTATION_USECASE_SMART) {
1733     TC3_LOG(WARNING)
1734         << "Invoking ClassifyText, which is not supported in RAW mode.";
1735     return {};
1736   }
1737   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1738     return {};
1739   }
1740 
1741   std::vector<Locale> detected_text_language_tags;
1742   if (!ParseLocales(options.detected_text_language_tags,
1743                     &detected_text_language_tags)) {
1744     TC3_LOG(WARNING)
1745         << "Failed to parse the detected_text_language_tags in options: "
1746         << options.detected_text_language_tags;
1747   }
1748   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1749                                     model_triggering_locales_,
1750                                     /*default_value=*/true)) {
1751     return {};
1752   }
1753 
1754   const UnicodeText context_unicode =
1755       UTF8ToUnicodeText(context, /*do_copy=*/false);
1756 
1757   if (!unilib_->IsValidUtf8(context_unicode)) {
1758     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
1759     return {};
1760   }
1761 
1762   if (!IsValidSpanInput(context_unicode, selection_indices)) {
1763     TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
1764                 << selection_indices.first << " " << selection_indices.second;
1765     return {};
1766   }
1767 
1768   // We'll accumulate a list of candidates, and pick the best candidate in the
1769   // end.
1770   std::vector<AnnotatedSpan> candidates;
1771 
1772   // Try the knowledge engine.
1773   // TODO(b/126579108): Propagate error status.
1774   ClassificationResult knowledge_result;
1775   if (knowledge_engine_ &&
1776       knowledge_engine_
1777           ->ClassifyText(context, selection_indices, options.annotation_usecase,
1778                          options.location_context, Permissions(),
1779                          &knowledge_result)
1780           .ok()) {
1781     candidates.push_back({selection_indices, {knowledge_result}});
1782     candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1783   }
1784 
1785   AddContactMetadataToKnowledgeClassificationResults(&candidates);
1786 
1787   // Try the contact engine.
1788   // TODO(b/126579108): Propagate error status.
1789   ClassificationResult contact_result;
1790   if (contact_engine_ && contact_engine_->ClassifyText(
1791                              context, selection_indices, &contact_result)) {
1792     candidates.push_back({selection_indices, {contact_result}});
1793   }
1794 
1795   // Try the person name engine.
1796   ClassificationResult person_name_result;
1797   if (person_name_engine_ &&
1798       person_name_engine_->ClassifyText(context, selection_indices,
1799                                         &person_name_result)) {
1800     candidates.push_back({selection_indices, {person_name_result}});
1801     candidates.back().source = AnnotatedSpan::Source::PERSON_NAME;
1802   }
1803 
1804   // Try the installed app engine.
1805   // TODO(b/126579108): Propagate error status.
1806   ClassificationResult installed_app_result;
1807   if (installed_app_engine_ &&
1808       installed_app_engine_->ClassifyText(context, selection_indices,
1809                                           &installed_app_result)) {
1810     candidates.push_back({selection_indices, {installed_app_result}});
1811   }
1812 
1813   // Try the regular expression models.
1814   std::vector<ClassificationResult> regex_results;
1815   if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1816     return {};
1817   }
1818   for (const ClassificationResult& result : regex_results) {
1819     candidates.push_back({selection_indices, {result}});
1820   }
1821 
1822   // Try the date model.
1823   //
1824   // DatetimeClassifyText only returns the first result, which can however have
1825   // more interpretations. They are inserted in the candidates as a single
1826   // AnnotatedSpan, so that they get treated together by the conflict resolution
1827   // algorithm.
1828   std::vector<ClassificationResult> datetime_results;
1829   if (!DatetimeClassifyText(context, selection_indices, options,
1830                             &datetime_results)) {
1831     return {};
1832   }
1833   if (!datetime_results.empty()) {
1834     candidates.push_back({selection_indices, std::move(datetime_results)});
1835     candidates.back().source = AnnotatedSpan::Source::DATETIME;
1836   }
1837 
1838   // Try the number annotator.
1839   // TODO(b/126579108): Propagate error status.
1840   ClassificationResult number_annotator_result;
1841   if (number_annotator_ &&
1842       number_annotator_->ClassifyText(context_unicode, selection_indices,
1843                                       options.annotation_usecase,
1844                                       &number_annotator_result)) {
1845     candidates.push_back({selection_indices, {number_annotator_result}});
1846   }
1847 
1848   // Try the duration annotator.
1849   ClassificationResult duration_annotator_result;
1850   if (duration_annotator_ &&
1851       duration_annotator_->ClassifyText(context_unicode, selection_indices,
1852                                         options.annotation_usecase,
1853                                         &duration_annotator_result)) {
1854     candidates.push_back({selection_indices, {duration_annotator_result}});
1855     candidates.back().source = AnnotatedSpan::Source::DURATION;
1856   }
1857 
1858   // Try the translate annotator.
1859   ClassificationResult translate_annotator_result;
1860   if (translate_annotator_ &&
1861       translate_annotator_->ClassifyText(context_unicode, selection_indices,
1862                                          options.user_familiar_language_tags,
1863                                          &translate_annotator_result)) {
1864     candidates.push_back({selection_indices, {translate_annotator_result}});
1865   }
1866 
1867   // Try the grammar model.
1868   ClassificationResult grammar_annotator_result;
1869   if (grammar_annotator_ && grammar_annotator_->ClassifyText(
1870                                 detected_text_language_tags, context_unicode,
1871                                 selection_indices, &grammar_annotator_result)) {
1872     candidates.push_back({selection_indices, {grammar_annotator_result}});
1873   }
1874 
1875   ClassificationResult pod_ner_annotator_result;
1876   if (pod_ner_annotator_ && options.use_pod_ner &&
1877       pod_ner_annotator_->ClassifyText(context_unicode, selection_indices,
1878                                        &pod_ner_annotator_result)) {
1879     candidates.push_back({selection_indices, {pod_ner_annotator_result}});
1880   }
1881 
1882   ClassificationResult vocab_annotator_result;
1883   if (vocab_annotator_ && options.use_vocab_annotator &&
1884       vocab_annotator_->ClassifyText(
1885           context_unicode, selection_indices, detected_text_language_tags,
1886           options.trigger_dictionary_on_beginner_words,
1887           &vocab_annotator_result)) {
1888     candidates.push_back({selection_indices, {vocab_annotator_result}});
1889   }
1890 
1891   if (experimental_annotator_) {
1892     experimental_annotator_->ClassifyText(context_unicode, selection_indices,
1893                                           candidates);
1894   }
1895 
1896   // Try the ML model.
1897   //
1898   // The output of the model is considered as an exclusive 1-of-N choice. That's
1899   // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1900   // span for each candidate, like e.g. the regex model.
1901   InterpreterManager interpreter_manager(selection_executor_.get(),
1902                                          classification_executor_.get());
1903   std::vector<ClassificationResult> model_results;
1904   std::vector<Token> tokens;
1905   if (!ModelClassifyText(
1906           context, /*cached_tokens=*/{}, detected_text_language_tags,
1907           selection_indices, options, &interpreter_manager,
1908           /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1909     return {};
1910   }
1911   if (!model_results.empty()) {
1912     candidates.push_back({selection_indices, std::move(model_results)});
1913   }
1914 
1915   std::vector<int> candidate_indices;
1916   if (!ResolveConflicts(candidates, context, tokens,
1917                         detected_text_language_tags, options,
1918                         &interpreter_manager, &candidate_indices)) {
1919     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1920     return {};
1921   }
1922 
1923   std::vector<ClassificationResult> results;
1924   for (const int i : candidate_indices) {
1925     for (const ClassificationResult& result : candidates[i].classification) {
1926       if (!FilteredForClassification(result)) {
1927         results.push_back(result);
1928       }
1929     }
1930   }
1931 
1932   // Sort results according to score.
1933   std::sort(results.begin(), results.end(),
1934             [](const ClassificationResult& a, const ClassificationResult& b) {
1935               return a.score > b.score;
1936             });
1937 
1938   if (results.empty()) {
1939     results = {{Collections::Other(), 1.0}};
1940   }
1941   return results;
1942 }
1943 
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,const AnnotationOptions & options,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1944 bool Annotator::ModelAnnotate(
1945     const std::string& context,
1946     const std::vector<Locale>& detected_text_language_tags,
1947     const AnnotationOptions& options, InterpreterManager* interpreter_manager,
1948     std::vector<Token>* tokens, std::vector<AnnotatedSpan>* result) const {
1949   if (model_->triggering_options() == nullptr ||
1950       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1951     return true;
1952   }
1953 
1954   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1955                                     ml_model_triggering_locales_,
1956                                     /*default_value=*/true)) {
1957     return true;
1958   }
1959 
1960   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1961                                                         /*do_copy=*/false);
1962   std::vector<UnicodeTextRange> lines;
1963   if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1964     lines.push_back({context_unicode.begin(), context_unicode.end()});
1965   } else {
1966     lines = selection_feature_processor_->SplitContext(
1967         context_unicode, selection_feature_processor_->GetOptions()
1968                              ->use_pipe_character_for_newline());
1969   }
1970 
1971   const float min_annotate_confidence =
1972       (model_->triggering_options() != nullptr
1973            ? model_->triggering_options()->min_annotate_confidence()
1974            : 0.f);
1975 
1976   for (const UnicodeTextRange& line : lines) {
1977     FeatureProcessor::EmbeddingCache embedding_cache;
1978     const std::string line_str =
1979         UnicodeText::UTF8Substring(line.first, line.second);
1980 
1981     std::vector<Token> line_tokens;
1982     line_tokens = selection_feature_processor_->Tokenize(line_str);
1983 
1984     selection_feature_processor_->RetokenizeAndFindClick(
1985         line_str, {0, std::distance(line.first, line.second)},
1986         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1987         &line_tokens,
1988         /*click_pos=*/nullptr);
1989     const TokenSpan full_line_span = {
1990         0, static_cast<TokenIndex>(line_tokens.size())};
1991 
1992     // TODO(zilka): Add support for greater granularity of this check.
1993     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1994             line_tokens, full_line_span)) {
1995       continue;
1996     }
1997 
1998     std::unique_ptr<CachedFeatures> cached_features;
1999     if (!selection_feature_processor_->ExtractFeatures(
2000             line_tokens, full_line_span,
2001             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
2002             embedding_executor_.get(),
2003             /*embedding_cache=*/nullptr,
2004             selection_feature_processor_->EmbeddingSize() +
2005                 selection_feature_processor_->DenseFeaturesCount(),
2006             &cached_features)) {
2007       TC3_LOG(ERROR) << "Could not extract features.";
2008       return false;
2009     }
2010 
2011     std::vector<TokenSpan> local_chunks;
2012     if (!ModelChunk(line_tokens.size(), /*span_of_interest=*/full_line_span,
2013                     interpreter_manager->SelectionInterpreter(),
2014                     *cached_features, &local_chunks)) {
2015       TC3_LOG(ERROR) << "Could not chunk.";
2016       return false;
2017     }
2018 
2019     const int offset = std::distance(context_unicode.begin(), line.first);
2020     UnicodeText line_unicode;
2021     std::vector<UnicodeText::const_iterator> line_codepoints;
2022     if (options.enable_optimization) {
2023       if (local_chunks.empty()) {
2024         continue;
2025       }
2026       line_unicode = UTF8ToUnicodeText(line_str, /*do_copy=*/false);
2027       line_codepoints = line_unicode.Codepoints();
2028       line_codepoints.push_back(line_unicode.end());
2029     }
2030     for (const TokenSpan& chunk : local_chunks) {
2031       CodepointSpan codepoint_span =
2032           TokenSpanToCodepointSpan(line_tokens, chunk);
2033       if (options.enable_optimization) {
2034         if (!codepoint_span.IsValid() ||
2035             codepoint_span.second > line_codepoints.size()) {
2036           continue;
2037         }
2038         codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2039             /*span_begin=*/line_codepoints[codepoint_span.first],
2040             /*span_end=*/line_codepoints[codepoint_span.second],
2041             codepoint_span);
2042         if (model_->selection_options()->strip_unpaired_brackets()) {
2043           codepoint_span = StripUnpairedBrackets(
2044               /*span_begin=*/line_codepoints[codepoint_span.first],
2045               /*span_end=*/line_codepoints[codepoint_span.second],
2046               codepoint_span, *unilib_);
2047         }
2048       } else {
2049         codepoint_span = selection_feature_processor_->StripBoundaryCodepoints(
2050             line_str, codepoint_span);
2051         if (model_->selection_options()->strip_unpaired_brackets()) {
2052           codepoint_span =
2053               StripUnpairedBrackets(context_unicode, codepoint_span, *unilib_);
2054         }
2055       }
2056 
2057       // Skip empty spans.
2058       if (codepoint_span.first != codepoint_span.second) {
2059         std::vector<ClassificationResult> classification;
2060         if (options.enable_optimization) {
2061           if (!ModelClassifyText(
2062                   line_unicode, line_tokens, detected_text_language_tags,
2063                   /*span_begin=*/line_codepoints[codepoint_span.first],
2064                   /*span_end=*/line_codepoints[codepoint_span.second], &line,
2065                   codepoint_span, options, interpreter_manager,
2066                   &embedding_cache, &classification, /*tokens=*/nullptr)) {
2067             TC3_LOG(ERROR) << "Could not classify text: "
2068                            << (codepoint_span.first + offset) << " "
2069                            << (codepoint_span.second + offset);
2070             return false;
2071           }
2072         } else {
2073           if (!ModelClassifyText(line_str, line_tokens,
2074                                  detected_text_language_tags, codepoint_span,
2075                                  options, interpreter_manager, &embedding_cache,
2076                                  &classification, /*tokens=*/nullptr)) {
2077             TC3_LOG(ERROR) << "Could not classify text: "
2078                            << (codepoint_span.first + offset) << " "
2079                            << (codepoint_span.second + offset);
2080             return false;
2081           }
2082         }
2083 
2084         // Do not include the span if it's classified as "other".
2085         if (!classification.empty() && !ClassifiedAsOther(classification) &&
2086             classification[0].score >= min_annotate_confidence) {
2087           AnnotatedSpan result_span;
2088           result_span.span = {codepoint_span.first + offset,
2089                               codepoint_span.second + offset};
2090           result_span.classification = std::move(classification);
2091           result->push_back(std::move(result_span));
2092         }
2093       }
2094     }
2095 
2096     // If we are going line-by-line, we need to insert the tokens for each line.
2097     // But if not, we can optimize and just std::move the current line vector to
2098     // the output.
2099     if (selection_feature_processor_->GetOptions()
2100             ->only_use_line_with_click()) {
2101       tokens->insert(tokens->end(), line_tokens.begin(), line_tokens.end());
2102     } else {
2103       *tokens = std::move(line_tokens);
2104     }
2105   }
2106   return true;
2107 }
2108 
SelectionFeatureProcessorForTests() const2109 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
2110   return selection_feature_processor_.get();
2111 }
2112 
ClassificationFeatureProcessorForTests() const2113 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
2114     const {
2115   return classification_feature_processor_.get();
2116 }
2117 
DatetimeParserForTests() const2118 const DatetimeParser* Annotator::DatetimeParserForTests() const {
2119   return datetime_parser_.get();
2120 }
2121 
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const2122 void Annotator::RemoveNotEnabledEntityTypes(
2123     const EnabledEntityTypes& is_entity_type_enabled,
2124     std::vector<AnnotatedSpan>* annotated_spans) const {
2125   for (AnnotatedSpan& annotated_span : *annotated_spans) {
2126     std::vector<ClassificationResult>& classifications =
2127         annotated_span.classification;
2128     classifications.erase(
2129         std::remove_if(classifications.begin(), classifications.end(),
2130                        [&is_entity_type_enabled](
2131                            const ClassificationResult& classification_result) {
2132                          return !is_entity_type_enabled(
2133                              classification_result.collection);
2134                        }),
2135         classifications.end());
2136   }
2137   annotated_spans->erase(
2138       std::remove_if(annotated_spans->begin(), annotated_spans->end(),
2139                      [](const AnnotatedSpan& annotated_span) {
2140                        return annotated_span.classification.empty();
2141                      }),
2142       annotated_spans->end());
2143 }
2144 
AddContactMetadataToKnowledgeClassificationResults(std::vector<AnnotatedSpan> * candidates) const2145 void Annotator::AddContactMetadataToKnowledgeClassificationResults(
2146     std::vector<AnnotatedSpan>* candidates) const {
2147   if (candidates == nullptr || contact_engine_ == nullptr) {
2148     return;
2149   }
2150   for (auto& candidate : *candidates) {
2151     for (auto& classification_result : candidate.classification) {
2152       contact_engine_->AddContactMetadataToKnowledgeClassificationResult(
2153           &classification_result);
2154     }
2155   }
2156 }
2157 
AnnotateSingleInput(const std::string & context,const AnnotationOptions & options,std::vector<AnnotatedSpan> * candidates) const2158 Status Annotator::AnnotateSingleInput(
2159     const std::string& context, const AnnotationOptions& options,
2160     std::vector<AnnotatedSpan>* candidates) const {
2161   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
2162     return Status(StatusCode::UNAVAILABLE, "Model annotation was not enabled.");
2163   }
2164 
2165   const UnicodeText context_unicode =
2166       UTF8ToUnicodeText(context, /*do_copy=*/false);
2167 
2168   std::vector<Locale> detected_text_language_tags;
2169   if (!ParseLocales(options.detected_text_language_tags,
2170                     &detected_text_language_tags)) {
2171     TC3_LOG(WARNING)
2172         << "Failed to parse the detected_text_language_tags in options: "
2173         << options.detected_text_language_tags;
2174   }
2175   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
2176                                     model_triggering_locales_,
2177                                     /*default_value=*/true)) {
2178     return Status(
2179         StatusCode::UNAVAILABLE,
2180         "The detected language tags are not in the supported locales.");
2181   }
2182 
2183   InterpreterManager interpreter_manager(selection_executor_.get(),
2184                                          classification_executor_.get());
2185 
2186   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
2187   const bool is_raw_usecase =
2188       options.annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW;
2189 
2190   // Annotate with the selection model.
2191   const bool model_annotations_enabled =
2192       !is_raw_usecase || IsAnyModelEntityTypeEnabled(is_entity_type_enabled);
2193   std::vector<Token> tokens;
2194   if (model_annotations_enabled &&
2195       !ModelAnnotate(context, detected_text_language_tags, options,
2196                      &interpreter_manager, &tokens, candidates)) {
2197     return Status(StatusCode::INTERNAL, "Couldn't run ModelAnnotate.");
2198   } else if (!model_annotations_enabled) {
2199     // If the ML model didn't run, we need to tokenize to support the other
2200     // annotators that depend on the tokens.
2201     // Optimization could be made to only do this when an annotator that uses
2202     // the tokens is enabled, but it's unclear if the added complexity is worth
2203     // it.
2204     if (selection_feature_processor_ != nullptr) {
2205       tokens = selection_feature_processor_->Tokenize(context_unicode);
2206     }
2207   }
2208 
2209   // Annotate with the regular expression models.
2210   const bool regex_annotations_enabled =
2211       !is_raw_usecase || IsAnyRegexEntityTypeEnabled(is_entity_type_enabled);
2212   if (regex_annotations_enabled &&
2213       !RegexChunk(
2214           UTF8ToUnicodeText(context, /*do_copy=*/false),
2215           annotation_regex_patterns_, options.is_serialized_entity_data_enabled,
2216           is_entity_type_enabled, options.annotation_usecase, candidates)) {
2217     return Status(StatusCode::INTERNAL, "Couldn't run RegexChunk.");
2218   }
2219 
2220   // Annotate with the datetime model.
2221   // NOTE: Datetime can be disabled even in the SMART usecase, because it's been
2222   // relatively slow for some clients.
2223   if ((is_entity_type_enabled(Collections::Date()) ||
2224        is_entity_type_enabled(Collections::DateTime())) &&
2225       !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
2226                      options.reference_time_ms_utc, options.reference_timezone,
2227                      options.locales, ModeFlag_ANNOTATION,
2228                      options.annotation_usecase,
2229                      options.is_serialized_entity_data_enabled, candidates)) {
2230     return Status(StatusCode::INTERNAL, "Couldn't run DatetimeChunk.");
2231   }
2232 
2233   // Annotate with the contact engine.
2234   const bool contact_annotations_enabled =
2235       !is_raw_usecase || is_entity_type_enabled(Collections::Contact());
2236   if (contact_annotations_enabled && contact_engine_ &&
2237       !contact_engine_->Chunk(context_unicode, tokens, candidates)) {
2238     return Status(StatusCode::INTERNAL, "Couldn't run contact engine Chunk.");
2239   }
2240 
2241   // Annotate with the installed app engine.
2242   const bool app_annotations_enabled =
2243       !is_raw_usecase || is_entity_type_enabled(Collections::App());
2244   if (app_annotations_enabled && installed_app_engine_ &&
2245       !installed_app_engine_->Chunk(context_unicode, tokens, candidates)) {
2246     return Status(StatusCode::INTERNAL,
2247                   "Couldn't run installed app engine Chunk.");
2248   }
2249 
2250   // Annotate with the number annotator.
2251   const bool number_annotations_enabled =
2252       !is_raw_usecase || (is_entity_type_enabled(Collections::Number()) ||
2253                           is_entity_type_enabled(Collections::Percentage()));
2254   if (number_annotations_enabled && number_annotator_ != nullptr &&
2255       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
2256                                   candidates)) {
2257     return Status(StatusCode::INTERNAL,
2258                   "Couldn't run number annotator FindAll.");
2259   }
2260 
2261   // Annotate with the duration annotator.
2262   const bool duration_annotations_enabled =
2263       !is_raw_usecase || is_entity_type_enabled(Collections::Duration());
2264   if (duration_annotations_enabled && duration_annotator_ != nullptr &&
2265       !duration_annotator_->FindAll(context_unicode, tokens,
2266                                     options.annotation_usecase, candidates)) {
2267     return Status(StatusCode::INTERNAL,
2268                   "Couldn't run duration annotator FindAll.");
2269   }
2270 
2271   // Annotate with the person name engine.
2272   const bool person_annotations_enabled =
2273       !is_raw_usecase || is_entity_type_enabled(Collections::PersonName());
2274   if (person_annotations_enabled && person_name_engine_ &&
2275       !person_name_engine_->Chunk(context_unicode, tokens, candidates)) {
2276     return Status(StatusCode::INTERNAL,
2277                   "Couldn't run person name engine Chunk.");
2278   }
2279 
2280   // Annotate with the grammar annotators.
2281   if (grammar_annotator_ != nullptr &&
2282       !grammar_annotator_->Annotate(detected_text_language_tags,
2283                                     context_unicode, candidates)) {
2284     return Status(StatusCode::INTERNAL, "Couldn't run grammar annotators.");
2285   }
2286 
2287   // Annotate with the POD NER annotator.
2288   const bool pod_ner_annotations_enabled =
2289       !is_raw_usecase || IsAnyPodNerEntityTypeEnabled(is_entity_type_enabled);
2290   if (pod_ner_annotations_enabled && pod_ner_annotator_ != nullptr &&
2291       options.use_pod_ner &&
2292       !pod_ner_annotator_->Annotate(context_unicode, candidates)) {
2293     return Status(StatusCode::INTERNAL, "Couldn't run POD NER annotator.");
2294   }
2295 
2296   // Annotate with the vocab annotator.
2297   const bool vocab_annotations_enabled =
2298       !is_raw_usecase || is_entity_type_enabled(Collections::Dictionary());
2299   if (vocab_annotations_enabled && vocab_annotator_ != nullptr &&
2300       options.use_vocab_annotator &&
2301       !vocab_annotator_->Annotate(context_unicode, detected_text_language_tags,
2302                                   options.trigger_dictionary_on_beginner_words,
2303                                   candidates)) {
2304     return Status(StatusCode::INTERNAL, "Couldn't run vocab annotator.");
2305   }
2306 
2307   // Annotate with the experimental annotator.
2308   if (experimental_annotator_ != nullptr &&
2309       !experimental_annotator_->Annotate(context_unicode, candidates)) {
2310     return Status(StatusCode::INTERNAL, "Couldn't run experimental annotator.");
2311   }
2312 
2313   // Sort candidates according to their position in the input, so that the next
2314   // code can assume that any connected component of overlapping spans forms a
2315   // contiguous block.
2316   // Also sort them according to the end position and collection, so that the
2317   // deduplication code below can assume that same spans and classifications
2318   // form contiguous blocks.
2319   std::sort(candidates->begin(), candidates->end(),
2320             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
2321               if (a.span.first != b.span.first) {
2322                 return a.span.first < b.span.first;
2323               }
2324 
2325               if (a.span.second != b.span.second) {
2326                 return a.span.second < b.span.second;
2327               }
2328 
2329               return a.classification[0].collection <
2330                      b.classification[0].collection;
2331             });
2332 
2333   std::vector<int> candidate_indices;
2334   if (!ResolveConflicts(*candidates, context, tokens,
2335                         detected_text_language_tags, options,
2336                         &interpreter_manager, &candidate_indices)) {
2337     return Status(StatusCode::INTERNAL, "Couldn't resolve conflicts.");
2338   }
2339 
2340   // Remove candidates that overlap exactly and have the same collection.
2341   // This can e.g. happen for phone coming from both ML model and regex.
2342   candidate_indices.erase(
2343       std::unique(candidate_indices.begin(), candidate_indices.end(),
2344                   [&candidates](const int a_index, const int b_index) {
2345                     const AnnotatedSpan& a = (*candidates)[a_index];
2346                     const AnnotatedSpan& b = (*candidates)[b_index];
2347                     return a.span == b.span &&
2348                            a.classification[0].collection ==
2349                                b.classification[0].collection;
2350                   }),
2351       candidate_indices.end());
2352 
2353   std::vector<AnnotatedSpan> result;
2354   result.reserve(candidate_indices.size());
2355   for (const int i : candidate_indices) {
2356     if ((*candidates)[i].classification.empty() ||
2357         ClassifiedAsOther((*candidates)[i].classification) ||
2358         FilteredForAnnotation((*candidates)[i])) {
2359       continue;
2360     }
2361     result.push_back(std::move((*candidates)[i]));
2362   }
2363 
2364   // We generate all candidates and remove them later (with the exception of
2365   // date/time/duration entities) because there are complex interdependencies
2366   // between the entity types. E.g., the TLD of an email can be interpreted as a
2367   // URL, but most likely a user of the API does not want such annotations if
2368   // "url" is enabled and "email" is not.
2369   RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
2370 
2371   for (AnnotatedSpan& annotated_span : result) {
2372     SortClassificationResults(&annotated_span.classification);
2373   }
2374   *candidates = result;
2375   return Status::OK;
2376 }
2377 
AnnotateStructuredInput(const std::vector<InputFragment> & string_fragments,const AnnotationOptions & options) const2378 StatusOr<Annotations> Annotator::AnnotateStructuredInput(
2379     const std::vector<InputFragment>& string_fragments,
2380     const AnnotationOptions& options) const {
2381   Annotations annotation_candidates;
2382   annotation_candidates.annotated_spans.resize(string_fragments.size());
2383 
2384   std::vector<std::string> text_to_annotate;
2385   text_to_annotate.reserve(string_fragments.size());
2386   std::vector<FragmentMetadata> fragment_metadata;
2387   fragment_metadata.reserve(string_fragments.size());
2388   for (const auto& string_fragment : string_fragments) {
2389     text_to_annotate.push_back(string_fragment.text);
2390     fragment_metadata.push_back(
2391         {.relative_bounding_box_top = string_fragment.bounding_box_top,
2392          .relative_bounding_box_height = string_fragment.bounding_box_height});
2393   }
2394 
2395   // KnowledgeEngine is special, because it supports annotation of multiple
2396   // fragments at once.
2397   if (knowledge_engine_ &&
2398       !knowledge_engine_
2399            ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
2400                                 options.annotation_usecase,
2401                                 options.location_context, options.permissions,
2402                                 options.annotate_mode, &annotation_candidates)
2403            .ok()) {
2404     return Status(StatusCode::INTERNAL, "Couldn't run knowledge engine Chunk.");
2405   }
2406   // The annotator engines shouldn't change the number of annotation vectors.
2407   if (annotation_candidates.annotated_spans.size() != text_to_annotate.size()) {
2408     TC3_LOG(ERROR) << "Received " << text_to_annotate.size()
2409                    << " texts to annotate but generated a different number of  "
2410                       "lists of annotations:"
2411                    << annotation_candidates.annotated_spans.size();
2412     return Status(StatusCode::INTERNAL,
2413                   "Number of annotation candidates differs from "
2414                   "number of texts to annotate.");
2415   }
2416 
2417   // As an optimization, if the only annotated type is Entity, we skip all the
2418   // other annotators than the KnowledgeEngine. This only happens in the raw
2419   // mode, to make sure it does not affect the result.
2420   if (options.annotation_usecase == ANNOTATION_USECASE_RAW &&
2421       options.entity_types.size() == 1 &&
2422       *options.entity_types.begin() == Collections::Entity()) {
2423     return annotation_candidates;
2424   }
2425 
2426   // Other annotators run on each fragment independently.
2427   for (int i = 0; i < text_to_annotate.size(); ++i) {
2428     AnnotationOptions annotation_options = options;
2429     if (string_fragments[i].datetime_options.has_value()) {
2430       DatetimeOptions reference_datetime =
2431           string_fragments[i].datetime_options.value();
2432       annotation_options.reference_time_ms_utc =
2433           reference_datetime.reference_time_ms_utc;
2434       annotation_options.reference_timezone =
2435           reference_datetime.reference_timezone;
2436     }
2437 
2438     AddContactMetadataToKnowledgeClassificationResults(
2439         &annotation_candidates.annotated_spans[i]);
2440 
2441     Status annotation_status =
2442         AnnotateSingleInput(text_to_annotate[i], annotation_options,
2443                             &annotation_candidates.annotated_spans[i]);
2444     if (!annotation_status.ok()) {
2445       return annotation_status;
2446     }
2447   }
2448   return annotation_candidates;
2449 }
2450 
Annotate(const std::string & context,const AnnotationOptions & options) const2451 std::vector<AnnotatedSpan> Annotator::Annotate(
2452     const std::string& context, const AnnotationOptions& options) const {
2453   if (context.size() > std::numeric_limits<int>::max()) {
2454     TC3_LOG(ERROR) << "Rejecting too long input.";
2455     return {};
2456   }
2457 
2458   const UnicodeText context_unicode =
2459       UTF8ToUnicodeText(context, /*do_copy=*/false);
2460   if (!unilib_->IsValidUtf8(context_unicode)) {
2461     TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
2462     return {};
2463   }
2464 
2465   std::vector<InputFragment> string_fragments;
2466   string_fragments.push_back({.text = context});
2467   StatusOr<Annotations> annotations =
2468       AnnotateStructuredInput(string_fragments, options);
2469   if (!annotations.ok()) {
2470     TC3_LOG(ERROR) << "Returned error when calling AnnotateStructuredInput: "
2471                    << annotations.status().error_message();
2472     return {};
2473   }
2474   return annotations.ValueOrDie().annotated_spans[0];
2475 }
2476 
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const2477 CodepointSpan Annotator::ComputeSelectionBoundaries(
2478     const UniLib::RegexMatcher* match,
2479     const RegexModel_::Pattern* config) const {
2480   if (config->capturing_group() == nullptr) {
2481     // Use first capturing group to specify the selection.
2482     int status = UniLib::RegexMatcher::kNoError;
2483     const CodepointSpan result = {match->Start(1, &status),
2484                                   match->End(1, &status)};
2485     if (status != UniLib::RegexMatcher::kNoError) {
2486       return {kInvalidIndex, kInvalidIndex};
2487     }
2488     return result;
2489   }
2490 
2491   CodepointSpan result = {kInvalidIndex, kInvalidIndex};
2492   const int num_groups = config->capturing_group()->size();
2493   for (int i = 0; i < num_groups; i++) {
2494     if (!config->capturing_group()->Get(i)->extend_selection()) {
2495       continue;
2496     }
2497 
2498     int status = UniLib::RegexMatcher::kNoError;
2499     // Check match and adjust bounds.
2500     const int group_start = match->Start(i, &status);
2501     const int group_end = match->End(i, &status);
2502     if (status != UniLib::RegexMatcher::kNoError) {
2503       return {kInvalidIndex, kInvalidIndex};
2504     }
2505     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2506       continue;
2507     }
2508     if (result.first == kInvalidIndex) {
2509       result = {group_start, group_end};
2510     } else {
2511       result.first = std::min(result.first, group_start);
2512       result.second = std::max(result.second, group_end);
2513     }
2514   }
2515   return result;
2516 }
2517 
HasEntityData(const RegexModel_::Pattern * pattern) const2518 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
2519   if (pattern->serialized_entity_data() != nullptr ||
2520       pattern->entity_data() != nullptr) {
2521     return true;
2522   }
2523   if (pattern->capturing_group() != nullptr) {
2524     for (const CapturingGroup* group : *pattern->capturing_group()) {
2525       if (group->entity_field_path() != nullptr) {
2526         return true;
2527       }
2528       if (group->serialized_entity_data() != nullptr ||
2529           group->entity_data() != nullptr) {
2530         return true;
2531       }
2532     }
2533   }
2534   return false;
2535 }
2536 
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const2537 bool Annotator::SerializedEntityDataFromRegexMatch(
2538     const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
2539     std::string* serialized_entity_data) const {
2540   if (!HasEntityData(pattern)) {
2541     serialized_entity_data->clear();
2542     return true;
2543   }
2544   TC3_CHECK(entity_data_builder_ != nullptr);
2545 
2546   std::unique_ptr<MutableFlatbuffer> entity_data =
2547       entity_data_builder_->NewRoot();
2548 
2549   TC3_CHECK(entity_data != nullptr);
2550 
2551   // Set fixed entity data.
2552   if (pattern->serialized_entity_data() != nullptr) {
2553     entity_data->MergeFromSerializedFlatbuffer(
2554         StringPiece(pattern->serialized_entity_data()->c_str(),
2555                     pattern->serialized_entity_data()->size()));
2556   }
2557   if (pattern->entity_data() != nullptr) {
2558     entity_data->MergeFrom(
2559         reinterpret_cast<const flatbuffers::Table*>(pattern->entity_data()));
2560   }
2561 
2562   // Add entity data from rule capturing groups.
2563   if (pattern->capturing_group() != nullptr) {
2564     const int num_groups = pattern->capturing_group()->size();
2565     for (int i = 0; i < num_groups; i++) {
2566       const CapturingGroup* group = pattern->capturing_group()->Get(i);
2567 
2568       // Check whether the group matched.
2569       Optional<std::string> group_match_text =
2570           GetCapturingGroupText(matcher, /*group_id=*/i);
2571       if (!group_match_text.has_value()) {
2572         continue;
2573       }
2574 
2575       // Set fixed entity data from capturing group match.
2576       if (group->serialized_entity_data() != nullptr) {
2577         entity_data->MergeFromSerializedFlatbuffer(
2578             StringPiece(group->serialized_entity_data()->c_str(),
2579                         group->serialized_entity_data()->size()));
2580       }
2581       if (group->entity_data() != nullptr) {
2582         entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
2583             pattern->entity_data()));
2584       }
2585 
2586       // Set entity field from capturing group text.
2587       if (group->entity_field_path() != nullptr) {
2588         UnicodeText normalized_group_match_text =
2589             UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
2590 
2591         // Apply normalization if specified.
2592         if (group->normalization_options() != nullptr) {
2593           normalized_group_match_text =
2594               NormalizeText(*unilib_, group->normalization_options(),
2595                             normalized_group_match_text);
2596         }
2597 
2598         if (!entity_data->ParseAndSet(
2599                 group->entity_field_path(),
2600                 normalized_group_match_text.ToUTF8String())) {
2601           TC3_LOG(ERROR)
2602               << "Could not set entity data from rule capturing group.";
2603           return false;
2604         }
2605       }
2606     }
2607   }
2608 
2609   *serialized_entity_data = entity_data->Serialize();
2610   return true;
2611 }
2612 
RemoveMoneySeparators(const std::unordered_set<char32> & decimal_separators,const UnicodeText & amount,UnicodeText::const_iterator it_decimal_separator)2613 UnicodeText RemoveMoneySeparators(
2614     const std::unordered_set<char32>& decimal_separators,
2615     const UnicodeText& amount,
2616     UnicodeText::const_iterator it_decimal_separator) {
2617   UnicodeText whole_amount;
2618   for (auto it = amount.begin();
2619        it != amount.end() && it != it_decimal_separator; ++it) {
2620     if (std::find(decimal_separators.begin(), decimal_separators.end(),
2621                   static_cast<char32>(*it)) == decimal_separators.end()) {
2622       whole_amount.push_back(*it);
2623     }
2624   }
2625   return whole_amount;
2626 }
2627 
GetMoneyQuantityFromCapturingGroup(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode,std::string * quantity,int * exponent) const2628 void Annotator::GetMoneyQuantityFromCapturingGroup(
2629     const UniLib::RegexMatcher* match, const RegexModel_::Pattern* config,
2630     const UnicodeText& context_unicode, std::string* quantity,
2631     int* exponent) const {
2632   if (config->capturing_group() == nullptr) {
2633     *exponent = 0;
2634     return;
2635   }
2636 
2637   const int num_groups = config->capturing_group()->size();
2638   for (int i = 0; i < num_groups; i++) {
2639     int status = UniLib::RegexMatcher::kNoError;
2640     const int group_start = match->Start(i, &status);
2641     const int group_end = match->End(i, &status);
2642     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
2643       continue;
2644     }
2645 
2646     *quantity =
2647         unilib_
2648             ->ToLowerText(UnicodeText::Substring(context_unicode, group_start,
2649                                                  group_end, /*do_copy=*/false))
2650             .ToUTF8String();
2651 
2652     if (auto entry = model_->money_parsing_options()
2653                          ->quantities_name_to_exponent()
2654                          ->LookupByKey((*quantity).c_str())) {
2655       *exponent = entry->value();
2656       return;
2657     }
2658   }
2659   *exponent = 0;
2660 }
2661 
ParseAndFillInMoneyAmount(std::string * serialized_entity_data,const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config,const UnicodeText & context_unicode) const2662 bool Annotator::ParseAndFillInMoneyAmount(
2663     std::string* serialized_entity_data, const UniLib::RegexMatcher* match,
2664     const RegexModel_::Pattern* config,
2665     const UnicodeText& context_unicode) const {
2666   std::unique_ptr<EntityDataT> data =
2667       LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
2668           *serialized_entity_data);
2669   if (data == nullptr) {
2670     if (model_->version() >= 706) {
2671       // This way of parsing money entity data is enabled for models newer than
2672       // v706, consequently logging errors only for them (b/156634162).
2673       TC3_LOG(ERROR)
2674           << "Data field is null when trying to parse Money Entity Data";
2675     }
2676     return false;
2677   }
2678   if (data->money->unnormalized_amount.empty()) {
2679     if (model_->version() >= 706) {
2680       // This way of parsing money entity data is enabled for models newer than
2681       // v706, consequently logging errors only for them (b/156634162).
2682       TC3_LOG(ERROR)
2683           << "Data unnormalized_amount is empty when trying to parse "
2684              "Money Entity Data";
2685     }
2686     return false;
2687   }
2688 
2689   UnicodeText amount =
2690       UTF8ToUnicodeText(data->money->unnormalized_amount, /*do_copy=*/false);
2691   int separator_back_index = 0;
2692   auto it_decimal_separator = --amount.end();
2693   for (; it_decimal_separator != amount.begin();
2694        --it_decimal_separator, ++separator_back_index) {
2695     if (std::find(money_separators_.begin(), money_separators_.end(),
2696                   static_cast<char32>(*it_decimal_separator)) !=
2697         money_separators_.end()) {
2698       break;
2699     }
2700   }
2701 
2702   // If there are 3 digits after the last separator, we consider that a
2703   // thousands separator => the number is an int (e.g. 1.234 is considered int).
2704   // If there is no separator in number, also that number is an int.
2705   if (separator_back_index == 3 || it_decimal_separator == amount.begin()) {
2706     it_decimal_separator = amount.end();
2707   }
2708 
2709   if (!unilib_->ParseInt32(RemoveMoneySeparators(money_separators_, amount,
2710                                                  it_decimal_separator),
2711                            &data->money->amount_whole_part)) {
2712     TC3_LOG(ERROR) << "Could not parse the money whole part as int32 from the "
2713                       "amount: "
2714                    << data->money->unnormalized_amount;
2715     return false;
2716   }
2717 
2718   if (it_decimal_separator == amount.end()) {
2719     data->money->amount_decimal_part = 0;
2720     data->money->nanos = 0;
2721   } else {
2722     const int amount_codepoints_size = amount.size_codepoints();
2723     const UnicodeText decimal_part = UnicodeText::Substring(
2724         amount, amount_codepoints_size - separator_back_index,
2725         amount_codepoints_size, /*do_copy=*/false);
2726     if (!unilib_->ParseInt32(decimal_part, &data->money->amount_decimal_part)) {
2727       TC3_LOG(ERROR) << "Could not parse the money decimal part as int32 from "
2728                         "the amount: "
2729                      << data->money->unnormalized_amount;
2730       return false;
2731     }
2732     data->money->nanos = data->money->amount_decimal_part *
2733                          pow(10, 9 - decimal_part.size_codepoints());
2734   }
2735 
2736   if (model_->money_parsing_options()->quantities_name_to_exponent() !=
2737       nullptr) {
2738     int quantity_exponent;
2739     std::string quantity;
2740     GetMoneyQuantityFromCapturingGroup(match, config, context_unicode,
2741                                        &quantity, &quantity_exponent);
2742     if (quantity_exponent > 0 && quantity_exponent <= 9) {
2743       const double amount_whole_part =
2744           data->money->amount_whole_part * pow(10, quantity_exponent) +
2745           data->money->nanos / pow(10, 9 - quantity_exponent);
2746       // TODO(jacekj): Change type of `data->money->amount_whole_part` to int64
2747       // (and `std::numeric_limits<int>::max()` to
2748       // `std::numeric_limits<int64>::max()`).
2749       if (amount_whole_part < std::numeric_limits<int>::max()) {
2750         data->money->amount_whole_part = amount_whole_part;
2751         data->money->nanos = data->money->nanos %
2752                              static_cast<int>(pow(10, 9 - quantity_exponent)) *
2753                              pow(10, quantity_exponent);
2754       }
2755     }
2756     if (quantity_exponent > 0) {
2757       data->money->unnormalized_amount = strings::JoinStrings(
2758           " ", {data->money->unnormalized_amount, quantity});
2759     }
2760   }
2761 
2762   *serialized_entity_data =
2763       PackFlatbuffer<libtextclassifier3::EntityData>(data.get());
2764   return true;
2765 }
2766 
IsAnyModelEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2767 bool Annotator::IsAnyModelEntityTypeEnabled(
2768     const EnabledEntityTypes& is_entity_type_enabled) const {
2769   if (model_->classification_feature_options() == nullptr ||
2770       model_->classification_feature_options()->collections() == nullptr) {
2771     return false;
2772   }
2773   for (int i = 0;
2774        i < model_->classification_feature_options()->collections()->size();
2775        i++) {
2776     if (is_entity_type_enabled(model_->classification_feature_options()
2777                                    ->collections()
2778                                    ->Get(i)
2779                                    ->str())) {
2780       return true;
2781     }
2782   }
2783   return false;
2784 }
2785 
IsAnyRegexEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2786 bool Annotator::IsAnyRegexEntityTypeEnabled(
2787     const EnabledEntityTypes& is_entity_type_enabled) const {
2788   if (model_->regex_model() == nullptr ||
2789       model_->regex_model()->patterns() == nullptr) {
2790     return false;
2791   }
2792   for (int i = 0; i < model_->regex_model()->patterns()->size(); i++) {
2793     if (is_entity_type_enabled(model_->regex_model()
2794                                    ->patterns()
2795                                    ->Get(i)
2796                                    ->collection_name()
2797                                    ->str())) {
2798       return true;
2799     }
2800   }
2801   return false;
2802 }
2803 
IsAnyPodNerEntityTypeEnabled(const EnabledEntityTypes & is_entity_type_enabled) const2804 bool Annotator::IsAnyPodNerEntityTypeEnabled(
2805     const EnabledEntityTypes& is_entity_type_enabled) const {
2806   if (pod_ner_annotator_ == nullptr) {
2807     return false;
2808   }
2809 
2810   for (const std::string& collection :
2811        pod_ner_annotator_->GetSupportedCollections()) {
2812     if (is_entity_type_enabled(collection)) {
2813       return true;
2814     }
2815   }
2816   return false;
2817 }
2818 
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,bool is_serialized_entity_data_enabled,const EnabledEntityTypes & enabled_entity_types,const AnnotationUsecase & annotation_usecase,std::vector<AnnotatedSpan> * result) const2819 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2820                            const std::vector<int>& rules,
2821                            bool is_serialized_entity_data_enabled,
2822                            const EnabledEntityTypes& enabled_entity_types,
2823                            const AnnotationUsecase& annotation_usecase,
2824                            std::vector<AnnotatedSpan>* result) const {
2825   for (int pattern_id : rules) {
2826     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2827     if (!enabled_entity_types(regex_pattern.config->collection_name()->str()) &&
2828         annotation_usecase == AnnotationUsecase_ANNOTATION_USECASE_RAW) {
2829       // No regex annotation type has been requested, skip regex annotation.
2830       continue;
2831     }
2832     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2833     if (!matcher) {
2834       TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2835                      << pattern_id;
2836       return false;
2837     }
2838 
2839     int status = UniLib::RegexMatcher::kNoError;
2840     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2841       if (regex_pattern.config->verification_options()) {
2842         if (!VerifyRegexMatchCandidate(
2843                 context_unicode.ToUTF8String(),
2844                 regex_pattern.config->verification_options(),
2845                 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2846           continue;
2847         }
2848       }
2849 
2850       std::string serialized_entity_data;
2851       if (is_serialized_entity_data_enabled) {
2852         if (!SerializedEntityDataFromRegexMatch(
2853                 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2854           TC3_LOG(ERROR) << "Could not get entity data.";
2855           return false;
2856         }
2857 
2858         // Further parsing of money amount. Need this since regexes cannot have
2859         // empty groups that fill in entity data (amount_decimal_part and
2860         // quantity might be empty groups).
2861         if (regex_pattern.config->collection_name()->str() ==
2862             Collections::Money()) {
2863           if (!ParseAndFillInMoneyAmount(&serialized_entity_data, matcher.get(),
2864                                          regex_pattern.config,
2865                                          context_unicode)) {
2866             if (model_->version() >= 706) {
2867               // This way of parsing money entity data is enabled for models
2868               // newer than v706 => logging errors only for them (b/156634162).
2869               TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
2870             }
2871           }
2872         }
2873       }
2874 
2875       result->emplace_back();
2876 
2877       // Selection/annotation regular expressions need to specify a capturing
2878       // group specifying the selection.
2879       result->back().span =
2880           ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2881 
2882       result->back().classification = {
2883           {regex_pattern.config->collection_name()->str(),
2884            regex_pattern.config->target_classification_score(),
2885            regex_pattern.config->priority_score()}};
2886 
2887       result->back().classification[0].serialized_entity_data =
2888           serialized_entity_data;
2889     }
2890   }
2891   return true;
2892 }
2893 
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2894 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2895                            tflite::Interpreter* selection_interpreter,
2896                            const CachedFeatures& cached_features,
2897                            std::vector<TokenSpan>* chunks) const {
2898   const int max_selection_span =
2899       selection_feature_processor_->GetOptions()->max_selection_span();
2900   // The inference span is the span of interest expanded to include
2901   // max_selection_span tokens on either side, which is how far a selection can
2902   // stretch from the click.
2903   const TokenSpan inference_span =
2904       IntersectTokenSpans(span_of_interest.Expand(
2905                               /*num_tokens_left=*/max_selection_span,
2906                               /*num_tokens_right=*/max_selection_span),
2907                           {0, num_tokens});
2908 
2909   std::vector<ScoredChunk> scored_chunks;
2910   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2911       selection_feature_processor_->GetOptions()
2912           ->bounds_sensitive_features()
2913           ->enabled()) {
2914     if (!ModelBoundsSensitiveScoreChunks(
2915             num_tokens, span_of_interest, inference_span, cached_features,
2916             selection_interpreter, &scored_chunks)) {
2917       return false;
2918     }
2919   } else {
2920     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2921                                       cached_features, selection_interpreter,
2922                                       &scored_chunks)) {
2923       return false;
2924     }
2925   }
2926   std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2927             [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2928               return lhs.score < rhs.score;
2929             });
2930 
2931   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2932   // them greedily as long as they do not overlap with any previously picked
2933   // chunks.
2934   std::vector<bool> token_used(inference_span.Size());
2935   chunks->clear();
2936   for (const ScoredChunk& scored_chunk : scored_chunks) {
2937     bool feasible = true;
2938     for (int i = scored_chunk.token_span.first;
2939          i < scored_chunk.token_span.second; ++i) {
2940       if (token_used[i - inference_span.first]) {
2941         feasible = false;
2942         break;
2943       }
2944     }
2945 
2946     if (!feasible) {
2947       continue;
2948     }
2949 
2950     for (int i = scored_chunk.token_span.first;
2951          i < scored_chunk.token_span.second; ++i) {
2952       token_used[i - inference_span.first] = true;
2953     }
2954 
2955     chunks->push_back(scored_chunk.token_span);
2956   }
2957 
2958   std::sort(chunks->begin(), chunks->end());
2959 
2960   return true;
2961 }
2962 
2963 namespace {
2964 // Updates the value at the given key in the map to maximum of the current value
2965 // and the given value, or simply inserts the value if the key is not yet there.
2966 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2967 void UpdateMax(Map* map, typename Map::key_type key,
2968                typename Map::mapped_type value) {
2969   const auto it = map->find(key);
2970   if (it != map->end()) {
2971     it->second = std::max(it->second, value);
2972   } else {
2973     (*map)[key] = value;
2974   }
2975 }
2976 }  // namespace
2977 
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2978 bool Annotator::ModelClickContextScoreChunks(
2979     int num_tokens, const TokenSpan& span_of_interest,
2980     const CachedFeatures& cached_features,
2981     tflite::Interpreter* selection_interpreter,
2982     std::vector<ScoredChunk>* scored_chunks) const {
2983   const int max_batch_size = model_->selection_options()->batch_size();
2984 
2985   std::vector<float> all_features;
2986   std::map<TokenSpan, float> chunk_scores;
2987   for (int batch_start = span_of_interest.first;
2988        batch_start < span_of_interest.second; batch_start += max_batch_size) {
2989     const int batch_end =
2990         std::min(batch_start + max_batch_size, span_of_interest.second);
2991 
2992     // Prepare features for the whole batch.
2993     all_features.clear();
2994     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2995     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2996       cached_features.AppendClickContextFeaturesForClick(click_pos,
2997                                                          &all_features);
2998     }
2999 
3000     // Run batched inference.
3001     const int batch_size = batch_end - batch_start;
3002     const int features_size = cached_features.OutputFeaturesSize();
3003     TensorView<float> logits = selection_executor_->ComputeLogits(
3004         TensorView<float>(all_features.data(), {batch_size, features_size}),
3005         selection_interpreter);
3006     if (!logits.is_valid()) {
3007       TC3_LOG(ERROR) << "Couldn't compute logits.";
3008       return false;
3009     }
3010     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3011         logits.dim(1) !=
3012             selection_feature_processor_->GetSelectionLabelCount()) {
3013       TC3_LOG(ERROR) << "Mismatching output.";
3014       return false;
3015     }
3016 
3017     // Save results.
3018     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
3019       const std::vector<float> scores = ComputeSoftmax(
3020           logits.data() + logits.dim(1) * (click_pos - batch_start),
3021           logits.dim(1));
3022       for (int j = 0;
3023            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
3024         TokenSpan relative_token_span;
3025         if (!selection_feature_processor_->LabelToTokenSpan(
3026                 j, &relative_token_span)) {
3027           TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
3028           return false;
3029         }
3030         const TokenSpan candidate_span = TokenSpan(click_pos).Expand(
3031             relative_token_span.first, relative_token_span.second);
3032         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
3033           UpdateMax(&chunk_scores, candidate_span, scores[j]);
3034         }
3035       }
3036     }
3037   }
3038 
3039   scored_chunks->clear();
3040   scored_chunks->reserve(chunk_scores.size());
3041   for (const auto& entry : chunk_scores) {
3042     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
3043   }
3044 
3045   return true;
3046 }
3047 
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const3048 bool Annotator::ModelBoundsSensitiveScoreChunks(
3049     int num_tokens, const TokenSpan& span_of_interest,
3050     const TokenSpan& inference_span, const CachedFeatures& cached_features,
3051     tflite::Interpreter* selection_interpreter,
3052     std::vector<ScoredChunk>* scored_chunks) const {
3053   const int max_selection_span =
3054       selection_feature_processor_->GetOptions()->max_selection_span();
3055   const int max_chunk_length = selection_feature_processor_->GetOptions()
3056                                        ->selection_reduced_output_space()
3057                                    ? max_selection_span + 1
3058                                    : 2 * max_selection_span + 1;
3059   const bool score_single_token_spans_as_zero =
3060       selection_feature_processor_->GetOptions()
3061           ->bounds_sensitive_features()
3062           ->score_single_token_spans_as_zero();
3063 
3064   scored_chunks->clear();
3065   if (score_single_token_spans_as_zero) {
3066     scored_chunks->reserve(span_of_interest.Size());
3067   }
3068 
3069   // Prepare all chunk candidates into one batch:
3070   //   - Are contained in the inference span
3071   //   - Have a non-empty intersection with the span of interest
3072   //   - Are at least one token long
3073   //   - Are not longer than the maximum chunk length
3074   std::vector<TokenSpan> candidate_spans;
3075   for (int start = inference_span.first; start < span_of_interest.second;
3076        ++start) {
3077     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
3078     for (int end = leftmost_end_index;
3079          end <= inference_span.second && end - start <= max_chunk_length;
3080          ++end) {
3081       const TokenSpan candidate_span = {start, end};
3082       if (score_single_token_spans_as_zero && candidate_span.Size() == 1) {
3083         // Do not include the single token span in the batch, add a zero score
3084         // for it directly to the output.
3085         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
3086       } else {
3087         candidate_spans.push_back(candidate_span);
3088       }
3089     }
3090   }
3091 
3092   const int max_batch_size = model_->selection_options()->batch_size();
3093 
3094   std::vector<float> all_features;
3095   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
3096   for (int batch_start = 0; batch_start < candidate_spans.size();
3097        batch_start += max_batch_size) {
3098     const int batch_end = std::min(batch_start + max_batch_size,
3099                                    static_cast<int>(candidate_spans.size()));
3100 
3101     // Prepare features for the whole batch.
3102     all_features.clear();
3103     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
3104     for (int i = batch_start; i < batch_end; ++i) {
3105       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
3106                                                            &all_features);
3107     }
3108 
3109     // Run batched inference.
3110     const int batch_size = batch_end - batch_start;
3111     const int features_size = cached_features.OutputFeaturesSize();
3112     TensorView<float> logits = selection_executor_->ComputeLogits(
3113         TensorView<float>(all_features.data(), {batch_size, features_size}),
3114         selection_interpreter);
3115     if (!logits.is_valid()) {
3116       TC3_LOG(ERROR) << "Couldn't compute logits.";
3117       return false;
3118     }
3119     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
3120         logits.dim(1) != 1) {
3121       TC3_LOG(ERROR) << "Mismatching output.";
3122       return false;
3123     }
3124 
3125     // Save results.
3126     for (int i = batch_start; i < batch_end; ++i) {
3127       scored_chunks->push_back(
3128           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
3129     }
3130   }
3131 
3132   return true;
3133 }
3134 
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const3135 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
3136                               int64 reference_time_ms_utc,
3137                               const std::string& reference_timezone,
3138                               const std::string& locales, ModeFlag mode,
3139                               AnnotationUsecase annotation_usecase,
3140                               bool is_serialized_entity_data_enabled,
3141                               std::vector<AnnotatedSpan>* result) const {
3142   if (!datetime_parser_) {
3143     return true;
3144   }
3145   LocaleList locale_list = LocaleList::ParseFrom(locales);
3146   StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
3147       datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
3148                               reference_timezone, locale_list, mode,
3149                               annotation_usecase,
3150                               /*anchor_start_end=*/false);
3151   if (!result_status.ok()) {
3152     return false;
3153   }
3154 
3155   for (const DatetimeParseResultSpan& datetime_span :
3156        result_status.ValueOrDie()) {
3157     AnnotatedSpan annotated_span;
3158     annotated_span.span = datetime_span.span;
3159     for (const DatetimeParseResult& parse_result : datetime_span.data) {
3160       annotated_span.classification.emplace_back(
3161           PickCollectionForDatetime(parse_result),
3162           datetime_span.target_classification_score,
3163           datetime_span.priority_score);
3164       annotated_span.classification.back().datetime_parse_result = parse_result;
3165       if (is_serialized_entity_data_enabled) {
3166         annotated_span.classification.back().serialized_entity_data =
3167             CreateDatetimeSerializedEntityData(parse_result);
3168       }
3169     }
3170     annotated_span.source = AnnotatedSpan::Source::DATETIME;
3171     result->push_back(std::move(annotated_span));
3172   }
3173   return true;
3174 }
3175 
model() const3176 const Model* Annotator::model() const { return model_; }
entity_data_schema() const3177 const reflection::Schema* Annotator::entity_data_schema() const {
3178   return entity_data_schema_;
3179 }
3180 
ViewModel(const void * buffer,int size)3181 const Model* ViewModel(const void* buffer, int size) {
3182   if (!buffer) {
3183     return nullptr;
3184   }
3185 
3186   return LoadAndVerifyModel(buffer, size);
3187 }
3188 
LookUpKnowledgeEntity(const std::string & id) const3189 StatusOr<std::string> Annotator::LookUpKnowledgeEntity(
3190     const std::string& id) const {
3191   if (!knowledge_engine_) {
3192     return Status(StatusCode::FAILED_PRECONDITION,
3193                   "knowledge_engine_ is nullptr");
3194   }
3195   return knowledge_engine_->LookUpEntity(id);
3196 }
3197 
LookUpKnowledgeEntityProperty(const std::string & mid_str,const std::string & property) const3198 StatusOr<std::string> Annotator::LookUpKnowledgeEntityProperty(
3199     const std::string& mid_str, const std::string& property) const {
3200   if (!knowledge_engine_) {
3201     return Status(StatusCode::FAILED_PRECONDITION,
3202                   "knowledge_engine_ is nullptr");
3203   }
3204   return knowledge_engine_->LookUpEntityProperty(mid_str, property);
3205 }
3206 
3207 }  // namespace libtextclassifier3
3208