• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "feature-processor.h"
18 
19 #include <iterator>
20 #include <set>
21 #include <vector>
22 
23 #include "util/base/logging.h"
24 #include "util/strings/utf8.h"
25 #include "util/utf8/unicodetext.h"
26 
27 namespace libtextclassifier2 {
28 
29 namespace internal {
30 
BuildTokenFeatureExtractorOptions(const FeatureProcessorOptions * const options)31 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
32     const FeatureProcessorOptions* const options) {
33   TokenFeatureExtractorOptions extractor_options;
34 
35   extractor_options.num_buckets = options->num_buckets();
36   if (options->chargram_orders() != nullptr) {
37     for (int order : *options->chargram_orders()) {
38       extractor_options.chargram_orders.push_back(order);
39     }
40   }
41   extractor_options.max_word_length = options->max_word_length();
42   extractor_options.extract_case_feature = options->extract_case_feature();
43   extractor_options.unicode_aware_features = options->unicode_aware_features();
44   extractor_options.extract_selection_mask_feature =
45       options->extract_selection_mask_feature();
46   if (options->regexp_feature() != nullptr) {
47     for (const auto& regexp_feauture : *options->regexp_feature()) {
48       extractor_options.regexp_features.push_back(regexp_feauture->str());
49     }
50   }
51   extractor_options.remap_digits = options->remap_digits();
52   extractor_options.lowercase_tokens = options->lowercase_tokens();
53 
54   if (options->allowed_chargrams() != nullptr) {
55     for (const auto& chargram : *options->allowed_chargrams()) {
56       extractor_options.allowed_chargrams.insert(chargram->str());
57     }
58   }
59   return extractor_options;
60 }
61 
SplitTokensOnSelectionBoundaries(CodepointSpan selection,std::vector<Token> * tokens)62 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
63                                       std::vector<Token>* tokens) {
64   for (auto it = tokens->begin(); it != tokens->end(); ++it) {
65     const UnicodeText token_word =
66         UTF8ToUnicodeText(it->value, /*do_copy=*/false);
67 
68     auto last_start = token_word.begin();
69     int last_start_index = it->start;
70     std::vector<UnicodeText::const_iterator> split_points;
71 
72     // Selection start split point.
73     if (selection.first > it->start && selection.first < it->end) {
74       std::advance(last_start, selection.first - last_start_index);
75       split_points.push_back(last_start);
76       last_start_index = selection.first;
77     }
78 
79     // Selection end split point.
80     if (selection.second > it->start && selection.second < it->end) {
81       std::advance(last_start, selection.second - last_start_index);
82       split_points.push_back(last_start);
83     }
84 
85     if (!split_points.empty()) {
86       // Add a final split for the rest of the token unless it's been all
87       // consumed already.
88       if (split_points.back() != token_word.end()) {
89         split_points.push_back(token_word.end());
90       }
91 
92       std::vector<Token> replacement_tokens;
93       last_start = token_word.begin();
94       int current_pos = it->start;
95       for (const auto& split_point : split_points) {
96         Token new_token(token_word.UTF8Substring(last_start, split_point),
97                         current_pos,
98                         current_pos + std::distance(last_start, split_point));
99 
100         last_start = split_point;
101         current_pos = new_token.end;
102 
103         replacement_tokens.push_back(new_token);
104       }
105 
106       it = tokens->erase(it);
107       it = tokens->insert(it, replacement_tokens.begin(),
108                           replacement_tokens.end());
109       std::advance(it, replacement_tokens.size() - 1);
110     }
111   }
112 }
113 
MaybeCreateUnilib(const UniLib * unilib,std::unique_ptr<UniLib> * owned_unilib)114 const UniLib* MaybeCreateUnilib(const UniLib* unilib,
115                                 std::unique_ptr<UniLib>* owned_unilib) {
116   if (unilib) {
117     return unilib;
118   } else {
119     owned_unilib->reset(new UniLib);
120     return owned_unilib->get();
121   }
122 }
123 
124 }  // namespace internal
125 
StripTokensFromOtherLines(const std::string & context,CodepointSpan span,std::vector<Token> * tokens) const126 void FeatureProcessor::StripTokensFromOtherLines(
127     const std::string& context, CodepointSpan span,
128     std::vector<Token>* tokens) const {
129   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
130                                                         /*do_copy=*/false);
131   StripTokensFromOtherLines(context_unicode, span, tokens);
132 }
133 
StripTokensFromOtherLines(const UnicodeText & context_unicode,CodepointSpan span,std::vector<Token> * tokens) const134 void FeatureProcessor::StripTokensFromOtherLines(
135     const UnicodeText& context_unicode, CodepointSpan span,
136     std::vector<Token>* tokens) const {
137   std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
138 
139   auto span_start = context_unicode.begin();
140   if (span.first > 0) {
141     std::advance(span_start, span.first);
142   }
143   auto span_end = context_unicode.begin();
144   if (span.second > 0) {
145     std::advance(span_end, span.second);
146   }
147   for (const UnicodeTextRange& line : lines) {
148     // Find the line that completely contains the span.
149     if (line.first <= span_start && line.second >= span_end) {
150       const CodepointIndex last_line_begin_index =
151           std::distance(context_unicode.begin(), line.first);
152       const CodepointIndex last_line_end_index =
153           last_line_begin_index + std::distance(line.first, line.second);
154 
155       for (auto token = tokens->begin(); token != tokens->end();) {
156         if (token->start >= last_line_begin_index &&
157             token->end <= last_line_end_index) {
158           ++token;
159         } else {
160           token = tokens->erase(token);
161         }
162       }
163     }
164   }
165 }
166 
GetDefaultCollection() const167 std::string FeatureProcessor::GetDefaultCollection() const {
168   if (options_->default_collection() < 0 ||
169       options_->collections() == nullptr ||
170       options_->default_collection() >= options_->collections()->size()) {
171     TC_LOG(ERROR)
172         << "Invalid or missing default collection. Returning empty string.";
173     return "";
174   }
175   return (*options_->collections())[options_->default_collection()]->str();
176 }
177 
Tokenize(const std::string & text) const178 std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
179   const UnicodeText text_unicode = UTF8ToUnicodeText(text, /*do_copy=*/false);
180   return Tokenize(text_unicode);
181 }
182 
Tokenize(const UnicodeText & text_unicode) const183 std::vector<Token> FeatureProcessor::Tokenize(
184     const UnicodeText& text_unicode) const {
185   if (options_->tokenization_type() ==
186       FeatureProcessorOptions_::TokenizationType_INTERNAL_TOKENIZER) {
187     return tokenizer_.Tokenize(text_unicode);
188   } else if (options_->tokenization_type() ==
189                  FeatureProcessorOptions_::TokenizationType_ICU ||
190              options_->tokenization_type() ==
191                  FeatureProcessorOptions_::TokenizationType_MIXED) {
192     std::vector<Token> result;
193     if (!ICUTokenize(text_unicode, &result)) {
194       return {};
195     }
196     if (options_->tokenization_type() ==
197         FeatureProcessorOptions_::TokenizationType_MIXED) {
198       InternalRetokenize(text_unicode, &result);
199     }
200     return result;
201   } else {
202     TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
203                      "internal.";
204     return tokenizer_.Tokenize(text_unicode);
205   }
206 }
207 
LabelToSpan(const int label,const VectorSpan<Token> & tokens,std::pair<CodepointIndex,CodepointIndex> * span) const208 bool FeatureProcessor::LabelToSpan(
209     const int label, const VectorSpan<Token>& tokens,
210     std::pair<CodepointIndex, CodepointIndex>* span) const {
211   if (tokens.size() != GetNumContextTokens()) {
212     return false;
213   }
214 
215   TokenSpan token_span;
216   if (!LabelToTokenSpan(label, &token_span)) {
217     return false;
218   }
219 
220   const int result_begin_token_index = token_span.first;
221   const Token& result_begin_token =
222       tokens[options_->context_size() - result_begin_token_index];
223   const int result_begin_codepoint = result_begin_token.start;
224   const int result_end_token_index = token_span.second;
225   const Token& result_end_token =
226       tokens[options_->context_size() + result_end_token_index];
227   const int result_end_codepoint = result_end_token.end;
228 
229   if (result_begin_codepoint == kInvalidIndex ||
230       result_end_codepoint == kInvalidIndex) {
231     *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
232   } else {
233     const UnicodeText token_begin_unicode =
234         UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
235     UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
236     const UnicodeText token_end_unicode =
237         UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
238     UnicodeText::const_iterator token_end = token_end_unicode.end();
239 
240     const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
241         token_begin, token_begin_unicode.end(),
242         /*count_from_beginning=*/true);
243     const int end_ignored =
244         CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
245                                            /*count_from_beginning=*/false);
246     // In case everything would be stripped, set the span to the original
247     // beginning and zero length.
248     if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
249       *span = {result_begin_codepoint, result_begin_codepoint};
250     } else {
251       *span = CodepointSpan({result_begin_codepoint + begin_ignored,
252                              result_end_codepoint - end_ignored});
253     }
254   }
255   return true;
256 }
257 
LabelToTokenSpan(const int label,TokenSpan * token_span) const258 bool FeatureProcessor::LabelToTokenSpan(const int label,
259                                         TokenSpan* token_span) const {
260   if (label >= 0 && label < label_to_selection_.size()) {
261     *token_span = label_to_selection_[label];
262     return true;
263   } else {
264     return false;
265   }
266 }
267 
SpanToLabel(const std::pair<CodepointIndex,CodepointIndex> & span,const std::vector<Token> & tokens,int * label) const268 bool FeatureProcessor::SpanToLabel(
269     const std::pair<CodepointIndex, CodepointIndex>& span,
270     const std::vector<Token>& tokens, int* label) const {
271   if (tokens.size() != GetNumContextTokens()) {
272     return false;
273   }
274 
275   const int click_position =
276       options_->context_size();  // Click is always in the middle.
277   const int padding = options_->context_size() - options_->max_selection_span();
278 
279   int span_left = 0;
280   for (int i = click_position - 1; i >= padding; i--) {
281     if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
282       ++span_left;
283     } else {
284       break;
285     }
286   }
287 
288   int span_right = 0;
289   for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
290     if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
291       ++span_right;
292     } else {
293       break;
294     }
295   }
296 
297   // Check that the spanned tokens cover the whole span.
298   bool tokens_match_span;
299   const CodepointIndex tokens_start = tokens[click_position - span_left].start;
300   const CodepointIndex tokens_end = tokens[click_position + span_right].end;
301   if (options_->snap_label_span_boundaries_to_containing_tokens()) {
302     tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
303   } else {
304     const UnicodeText token_left_unicode = UTF8ToUnicodeText(
305         tokens[click_position - span_left].value, /*do_copy=*/false);
306     const UnicodeText token_right_unicode = UTF8ToUnicodeText(
307         tokens[click_position + span_right].value, /*do_copy=*/false);
308 
309     UnicodeText::const_iterator span_begin = token_left_unicode.begin();
310     UnicodeText::const_iterator span_end = token_right_unicode.end();
311 
312     const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
313         span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
314     const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
315         token_right_unicode.begin(), span_end,
316         /*count_from_beginning=*/false);
317 
318     tokens_match_span = tokens_start <= span.first &&
319                         tokens_start + num_punctuation_start >= span.first &&
320                         tokens_end >= span.second &&
321                         tokens_end - num_punctuation_end <= span.second;
322   }
323 
324   if (tokens_match_span) {
325     *label = TokenSpanToLabel({span_left, span_right});
326   } else {
327     *label = kInvalidLabel;
328   }
329 
330   return true;
331 }
332 
TokenSpanToLabel(const TokenSpan & span) const333 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
334   auto it = selection_to_label_.find(span);
335   if (it != selection_to_label_.end()) {
336     return it->second;
337   } else {
338     return kInvalidLabel;
339   }
340 }
341 
CodepointSpanToTokenSpan(const std::vector<Token> & selectable_tokens,CodepointSpan codepoint_span,bool snap_boundaries_to_containing_tokens)342 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
343                                    CodepointSpan codepoint_span,
344                                    bool snap_boundaries_to_containing_tokens) {
345   const int codepoint_start = std::get<0>(codepoint_span);
346   const int codepoint_end = std::get<1>(codepoint_span);
347 
348   TokenIndex start_token = kInvalidIndex;
349   TokenIndex end_token = kInvalidIndex;
350   for (int i = 0; i < selectable_tokens.size(); ++i) {
351     bool is_token_in_span;
352     if (snap_boundaries_to_containing_tokens) {
353       is_token_in_span = codepoint_start < selectable_tokens[i].end &&
354                          codepoint_end > selectable_tokens[i].start;
355     } else {
356       is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
357                          codepoint_end >= selectable_tokens[i].end;
358     }
359     if (is_token_in_span && !selectable_tokens[i].is_padding) {
360       if (start_token == kInvalidIndex) {
361         start_token = i;
362       }
363       end_token = i + 1;
364     }
365   }
366   return {start_token, end_token};
367 }
368 
TokenSpanToCodepointSpan(const std::vector<Token> & selectable_tokens,TokenSpan token_span)369 CodepointSpan TokenSpanToCodepointSpan(
370     const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
371   return {selectable_tokens[token_span.first].start,
372           selectable_tokens[token_span.second - 1].end};
373 }
374 
375 namespace {
376 
377 // Finds a single token that completely contains the given span.
FindTokenThatContainsSpan(const std::vector<Token> & selectable_tokens,CodepointSpan codepoint_span)378 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
379                               CodepointSpan codepoint_span) {
380   const int codepoint_start = std::get<0>(codepoint_span);
381   const int codepoint_end = std::get<1>(codepoint_span);
382 
383   for (int i = 0; i < selectable_tokens.size(); ++i) {
384     if (codepoint_start >= selectable_tokens[i].start &&
385         codepoint_end <= selectable_tokens[i].end) {
386       return i;
387     }
388   }
389   return kInvalidIndex;
390 }
391 
392 }  // namespace
393 
394 namespace internal {
395 
CenterTokenFromClick(CodepointSpan span,const std::vector<Token> & selectable_tokens)396 int CenterTokenFromClick(CodepointSpan span,
397                          const std::vector<Token>& selectable_tokens) {
398   int range_begin;
399   int range_end;
400   std::tie(range_begin, range_end) =
401       CodepointSpanToTokenSpan(selectable_tokens, span);
402 
403   // If no exact match was found, try finding a token that completely contains
404   // the click span. This is useful e.g. when Android builds the selection
405   // using ICU tokenization, and ends up with only a portion of our space-
406   // separated token. E.g. for "(857)" Android would select "857".
407   if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
408     int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
409     if (token_index != kInvalidIndex) {
410       range_begin = token_index;
411       range_end = token_index + 1;
412     }
413   }
414 
415   // We only allow clicks that are exactly 1 selectable token.
416   if (range_end - range_begin == 1) {
417     return range_begin;
418   } else {
419     return kInvalidIndex;
420   }
421 }
422 
CenterTokenFromMiddleOfSelection(CodepointSpan span,const std::vector<Token> & selectable_tokens)423 int CenterTokenFromMiddleOfSelection(
424     CodepointSpan span, const std::vector<Token>& selectable_tokens) {
425   int range_begin;
426   int range_end;
427   std::tie(range_begin, range_end) =
428       CodepointSpanToTokenSpan(selectable_tokens, span);
429 
430   // Center the clicked token in the selection range.
431   if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
432     return (range_begin + range_end - 1) / 2;
433   } else {
434     return kInvalidIndex;
435   }
436 }
437 
438 }  // namespace internal
439 
FindCenterToken(CodepointSpan span,const std::vector<Token> & tokens) const440 int FeatureProcessor::FindCenterToken(CodepointSpan span,
441                                       const std::vector<Token>& tokens) const {
442   if (options_->center_token_selection_method() ==
443       FeatureProcessorOptions_::
444           CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
445     return internal::CenterTokenFromClick(span, tokens);
446   } else if (options_->center_token_selection_method() ==
447              FeatureProcessorOptions_::
448                  CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
449     return internal::CenterTokenFromMiddleOfSelection(span, tokens);
450   } else if (options_->center_token_selection_method() ==
451              FeatureProcessorOptions_::
452                  CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
453     // TODO(zilka): Remove once we have new models on the device.
454     // It uses the fact that sharing model use
455     // split_tokens_on_selection_boundaries and selection not. So depending on
456     // this we select the right way of finding the click location.
457     if (!options_->split_tokens_on_selection_boundaries()) {
458       // SmartSelection model.
459       return internal::CenterTokenFromClick(span, tokens);
460     } else {
461       // SmartSharing model.
462       return internal::CenterTokenFromMiddleOfSelection(span, tokens);
463     }
464   } else {
465     TC_LOG(ERROR) << "Invalid center token selection method.";
466     return kInvalidIndex;
467   }
468 }
469 
SelectionLabelSpans(const VectorSpan<Token> tokens,std::vector<CodepointSpan> * selection_label_spans) const470 bool FeatureProcessor::SelectionLabelSpans(
471     const VectorSpan<Token> tokens,
472     std::vector<CodepointSpan>* selection_label_spans) const {
473   for (int i = 0; i < label_to_selection_.size(); ++i) {
474     CodepointSpan span;
475     if (!LabelToSpan(i, tokens, &span)) {
476       TC_LOG(ERROR) << "Could not convert label to span: " << i;
477       return false;
478     }
479     selection_label_spans->push_back(span);
480   }
481   return true;
482 }
483 
PrepareCodepointRanges(const std::vector<const FeatureProcessorOptions_::CodepointRange * > & codepoint_ranges,std::vector<CodepointRange> * prepared_codepoint_ranges)484 void FeatureProcessor::PrepareCodepointRanges(
485     const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
486         codepoint_ranges,
487     std::vector<CodepointRange>* prepared_codepoint_ranges) {
488   prepared_codepoint_ranges->clear();
489   prepared_codepoint_ranges->reserve(codepoint_ranges.size());
490   for (const FeatureProcessorOptions_::CodepointRange* range :
491        codepoint_ranges) {
492     prepared_codepoint_ranges->push_back(
493         CodepointRange(range->start(), range->end()));
494   }
495 
496   std::sort(prepared_codepoint_ranges->begin(),
497             prepared_codepoint_ranges->end(),
498             [](const CodepointRange& a, const CodepointRange& b) {
499               return a.start < b.start;
500             });
501 }
502 
PrepareIgnoredSpanBoundaryCodepoints()503 void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
504   if (options_->ignored_span_boundary_codepoints() != nullptr) {
505     for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
506       ignored_span_boundary_codepoints_.insert(codepoint);
507     }
508   }
509 }
510 
CountIgnoredSpanBoundaryCodepoints(const UnicodeText::const_iterator & span_start,const UnicodeText::const_iterator & span_end,bool count_from_beginning) const511 int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
512     const UnicodeText::const_iterator& span_start,
513     const UnicodeText::const_iterator& span_end,
514     bool count_from_beginning) const {
515   if (span_start == span_end) {
516     return 0;
517   }
518 
519   UnicodeText::const_iterator it;
520   UnicodeText::const_iterator it_last;
521   if (count_from_beginning) {
522     it = span_start;
523     it_last = span_end;
524     // We can assume that the string is non-zero length because of the check
525     // above, thus the decrement is always valid here.
526     --it_last;
527   } else {
528     it = span_end;
529     it_last = span_start;
530     // We can assume that the string is non-zero length because of the check
531     // above, thus the decrement is always valid here.
532     --it;
533   }
534 
535   // Move until we encounter a non-ignored character.
536   int num_ignored = 0;
537   while (ignored_span_boundary_codepoints_.find(*it) !=
538          ignored_span_boundary_codepoints_.end()) {
539     ++num_ignored;
540 
541     if (it == it_last) {
542       break;
543     }
544 
545     if (count_from_beginning) {
546       ++it;
547     } else {
548       --it;
549     }
550   }
551 
552   return num_ignored;
553 }
554 
555 namespace {
556 
FindSubstrings(const UnicodeText & t,const std::set<char32> & codepoints,std::vector<UnicodeTextRange> * ranges)557 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
558                     std::vector<UnicodeTextRange>* ranges) {
559   UnicodeText::const_iterator start = t.begin();
560   UnicodeText::const_iterator curr = start;
561   UnicodeText::const_iterator end = t.end();
562   for (; curr != end; ++curr) {
563     if (codepoints.find(*curr) != codepoints.end()) {
564       if (start != curr) {
565         ranges->push_back(std::make_pair(start, curr));
566       }
567       start = curr;
568       ++start;
569     }
570   }
571   if (start != end) {
572     ranges->push_back(std::make_pair(start, end));
573   }
574 }
575 
576 }  // namespace
577 
SplitContext(const UnicodeText & context_unicode) const578 std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
579     const UnicodeText& context_unicode) const {
580   std::vector<UnicodeTextRange> lines;
581   const std::set<char32> codepoints{{'\n', '|'}};
582   FindSubstrings(context_unicode, codepoints, &lines);
583   return lines;
584 }
585 
StripBoundaryCodepoints(const std::string & context,CodepointSpan span) const586 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
587     const std::string& context, CodepointSpan span) const {
588   const UnicodeText context_unicode =
589       UTF8ToUnicodeText(context, /*do_copy=*/false);
590   return StripBoundaryCodepoints(context_unicode, span);
591 }
592 
StripBoundaryCodepoints(const UnicodeText & context_unicode,CodepointSpan span) const593 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
594     const UnicodeText& context_unicode, CodepointSpan span) const {
595   if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
596     return span;
597   }
598 
599   UnicodeText::const_iterator span_begin = context_unicode.begin();
600   std::advance(span_begin, span.first);
601   UnicodeText::const_iterator span_end = context_unicode.begin();
602   std::advance(span_end, span.second);
603 
604   const int start_offset = CountIgnoredSpanBoundaryCodepoints(
605       span_begin, span_end, /*count_from_beginning=*/true);
606   const int end_offset = CountIgnoredSpanBoundaryCodepoints(
607       span_begin, span_end, /*count_from_beginning=*/false);
608 
609   if (span.first + start_offset < span.second - end_offset) {
610     return {span.first + start_offset, span.second - end_offset};
611   } else {
612     return {span.first, span.first};
613   }
614 }
615 
SupportedCodepointsRatio(const TokenSpan & token_span,const std::vector<Token> & tokens) const616 float FeatureProcessor::SupportedCodepointsRatio(
617     const TokenSpan& token_span, const std::vector<Token>& tokens) const {
618   int num_supported = 0;
619   int num_total = 0;
620   for (int i = token_span.first; i < token_span.second; ++i) {
621     const UnicodeText value =
622         UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
623     for (auto codepoint : value) {
624       if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
625         ++num_supported;
626       }
627       ++num_total;
628     }
629   }
630   return static_cast<float>(num_supported) / static_cast<float>(num_total);
631 }
632 
IsCodepointInRanges(int codepoint,const std::vector<CodepointRange> & codepoint_ranges) const633 bool FeatureProcessor::IsCodepointInRanges(
634     int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
635   auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
636                              codepoint,
637                              [](const CodepointRange& range, int codepoint) {
638                                // This function compares range with the
639                                // codepoint for the purpose of finding the first
640                                // greater or equal range. Because of the use of
641                                // std::lower_bound it needs to return true when
642                                // range < codepoint; the first time it will
643                                // return false the lower bound is found and
644                                // returned.
645                                //
646                                // It might seem weird that the condition is
647                                // range.end <= codepoint here but when codepoint
648                                // == range.end it means it's actually just
649                                // outside of the range, thus the range is less
650                                // than the codepoint.
651                                return range.end <= codepoint;
652                              });
653   if (it != codepoint_ranges.end() && it->start <= codepoint &&
654       it->end > codepoint) {
655     return true;
656   } else {
657     return false;
658   }
659 }
660 
CollectionToLabel(const std::string & collection) const661 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
662   const auto it = collection_to_label_.find(collection);
663   if (it == collection_to_label_.end()) {
664     return options_->default_collection();
665   } else {
666     return it->second;
667   }
668 }
669 
LabelToCollection(int label) const670 std::string FeatureProcessor::LabelToCollection(int label) const {
671   if (label >= 0 && label < collection_to_label_.size()) {
672     return (*options_->collections())[label]->str();
673   } else {
674     return GetDefaultCollection();
675   }
676 }
677 
MakeLabelMaps()678 void FeatureProcessor::MakeLabelMaps() {
679   if (options_->collections() != nullptr) {
680     for (int i = 0; i < options_->collections()->size(); ++i) {
681       collection_to_label_[(*options_->collections())[i]->str()] = i;
682     }
683   }
684 
685   int selection_label_id = 0;
686   for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
687     for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
688       if (!options_->selection_reduced_output_space() ||
689           r + l <= options_->max_selection_span()) {
690         TokenSpan token_span{l, r};
691         selection_to_label_[token_span] = selection_label_id;
692         label_to_selection_.push_back(token_span);
693         ++selection_label_id;
694       }
695     }
696   }
697 }
698 
RetokenizeAndFindClick(const std::string & context,CodepointSpan input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const699 void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
700                                               CodepointSpan input_span,
701                                               bool only_use_line_with_click,
702                                               std::vector<Token>* tokens,
703                                               int* click_pos) const {
704   const UnicodeText context_unicode =
705       UTF8ToUnicodeText(context, /*do_copy=*/false);
706   RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
707                          tokens, click_pos);
708 }
709 
RetokenizeAndFindClick(const UnicodeText & context_unicode,CodepointSpan input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const710 void FeatureProcessor::RetokenizeAndFindClick(
711     const UnicodeText& context_unicode, CodepointSpan input_span,
712     bool only_use_line_with_click, std::vector<Token>* tokens,
713     int* click_pos) const {
714   TC_CHECK(tokens != nullptr);
715 
716   if (options_->split_tokens_on_selection_boundaries()) {
717     internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
718   }
719 
720   if (only_use_line_with_click) {
721     StripTokensFromOtherLines(context_unicode, input_span, tokens);
722   }
723 
724   int local_click_pos;
725   if (click_pos == nullptr) {
726     click_pos = &local_click_pos;
727   }
728   *click_pos = FindCenterToken(input_span, *tokens);
729   if (*click_pos == kInvalidIndex) {
730     // If the default click method failed, let's try to do sub-token matching
731     // before we fail.
732     *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
733   }
734 }
735 
736 namespace internal {
737 
StripOrPadTokens(TokenSpan relative_click_span,int context_size,std::vector<Token> * tokens,int * click_pos)738 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
739                       std::vector<Token>* tokens, int* click_pos) {
740   int right_context_needed = relative_click_span.second + context_size;
741   if (*click_pos + right_context_needed + 1 >= tokens->size()) {
742     // Pad max the context size.
743     const int num_pad_tokens = std::min(
744         context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
745                                        tokens->size()));
746     std::vector<Token> pad_tokens(num_pad_tokens);
747     tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
748   } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
749     // Strip unused tokens.
750     auto it = tokens->begin();
751     std::advance(it, *click_pos + right_context_needed + 1);
752     tokens->erase(it, tokens->end());
753   }
754 
755   int left_context_needed = relative_click_span.first + context_size;
756   if (*click_pos < left_context_needed) {
757     // Pad max the context size.
758     const int num_pad_tokens =
759         std::min(context_size, left_context_needed - *click_pos);
760     std::vector<Token> pad_tokens(num_pad_tokens);
761     tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
762     *click_pos += num_pad_tokens;
763   } else if (*click_pos > left_context_needed) {
764     // Strip unused tokens.
765     auto it = tokens->begin();
766     std::advance(it, *click_pos - left_context_needed);
767     *click_pos -= it - tokens->begin();
768     tokens->erase(tokens->begin(), it);
769   }
770 }
771 
772 }  // namespace internal
773 
HasEnoughSupportedCodepoints(const std::vector<Token> & tokens,TokenSpan token_span) const774 bool FeatureProcessor::HasEnoughSupportedCodepoints(
775     const std::vector<Token>& tokens, TokenSpan token_span) const {
776   if (options_->min_supported_codepoint_ratio() > 0) {
777     const float supported_codepoint_ratio =
778         SupportedCodepointsRatio(token_span, tokens);
779     if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
780       TC_VLOG(1) << "Not enough supported codepoints in the context: "
781                  << supported_codepoint_ratio;
782       return false;
783     }
784   }
785   return true;
786 }
787 
ExtractFeatures(const std::vector<Token> & tokens,TokenSpan token_span,CodepointSpan selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,int feature_vector_size,std::unique_ptr<CachedFeatures> * cached_features) const788 bool FeatureProcessor::ExtractFeatures(
789     const std::vector<Token>& tokens, TokenSpan token_span,
790     CodepointSpan selection_span_for_feature,
791     const EmbeddingExecutor* embedding_executor,
792     EmbeddingCache* embedding_cache, int feature_vector_size,
793     std::unique_ptr<CachedFeatures>* cached_features) const {
794   std::unique_ptr<std::vector<float>> features(new std::vector<float>());
795   features->reserve(feature_vector_size * TokenSpanSize(token_span));
796   for (int i = token_span.first; i < token_span.second; ++i) {
797     if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
798                                       embedding_executor, embedding_cache,
799                                       features.get())) {
800       TC_LOG(ERROR) << "Could not get token features.";
801       return false;
802     }
803   }
804 
805   std::unique_ptr<std::vector<float>> padding_features(
806       new std::vector<float>());
807   padding_features->reserve(feature_vector_size);
808   if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
809                                     embedding_executor, embedding_cache,
810                                     padding_features.get())) {
811     TC_LOG(ERROR) << "Count not get padding token features.";
812     return false;
813   }
814 
815   *cached_features = CachedFeatures::Create(token_span, std::move(features),
816                                             std::move(padding_features),
817                                             options_, feature_vector_size);
818   if (!*cached_features) {
819     TC_LOG(ERROR) << "Cound not create cached features.";
820     return false;
821   }
822 
823   return true;
824 }
825 
ICUTokenize(const UnicodeText & context_unicode,std::vector<Token> * result) const826 bool FeatureProcessor::ICUTokenize(const UnicodeText& context_unicode,
827                                    std::vector<Token>* result) const {
828   std::unique_ptr<UniLib::BreakIterator> break_iterator =
829       unilib_->CreateBreakIterator(context_unicode);
830   if (!break_iterator) {
831     return false;
832   }
833   int last_break_index = 0;
834   int break_index = 0;
835   int last_unicode_index = 0;
836   int unicode_index = 0;
837   auto token_begin_it = context_unicode.begin();
838   while ((break_index = break_iterator->Next()) !=
839          UniLib::BreakIterator::kDone) {
840     const int token_length = break_index - last_break_index;
841     unicode_index = last_unicode_index + token_length;
842 
843     auto token_end_it = token_begin_it;
844     std::advance(token_end_it, token_length);
845 
846     // Determine if the whole token is whitespace.
847     bool is_whitespace = true;
848     for (auto char_it = token_begin_it; char_it < token_end_it; ++char_it) {
849       if (!unilib_->IsWhitespace(*char_it)) {
850         is_whitespace = false;
851         break;
852       }
853     }
854 
855     const std::string token =
856         context_unicode.UTF8Substring(token_begin_it, token_end_it);
857 
858     if (!is_whitespace || options_->icu_preserve_whitespace_tokens()) {
859       result->push_back(Token(token, last_unicode_index, unicode_index));
860     }
861 
862     last_break_index = break_index;
863     last_unicode_index = unicode_index;
864     token_begin_it = token_end_it;
865   }
866 
867   return true;
868 }
869 
InternalRetokenize(const UnicodeText & unicode_text,std::vector<Token> * tokens) const870 void FeatureProcessor::InternalRetokenize(const UnicodeText& unicode_text,
871                                           std::vector<Token>* tokens) const {
872   std::vector<Token> result;
873   CodepointSpan span(-1, -1);
874   for (Token& token : *tokens) {
875     const UnicodeText unicode_token_value =
876         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
877     bool should_retokenize = true;
878     for (const int codepoint : unicode_token_value) {
879       if (!IsCodepointInRanges(codepoint,
880                                internal_tokenizer_codepoint_ranges_)) {
881         should_retokenize = false;
882         break;
883       }
884     }
885 
886     if (should_retokenize) {
887       if (span.first < 0) {
888         span.first = token.start;
889       }
890       span.second = token.end;
891     } else {
892       TokenizeSubstring(unicode_text, span, &result);
893       span.first = -1;
894       result.emplace_back(std::move(token));
895     }
896   }
897   TokenizeSubstring(unicode_text, span, &result);
898 
899   *tokens = std::move(result);
900 }
901 
TokenizeSubstring(const UnicodeText & unicode_text,CodepointSpan span,std::vector<Token> * result) const902 void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
903                                          CodepointSpan span,
904                                          std::vector<Token>* result) const {
905   if (span.first < 0) {
906     // There is no span to tokenize.
907     return;
908   }
909 
910   // Extract the substring.
911   UnicodeText::const_iterator it_begin = unicode_text.begin();
912   for (int i = 0; i < span.first; ++i) {
913     ++it_begin;
914   }
915   UnicodeText::const_iterator it_end = unicode_text.begin();
916   for (int i = 0; i < span.second; ++i) {
917     ++it_end;
918   }
919   const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
920 
921   // Run the tokenizer and update the token bounds to reflect the offset of the
922   // substring.
923   std::vector<Token> tokens = tokenizer_.Tokenize(text);
924   // Avoids progressive capacity increases in the for loop.
925   result->reserve(result->size() + tokens.size());
926   for (Token& token : tokens) {
927     token.start += span.first;
928     token.end += span.first;
929     result->emplace_back(std::move(token));
930   }
931 }
932 
AppendTokenFeaturesWithCache(const Token & token,CodepointSpan selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,std::vector<float> * output_features) const933 bool FeatureProcessor::AppendTokenFeaturesWithCache(
934     const Token& token, CodepointSpan selection_span_for_feature,
935     const EmbeddingExecutor* embedding_executor,
936     EmbeddingCache* embedding_cache,
937     std::vector<float>* output_features) const {
938   // Look for the embedded features for the token in the cache, if there is one.
939   if (embedding_cache) {
940     const auto it = embedding_cache->find({token.start, token.end});
941     if (it != embedding_cache->end()) {
942       // The embedded features were found in the cache, extract only the dense
943       // features.
944       std::vector<float> dense_features;
945       if (!feature_extractor_.Extract(
946               token, token.IsContainedInSpan(selection_span_for_feature),
947               /*sparse_features=*/nullptr, &dense_features)) {
948         TC_LOG(ERROR) << "Could not extract token's dense features.";
949         return false;
950       }
951 
952       // Append both embedded and dense features to the output and return.
953       output_features->insert(output_features->end(), it->second.begin(),
954                               it->second.end());
955       output_features->insert(output_features->end(), dense_features.begin(),
956                               dense_features.end());
957       return true;
958     }
959   }
960 
961   // Extract the sparse and dense features.
962   std::vector<int> sparse_features;
963   std::vector<float> dense_features;
964   if (!feature_extractor_.Extract(
965           token, token.IsContainedInSpan(selection_span_for_feature),
966           &sparse_features, &dense_features)) {
967     TC_LOG(ERROR) << "Could not extract token's features.";
968     return false;
969   }
970 
971   // Embed the sparse features, appending them directly to the output.
972   const int embedding_size = GetOptions()->embedding_size();
973   output_features->resize(output_features->size() + embedding_size);
974   float* output_features_end =
975       output_features->data() + output_features->size();
976   if (!embedding_executor->AddEmbedding(
977           TensorView<int>(sparse_features.data(),
978                           {static_cast<int>(sparse_features.size())}),
979           /*dest=*/output_features_end - embedding_size,
980           /*dest_size=*/embedding_size)) {
981     TC_LOG(ERROR) << "Cound not embed token's sparse features.";
982     return false;
983   }
984 
985   // If there is a cache, the embedded features for the token were not in it,
986   // so insert them.
987   if (embedding_cache) {
988     (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
989         output_features_end - embedding_size, output_features_end);
990   }
991 
992   // Append the dense features to the output.
993   output_features->insert(output_features->end(), dense_features.begin(),
994                           dense_features.end());
995   return true;
996 }
997 
998 }  // namespace libtextclassifier2
999