1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/feature-processor.h"
18
19 #include <iterator>
20 #include <set>
21 #include <vector>
22
23 #include "utils/base/logging.h"
24 #include "utils/strings/utf8.h"
25 #include "utils/utf8/unicodetext.h"
26
27 namespace libtextclassifier3 {
28
29 namespace internal {
30
BuildTokenizer(const FeatureProcessorOptions * options,const UniLib * unilib)31 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
32 const UniLib* unilib) {
33 std::vector<const TokenizationCodepointRange*> codepoint_config;
34 if (options->tokenization_codepoint_config() != nullptr) {
35 codepoint_config.insert(codepoint_config.end(),
36 options->tokenization_codepoint_config()->begin(),
37 options->tokenization_codepoint_config()->end());
38 }
39 std::vector<const CodepointRange*> internal_codepoint_config;
40 if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
41 internal_codepoint_config.insert(
42 internal_codepoint_config.end(),
43 options->internal_tokenizer_codepoint_ranges()->begin(),
44 options->internal_tokenizer_codepoint_ranges()->end());
45 }
46 const bool tokenize_on_script_change =
47 options->tokenization_codepoint_config() != nullptr &&
48 options->tokenize_on_script_change();
49 return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
50 internal_codepoint_config, tokenize_on_script_change,
51 options->icu_preserve_whitespace_tokens());
52 }
53
BuildTokenFeatureExtractorOptions(const FeatureProcessorOptions * const options)54 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
55 const FeatureProcessorOptions* const options) {
56 TokenFeatureExtractorOptions extractor_options;
57
58 extractor_options.num_buckets = options->num_buckets();
59 if (options->chargram_orders() != nullptr) {
60 for (int order : *options->chargram_orders()) {
61 extractor_options.chargram_orders.push_back(order);
62 }
63 }
64 extractor_options.max_word_length = options->max_word_length();
65 extractor_options.extract_case_feature = options->extract_case_feature();
66 extractor_options.unicode_aware_features = options->unicode_aware_features();
67 extractor_options.extract_selection_mask_feature =
68 options->extract_selection_mask_feature();
69 if (options->regexp_feature() != nullptr) {
70 for (const auto& regexp_feature : *options->regexp_feature()) {
71 extractor_options.regexp_features.push_back(regexp_feature->str());
72 }
73 }
74 extractor_options.remap_digits = options->remap_digits();
75 extractor_options.lowercase_tokens = options->lowercase_tokens();
76
77 if (options->allowed_chargrams() != nullptr) {
78 for (const auto& chargram : *options->allowed_chargrams()) {
79 extractor_options.allowed_chargrams.insert(chargram->str());
80 }
81 }
82 return extractor_options;
83 }
84
SplitTokensOnSelectionBoundaries(const CodepointSpan & selection,std::vector<Token> * tokens)85 void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
86 std::vector<Token>* tokens) {
87 for (auto it = tokens->begin(); it != tokens->end(); ++it) {
88 const UnicodeText token_word =
89 UTF8ToUnicodeText(it->value, /*do_copy=*/false);
90
91 auto last_start = token_word.begin();
92 int last_start_index = it->start;
93 std::vector<UnicodeText::const_iterator> split_points;
94
95 // Selection start split point.
96 if (selection.first > it->start && selection.first < it->end) {
97 std::advance(last_start, selection.first - last_start_index);
98 split_points.push_back(last_start);
99 last_start_index = selection.first;
100 }
101
102 // Selection end split point.
103 if (selection.second > it->start && selection.second < it->end) {
104 std::advance(last_start, selection.second - last_start_index);
105 split_points.push_back(last_start);
106 }
107
108 if (!split_points.empty()) {
109 // Add a final split for the rest of the token unless it's been all
110 // consumed already.
111 if (split_points.back() != token_word.end()) {
112 split_points.push_back(token_word.end());
113 }
114
115 std::vector<Token> replacement_tokens;
116 last_start = token_word.begin();
117 int current_pos = it->start;
118 for (const auto& split_point : split_points) {
119 Token new_token(token_word.UTF8Substring(last_start, split_point),
120 current_pos,
121 current_pos + std::distance(last_start, split_point));
122
123 last_start = split_point;
124 current_pos = new_token.end;
125
126 replacement_tokens.push_back(new_token);
127 }
128
129 it = tokens->erase(it);
130 it = tokens->insert(it, replacement_tokens.begin(),
131 replacement_tokens.end());
132 std::advance(it, replacement_tokens.size() - 1);
133 }
134 }
135 }
136
137 } // namespace internal
138
StripTokensFromOtherLines(const std::string & context,const CodepointSpan & span,std::vector<Token> * tokens) const139 void FeatureProcessor::StripTokensFromOtherLines(
140 const std::string& context, const CodepointSpan& span,
141 std::vector<Token>* tokens) const {
142 const UnicodeText context_unicode = UTF8ToUnicodeText(context,
143 /*do_copy=*/false);
144 const auto [span_begin, span_end] =
145 CodepointSpanToUnicodeTextRange(context_unicode, span);
146 StripTokensFromOtherLines(context_unicode, span_begin, span_end, span,
147 tokens);
148 }
149
StripTokensFromOtherLines(const UnicodeText & context_unicode,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & span,std::vector<Token> * tokens) const150 void FeatureProcessor::StripTokensFromOtherLines(
151 const UnicodeText& context_unicode,
152 const UnicodeText::const_iterator& span_begin,
153 const UnicodeText::const_iterator& span_end, const CodepointSpan& span,
154 std::vector<Token>* tokens) const {
155 std::vector<UnicodeTextRange> lines =
156 SplitContext(context_unicode, options_->use_pipe_character_for_newline());
157
158 for (const UnicodeTextRange& line : lines) {
159 // Find the line that completely contains the span.
160 if (line.first <= span_begin && line.second >= span_end) {
161 const CodepointIndex last_line_begin_index =
162 std::distance(context_unicode.begin(), line.first);
163 const CodepointIndex last_line_end_index =
164 last_line_begin_index + std::distance(line.first, line.second);
165
166 for (auto token = tokens->begin(); token != tokens->end();) {
167 if (token->start >= last_line_begin_index &&
168 token->end <= last_line_end_index) {
169 ++token;
170 } else {
171 token = tokens->erase(token);
172 }
173 }
174 }
175 }
176 }
177
GetDefaultCollection() const178 std::string FeatureProcessor::GetDefaultCollection() const {
179 if (options_->default_collection() < 0 ||
180 options_->collections() == nullptr ||
181 options_->default_collection() >= options_->collections()->size()) {
182 TC3_LOG(ERROR)
183 << "Invalid or missing default collection. Returning empty string.";
184 return "";
185 }
186 return (*options_->collections())[options_->default_collection()]->str();
187 }
188
Tokenize(const std::string & text) const189 std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
190 return tokenizer_.Tokenize(text);
191 }
192
Tokenize(const UnicodeText & text_unicode) const193 std::vector<Token> FeatureProcessor::Tokenize(
194 const UnicodeText& text_unicode) const {
195 return tokenizer_.Tokenize(text_unicode);
196 }
197
LabelToSpan(const int label,const VectorSpan<Token> & tokens,CodepointSpan * span) const198 bool FeatureProcessor::LabelToSpan(const int label,
199 const VectorSpan<Token>& tokens,
200 CodepointSpan* span) const {
201 if (tokens.size() != GetNumContextTokens()) {
202 return false;
203 }
204
205 TokenSpan token_span;
206 if (!LabelToTokenSpan(label, &token_span)) {
207 return false;
208 }
209
210 const int result_begin_token_index = token_span.first;
211 const Token& result_begin_token =
212 tokens[options_->context_size() - result_begin_token_index];
213 const int result_begin_codepoint = result_begin_token.start;
214 const int result_end_token_index = token_span.second;
215 const Token& result_end_token =
216 tokens[options_->context_size() + result_end_token_index];
217 const int result_end_codepoint = result_end_token.end;
218
219 if (result_begin_codepoint == kInvalidIndex ||
220 result_end_codepoint == kInvalidIndex) {
221 *span = CodepointSpan::kInvalid;
222 } else {
223 const UnicodeText token_begin_unicode =
224 UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
225 UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
226 const UnicodeText token_end_unicode =
227 UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
228 UnicodeText::const_iterator token_end = token_end_unicode.end();
229
230 const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
231 token_begin, token_begin_unicode.end(),
232 /*count_from_beginning=*/true);
233 const int end_ignored =
234 CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
235 /*count_from_beginning=*/false);
236 // In case everything would be stripped, set the span to the original
237 // beginning and zero length.
238 if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
239 *span = {result_begin_codepoint, result_begin_codepoint};
240 } else {
241 *span = CodepointSpan(result_begin_codepoint + begin_ignored,
242 result_end_codepoint - end_ignored);
243 }
244 }
245 return true;
246 }
247
LabelToTokenSpan(const int label,TokenSpan * token_span) const248 bool FeatureProcessor::LabelToTokenSpan(const int label,
249 TokenSpan* token_span) const {
250 if (label >= 0 && label < label_to_selection_.size()) {
251 *token_span = label_to_selection_[label];
252 return true;
253 } else {
254 return false;
255 }
256 }
257
SpanToLabel(const CodepointSpan & span,const std::vector<Token> & tokens,int * label) const258 bool FeatureProcessor::SpanToLabel(const CodepointSpan& span,
259 const std::vector<Token>& tokens,
260 int* label) const {
261 if (tokens.size() != GetNumContextTokens()) {
262 return false;
263 }
264
265 const int click_position =
266 options_->context_size(); // Click is always in the middle.
267 const int padding = options_->context_size() - options_->max_selection_span();
268
269 int span_left = 0;
270 for (int i = click_position - 1; i >= padding; i--) {
271 if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
272 ++span_left;
273 } else {
274 break;
275 }
276 }
277
278 int span_right = 0;
279 for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
280 if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
281 ++span_right;
282 } else {
283 break;
284 }
285 }
286
287 // Check that the spanned tokens cover the whole span.
288 bool tokens_match_span;
289 const CodepointIndex tokens_start = tokens[click_position - span_left].start;
290 const CodepointIndex tokens_end = tokens[click_position + span_right].end;
291 if (options_->snap_label_span_boundaries_to_containing_tokens()) {
292 tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
293 } else {
294 const UnicodeText token_left_unicode = UTF8ToUnicodeText(
295 tokens[click_position - span_left].value, /*do_copy=*/false);
296 const UnicodeText token_right_unicode = UTF8ToUnicodeText(
297 tokens[click_position + span_right].value, /*do_copy=*/false);
298
299 UnicodeText::const_iterator span_begin = token_left_unicode.begin();
300 UnicodeText::const_iterator span_end = token_right_unicode.end();
301
302 const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
303 span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
304 const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
305 token_right_unicode.begin(), span_end,
306 /*count_from_beginning=*/false);
307
308 tokens_match_span = tokens_start <= span.first &&
309 tokens_start + num_punctuation_start >= span.first &&
310 tokens_end >= span.second &&
311 tokens_end - num_punctuation_end <= span.second;
312 }
313
314 if (tokens_match_span) {
315 *label = TokenSpanToLabel({span_left, span_right});
316 } else {
317 *label = kInvalidLabel;
318 }
319
320 return true;
321 }
322
TokenSpanToLabel(const TokenSpan & token_span) const323 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& token_span) const {
324 auto it = selection_to_label_.find(token_span);
325 if (it != selection_to_label_.end()) {
326 return it->second;
327 } else {
328 return kInvalidLabel;
329 }
330 }
331
CodepointSpanToTokenSpan(const std::vector<Token> & selectable_tokens,const CodepointSpan & codepoint_span,bool snap_boundaries_to_containing_tokens)332 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
333 const CodepointSpan& codepoint_span,
334 bool snap_boundaries_to_containing_tokens) {
335 const int codepoint_start = codepoint_span.first;
336 const int codepoint_end = codepoint_span.second;
337
338 TokenIndex start_token = kInvalidIndex;
339 TokenIndex end_token = kInvalidIndex;
340 for (int i = 0; i < selectable_tokens.size(); ++i) {
341 bool is_token_in_span;
342 if (snap_boundaries_to_containing_tokens) {
343 is_token_in_span = codepoint_start < selectable_tokens[i].end &&
344 codepoint_end > selectable_tokens[i].start;
345 } else {
346 is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
347 codepoint_end >= selectable_tokens[i].end;
348 }
349 if (is_token_in_span && !selectable_tokens[i].is_padding) {
350 if (start_token == kInvalidIndex) {
351 start_token = i;
352 }
353 end_token = i + 1;
354 }
355 }
356 return {start_token, end_token};
357 }
358
TokenSpanToCodepointSpan(const std::vector<Token> & selectable_tokens,const TokenSpan & token_span)359 CodepointSpan TokenSpanToCodepointSpan(
360 const std::vector<Token>& selectable_tokens, const TokenSpan& token_span) {
361 return {selectable_tokens[token_span.first].start,
362 selectable_tokens[token_span.second - 1].end};
363 }
364
CodepointSpanToUnicodeTextRange(const UnicodeText & unicode_text,const CodepointSpan & span)365 UnicodeTextRange CodepointSpanToUnicodeTextRange(
366 const UnicodeText& unicode_text, const CodepointSpan& span) {
367 auto begin = unicode_text.begin();
368 if (span.first > 0) {
369 std::advance(begin, span.first);
370 }
371 auto end = unicode_text.begin();
372 if (span.second > 0) {
373 std::advance(end, span.second);
374 }
375 return {begin, end};
376 }
377
378 namespace {
379
380 // Finds a single token that completely contains the given span.
FindTokenThatContainsSpan(const std::vector<Token> & selectable_tokens,const CodepointSpan & codepoint_span)381 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
382 const CodepointSpan& codepoint_span) {
383 const int codepoint_start = codepoint_span.first;
384 const int codepoint_end = codepoint_span.second;
385
386 for (int i = 0; i < selectable_tokens.size(); ++i) {
387 if (codepoint_start >= selectable_tokens[i].start &&
388 codepoint_end <= selectable_tokens[i].end) {
389 return i;
390 }
391 }
392 return kInvalidIndex;
393 }
394
395 } // namespace
396
397 namespace internal {
398
CenterTokenFromClick(const CodepointSpan & span,const std::vector<Token> & selectable_tokens)399 int CenterTokenFromClick(const CodepointSpan& span,
400 const std::vector<Token>& selectable_tokens) {
401 const TokenSpan token_span =
402 CodepointSpanToTokenSpan(selectable_tokens, span);
403 int range_begin = token_span.first;
404 int range_end = token_span.second;
405
406 // If no exact match was found, try finding a token that completely contains
407 // the click span. This is useful e.g. when Android builds the selection
408 // using ICU tokenization, and ends up with only a portion of our space-
409 // separated token. E.g. for "(857)" Android would select "857".
410 if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
411 int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
412 if (token_index != kInvalidIndex) {
413 range_begin = token_index;
414 range_end = token_index + 1;
415 }
416 }
417
418 // We only allow clicks that are exactly 1 selectable token.
419 if (range_end - range_begin == 1) {
420 return range_begin;
421 } else {
422 return kInvalidIndex;
423 }
424 }
425
CenterTokenFromMiddleOfSelection(const CodepointSpan & span,const std::vector<Token> & selectable_tokens)426 int CenterTokenFromMiddleOfSelection(
427 const CodepointSpan& span, const std::vector<Token>& selectable_tokens) {
428 const TokenSpan token_span =
429 CodepointSpanToTokenSpan(selectable_tokens, span);
430 const int range_begin = token_span.first;
431 const int range_end = token_span.second;
432
433 // Center the clicked token in the selection range.
434 if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
435 return (range_begin + range_end - 1) / 2;
436 } else {
437 return kInvalidIndex;
438 }
439 }
440
441 } // namespace internal
442
FindCenterToken(const CodepointSpan & span,const std::vector<Token> & tokens) const443 int FeatureProcessor::FindCenterToken(const CodepointSpan& span,
444 const std::vector<Token>& tokens) const {
445 if (options_->center_token_selection_method() ==
446 FeatureProcessorOptions_::
447 CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
448 return internal::CenterTokenFromClick(span, tokens);
449 } else if (options_->center_token_selection_method() ==
450 FeatureProcessorOptions_::
451 CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
452 return internal::CenterTokenFromMiddleOfSelection(span, tokens);
453 } else if (options_->center_token_selection_method() ==
454 FeatureProcessorOptions_::
455 CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
456 // TODO(zilka): Remove once we have new models on the device.
457 // It uses the fact that sharing model use
458 // split_tokens_on_selection_boundaries and selection not. So depending on
459 // this we select the right way of finding the click location.
460 if (!options_->split_tokens_on_selection_boundaries()) {
461 // SmartSelection model.
462 return internal::CenterTokenFromClick(span, tokens);
463 } else {
464 // SmartSharing model.
465 return internal::CenterTokenFromMiddleOfSelection(span, tokens);
466 }
467 } else {
468 TC3_LOG(ERROR) << "Invalid center token selection method.";
469 return kInvalidIndex;
470 }
471 }
472
SelectionLabelSpans(const VectorSpan<Token> tokens,std::vector<CodepointSpan> * selection_label_spans) const473 bool FeatureProcessor::SelectionLabelSpans(
474 const VectorSpan<Token> tokens,
475 std::vector<CodepointSpan>* selection_label_spans) const {
476 for (int i = 0; i < label_to_selection_.size(); ++i) {
477 CodepointSpan span = CodepointSpan::kInvalid;
478 if (!LabelToSpan(i, tokens, &span)) {
479 TC3_LOG(ERROR) << "Could not convert label to span: " << i;
480 return false;
481 }
482 selection_label_spans->push_back(span);
483 }
484 return true;
485 }
486
SelectionLabelRelativeTokenSpans(std::vector<TokenSpan> * selection_label_relative_token_spans) const487 bool FeatureProcessor::SelectionLabelRelativeTokenSpans(
488 std::vector<TokenSpan>* selection_label_relative_token_spans) const {
489 selection_label_relative_token_spans->assign(label_to_selection_.begin(),
490 label_to_selection_.end());
491 return true;
492 }
493
PrepareIgnoredSpanBoundaryCodepoints()494 void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
495 if (options_->ignored_span_boundary_codepoints() != nullptr) {
496 for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
497 ignored_span_boundary_codepoints_.insert(codepoint);
498 }
499 }
500 }
501
CountIgnoredSpanBoundaryCodepoints(const UnicodeText::const_iterator & span_start,const UnicodeText::const_iterator & span_end,bool count_from_beginning) const502 int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
503 const UnicodeText::const_iterator& span_start,
504 const UnicodeText::const_iterator& span_end,
505 bool count_from_beginning) const {
506 if (span_start == span_end) {
507 return 0;
508 }
509
510 UnicodeText::const_iterator it;
511 UnicodeText::const_iterator it_last;
512 if (count_from_beginning) {
513 it = span_start;
514 it_last = span_end;
515 // We can assume that the string is non-zero length because of the check
516 // above, thus the decrement is always valid here.
517 --it_last;
518 } else {
519 it = span_end;
520 it_last = span_start;
521 // We can assume that the string is non-zero length because of the check
522 // above, thus the decrement is always valid here.
523 --it;
524 }
525
526 // Move until we encounter a non-ignored character.
527 int num_ignored = 0;
528 while (ignored_span_boundary_codepoints_.find(*it) !=
529 ignored_span_boundary_codepoints_.end()) {
530 ++num_ignored;
531
532 if (it == it_last) {
533 break;
534 }
535
536 if (count_from_beginning) {
537 ++it;
538 } else {
539 --it;
540 }
541 }
542
543 return num_ignored;
544 }
545
546 namespace {
547
FindSubstrings(const UnicodeText & t,const std::set<char32> & codepoints,std::vector<UnicodeTextRange> * ranges)548 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
549 std::vector<UnicodeTextRange>* ranges) {
550 UnicodeText::const_iterator start = t.begin();
551 UnicodeText::const_iterator curr = start;
552 UnicodeText::const_iterator end = t.end();
553 for (; curr != end; ++curr) {
554 if (codepoints.find(*curr) != codepoints.end()) {
555 if (start != curr) {
556 ranges->push_back(std::make_pair(start, curr));
557 }
558 start = curr;
559 ++start;
560 }
561 }
562 if (start != end) {
563 ranges->push_back(std::make_pair(start, end));
564 }
565 }
566
567 } // namespace
568
SplitContext(const UnicodeText & context_unicode,const bool use_pipe_character_for_newline) const569 std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
570 const UnicodeText& context_unicode,
571 const bool use_pipe_character_for_newline) const {
572 std::vector<UnicodeTextRange> lines;
573 std::set<char32> codepoints{'\n'};
574 if (use_pipe_character_for_newline) {
575 codepoints.insert('|');
576 }
577 FindSubstrings(context_unicode, codepoints, &lines);
578 return lines;
579 }
580
StripBoundaryCodepoints(const std::string & context,const CodepointSpan & span) const581 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
582 const std::string& context, const CodepointSpan& span) const {
583 const UnicodeText context_unicode =
584 UTF8ToUnicodeText(context, /*do_copy=*/false);
585 return StripBoundaryCodepoints(context_unicode, span);
586 }
587
StripBoundaryCodepoints(const UnicodeText & context_unicode,const CodepointSpan & span) const588 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
589 const UnicodeText& context_unicode, const CodepointSpan& span) const {
590 if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
591 return span;
592 }
593
594 const auto [span_begin, span_end] =
595 CodepointSpanToUnicodeTextRange(context_unicode, span);
596
597 return StripBoundaryCodepoints(span_begin, span_end, span);
598 }
599
StripBoundaryCodepoints(const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & span) const600 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
601 const UnicodeText::const_iterator& span_begin,
602 const UnicodeText::const_iterator& span_end,
603 const CodepointSpan& span) const {
604 if (!span.IsValid() || span.IsEmpty() || span_begin == span_end) {
605 return span;
606 }
607
608 const int start_offset = CountIgnoredSpanBoundaryCodepoints(
609 span_begin, span_end, /*count_from_beginning=*/true);
610 const int end_offset = CountIgnoredSpanBoundaryCodepoints(
611 span_begin, span_end, /*count_from_beginning=*/false);
612
613 if (span.first + start_offset < span.second - end_offset) {
614 return {span.first + start_offset, span.second - end_offset};
615 } else {
616 return {span.first, span.first};
617 }
618 }
619
SupportedCodepointsRatio(const TokenSpan & token_span,const std::vector<Token> & tokens) const620 float FeatureProcessor::SupportedCodepointsRatio(
621 const TokenSpan& token_span, const std::vector<Token>& tokens) const {
622 int num_supported = 0;
623 int num_total = 0;
624 for (int i = token_span.first; i < token_span.second; ++i) {
625 const UnicodeText value =
626 UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
627 for (auto codepoint : value) {
628 if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
629 ++num_supported;
630 }
631 ++num_total;
632 }
633 }
634 // Avoid division by zero.
635 if (num_total == 0) {
636 return 0.0;
637 }
638 return static_cast<float>(num_supported) / static_cast<float>(num_total);
639 }
640
StripBoundaryCodepoints(const std::string & value,std::string * buffer) const641 const std::string& FeatureProcessor::StripBoundaryCodepoints(
642 const std::string& value, std::string* buffer) const {
643 const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
644 const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
645 const CodepointSpan stripped_span =
646 StripBoundaryCodepoints(value_unicode, initial_span);
647
648 if (initial_span != stripped_span) {
649 const UnicodeText stripped_token_value =
650 UnicodeText::Substring(value_unicode, stripped_span.first,
651 stripped_span.second, /*do_copy=*/false);
652 *buffer = stripped_token_value.ToUTF8String();
653 return *buffer;
654 }
655 return value;
656 }
657
CollectionToLabel(const std::string & collection) const658 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
659 const auto it = collection_to_label_.find(collection);
660 if (it == collection_to_label_.end()) {
661 return options_->default_collection();
662 } else {
663 return it->second;
664 }
665 }
666
LabelToCollection(int label) const667 std::string FeatureProcessor::LabelToCollection(int label) const {
668 if (label >= 0 && label < collection_to_label_.size()) {
669 return (*options_->collections())[label]->str();
670 } else {
671 return GetDefaultCollection();
672 }
673 }
674
MakeLabelMaps()675 void FeatureProcessor::MakeLabelMaps() {
676 if (options_->collections() != nullptr) {
677 for (int i = 0; i < options_->collections()->size(); ++i) {
678 collection_to_label_[(*options_->collections())[i]->str()] = i;
679 }
680 }
681
682 int selection_label_id = 0;
683 for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
684 for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
685 if (!options_->selection_reduced_output_space() ||
686 r + l <= options_->max_selection_span()) {
687 TokenSpan token_span{l, r};
688 selection_to_label_[token_span] = selection_label_id;
689 label_to_selection_.push_back(token_span);
690 ++selection_label_id;
691 }
692 }
693 }
694 }
695
RetokenizeAndFindClick(const std::string & context,const CodepointSpan & input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const696 void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
697 const CodepointSpan& input_span,
698 bool only_use_line_with_click,
699 std::vector<Token>* tokens,
700 int* click_pos) const {
701 const UnicodeText context_unicode =
702 UTF8ToUnicodeText(context, /*do_copy=*/false);
703 const auto [span_begin, span_end] =
704 CodepointSpanToUnicodeTextRange(context_unicode, input_span);
705 RetokenizeAndFindClick(context_unicode, span_begin, span_end, input_span,
706 only_use_line_with_click, tokens, click_pos);
707 }
708
RetokenizeAndFindClick(const UnicodeText & context_unicode,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const709 void FeatureProcessor::RetokenizeAndFindClick(
710 const UnicodeText& context_unicode,
711 const UnicodeText::const_iterator& span_begin,
712 const UnicodeText::const_iterator& span_end,
713 const CodepointSpan& input_span, bool only_use_line_with_click,
714 std::vector<Token>* tokens, int* click_pos) const {
715 TC3_CHECK(tokens != nullptr);
716
717 if (options_->split_tokens_on_selection_boundaries()) {
718 internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
719 }
720
721 if (only_use_line_with_click) {
722 StripTokensFromOtherLines(context_unicode, span_begin, span_end, input_span,
723 tokens);
724 }
725
726 int local_click_pos;
727 if (click_pos == nullptr) {
728 click_pos = &local_click_pos;
729 }
730 *click_pos = FindCenterToken(input_span, *tokens);
731 if (*click_pos == kInvalidIndex) {
732 // If the default click method failed, let's try to do sub-token matching
733 // before we fail.
734 *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
735 }
736 }
737
738 namespace internal {
739
StripOrPadTokens(const TokenSpan & relative_click_span,int context_size,std::vector<Token> * tokens,int * click_pos)740 void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
741 std::vector<Token>* tokens, int* click_pos) {
742 int right_context_needed = relative_click_span.second + context_size;
743 if (*click_pos + right_context_needed + 1 >= tokens->size()) {
744 // Pad max the context size.
745 const int num_pad_tokens = std::min(
746 context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
747 tokens->size()));
748 std::vector<Token> pad_tokens(num_pad_tokens);
749 tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
750 } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
751 // Strip unused tokens.
752 auto it = tokens->begin();
753 std::advance(it, *click_pos + right_context_needed + 1);
754 tokens->erase(it, tokens->end());
755 }
756
757 int left_context_needed = relative_click_span.first + context_size;
758 if (*click_pos < left_context_needed) {
759 // Pad max the context size.
760 const int num_pad_tokens =
761 std::min(context_size, left_context_needed - *click_pos);
762 std::vector<Token> pad_tokens(num_pad_tokens);
763 tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
764 *click_pos += num_pad_tokens;
765 } else if (*click_pos > left_context_needed) {
766 // Strip unused tokens.
767 auto it = tokens->begin();
768 std::advance(it, *click_pos - left_context_needed);
769 *click_pos -= it - tokens->begin();
770 tokens->erase(tokens->begin(), it);
771 }
772 }
773
774 } // namespace internal
775
HasEnoughSupportedCodepoints(const std::vector<Token> & tokens,const TokenSpan & token_span) const776 bool FeatureProcessor::HasEnoughSupportedCodepoints(
777 const std::vector<Token>& tokens, const TokenSpan& token_span) const {
778 if (options_->min_supported_codepoint_ratio() > 0) {
779 const float supported_codepoint_ratio =
780 SupportedCodepointsRatio(token_span, tokens);
781 if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
782 TC3_VLOG(1) << "Not enough supported codepoints in the context: "
783 << supported_codepoint_ratio;
784 return false;
785 }
786 }
787 return true;
788 }
789
ExtractFeatures(const std::vector<Token> & tokens,const TokenSpan & token_span,const CodepointSpan & selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,int feature_vector_size,std::unique_ptr<CachedFeatures> * cached_features) const790 bool FeatureProcessor::ExtractFeatures(
791 const std::vector<Token>& tokens, const TokenSpan& token_span,
792 const CodepointSpan& selection_span_for_feature,
793 const EmbeddingExecutor* embedding_executor,
794 EmbeddingCache* embedding_cache, int feature_vector_size,
795 std::unique_ptr<CachedFeatures>* cached_features) const {
796 std::unique_ptr<std::vector<float>> features(new std::vector<float>());
797 features->reserve(feature_vector_size * token_span.Size());
798 for (int i = token_span.first; i < token_span.second; ++i) {
799 if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
800 embedding_executor, embedding_cache,
801 features.get())) {
802 TC3_LOG(ERROR) << "Could not get token features.";
803 return false;
804 }
805 }
806
807 std::unique_ptr<std::vector<float>> padding_features(
808 new std::vector<float>());
809 padding_features->reserve(feature_vector_size);
810 if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
811 embedding_executor, embedding_cache,
812 padding_features.get())) {
813 TC3_LOG(ERROR) << "Count not get padding token features.";
814 return false;
815 }
816
817 *cached_features = CachedFeatures::Create(token_span, std::move(features),
818 std::move(padding_features),
819 options_, feature_vector_size);
820 if (!*cached_features) {
821 TC3_LOG(ERROR) << "Cound not create cached features.";
822 return false;
823 }
824
825 return true;
826 }
827
AppendTokenFeaturesWithCache(const Token & token,const CodepointSpan & selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,std::vector<float> * output_features) const828 bool FeatureProcessor::AppendTokenFeaturesWithCache(
829 const Token& token, const CodepointSpan& selection_span_for_feature,
830 const EmbeddingExecutor* embedding_executor,
831 EmbeddingCache* embedding_cache,
832 std::vector<float>* output_features) const {
833 // Look for the embedded features for the token in the cache, if there is one.
834 if (embedding_cache) {
835 const auto it = embedding_cache->find({token.start, token.end});
836 if (it != embedding_cache->end()) {
837 // The embedded features were found in the cache, extract only the dense
838 // features.
839 std::vector<float> dense_features;
840 if (!feature_extractor_.Extract(
841 token, token.IsContainedInSpan(selection_span_for_feature),
842 /*sparse_features=*/nullptr, &dense_features)) {
843 TC3_LOG(ERROR) << "Could not extract token's dense features.";
844 return false;
845 }
846
847 // Append both embedded and dense features to the output and return.
848 output_features->insert(output_features->end(), it->second.begin(),
849 it->second.end());
850 output_features->insert(output_features->end(), dense_features.begin(),
851 dense_features.end());
852 return true;
853 }
854 }
855
856 // Extract the sparse and dense features.
857 std::vector<int> sparse_features;
858 std::vector<float> dense_features;
859 if (!feature_extractor_.Extract(
860 token, token.IsContainedInSpan(selection_span_for_feature),
861 &sparse_features, &dense_features)) {
862 TC3_LOG(ERROR) << "Could not extract token's features.";
863 return false;
864 }
865
866 // Embed the sparse features, appending them directly to the output.
867 const int embedding_size = GetOptions()->embedding_size();
868 output_features->resize(output_features->size() + embedding_size);
869 float* output_features_end =
870 output_features->data() + output_features->size();
871 if (!embedding_executor->AddEmbedding(
872 TensorView<int>(sparse_features.data(),
873 {static_cast<int>(sparse_features.size())}),
874 /*dest=*/output_features_end - embedding_size,
875 /*dest_size=*/embedding_size)) {
876 TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
877 return false;
878 }
879
880 // If there is a cache, the embedded features for the token were not in it,
881 // so insert them.
882 if (embedding_cache) {
883 (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
884 output_features_end - embedding_size, output_features_end);
885 }
886
887 // Append the dense features to the output.
888 output_features->insert(output_features->end(), dense_features.begin(),
889 dense_features.end());
890 return true;
891 }
892
893 } // namespace libtextclassifier3
894