• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/feature-processor.h"
18 
19 #include <iterator>
20 #include <set>
21 #include <vector>
22 
23 #include "utils/base/logging.h"
24 #include "utils/strings/utf8.h"
25 #include "utils/utf8/unicodetext.h"
26 
27 namespace libtextclassifier3 {
28 
29 namespace internal {
30 
BuildTokenizer(const FeatureProcessorOptions * options,const UniLib * unilib)31 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
32                          const UniLib* unilib) {
33   std::vector<const TokenizationCodepointRange*> codepoint_config;
34   if (options->tokenization_codepoint_config() != nullptr) {
35     codepoint_config.insert(codepoint_config.end(),
36                             options->tokenization_codepoint_config()->begin(),
37                             options->tokenization_codepoint_config()->end());
38   }
39   std::vector<const CodepointRange*> internal_codepoint_config;
40   if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
41     internal_codepoint_config.insert(
42         internal_codepoint_config.end(),
43         options->internal_tokenizer_codepoint_ranges()->begin(),
44         options->internal_tokenizer_codepoint_ranges()->end());
45   }
46   const bool tokenize_on_script_change =
47       options->tokenization_codepoint_config() != nullptr &&
48       options->tokenize_on_script_change();
49   return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
50                    internal_codepoint_config, tokenize_on_script_change,
51                    options->icu_preserve_whitespace_tokens());
52 }
53 
BuildTokenFeatureExtractorOptions(const FeatureProcessorOptions * const options)54 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
55     const FeatureProcessorOptions* const options) {
56   TokenFeatureExtractorOptions extractor_options;
57 
58   extractor_options.num_buckets = options->num_buckets();
59   if (options->chargram_orders() != nullptr) {
60     for (int order : *options->chargram_orders()) {
61       extractor_options.chargram_orders.push_back(order);
62     }
63   }
64   extractor_options.max_word_length = options->max_word_length();
65   extractor_options.extract_case_feature = options->extract_case_feature();
66   extractor_options.unicode_aware_features = options->unicode_aware_features();
67   extractor_options.extract_selection_mask_feature =
68       options->extract_selection_mask_feature();
69   if (options->regexp_feature() != nullptr) {
70     for (const auto& regexp_feature : *options->regexp_feature()) {
71       extractor_options.regexp_features.push_back(regexp_feature->str());
72     }
73   }
74   extractor_options.remap_digits = options->remap_digits();
75   extractor_options.lowercase_tokens = options->lowercase_tokens();
76 
77   if (options->allowed_chargrams() != nullptr) {
78     for (const auto& chargram : *options->allowed_chargrams()) {
79       extractor_options.allowed_chargrams.insert(chargram->str());
80     }
81   }
82   return extractor_options;
83 }
84 
SplitTokensOnSelectionBoundaries(const CodepointSpan & selection,std::vector<Token> * tokens)85 void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
86                                       std::vector<Token>* tokens) {
87   for (auto it = tokens->begin(); it != tokens->end(); ++it) {
88     const UnicodeText token_word =
89         UTF8ToUnicodeText(it->value, /*do_copy=*/false);
90 
91     auto last_start = token_word.begin();
92     int last_start_index = it->start;
93     std::vector<UnicodeText::const_iterator> split_points;
94 
95     // Selection start split point.
96     if (selection.first > it->start && selection.first < it->end) {
97       std::advance(last_start, selection.first - last_start_index);
98       split_points.push_back(last_start);
99       last_start_index = selection.first;
100     }
101 
102     // Selection end split point.
103     if (selection.second > it->start && selection.second < it->end) {
104       std::advance(last_start, selection.second - last_start_index);
105       split_points.push_back(last_start);
106     }
107 
108     if (!split_points.empty()) {
109       // Add a final split for the rest of the token unless it's been all
110       // consumed already.
111       if (split_points.back() != token_word.end()) {
112         split_points.push_back(token_word.end());
113       }
114 
115       std::vector<Token> replacement_tokens;
116       last_start = token_word.begin();
117       int current_pos = it->start;
118       for (const auto& split_point : split_points) {
119         Token new_token(token_word.UTF8Substring(last_start, split_point),
120                         current_pos,
121                         current_pos + std::distance(last_start, split_point));
122 
123         last_start = split_point;
124         current_pos = new_token.end;
125 
126         replacement_tokens.push_back(new_token);
127       }
128 
129       it = tokens->erase(it);
130       it = tokens->insert(it, replacement_tokens.begin(),
131                           replacement_tokens.end());
132       std::advance(it, replacement_tokens.size() - 1);
133     }
134   }
135 }
136 
137 }  // namespace internal
138 
StripTokensFromOtherLines(const std::string & context,const CodepointSpan & span,std::vector<Token> * tokens) const139 void FeatureProcessor::StripTokensFromOtherLines(
140     const std::string& context, const CodepointSpan& span,
141     std::vector<Token>* tokens) const {
142   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
143                                                         /*do_copy=*/false);
144   const auto [span_begin, span_end] =
145       CodepointSpanToUnicodeTextRange(context_unicode, span);
146   StripTokensFromOtherLines(context_unicode, span_begin, span_end, span,
147                             tokens);
148 }
149 
StripTokensFromOtherLines(const UnicodeText & context_unicode,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & span,std::vector<Token> * tokens) const150 void FeatureProcessor::StripTokensFromOtherLines(
151     const UnicodeText& context_unicode,
152     const UnicodeText::const_iterator& span_begin,
153     const UnicodeText::const_iterator& span_end, const CodepointSpan& span,
154     std::vector<Token>* tokens) const {
155   std::vector<UnicodeTextRange> lines =
156       SplitContext(context_unicode, options_->use_pipe_character_for_newline());
157 
158   for (const UnicodeTextRange& line : lines) {
159     // Find the line that completely contains the span.
160     if (line.first <= span_begin && line.second >= span_end) {
161       const CodepointIndex last_line_begin_index =
162           std::distance(context_unicode.begin(), line.first);
163       const CodepointIndex last_line_end_index =
164           last_line_begin_index + std::distance(line.first, line.second);
165 
166       for (auto token = tokens->begin(); token != tokens->end();) {
167         if (token->start >= last_line_begin_index &&
168             token->end <= last_line_end_index) {
169           ++token;
170         } else {
171           token = tokens->erase(token);
172         }
173       }
174     }
175   }
176 }
177 
GetDefaultCollection() const178 std::string FeatureProcessor::GetDefaultCollection() const {
179   if (options_->default_collection() < 0 ||
180       options_->collections() == nullptr ||
181       options_->default_collection() >= options_->collections()->size()) {
182     TC3_LOG(ERROR)
183         << "Invalid or missing default collection. Returning empty string.";
184     return "";
185   }
186   return (*options_->collections())[options_->default_collection()]->str();
187 }
188 
Tokenize(const std::string & text) const189 std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
190   return tokenizer_.Tokenize(text);
191 }
192 
Tokenize(const UnicodeText & text_unicode) const193 std::vector<Token> FeatureProcessor::Tokenize(
194     const UnicodeText& text_unicode) const {
195   return tokenizer_.Tokenize(text_unicode);
196 }
197 
LabelToSpan(const int label,const VectorSpan<Token> & tokens,CodepointSpan * span) const198 bool FeatureProcessor::LabelToSpan(const int label,
199                                    const VectorSpan<Token>& tokens,
200                                    CodepointSpan* span) const {
201   if (tokens.size() != GetNumContextTokens()) {
202     return false;
203   }
204 
205   TokenSpan token_span;
206   if (!LabelToTokenSpan(label, &token_span)) {
207     return false;
208   }
209 
210   const int result_begin_token_index = token_span.first;
211   const Token& result_begin_token =
212       tokens[options_->context_size() - result_begin_token_index];
213   const int result_begin_codepoint = result_begin_token.start;
214   const int result_end_token_index = token_span.second;
215   const Token& result_end_token =
216       tokens[options_->context_size() + result_end_token_index];
217   const int result_end_codepoint = result_end_token.end;
218 
219   if (result_begin_codepoint == kInvalidIndex ||
220       result_end_codepoint == kInvalidIndex) {
221     *span = CodepointSpan::kInvalid;
222   } else {
223     const UnicodeText token_begin_unicode =
224         UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
225     UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
226     const UnicodeText token_end_unicode =
227         UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
228     UnicodeText::const_iterator token_end = token_end_unicode.end();
229 
230     const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
231         token_begin, token_begin_unicode.end(),
232         /*count_from_beginning=*/true);
233     const int end_ignored =
234         CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
235                                            /*count_from_beginning=*/false);
236     // In case everything would be stripped, set the span to the original
237     // beginning and zero length.
238     if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
239       *span = {result_begin_codepoint, result_begin_codepoint};
240     } else {
241       *span = CodepointSpan(result_begin_codepoint + begin_ignored,
242                             result_end_codepoint - end_ignored);
243     }
244   }
245   return true;
246 }
247 
LabelToTokenSpan(const int label,TokenSpan * token_span) const248 bool FeatureProcessor::LabelToTokenSpan(const int label,
249                                         TokenSpan* token_span) const {
250   if (label >= 0 && label < label_to_selection_.size()) {
251     *token_span = label_to_selection_[label];
252     return true;
253   } else {
254     return false;
255   }
256 }
257 
SpanToLabel(const CodepointSpan & span,const std::vector<Token> & tokens,int * label) const258 bool FeatureProcessor::SpanToLabel(const CodepointSpan& span,
259                                    const std::vector<Token>& tokens,
260                                    int* label) const {
261   if (tokens.size() != GetNumContextTokens()) {
262     return false;
263   }
264 
265   const int click_position =
266       options_->context_size();  // Click is always in the middle.
267   const int padding = options_->context_size() - options_->max_selection_span();
268 
269   int span_left = 0;
270   for (int i = click_position - 1; i >= padding; i--) {
271     if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
272       ++span_left;
273     } else {
274       break;
275     }
276   }
277 
278   int span_right = 0;
279   for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
280     if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
281       ++span_right;
282     } else {
283       break;
284     }
285   }
286 
287   // Check that the spanned tokens cover the whole span.
288   bool tokens_match_span;
289   const CodepointIndex tokens_start = tokens[click_position - span_left].start;
290   const CodepointIndex tokens_end = tokens[click_position + span_right].end;
291   if (options_->snap_label_span_boundaries_to_containing_tokens()) {
292     tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
293   } else {
294     const UnicodeText token_left_unicode = UTF8ToUnicodeText(
295         tokens[click_position - span_left].value, /*do_copy=*/false);
296     const UnicodeText token_right_unicode = UTF8ToUnicodeText(
297         tokens[click_position + span_right].value, /*do_copy=*/false);
298 
299     UnicodeText::const_iterator span_begin = token_left_unicode.begin();
300     UnicodeText::const_iterator span_end = token_right_unicode.end();
301 
302     const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
303         span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
304     const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
305         token_right_unicode.begin(), span_end,
306         /*count_from_beginning=*/false);
307 
308     tokens_match_span = tokens_start <= span.first &&
309                         tokens_start + num_punctuation_start >= span.first &&
310                         tokens_end >= span.second &&
311                         tokens_end - num_punctuation_end <= span.second;
312   }
313 
314   if (tokens_match_span) {
315     *label = TokenSpanToLabel({span_left, span_right});
316   } else {
317     *label = kInvalidLabel;
318   }
319 
320   return true;
321 }
322 
TokenSpanToLabel(const TokenSpan & token_span) const323 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& token_span) const {
324   auto it = selection_to_label_.find(token_span);
325   if (it != selection_to_label_.end()) {
326     return it->second;
327   } else {
328     return kInvalidLabel;
329   }
330 }
331 
CodepointSpanToTokenSpan(const std::vector<Token> & selectable_tokens,const CodepointSpan & codepoint_span,bool snap_boundaries_to_containing_tokens)332 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
333                                    const CodepointSpan& codepoint_span,
334                                    bool snap_boundaries_to_containing_tokens) {
335   const int codepoint_start = codepoint_span.first;
336   const int codepoint_end = codepoint_span.second;
337 
338   TokenIndex start_token = kInvalidIndex;
339   TokenIndex end_token = kInvalidIndex;
340   for (int i = 0; i < selectable_tokens.size(); ++i) {
341     bool is_token_in_span;
342     if (snap_boundaries_to_containing_tokens) {
343       is_token_in_span = codepoint_start < selectable_tokens[i].end &&
344                          codepoint_end > selectable_tokens[i].start;
345     } else {
346       is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
347                          codepoint_end >= selectable_tokens[i].end;
348     }
349     if (is_token_in_span && !selectable_tokens[i].is_padding) {
350       if (start_token == kInvalidIndex) {
351         start_token = i;
352       }
353       end_token = i + 1;
354     }
355   }
356   return {start_token, end_token};
357 }
358 
TokenSpanToCodepointSpan(const std::vector<Token> & selectable_tokens,const TokenSpan & token_span)359 CodepointSpan TokenSpanToCodepointSpan(
360     const std::vector<Token>& selectable_tokens, const TokenSpan& token_span) {
361   return {selectable_tokens[token_span.first].start,
362           selectable_tokens[token_span.second - 1].end};
363 }
364 
CodepointSpanToUnicodeTextRange(const UnicodeText & unicode_text,const CodepointSpan & span)365 UnicodeTextRange CodepointSpanToUnicodeTextRange(
366     const UnicodeText& unicode_text, const CodepointSpan& span) {
367   auto begin = unicode_text.begin();
368   if (span.first > 0) {
369     std::advance(begin, span.first);
370   }
371   auto end = unicode_text.begin();
372   if (span.second > 0) {
373     std::advance(end, span.second);
374   }
375   return {begin, end};
376 }
377 
378 namespace {
379 
380 // Finds a single token that completely contains the given span.
FindTokenThatContainsSpan(const std::vector<Token> & selectable_tokens,const CodepointSpan & codepoint_span)381 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
382                               const CodepointSpan& codepoint_span) {
383   const int codepoint_start = codepoint_span.first;
384   const int codepoint_end = codepoint_span.second;
385 
386   for (int i = 0; i < selectable_tokens.size(); ++i) {
387     if (codepoint_start >= selectable_tokens[i].start &&
388         codepoint_end <= selectable_tokens[i].end) {
389       return i;
390     }
391   }
392   return kInvalidIndex;
393 }
394 
395 }  // namespace
396 
397 namespace internal {
398 
CenterTokenFromClick(const CodepointSpan & span,const std::vector<Token> & selectable_tokens)399 int CenterTokenFromClick(const CodepointSpan& span,
400                          const std::vector<Token>& selectable_tokens) {
401   const TokenSpan token_span =
402       CodepointSpanToTokenSpan(selectable_tokens, span);
403   int range_begin = token_span.first;
404   int range_end = token_span.second;
405 
406   // If no exact match was found, try finding a token that completely contains
407   // the click span. This is useful e.g. when Android builds the selection
408   // using ICU tokenization, and ends up with only a portion of our space-
409   // separated token. E.g. for "(857)" Android would select "857".
410   if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
411     int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
412     if (token_index != kInvalidIndex) {
413       range_begin = token_index;
414       range_end = token_index + 1;
415     }
416   }
417 
418   // We only allow clicks that are exactly 1 selectable token.
419   if (range_end - range_begin == 1) {
420     return range_begin;
421   } else {
422     return kInvalidIndex;
423   }
424 }
425 
CenterTokenFromMiddleOfSelection(const CodepointSpan & span,const std::vector<Token> & selectable_tokens)426 int CenterTokenFromMiddleOfSelection(
427     const CodepointSpan& span, const std::vector<Token>& selectable_tokens) {
428   const TokenSpan token_span =
429       CodepointSpanToTokenSpan(selectable_tokens, span);
430   const int range_begin = token_span.first;
431   const int range_end = token_span.second;
432 
433   // Center the clicked token in the selection range.
434   if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
435     return (range_begin + range_end - 1) / 2;
436   } else {
437     return kInvalidIndex;
438   }
439 }
440 
441 }  // namespace internal
442 
FindCenterToken(const CodepointSpan & span,const std::vector<Token> & tokens) const443 int FeatureProcessor::FindCenterToken(const CodepointSpan& span,
444                                       const std::vector<Token>& tokens) const {
445   if (options_->center_token_selection_method() ==
446       FeatureProcessorOptions_::
447           CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
448     return internal::CenterTokenFromClick(span, tokens);
449   } else if (options_->center_token_selection_method() ==
450              FeatureProcessorOptions_::
451                  CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
452     return internal::CenterTokenFromMiddleOfSelection(span, tokens);
453   } else if (options_->center_token_selection_method() ==
454              FeatureProcessorOptions_::
455                  CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
456     // TODO(zilka): Remove once we have new models on the device.
457     // It uses the fact that sharing model use
458     // split_tokens_on_selection_boundaries and selection not. So depending on
459     // this we select the right way of finding the click location.
460     if (!options_->split_tokens_on_selection_boundaries()) {
461       // SmartSelection model.
462       return internal::CenterTokenFromClick(span, tokens);
463     } else {
464       // SmartSharing model.
465       return internal::CenterTokenFromMiddleOfSelection(span, tokens);
466     }
467   } else {
468     TC3_LOG(ERROR) << "Invalid center token selection method.";
469     return kInvalidIndex;
470   }
471 }
472 
SelectionLabelSpans(const VectorSpan<Token> tokens,std::vector<CodepointSpan> * selection_label_spans) const473 bool FeatureProcessor::SelectionLabelSpans(
474     const VectorSpan<Token> tokens,
475     std::vector<CodepointSpan>* selection_label_spans) const {
476   for (int i = 0; i < label_to_selection_.size(); ++i) {
477     CodepointSpan span = CodepointSpan::kInvalid;
478     if (!LabelToSpan(i, tokens, &span)) {
479       TC3_LOG(ERROR) << "Could not convert label to span: " << i;
480       return false;
481     }
482     selection_label_spans->push_back(span);
483   }
484   return true;
485 }
486 
SelectionLabelRelativeTokenSpans(std::vector<TokenSpan> * selection_label_relative_token_spans) const487 bool FeatureProcessor::SelectionLabelRelativeTokenSpans(
488     std::vector<TokenSpan>* selection_label_relative_token_spans) const {
489   selection_label_relative_token_spans->assign(label_to_selection_.begin(),
490                                                label_to_selection_.end());
491   return true;
492 }
493 
PrepareIgnoredSpanBoundaryCodepoints()494 void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
495   if (options_->ignored_span_boundary_codepoints() != nullptr) {
496     for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
497       ignored_span_boundary_codepoints_.insert(codepoint);
498     }
499   }
500 }
501 
CountIgnoredSpanBoundaryCodepoints(const UnicodeText::const_iterator & span_start,const UnicodeText::const_iterator & span_end,bool count_from_beginning) const502 int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
503     const UnicodeText::const_iterator& span_start,
504     const UnicodeText::const_iterator& span_end,
505     bool count_from_beginning) const {
506   if (span_start == span_end) {
507     return 0;
508   }
509 
510   UnicodeText::const_iterator it;
511   UnicodeText::const_iterator it_last;
512   if (count_from_beginning) {
513     it = span_start;
514     it_last = span_end;
515     // We can assume that the string is non-zero length because of the check
516     // above, thus the decrement is always valid here.
517     --it_last;
518   } else {
519     it = span_end;
520     it_last = span_start;
521     // We can assume that the string is non-zero length because of the check
522     // above, thus the decrement is always valid here.
523     --it;
524   }
525 
526   // Move until we encounter a non-ignored character.
527   int num_ignored = 0;
528   while (ignored_span_boundary_codepoints_.find(*it) !=
529          ignored_span_boundary_codepoints_.end()) {
530     ++num_ignored;
531 
532     if (it == it_last) {
533       break;
534     }
535 
536     if (count_from_beginning) {
537       ++it;
538     } else {
539       --it;
540     }
541   }
542 
543   return num_ignored;
544 }
545 
546 namespace {
547 
FindSubstrings(const UnicodeText & t,const std::set<char32> & codepoints,std::vector<UnicodeTextRange> * ranges)548 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
549                     std::vector<UnicodeTextRange>* ranges) {
550   UnicodeText::const_iterator start = t.begin();
551   UnicodeText::const_iterator curr = start;
552   UnicodeText::const_iterator end = t.end();
553   for (; curr != end; ++curr) {
554     if (codepoints.find(*curr) != codepoints.end()) {
555       if (start != curr) {
556         ranges->push_back(std::make_pair(start, curr));
557       }
558       start = curr;
559       ++start;
560     }
561   }
562   if (start != end) {
563     ranges->push_back(std::make_pair(start, end));
564   }
565 }
566 
567 }  // namespace
568 
SplitContext(const UnicodeText & context_unicode,const bool use_pipe_character_for_newline) const569 std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
570     const UnicodeText& context_unicode,
571     const bool use_pipe_character_for_newline) const {
572   std::vector<UnicodeTextRange> lines;
573   std::set<char32> codepoints{'\n'};
574   if (use_pipe_character_for_newline) {
575     codepoints.insert('|');
576   }
577   FindSubstrings(context_unicode, codepoints, &lines);
578   return lines;
579 }
580 
StripBoundaryCodepoints(const std::string & context,const CodepointSpan & span) const581 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
582     const std::string& context, const CodepointSpan& span) const {
583   const UnicodeText context_unicode =
584       UTF8ToUnicodeText(context, /*do_copy=*/false);
585   return StripBoundaryCodepoints(context_unicode, span);
586 }
587 
StripBoundaryCodepoints(const UnicodeText & context_unicode,const CodepointSpan & span) const588 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
589     const UnicodeText& context_unicode, const CodepointSpan& span) const {
590   if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
591     return span;
592   }
593 
594   const auto [span_begin, span_end] =
595       CodepointSpanToUnicodeTextRange(context_unicode, span);
596 
597   return StripBoundaryCodepoints(span_begin, span_end, span);
598 }
599 
StripBoundaryCodepoints(const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & span) const600 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
601     const UnicodeText::const_iterator& span_begin,
602     const UnicodeText::const_iterator& span_end,
603     const CodepointSpan& span) const {
604   if (!span.IsValid() || span.IsEmpty() || span_begin == span_end) {
605     return span;
606   }
607 
608   const int start_offset = CountIgnoredSpanBoundaryCodepoints(
609       span_begin, span_end, /*count_from_beginning=*/true);
610   const int end_offset = CountIgnoredSpanBoundaryCodepoints(
611       span_begin, span_end, /*count_from_beginning=*/false);
612 
613   if (span.first + start_offset < span.second - end_offset) {
614     return {span.first + start_offset, span.second - end_offset};
615   } else {
616     return {span.first, span.first};
617   }
618 }
619 
SupportedCodepointsRatio(const TokenSpan & token_span,const std::vector<Token> & tokens) const620 float FeatureProcessor::SupportedCodepointsRatio(
621     const TokenSpan& token_span, const std::vector<Token>& tokens) const {
622   int num_supported = 0;
623   int num_total = 0;
624   for (int i = token_span.first; i < token_span.second; ++i) {
625     const UnicodeText value =
626         UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
627     for (auto codepoint : value) {
628       if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
629         ++num_supported;
630       }
631       ++num_total;
632     }
633   }
634   // Avoid division by zero.
635   if (num_total == 0) {
636     return 0.0;
637   }
638   return static_cast<float>(num_supported) / static_cast<float>(num_total);
639 }
640 
StripBoundaryCodepoints(const std::string & value,std::string * buffer) const641 const std::string& FeatureProcessor::StripBoundaryCodepoints(
642     const std::string& value, std::string* buffer) const {
643   const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
644   const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
645   const CodepointSpan stripped_span =
646       StripBoundaryCodepoints(value_unicode, initial_span);
647 
648   if (initial_span != stripped_span) {
649     const UnicodeText stripped_token_value =
650         UnicodeText::Substring(value_unicode, stripped_span.first,
651                                stripped_span.second, /*do_copy=*/false);
652     *buffer = stripped_token_value.ToUTF8String();
653     return *buffer;
654   }
655   return value;
656 }
657 
CollectionToLabel(const std::string & collection) const658 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
659   const auto it = collection_to_label_.find(collection);
660   if (it == collection_to_label_.end()) {
661     return options_->default_collection();
662   } else {
663     return it->second;
664   }
665 }
666 
LabelToCollection(int label) const667 std::string FeatureProcessor::LabelToCollection(int label) const {
668   if (label >= 0 && label < collection_to_label_.size()) {
669     return (*options_->collections())[label]->str();
670   } else {
671     return GetDefaultCollection();
672   }
673 }
674 
MakeLabelMaps()675 void FeatureProcessor::MakeLabelMaps() {
676   if (options_->collections() != nullptr) {
677     for (int i = 0; i < options_->collections()->size(); ++i) {
678       collection_to_label_[(*options_->collections())[i]->str()] = i;
679     }
680   }
681 
682   int selection_label_id = 0;
683   for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
684     for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
685       if (!options_->selection_reduced_output_space() ||
686           r + l <= options_->max_selection_span()) {
687         TokenSpan token_span{l, r};
688         selection_to_label_[token_span] = selection_label_id;
689         label_to_selection_.push_back(token_span);
690         ++selection_label_id;
691       }
692     }
693   }
694 }
695 
RetokenizeAndFindClick(const std::string & context,const CodepointSpan & input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const696 void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
697                                               const CodepointSpan& input_span,
698                                               bool only_use_line_with_click,
699                                               std::vector<Token>* tokens,
700                                               int* click_pos) const {
701   const UnicodeText context_unicode =
702       UTF8ToUnicodeText(context, /*do_copy=*/false);
703   const auto [span_begin, span_end] =
704       CodepointSpanToUnicodeTextRange(context_unicode, input_span);
705   RetokenizeAndFindClick(context_unicode, span_begin, span_end, input_span,
706                          only_use_line_with_click, tokens, click_pos);
707 }
708 
RetokenizeAndFindClick(const UnicodeText & context_unicode,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const709 void FeatureProcessor::RetokenizeAndFindClick(
710     const UnicodeText& context_unicode,
711     const UnicodeText::const_iterator& span_begin,
712     const UnicodeText::const_iterator& span_end,
713     const CodepointSpan& input_span, bool only_use_line_with_click,
714     std::vector<Token>* tokens, int* click_pos) const {
715   TC3_CHECK(tokens != nullptr);
716 
717   if (options_->split_tokens_on_selection_boundaries()) {
718     internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
719   }
720 
721   if (only_use_line_with_click) {
722     StripTokensFromOtherLines(context_unicode, span_begin, span_end, input_span,
723                               tokens);
724   }
725 
726   int local_click_pos;
727   if (click_pos == nullptr) {
728     click_pos = &local_click_pos;
729   }
730   *click_pos = FindCenterToken(input_span, *tokens);
731   if (*click_pos == kInvalidIndex) {
732     // If the default click method failed, let's try to do sub-token matching
733     // before we fail.
734     *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
735   }
736 }
737 
738 namespace internal {
739 
StripOrPadTokens(const TokenSpan & relative_click_span,int context_size,std::vector<Token> * tokens,int * click_pos)740 void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
741                       std::vector<Token>* tokens, int* click_pos) {
742   int right_context_needed = relative_click_span.second + context_size;
743   if (*click_pos + right_context_needed + 1 >= tokens->size()) {
744     // Pad max the context size.
745     const int num_pad_tokens = std::min(
746         context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
747                                        tokens->size()));
748     std::vector<Token> pad_tokens(num_pad_tokens);
749     tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
750   } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
751     // Strip unused tokens.
752     auto it = tokens->begin();
753     std::advance(it, *click_pos + right_context_needed + 1);
754     tokens->erase(it, tokens->end());
755   }
756 
757   int left_context_needed = relative_click_span.first + context_size;
758   if (*click_pos < left_context_needed) {
759     // Pad max the context size.
760     const int num_pad_tokens =
761         std::min(context_size, left_context_needed - *click_pos);
762     std::vector<Token> pad_tokens(num_pad_tokens);
763     tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
764     *click_pos += num_pad_tokens;
765   } else if (*click_pos > left_context_needed) {
766     // Strip unused tokens.
767     auto it = tokens->begin();
768     std::advance(it, *click_pos - left_context_needed);
769     *click_pos -= it - tokens->begin();
770     tokens->erase(tokens->begin(), it);
771   }
772 }
773 
774 }  // namespace internal
775 
HasEnoughSupportedCodepoints(const std::vector<Token> & tokens,const TokenSpan & token_span) const776 bool FeatureProcessor::HasEnoughSupportedCodepoints(
777     const std::vector<Token>& tokens, const TokenSpan& token_span) const {
778   if (options_->min_supported_codepoint_ratio() > 0) {
779     const float supported_codepoint_ratio =
780         SupportedCodepointsRatio(token_span, tokens);
781     if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
782       TC3_VLOG(1) << "Not enough supported codepoints in the context: "
783                   << supported_codepoint_ratio;
784       return false;
785     }
786   }
787   return true;
788 }
789 
ExtractFeatures(const std::vector<Token> & tokens,const TokenSpan & token_span,const CodepointSpan & selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,int feature_vector_size,std::unique_ptr<CachedFeatures> * cached_features) const790 bool FeatureProcessor::ExtractFeatures(
791     const std::vector<Token>& tokens, const TokenSpan& token_span,
792     const CodepointSpan& selection_span_for_feature,
793     const EmbeddingExecutor* embedding_executor,
794     EmbeddingCache* embedding_cache, int feature_vector_size,
795     std::unique_ptr<CachedFeatures>* cached_features) const {
796   std::unique_ptr<std::vector<float>> features(new std::vector<float>());
797   features->reserve(feature_vector_size * token_span.Size());
798   for (int i = token_span.first; i < token_span.second; ++i) {
799     if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
800                                       embedding_executor, embedding_cache,
801                                       features.get())) {
802       TC3_LOG(ERROR) << "Could not get token features.";
803       return false;
804     }
805   }
806 
807   std::unique_ptr<std::vector<float>> padding_features(
808       new std::vector<float>());
809   padding_features->reserve(feature_vector_size);
810   if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
811                                     embedding_executor, embedding_cache,
812                                     padding_features.get())) {
813     TC3_LOG(ERROR) << "Count not get padding token features.";
814     return false;
815   }
816 
817   *cached_features = CachedFeatures::Create(token_span, std::move(features),
818                                             std::move(padding_features),
819                                             options_, feature_vector_size);
820   if (!*cached_features) {
821     TC3_LOG(ERROR) << "Cound not create cached features.";
822     return false;
823   }
824 
825   return true;
826 }
827 
AppendTokenFeaturesWithCache(const Token & token,const CodepointSpan & selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,std::vector<float> * output_features) const828 bool FeatureProcessor::AppendTokenFeaturesWithCache(
829     const Token& token, const CodepointSpan& selection_span_for_feature,
830     const EmbeddingExecutor* embedding_executor,
831     EmbeddingCache* embedding_cache,
832     std::vector<float>* output_features) const {
833   // Look for the embedded features for the token in the cache, if there is one.
834   if (embedding_cache) {
835     const auto it = embedding_cache->find({token.start, token.end});
836     if (it != embedding_cache->end()) {
837       // The embedded features were found in the cache, extract only the dense
838       // features.
839       std::vector<float> dense_features;
840       if (!feature_extractor_.Extract(
841               token, token.IsContainedInSpan(selection_span_for_feature),
842               /*sparse_features=*/nullptr, &dense_features)) {
843         TC3_LOG(ERROR) << "Could not extract token's dense features.";
844         return false;
845       }
846 
847       // Append both embedded and dense features to the output and return.
848       output_features->insert(output_features->end(), it->second.begin(),
849                               it->second.end());
850       output_features->insert(output_features->end(), dense_features.begin(),
851                               dense_features.end());
852       return true;
853     }
854   }
855 
856   // Extract the sparse and dense features.
857   std::vector<int> sparse_features;
858   std::vector<float> dense_features;
859   if (!feature_extractor_.Extract(
860           token, token.IsContainedInSpan(selection_span_for_feature),
861           &sparse_features, &dense_features)) {
862     TC3_LOG(ERROR) << "Could not extract token's features.";
863     return false;
864   }
865 
866   // Embed the sparse features, appending them directly to the output.
867   const int embedding_size = GetOptions()->embedding_size();
868   output_features->resize(output_features->size() + embedding_size);
869   float* output_features_end =
870       output_features->data() + output_features->size();
871   if (!embedding_executor->AddEmbedding(
872           TensorView<int>(sparse_features.data(),
873                           {static_cast<int>(sparse_features.size())}),
874           /*dest=*/output_features_end - embedding_size,
875           /*dest_size=*/embedding_size)) {
876     TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
877     return false;
878   }
879 
880   // If there is a cache, the embedded features for the token were not in it,
881   // so insert them.
882   if (embedding_cache) {
883     (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
884         output_features_end - embedding_size, output_features_end);
885   }
886 
887   // Append the dense features to the output.
888   output_features->insert(output_features->end(), dense_features.begin(),
889                           dense_features.end());
890   return true;
891 }
892 
893 }  // namespace libtextclassifier3
894