1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "text-classifier.h"
18
19 #include <algorithm>
20 #include <cctype>
21 #include <cmath>
22 #include <iterator>
23 #include <numeric>
24
25 #include "util/base/logging.h"
26 #include "util/math/softmax.h"
27 #include "util/utf8/unicodetext.h"
28
29 namespace libtextclassifier2 {
30 const std::string& TextClassifier::kOtherCollection =
__anon1665e8690102() 31 *[]() { return new std::string("other"); }();
32 const std::string& TextClassifier::kPhoneCollection =
__anon1665e8690202() 33 *[]() { return new std::string("phone"); }();
34 const std::string& TextClassifier::kAddressCollection =
__anon1665e8690302() 35 *[]() { return new std::string("address"); }();
36 const std::string& TextClassifier::kDateCollection =
__anon1665e8690402() 37 *[]() { return new std::string("date"); }();
38
39 namespace {
LoadAndVerifyModel(const void * addr,int size)40 const Model* LoadAndVerifyModel(const void* addr, int size) {
41 const Model* model = GetModel(addr);
42
43 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
44 if (model->Verify(verifier)) {
45 return model;
46 } else {
47 return nullptr;
48 }
49 }
50 } // namespace
51
SelectionInterpreter()52 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
53 if (!selection_interpreter_) {
54 TC_CHECK(selection_executor_);
55 selection_interpreter_ = selection_executor_->CreateInterpreter();
56 if (!selection_interpreter_) {
57 TC_LOG(ERROR) << "Could not build TFLite interpreter.";
58 }
59 }
60 return selection_interpreter_.get();
61 }
62
ClassificationInterpreter()63 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
64 if (!classification_interpreter_) {
65 TC_CHECK(classification_executor_);
66 classification_interpreter_ = classification_executor_->CreateInterpreter();
67 if (!classification_interpreter_) {
68 TC_LOG(ERROR) << "Could not build TFLite interpreter.";
69 }
70 }
71 return classification_interpreter_.get();
72 }
73
FromUnownedBuffer(const char * buffer,int size,const UniLib * unilib)74 std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer(
75 const char* buffer, int size, const UniLib* unilib) {
76 const Model* model = LoadAndVerifyModel(buffer, size);
77 if (model == nullptr) {
78 return nullptr;
79 }
80
81 auto classifier =
82 std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib));
83 if (!classifier->IsInitialized()) {
84 return nullptr;
85 }
86
87 return classifier;
88 }
89
FromScopedMmap(std::unique_ptr<ScopedMmap> * mmap,const UniLib * unilib)90 std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap(
91 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) {
92 if (!(*mmap)->handle().ok()) {
93 TC_VLOG(1) << "Mmap failed.";
94 return nullptr;
95 }
96
97 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
98 (*mmap)->handle().num_bytes());
99 if (!model) {
100 TC_LOG(ERROR) << "Model verification failed.";
101 return nullptr;
102 }
103
104 auto classifier =
105 std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib));
106 if (!classifier->IsInitialized()) {
107 return nullptr;
108 }
109
110 return classifier;
111 }
112
FromFileDescriptor(int fd,int offset,int size,const UniLib * unilib)113 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
114 int fd, int offset, int size, const UniLib* unilib) {
115 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
116 return FromScopedMmap(&mmap, unilib);
117 }
118
FromFileDescriptor(int fd,const UniLib * unilib)119 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
120 int fd, const UniLib* unilib) {
121 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
122 return FromScopedMmap(&mmap, unilib);
123 }
124
FromPath(const std::string & path,const UniLib * unilib)125 std::unique_ptr<TextClassifier> TextClassifier::FromPath(
126 const std::string& path, const UniLib* unilib) {
127 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
128 return FromScopedMmap(&mmap, unilib);
129 }
130
ValidateAndInitialize()131 void TextClassifier::ValidateAndInitialize() {
132 initialized_ = false;
133
134 if (model_ == nullptr) {
135 TC_LOG(ERROR) << "No model specified.";
136 return;
137 }
138
139 const bool model_enabled_for_annotation =
140 (model_->triggering_options() != nullptr &&
141 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
142 const bool model_enabled_for_classification =
143 (model_->triggering_options() != nullptr &&
144 (model_->triggering_options()->enabled_modes() &
145 ModeFlag_CLASSIFICATION));
146 const bool model_enabled_for_selection =
147 (model_->triggering_options() != nullptr &&
148 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
149
150 // Annotation requires the selection model.
151 if (model_enabled_for_annotation || model_enabled_for_selection) {
152 if (!model_->selection_options()) {
153 TC_LOG(ERROR) << "No selection options.";
154 return;
155 }
156 if (!model_->selection_feature_options()) {
157 TC_LOG(ERROR) << "No selection feature options.";
158 return;
159 }
160 if (!model_->selection_feature_options()->bounds_sensitive_features()) {
161 TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
162 return;
163 }
164 if (!model_->selection_model()) {
165 TC_LOG(ERROR) << "No selection model.";
166 return;
167 }
168 selection_executor_ = ModelExecutor::Instance(model_->selection_model());
169 if (!selection_executor_) {
170 TC_LOG(ERROR) << "Could not initialize selection executor.";
171 return;
172 }
173 selection_feature_processor_.reset(
174 new FeatureProcessor(model_->selection_feature_options(), unilib_));
175 }
176
177 // Annotation requires the classification model for conflict resolution and
178 // scoring.
179 // Selection requires the classification model for conflict resolution.
180 if (model_enabled_for_annotation || model_enabled_for_classification ||
181 model_enabled_for_selection) {
182 if (!model_->classification_options()) {
183 TC_LOG(ERROR) << "No classification options.";
184 return;
185 }
186
187 if (!model_->classification_feature_options()) {
188 TC_LOG(ERROR) << "No classification feature options.";
189 return;
190 }
191
192 if (!model_->classification_feature_options()
193 ->bounds_sensitive_features()) {
194 TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
195 return;
196 }
197 if (!model_->classification_model()) {
198 TC_LOG(ERROR) << "No clf model.";
199 return;
200 }
201
202 classification_executor_ =
203 ModelExecutor::Instance(model_->classification_model());
204 if (!classification_executor_) {
205 TC_LOG(ERROR) << "Could not initialize classification executor.";
206 return;
207 }
208
209 classification_feature_processor_.reset(new FeatureProcessor(
210 model_->classification_feature_options(), unilib_));
211 }
212
213 // The embeddings need to be specified if the model is to be used for
214 // classification or selection.
215 if (model_enabled_for_annotation || model_enabled_for_classification ||
216 model_enabled_for_selection) {
217 if (!model_->embedding_model()) {
218 TC_LOG(ERROR) << "No embedding model.";
219 return;
220 }
221
222 // Check that the embedding size of the selection and classification model
223 // matches, as they are using the same embeddings.
224 if (model_enabled_for_selection &&
225 (model_->selection_feature_options()->embedding_size() !=
226 model_->classification_feature_options()->embedding_size() ||
227 model_->selection_feature_options()->embedding_quantization_bits() !=
228 model_->classification_feature_options()
229 ->embedding_quantization_bits())) {
230 TC_LOG(ERROR) << "Mismatching embedding size/quantization.";
231 return;
232 }
233
234 embedding_executor_ = TFLiteEmbeddingExecutor::Instance(
235 model_->embedding_model(),
236 model_->classification_feature_options()->embedding_size(),
237 model_->classification_feature_options()
238 ->embedding_quantization_bits());
239 if (!embedding_executor_) {
240 TC_LOG(ERROR) << "Could not initialize embedding executor.";
241 return;
242 }
243 }
244
245 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
246 if (model_->regex_model()) {
247 if (!InitializeRegexModel(decompressor.get())) {
248 TC_LOG(ERROR) << "Could not initialize regex model.";
249 return;
250 }
251 }
252
253 if (model_->datetime_model()) {
254 datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(),
255 *unilib_, decompressor.get());
256 if (!datetime_parser_) {
257 TC_LOG(ERROR) << "Could not initialize datetime parser.";
258 return;
259 }
260 }
261
262 if (model_->output_options()) {
263 if (model_->output_options()->filtered_collections_annotation()) {
264 for (const auto collection :
265 *model_->output_options()->filtered_collections_annotation()) {
266 filtered_collections_annotation_.insert(collection->str());
267 }
268 }
269 if (model_->output_options()->filtered_collections_classification()) {
270 for (const auto collection :
271 *model_->output_options()->filtered_collections_classification()) {
272 filtered_collections_classification_.insert(collection->str());
273 }
274 }
275 if (model_->output_options()->filtered_collections_selection()) {
276 for (const auto collection :
277 *model_->output_options()->filtered_collections_selection()) {
278 filtered_collections_selection_.insert(collection->str());
279 }
280 }
281 }
282
283 initialized_ = true;
284 }
285
InitializeRegexModel(ZlibDecompressor * decompressor)286 bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) {
287 if (!model_->regex_model()->patterns()) {
288 return true;
289 }
290
291 // Initialize pattern recognizers.
292 int regex_pattern_id = 0;
293 for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
294 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
295 UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(),
296 regex_pattern->compressed_pattern(),
297 decompressor);
298 if (!compiled_pattern) {
299 TC_LOG(INFO) << "Failed to load regex pattern";
300 return false;
301 }
302
303 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
304 annotation_regex_patterns_.push_back(regex_pattern_id);
305 }
306 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
307 classification_regex_patterns_.push_back(regex_pattern_id);
308 }
309 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
310 selection_regex_patterns_.push_back(regex_pattern_id);
311 }
312 regex_patterns_.push_back({regex_pattern->collection_name()->str(),
313 regex_pattern->target_classification_score(),
314 regex_pattern->priority_score(),
315 std::move(compiled_pattern)});
316 if (regex_pattern->use_approximate_matching()) {
317 regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
318 }
319 ++regex_pattern_id;
320 }
321
322 return true;
323 }
324
325 namespace {
326
CountDigits(const std::string & str,CodepointSpan selection_indices)327 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
328 int count = 0;
329 int i = 0;
330 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
331 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
332 if (i >= selection_indices.first && i < selection_indices.second &&
333 isdigit(*it)) {
334 ++count;
335 }
336 }
337 return count;
338 }
339
ExtractSelection(const std::string & context,CodepointSpan selection_indices)340 std::string ExtractSelection(const std::string& context,
341 CodepointSpan selection_indices) {
342 const UnicodeText context_unicode =
343 UTF8ToUnicodeText(context, /*do_copy=*/false);
344 auto selection_begin = context_unicode.begin();
345 std::advance(selection_begin, selection_indices.first);
346 auto selection_end = context_unicode.begin();
347 std::advance(selection_end, selection_indices.second);
348 return UnicodeText::UTF8Substring(selection_begin, selection_end);
349 }
350 } // namespace
351
352 namespace internal {
353 // Helper function, which if the initial 'span' contains only white-spaces,
354 // moves the selection to a single-codepoint selection on a left or right side
355 // of this space.
SnapLeftIfWhitespaceSelection(CodepointSpan span,const UnicodeText & context_unicode,const UniLib & unilib)356 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
357 const UnicodeText& context_unicode,
358 const UniLib& unilib) {
359 TC_CHECK(ValidNonEmptySpan(span));
360
361 UnicodeText::const_iterator it;
362
363 // Check that the current selection is all whitespaces.
364 it = context_unicode.begin();
365 std::advance(it, span.first);
366 for (int i = 0; i < (span.second - span.first); ++i, ++it) {
367 if (!unilib.IsWhitespace(*it)) {
368 return span;
369 }
370 }
371
372 CodepointSpan result;
373
374 // Try moving left.
375 result = span;
376 it = context_unicode.begin();
377 std::advance(it, span.first);
378 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
379 --result.first;
380 --it;
381 }
382 result.second = result.first + 1;
383 if (!unilib.IsWhitespace(*it)) {
384 return result;
385 }
386
387 // If moving left didn't find a non-whitespace character, just return the
388 // original span.
389 return span;
390 }
391 } // namespace internal
392
FilteredForAnnotation(const AnnotatedSpan & span) const393 bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const {
394 return !span.classification.empty() &&
395 filtered_collections_annotation_.find(
396 span.classification[0].collection) !=
397 filtered_collections_annotation_.end();
398 }
399
FilteredForClassification(const ClassificationResult & classification) const400 bool TextClassifier::FilteredForClassification(
401 const ClassificationResult& classification) const {
402 return filtered_collections_classification_.find(classification.collection) !=
403 filtered_collections_classification_.end();
404 }
405
FilteredForSelection(const AnnotatedSpan & span) const406 bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const {
407 return !span.classification.empty() &&
408 filtered_collections_selection_.find(
409 span.classification[0].collection) !=
410 filtered_collections_selection_.end();
411 }
412
SuggestSelection(const std::string & context,CodepointSpan click_indices,const SelectionOptions & options) const413 CodepointSpan TextClassifier::SuggestSelection(
414 const std::string& context, CodepointSpan click_indices,
415 const SelectionOptions& options) const {
416 CodepointSpan original_click_indices = click_indices;
417 if (!initialized_) {
418 TC_LOG(ERROR) << "Not initialized";
419 return original_click_indices;
420 }
421 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
422 return original_click_indices;
423 }
424
425 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
426 /*do_copy=*/false);
427
428 if (!context_unicode.is_valid()) {
429 return original_click_indices;
430 }
431
432 const int context_codepoint_size = context_unicode.size_codepoints();
433
434 if (click_indices.first < 0 || click_indices.second < 0 ||
435 click_indices.first >= context_codepoint_size ||
436 click_indices.second > context_codepoint_size ||
437 click_indices.first >= click_indices.second) {
438 TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
439 << click_indices.first << " " << click_indices.second;
440 return original_click_indices;
441 }
442
443 if (model_->snap_whitespace_selections()) {
444 // We want to expand a purely white-space selection to a multi-selection it
445 // would've been part of. But with this feature disabled we would do a no-
446 // op, because no token is found. Therefore, we need to modify the
447 // 'click_indices' a bit to include a part of the token, so that the click-
448 // finding logic finds the clicked token correctly. This modification is
449 // done by the following function. Note, that it's enough to check the left
450 // side of the current selection, because if the white-space is a part of a
451 // multi-selection, neccessarily both tokens - on the left and the right
452 // sides need to be selected. Thus snapping only to the left is sufficient
453 // (there's a check at the bottom that makes sure that if we snap to the
454 // left token but the result does not contain the initial white-space,
455 // returns the original indices).
456 click_indices = internal::SnapLeftIfWhitespaceSelection(
457 click_indices, context_unicode, *unilib_);
458 }
459
460 std::vector<AnnotatedSpan> candidates;
461 InterpreterManager interpreter_manager(selection_executor_.get(),
462 classification_executor_.get());
463 std::vector<Token> tokens;
464 if (!ModelSuggestSelection(context_unicode, click_indices,
465 &interpreter_manager, &tokens, &candidates)) {
466 TC_LOG(ERROR) << "Model suggest selection failed.";
467 return original_click_indices;
468 }
469 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
470 TC_LOG(ERROR) << "Regex suggest selection failed.";
471 return original_click_indices;
472 }
473 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
474 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
475 options.locales, ModeFlag_SELECTION, &candidates)) {
476 TC_LOG(ERROR) << "Datetime suggest selection failed.";
477 return original_click_indices;
478 }
479
480 // Sort candidates according to their position in the input, so that the next
481 // code can assume that any connected component of overlapping spans forms a
482 // contiguous block.
483 std::sort(candidates.begin(), candidates.end(),
484 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
485 return a.span.first < b.span.first;
486 });
487
488 std::vector<int> candidate_indices;
489 if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
490 &candidate_indices)) {
491 TC_LOG(ERROR) << "Couldn't resolve conflicts.";
492 return original_click_indices;
493 }
494
495 for (const int i : candidate_indices) {
496 if (SpansOverlap(candidates[i].span, click_indices) &&
497 SpansOverlap(candidates[i].span, original_click_indices)) {
498 // Run model classification if not present but requested and there's a
499 // classification collection filter specified.
500 if (candidates[i].classification.empty() &&
501 model_->selection_options()->always_classify_suggested_selection() &&
502 !filtered_collections_selection_.empty()) {
503 if (!ModelClassifyText(
504 context, candidates[i].span, &interpreter_manager,
505 /*embedding_cache=*/nullptr, &candidates[i].classification)) {
506 return original_click_indices;
507 }
508 }
509
510 // Ignore if span classification is filtered.
511 if (FilteredForSelection(candidates[i])) {
512 return original_click_indices;
513 }
514
515 return candidates[i].span;
516 }
517 }
518
519 return original_click_indices;
520 }
521
522 namespace {
523 // Helper function that returns the index of the first candidate that
524 // transitively does not overlap with the candidate on 'start_index'. If the end
525 // of 'candidates' is reached, it returns the index that points right behind the
526 // array.
FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan> & candidates,int start_index)527 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
528 int start_index) {
529 int first_non_overlapping = start_index + 1;
530 CodepointSpan conflicting_span = candidates[start_index].span;
531 while (
532 first_non_overlapping < candidates.size() &&
533 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
534 // Grow the span to include the current one.
535 conflicting_span.second = std::max(
536 conflicting_span.second, candidates[first_non_overlapping].span.second);
537
538 ++first_non_overlapping;
539 }
540 return first_non_overlapping;
541 }
542 } // namespace
543
ResolveConflicts(const std::vector<AnnotatedSpan> & candidates,const std::string & context,const std::vector<Token> & cached_tokens,InterpreterManager * interpreter_manager,std::vector<int> * result) const544 bool TextClassifier::ResolveConflicts(
545 const std::vector<AnnotatedSpan>& candidates, const std::string& context,
546 const std::vector<Token>& cached_tokens,
547 InterpreterManager* interpreter_manager, std::vector<int>* result) const {
548 result->clear();
549 result->reserve(candidates.size());
550 for (int i = 0; i < candidates.size();) {
551 int first_non_overlapping =
552 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
553
554 const bool conflict_found = first_non_overlapping != (i + 1);
555 if (conflict_found) {
556 std::vector<int> candidate_indices;
557 if (!ResolveConflict(context, cached_tokens, candidates, i,
558 first_non_overlapping, interpreter_manager,
559 &candidate_indices)) {
560 return false;
561 }
562 result->insert(result->end(), candidate_indices.begin(),
563 candidate_indices.end());
564 } else {
565 result->push_back(i);
566 }
567
568 // Skip over the whole conflicting group/go to next candidate.
569 i = first_non_overlapping;
570 }
571 return true;
572 }
573
574 namespace {
ClassifiedAsOther(const std::vector<ClassificationResult> & classification)575 inline bool ClassifiedAsOther(
576 const std::vector<ClassificationResult>& classification) {
577 return !classification.empty() &&
578 classification[0].collection == TextClassifier::kOtherCollection;
579 }
580
GetPriorityScore(const std::vector<ClassificationResult> & classification)581 float GetPriorityScore(
582 const std::vector<ClassificationResult>& classification) {
583 if (!ClassifiedAsOther(classification)) {
584 return classification[0].priority_score;
585 } else {
586 return -1.0;
587 }
588 }
589 } // namespace
590
ResolveConflict(const std::string & context,const std::vector<Token> & cached_tokens,const std::vector<AnnotatedSpan> & candidates,int start_index,int end_index,InterpreterManager * interpreter_manager,std::vector<int> * chosen_indices) const591 bool TextClassifier::ResolveConflict(
592 const std::string& context, const std::vector<Token>& cached_tokens,
593 const std::vector<AnnotatedSpan>& candidates, int start_index,
594 int end_index, InterpreterManager* interpreter_manager,
595 std::vector<int>* chosen_indices) const {
596 std::vector<int> conflicting_indices;
597 std::unordered_map<int, float> scores;
598 for (int i = start_index; i < end_index; ++i) {
599 conflicting_indices.push_back(i);
600 if (!candidates[i].classification.empty()) {
601 scores[i] = GetPriorityScore(candidates[i].classification);
602 continue;
603 }
604
605 // OPTIMIZATION: So that we don't have to classify all the ML model
606 // spans apriori, we wait until we get here, when they conflict with
607 // something and we need the actual classification scores. So if the
608 // candidate conflicts and comes from the model, we need to run a
609 // classification to determine its priority:
610 std::vector<ClassificationResult> classification;
611 if (!ModelClassifyText(context, cached_tokens, candidates[i].span,
612 interpreter_manager,
613 /*embedding_cache=*/nullptr, &classification)) {
614 return false;
615 }
616
617 if (!classification.empty()) {
618 scores[i] = GetPriorityScore(classification);
619 }
620 }
621
622 std::sort(conflicting_indices.begin(), conflicting_indices.end(),
623 [&scores](int i, int j) { return scores[i] > scores[j]; });
624
625 // Keeps the candidates sorted by their position in the text (their left span
626 // index) for fast retrieval down.
627 std::set<int, std::function<bool(int, int)>> chosen_indices_set(
628 [&candidates](int a, int b) {
629 return candidates[a].span.first < candidates[b].span.first;
630 });
631
632 // Greedily place the candidates if they don't conflict with the already
633 // placed ones.
634 for (int i = 0; i < conflicting_indices.size(); ++i) {
635 const int considered_candidate = conflicting_indices[i];
636 if (!DoesCandidateConflict(considered_candidate, candidates,
637 chosen_indices_set)) {
638 chosen_indices_set.insert(considered_candidate);
639 }
640 }
641
642 *chosen_indices =
643 std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end());
644
645 return true;
646 }
647
ModelSuggestSelection(const UnicodeText & context_unicode,CodepointSpan click_indices,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const648 bool TextClassifier::ModelSuggestSelection(
649 const UnicodeText& context_unicode, CodepointSpan click_indices,
650 InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
651 std::vector<AnnotatedSpan>* result) const {
652 if (model_->triggering_options() == nullptr ||
653 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
654 return true;
655 }
656
657 int click_pos;
658 *tokens = selection_feature_processor_->Tokenize(context_unicode);
659 selection_feature_processor_->RetokenizeAndFindClick(
660 context_unicode, click_indices,
661 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
662 tokens, &click_pos);
663 if (click_pos == kInvalidIndex) {
664 TC_VLOG(1) << "Could not calculate the click position.";
665 return false;
666 }
667
668 const int symmetry_context_size =
669 model_->selection_options()->symmetry_context_size();
670 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
671 bounds_sensitive_features = selection_feature_processor_->GetOptions()
672 ->bounds_sensitive_features();
673
674 // The symmetry context span is the clicked token with symmetry_context_size
675 // tokens on either side.
676 const TokenSpan symmetry_context_span = IntersectTokenSpans(
677 ExpandTokenSpan(SingleTokenSpan(click_pos),
678 /*num_tokens_left=*/symmetry_context_size,
679 /*num_tokens_right=*/symmetry_context_size),
680 {0, tokens->size()});
681
682 // Compute the extraction span based on the model type.
683 TokenSpan extraction_span;
684 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
685 // The extraction span is the symmetry context span expanded to include
686 // max_selection_span tokens on either side, which is how far a selection
687 // can stretch from the click, plus a relevant number of tokens outside of
688 // the bounds of the selection.
689 const int max_selection_span =
690 selection_feature_processor_->GetOptions()->max_selection_span();
691 extraction_span =
692 ExpandTokenSpan(symmetry_context_span,
693 /*num_tokens_left=*/max_selection_span +
694 bounds_sensitive_features->num_tokens_before(),
695 /*num_tokens_right=*/max_selection_span +
696 bounds_sensitive_features->num_tokens_after());
697 } else {
698 // The extraction span is the symmetry context span expanded to include
699 // context_size tokens on either side.
700 const int context_size =
701 selection_feature_processor_->GetOptions()->context_size();
702 extraction_span = ExpandTokenSpan(symmetry_context_span,
703 /*num_tokens_left=*/context_size,
704 /*num_tokens_right=*/context_size);
705 }
706 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
707
708 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
709 *tokens, extraction_span)) {
710 return true;
711 }
712
713 std::unique_ptr<CachedFeatures> cached_features;
714 if (!selection_feature_processor_->ExtractFeatures(
715 *tokens, extraction_span,
716 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
717 embedding_executor_.get(),
718 /*embedding_cache=*/nullptr,
719 selection_feature_processor_->EmbeddingSize() +
720 selection_feature_processor_->DenseFeaturesCount(),
721 &cached_features)) {
722 TC_LOG(ERROR) << "Could not extract features.";
723 return false;
724 }
725
726 // Produce selection model candidates.
727 std::vector<TokenSpan> chunks;
728 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
729 interpreter_manager->SelectionInterpreter(), *cached_features,
730 &chunks)) {
731 TC_LOG(ERROR) << "Could not chunk.";
732 return false;
733 }
734
735 for (const TokenSpan& chunk : chunks) {
736 AnnotatedSpan candidate;
737 candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
738 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
739 if (model_->selection_options()->strip_unpaired_brackets()) {
740 candidate.span =
741 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
742 }
743
744 // Only output non-empty spans.
745 if (candidate.span.first != candidate.span.second) {
746 result->push_back(candidate);
747 }
748 }
749 return true;
750 }
751
ModelClassifyText(const std::string & context,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const752 bool TextClassifier::ModelClassifyText(
753 const std::string& context, CodepointSpan selection_indices,
754 InterpreterManager* interpreter_manager,
755 FeatureProcessor::EmbeddingCache* embedding_cache,
756 std::vector<ClassificationResult>* classification_results) const {
757 if (model_->triggering_options() == nullptr ||
758 !(model_->triggering_options()->enabled_modes() &
759 ModeFlag_CLASSIFICATION)) {
760 return true;
761 }
762 return ModelClassifyText(context, {}, selection_indices, interpreter_manager,
763 embedding_cache, classification_results);
764 }
765
766 namespace internal {
CopyCachedTokens(const std::vector<Token> & cached_tokens,CodepointSpan selection_indices,TokenSpan tokens_around_selection_to_copy)767 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
768 CodepointSpan selection_indices,
769 TokenSpan tokens_around_selection_to_copy) {
770 const auto first_selection_token = std::upper_bound(
771 cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
772 [](int selection_start, const Token& token) {
773 return selection_start < token.end;
774 });
775 const auto last_selection_token = std::lower_bound(
776 cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
777 [](const Token& token, int selection_end) {
778 return token.start < selection_end;
779 });
780
781 const int64 first_token = std::max(
782 static_cast<int64>(0),
783 static_cast<int64>((first_selection_token - cached_tokens.begin()) -
784 tokens_around_selection_to_copy.first));
785 const int64 last_token = std::min(
786 static_cast<int64>(cached_tokens.size()),
787 static_cast<int64>((last_selection_token - cached_tokens.begin()) +
788 tokens_around_selection_to_copy.second));
789
790 std::vector<Token> tokens;
791 tokens.reserve(last_token - first_token);
792 for (int i = first_token; i < last_token; ++i) {
793 tokens.push_back(cached_tokens[i]);
794 }
795 return tokens;
796 }
797 } // namespace internal
798
ClassifyTextUpperBoundNeededTokens() const799 TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const {
800 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
801 bounds_sensitive_features =
802 classification_feature_processor_->GetOptions()
803 ->bounds_sensitive_features();
804 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
805 // The extraction span is the selection span expanded to include a relevant
806 // number of tokens outside of the bounds of the selection.
807 return {bounds_sensitive_features->num_tokens_before(),
808 bounds_sensitive_features->num_tokens_after()};
809 } else {
810 // The extraction span is the clicked token with context_size tokens on
811 // either side.
812 const int context_size =
813 selection_feature_processor_->GetOptions()->context_size();
814 return {context_size, context_size};
815 }
816 }
817
ModelClassifyText(const std::string & context,const std::vector<Token> & cached_tokens,CodepointSpan selection_indices,InterpreterManager * interpreter_manager,FeatureProcessor::EmbeddingCache * embedding_cache,std::vector<ClassificationResult> * classification_results) const818 bool TextClassifier::ModelClassifyText(
819 const std::string& context, const std::vector<Token>& cached_tokens,
820 CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
821 FeatureProcessor::EmbeddingCache* embedding_cache,
822 std::vector<ClassificationResult>* classification_results) const {
823 std::vector<Token> tokens;
824 if (cached_tokens.empty()) {
825 tokens = classification_feature_processor_->Tokenize(context);
826 } else {
827 tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
828 ClassifyTextUpperBoundNeededTokens());
829 }
830
831 int click_pos;
832 classification_feature_processor_->RetokenizeAndFindClick(
833 context, selection_indices,
834 classification_feature_processor_->GetOptions()
835 ->only_use_line_with_click(),
836 &tokens, &click_pos);
837 const TokenSpan selection_token_span =
838 CodepointSpanToTokenSpan(tokens, selection_indices);
839 const int selection_num_tokens = TokenSpanSize(selection_token_span);
840 if (model_->classification_options()->max_num_tokens() > 0 &&
841 model_->classification_options()->max_num_tokens() <
842 selection_num_tokens) {
843 *classification_results = {{kOtherCollection, 1.0}};
844 return true;
845 }
846
847 const FeatureProcessorOptions_::BoundsSensitiveFeatures*
848 bounds_sensitive_features =
849 classification_feature_processor_->GetOptions()
850 ->bounds_sensitive_features();
851 if (selection_token_span.first == kInvalidIndex ||
852 selection_token_span.second == kInvalidIndex) {
853 TC_LOG(ERROR) << "Could not determine span.";
854 return false;
855 }
856
857 // Compute the extraction span based on the model type.
858 TokenSpan extraction_span;
859 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
860 // The extraction span is the selection span expanded to include a relevant
861 // number of tokens outside of the bounds of the selection.
862 extraction_span = ExpandTokenSpan(
863 selection_token_span,
864 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
865 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
866 } else {
867 if (click_pos == kInvalidIndex) {
868 TC_LOG(ERROR) << "Couldn't choose a click position.";
869 return false;
870 }
871 // The extraction span is the clicked token with context_size tokens on
872 // either side.
873 const int context_size =
874 classification_feature_processor_->GetOptions()->context_size();
875 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
876 /*num_tokens_left=*/context_size,
877 /*num_tokens_right=*/context_size);
878 }
879 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
880
881 if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
882 tokens, extraction_span)) {
883 *classification_results = {{kOtherCollection, 1.0}};
884 return true;
885 }
886
887 std::unique_ptr<CachedFeatures> cached_features;
888 if (!classification_feature_processor_->ExtractFeatures(
889 tokens, extraction_span, selection_indices, embedding_executor_.get(),
890 embedding_cache,
891 classification_feature_processor_->EmbeddingSize() +
892 classification_feature_processor_->DenseFeaturesCount(),
893 &cached_features)) {
894 TC_LOG(ERROR) << "Could not extract features.";
895 return false;
896 }
897
898 std::vector<float> features;
899 features.reserve(cached_features->OutputFeaturesSize());
900 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
901 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
902 &features);
903 } else {
904 cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
905 }
906
907 TensorView<float> logits = classification_executor_->ComputeLogits(
908 TensorView<float>(features.data(),
909 {1, static_cast<int>(features.size())}),
910 interpreter_manager->ClassificationInterpreter());
911 if (!logits.is_valid()) {
912 TC_LOG(ERROR) << "Couldn't compute logits.";
913 return false;
914 }
915
916 if (logits.dims() != 2 || logits.dim(0) != 1 ||
917 logits.dim(1) != classification_feature_processor_->NumCollections()) {
918 TC_LOG(ERROR) << "Mismatching output";
919 return false;
920 }
921
922 const std::vector<float> scores =
923 ComputeSoftmax(logits.data(), logits.dim(1));
924
925 classification_results->resize(scores.size());
926 for (int i = 0; i < scores.size(); i++) {
927 (*classification_results)[i] = {
928 classification_feature_processor_->LabelToCollection(i), scores[i]};
929 }
930 std::sort(classification_results->begin(), classification_results->end(),
931 [](const ClassificationResult& a, const ClassificationResult& b) {
932 return a.score > b.score;
933 });
934
935 // Phone class sanity check.
936 if (!classification_results->empty() &&
937 classification_results->begin()->collection == kPhoneCollection) {
938 const int digit_count = CountDigits(context, selection_indices);
939 if (digit_count <
940 model_->classification_options()->phone_min_num_digits() ||
941 digit_count >
942 model_->classification_options()->phone_max_num_digits()) {
943 *classification_results = {{kOtherCollection, 1.0}};
944 }
945 }
946
947 // Address class sanity check.
948 if (!classification_results->empty() &&
949 classification_results->begin()->collection == kAddressCollection) {
950 if (selection_num_tokens <
951 model_->classification_options()->address_min_num_tokens()) {
952 *classification_results = {{kOtherCollection, 1.0}};
953 }
954 }
955
956 return true;
957 }
958
RegexClassifyText(const std::string & context,CodepointSpan selection_indices,ClassificationResult * classification_result) const959 bool TextClassifier::RegexClassifyText(
960 const std::string& context, CodepointSpan selection_indices,
961 ClassificationResult* classification_result) const {
962 const std::string selection_text =
963 ExtractSelection(context, selection_indices);
964 const UnicodeText selection_text_unicode(
965 UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
966
967 // Check whether any of the regular expressions match.
968 for (const int pattern_id : classification_regex_patterns_) {
969 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
970 const std::unique_ptr<UniLib::RegexMatcher> matcher =
971 regex_pattern.pattern->Matcher(selection_text_unicode);
972 int status = UniLib::RegexMatcher::kNoError;
973 bool matches;
974 if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
975 regex_approximate_match_pattern_ids_.end()) {
976 matches = matcher->ApproximatelyMatches(&status);
977 } else {
978 matches = matcher->Matches(&status);
979 }
980 if (status != UniLib::RegexMatcher::kNoError) {
981 return false;
982 }
983 if (matches) {
984 *classification_result = {regex_pattern.collection_name,
985 regex_pattern.target_classification_score,
986 regex_pattern.priority_score};
987 return true;
988 }
989 if (status != UniLib::RegexMatcher::kNoError) {
990 TC_LOG(ERROR) << "Cound't match regex: " << pattern_id;
991 }
992 }
993
994 return false;
995 }
996
DatetimeClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options,ClassificationResult * classification_result) const997 bool TextClassifier::DatetimeClassifyText(
998 const std::string& context, CodepointSpan selection_indices,
999 const ClassificationOptions& options,
1000 ClassificationResult* classification_result) const {
1001 if (!datetime_parser_) {
1002 return false;
1003 }
1004
1005 const std::string selection_text =
1006 ExtractSelection(context, selection_indices);
1007
1008 std::vector<DatetimeParseResultSpan> datetime_spans;
1009 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
1010 options.reference_timezone, options.locales,
1011 ModeFlag_CLASSIFICATION,
1012 /*anchor_start_end=*/true, &datetime_spans)) {
1013 TC_LOG(ERROR) << "Error during parsing datetime.";
1014 return false;
1015 }
1016 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1017 // Only consider the result valid if the selection and extracted datetime
1018 // spans exactly match.
1019 if (std::make_pair(datetime_span.span.first + selection_indices.first,
1020 datetime_span.span.second + selection_indices.first) ==
1021 selection_indices) {
1022 *classification_result = {kDateCollection,
1023 datetime_span.target_classification_score};
1024 classification_result->datetime_parse_result = datetime_span.data;
1025 return true;
1026 }
1027 }
1028 return false;
1029 }
1030
ClassifyText(const std::string & context,CodepointSpan selection_indices,const ClassificationOptions & options) const1031 std::vector<ClassificationResult> TextClassifier::ClassifyText(
1032 const std::string& context, CodepointSpan selection_indices,
1033 const ClassificationOptions& options) const {
1034 if (!initialized_) {
1035 TC_LOG(ERROR) << "Not initialized";
1036 return {};
1037 }
1038
1039 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
1040 return {};
1041 }
1042
1043 if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
1044 return {};
1045 }
1046
1047 if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
1048 TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
1049 << std::get<0>(selection_indices) << " "
1050 << std::get<1>(selection_indices);
1051 return {};
1052 }
1053
1054 // Try the regular expression models.
1055 ClassificationResult regex_result;
1056 if (RegexClassifyText(context, selection_indices, ®ex_result)) {
1057 if (!FilteredForClassification(regex_result)) {
1058 return {regex_result};
1059 } else {
1060 return {{kOtherCollection, 1.0}};
1061 }
1062 }
1063
1064 // Try the date model.
1065 ClassificationResult datetime_result;
1066 if (DatetimeClassifyText(context, selection_indices, options,
1067 &datetime_result)) {
1068 if (!FilteredForClassification(datetime_result)) {
1069 return {datetime_result};
1070 } else {
1071 return {{kOtherCollection, 1.0}};
1072 }
1073 }
1074
1075 // Fallback to the model.
1076 std::vector<ClassificationResult> model_result;
1077
1078 InterpreterManager interpreter_manager(selection_executor_.get(),
1079 classification_executor_.get());
1080 if (ModelClassifyText(context, selection_indices, &interpreter_manager,
1081 /*embedding_cache=*/nullptr, &model_result) &&
1082 !model_result.empty()) {
1083 if (!FilteredForClassification(model_result[0])) {
1084 return model_result;
1085 } else {
1086 return {{kOtherCollection, 1.0}};
1087 }
1088 }
1089
1090 // No classifications.
1091 return {};
1092 }
1093
ModelAnnotate(const std::string & context,InterpreterManager * interpreter_manager,std::vector<Token> * tokens,std::vector<AnnotatedSpan> * result) const1094 bool TextClassifier::ModelAnnotate(const std::string& context,
1095 InterpreterManager* interpreter_manager,
1096 std::vector<Token>* tokens,
1097 std::vector<AnnotatedSpan>* result) const {
1098 if (model_->triggering_options() == nullptr ||
1099 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
1100 return true;
1101 }
1102
1103 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
1104 /*do_copy=*/false);
1105 std::vector<UnicodeTextRange> lines;
1106 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
1107 lines.push_back({context_unicode.begin(), context_unicode.end()});
1108 } else {
1109 lines = selection_feature_processor_->SplitContext(context_unicode);
1110 }
1111
1112 const float min_annotate_confidence =
1113 (model_->triggering_options() != nullptr
1114 ? model_->triggering_options()->min_annotate_confidence()
1115 : 0.f);
1116
1117 FeatureProcessor::EmbeddingCache embedding_cache;
1118 for (const UnicodeTextRange& line : lines) {
1119 const std::string line_str =
1120 UnicodeText::UTF8Substring(line.first, line.second);
1121
1122 *tokens = selection_feature_processor_->Tokenize(line_str);
1123 selection_feature_processor_->RetokenizeAndFindClick(
1124 line_str, {0, std::distance(line.first, line.second)},
1125 selection_feature_processor_->GetOptions()->only_use_line_with_click(),
1126 tokens,
1127 /*click_pos=*/nullptr);
1128 const TokenSpan full_line_span = {0, tokens->size()};
1129
1130 // TODO(zilka): Add support for greater granularity of this check.
1131 if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
1132 *tokens, full_line_span)) {
1133 continue;
1134 }
1135
1136 std::unique_ptr<CachedFeatures> cached_features;
1137 if (!selection_feature_processor_->ExtractFeatures(
1138 *tokens, full_line_span,
1139 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
1140 embedding_executor_.get(),
1141 /*embedding_cache=*/nullptr,
1142 selection_feature_processor_->EmbeddingSize() +
1143 selection_feature_processor_->DenseFeaturesCount(),
1144 &cached_features)) {
1145 TC_LOG(ERROR) << "Could not extract features.";
1146 return false;
1147 }
1148
1149 std::vector<TokenSpan> local_chunks;
1150 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
1151 interpreter_manager->SelectionInterpreter(),
1152 *cached_features, &local_chunks)) {
1153 TC_LOG(ERROR) << "Could not chunk.";
1154 return false;
1155 }
1156
1157 const int offset = std::distance(context_unicode.begin(), line.first);
1158 for (const TokenSpan& chunk : local_chunks) {
1159 const CodepointSpan codepoint_span =
1160 selection_feature_processor_->StripBoundaryCodepoints(
1161 line_str, TokenSpanToCodepointSpan(*tokens, chunk));
1162
1163 // Skip empty spans.
1164 if (codepoint_span.first != codepoint_span.second) {
1165 std::vector<ClassificationResult> classification;
1166 if (!ModelClassifyText(line_str, *tokens, codepoint_span,
1167 interpreter_manager, &embedding_cache,
1168 &classification)) {
1169 TC_LOG(ERROR) << "Could not classify text: "
1170 << (codepoint_span.first + offset) << " "
1171 << (codepoint_span.second + offset);
1172 return false;
1173 }
1174
1175 // Do not include the span if it's classified as "other".
1176 if (!classification.empty() && !ClassifiedAsOther(classification) &&
1177 classification[0].score >= min_annotate_confidence) {
1178 AnnotatedSpan result_span;
1179 result_span.span = {codepoint_span.first + offset,
1180 codepoint_span.second + offset};
1181 result_span.classification = std::move(classification);
1182 result->push_back(std::move(result_span));
1183 }
1184 }
1185 }
1186 }
1187 return true;
1188 }
1189
SelectionFeatureProcessorForTests() const1190 const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests()
1191 const {
1192 return selection_feature_processor_.get();
1193 }
1194
ClassificationFeatureProcessorForTests() const1195 const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests()
1196 const {
1197 return classification_feature_processor_.get();
1198 }
1199
DatetimeParserForTests() const1200 const DatetimeParser* TextClassifier::DatetimeParserForTests() const {
1201 return datetime_parser_.get();
1202 }
1203
Annotate(const std::string & context,const AnnotationOptions & options) const1204 std::vector<AnnotatedSpan> TextClassifier::Annotate(
1205 const std::string& context, const AnnotationOptions& options) const {
1206 std::vector<AnnotatedSpan> candidates;
1207
1208 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
1209 return {};
1210 }
1211
1212 if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
1213 return {};
1214 }
1215
1216 InterpreterManager interpreter_manager(selection_executor_.get(),
1217 classification_executor_.get());
1218 // Annotate with the selection model.
1219 std::vector<Token> tokens;
1220 if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) {
1221 TC_LOG(ERROR) << "Couldn't run ModelAnnotate.";
1222 return {};
1223 }
1224
1225 // Annotate with the regular expression models.
1226 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1227 annotation_regex_patterns_, &candidates)) {
1228 TC_LOG(ERROR) << "Couldn't run RegexChunk.";
1229 return {};
1230 }
1231
1232 // Annotate with the datetime model.
1233 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
1234 options.reference_time_ms_utc, options.reference_timezone,
1235 options.locales, ModeFlag_ANNOTATION, &candidates)) {
1236 TC_LOG(ERROR) << "Couldn't run RegexChunk.";
1237 return {};
1238 }
1239
1240 // Sort candidates according to their position in the input, so that the next
1241 // code can assume that any connected component of overlapping spans forms a
1242 // contiguous block.
1243 std::sort(candidates.begin(), candidates.end(),
1244 [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
1245 return a.span.first < b.span.first;
1246 });
1247
1248 std::vector<int> candidate_indices;
1249 if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
1250 &candidate_indices)) {
1251 TC_LOG(ERROR) << "Couldn't resolve conflicts.";
1252 return {};
1253 }
1254
1255 std::vector<AnnotatedSpan> result;
1256 result.reserve(candidate_indices.size());
1257 for (const int i : candidate_indices) {
1258 if (!candidates[i].classification.empty() &&
1259 !ClassifiedAsOther(candidates[i].classification) &&
1260 !FilteredForAnnotation(candidates[i])) {
1261 result.push_back(std::move(candidates[i]));
1262 }
1263 }
1264
1265 return result;
1266 }
1267
RegexChunk(const UnicodeText & context_unicode,const std::vector<int> & rules,std::vector<AnnotatedSpan> * result) const1268 bool TextClassifier::RegexChunk(const UnicodeText& context_unicode,
1269 const std::vector<int>& rules,
1270 std::vector<AnnotatedSpan>* result) const {
1271 for (int pattern_id : rules) {
1272 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
1273 const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
1274 if (!matcher) {
1275 TC_LOG(ERROR) << "Could not get regex matcher for pattern: "
1276 << pattern_id;
1277 return false;
1278 }
1279
1280 int status = UniLib::RegexMatcher::kNoError;
1281 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
1282 result->emplace_back();
1283 // Selection/annotation regular expressions need to specify a capturing
1284 // group specifying the selection.
1285 result->back().span = {matcher->Start(1, &status),
1286 matcher->End(1, &status)};
1287 result->back().classification = {
1288 {regex_pattern.collection_name,
1289 regex_pattern.target_classification_score,
1290 regex_pattern.priority_score}};
1291 }
1292 }
1293 return true;
1294 }
1295
ModelChunk(int num_tokens,const TokenSpan & span_of_interest,tflite::Interpreter * selection_interpreter,const CachedFeatures & cached_features,std::vector<TokenSpan> * chunks) const1296 bool TextClassifier::ModelChunk(int num_tokens,
1297 const TokenSpan& span_of_interest,
1298 tflite::Interpreter* selection_interpreter,
1299 const CachedFeatures& cached_features,
1300 std::vector<TokenSpan>* chunks) const {
1301 const int max_selection_span =
1302 selection_feature_processor_->GetOptions()->max_selection_span();
1303 // The inference span is the span of interest expanded to include
1304 // max_selection_span tokens on either side, which is how far a selection can
1305 // stretch from the click.
1306 const TokenSpan inference_span = IntersectTokenSpans(
1307 ExpandTokenSpan(span_of_interest,
1308 /*num_tokens_left=*/max_selection_span,
1309 /*num_tokens_right=*/max_selection_span),
1310 {0, num_tokens});
1311
1312 std::vector<ScoredChunk> scored_chunks;
1313 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
1314 selection_feature_processor_->GetOptions()
1315 ->bounds_sensitive_features()
1316 ->enabled()) {
1317 if (!ModelBoundsSensitiveScoreChunks(
1318 num_tokens, span_of_interest, inference_span, cached_features,
1319 selection_interpreter, &scored_chunks)) {
1320 return false;
1321 }
1322 } else {
1323 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
1324 cached_features, selection_interpreter,
1325 &scored_chunks)) {
1326 return false;
1327 }
1328 }
1329 std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
1330 [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
1331 return lhs.score < rhs.score;
1332 });
1333
1334 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
1335 // them greedily as long as they do not overlap with any previously picked
1336 // chunks.
1337 std::vector<bool> token_used(TokenSpanSize(inference_span));
1338 chunks->clear();
1339 for (const ScoredChunk& scored_chunk : scored_chunks) {
1340 bool feasible = true;
1341 for (int i = scored_chunk.token_span.first;
1342 i < scored_chunk.token_span.second; ++i) {
1343 if (token_used[i - inference_span.first]) {
1344 feasible = false;
1345 break;
1346 }
1347 }
1348
1349 if (!feasible) {
1350 continue;
1351 }
1352
1353 for (int i = scored_chunk.token_span.first;
1354 i < scored_chunk.token_span.second; ++i) {
1355 token_used[i - inference_span.first] = true;
1356 }
1357
1358 chunks->push_back(scored_chunk.token_span);
1359 }
1360
1361 std::sort(chunks->begin(), chunks->end());
1362
1363 return true;
1364 }
1365
1366 namespace {
1367 // Updates the value at the given key in the map to maximum of the current value
1368 // and the given value, or simply inserts the value if the key is not yet there.
1369 template <typename Map>
UpdateMax(Map * map,typename Map::key_type key,typename Map::mapped_type value)1370 void UpdateMax(Map* map, typename Map::key_type key,
1371 typename Map::mapped_type value) {
1372 const auto it = map->find(key);
1373 if (it != map->end()) {
1374 it->second = std::max(it->second, value);
1375 } else {
1376 (*map)[key] = value;
1377 }
1378 }
1379 } // namespace
1380
ModelClickContextScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const1381 bool TextClassifier::ModelClickContextScoreChunks(
1382 int num_tokens, const TokenSpan& span_of_interest,
1383 const CachedFeatures& cached_features,
1384 tflite::Interpreter* selection_interpreter,
1385 std::vector<ScoredChunk>* scored_chunks) const {
1386 const int max_batch_size = model_->selection_options()->batch_size();
1387
1388 std::vector<float> all_features;
1389 std::map<TokenSpan, float> chunk_scores;
1390 for (int batch_start = span_of_interest.first;
1391 batch_start < span_of_interest.second; batch_start += max_batch_size) {
1392 const int batch_end =
1393 std::min(batch_start + max_batch_size, span_of_interest.second);
1394
1395 // Prepare features for the whole batch.
1396 all_features.clear();
1397 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
1398 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
1399 cached_features.AppendClickContextFeaturesForClick(click_pos,
1400 &all_features);
1401 }
1402
1403 // Run batched inference.
1404 const int batch_size = batch_end - batch_start;
1405 const int features_size = cached_features.OutputFeaturesSize();
1406 TensorView<float> logits = selection_executor_->ComputeLogits(
1407 TensorView<float>(all_features.data(), {batch_size, features_size}),
1408 selection_interpreter);
1409 if (!logits.is_valid()) {
1410 TC_LOG(ERROR) << "Couldn't compute logits.";
1411 return false;
1412 }
1413 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
1414 logits.dim(1) !=
1415 selection_feature_processor_->GetSelectionLabelCount()) {
1416 TC_LOG(ERROR) << "Mismatching output.";
1417 return false;
1418 }
1419
1420 // Save results.
1421 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
1422 const std::vector<float> scores = ComputeSoftmax(
1423 logits.data() + logits.dim(1) * (click_pos - batch_start),
1424 logits.dim(1));
1425 for (int j = 0;
1426 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
1427 TokenSpan relative_token_span;
1428 if (!selection_feature_processor_->LabelToTokenSpan(
1429 j, &relative_token_span)) {
1430 TC_LOG(ERROR) << "Couldn't map the label to a token span.";
1431 return false;
1432 }
1433 const TokenSpan candidate_span = ExpandTokenSpan(
1434 SingleTokenSpan(click_pos), relative_token_span.first,
1435 relative_token_span.second);
1436 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
1437 UpdateMax(&chunk_scores, candidate_span, scores[j]);
1438 }
1439 }
1440 }
1441 }
1442
1443 scored_chunks->clear();
1444 scored_chunks->reserve(chunk_scores.size());
1445 for (const auto& entry : chunk_scores) {
1446 scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
1447 }
1448
1449 return true;
1450 }
1451
ModelBoundsSensitiveScoreChunks(int num_tokens,const TokenSpan & span_of_interest,const TokenSpan & inference_span,const CachedFeatures & cached_features,tflite::Interpreter * selection_interpreter,std::vector<ScoredChunk> * scored_chunks) const1452 bool TextClassifier::ModelBoundsSensitiveScoreChunks(
1453 int num_tokens, const TokenSpan& span_of_interest,
1454 const TokenSpan& inference_span, const CachedFeatures& cached_features,
1455 tflite::Interpreter* selection_interpreter,
1456 std::vector<ScoredChunk>* scored_chunks) const {
1457 const int max_selection_span =
1458 selection_feature_processor_->GetOptions()->max_selection_span();
1459 const int max_chunk_length = selection_feature_processor_->GetOptions()
1460 ->selection_reduced_output_space()
1461 ? max_selection_span + 1
1462 : 2 * max_selection_span + 1;
1463 const bool score_single_token_spans_as_zero =
1464 selection_feature_processor_->GetOptions()
1465 ->bounds_sensitive_features()
1466 ->score_single_token_spans_as_zero();
1467
1468 scored_chunks->clear();
1469 if (score_single_token_spans_as_zero) {
1470 scored_chunks->reserve(TokenSpanSize(span_of_interest));
1471 }
1472
1473 // Prepare all chunk candidates into one batch:
1474 // - Are contained in the inference span
1475 // - Have a non-empty intersection with the span of interest
1476 // - Are at least one token long
1477 // - Are not longer than the maximum chunk length
1478 std::vector<TokenSpan> candidate_spans;
1479 for (int start = inference_span.first; start < span_of_interest.second;
1480 ++start) {
1481 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
1482 for (int end = leftmost_end_index;
1483 end <= inference_span.second && end - start <= max_chunk_length;
1484 ++end) {
1485 const TokenSpan candidate_span = {start, end};
1486 if (score_single_token_spans_as_zero &&
1487 TokenSpanSize(candidate_span) == 1) {
1488 // Do not include the single token span in the batch, add a zero score
1489 // for it directly to the output.
1490 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
1491 } else {
1492 candidate_spans.push_back(candidate_span);
1493 }
1494 }
1495 }
1496
1497 const int max_batch_size = model_->selection_options()->batch_size();
1498
1499 std::vector<float> all_features;
1500 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
1501 for (int batch_start = 0; batch_start < candidate_spans.size();
1502 batch_start += max_batch_size) {
1503 const int batch_end = std::min(batch_start + max_batch_size,
1504 static_cast<int>(candidate_spans.size()));
1505
1506 // Prepare features for the whole batch.
1507 all_features.clear();
1508 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
1509 for (int i = batch_start; i < batch_end; ++i) {
1510 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
1511 &all_features);
1512 }
1513
1514 // Run batched inference.
1515 const int batch_size = batch_end - batch_start;
1516 const int features_size = cached_features.OutputFeaturesSize();
1517 TensorView<float> logits = selection_executor_->ComputeLogits(
1518 TensorView<float>(all_features.data(), {batch_size, features_size}),
1519 selection_interpreter);
1520 if (!logits.is_valid()) {
1521 TC_LOG(ERROR) << "Couldn't compute logits.";
1522 return false;
1523 }
1524 if (logits.dims() != 2 || logits.dim(0) != batch_size ||
1525 logits.dim(1) != 1) {
1526 TC_LOG(ERROR) << "Mismatching output.";
1527 return false;
1528 }
1529
1530 // Save results.
1531 for (int i = batch_start; i < batch_end; ++i) {
1532 scored_chunks->push_back(
1533 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
1534 }
1535 }
1536
1537 return true;
1538 }
1539
DatetimeChunk(const UnicodeText & context_unicode,int64 reference_time_ms_utc,const std::string & reference_timezone,const std::string & locales,ModeFlag mode,std::vector<AnnotatedSpan> * result) const1540 bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode,
1541 int64 reference_time_ms_utc,
1542 const std::string& reference_timezone,
1543 const std::string& locales, ModeFlag mode,
1544 std::vector<AnnotatedSpan>* result) const {
1545 if (!datetime_parser_) {
1546 return true;
1547 }
1548
1549 std::vector<DatetimeParseResultSpan> datetime_spans;
1550 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
1551 reference_timezone, locales, mode,
1552 /*anchor_start_end=*/false, &datetime_spans)) {
1553 return false;
1554 }
1555 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
1556 AnnotatedSpan annotated_span;
1557 annotated_span.span = datetime_span.span;
1558 annotated_span.classification = {{kDateCollection,
1559 datetime_span.target_classification_score,
1560 datetime_span.priority_score}};
1561 annotated_span.classification[0].datetime_parse_result = datetime_span.data;
1562
1563 result->push_back(std::move(annotated_span));
1564 }
1565 return true;
1566 }
1567
ViewModel(const void * buffer,int size)1568 const Model* ViewModel(const void* buffer, int size) {
1569 if (!buffer) {
1570 return nullptr;
1571 }
1572
1573 return LoadAndVerifyModel(buffer, size);
1574 }
1575
1576 } // namespace libtextclassifier2
1577