• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "text-classifier.h"
18 
19 #include <algorithm>
20 #include <cctype>
21 #include <cmath>
22 #include <iterator>
23 #include <numeric>
24 
25 #include "util/base/logging.h"
26 #include "util/math/softmax.h"
27 #include "util/utf8/unicodetext.h"
28 
29 namespace libtextclassifier2 {
30 const std::string& TextClassifier::kOtherCollection =
__anon1665e8690102() 31     *[]() { return new std::string("other"); }();
32 const std::string& TextClassifier::kPhoneCollection =
__anon1665e8690202() 33     *[]() { return new std::string("phone"); }();
34 const std::string& TextClassifier::kAddressCollection =
__anon1665e8690302() 35     *[]() { return new std::string("address"); }();
36 const std::string& TextClassifier::kDateCollection =
__anon1665e8690402() 37     *[]() { return new std::string("date"); }();
38 
39 namespace {
LoadAndVerifyModel(const void * addr,int size)40 const Model* LoadAndVerifyModel(const void* addr, int size) {
41   const Model* model = GetModel(addr);
42 
43   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
44   if (model->Verify(verifier)) {
45     return model;
46   } else {
47     return nullptr;
48   }
49 }
50 }  // namespace
51 
SelectionInterpreter()52 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
53   if (!selection_interpreter_) {
54     TC_CHECK(selection_executor_);
55     selection_interpreter_ = selection_executor_->CreateInterpreter();
56     if (!selection_interpreter_) {
57       TC_LOG(ERROR) << "Could not build TFLite interpreter.";
58     }
59   }
60   return selection_interpreter_.get();
61 }
62 
ClassificationInterpreter()63 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
64   if (!classification_interpreter_) {
65     TC_CHECK(classification_executor_);
66     classification_interpreter_ = classification_executor_->CreateInterpreter();
67     if (!classification_interpreter_) {
68       TC_LOG(ERROR) << "Could not build TFLite interpreter.";
69     }
70   }
71   return classification_interpreter_.get();
72 }
73 
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib)74 std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer(
75     const char* buffer, int size, const UniLib* unilib) {
76   const Model* model = LoadAndVerifyModel(buffer, size);
77   if (model == nullptr) {
78     return nullptr;
79   }
80 
81   auto classifier =
82       std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib));
83   if (!classifier->IsInitialized()) {
84     return nullptr;
85   }
86 
87   return classifier;
88 }
89 
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib)90 std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap(
91     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) {
92   if (!(*mmap)->handle().ok()) {
93     TC_VLOG(1) << "Mmap failed.";
94     return nullptr;
95   }
96 
97   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
98                                           (*mmap)->handle().num_bytes());
99   if (!model) {
100     TC_LOG(ERROR) << "Model verification failed.";
101     return nullptr;
102   }
103 
104   auto classifier =
105       std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib));
106   if (!classifier->IsInitialized()) {
107     return nullptr;
108   }
109 
110   return classifier;
111 }
112 
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib)113 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
114     int fd, int offset, int size, const UniLib* unilib) {
115   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
116   return FromScopedMmap(&mmap, unilib);
117 }
118 
FromFileDescriptor(int fd,const UniLib * unilib)119 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
120     int fd, const UniLib* unilib) {
121   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
122   return FromScopedMmap(&mmap, unilib);
123 }
124 
FromPath(const std::string & path,const UniLib * unilib)125 std::unique_ptr<TextClassifier> TextClassifier::FromPath(
126     const std::string& path, const UniLib* unilib) {
127   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
128   return FromScopedMmap(&mmap, unilib);
129 }
130 
ValidateAndInitialize()131 void TextClassifier::ValidateAndInitialize() {
132   initialized_ = false;
133 
134   if (model_ == nullptr) {
135     TC_LOG(ERROR) << "No model specified.";
136     return;
137   }
138 
139   const bool model_enabled_for_annotation =
140       (model_->triggering_options() != nullptr &&
141        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
142   const bool model_enabled_for_classification =
143       (model_->triggering_options() != nullptr &&
144        (model_->triggering_options()->enabled_modes() &
145         ModeFlag_CLASSIFICATION));
146   const bool model_enabled_for_selection =
147       (model_->triggering_options() != nullptr &&
148        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
149 
150   // Annotation requires the selection model.
151   if (model_enabled_for_annotation || model_enabled_for_selection) {
152     if (!model_->selection_options()) {
153       TC_LOG(ERROR) << "No selection options.";
154       return;
155     }
156     if (!model_->selection_feature_options()) {
157       TC_LOG(ERROR) << "No selection feature options.";
158       return;
159     }
160     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
161       TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
162       return;
163     }
164     if (!model_->selection_model()) {
165       TC_LOG(ERROR) << "No selection model.";
166       return;
167     }
168     selection_executor_ = ModelExecutor::Instance(model_->selection_model());
169     if (!selection_executor_) {
170       TC_LOG(ERROR) << "Could not initialize selection executor.";
171       return;
172     }
173     selection_feature_processor_.reset(
174         new FeatureProcessor(model_->selection_feature_options(), unilib_));
175   }
176 
177   // Annotation requires the classification model for conflict resolution and
178   // scoring.
179   // Selection requires the classification model for conflict resolution.
180   if (model_enabled_for_annotation || model_enabled_for_classification ||
181       model_enabled_for_selection) {
182     if (!model_->classification_options()) {
183       TC_LOG(ERROR) << "No classification options.";
184       return;
185     }
186 
187     if (!model_->classification_feature_options()) {
188       TC_LOG(ERROR) << "No classification feature options.";
189       return;
190     }
191 
192     if (!model_->classification_feature_options()
193              ->bounds_sensitive_features()) {
194       TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
195       return;
196     }
197     if (!model_->classification_model()) {
198       TC_LOG(ERROR) << "No clf model.";
199       return;
200     }
201 
202     classification_executor_ =
203         ModelExecutor::Instance(model_->classification_model());
204     if (!classification_executor_) {
205       TC_LOG(ERROR) << "Could not initialize classification executor.";
206       return;
207     }
208 
209     classification_feature_processor_.reset(new FeatureProcessor(
210         model_->classification_feature_options(), unilib_));
211   }
212 
213   // The embeddings need to be specified if the model is to be used for
214   // classification or selection.
215   if (model_enabled_for_annotation || model_enabled_for_classification ||
216       model_enabled_for_selection) {
217     if (!model_->embedding_model()) {
218       TC_LOG(ERROR) << "No embedding model.";
219       return;
220     }
221 
222     // Check that the embedding size of the selection and classification model
223     // matches, as they are using the same embeddings.
224     if (model_enabled_for_selection &&
225         (model_->selection_feature_options()->embedding_size() !=
226              model_->classification_feature_options()->embedding_size() ||
227          model_->selection_feature_options()->embedding_quantization_bits() !=
228              model_->classification_feature_options()
229                  ->embedding_quantization_bits())) {
230       TC_LOG(ERROR) << "Mismatching embedding size/quantization.";
231       return;
232     }
233 
234     embedding_executor_ = TFLiteEmbeddingExecutor::Instance(
235         model_->embedding_model(),
236         model_->classification_feature_options()->embedding_size(),
237         model_->classification_feature_options()
238             ->embedding_quantization_bits());
239     if (!embedding_executor_) {
240       TC_LOG(ERROR) << "Could not initialize embedding executor.";
241       return;
242     }
243   }
244 
245   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
246   if (model_->regex_model()) {
247     if (!InitializeRegexModel(decompressor.get())) {
248       TC_LOG(ERROR) << "Could not initialize regex model.";
249       return;
250     }
251   }
252 
253   if (model_->datetime_model()) {
254     datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(),
255                                                 *unilib_, decompressor.get());
256     if (!datetime_parser_) {
257       TC_LOG(ERROR) << "Could not initialize datetime parser.";
258       return;
259     }
260   }
261 
262   if (model_->output_options()) {
263     if (model_->output_options()->filtered_collections_annotation()) {
264       for (const auto collection :
265            *model_->output_options()->filtered_collections_annotation()) {
266         filtered_collections_annotation_.insert(collection->str());
267       }
268     }
269     if (model_->output_options()->filtered_collections_classification()) {
270       for (const auto collection :
271            *model_->output_options()->filtered_collections_classification()) {
272         filtered_collections_classification_.insert(collection->str());
273       }
274     }
275     if (model_->output_options()->filtered_collections_selection()) {
276       for (const auto collection :
277            *model_->output_options()->filtered_collections_selection()) {
278         filtered_collections_selection_.insert(collection->str());
279       }
280     }
281   }
282 
283   initialized_ = true;
284 }
285 
InitializeRegexModel(ZlibDecompressor * decompressor)286 bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) {
287   if (!model_->regex_model()->patterns()) {
288     return true;
289   }
290 
291   // Initialize pattern recognizers.
292   int regex_pattern_id = 0;
293   for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
294     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
295         UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(),
296                                    regex_pattern->compressed_pattern(),
297                                    decompressor);
298     if (!compiled_pattern) {
299       TC_LOG(INFO) << "Failed to load regex pattern";
300       return false;
301     }
302 
303     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
304       annotation_regex_patterns_.push_back(regex_pattern_id);
305     }
306     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
307       classification_regex_patterns_.push_back(regex_pattern_id);
308     }
309     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
310       selection_regex_patterns_.push_back(regex_pattern_id);
311     }
312     regex_patterns_.push_back({regex_pattern->collection_name()->str(),
313                                regex_pattern->target_classification_score(),
314                                regex_pattern->priority_score(),
315                                std::move(compiled_pattern)});
316     if (regex_pattern->use_approximate_matching()) {
317       regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
318     }
319     ++regex_pattern_id;
320   }
321 
322   return true;
323 }
324 
325 namespace {
326 
CountDigits(const std::string & str,CodepointSpan selection_indices)327 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
328   int count = 0;
329   int i = 0;
330   const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
331   for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
332     if (i >= selection_indices.first && i < selection_indices.second &&
333         isdigit(*it)) {
334       ++count;
335     }
336   }
337   return count;
338 }
339 
ExtractSelection(const std::string & context,CodepointSpan selection_indices)340 std::string ExtractSelection(const std::string& context,
341                              CodepointSpan selection_indices) {
342   const UnicodeText context_unicode =
343       UTF8ToUnicodeText(context, /*do_copy=*/false);
344   auto selection_begin = context_unicode.begin();
345   std::advance(selection_begin, selection_indices.first);
346   auto selection_end = context_unicode.begin();
347   std::advance(selection_end, selection_indices.second);
348   return UnicodeText::UTF8Substring(selection_begin, selection_end);
349 }
350 }  // namespace
351 
352 namespace internal {
353 // Helper function, which if the initial 'span' contains only white-spaces,
354 // moves the selection to a single-codepoint selection on a left or right side
355 // of this space.
SnapLeftIfWhitespaceSelection(CodepointSpan span,const UnicodeText & context_unicode,const UniLib & unilib)356 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
357                                             const UnicodeText& context_unicode,
358                                             const UniLib& unilib) {
359   TC_CHECK(ValidNonEmptySpan(span));
360 
361   UnicodeText::const_iterator it;
362 
363   // Check that the current selection is all whitespaces.
364   it = context_unicode.begin();
365   std::advance(it, span.first);
366   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
367     if (!unilib.IsWhitespace(*it)) {
368       return span;
369     }
370   }
371 
372   CodepointSpan result;
373 
374   // Try moving left.
375   result = span;
376   it = context_unicode.begin();
377   std::advance(it, span.first);
378   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
379     --result.first;
380     --it;
381   }
382   result.second = result.first + 1;
383   if (!unilib.IsWhitespace(*it)) {
384     return result;
385   }
386 
387   // If moving left didn't find a non-whitespace character, just return the
388   // original span.
389   return span;
390 }
391 }  // namespace internal
392 
FilteredForAnnotation(const AnnotatedSpan & span) const393 bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const {
394   return !span.classification.empty() &&
395          filtered_collections_annotation_.find(
396              span.classification[0].collection) !=
397              filtered_collections_annotation_.end();
398 }
399 
FilteredForClassification(const ClassificationResult & classification) const400 bool TextClassifier::FilteredForClassification(
401     const ClassificationResult& classification) const {
402   return filtered_collections_classification_.find(classification.collection) !=
403          filtered_collections_classification_.end();
404 }
405 
FilteredForSelection(const AnnotatedSpan & span) const406 bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const {
407   return !span.classification.empty() &&
408          filtered_collections_selection_.find(
409              span.classification[0].collection) !=
410              filtered_collections_selection_.end();
411 }
412 
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const413 CodepointSpan TextClassifier::SuggestSelection(
414     const std::string& context, CodepointSpan click_indices,
415     const SelectionOptions& options) const {
416   CodepointSpan original_click_indices = click_indices;
417   if (!initialized_) {
418     TC_LOG(ERROR) << "Not initialized";
419     return original_click_indices;
420   }
421   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
422     return original_click_indices;
423   }
424 
425   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
426                                                         /*do_copy=*/false);
427 
428   if (!context_unicode.is_valid()) {
429     return original_click_indices;
430   }
431 
432   const int context_codepoint_size = context_unicode.size_codepoints();
433 
434   if (click_indices.first < 0 || click_indices.second < 0 ||
435       click_indices.first >= context_codepoint_size ||
436       click_indices.second > context_codepoint_size ||
437       click_indices.first >= click_indices.second) {
438     TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
439                << click_indices.first << " " << click_indices.second;
440     return original_click_indices;
441   }
442 
443   if (model_->snap_whitespace_selections()) {
444     // We want to expand a purely white-space selection to a multi-selection it
445     // would've been part of. But with this feature disabled we would do a no-
446     // op, because no token is found. Therefore, we need to modify the
447     // 'click_indices' a bit to include a part of the token, so that the click-
448     // finding logic finds the clicked token correctly. This modification is
449     // done by the following function. Note, that it's enough to check the left
450     // side of the current selection, because if the white-space is a part of a
451     // multi-selection, neccessarily both tokens - on the left and the right
452     // sides need to be selected. Thus snapping only to the left is sufficient
453     // (there's a check at the bottom that makes sure that if we snap to the
454     // left token but the result does not contain the initial white-space,
455     // returns the original indices).
456     click_indices = internal::SnapLeftIfWhitespaceSelection(
457         click_indices, context_unicode, *unilib_);
458   }
459 
460   std::vector<AnnotatedSpan> candidates;
461   InterpreterManager interpreter_manager(selection_executor_.get(),
462                                          classification_executor_.get());
463   std::vector<Token> tokens;
464   if (!ModelSuggestSelection(context_unicode, click_indices,
465                              &interpreter_manager, &tokens, &candidates)) {
466     TC_LOG(ERROR) << "Model suggest selection failed.";
467     return original_click_indices;
468   }
469   if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
470     TC_LOG(ERROR) << "Regex suggest selection failed.";
471     return original_click_indices;
472   }
473   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
474                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
475                      options.locales, ModeFlag_SELECTION, &candidates)) {
476     TC_LOG(ERROR) << "Datetime suggest selection failed.";
477     return original_click_indices;
478   }
479 
480   // Sort candidates according to their position in the input, so that the next
481   // code can assume that any connected component of overlapping spans forms a
482   // contiguous block.
483   std::sort(candidates.begin(), candidates.end(),
484             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
485               return a.span.first < b.span.first;
486             });
487 
488   std::vector<int> candidate_indices;
489   if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
490                         &candidate_indices)) {
491     TC_LOG(ERROR) << "Couldn't resolve conflicts.";
492     return original_click_indices;
493   }
494 
495   for (const int i : candidate_indices) {
496     if (SpansOverlap(candidates[i].span, click_indices) &&
497         SpansOverlap(candidates[i].span, original_click_indices)) {
498       // Run model classification if not present but requested and there's a
499       // classification collection filter specified.
500       if (candidates[i].classification.empty() &&
501           model_->selection_options()->always_classify_suggested_selection() &&
502           !filtered_collections_selection_.empty()) {
503         if (!ModelClassifyText(
504                 context, candidates[i].span, &interpreter_manager,
505                 /*embedding_cache=*/nullptr, &candidates[i].classification)) {
506           return original_click_indices;
507         }
508       }
509 
510       // Ignore if span classification is filtered.
511       if (FilteredForSelection(candidates[i])) {
512         return original_click_indices;
513       }
514 
515       return candidates[i].span;
516     }
517   }
518 
519   return original_click_indices;
520 }
521 
522 namespace {
523 // Helper function that returns the index of the first candidate that
524 // transitively does not overlap with the candidate on 'start_index'. If the end
525 // of 'candidates' is reached, it returns the index that points right behind the
526 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)527 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
528                                  int start_index) {
529   int first_non_overlapping = start_index + 1;
530   CodepointSpan conflicting_span = candidates[start_index].span;
531   while (
532       first_non_overlapping < candidates.size() &&
533       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
534     // Grow the span to include the current one.
535     conflicting_span.second = std::max(
536         conflicting_span.second, candidates[first_non_overlapping].span.second);
537 
538     ++first_non_overlapping;
539   }
540   return first_non_overlapping;
541 }
542 }  // namespace
543 
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,InterpreterManager * interpreter_manager,std::vector<int> * result) const544 bool TextClassifier::ResolveConflicts(
545     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
546     const std::vector<Token>& cached_tokens,
547     InterpreterManager* interpreter_manager, std::vector<int>* result) const {
548   result->clear();
549   result->reserve(candidates.size());
550   for (int i = 0; i < candidates.size();) {
551     int first_non_overlapping =
552         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
553 
554     const bool conflict_found = first_non_overlapping != (i + 1);
555     if (conflict_found) {
556       std::vector<int> candidate_indices;
557       if (!ResolveConflict(context, cached_tokens, candidates, i,
558                            first_non_overlapping, interpreter_manager,
559                            &candidate_indices)) {
560         return false;
561       }
562       result->insert(result->end(), candidate_indices.begin(),
563                      candidate_indices.end());
564     } else {
565       result->push_back(i);
566     }
567 
568     // Skip over the whole conflicting group/go to next candidate.
569     i = first_non_overlapping;
570   }
571   return true;
572 }
573 
574 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)575 inline bool ClassifiedAsOther(
576     const std::vector<ClassificationResult>& classification) {
577   return !classification.empty() &&
578          classification[0].collection == TextClassifier::kOtherCollection;
579 }
580 
GetPriorityScore(const std::vector<ClassificationResult> & classification)581 float GetPriorityScore(
582     const std::vector<ClassificationResult>& classification) {
583   if (!ClassifiedAsOther(classification)) {
584     return classification[0].priority_score;
585   } else {
586     return -1.0;
587   }
588 }
589 }  // namespace
590 
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,int start_index,int end_index,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const591 bool TextClassifier::ResolveConflict(
592     const std::string& context, const std::vector<Token>& cached_tokens,
593     const std::vector<AnnotatedSpan>& candidates, int start_index,
594     int end_index, InterpreterManager* interpreter_manager,
595     std::vector<int>* chosen_indices) const {
596   std::vector<int> conflicting_indices;
597   std::unordered_map<int, float> scores;
598   for (int i = start_index; i < end_index; ++i) {
599     conflicting_indices.push_back(i);
600     if (!candidates[i].classification.empty()) {
601       scores[i] = GetPriorityScore(candidates[i].classification);
602       continue;
603     }
604 
605     // OPTIMIZATION: So that we don't have to classify all the ML model
606     // spans apriori, we wait until we get here, when they conflict with
607     // something and we need the actual classification scores. So if the
608     // candidate conflicts and comes from the model, we need to run a
609     // classification to determine its priority:
610     std::vector<ClassificationResult> classification;
611     if (!ModelClassifyText(context, cached_tokens, candidates[i].span,
612                            interpreter_manager,
613                            /*embedding_cache=*/nullptr, &classification)) {
614       return false;
615     }
616 
617     if (!classification.empty()) {
618       scores[i] = GetPriorityScore(classification);
619     }
620   }
621 
622   std::sort(conflicting_indices.begin(), conflicting_indices.end(),
623             [&scores](int i, int j) { return scores[i] > scores[j]; });
624 
625   // Keeps the candidates sorted by their position in the text (their left span
626   // index) for fast retrieval down.
627   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
628       [&candidates](int a, int b) {
629         return candidates[a].span.first < candidates[b].span.first;
630       });
631 
632   // Greedily place the candidates if they don't conflict with the already
633   // placed ones.
634   for (int i = 0; i < conflicting_indices.size(); ++i) {
635     const int considered_candidate = conflicting_indices[i];
636     if (!DoesCandidateConflict(considered_candidate, candidates,
637                                chosen_indices_set)) {
638       chosen_indices_set.insert(considered_candidate);
639     }
640   }
641 
642   *chosen_indices =
643       std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end());
644 
645   return true;
646 }
647 
ModelSuggestSelection(const UnicodeText & context_unicode,CodepointSpan click_indices,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const648 bool TextClassifier::ModelSuggestSelection(
649     const UnicodeText& context_unicode, CodepointSpan click_indices,
650     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
651     std::vector<AnnotatedSpan>* result) const {
652   if (model_->triggering_options() == nullptr ||
653       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
654     return true;
655   }
656 
657   int click_pos;
658   *tokens = selection_feature_processor_->Tokenize(context_unicode);
659   selection_feature_processor_->RetokenizeAndFindClick(
660       context_unicode, click_indices,
661       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
662       tokens, &click_pos);
663   if (click_pos == kInvalidIndex) {
664     TC_VLOG(1) << "Could not calculate the click position.";
665     return false;
666   }
667 
668   const int symmetry_context_size =
669       model_->selection_options()->symmetry_context_size();
670   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
671       bounds_sensitive_features = selection_feature_processor_->GetOptions()
672                                       ->bounds_sensitive_features();
673 
674   // The symmetry context span is the clicked token with symmetry_context_size
675   // tokens on either side.
676   const TokenSpan symmetry_context_span = IntersectTokenSpans(
677       ExpandTokenSpan(SingleTokenSpan(click_pos),
678                       /*num_tokens_left=*/symmetry_context_size,
679                       /*num_tokens_right=*/symmetry_context_size),
680       {0, tokens->size()});
681 
682   // Compute the extraction span based on the model type.
683   TokenSpan extraction_span;
684   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
685     // The extraction span is the symmetry context span expanded to include
686     // max_selection_span tokens on either side, which is how far a selection
687     // can stretch from the click, plus a relevant number of tokens outside of
688     // the bounds of the selection.
689     const int max_selection_span =
690         selection_feature_processor_->GetOptions()->max_selection_span();
691     extraction_span =
692         ExpandTokenSpan(symmetry_context_span,
693                         /*num_tokens_left=*/max_selection_span +
694                             bounds_sensitive_features->num_tokens_before(),
695                         /*num_tokens_right=*/max_selection_span +
696                             bounds_sensitive_features->num_tokens_after());
697   } else {
698     // The extraction span is the symmetry context span expanded to include
699     // context_size tokens on either side.
700     const int context_size =
701         selection_feature_processor_->GetOptions()->context_size();
702     extraction_span = ExpandTokenSpan(symmetry_context_span,
703                                       /*num_tokens_left=*/context_size,
704                                       /*num_tokens_right=*/context_size);
705   }
706   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
707 
708   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
709           *tokens, extraction_span)) {
710     return true;
711   }
712 
713   std::unique_ptr<CachedFeatures> cached_features;
714   if (!selection_feature_processor_->ExtractFeatures(
715           *tokens, extraction_span,
716           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
717           embedding_executor_.get(),
718           /*embedding_cache=*/nullptr,
719           selection_feature_processor_->EmbeddingSize() +
720               selection_feature_processor_->DenseFeaturesCount(),
721           &cached_features)) {
722     TC_LOG(ERROR) << "Could not extract features.";
723     return false;
724   }
725 
726   // Produce selection model candidates.
727   std::vector<TokenSpan> chunks;
728   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
729                   interpreter_manager->SelectionInterpreter(), *cached_features,
730                   &chunks)) {
731     TC_LOG(ERROR) << "Could not chunk.";
732     return false;
733   }
734 
735   for (const TokenSpan& chunk : chunks) {
736     AnnotatedSpan candidate;
737     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
738         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
739     if (model_->selection_options()->strip_unpaired_brackets()) {
740       candidate.span =
741           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
742     }
743 
744     // Only output non-empty spans.
745     if (candidate.span.first != candidate.span.second) {
746       result->push_back(candidate);
747     }
748   }
749   return true;
750 }
751 
ModelClassifyText(const std::string & context,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const752 bool TextClassifier::ModelClassifyText(
753     const std::string& context, CodepointSpan selection_indices,
754     InterpreterManager* interpreter_manager,
755     FeatureProcessor::EmbeddingCache* embedding_cache,
756     std::vector<ClassificationResult>* classification_results) const {
757   if (model_->triggering_options() == nullptr ||
758       !(model_->triggering_options()->enabled_modes() &
759         ModeFlag_CLASSIFICATION)) {
760     return true;
761   }
762   return ModelClassifyText(context, {}, selection_indices, interpreter_manager,
763                            embedding_cache, classification_results);
764 }
765 
766 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,CodepointSpan selection_indices,TokenSpan tokens_around_selection_to_copy)767 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
768                                     CodepointSpan selection_indices,
769                                     TokenSpan tokens_around_selection_to_copy) {
770   const auto first_selection_token = std::upper_bound(
771       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
772       [](int selection_start, const Token& token) {
773         return selection_start < token.end;
774       });
775   const auto last_selection_token = std::lower_bound(
776       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
777       [](const Token& token, int selection_end) {
778         return token.start < selection_end;
779       });
780 
781   const int64 first_token = std::max(
782       static_cast<int64>(0),
783       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
784                          tokens_around_selection_to_copy.first));
785   const int64 last_token = std::min(
786       static_cast<int64>(cached_tokens.size()),
787       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
788                          tokens_around_selection_to_copy.second));
789 
790   std::vector<Token> tokens;
791   tokens.reserve(last_token - first_token);
792   for (int i = first_token; i < last_token; ++i) {
793     tokens.push_back(cached_tokens[i]);
794   }
795   return tokens;
796 }
797 }  // namespace internal
798 
ClassifyTextUpperBoundNeededTokens() const799 TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const {
800   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
801       bounds_sensitive_features =
802           classification_feature_processor_->GetOptions()
803               ->bounds_sensitive_features();
804   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
805     // The extraction span is the selection span expanded to include a relevant
806     // number of tokens outside of the bounds of the selection.
807     return {bounds_sensitive_features->num_tokens_before(),
808             bounds_sensitive_features->num_tokens_after()};
809   } else {
810     // The extraction span is the clicked token with context_size tokens on
811     // either side.
812     const int context_size =
813         selection_feature_processor_->GetOptions()->context_size();
814     return {context_size, context_size};
815   }
816 }
817 
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const818 bool TextClassifier::ModelClassifyText(
819     const std::string& context, const std::vector<Token>& cached_tokens,
820     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
821     FeatureProcessor::EmbeddingCache* embedding_cache,
822     std::vector<ClassificationResult>* classification_results) const {
823   std::vector<Token> tokens;
824   if (cached_tokens.empty()) {
825     tokens = classification_feature_processor_->Tokenize(context);
826   } else {
827     tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
828                                         ClassifyTextUpperBoundNeededTokens());
829   }
830 
831   int click_pos;
832   classification_feature_processor_->RetokenizeAndFindClick(
833       context, selection_indices,
834       classification_feature_processor_->GetOptions()
835           ->only_use_line_with_click(),
836       &tokens, &click_pos);
837   const TokenSpan selection_token_span =
838       CodepointSpanToTokenSpan(tokens, selection_indices);
839   const int selection_num_tokens = TokenSpanSize(selection_token_span);
840   if (model_->classification_options()->max_num_tokens() > 0 &&
841       model_->classification_options()->max_num_tokens() <
842           selection_num_tokens) {
843     *classification_results = {{kOtherCollection, 1.0}};
844     return true;
845   }
846 
847   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
848       bounds_sensitive_features =
849           classification_feature_processor_->GetOptions()
850               ->bounds_sensitive_features();
851   if (selection_token_span.first == kInvalidIndex ||
852       selection_token_span.second == kInvalidIndex) {
853     TC_LOG(ERROR) << "Could not determine span.";
854     return false;
855   }
856 
857   // Compute the extraction span based on the model type.
858   TokenSpan extraction_span;
859   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
860     // The extraction span is the selection span expanded to include a relevant
861     // number of tokens outside of the bounds of the selection.
862     extraction_span = ExpandTokenSpan(
863         selection_token_span,
864         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
865         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
866   } else {
867     if (click_pos == kInvalidIndex) {
868       TC_LOG(ERROR) << "Couldn't choose a click position.";
869       return false;
870     }
871     // The extraction span is the clicked token with context_size tokens on
872     // either side.
873     const int context_size =
874         classification_feature_processor_->GetOptions()->context_size();
875     extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
876                                       /*num_tokens_left=*/context_size,
877                                       /*num_tokens_right=*/context_size);
878   }
879   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
880 
881   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
882           tokens, extraction_span)) {
883     *classification_results = {{kOtherCollection, 1.0}};
884     return true;
885   }
886 
887   std::unique_ptr<CachedFeatures> cached_features;
888   if (!classification_feature_processor_->ExtractFeatures(
889           tokens, extraction_span, selection_indices, embedding_executor_.get(),
890           embedding_cache,
891           classification_feature_processor_->EmbeddingSize() +
892               classification_feature_processor_->DenseFeaturesCount(),
893           &cached_features)) {
894     TC_LOG(ERROR) << "Could not extract features.";
895     return false;
896   }
897 
898   std::vector<float> features;
899   features.reserve(cached_features->OutputFeaturesSize());
900   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
901     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
902                                                           &features);
903   } else {
904     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
905   }
906 
907   TensorView<float> logits = classification_executor_->ComputeLogits(
908       TensorView<float>(features.data(),
909                         {1, static_cast<int>(features.size())}),
910       interpreter_manager->ClassificationInterpreter());
911   if (!logits.is_valid()) {
912     TC_LOG(ERROR) << "Couldn't compute logits.";
913     return false;
914   }
915 
916   if (logits.dims() != 2 || logits.dim(0) != 1 ||
917       logits.dim(1) != classification_feature_processor_->NumCollections()) {
918     TC_LOG(ERROR) << "Mismatching output";
919     return false;
920   }
921 
922   const std::vector<float> scores =
923       ComputeSoftmax(logits.data(), logits.dim(1));
924 
925   classification_results->resize(scores.size());
926   for (int i = 0; i < scores.size(); i++) {
927     (*classification_results)[i] = {
928         classification_feature_processor_->LabelToCollection(i), scores[i]};
929   }
930   std::sort(classification_results->begin(), classification_results->end(),
931             [](const ClassificationResult& a, const ClassificationResult& b) {
932               return a.score > b.score;
933             });
934 
935   // Phone class sanity check.
936   if (!classification_results->empty() &&
937       classification_results->begin()->collection == kPhoneCollection) {
938     const int digit_count = CountDigits(context, selection_indices);
939     if (digit_count <
940             model_->classification_options()->phone_min_num_digits() ||
941         digit_count >
942             model_->classification_options()->phone_max_num_digits()) {
943       *classification_results = {{kOtherCollection, 1.0}};
944     }
945   }
946 
947   // Address class sanity check.
948   if (!classification_results->empty() &&
949       classification_results->begin()->collection == kAddressCollection) {
950     if (selection_num_tokens <
951         model_->classification_options()->address_min_num_tokens()) {
952       *classification_results = {{kOtherCollection, 1.0}};
953     }
954   }
955 
956   return true;
957 }
958 
RegexClassifyText(const std::string & context,CodepointSpan selection_indices,ClassificationResult * classification_result) const959 bool TextClassifier::RegexClassifyText(
960     const std::string& context, CodepointSpan selection_indices,
961     ClassificationResult* classification_result) const {
962   const std::string selection_text =
963       ExtractSelection(context, selection_indices);
964   const UnicodeText selection_text_unicode(
965       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
966 
967   // Check whether any of the regular expressions match.
968   for (const int pattern_id : classification_regex_patterns_) {
969     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
970     const std::unique_ptr<UniLib::RegexMatcher> matcher =
971         regex_pattern.pattern->Matcher(selection_text_unicode);
972     int status = UniLib::RegexMatcher::kNoError;
973     bool matches;
974     if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
975         regex_approximate_match_pattern_ids_.end()) {
976       matches = matcher->ApproximatelyMatches(&status);
977     } else {
978       matches = matcher->Matches(&status);
979     }
980     if (status != UniLib::RegexMatcher::kNoError) {
981       return false;
982     }
983     if (matches) {
984       *classification_result = {regex_pattern.collection_name,
985                                 regex_pattern.target_classification_score,
986                                 regex_pattern.priority_score};
987       return true;
988     }
989     if (status != UniLib::RegexMatcher::kNoError) {
990       TC_LOG(ERROR) << "Cound't match regex: " << pattern_id;
991     }
992   }
993 
994   return false;
995 }
996 
DatetimeClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options,ClassificationResult * classification_result) const997 bool TextClassifier::DatetimeClassifyText(
998     const std::string& context, CodepointSpan selection_indices,
999     const ClassificationOptions& options,
1000     ClassificationResult* classification_result) const {
1001   if (!datetime_parser_) {
1002     return false;
1003   }
1004 
1005   const std::string selection_text =
1006       ExtractSelection(context, selection_indices);
1007 
1008   std::vector<DatetimeParseResultSpan> datetime_spans;
1009   if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1010                                options.reference_timezone, options.locales,
1011                                ModeFlag_CLASSIFICATION,
1012                                /*anchor_start_end=*/true, &datetime_spans)) {
1013     TC_LOG(ERROR) << "Error during parsing datetime.";
1014     return false;
1015   }
1016   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1017     // Only consider the result valid if the selection and extracted datetime
1018     // spans exactly match.
1019     if (std::make_pair(datetime_span.span.first + selection_indices.first,
1020                        datetime_span.span.second + selection_indices.first) ==
1021         selection_indices) {
1022       *classification_result = {kDateCollection,
1023                                 datetime_span.target_classification_score};
1024       classification_result->datetime_parse_result = datetime_span.data;
1025       return true;
1026     }
1027   }
1028   return false;
1029 }
1030 
ClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options) const1031 std::vector<ClassificationResult> TextClassifier::ClassifyText(
1032     const std::string& context, CodepointSpan selection_indices,
1033     const ClassificationOptions& options) const {
1034   if (!initialized_) {
1035     TC_LOG(ERROR) << "Not initialized";
1036     return {};
1037   }
1038 
1039   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1040     return {};
1041   }
1042 
1043   if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
1044     return {};
1045   }
1046 
1047   if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
1048     TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
1049                << std::get<0>(selection_indices) << " "
1050                << std::get<1>(selection_indices);
1051     return {};
1052   }
1053 
1054   // Try the regular expression models.
1055   ClassificationResult regex_result;
1056   if (RegexClassifyText(context, selection_indices, &regex_result)) {
1057     if (!FilteredForClassification(regex_result)) {
1058       return {regex_result};
1059     } else {
1060       return {{kOtherCollection, 1.0}};
1061     }
1062   }
1063 
1064   // Try the date model.
1065   ClassificationResult datetime_result;
1066   if (DatetimeClassifyText(context, selection_indices, options,
1067                            &datetime_result)) {
1068     if (!FilteredForClassification(datetime_result)) {
1069       return {datetime_result};
1070     } else {
1071       return {{kOtherCollection, 1.0}};
1072     }
1073   }
1074 
1075   // Fallback to the model.
1076   std::vector<ClassificationResult> model_result;
1077 
1078   InterpreterManager interpreter_manager(selection_executor_.get(),
1079                                          classification_executor_.get());
1080   if (ModelClassifyText(context, selection_indices, &interpreter_manager,
1081                         /*embedding_cache=*/nullptr, &model_result) &&
1082       !model_result.empty()) {
1083     if (!FilteredForClassification(model_result[0])) {
1084       return model_result;
1085     } else {
1086       return {{kOtherCollection, 1.0}};
1087     }
1088   }
1089 
1090   // No classifications.
1091   return {};
1092 }
1093 
ModelAnnotate(const std::string & context,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1094 bool TextClassifier::ModelAnnotate(const std::string& context,
1095                                    InterpreterManager* interpreter_manager,
1096                                    std::vector<Token>* tokens,
1097                                    std::vector<AnnotatedSpan>* result) const {
1098   if (model_->triggering_options() == nullptr ||
1099       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1100     return true;
1101   }
1102 
1103   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1104                                                         /*do_copy=*/false);
1105   std::vector<UnicodeTextRange> lines;
1106   if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1107     lines.push_back({context_unicode.begin(), context_unicode.end()});
1108   } else {
1109     lines = selection_feature_processor_->SplitContext(context_unicode);
1110   }
1111 
1112   const float min_annotate_confidence =
1113       (model_->triggering_options() != nullptr
1114            ? model_->triggering_options()->min_annotate_confidence()
1115            : 0.f);
1116 
1117   FeatureProcessor::EmbeddingCache embedding_cache;
1118   for (const UnicodeTextRange& line : lines) {
1119     const std::string line_str =
1120         UnicodeText::UTF8Substring(line.first, line.second);
1121 
1122     *tokens = selection_feature_processor_->Tokenize(line_str);
1123     selection_feature_processor_->RetokenizeAndFindClick(
1124         line_str, {0, std::distance(line.first, line.second)},
1125         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1126         tokens,
1127         /*click_pos=*/nullptr);
1128     const TokenSpan full_line_span = {0, tokens->size()};
1129 
1130     // TODO(zilka): Add support for greater granularity of this check.
1131     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1132             *tokens, full_line_span)) {
1133       continue;
1134     }
1135 
1136     std::unique_ptr<CachedFeatures> cached_features;
1137     if (!selection_feature_processor_->ExtractFeatures(
1138             *tokens, full_line_span,
1139             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1140             embedding_executor_.get(),
1141             /*embedding_cache=*/nullptr,
1142             selection_feature_processor_->EmbeddingSize() +
1143                 selection_feature_processor_->DenseFeaturesCount(),
1144             &cached_features)) {
1145       TC_LOG(ERROR) << "Could not extract features.";
1146       return false;
1147     }
1148 
1149     std::vector<TokenSpan> local_chunks;
1150     if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
1151                     interpreter_manager->SelectionInterpreter(),
1152                     *cached_features, &local_chunks)) {
1153       TC_LOG(ERROR) << "Could not chunk.";
1154       return false;
1155     }
1156 
1157     const int offset = std::distance(context_unicode.begin(), line.first);
1158     for (const TokenSpan& chunk : local_chunks) {
1159       const CodepointSpan codepoint_span =
1160           selection_feature_processor_->StripBoundaryCodepoints(
1161               line_str, TokenSpanToCodepointSpan(*tokens, chunk));
1162 
1163       // Skip empty spans.
1164       if (codepoint_span.first != codepoint_span.second) {
1165         std::vector<ClassificationResult> classification;
1166         if (!ModelClassifyText(line_str, *tokens, codepoint_span,
1167                                interpreter_manager, &embedding_cache,
1168                                &classification)) {
1169           TC_LOG(ERROR) << "Could not classify text: "
1170                         << (codepoint_span.first + offset) << " "
1171                         << (codepoint_span.second + offset);
1172           return false;
1173         }
1174 
1175         // Do not include the span if it's classified as "other".
1176         if (!classification.empty() && !ClassifiedAsOther(classification) &&
1177             classification[0].score >= min_annotate_confidence) {
1178           AnnotatedSpan result_span;
1179           result_span.span = {codepoint_span.first + offset,
1180                               codepoint_span.second + offset};
1181           result_span.classification = std::move(classification);
1182           result->push_back(std::move(result_span));
1183         }
1184       }
1185     }
1186   }
1187   return true;
1188 }
1189 
SelectionFeatureProcessorForTests() const1190 const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests()
1191     const {
1192   return selection_feature_processor_.get();
1193 }
1194 
ClassificationFeatureProcessorForTests() const1195 const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests()
1196     const {
1197   return classification_feature_processor_.get();
1198 }
1199 
DatetimeParserForTests() const1200 const DatetimeParser* TextClassifier::DatetimeParserForTests() const {
1201   return datetime_parser_.get();
1202 }
1203 
Annotate(const std::string & context,const AnnotationOptions & options) const1204 std::vector<AnnotatedSpan> TextClassifier::Annotate(
1205     const std::string& context, const AnnotationOptions& options) const {
1206   std::vector<AnnotatedSpan> candidates;
1207 
1208   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
1209     return {};
1210   }
1211 
1212   if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
1213     return {};
1214   }
1215 
1216   InterpreterManager interpreter_manager(selection_executor_.get(),
1217                                          classification_executor_.get());
1218   // Annotate with the selection model.
1219   std::vector<Token> tokens;
1220   if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) {
1221     TC_LOG(ERROR) << "Couldn't run ModelAnnotate.";
1222     return {};
1223   }
1224 
1225   // Annotate with the regular expression models.
1226   if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1227                   annotation_regex_patterns_, &candidates)) {
1228     TC_LOG(ERROR) << "Couldn't run RegexChunk.";
1229     return {};
1230   }
1231 
1232   // Annotate with the datetime model.
1233   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1234                      options.reference_time_ms_utc, options.reference_timezone,
1235                      options.locales, ModeFlag_ANNOTATION, &candidates)) {
1236     TC_LOG(ERROR) << "Couldn't run RegexChunk.";
1237     return {};
1238   }
1239 
1240   // Sort candidates according to their position in the input, so that the next
1241   // code can assume that any connected component of overlapping spans forms a
1242   // contiguous block.
1243   std::sort(candidates.begin(), candidates.end(),
1244             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
1245               return a.span.first < b.span.first;
1246             });
1247 
1248   std::vector<int> candidate_indices;
1249   if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
1250                         &candidate_indices)) {
1251     TC_LOG(ERROR) << "Couldn't resolve conflicts.";
1252     return {};
1253   }
1254 
1255   std::vector<AnnotatedSpan> result;
1256   result.reserve(candidate_indices.size());
1257   for (const int i : candidate_indices) {
1258     if (!candidates[i].classification.empty() &&
1259         !ClassifiedAsOther(candidates[i].classification) &&
1260         !FilteredForAnnotation(candidates[i])) {
1261       result.push_back(std::move(candidates[i]));
1262     }
1263   }
1264 
1265   return result;
1266 }
1267 
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,std::vector<AnnotatedSpan> * result) const1268 bool TextClassifier::RegexChunk(const UnicodeText& context_unicode,
1269                                 const std::vector<int>& rules,
1270                                 std::vector<AnnotatedSpan>* result) const {
1271   for (int pattern_id : rules) {
1272     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1273     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
1274     if (!matcher) {
1275       TC_LOG(ERROR) << "Could not get regex matcher for pattern: "
1276                     << pattern_id;
1277       return false;
1278     }
1279 
1280     int status = UniLib::RegexMatcher::kNoError;
1281     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
1282       result->emplace_back();
1283       // Selection/annotation regular expressions need to specify a capturing
1284       // group specifying the selection.
1285       result->back().span = {matcher->Start(1, &status),
1286                              matcher->End(1, &status)};
1287       result->back().classification = {
1288           {regex_pattern.collection_name,
1289            regex_pattern.target_classification_score,
1290            regex_pattern.priority_score}};
1291     }
1292   }
1293   return true;
1294 }
1295 
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const1296 bool TextClassifier::ModelChunk(int num_tokens,
1297                                 const TokenSpan& span_of_interest,
1298                                 tflite::Interpreter* selection_interpreter,
1299                                 const CachedFeatures& cached_features,
1300                                 std::vector<TokenSpan>* chunks) const {
1301   const int max_selection_span =
1302       selection_feature_processor_->GetOptions()->max_selection_span();
1303   // The inference span is the span of interest expanded to include
1304   // max_selection_span tokens on either side, which is how far a selection can
1305   // stretch from the click.
1306   const TokenSpan inference_span = IntersectTokenSpans(
1307       ExpandTokenSpan(span_of_interest,
1308                       /*num_tokens_left=*/max_selection_span,
1309                       /*num_tokens_right=*/max_selection_span),
1310       {0, num_tokens});
1311 
1312   std::vector<ScoredChunk> scored_chunks;
1313   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
1314       selection_feature_processor_->GetOptions()
1315           ->bounds_sensitive_features()
1316           ->enabled()) {
1317     if (!ModelBoundsSensitiveScoreChunks(
1318             num_tokens, span_of_interest, inference_span, cached_features,
1319             selection_interpreter, &scored_chunks)) {
1320       return false;
1321     }
1322   } else {
1323     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
1324                                       cached_features, selection_interpreter,
1325                                       &scored_chunks)) {
1326       return false;
1327     }
1328   }
1329   std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
1330             [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
1331               return lhs.score < rhs.score;
1332             });
1333 
1334   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
1335   // them greedily as long as they do not overlap with any previously picked
1336   // chunks.
1337   std::vector<bool> token_used(TokenSpanSize(inference_span));
1338   chunks->clear();
1339   for (const ScoredChunk& scored_chunk : scored_chunks) {
1340     bool feasible = true;
1341     for (int i = scored_chunk.token_span.first;
1342          i < scored_chunk.token_span.second; ++i) {
1343       if (token_used[i - inference_span.first]) {
1344         feasible = false;
1345         break;
1346       }
1347     }
1348 
1349     if (!feasible) {
1350       continue;
1351     }
1352 
1353     for (int i = scored_chunk.token_span.first;
1354          i < scored_chunk.token_span.second; ++i) {
1355       token_used[i - inference_span.first] = true;
1356     }
1357 
1358     chunks->push_back(scored_chunk.token_span);
1359   }
1360 
1361   std::sort(chunks->begin(), chunks->end());
1362 
1363   return true;
1364 }
1365 
1366 namespace {
1367 // Updates the value at the given key in the map to maximum of the current value
1368 // and the given value, or simply inserts the value if the key is not yet there.
1369 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)1370 void UpdateMax(Map* map, typename Map::key_type key,
1371                typename Map::mapped_type value) {
1372   const auto it = map->find(key);
1373   if (it != map->end()) {
1374     it->second = std::max(it->second, value);
1375   } else {
1376     (*map)[key] = value;
1377   }
1378 }
1379 }  // namespace
1380 
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const1381 bool TextClassifier::ModelClickContextScoreChunks(
1382     int num_tokens, const TokenSpan& span_of_interest,
1383     const CachedFeatures& cached_features,
1384     tflite::Interpreter* selection_interpreter,
1385     std::vector<ScoredChunk>* scored_chunks) const {
1386   const int max_batch_size = model_->selection_options()->batch_size();
1387 
1388   std::vector<float> all_features;
1389   std::map<TokenSpan, float> chunk_scores;
1390   for (int batch_start = span_of_interest.first;
1391        batch_start < span_of_interest.second; batch_start += max_batch_size) {
1392     const int batch_end =
1393         std::min(batch_start + max_batch_size, span_of_interest.second);
1394 
1395     // Prepare features for the whole batch.
1396     all_features.clear();
1397     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
1398     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
1399       cached_features.AppendClickContextFeaturesForClick(click_pos,
1400                                                          &all_features);
1401     }
1402 
1403     // Run batched inference.
1404     const int batch_size = batch_end - batch_start;
1405     const int features_size = cached_features.OutputFeaturesSize();
1406     TensorView<float> logits = selection_executor_->ComputeLogits(
1407         TensorView<float>(all_features.data(), {batch_size, features_size}),
1408         selection_interpreter);
1409     if (!logits.is_valid()) {
1410       TC_LOG(ERROR) << "Couldn't compute logits.";
1411       return false;
1412     }
1413     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
1414         logits.dim(1) !=
1415             selection_feature_processor_->GetSelectionLabelCount()) {
1416       TC_LOG(ERROR) << "Mismatching output.";
1417       return false;
1418     }
1419 
1420     // Save results.
1421     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
1422       const std::vector<float> scores = ComputeSoftmax(
1423           logits.data() + logits.dim(1) * (click_pos - batch_start),
1424           logits.dim(1));
1425       for (int j = 0;
1426            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
1427         TokenSpan relative_token_span;
1428         if (!selection_feature_processor_->LabelToTokenSpan(
1429                 j, &relative_token_span)) {
1430           TC_LOG(ERROR) << "Couldn't map the label to a token span.";
1431           return false;
1432         }
1433         const TokenSpan candidate_span = ExpandTokenSpan(
1434             SingleTokenSpan(click_pos), relative_token_span.first,
1435             relative_token_span.second);
1436         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
1437           UpdateMax(&chunk_scores, candidate_span, scores[j]);
1438         }
1439       }
1440     }
1441   }
1442 
1443   scored_chunks->clear();
1444   scored_chunks->reserve(chunk_scores.size());
1445   for (const auto& entry : chunk_scores) {
1446     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
1447   }
1448 
1449   return true;
1450 }
1451 
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const1452 bool TextClassifier::ModelBoundsSensitiveScoreChunks(
1453     int num_tokens, const TokenSpan& span_of_interest,
1454     const TokenSpan& inference_span, const CachedFeatures& cached_features,
1455     tflite::Interpreter* selection_interpreter,
1456     std::vector<ScoredChunk>* scored_chunks) const {
1457   const int max_selection_span =
1458       selection_feature_processor_->GetOptions()->max_selection_span();
1459   const int max_chunk_length = selection_feature_processor_->GetOptions()
1460                                        ->selection_reduced_output_space()
1461                                    ? max_selection_span + 1
1462                                    : 2 * max_selection_span + 1;
1463   const bool score_single_token_spans_as_zero =
1464       selection_feature_processor_->GetOptions()
1465           ->bounds_sensitive_features()
1466           ->score_single_token_spans_as_zero();
1467 
1468   scored_chunks->clear();
1469   if (score_single_token_spans_as_zero) {
1470     scored_chunks->reserve(TokenSpanSize(span_of_interest));
1471   }
1472 
1473   // Prepare all chunk candidates into one batch:
1474   //   - Are contained in the inference span
1475   //   - Have a non-empty intersection with the span of interest
1476   //   - Are at least one token long
1477   //   - Are not longer than the maximum chunk length
1478   std::vector<TokenSpan> candidate_spans;
1479   for (int start = inference_span.first; start < span_of_interest.second;
1480        ++start) {
1481     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
1482     for (int end = leftmost_end_index;
1483          end <= inference_span.second && end - start <= max_chunk_length;
1484          ++end) {
1485       const TokenSpan candidate_span = {start, end};
1486       if (score_single_token_spans_as_zero &&
1487           TokenSpanSize(candidate_span) == 1) {
1488         // Do not include the single token span in the batch, add a zero score
1489         // for it directly to the output.
1490         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
1491       } else {
1492         candidate_spans.push_back(candidate_span);
1493       }
1494     }
1495   }
1496 
1497   const int max_batch_size = model_->selection_options()->batch_size();
1498 
1499   std::vector<float> all_features;
1500   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
1501   for (int batch_start = 0; batch_start < candidate_spans.size();
1502        batch_start += max_batch_size) {
1503     const int batch_end = std::min(batch_start + max_batch_size,
1504                                    static_cast<int>(candidate_spans.size()));
1505 
1506     // Prepare features for the whole batch.
1507     all_features.clear();
1508     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
1509     for (int i = batch_start; i < batch_end; ++i) {
1510       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
1511                                                            &all_features);
1512     }
1513 
1514     // Run batched inference.
1515     const int batch_size = batch_end - batch_start;
1516     const int features_size = cached_features.OutputFeaturesSize();
1517     TensorView<float> logits = selection_executor_->ComputeLogits(
1518         TensorView<float>(all_features.data(), {batch_size, features_size}),
1519         selection_interpreter);
1520     if (!logits.is_valid()) {
1521       TC_LOG(ERROR) << "Couldn't compute logits.";
1522       return false;
1523     }
1524     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
1525         logits.dim(1) != 1) {
1526       TC_LOG(ERROR) << "Mismatching output.";
1527       return false;
1528     }
1529 
1530     // Save results.
1531     for (int i = batch_start; i < batch_end; ++i) {
1532       scored_chunks->push_back(
1533           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
1534     }
1535   }
1536 
1537   return true;
1538 }
1539 
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,std::vector<AnnotatedSpan> * result) const1540 bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode,
1541                                    int64 reference_time_ms_utc,
1542                                    const std::string& reference_timezone,
1543                                    const std::string& locales, ModeFlag mode,
1544                                    std::vector<AnnotatedSpan>* result) const {
1545   if (!datetime_parser_) {
1546     return true;
1547   }
1548 
1549   std::vector<DatetimeParseResultSpan> datetime_spans;
1550   if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
1551                                reference_timezone, locales, mode,
1552                                /*anchor_start_end=*/false, &datetime_spans)) {
1553     return false;
1554   }
1555   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1556     AnnotatedSpan annotated_span;
1557     annotated_span.span = datetime_span.span;
1558     annotated_span.classification = {{kDateCollection,
1559                                       datetime_span.target_classification_score,
1560                                       datetime_span.priority_score}};
1561     annotated_span.classification[0].datetime_parse_result = datetime_span.data;
1562 
1563     result->push_back(std::move(annotated_span));
1564   }
1565   return true;
1566 }
1567 
ViewModel(const void * buffer,int size)1568 const Model* ViewModel(const void* buffer, int size) {
1569   if (!buffer) {
1570     return nullptr;
1571   }
1572 
1573   return LoadAndVerifyModel(buffer, size);
1574 }
1575 
1576 }  // namespace libtextclassifier2
1577