• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/annotator.h"
18 
19 #include <algorithm>
20 #include <cctype>
21 #include <cmath>
22 #include <iterator>
23 #include <numeric>
24 #include <unordered_map>
25 
26 #include "annotator/collections.h"
27 #include "annotator/model_generated.h"
28 #include "annotator/types.h"
29 #include "utils/base/logging.h"
30 #include "utils/checksum.h"
31 #include "utils/math/softmax.h"
32 #include "utils/regex-match.h"
33 #include "utils/utf8/unicodetext.h"
34 #include "utils/zlib/zlib_regex.h"
35 
36 
37 namespace libtextclassifier3 {
38 
39 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
40 
41 const std::string& Annotator::kPhoneCollection =
__anon03b269f90102() 42     *[]() { return new std::string("phone"); }();
43 const std::string& Annotator::kAddressCollection =
__anon03b269f90202() 44     *[]() { return new std::string("address"); }();
45 const std::string& Annotator::kDateCollection =
__anon03b269f90302() 46     *[]() { return new std::string("date"); }();
47 const std::string& Annotator::kUrlCollection =
__anon03b269f90402() 48     *[]() { return new std::string("url"); }();
49 const std::string& Annotator::kEmailCollection =
__anon03b269f90502() 50     *[]() { return new std::string("email"); }();
51 
52 namespace {
LoadAndVerifyModel(const void * addr,int size)53 const Model* LoadAndVerifyModel(const void* addr, int size) {
54   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
55   if (VerifyModelBuffer(verifier)) {
56     return GetModel(addr);
57   } else {
58     return nullptr;
59   }
60 }
61 
62 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
63 // create a new instance, assign ownership to owned_lib, and return it.
MaybeCreateUnilib(const UniLib * lib,std::unique_ptr<UniLib> * owned_lib)64 const UniLib* MaybeCreateUnilib(const UniLib* lib,
65                                 std::unique_ptr<UniLib>* owned_lib) {
66   if (lib) {
67     return lib;
68   } else {
69     owned_lib->reset(new UniLib);
70     return owned_lib->get();
71   }
72 }
73 
74 // As above, but for CalendarLib.
MaybeCreateCalendarlib(const CalendarLib * lib,std::unique_ptr<CalendarLib> * owned_lib)75 const CalendarLib* MaybeCreateCalendarlib(
76     const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
77   if (lib) {
78     return lib;
79   } else {
80     owned_lib->reset(new CalendarLib);
81     return owned_lib->get();
82   }
83 }
84 
85 }  // namespace
86 
SelectionInterpreter()87 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
88   if (!selection_interpreter_) {
89     TC3_CHECK(selection_executor_);
90     selection_interpreter_ = selection_executor_->CreateInterpreter();
91     if (!selection_interpreter_) {
92       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
93     }
94   }
95   return selection_interpreter_.get();
96 }
97 
ClassificationInterpreter()98 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
99   if (!classification_interpreter_) {
100     TC3_CHECK(classification_executor_);
101     classification_interpreter_ = classification_executor_->CreateInterpreter();
102     if (!classification_interpreter_) {
103       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
104     }
105   }
106   return classification_interpreter_.get();
107 }
108 
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib,const CalendarLib * calendarlib)109 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
110     const char* buffer, int size, const UniLib* unilib,
111     const CalendarLib* calendarlib) {
112   const Model* model = LoadAndVerifyModel(buffer, size);
113   if (model == nullptr) {
114     return nullptr;
115   }
116 
117   auto classifier =
118       std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
119   if (!classifier->IsInitialized()) {
120     return nullptr;
121   }
122 
123   return classifier;
124 }
125 
126 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib,const CalendarLib * calendarlib)127 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
128     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
129     const CalendarLib* calendarlib) {
130   if (!(*mmap)->handle().ok()) {
131     TC3_VLOG(1) << "Mmap failed.";
132     return nullptr;
133   }
134 
135   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
136                                           (*mmap)->handle().num_bytes());
137   if (!model) {
138     TC3_LOG(ERROR) << "Model verification failed.";
139     return nullptr;
140   }
141 
142   auto classifier = std::unique_ptr<Annotator>(
143       new Annotator(mmap, model, unilib, calendarlib));
144   if (!classifier->IsInitialized()) {
145     return nullptr;
146   }
147 
148   return classifier;
149 }
150 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)151 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
152     std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
153     std::unique_ptr<CalendarLib> calendarlib) {
154   if (!(*mmap)->handle().ok()) {
155     TC3_VLOG(1) << "Mmap failed.";
156     return nullptr;
157   }
158 
159   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
160                                           (*mmap)->handle().num_bytes());
161   if (model == nullptr) {
162     TC3_LOG(ERROR) << "Model verification failed.";
163     return nullptr;
164   }
165 
166   auto classifier = std::unique_ptr<Annotator>(
167       new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
168   if (!classifier->IsInitialized()) {
169     return nullptr;
170   }
171 
172   return classifier;
173 }
174 
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib,const CalendarLib * calendarlib)175 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
176     int fd, int offset, int size, const UniLib* unilib,
177     const CalendarLib* calendarlib) {
178   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
179   return FromScopedMmap(&mmap, unilib, calendarlib);
180 }
181 
FromFileDescriptor(int fd,int offset,int size,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)182 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
183     int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
184     std::unique_ptr<CalendarLib> calendarlib) {
185   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
186   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
187 }
188 
FromFileDescriptor(int fd,const UniLib * unilib,const CalendarLib * calendarlib)189 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
190     int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
191   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
192   return FromScopedMmap(&mmap, unilib, calendarlib);
193 }
194 
FromFileDescriptor(int fd,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)195 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
196     int fd, std::unique_ptr<UniLib> unilib,
197     std::unique_ptr<CalendarLib> calendarlib) {
198   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
199   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
200 }
201 
FromPath(const std::string & path,const UniLib * unilib,const CalendarLib * calendarlib)202 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
203                                                const UniLib* unilib,
204                                                const CalendarLib* calendarlib) {
205   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
206   return FromScopedMmap(&mmap, unilib, calendarlib);
207 }
208 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)209 std::unique_ptr<Annotator> Annotator::FromPath(
210     const std::string& path, std::unique_ptr<UniLib> unilib,
211     std::unique_ptr<CalendarLib> calendarlib) {
212   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
213   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
214 }
215 
Annotator(std::unique_ptr<ScopedMmap> * mmap,const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)216 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
217                      const UniLib* unilib, const CalendarLib* calendarlib)
218     : model_(model),
219       mmap_(std::move(*mmap)),
220       owned_unilib_(nullptr),
221       unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
222       owned_calendarlib_(nullptr),
223       calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
224   ValidateAndInitialize();
225 }
226 
Annotator(std::unique_ptr<ScopedMmap> * mmap,const Model * model,std::unique_ptr<UniLib> unilib,std::unique_ptr<CalendarLib> calendarlib)227 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
228                      std::unique_ptr<UniLib> unilib,
229                      std::unique_ptr<CalendarLib> calendarlib)
230     : model_(model),
231       mmap_(std::move(*mmap)),
232       owned_unilib_(std::move(unilib)),
233       unilib_(owned_unilib_.get()),
234       owned_calendarlib_(std::move(calendarlib)),
235       calendarlib_(owned_calendarlib_.get()) {
236   ValidateAndInitialize();
237 }
238 
Annotator(const Model * model,const UniLib * unilib,const CalendarLib * calendarlib)239 Annotator::Annotator(const Model* model, const UniLib* unilib,
240                      const CalendarLib* calendarlib)
241     : model_(model),
242       owned_unilib_(nullptr),
243       unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
244       owned_calendarlib_(nullptr),
245       calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
246   ValidateAndInitialize();
247 }
248 
ValidateAndInitialize()249 void Annotator::ValidateAndInitialize() {
250   initialized_ = false;
251 
252   if (model_ == nullptr) {
253     TC3_LOG(ERROR) << "No model specified.";
254     return;
255   }
256 
257   const bool model_enabled_for_annotation =
258       (model_->triggering_options() != nullptr &&
259        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
260   const bool model_enabled_for_classification =
261       (model_->triggering_options() != nullptr &&
262        (model_->triggering_options()->enabled_modes() &
263         ModeFlag_CLASSIFICATION));
264   const bool model_enabled_for_selection =
265       (model_->triggering_options() != nullptr &&
266        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
267 
268   // Annotation requires the selection model.
269   if (model_enabled_for_annotation || model_enabled_for_selection) {
270     if (!model_->selection_options()) {
271       TC3_LOG(ERROR) << "No selection options.";
272       return;
273     }
274     if (!model_->selection_feature_options()) {
275       TC3_LOG(ERROR) << "No selection feature options.";
276       return;
277     }
278     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
279       TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
280       return;
281     }
282     if (!model_->selection_model()) {
283       TC3_LOG(ERROR) << "No selection model.";
284       return;
285     }
286     selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
287     if (!selection_executor_) {
288       TC3_LOG(ERROR) << "Could not initialize selection executor.";
289       return;
290     }
291     selection_feature_processor_.reset(
292         new FeatureProcessor(model_->selection_feature_options(), unilib_));
293   }
294 
295   // Annotation requires the classification model for conflict resolution and
296   // scoring.
297   // Selection requires the classification model for conflict resolution.
298   if (model_enabled_for_annotation || model_enabled_for_classification ||
299       model_enabled_for_selection) {
300     if (!model_->classification_options()) {
301       TC3_LOG(ERROR) << "No classification options.";
302       return;
303     }
304 
305     if (!model_->classification_feature_options()) {
306       TC3_LOG(ERROR) << "No classification feature options.";
307       return;
308     }
309 
310     if (!model_->classification_feature_options()
311              ->bounds_sensitive_features()) {
312       TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
313       return;
314     }
315     if (!model_->classification_model()) {
316       TC3_LOG(ERROR) << "No clf model.";
317       return;
318     }
319 
320     classification_executor_ =
321         ModelExecutor::FromBuffer(model_->classification_model());
322     if (!classification_executor_) {
323       TC3_LOG(ERROR) << "Could not initialize classification executor.";
324       return;
325     }
326 
327     classification_feature_processor_.reset(new FeatureProcessor(
328         model_->classification_feature_options(), unilib_));
329   }
330 
331   // The embeddings need to be specified if the model is to be used for
332   // classification or selection.
333   if (model_enabled_for_annotation || model_enabled_for_classification ||
334       model_enabled_for_selection) {
335     if (!model_->embedding_model()) {
336       TC3_LOG(ERROR) << "No embedding model.";
337       return;
338     }
339 
340     // Check that the embedding size of the selection and classification model
341     // matches, as they are using the same embeddings.
342     if (model_enabled_for_selection &&
343         (model_->selection_feature_options()->embedding_size() !=
344              model_->classification_feature_options()->embedding_size() ||
345          model_->selection_feature_options()->embedding_quantization_bits() !=
346              model_->classification_feature_options()
347                  ->embedding_quantization_bits())) {
348       TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
349       return;
350     }
351 
352     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
353         model_->embedding_model(),
354         model_->classification_feature_options()->embedding_size(),
355         model_->classification_feature_options()->embedding_quantization_bits(),
356         model_->embedding_pruning_mask());
357     if (!embedding_executor_) {
358       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
359       return;
360     }
361   }
362 
363   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
364   if (model_->regex_model()) {
365     if (!InitializeRegexModel(decompressor.get())) {
366       TC3_LOG(ERROR) << "Could not initialize regex model.";
367       return;
368     }
369   }
370 
371   if (model_->datetime_model()) {
372     datetime_parser_ = DatetimeParser::Instance(
373         model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
374     if (!datetime_parser_) {
375       TC3_LOG(ERROR) << "Could not initialize datetime parser.";
376       return;
377     }
378   }
379 
380   if (model_->output_options()) {
381     if (model_->output_options()->filtered_collections_annotation()) {
382       for (const auto collection :
383            *model_->output_options()->filtered_collections_annotation()) {
384         filtered_collections_annotation_.insert(collection->str());
385       }
386     }
387     if (model_->output_options()->filtered_collections_classification()) {
388       for (const auto collection :
389            *model_->output_options()->filtered_collections_classification()) {
390         filtered_collections_classification_.insert(collection->str());
391       }
392     }
393     if (model_->output_options()->filtered_collections_selection()) {
394       for (const auto collection :
395            *model_->output_options()->filtered_collections_selection()) {
396         filtered_collections_selection_.insert(collection->str());
397       }
398     }
399   }
400 
401   if (model_->number_annotator_options() &&
402       model_->number_annotator_options()->enabled()) {
403     if (selection_feature_processor_ == nullptr) {
404       TC3_LOG(ERROR)
405           << "Could not initialize NumberAnnotator without a feature processor";
406       return;
407     }
408 
409     number_annotator_.reset(
410         new NumberAnnotator(model_->number_annotator_options(),
411                             selection_feature_processor_.get()));
412   }
413 
414   if (model_->duration_annotator_options() &&
415       model_->duration_annotator_options()->enabled()) {
416     duration_annotator_.reset(
417         new DurationAnnotator(model_->duration_annotator_options(),
418                               selection_feature_processor_.get()));
419   }
420 
421   if (model_->entity_data_schema()) {
422     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
423         model_->entity_data_schema()->Data(),
424         model_->entity_data_schema()->size());
425     if (entity_data_schema_ == nullptr) {
426       TC3_LOG(ERROR) << "Could not load entity data schema data.";
427       return;
428     }
429 
430     entity_data_builder_.reset(
431         new ReflectiveFlatbufferBuilder(entity_data_schema_));
432   } else {
433     entity_data_schema_ = nullptr;
434     entity_data_builder_ = nullptr;
435   }
436 
437   if (model_->triggering_locales() &&
438       !ParseLocales(model_->triggering_locales()->c_str(),
439                     &model_triggering_locales_)) {
440     TC3_LOG(ERROR) << "Could not parse model supported locales.";
441     return;
442   }
443 
444   if (model_->triggering_options() != nullptr &&
445       model_->triggering_options()->locales() != nullptr &&
446       !ParseLocales(model_->triggering_options()->locales()->c_str(),
447                     &ml_model_triggering_locales_)) {
448     TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
449     return;
450   }
451 
452   if (model_->triggering_options() != nullptr &&
453       model_->triggering_options()->dictionary_locales() != nullptr &&
454       !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
455                     &dictionary_locales_)) {
456     TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
457     return;
458   }
459 
460   initialized_ = true;
461 }
462 
InitializeRegexModel(ZlibDecompressor * decompressor)463 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
464   if (!model_->regex_model()->patterns()) {
465     return true;
466   }
467 
468   // Initialize pattern recognizers.
469   int regex_pattern_id = 0;
470   for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
471     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
472         UncompressMakeRegexPattern(
473             *unilib_, regex_pattern->pattern(),
474             regex_pattern->compressed_pattern(),
475             model_->regex_model()->lazy_regex_compilation(), decompressor);
476     if (!compiled_pattern) {
477       TC3_LOG(INFO) << "Failed to load regex pattern";
478       return false;
479     }
480 
481     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
482       annotation_regex_patterns_.push_back(regex_pattern_id);
483     }
484     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
485       classification_regex_patterns_.push_back(regex_pattern_id);
486     }
487     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
488       selection_regex_patterns_.push_back(regex_pattern_id);
489     }
490     regex_patterns_.push_back({
491         regex_pattern,
492         std::move(compiled_pattern),
493     });
494     ++regex_pattern_id;
495   }
496 
497   return true;
498 }
499 
InitializeKnowledgeEngine(const std::string & serialized_config)500 bool Annotator::InitializeKnowledgeEngine(
501     const std::string& serialized_config) {
502   std::unique_ptr<KnowledgeEngine> knowledge_engine(
503       new KnowledgeEngine(unilib_));
504   if (!knowledge_engine->Initialize(serialized_config)) {
505     TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
506     return false;
507   }
508   knowledge_engine_ = std::move(knowledge_engine);
509   return true;
510 }
511 
InitializeContactEngine(const std::string & serialized_config)512 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
513   std::unique_ptr<ContactEngine> contact_engine(
514       new ContactEngine(selection_feature_processor_.get(), unilib_));
515   if (!contact_engine->Initialize(serialized_config)) {
516     TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
517     return false;
518   }
519   contact_engine_ = std::move(contact_engine);
520   return true;
521 }
522 
InitializeInstalledAppEngine(const std::string & serialized_config)523 bool Annotator::InitializeInstalledAppEngine(
524     const std::string& serialized_config) {
525   std::unique_ptr<InstalledAppEngine> installed_app_engine(
526       new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
527   if (!installed_app_engine->Initialize(serialized_config)) {
528     TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
529     return false;
530   }
531   installed_app_engine_ = std::move(installed_app_engine);
532   return true;
533 }
534 
535 namespace {
536 
CountDigits(const std::string & str,CodepointSpan selection_indices)537 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
538   int count = 0;
539   int i = 0;
540   const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
541   for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
542     if (i >= selection_indices.first && i < selection_indices.second &&
543         isdigit(*it)) {
544       ++count;
545     }
546   }
547   return count;
548 }
549 
550 }  // namespace
551 
552 namespace internal {
553 // Helper function, which if the initial 'span' contains only white-spaces,
554 // moves the selection to a single-codepoint selection on a left or right side
555 // of this space.
SnapLeftIfWhitespaceSelection(CodepointSpan span,const UnicodeText & context_unicode,const UniLib & unilib)556 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
557                                             const UnicodeText& context_unicode,
558                                             const UniLib& unilib) {
559   TC3_CHECK(ValidNonEmptySpan(span));
560 
561   UnicodeText::const_iterator it;
562 
563   // Check that the current selection is all whitespaces.
564   it = context_unicode.begin();
565   std::advance(it, span.first);
566   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
567     if (!unilib.IsWhitespace(*it)) {
568       return span;
569     }
570   }
571 
572   CodepointSpan result;
573 
574   // Try moving left.
575   result = span;
576   it = context_unicode.begin();
577   std::advance(it, span.first);
578   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
579     --result.first;
580     --it;
581   }
582   result.second = result.first + 1;
583   if (!unilib.IsWhitespace(*it)) {
584     return result;
585   }
586 
587   // If moving left didn't find a non-whitespace character, just return the
588   // original span.
589   return span;
590 }
591 }  // namespace internal
592 
FilteredForAnnotation(const AnnotatedSpan & span) const593 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
594   return !span.classification.empty() &&
595          filtered_collections_annotation_.find(
596              span.classification[0].collection) !=
597              filtered_collections_annotation_.end();
598 }
599 
FilteredForClassification(const ClassificationResult & classification) const600 bool Annotator::FilteredForClassification(
601     const ClassificationResult& classification) const {
602   return filtered_collections_classification_.find(classification.collection) !=
603          filtered_collections_classification_.end();
604 }
605 
FilteredForSelection(const AnnotatedSpan & span) const606 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
607   return !span.classification.empty() &&
608          filtered_collections_selection_.find(
609              span.classification[0].collection) !=
610              filtered_collections_selection_.end();
611 }
612 
613 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)614 inline bool ClassifiedAsOther(
615     const std::vector<ClassificationResult>& classification) {
616   return !classification.empty() &&
617          classification[0].collection == Collections::Other();
618 }
619 
GetPriorityScore(const std::vector<ClassificationResult> & classification)620 float GetPriorityScore(
621     const std::vector<ClassificationResult>& classification) {
622   if (!classification.empty() && !ClassifiedAsOther(classification)) {
623     return classification[0].priority_score;
624   } else {
625     return -1.0;
626   }
627 }
628 }  // namespace
629 
VerifyRegexMatchCandidate(const std::string & context,const VerificationOptions * verification_options,const std::string & match,const UniLib::RegexMatcher * matcher) const630 bool Annotator::VerifyRegexMatchCandidate(
631     const std::string& context, const VerificationOptions* verification_options,
632     const std::string& match, const UniLib::RegexMatcher* matcher) const {
633   if (verification_options == nullptr) {
634     return true;
635   }
636   if (verification_options->verify_luhn_checksum() &&
637       !VerifyLuhnChecksum(match)) {
638     return false;
639   }
640   const int lua_verifier = verification_options->lua_verifier();
641   if (lua_verifier >= 0) {
642     if (model_->regex_model()->lua_verifier() == nullptr ||
643         lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
644       TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
645       return false;
646     }
647     return VerifyMatch(
648         context, matcher,
649         model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
650   }
651   return true;
652 }
653 
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const654 CodepointSpan Annotator::SuggestSelection(
655     const std::string& context, CodepointSpan click_indices,
656     const SelectionOptions& options) const {
657   CodepointSpan original_click_indices = click_indices;
658   if (!initialized_) {
659     TC3_LOG(ERROR) << "Not initialized";
660     return original_click_indices;
661   }
662   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
663     return original_click_indices;
664   }
665 
666   std::vector<Locale> detected_text_language_tags;
667   if (!ParseLocales(options.detected_text_language_tags,
668                     &detected_text_language_tags)) {
669     TC3_LOG(WARNING)
670         << "Failed to parse the detected_text_language_tags in options: "
671         << options.detected_text_language_tags;
672   }
673   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
674                                     model_triggering_locales_,
675                                     /*default_value=*/true)) {
676     return original_click_indices;
677   }
678 
679   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
680                                                         /*do_copy=*/false);
681 
682   if (!context_unicode.is_valid()) {
683     return original_click_indices;
684   }
685 
686   const int context_codepoint_size = context_unicode.size_codepoints();
687 
688   if (click_indices.first < 0 || click_indices.second < 0 ||
689       click_indices.first >= context_codepoint_size ||
690       click_indices.second > context_codepoint_size ||
691       click_indices.first >= click_indices.second) {
692     TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
693                 << click_indices.first << " " << click_indices.second;
694     return original_click_indices;
695   }
696 
697   if (model_->snap_whitespace_selections()) {
698     // We want to expand a purely white-space selection to a multi-selection it
699     // would've been part of. But with this feature disabled we would do a no-
700     // op, because no token is found. Therefore, we need to modify the
701     // 'click_indices' a bit to include a part of the token, so that the click-
702     // finding logic finds the clicked token correctly. This modification is
703     // done by the following function. Note, that it's enough to check the left
704     // side of the current selection, because if the white-space is a part of a
705     // multi-selection, necessarily both tokens - on the left and the right
706     // sides need to be selected. Thus snapping only to the left is sufficient
707     // (there's a check at the bottom that makes sure that if we snap to the
708     // left token but the result does not contain the initial white-space,
709     // returns the original indices).
710     click_indices = internal::SnapLeftIfWhitespaceSelection(
711         click_indices, context_unicode, *unilib_);
712   }
713 
714   std::vector<AnnotatedSpan> candidates;
715   InterpreterManager interpreter_manager(selection_executor_.get(),
716                                          classification_executor_.get());
717   std::vector<Token> tokens;
718   if (!ModelSuggestSelection(context_unicode, click_indices,
719                              detected_text_language_tags, &interpreter_manager,
720                              &tokens, &candidates)) {
721     TC3_LOG(ERROR) << "Model suggest selection failed.";
722     return original_click_indices;
723   }
724   if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
725                   /*is_serialized_entity_data_enabled=*/false)) {
726     TC3_LOG(ERROR) << "Regex suggest selection failed.";
727     return original_click_indices;
728   }
729   if (!DatetimeChunk(
730           UTF8ToUnicodeText(context, /*do_copy=*/false),
731           /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
732           options.locales, ModeFlag_SELECTION, options.annotation_usecase,
733           /*is_serialized_entity_data_enabled=*/false, &candidates)) {
734     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
735     return original_click_indices;
736   }
737   if (knowledge_engine_ != nullptr &&
738       !knowledge_engine_->Chunk(context, &candidates)) {
739     TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
740     return original_click_indices;
741   }
742   if (contact_engine_ != nullptr &&
743       !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
744     TC3_LOG(ERROR) << "Contact suggest selection failed.";
745     return original_click_indices;
746   }
747   if (installed_app_engine_ != nullptr &&
748       !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
749     TC3_LOG(ERROR) << "Installed app suggest selection failed.";
750     return original_click_indices;
751   }
752   if (number_annotator_ != nullptr &&
753       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
754                                   &candidates)) {
755     TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
756     return original_click_indices;
757   }
758   if (duration_annotator_ != nullptr &&
759       !duration_annotator_->FindAll(context_unicode, tokens,
760                                     options.annotation_usecase, &candidates)) {
761     TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
762     return original_click_indices;
763   }
764 
765   // Sort candidates according to their position in the input, so that the next
766   // code can assume that any connected component of overlapping spans forms a
767   // contiguous block.
768   std::sort(candidates.begin(), candidates.end(),
769             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
770               return a.span.first < b.span.first;
771             });
772 
773   std::vector<int> candidate_indices;
774   if (!ResolveConflicts(candidates, context, tokens,
775                         detected_text_language_tags, options.annotation_usecase,
776                         &interpreter_manager, &candidate_indices)) {
777     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
778     return original_click_indices;
779   }
780 
781   std::sort(candidate_indices.begin(), candidate_indices.end(),
782             [&candidates](int a, int b) {
783               return GetPriorityScore(candidates[a].classification) >
784                      GetPriorityScore(candidates[b].classification);
785             });
786 
787   for (const int i : candidate_indices) {
788     if (SpansOverlap(candidates[i].span, click_indices) &&
789         SpansOverlap(candidates[i].span, original_click_indices)) {
790       // Run model classification if not present but requested and there's a
791       // classification collection filter specified.
792       if (candidates[i].classification.empty() &&
793           model_->selection_options()->always_classify_suggested_selection() &&
794           !filtered_collections_selection_.empty()) {
795         if (!ModelClassifyText(context, detected_text_language_tags,
796                                candidates[i].span, &interpreter_manager,
797                                /*embedding_cache=*/nullptr,
798                                &candidates[i].classification)) {
799           return original_click_indices;
800         }
801       }
802 
803       // Ignore if span classification is filtered.
804       if (FilteredForSelection(candidates[i])) {
805         return original_click_indices;
806       }
807 
808       return candidates[i].span;
809     }
810   }
811 
812   return original_click_indices;
813 }
814 
815 namespace {
816 // Helper function that returns the index of the first candidate that
817 // transitively does not overlap with the candidate on 'start_index'. If the end
818 // of 'candidates' is reached, it returns the index that points right behind the
819 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)820 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
821                                  int start_index) {
822   int first_non_overlapping = start_index + 1;
823   CodepointSpan conflicting_span = candidates[start_index].span;
824   while (
825       first_non_overlapping < candidates.size() &&
826       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
827     // Grow the span to include the current one.
828     conflicting_span.second = std::max(
829         conflicting_span.second, candidates[first_non_overlapping].span.second);
830 
831     ++first_non_overlapping;
832   }
833   return first_non_overlapping;
834 }
835 }  // namespace
836 
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,AnnotationUsecase annotation_usecase,InterpreterManager * interpreter_manager,std::vector<int> * result) const837 bool Annotator::ResolveConflicts(
838     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
839     const std::vector<Token>& cached_tokens,
840     const std::vector<Locale>& detected_text_language_tags,
841     AnnotationUsecase annotation_usecase,
842     InterpreterManager* interpreter_manager, std::vector<int>* result) const {
843   result->clear();
844   result->reserve(candidates.size());
845   for (int i = 0; i < candidates.size();) {
846     int first_non_overlapping =
847         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
848 
849     const bool conflict_found = first_non_overlapping != (i + 1);
850     if (conflict_found) {
851       std::vector<int> candidate_indices;
852       if (!ResolveConflict(context, cached_tokens, candidates,
853                            detected_text_language_tags, i,
854                            first_non_overlapping, annotation_usecase,
855                            interpreter_manager, &candidate_indices)) {
856         return false;
857       }
858       result->insert(result->end(), candidate_indices.begin(),
859                      candidate_indices.end());
860     } else {
861       result->push_back(i);
862     }
863 
864     // Skip over the whole conflicting group/go to next candidate.
865     i = first_non_overlapping;
866   }
867   return true;
868 }
869 
870 namespace {
871 // Returns true, if the given two sources do conflict in given annotation
872 // usecase.
873 //  - In SMART usecase, all sources do conflict, because there's only 1 possible
874 //  annotation for a given span.
875 //  - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
876 //  and duration), while others not (e.g. duration and number).
DoSourcesConflict(AnnotationUsecase annotation_usecase,const AnnotatedSpan::Source source1,const AnnotatedSpan::Source source2)877 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
878                        const AnnotatedSpan::Source source1,
879                        const AnnotatedSpan::Source source2) {
880   uint32 source_mask =
881       (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
882 
883   switch (annotation_usecase) {
884     case AnnotationUsecase_ANNOTATION_USECASE_SMART:
885       // In the SMART mode, all annotations conflict.
886       return true;
887 
888     case AnnotationUsecase_ANNOTATION_USECASE_RAW:
889       // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
890       // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
891       // hours" (duration).
892       if ((source_mask &
893            (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
894           (source_mask &
895            (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
896         return false;
897       }
898 
899       // A KNOWLEDGE entity does not conflict with anything.
900       if ((source_mask &
901            (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
902         return false;
903       }
904 
905       // Entities from other sources can conflict.
906       return true;
907   }
908 }
909 }  // namespace
910 
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,const std::vector<Locale> & detected_text_language_tags,int start_index,int end_index,AnnotationUsecase annotation_usecase,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const911 bool Annotator::ResolveConflict(
912     const std::string& context, const std::vector<Token>& cached_tokens,
913     const std::vector<AnnotatedSpan>& candidates,
914     const std::vector<Locale>& detected_text_language_tags, int start_index,
915     int end_index, AnnotationUsecase annotation_usecase,
916     InterpreterManager* interpreter_manager,
917     std::vector<int>* chosen_indices) const {
918   std::vector<int> conflicting_indices;
919   std::unordered_map<int, float> scores;
920   for (int i = start_index; i < end_index; ++i) {
921     conflicting_indices.push_back(i);
922     if (!candidates[i].classification.empty()) {
923       scores[i] = GetPriorityScore(candidates[i].classification);
924       continue;
925     }
926 
927     // OPTIMIZATION: So that we don't have to classify all the ML model
928     // spans apriori, we wait until we get here, when they conflict with
929     // something and we need the actual classification scores. So if the
930     // candidate conflicts and comes from the model, we need to run a
931     // classification to determine its priority:
932     std::vector<ClassificationResult> classification;
933     if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
934                            candidates[i].span, interpreter_manager,
935                            /*embedding_cache=*/nullptr, &classification)) {
936       return false;
937     }
938 
939     if (!classification.empty()) {
940       scores[i] = GetPriorityScore(classification);
941     }
942   }
943 
944   std::sort(conflicting_indices.begin(), conflicting_indices.end(),
945             [&scores](int i, int j) { return scores[i] > scores[j]; });
946 
947   // Here we keep a set of indices that were chosen, per-source, to enable
948   // effective computation.
949   std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
950       chosen_indices_for_source_map;
951 
952   // Greedily place the candidates if they don't conflict with the already
953   // placed ones.
954   for (int i = 0; i < conflicting_indices.size(); ++i) {
955     const int considered_candidate = conflicting_indices[i];
956 
957     // See if there is a conflict between the candidate and all already placed
958     // candidates.
959     bool conflict = false;
960     SortedIntSet* chosen_indices_for_source_ptr = nullptr;
961     for (auto& source_set_pair : chosen_indices_for_source_map) {
962       if (source_set_pair.first == candidates[considered_candidate].source) {
963         chosen_indices_for_source_ptr = &source_set_pair.second;
964       }
965 
966       if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
967                             candidates[considered_candidate].source) &&
968           DoesCandidateConflict(considered_candidate, candidates,
969                                 source_set_pair.second)) {
970         conflict = true;
971         break;
972       }
973     }
974 
975     // Skip the candidate if a conflict was found.
976     if (conflict) {
977       continue;
978     }
979 
980     // If the set of indices for the current source doesn't exist yet,
981     // initialize it.
982     if (chosen_indices_for_source_ptr == nullptr) {
983       SortedIntSet new_set([&candidates](int a, int b) {
984         return candidates[a].span.first < candidates[b].span.first;
985       });
986       chosen_indices_for_source_map[candidates[considered_candidate].source] =
987           std::move(new_set);
988       chosen_indices_for_source_ptr =
989           &chosen_indices_for_source_map[candidates[considered_candidate]
990                                              .source];
991     }
992 
993     // Place the candidate to the output and to the per-source conflict set.
994     chosen_indices->push_back(considered_candidate);
995     chosen_indices_for_source_ptr->insert(considered_candidate);
996   }
997 
998   std::sort(chosen_indices->begin(), chosen_indices->end());
999 
1000   return true;
1001 }
1002 
ModelSuggestSelection(const UnicodeText & context_unicode,CodepointSpan click_indices,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1003 bool Annotator::ModelSuggestSelection(
1004     const UnicodeText& context_unicode, CodepointSpan click_indices,
1005     const std::vector<Locale>& detected_text_language_tags,
1006     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1007     std::vector<AnnotatedSpan>* result) const {
1008   if (model_->triggering_options() == nullptr ||
1009       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
1010     return true;
1011   }
1012 
1013   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1014                                     ml_model_triggering_locales_,
1015                                     /*default_value=*/true)) {
1016     return true;
1017   }
1018 
1019   int click_pos;
1020   *tokens = selection_feature_processor_->Tokenize(context_unicode);
1021   selection_feature_processor_->RetokenizeAndFindClick(
1022       context_unicode, click_indices,
1023       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1024       tokens, &click_pos);
1025   if (click_pos == kInvalidIndex) {
1026     TC3_VLOG(1) << "Could not calculate the click position.";
1027     return false;
1028   }
1029 
1030   const int symmetry_context_size =
1031       model_->selection_options()->symmetry_context_size();
1032   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1033       bounds_sensitive_features = selection_feature_processor_->GetOptions()
1034                                       ->bounds_sensitive_features();
1035 
1036   // The symmetry context span is the clicked token with symmetry_context_size
1037   // tokens on either side.
1038   const TokenSpan symmetry_context_span = IntersectTokenSpans(
1039       ExpandTokenSpan(SingleTokenSpan(click_pos),
1040                       /*num_tokens_left=*/symmetry_context_size,
1041                       /*num_tokens_right=*/symmetry_context_size),
1042       {0, tokens->size()});
1043 
1044   // Compute the extraction span based on the model type.
1045   TokenSpan extraction_span;
1046   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1047     // The extraction span is the symmetry context span expanded to include
1048     // max_selection_span tokens on either side, which is how far a selection
1049     // can stretch from the click, plus a relevant number of tokens outside of
1050     // the bounds of the selection.
1051     const int max_selection_span =
1052         selection_feature_processor_->GetOptions()->max_selection_span();
1053     extraction_span =
1054         ExpandTokenSpan(symmetry_context_span,
1055                         /*num_tokens_left=*/max_selection_span +
1056                             bounds_sensitive_features->num_tokens_before(),
1057                         /*num_tokens_right=*/max_selection_span +
1058                             bounds_sensitive_features->num_tokens_after());
1059   } else {
1060     // The extraction span is the symmetry context span expanded to include
1061     // context_size tokens on either side.
1062     const int context_size =
1063         selection_feature_processor_->GetOptions()->context_size();
1064     extraction_span = ExpandTokenSpan(symmetry_context_span,
1065                                       /*num_tokens_left=*/context_size,
1066                                       /*num_tokens_right=*/context_size);
1067   }
1068   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
1069 
1070   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1071           *tokens, extraction_span)) {
1072     return true;
1073   }
1074 
1075   std::unique_ptr<CachedFeatures> cached_features;
1076   if (!selection_feature_processor_->ExtractFeatures(
1077           *tokens, extraction_span,
1078           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1079           embedding_executor_.get(),
1080           /*embedding_cache=*/nullptr,
1081           selection_feature_processor_->EmbeddingSize() +
1082               selection_feature_processor_->DenseFeaturesCount(),
1083           &cached_features)) {
1084     TC3_LOG(ERROR) << "Could not extract features.";
1085     return false;
1086   }
1087 
1088   // Produce selection model candidates.
1089   std::vector<TokenSpan> chunks;
1090   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
1091                   interpreter_manager->SelectionInterpreter(), *cached_features,
1092                   &chunks)) {
1093     TC3_LOG(ERROR) << "Could not chunk.";
1094     return false;
1095   }
1096 
1097   for (const TokenSpan& chunk : chunks) {
1098     AnnotatedSpan candidate;
1099     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
1100         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
1101     if (model_->selection_options()->strip_unpaired_brackets()) {
1102       candidate.span =
1103           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
1104     }
1105 
1106     // Only output non-empty spans.
1107     if (candidate.span.first != candidate.span.second) {
1108       result->push_back(candidate);
1109     }
1110   }
1111   return true;
1112 }
1113 
ModelClassifyText(const std::string & context,const std::vector<Locale> & detected_text_language_tags,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const1114 bool Annotator::ModelClassifyText(
1115     const std::string& context,
1116     const std::vector<Locale>& detected_text_language_tags,
1117     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1118     FeatureProcessor::EmbeddingCache* embedding_cache,
1119     std::vector<ClassificationResult>* classification_results) const {
1120   return ModelClassifyText(context, {}, detected_text_language_tags,
1121                            selection_indices, interpreter_manager,
1122                            embedding_cache, classification_results);
1123 }
1124 
1125 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,CodepointSpan selection_indices,TokenSpan tokens_around_selection_to_copy)1126 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
1127                                     CodepointSpan selection_indices,
1128                                     TokenSpan tokens_around_selection_to_copy) {
1129   const auto first_selection_token = std::upper_bound(
1130       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
1131       [](int selection_start, const Token& token) {
1132         return selection_start < token.end;
1133       });
1134   const auto last_selection_token = std::lower_bound(
1135       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
1136       [](const Token& token, int selection_end) {
1137         return token.start < selection_end;
1138       });
1139 
1140   const int64 first_token = std::max(
1141       static_cast<int64>(0),
1142       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
1143                          tokens_around_selection_to_copy.first));
1144   const int64 last_token = std::min(
1145       static_cast<int64>(cached_tokens.size()),
1146       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
1147                          tokens_around_selection_to_copy.second));
1148 
1149   std::vector<Token> tokens;
1150   tokens.reserve(last_token - first_token);
1151   for (int i = first_token; i < last_token; ++i) {
1152     tokens.push_back(cached_tokens[i]);
1153   }
1154   return tokens;
1155 }
1156 }  // namespace internal
1157 
ClassifyTextUpperBoundNeededTokens() const1158 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
1159   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1160       bounds_sensitive_features =
1161           classification_feature_processor_->GetOptions()
1162               ->bounds_sensitive_features();
1163   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1164     // The extraction span is the selection span expanded to include a relevant
1165     // number of tokens outside of the bounds of the selection.
1166     return {bounds_sensitive_features->num_tokens_before(),
1167             bounds_sensitive_features->num_tokens_after()};
1168   } else {
1169     // The extraction span is the clicked token with context_size tokens on
1170     // either side.
1171     const int context_size =
1172         selection_feature_processor_->GetOptions()->context_size();
1173     return {context_size, context_size};
1174   }
1175 }
1176 
1177 namespace {
1178 // Sorts the classification results from high score to low score.
SortClassificationResults(std::vector<ClassificationResult> * classification_results)1179 void SortClassificationResults(
1180     std::vector<ClassificationResult>* classification_results) {
1181   std::sort(classification_results->begin(), classification_results->end(),
1182             [](const ClassificationResult& a, const ClassificationResult& b) {
1183               return a.score > b.score;
1184             });
1185 }
1186 }  // namespace
1187 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const1188 bool Annotator::ModelClassifyText(
1189     const std::string& context, const std::vector<Token>& cached_tokens,
1190     const std::vector<Locale>& detected_text_language_tags,
1191     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1192     FeatureProcessor::EmbeddingCache* embedding_cache,
1193     std::vector<ClassificationResult>* classification_results) const {
1194   std::vector<Token> tokens;
1195   return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
1196                            selection_indices, interpreter_manager,
1197                            embedding_cache, classification_results, &tokens);
1198 }
1199 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<Locale> & detected_text_language_tags,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results,std::vector<Token> * tokens) const1200 bool Annotator::ModelClassifyText(
1201     const std::string& context, const std::vector<Token>& cached_tokens,
1202     const std::vector<Locale>& detected_text_language_tags,
1203     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
1204     FeatureProcessor::EmbeddingCache* embedding_cache,
1205     std::vector<ClassificationResult>* classification_results,
1206     std::vector<Token>* tokens) const {
1207   if (model_->triggering_options() == nullptr ||
1208       !(model_->triggering_options()->enabled_modes() &
1209         ModeFlag_CLASSIFICATION)) {
1210     return true;
1211   }
1212 
1213   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1214                                     ml_model_triggering_locales_,
1215                                     /*default_value=*/true)) {
1216     return true;
1217   }
1218 
1219   if (cached_tokens.empty()) {
1220     *tokens = classification_feature_processor_->Tokenize(context);
1221   } else {
1222     *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
1223                                          ClassifyTextUpperBoundNeededTokens());
1224   }
1225 
1226   int click_pos;
1227   classification_feature_processor_->RetokenizeAndFindClick(
1228       context, selection_indices,
1229       classification_feature_processor_->GetOptions()
1230           ->only_use_line_with_click(),
1231       tokens, &click_pos);
1232   const TokenSpan selection_token_span =
1233       CodepointSpanToTokenSpan(*tokens, selection_indices);
1234   const int selection_num_tokens = TokenSpanSize(selection_token_span);
1235   if (model_->classification_options()->max_num_tokens() > 0 &&
1236       model_->classification_options()->max_num_tokens() <
1237           selection_num_tokens) {
1238     *classification_results = {{Collections::Other(), 1.0}};
1239     return true;
1240   }
1241 
1242   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
1243       bounds_sensitive_features =
1244           classification_feature_processor_->GetOptions()
1245               ->bounds_sensitive_features();
1246   if (selection_token_span.first == kInvalidIndex ||
1247       selection_token_span.second == kInvalidIndex) {
1248     TC3_LOG(ERROR) << "Could not determine span.";
1249     return false;
1250   }
1251 
1252   // Compute the extraction span based on the model type.
1253   TokenSpan extraction_span;
1254   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1255     // The extraction span is the selection span expanded to include a relevant
1256     // number of tokens outside of the bounds of the selection.
1257     extraction_span = ExpandTokenSpan(
1258         selection_token_span,
1259         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
1260         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
1261   } else {
1262     if (click_pos == kInvalidIndex) {
1263       TC3_LOG(ERROR) << "Couldn't choose a click position.";
1264       return false;
1265     }
1266     // The extraction span is the clicked token with context_size tokens on
1267     // either side.
1268     const int context_size =
1269         classification_feature_processor_->GetOptions()->context_size();
1270     extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
1271                                       /*num_tokens_left=*/context_size,
1272                                       /*num_tokens_right=*/context_size);
1273   }
1274   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
1275 
1276   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
1277           *tokens, extraction_span)) {
1278     *classification_results = {{Collections::Other(), 1.0}};
1279     return true;
1280   }
1281 
1282   std::unique_ptr<CachedFeatures> cached_features;
1283   if (!classification_feature_processor_->ExtractFeatures(
1284           *tokens, extraction_span, selection_indices,
1285           embedding_executor_.get(), embedding_cache,
1286           classification_feature_processor_->EmbeddingSize() +
1287               classification_feature_processor_->DenseFeaturesCount(),
1288           &cached_features)) {
1289     TC3_LOG(ERROR) << "Could not extract features.";
1290     return false;
1291   }
1292 
1293   std::vector<float> features;
1294   features.reserve(cached_features->OutputFeaturesSize());
1295   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
1296     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
1297                                                           &features);
1298   } else {
1299     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
1300   }
1301 
1302   TensorView<float> logits = classification_executor_->ComputeLogits(
1303       TensorView<float>(features.data(),
1304                         {1, static_cast<int>(features.size())}),
1305       interpreter_manager->ClassificationInterpreter());
1306   if (!logits.is_valid()) {
1307     TC3_LOG(ERROR) << "Couldn't compute logits.";
1308     return false;
1309   }
1310 
1311   if (logits.dims() != 2 || logits.dim(0) != 1 ||
1312       logits.dim(1) != classification_feature_processor_->NumCollections()) {
1313     TC3_LOG(ERROR) << "Mismatching output";
1314     return false;
1315   }
1316 
1317   const std::vector<float> scores =
1318       ComputeSoftmax(logits.data(), logits.dim(1));
1319 
1320   if (scores.empty()) {
1321     *classification_results = {{Collections::Other(), 1.0}};
1322     return true;
1323   }
1324 
1325   const int best_score_index =
1326       std::max_element(scores.begin(), scores.end()) - scores.begin();
1327   const std::string top_collection =
1328       classification_feature_processor_->LabelToCollection(best_score_index);
1329 
1330   // Sanity checks.
1331   if (top_collection == Collections::Phone()) {
1332     const int digit_count = CountDigits(context, selection_indices);
1333     if (digit_count <
1334             model_->classification_options()->phone_min_num_digits() ||
1335         digit_count >
1336             model_->classification_options()->phone_max_num_digits()) {
1337       *classification_results = {{Collections::Other(), 1.0}};
1338       return true;
1339     }
1340   } else if (top_collection == Collections::Address()) {
1341     if (selection_num_tokens <
1342         model_->classification_options()->address_min_num_tokens()) {
1343       *classification_results = {{Collections::Other(), 1.0}};
1344       return true;
1345     }
1346   } else if (top_collection == Collections::Dictionary()) {
1347     if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1348                                       dictionary_locales_,
1349                                       /*default_value=*/false)) {
1350       *classification_results = {{Collections::Other(), 1.0}};
1351       return true;
1352     }
1353   }
1354 
1355   *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
1356   return true;
1357 }
1358 
RegexClassifyText(const std::string & context,CodepointSpan selection_indices,std::vector<ClassificationResult> * classification_result) const1359 bool Annotator::RegexClassifyText(
1360     const std::string& context, CodepointSpan selection_indices,
1361     std::vector<ClassificationResult>* classification_result) const {
1362   const std::string selection_text =
1363       UTF8ToUnicodeText(context, /*do_copy=*/false)
1364           .UTF8Substring(selection_indices.first, selection_indices.second);
1365   const UnicodeText selection_text_unicode(
1366       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
1367 
1368   // Check whether any of the regular expressions match.
1369   for (const int pattern_id : classification_regex_patterns_) {
1370     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1371     const std::unique_ptr<UniLib::RegexMatcher> matcher =
1372         regex_pattern.pattern->Matcher(selection_text_unicode);
1373     int status = UniLib::RegexMatcher::kNoError;
1374     bool matches;
1375     if (regex_pattern.config->use_approximate_matching()) {
1376       matches = matcher->ApproximatelyMatches(&status);
1377     } else {
1378       matches = matcher->Matches(&status);
1379     }
1380     if (status != UniLib::RegexMatcher::kNoError) {
1381       return false;
1382     }
1383     if (matches && VerifyRegexMatchCandidate(
1384                        context, regex_pattern.config->verification_options(),
1385                        selection_text, matcher.get())) {
1386       classification_result->push_back(
1387           {regex_pattern.config->collection_name()->str(),
1388            regex_pattern.config->target_classification_score(),
1389            regex_pattern.config->priority_score()});
1390       if (!SerializedEntityDataFromRegexMatch(
1391               regex_pattern.config, matcher.get(),
1392               &classification_result->back().serialized_entity_data)) {
1393         TC3_LOG(ERROR) << "Could not get entity data.";
1394         return false;
1395       }
1396     }
1397   }
1398 
1399   return true;
1400 }
1401 
1402 namespace {
PickCollectionForDatetime(const DatetimeParseResult & datetime_parse_result)1403 std::string PickCollectionForDatetime(
1404     const DatetimeParseResult& datetime_parse_result) {
1405   switch (datetime_parse_result.granularity) {
1406     case GRANULARITY_HOUR:
1407     case GRANULARITY_MINUTE:
1408     case GRANULARITY_SECOND:
1409       return Collections::DateTime();
1410     default:
1411       return Collections::Date();
1412   }
1413 }
1414 
CreateDatetimeSerializedEntityData(const DatetimeParseResult & parse_result)1415 std::string CreateDatetimeSerializedEntityData(
1416     const DatetimeParseResult& parse_result) {
1417   EntityDataT entity_data;
1418   entity_data.datetime.reset(new EntityData_::DatetimeT());
1419   entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
1420   entity_data.datetime->granularity =
1421       static_cast<EntityData_::Datetime_::Granularity>(
1422           parse_result.granularity);
1423 
1424   flatbuffers::FlatBufferBuilder builder;
1425   FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
1426   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
1427                      builder.GetSize());
1428 }
1429 }  // namespace
1430 
DatetimeClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options,std::vector<ClassificationResult> * classification_results) const1431 bool Annotator::DatetimeClassifyText(
1432     const std::string& context, CodepointSpan selection_indices,
1433     const ClassificationOptions& options,
1434     std::vector<ClassificationResult>* classification_results) const {
1435   if (!datetime_parser_) {
1436     return false;
1437   }
1438 
1439   const std::string selection_text =
1440       UTF8ToUnicodeText(context, /*do_copy=*/false)
1441           .UTF8Substring(selection_indices.first, selection_indices.second);
1442 
1443   std::vector<DatetimeParseResultSpan> datetime_spans;
1444   if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1445                                options.reference_timezone, options.locales,
1446                                ModeFlag_CLASSIFICATION,
1447                                options.annotation_usecase,
1448                                /*anchor_start_end=*/true, &datetime_spans)) {
1449     TC3_LOG(ERROR) << "Error during parsing datetime.";
1450     return false;
1451   }
1452   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1453     // Only consider the result valid if the selection and extracted datetime
1454     // spans exactly match.
1455     if (std::make_pair(datetime_span.span.first + selection_indices.first,
1456                        datetime_span.span.second + selection_indices.first) ==
1457         selection_indices) {
1458       for (const DatetimeParseResult& parse_result : datetime_span.data) {
1459         classification_results->emplace_back(
1460             PickCollectionForDatetime(parse_result),
1461             datetime_span.target_classification_score);
1462         classification_results->back().datetime_parse_result = parse_result;
1463         classification_results->back().serialized_entity_data =
1464             CreateDatetimeSerializedEntityData(parse_result);
1465         classification_results->back().priority_score =
1466             datetime_span.priority_score;
1467       }
1468       return true;
1469     }
1470   }
1471   return true;
1472 }
1473 
ClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options) const1474 std::vector<ClassificationResult> Annotator::ClassifyText(
1475     const std::string& context, CodepointSpan selection_indices,
1476     const ClassificationOptions& options) const {
1477   if (!initialized_) {
1478     TC3_LOG(ERROR) << "Not initialized";
1479     return {};
1480   }
1481 
1482   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1483     return {};
1484   }
1485 
1486   std::vector<Locale> detected_text_language_tags;
1487   if (!ParseLocales(options.detected_text_language_tags,
1488                     &detected_text_language_tags)) {
1489     TC3_LOG(WARNING)
1490         << "Failed to parse the detected_text_language_tags in options: "
1491         << options.detected_text_language_tags;
1492   }
1493   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1494                                     model_triggering_locales_,
1495                                     /*default_value=*/true)) {
1496     return {};
1497   }
1498 
1499   if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
1500     return {};
1501   }
1502 
1503   if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
1504     TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
1505                 << std::get<0>(selection_indices) << " "
1506                 << std::get<1>(selection_indices);
1507     return {};
1508   }
1509 
1510   // We'll accumulate a list of candidates, and pick the best candidate in the
1511   // end.
1512   std::vector<AnnotatedSpan> candidates;
1513 
1514   // Try the knowledge engine.
1515   // TODO(b/126579108): Propagate error status.
1516   ClassificationResult knowledge_result;
1517   if (knowledge_engine_ && knowledge_engine_->ClassifyText(
1518                                context, selection_indices, &knowledge_result)) {
1519     candidates.push_back({selection_indices, {knowledge_result}});
1520     candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
1521   }
1522 
1523   // Try the contact engine.
1524   // TODO(b/126579108): Propagate error status.
1525   ClassificationResult contact_result;
1526   if (contact_engine_ && contact_engine_->ClassifyText(
1527                              context, selection_indices, &contact_result)) {
1528     candidates.push_back({selection_indices, {contact_result}});
1529   }
1530 
1531   // Try the installed app engine.
1532   // TODO(b/126579108): Propagate error status.
1533   ClassificationResult installed_app_result;
1534   if (installed_app_engine_ &&
1535       installed_app_engine_->ClassifyText(context, selection_indices,
1536                                           &installed_app_result)) {
1537     candidates.push_back({selection_indices, {installed_app_result}});
1538   }
1539 
1540   // Try the regular expression models.
1541   std::vector<ClassificationResult> regex_results;
1542   if (!RegexClassifyText(context, selection_indices, &regex_results)) {
1543     return {};
1544   }
1545   for (const ClassificationResult& result : regex_results) {
1546     candidates.push_back({selection_indices, {result}});
1547   }
1548 
1549   // Try the date model.
1550   //
1551   // DatetimeClassifyText only returns the first result, which can however have
1552   // more interpretations. They are inserted in the candidates as a single
1553   // AnnotatedSpan, so that they get treated together by the conflict resolution
1554   // algorithm.
1555   std::vector<ClassificationResult> datetime_results;
1556   if (!DatetimeClassifyText(context, selection_indices, options,
1557                             &datetime_results)) {
1558     return {};
1559   }
1560   if (!datetime_results.empty()) {
1561     candidates.push_back({selection_indices, std::move(datetime_results)});
1562     candidates.back().source = AnnotatedSpan::Source::DATETIME;
1563   }
1564 
1565   // Try the number annotator.
1566   // TODO(b/126579108): Propagate error status.
1567   ClassificationResult number_annotator_result;
1568   if (number_annotator_ &&
1569       number_annotator_->ClassifyText(
1570           UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1571           options.annotation_usecase, &number_annotator_result)) {
1572     candidates.push_back({selection_indices, {number_annotator_result}});
1573   }
1574 
1575   // Try the duration annotator.
1576   ClassificationResult duration_annotator_result;
1577   if (duration_annotator_ &&
1578       duration_annotator_->ClassifyText(
1579           UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
1580           options.annotation_usecase, &duration_annotator_result)) {
1581     candidates.push_back({selection_indices, {duration_annotator_result}});
1582     candidates.back().source = AnnotatedSpan::Source::DURATION;
1583   }
1584 
1585   // Try the ML model.
1586   //
1587   // The output of the model is considered as an exclusive 1-of-N choice. That's
1588   // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
1589   // span for each candidate, like e.g. the regex model.
1590   InterpreterManager interpreter_manager(selection_executor_.get(),
1591                                          classification_executor_.get());
1592   std::vector<ClassificationResult> model_results;
1593   std::vector<Token> tokens;
1594   if (!ModelClassifyText(
1595           context, /*cached_tokens=*/{}, detected_text_language_tags,
1596           selection_indices, &interpreter_manager,
1597           /*embedding_cache=*/nullptr, &model_results, &tokens)) {
1598     return {};
1599   }
1600   if (!model_results.empty()) {
1601     candidates.push_back({selection_indices, std::move(model_results)});
1602   }
1603 
1604   std::vector<int> candidate_indices;
1605   if (!ResolveConflicts(candidates, context, tokens,
1606                         detected_text_language_tags, options.annotation_usecase,
1607                         &interpreter_manager, &candidate_indices)) {
1608     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1609     return {};
1610   }
1611 
1612   std::vector<ClassificationResult> results;
1613   for (const int i : candidate_indices) {
1614     for (const ClassificationResult& result : candidates[i].classification) {
1615       if (!FilteredForClassification(result)) {
1616         results.push_back(result);
1617       }
1618     }
1619   }
1620 
1621   // Sort results according to score.
1622   std::sort(results.begin(), results.end(),
1623             [](const ClassificationResult& a, const ClassificationResult& b) {
1624               return a.score > b.score;
1625             });
1626 
1627   if (results.empty()) {
1628     results = {{Collections::Other(), 1.0}};
1629   }
1630   return results;
1631 }
1632 
ModelAnnotate(const std::string & context,const std::vector<Locale> & detected_text_language_tags,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1633 bool Annotator::ModelAnnotate(
1634     const std::string& context,
1635     const std::vector<Locale>& detected_text_language_tags,
1636     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
1637     std::vector<AnnotatedSpan>* result) const {
1638   if (model_->triggering_options() == nullptr ||
1639       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1640     return true;
1641   }
1642 
1643   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1644                                     ml_model_triggering_locales_,
1645                                     /*default_value=*/true)) {
1646     return true;
1647   }
1648 
1649   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1650                                                         /*do_copy=*/false);
1651   std::vector<UnicodeTextRange> lines;
1652   if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1653     lines.push_back({context_unicode.begin(), context_unicode.end()});
1654   } else {
1655     lines = selection_feature_processor_->SplitContext(context_unicode);
1656   }
1657 
1658   const float min_annotate_confidence =
1659       (model_->triggering_options() != nullptr
1660            ? model_->triggering_options()->min_annotate_confidence()
1661            : 0.f);
1662 
1663   for (const UnicodeTextRange& line : lines) {
1664     FeatureProcessor::EmbeddingCache embedding_cache;
1665     const std::string line_str =
1666         UnicodeText::UTF8Substring(line.first, line.second);
1667 
1668     *tokens = selection_feature_processor_->Tokenize(line_str);
1669     selection_feature_processor_->RetokenizeAndFindClick(
1670         line_str, {0, std::distance(line.first, line.second)},
1671         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1672         tokens,
1673         /*click_pos=*/nullptr);
1674     const TokenSpan full_line_span = {0, tokens->size()};
1675 
1676     // TODO(zilka): Add support for greater granularity of this check.
1677     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1678             *tokens, full_line_span)) {
1679       continue;
1680     }
1681 
1682     std::unique_ptr<CachedFeatures> cached_features;
1683     if (!selection_feature_processor_->ExtractFeatures(
1684             *tokens, full_line_span,
1685             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1686             embedding_executor_.get(),
1687             /*embedding_cache=*/nullptr,
1688             selection_feature_processor_->EmbeddingSize() +
1689                 selection_feature_processor_->DenseFeaturesCount(),
1690             &cached_features)) {
1691       TC3_LOG(ERROR) << "Could not extract features.";
1692       return false;
1693     }
1694 
1695     std::vector<TokenSpan> local_chunks;
1696     if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
1697                     interpreter_manager->SelectionInterpreter(),
1698                     *cached_features, &local_chunks)) {
1699       TC3_LOG(ERROR) << "Could not chunk.";
1700       return false;
1701     }
1702 
1703     const int offset = std::distance(context_unicode.begin(), line.first);
1704     for (const TokenSpan& chunk : local_chunks) {
1705       const CodepointSpan codepoint_span =
1706           selection_feature_processor_->StripBoundaryCodepoints(
1707               line_str, TokenSpanToCodepointSpan(*tokens, chunk));
1708 
1709       // Skip empty spans.
1710       if (codepoint_span.first != codepoint_span.second) {
1711         std::vector<ClassificationResult> classification;
1712         if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
1713                                codepoint_span, interpreter_manager,
1714                                &embedding_cache, &classification)) {
1715           TC3_LOG(ERROR) << "Could not classify text: "
1716                          << (codepoint_span.first + offset) << " "
1717                          << (codepoint_span.second + offset);
1718           return false;
1719         }
1720 
1721         // Do not include the span if it's classified as "other".
1722         if (!classification.empty() && !ClassifiedAsOther(classification) &&
1723             classification[0].score >= min_annotate_confidence) {
1724           AnnotatedSpan result_span;
1725           result_span.span = {codepoint_span.first + offset,
1726                               codepoint_span.second + offset};
1727           result_span.classification = std::move(classification);
1728           result->push_back(std::move(result_span));
1729         }
1730       }
1731     }
1732   }
1733   return true;
1734 }
1735 
SelectionFeatureProcessorForTests() const1736 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
1737   return selection_feature_processor_.get();
1738 }
1739 
ClassificationFeatureProcessorForTests() const1740 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
1741     const {
1742   return classification_feature_processor_.get();
1743 }
1744 
DatetimeParserForTests() const1745 const DatetimeParser* Annotator::DatetimeParserForTests() const {
1746   return datetime_parser_.get();
1747 }
1748 
RemoveNotEnabledEntityTypes(const EnabledEntityTypes & is_entity_type_enabled,std::vector<AnnotatedSpan> * annotated_spans) const1749 void Annotator::RemoveNotEnabledEntityTypes(
1750     const EnabledEntityTypes& is_entity_type_enabled,
1751     std::vector<AnnotatedSpan>* annotated_spans) const {
1752   for (AnnotatedSpan& annotated_span : *annotated_spans) {
1753     std::vector<ClassificationResult>& classifications =
1754         annotated_span.classification;
1755     classifications.erase(
1756         std::remove_if(classifications.begin(), classifications.end(),
1757                        [&is_entity_type_enabled](
1758                            const ClassificationResult& classification_result) {
1759                          return !is_entity_type_enabled(
1760                              classification_result.collection);
1761                        }),
1762         classifications.end());
1763   }
1764   annotated_spans->erase(
1765       std::remove_if(annotated_spans->begin(), annotated_spans->end(),
1766                      [](const AnnotatedSpan& annotated_span) {
1767                        return annotated_span.classification.empty();
1768                      }),
1769       annotated_spans->end());
1770 }
1771 
Annotate(const std::string & context,const AnnotationOptions & options) const1772 std::vector<AnnotatedSpan> Annotator::Annotate(
1773     const std::string& context, const AnnotationOptions& options) const {
1774   std::vector<AnnotatedSpan> candidates;
1775 
1776   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
1777     return {};
1778   }
1779 
1780   const UnicodeText context_unicode =
1781       UTF8ToUnicodeText(context, /*do_copy=*/false);
1782   if (!context_unicode.is_valid()) {
1783     return {};
1784   }
1785 
1786   std::vector<Locale> detected_text_language_tags;
1787   if (!ParseLocales(options.detected_text_language_tags,
1788                     &detected_text_language_tags)) {
1789     TC3_LOG(WARNING)
1790         << "Failed to parse the detected_text_language_tags in options: "
1791         << options.detected_text_language_tags;
1792   }
1793   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
1794                                     model_triggering_locales_,
1795                                     /*default_value=*/true)) {
1796     return {};
1797   }
1798 
1799   InterpreterManager interpreter_manager(selection_executor_.get(),
1800                                          classification_executor_.get());
1801 
1802   // Annotate with the selection model.
1803   std::vector<Token> tokens;
1804   if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
1805                      &tokens, &candidates)) {
1806     TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
1807     return {};
1808   }
1809 
1810   // Annotate with the regular expression models.
1811   if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1812                   annotation_regex_patterns_, &candidates,
1813                   options.is_serialized_entity_data_enabled)) {
1814     TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
1815     return {};
1816   }
1817 
1818   // Annotate with the datetime model.
1819   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
1820   if ((is_entity_type_enabled(Collections::Date()) ||
1821        is_entity_type_enabled(Collections::DateTime())) &&
1822       !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1823                      options.reference_time_ms_utc, options.reference_timezone,
1824                      options.locales, ModeFlag_ANNOTATION,
1825                      options.annotation_usecase,
1826                      options.is_serialized_entity_data_enabled, &candidates)) {
1827     TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
1828     return {};
1829   }
1830 
1831   // Annotate with the knowledge engine.
1832   if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
1833     TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
1834     return {};
1835   }
1836 
1837   // Annotate with the contact engine.
1838   if (contact_engine_ &&
1839       !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
1840     TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
1841     return {};
1842   }
1843 
1844   // Annotate with the installed app engine.
1845   if (installed_app_engine_ &&
1846       !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
1847     TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
1848     return {};
1849   }
1850 
1851   // Annotate with the number annotator.
1852   if (number_annotator_ != nullptr &&
1853       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
1854                                   &candidates)) {
1855     TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
1856     return {};
1857   }
1858 
1859   // Annotate with the duration annotator.
1860   if (is_entity_type_enabled(Collections::Duration()) &&
1861       duration_annotator_ != nullptr &&
1862       !duration_annotator_->FindAll(context_unicode, tokens,
1863                                     options.annotation_usecase, &candidates)) {
1864     TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
1865     return {};
1866   }
1867 
1868   // Sort candidates according to their position in the input, so that the next
1869   // code can assume that any connected component of overlapping spans forms a
1870   // contiguous block.
1871   std::sort(candidates.begin(), candidates.end(),
1872             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
1873               return a.span.first < b.span.first;
1874             });
1875 
1876   std::vector<int> candidate_indices;
1877   if (!ResolveConflicts(candidates, context, tokens,
1878                         detected_text_language_tags, options.annotation_usecase,
1879                         &interpreter_manager, &candidate_indices)) {
1880     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
1881     return {};
1882   }
1883 
1884   std::vector<AnnotatedSpan> result;
1885   result.reserve(candidate_indices.size());
1886   AnnotatedSpan aggregated_span;
1887   for (const int i : candidate_indices) {
1888     if (candidates[i].span != aggregated_span.span) {
1889       if (!aggregated_span.classification.empty()) {
1890         result.push_back(std::move(aggregated_span));
1891       }
1892       aggregated_span =
1893           AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
1894     }
1895     if (candidates[i].classification.empty() ||
1896         ClassifiedAsOther(candidates[i].classification) ||
1897         FilteredForAnnotation(candidates[i])) {
1898       continue;
1899     }
1900     for (ClassificationResult& classification : candidates[i].classification) {
1901       aggregated_span.classification.push_back(std::move(classification));
1902     }
1903   }
1904   if (!aggregated_span.classification.empty()) {
1905     result.push_back(std::move(aggregated_span));
1906   }
1907 
1908   // We generate all candidates and remove them later (with the exception of
1909   // date/time/duration entities) because there are complex interdependencies
1910   // between the entity types. E.g., the TLD of an email can be interpreted as a
1911   // URL, but most likely a user of the API does not want such annotations if
1912   // "url" is enabled and "email" is not.
1913   RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
1914 
1915   for (AnnotatedSpan& annotated_span : result) {
1916     SortClassificationResults(&annotated_span.classification);
1917   }
1918 
1919   return result;
1920 }
1921 
ComputeSelectionBoundaries(const UniLib::RegexMatcher * match,const RegexModel_::Pattern * config) const1922 CodepointSpan Annotator::ComputeSelectionBoundaries(
1923     const UniLib::RegexMatcher* match,
1924     const RegexModel_::Pattern* config) const {
1925   if (config->capturing_group() == nullptr) {
1926     // Use first capturing group to specify the selection.
1927     int status = UniLib::RegexMatcher::kNoError;
1928     const CodepointSpan result = {match->Start(1, &status),
1929                                   match->End(1, &status)};
1930     if (status != UniLib::RegexMatcher::kNoError) {
1931       return {kInvalidIndex, kInvalidIndex};
1932     }
1933     return result;
1934   }
1935 
1936   CodepointSpan result = {kInvalidIndex, kInvalidIndex};
1937   const int num_groups = config->capturing_group()->size();
1938   for (int i = 0; i < num_groups; i++) {
1939     if (!config->capturing_group()->Get(i)->extend_selection()) {
1940       continue;
1941     }
1942 
1943     int status = UniLib::RegexMatcher::kNoError;
1944     // Check match and adjust bounds.
1945     const int group_start = match->Start(i, &status);
1946     const int group_end = match->End(i, &status);
1947     if (status != UniLib::RegexMatcher::kNoError) {
1948       return {kInvalidIndex, kInvalidIndex};
1949     }
1950     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
1951       continue;
1952     }
1953     if (result.first == kInvalidIndex) {
1954       result = {group_start, group_end};
1955     } else {
1956       result.first = std::min(result.first, group_start);
1957       result.second = std::max(result.second, group_end);
1958     }
1959   }
1960   return result;
1961 }
1962 
HasEntityData(const RegexModel_::Pattern * pattern) const1963 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
1964   if (pattern->serialized_entity_data() != nullptr) {
1965     return true;
1966   }
1967   if (pattern->capturing_group() != nullptr) {
1968     for (const RegexModel_::Pattern_::CapturingGroup* group :
1969          *pattern->capturing_group()) {
1970       if (group->entity_field_path() != nullptr) {
1971         return true;
1972       }
1973     }
1974   }
1975   return false;
1976 }
1977 
SerializedEntityDataFromRegexMatch(const RegexModel_::Pattern * pattern,UniLib::RegexMatcher * matcher,std::string * serialized_entity_data) const1978 bool Annotator::SerializedEntityDataFromRegexMatch(
1979     const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
1980     std::string* serialized_entity_data) const {
1981   if (!HasEntityData(pattern)) {
1982     serialized_entity_data->clear();
1983     return true;
1984   }
1985   TC3_CHECK(entity_data_builder_ != nullptr);
1986 
1987   std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1988       entity_data_builder_->NewRoot();
1989 
1990   TC3_CHECK(entity_data != nullptr);
1991 
1992   // Set static entity data.
1993   if (pattern->serialized_entity_data() != nullptr) {
1994     TC3_CHECK(entity_data != nullptr);
1995     entity_data->MergeFromSerializedFlatbuffer(
1996         StringPiece(pattern->serialized_entity_data()->c_str(),
1997                     pattern->serialized_entity_data()->size()));
1998   }
1999 
2000   // Add entity data from rule capturing groups.
2001   if (pattern->capturing_group() != nullptr) {
2002     const int num_groups = pattern->capturing_group()->size();
2003     for (int i = 0; i < num_groups; i++) {
2004       const FlatbufferFieldPath* field_path =
2005           pattern->capturing_group()->Get(i)->entity_field_path();
2006       if (field_path == nullptr) {
2007         continue;
2008       }
2009       TC3_CHECK(entity_data != nullptr);
2010       if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
2011                                       entity_data.get())) {
2012         TC3_LOG(ERROR)
2013             << "Could not set entity data from rule capturing group.";
2014         return false;
2015       }
2016     }
2017   }
2018 
2019   *serialized_entity_data = entity_data->Serialize();
2020   return true;
2021 }
2022 
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,std::vector<AnnotatedSpan> * result,bool is_serialized_entity_data_enabled) const2023 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
2024                            const std::vector<int>& rules,
2025                            std::vector<AnnotatedSpan>* result,
2026                            bool is_serialized_entity_data_enabled) const {
2027   for (int pattern_id : rules) {
2028     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
2029     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
2030     if (!matcher) {
2031       TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
2032                      << pattern_id;
2033       return false;
2034     }
2035 
2036     int status = UniLib::RegexMatcher::kNoError;
2037     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
2038       if (regex_pattern.config->verification_options()) {
2039         if (!VerifyRegexMatchCandidate(
2040                 context_unicode.ToUTF8String(),
2041                 regex_pattern.config->verification_options(),
2042                 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
2043           continue;
2044         }
2045       }
2046 
2047       std::string serialized_entity_data;
2048       if (is_serialized_entity_data_enabled) {
2049         if (!SerializedEntityDataFromRegexMatch(
2050                 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
2051           TC3_LOG(ERROR) << "Could not get entity data.";
2052           return false;
2053         }
2054       }
2055 
2056       result->emplace_back();
2057 
2058       // Selection/annotation regular expressions need to specify a capturing
2059       // group specifying the selection.
2060       result->back().span =
2061           ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
2062 
2063       result->back().classification = {
2064           {regex_pattern.config->collection_name()->str(),
2065            regex_pattern.config->target_classification_score(),
2066            regex_pattern.config->priority_score()}};
2067 
2068       result->back().classification[0].serialized_entity_data =
2069           serialized_entity_data;
2070     }
2071   }
2072   return true;
2073 }
2074 
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const2075 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
2076                            tflite::Interpreter* selection_interpreter,
2077                            const CachedFeatures& cached_features,
2078                            std::vector<TokenSpan>* chunks) const {
2079   const int max_selection_span =
2080       selection_feature_processor_->GetOptions()->max_selection_span();
2081   // The inference span is the span of interest expanded to include
2082   // max_selection_span tokens on either side, which is how far a selection can
2083   // stretch from the click.
2084   const TokenSpan inference_span = IntersectTokenSpans(
2085       ExpandTokenSpan(span_of_interest,
2086                       /*num_tokens_left=*/max_selection_span,
2087                       /*num_tokens_right=*/max_selection_span),
2088       {0, num_tokens});
2089 
2090   std::vector<ScoredChunk> scored_chunks;
2091   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
2092       selection_feature_processor_->GetOptions()
2093           ->bounds_sensitive_features()
2094           ->enabled()) {
2095     if (!ModelBoundsSensitiveScoreChunks(
2096             num_tokens, span_of_interest, inference_span, cached_features,
2097             selection_interpreter, &scored_chunks)) {
2098       return false;
2099     }
2100   } else {
2101     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
2102                                       cached_features, selection_interpreter,
2103                                       &scored_chunks)) {
2104       return false;
2105     }
2106   }
2107   std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
2108             [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
2109               return lhs.score < rhs.score;
2110             });
2111 
2112   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
2113   // them greedily as long as they do not overlap with any previously picked
2114   // chunks.
2115   std::vector<bool> token_used(TokenSpanSize(inference_span));
2116   chunks->clear();
2117   for (const ScoredChunk& scored_chunk : scored_chunks) {
2118     bool feasible = true;
2119     for (int i = scored_chunk.token_span.first;
2120          i < scored_chunk.token_span.second; ++i) {
2121       if (token_used[i - inference_span.first]) {
2122         feasible = false;
2123         break;
2124       }
2125     }
2126 
2127     if (!feasible) {
2128       continue;
2129     }
2130 
2131     for (int i = scored_chunk.token_span.first;
2132          i < scored_chunk.token_span.second; ++i) {
2133       token_used[i - inference_span.first] = true;
2134     }
2135 
2136     chunks->push_back(scored_chunk.token_span);
2137   }
2138 
2139   std::sort(chunks->begin(), chunks->end());
2140 
2141   return true;
2142 }
2143 
2144 namespace {
2145 // Updates the value at the given key in the map to maximum of the current value
2146 // and the given value, or simply inserts the value if the key is not yet there.
2147 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)2148 void UpdateMax(Map* map, typename Map::key_type key,
2149                typename Map::mapped_type value) {
2150   const auto it = map->find(key);
2151   if (it != map->end()) {
2152     it->second = std::max(it->second, value);
2153   } else {
2154     (*map)[key] = value;
2155   }
2156 }
2157 }  // namespace
2158 
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2159 bool Annotator::ModelClickContextScoreChunks(
2160     int num_tokens, const TokenSpan& span_of_interest,
2161     const CachedFeatures& cached_features,
2162     tflite::Interpreter* selection_interpreter,
2163     std::vector<ScoredChunk>* scored_chunks) const {
2164   const int max_batch_size = model_->selection_options()->batch_size();
2165 
2166   std::vector<float> all_features;
2167   std::map<TokenSpan, float> chunk_scores;
2168   for (int batch_start = span_of_interest.first;
2169        batch_start < span_of_interest.second; batch_start += max_batch_size) {
2170     const int batch_end =
2171         std::min(batch_start + max_batch_size, span_of_interest.second);
2172 
2173     // Prepare features for the whole batch.
2174     all_features.clear();
2175     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2176     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2177       cached_features.AppendClickContextFeaturesForClick(click_pos,
2178                                                          &all_features);
2179     }
2180 
2181     // Run batched inference.
2182     const int batch_size = batch_end - batch_start;
2183     const int features_size = cached_features.OutputFeaturesSize();
2184     TensorView<float> logits = selection_executor_->ComputeLogits(
2185         TensorView<float>(all_features.data(), {batch_size, features_size}),
2186         selection_interpreter);
2187     if (!logits.is_valid()) {
2188       TC3_LOG(ERROR) << "Couldn't compute logits.";
2189       return false;
2190     }
2191     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2192         logits.dim(1) !=
2193             selection_feature_processor_->GetSelectionLabelCount()) {
2194       TC3_LOG(ERROR) << "Mismatching output.";
2195       return false;
2196     }
2197 
2198     // Save results.
2199     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
2200       const std::vector<float> scores = ComputeSoftmax(
2201           logits.data() + logits.dim(1) * (click_pos - batch_start),
2202           logits.dim(1));
2203       for (int j = 0;
2204            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
2205         TokenSpan relative_token_span;
2206         if (!selection_feature_processor_->LabelToTokenSpan(
2207                 j, &relative_token_span)) {
2208           TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
2209           return false;
2210         }
2211         const TokenSpan candidate_span = ExpandTokenSpan(
2212             SingleTokenSpan(click_pos), relative_token_span.first,
2213             relative_token_span.second);
2214         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
2215           UpdateMax(&chunk_scores, candidate_span, scores[j]);
2216         }
2217       }
2218     }
2219   }
2220 
2221   scored_chunks->clear();
2222   scored_chunks->reserve(chunk_scores.size());
2223   for (const auto& entry : chunk_scores) {
2224     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
2225   }
2226 
2227   return true;
2228 }
2229 
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const2230 bool Annotator::ModelBoundsSensitiveScoreChunks(
2231     int num_tokens, const TokenSpan& span_of_interest,
2232     const TokenSpan& inference_span, const CachedFeatures& cached_features,
2233     tflite::Interpreter* selection_interpreter,
2234     std::vector<ScoredChunk>* scored_chunks) const {
2235   const int max_selection_span =
2236       selection_feature_processor_->GetOptions()->max_selection_span();
2237   const int max_chunk_length = selection_feature_processor_->GetOptions()
2238                                        ->selection_reduced_output_space()
2239                                    ? max_selection_span + 1
2240                                    : 2 * max_selection_span + 1;
2241   const bool score_single_token_spans_as_zero =
2242       selection_feature_processor_->GetOptions()
2243           ->bounds_sensitive_features()
2244           ->score_single_token_spans_as_zero();
2245 
2246   scored_chunks->clear();
2247   if (score_single_token_spans_as_zero) {
2248     scored_chunks->reserve(TokenSpanSize(span_of_interest));
2249   }
2250 
2251   // Prepare all chunk candidates into one batch:
2252   //   - Are contained in the inference span
2253   //   - Have a non-empty intersection with the span of interest
2254   //   - Are at least one token long
2255   //   - Are not longer than the maximum chunk length
2256   std::vector<TokenSpan> candidate_spans;
2257   for (int start = inference_span.first; start < span_of_interest.second;
2258        ++start) {
2259     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
2260     for (int end = leftmost_end_index;
2261          end <= inference_span.second && end - start <= max_chunk_length;
2262          ++end) {
2263       const TokenSpan candidate_span = {start, end};
2264       if (score_single_token_spans_as_zero &&
2265           TokenSpanSize(candidate_span) == 1) {
2266         // Do not include the single token span in the batch, add a zero score
2267         // for it directly to the output.
2268         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
2269       } else {
2270         candidate_spans.push_back(candidate_span);
2271       }
2272     }
2273   }
2274 
2275   const int max_batch_size = model_->selection_options()->batch_size();
2276 
2277   std::vector<float> all_features;
2278   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
2279   for (int batch_start = 0; batch_start < candidate_spans.size();
2280        batch_start += max_batch_size) {
2281     const int batch_end = std::min(batch_start + max_batch_size,
2282                                    static_cast<int>(candidate_spans.size()));
2283 
2284     // Prepare features for the whole batch.
2285     all_features.clear();
2286     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
2287     for (int i = batch_start; i < batch_end; ++i) {
2288       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
2289                                                            &all_features);
2290     }
2291 
2292     // Run batched inference.
2293     const int batch_size = batch_end - batch_start;
2294     const int features_size = cached_features.OutputFeaturesSize();
2295     TensorView<float> logits = selection_executor_->ComputeLogits(
2296         TensorView<float>(all_features.data(), {batch_size, features_size}),
2297         selection_interpreter);
2298     if (!logits.is_valid()) {
2299       TC3_LOG(ERROR) << "Couldn't compute logits.";
2300       return false;
2301     }
2302     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
2303         logits.dim(1) != 1) {
2304       TC3_LOG(ERROR) << "Mismatching output.";
2305       return false;
2306     }
2307 
2308     // Save results.
2309     for (int i = batch_start; i < batch_end; ++i) {
2310       scored_chunks->push_back(
2311           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
2312     }
2313   }
2314 
2315   return true;
2316 }
2317 
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,AnnotationUsecase annotation_usecase,bool is_serialized_entity_data_enabled,std::vector<AnnotatedSpan> * result) const2318 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
2319                               int64 reference_time_ms_utc,
2320                               const std::string& reference_timezone,
2321                               const std::string& locales, ModeFlag mode,
2322                               AnnotationUsecase annotation_usecase,
2323                               bool is_serialized_entity_data_enabled,
2324                               std::vector<AnnotatedSpan>* result) const {
2325   if (!datetime_parser_) {
2326     return true;
2327   }
2328 
2329   std::vector<DatetimeParseResultSpan> datetime_spans;
2330   if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
2331                                reference_timezone, locales, mode,
2332                                annotation_usecase,
2333                                /*anchor_start_end=*/false, &datetime_spans)) {
2334     return false;
2335   }
2336   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
2337     AnnotatedSpan annotated_span;
2338     annotated_span.span = datetime_span.span;
2339     for (const DatetimeParseResult& parse_result : datetime_span.data) {
2340       annotated_span.classification.emplace_back(
2341           PickCollectionForDatetime(parse_result),
2342           datetime_span.target_classification_score,
2343           datetime_span.priority_score);
2344       annotated_span.classification.back().datetime_parse_result = parse_result;
2345       if (is_serialized_entity_data_enabled) {
2346         annotated_span.classification.back().serialized_entity_data =
2347             CreateDatetimeSerializedEntityData(parse_result);
2348       }
2349     }
2350     annotated_span.source = AnnotatedSpan::Source::DATETIME;
2351     result->push_back(std::move(annotated_span));
2352   }
2353   return true;
2354 }
2355 
model() const2356 const Model* Annotator::model() const { return model_; }
entity_data_schema() const2357 const reflection::Schema* Annotator::entity_data_schema() const {
2358   return entity_data_schema_;
2359 }
2360 
ViewModel(const void * buffer,int size)2361 const Model* ViewModel(const void* buffer, int size) {
2362   if (!buffer) {
2363     return nullptr;
2364   }
2365 
2366   return LoadAndVerifyModel(buffer, size);
2367 }
2368 
LookUpKnowledgeEntity(const std::string & id,std::string * serialized_knowledge_result) const2369 bool Annotator::LookUpKnowledgeEntity(
2370     const std::string& id, std::string* serialized_knowledge_result) const {
2371   return knowledge_engine_ &&
2372          knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
2373 }
2374 
2375 }  // namespace libtextclassifier3
2376