/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/feature-processor.h" #include #include #include #include "utils/base/logging.h" #include "utils/strings/utf8.h" #include "utils/utf8/unicodetext.h" namespace libtextclassifier3 { namespace internal { Tokenizer BuildTokenizer(const FeatureProcessorOptions* options, const UniLib* unilib) { std::vector codepoint_config; if (options->tokenization_codepoint_config() != nullptr) { codepoint_config.insert(codepoint_config.end(), options->tokenization_codepoint_config()->begin(), options->tokenization_codepoint_config()->end()); } std::vector internal_codepoint_config; if (options->internal_tokenizer_codepoint_ranges() != nullptr) { internal_codepoint_config.insert( internal_codepoint_config.end(), options->internal_tokenizer_codepoint_ranges()->begin(), options->internal_tokenizer_codepoint_ranges()->end()); } const bool tokenize_on_script_change = options->tokenization_codepoint_config() != nullptr && options->tokenize_on_script_change(); return Tokenizer(options->tokenization_type(), unilib, codepoint_config, internal_codepoint_config, tokenize_on_script_change, options->icu_preserve_whitespace_tokens()); } TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( const FeatureProcessorOptions* const options) { TokenFeatureExtractorOptions extractor_options; extractor_options.num_buckets = options->num_buckets(); if (options->chargram_orders() != nullptr) { for (int order : *options->chargram_orders()) { extractor_options.chargram_orders.push_back(order); } } extractor_options.max_word_length = options->max_word_length(); extractor_options.extract_case_feature = options->extract_case_feature(); extractor_options.unicode_aware_features = options->unicode_aware_features(); extractor_options.extract_selection_mask_feature = options->extract_selection_mask_feature(); if (options->regexp_feature() != nullptr) { for (const auto& regexp_feauture : *options->regexp_feature()) { extractor_options.regexp_features.push_back(regexp_feauture->str()); } } extractor_options.remap_digits = options->remap_digits(); extractor_options.lowercase_tokens = options->lowercase_tokens(); if (options->allowed_chargrams() != nullptr) { for (const auto& chargram : *options->allowed_chargrams()) { extractor_options.allowed_chargrams.insert(chargram->str()); } } return extractor_options; } void SplitTokensOnSelectionBoundaries(CodepointSpan selection, std::vector* tokens) { for (auto it = tokens->begin(); it != tokens->end(); ++it) { const UnicodeText token_word = UTF8ToUnicodeText(it->value, /*do_copy=*/false); auto last_start = token_word.begin(); int last_start_index = it->start; std::vector split_points; // Selection start split point. if (selection.first > it->start && selection.first < it->end) { std::advance(last_start, selection.first - last_start_index); split_points.push_back(last_start); last_start_index = selection.first; } // Selection end split point. if (selection.second > it->start && selection.second < it->end) { std::advance(last_start, selection.second - last_start_index); split_points.push_back(last_start); } if (!split_points.empty()) { // Add a final split for the rest of the token unless it's been all // consumed already. if (split_points.back() != token_word.end()) { split_points.push_back(token_word.end()); } std::vector replacement_tokens; last_start = token_word.begin(); int current_pos = it->start; for (const auto& split_point : split_points) { Token new_token(token_word.UTF8Substring(last_start, split_point), current_pos, current_pos + std::distance(last_start, split_point)); last_start = split_point; current_pos = new_token.end; replacement_tokens.push_back(new_token); } it = tokens->erase(it); it = tokens->insert(it, replacement_tokens.begin(), replacement_tokens.end()); std::advance(it, replacement_tokens.size() - 1); } } } } // namespace internal void FeatureProcessor::StripTokensFromOtherLines( const std::string& context, CodepointSpan span, std::vector* tokens) const { const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); StripTokensFromOtherLines(context_unicode, span, tokens); } void FeatureProcessor::StripTokensFromOtherLines( const UnicodeText& context_unicode, CodepointSpan span, std::vector* tokens) const { std::vector lines = SplitContext(context_unicode); auto span_start = context_unicode.begin(); if (span.first > 0) { std::advance(span_start, span.first); } auto span_end = context_unicode.begin(); if (span.second > 0) { std::advance(span_end, span.second); } for (const UnicodeTextRange& line : lines) { // Find the line that completely contains the span. if (line.first <= span_start && line.second >= span_end) { const CodepointIndex last_line_begin_index = std::distance(context_unicode.begin(), line.first); const CodepointIndex last_line_end_index = last_line_begin_index + std::distance(line.first, line.second); for (auto token = tokens->begin(); token != tokens->end();) { if (token->start >= last_line_begin_index && token->end <= last_line_end_index) { ++token; } else { token = tokens->erase(token); } } } } } std::string FeatureProcessor::GetDefaultCollection() const { if (options_->default_collection() < 0 || options_->collections() == nullptr || options_->default_collection() >= options_->collections()->size()) { TC3_LOG(ERROR) << "Invalid or missing default collection. Returning empty string."; return ""; } return (*options_->collections())[options_->default_collection()]->str(); } std::vector FeatureProcessor::Tokenize(const std::string& text) const { return tokenizer_.Tokenize(text); } std::vector FeatureProcessor::Tokenize( const UnicodeText& text_unicode) const { return tokenizer_.Tokenize(text_unicode); } bool FeatureProcessor::LabelToSpan( const int label, const VectorSpan& tokens, std::pair* span) const { if (tokens.size() != GetNumContextTokens()) { return false; } TokenSpan token_span; if (!LabelToTokenSpan(label, &token_span)) { return false; } const int result_begin_token_index = token_span.first; const Token& result_begin_token = tokens[options_->context_size() - result_begin_token_index]; const int result_begin_codepoint = result_begin_token.start; const int result_end_token_index = token_span.second; const Token& result_end_token = tokens[options_->context_size() + result_end_token_index]; const int result_end_codepoint = result_end_token.end; if (result_begin_codepoint == kInvalidIndex || result_end_codepoint == kInvalidIndex) { *span = CodepointSpan({kInvalidIndex, kInvalidIndex}); } else { const UnicodeText token_begin_unicode = UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false); UnicodeText::const_iterator token_begin = token_begin_unicode.begin(); const UnicodeText token_end_unicode = UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false); UnicodeText::const_iterator token_end = token_end_unicode.end(); const int begin_ignored = CountIgnoredSpanBoundaryCodepoints( token_begin, token_begin_unicode.end(), /*count_from_beginning=*/true); const int end_ignored = CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end, /*count_from_beginning=*/false); // In case everything would be stripped, set the span to the original // beginning and zero length. if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) { *span = {result_begin_codepoint, result_begin_codepoint}; } else { *span = CodepointSpan({result_begin_codepoint + begin_ignored, result_end_codepoint - end_ignored}); } } return true; } bool FeatureProcessor::LabelToTokenSpan(const int label, TokenSpan* token_span) const { if (label >= 0 && label < label_to_selection_.size()) { *token_span = label_to_selection_[label]; return true; } else { return false; } } bool FeatureProcessor::SpanToLabel( const std::pair& span, const std::vector& tokens, int* label) const { if (tokens.size() != GetNumContextTokens()) { return false; } const int click_position = options_->context_size(); // Click is always in the middle. const int padding = options_->context_size() - options_->max_selection_span(); int span_left = 0; for (int i = click_position - 1; i >= padding; i--) { if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) { ++span_left; } else { break; } } int span_right = 0; for (int i = click_position + 1; i < tokens.size() - padding; ++i) { if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) { ++span_right; } else { break; } } // Check that the spanned tokens cover the whole span. bool tokens_match_span; const CodepointIndex tokens_start = tokens[click_position - span_left].start; const CodepointIndex tokens_end = tokens[click_position + span_right].end; if (options_->snap_label_span_boundaries_to_containing_tokens()) { tokens_match_span = tokens_start <= span.first && tokens_end >= span.second; } else { const UnicodeText token_left_unicode = UTF8ToUnicodeText( tokens[click_position - span_left].value, /*do_copy=*/false); const UnicodeText token_right_unicode = UTF8ToUnicodeText( tokens[click_position + span_right].value, /*do_copy=*/false); UnicodeText::const_iterator span_begin = token_left_unicode.begin(); UnicodeText::const_iterator span_end = token_right_unicode.end(); const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints( span_begin, token_left_unicode.end(), /*count_from_beginning=*/true); const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints( token_right_unicode.begin(), span_end, /*count_from_beginning=*/false); tokens_match_span = tokens_start <= span.first && tokens_start + num_punctuation_start >= span.first && tokens_end >= span.second && tokens_end - num_punctuation_end <= span.second; } if (tokens_match_span) { *label = TokenSpanToLabel({span_left, span_right}); } else { *label = kInvalidLabel; } return true; } int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const { auto it = selection_to_label_.find(span); if (it != selection_to_label_.end()) { return it->second; } else { return kInvalidLabel; } } TokenSpan CodepointSpanToTokenSpan(const std::vector& selectable_tokens, CodepointSpan codepoint_span, bool snap_boundaries_to_containing_tokens) { const int codepoint_start = std::get<0>(codepoint_span); const int codepoint_end = std::get<1>(codepoint_span); TokenIndex start_token = kInvalidIndex; TokenIndex end_token = kInvalidIndex; for (int i = 0; i < selectable_tokens.size(); ++i) { bool is_token_in_span; if (snap_boundaries_to_containing_tokens) { is_token_in_span = codepoint_start < selectable_tokens[i].end && codepoint_end > selectable_tokens[i].start; } else { is_token_in_span = codepoint_start <= selectable_tokens[i].start && codepoint_end >= selectable_tokens[i].end; } if (is_token_in_span && !selectable_tokens[i].is_padding) { if (start_token == kInvalidIndex) { start_token = i; } end_token = i + 1; } } return {start_token, end_token}; } CodepointSpan TokenSpanToCodepointSpan( const std::vector& selectable_tokens, TokenSpan token_span) { return {selectable_tokens[token_span.first].start, selectable_tokens[token_span.second - 1].end}; } namespace { // Finds a single token that completely contains the given span. int FindTokenThatContainsSpan(const std::vector& selectable_tokens, CodepointSpan codepoint_span) { const int codepoint_start = std::get<0>(codepoint_span); const int codepoint_end = std::get<1>(codepoint_span); for (int i = 0; i < selectable_tokens.size(); ++i) { if (codepoint_start >= selectable_tokens[i].start && codepoint_end <= selectable_tokens[i].end) { return i; } } return kInvalidIndex; } } // namespace namespace internal { int CenterTokenFromClick(CodepointSpan span, const std::vector& selectable_tokens) { int range_begin; int range_end; std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(selectable_tokens, span); // If no exact match was found, try finding a token that completely contains // the click span. This is useful e.g. when Android builds the selection // using ICU tokenization, and ends up with only a portion of our space- // separated token. E.g. for "(857)" Android would select "857". if (range_begin == kInvalidIndex || range_end == kInvalidIndex) { int token_index = FindTokenThatContainsSpan(selectable_tokens, span); if (token_index != kInvalidIndex) { range_begin = token_index; range_end = token_index + 1; } } // We only allow clicks that are exactly 1 selectable token. if (range_end - range_begin == 1) { return range_begin; } else { return kInvalidIndex; } } int CenterTokenFromMiddleOfSelection( CodepointSpan span, const std::vector& selectable_tokens) { int range_begin; int range_end; std::tie(range_begin, range_end) = CodepointSpanToTokenSpan(selectable_tokens, span); // Center the clicked token in the selection range. if (range_begin != kInvalidIndex && range_end != kInvalidIndex) { return (range_begin + range_end - 1) / 2; } else { return kInvalidIndex; } } } // namespace internal int FeatureProcessor::FindCenterToken(CodepointSpan span, const std::vector& tokens) const { if (options_->center_token_selection_method() == FeatureProcessorOptions_:: CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) { return internal::CenterTokenFromClick(span, tokens); } else if (options_->center_token_selection_method() == FeatureProcessorOptions_:: CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) { return internal::CenterTokenFromMiddleOfSelection(span, tokens); } else if (options_->center_token_selection_method() == FeatureProcessorOptions_:: CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) { // TODO(zilka): Remove once we have new models on the device. // It uses the fact that sharing model use // split_tokens_on_selection_boundaries and selection not. So depending on // this we select the right way of finding the click location. if (!options_->split_tokens_on_selection_boundaries()) { // SmartSelection model. return internal::CenterTokenFromClick(span, tokens); } else { // SmartSharing model. return internal::CenterTokenFromMiddleOfSelection(span, tokens); } } else { TC3_LOG(ERROR) << "Invalid center token selection method."; return kInvalidIndex; } } bool FeatureProcessor::SelectionLabelSpans( const VectorSpan tokens, std::vector* selection_label_spans) const { for (int i = 0; i < label_to_selection_.size(); ++i) { CodepointSpan span; if (!LabelToSpan(i, tokens, &span)) { TC3_LOG(ERROR) << "Could not convert label to span: " << i; return false; } selection_label_spans->push_back(span); } return true; } void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() { if (options_->ignored_span_boundary_codepoints() != nullptr) { for (const int codepoint : *options_->ignored_span_boundary_codepoints()) { ignored_span_boundary_codepoints_.insert(codepoint); } } } int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints( const UnicodeText::const_iterator& span_start, const UnicodeText::const_iterator& span_end, bool count_from_beginning) const { if (span_start == span_end) { return 0; } UnicodeText::const_iterator it; UnicodeText::const_iterator it_last; if (count_from_beginning) { it = span_start; it_last = span_end; // We can assume that the string is non-zero length because of the check // above, thus the decrement is always valid here. --it_last; } else { it = span_end; it_last = span_start; // We can assume that the string is non-zero length because of the check // above, thus the decrement is always valid here. --it; } // Move until we encounter a non-ignored character. int num_ignored = 0; while (ignored_span_boundary_codepoints_.find(*it) != ignored_span_boundary_codepoints_.end()) { ++num_ignored; if (it == it_last) { break; } if (count_from_beginning) { ++it; } else { --it; } } return num_ignored; } namespace { void FindSubstrings(const UnicodeText& t, const std::set& codepoints, std::vector* ranges) { UnicodeText::const_iterator start = t.begin(); UnicodeText::const_iterator curr = start; UnicodeText::const_iterator end = t.end(); for (; curr != end; ++curr) { if (codepoints.find(*curr) != codepoints.end()) { if (start != curr) { ranges->push_back(std::make_pair(start, curr)); } start = curr; ++start; } } if (start != end) { ranges->push_back(std::make_pair(start, end)); } } } // namespace std::vector FeatureProcessor::SplitContext( const UnicodeText& context_unicode) const { std::vector lines; const std::set codepoints{{'\n', '|'}}; FindSubstrings(context_unicode, codepoints, &lines); return lines; } CodepointSpan FeatureProcessor::StripBoundaryCodepoints( const std::string& context, CodepointSpan span) const { const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); return StripBoundaryCodepoints(context_unicode, span); } CodepointSpan FeatureProcessor::StripBoundaryCodepoints( const UnicodeText& context_unicode, CodepointSpan span) const { if (context_unicode.empty() || !ValidNonEmptySpan(span)) { return span; } UnicodeText::const_iterator span_begin = context_unicode.begin(); std::advance(span_begin, span.first); UnicodeText::const_iterator span_end = context_unicode.begin(); std::advance(span_end, span.second); return StripBoundaryCodepoints(span_begin, span_end, span); } CodepointSpan FeatureProcessor::StripBoundaryCodepoints( const UnicodeText::const_iterator& span_begin, const UnicodeText::const_iterator& span_end, CodepointSpan span) const { if (!ValidNonEmptySpan(span) || span_begin == span_end) { return span; } const int start_offset = CountIgnoredSpanBoundaryCodepoints( span_begin, span_end, /*count_from_beginning=*/true); const int end_offset = CountIgnoredSpanBoundaryCodepoints( span_begin, span_end, /*count_from_beginning=*/false); if (span.first + start_offset < span.second - end_offset) { return {span.first + start_offset, span.second - end_offset}; } else { return {span.first, span.first}; } } float FeatureProcessor::SupportedCodepointsRatio( const TokenSpan& token_span, const std::vector& tokens) const { int num_supported = 0; int num_total = 0; for (int i = token_span.first; i < token_span.second; ++i) { const UnicodeText value = UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false); for (auto codepoint : value) { if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) { ++num_supported; } ++num_total; } } return static_cast(num_supported) / static_cast(num_total); } const std::string& FeatureProcessor::StripBoundaryCodepoints( const std::string& value, std::string* buffer) const { const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false); const CodepointSpan initial_span{0, value_unicode.size_codepoints()}; const CodepointSpan stripped_span = StripBoundaryCodepoints(value_unicode, initial_span); if (initial_span != stripped_span) { const UnicodeText stripped_token_value = UnicodeText::Substring(value_unicode, stripped_span.first, stripped_span.second, /*do_copy=*/false); *buffer = stripped_token_value.ToUTF8String(); return *buffer; } return value; } int FeatureProcessor::CollectionToLabel(const std::string& collection) const { const auto it = collection_to_label_.find(collection); if (it == collection_to_label_.end()) { return options_->default_collection(); } else { return it->second; } } std::string FeatureProcessor::LabelToCollection(int label) const { if (label >= 0 && label < collection_to_label_.size()) { return (*options_->collections())[label]->str(); } else { return GetDefaultCollection(); } } void FeatureProcessor::MakeLabelMaps() { if (options_->collections() != nullptr) { for (int i = 0; i < options_->collections()->size(); ++i) { collection_to_label_[(*options_->collections())[i]->str()] = i; } } int selection_label_id = 0; for (int l = 0; l < (options_->max_selection_span() + 1); ++l) { for (int r = 0; r < (options_->max_selection_span() + 1); ++r) { if (!options_->selection_reduced_output_space() || r + l <= options_->max_selection_span()) { TokenSpan token_span{l, r}; selection_to_label_[token_span] = selection_label_id; label_to_selection_.push_back(token_span); ++selection_label_id; } } } } void FeatureProcessor::RetokenizeAndFindClick(const std::string& context, CodepointSpan input_span, bool only_use_line_with_click, std::vector* tokens, int* click_pos) const { const UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click, tokens, click_pos); } void FeatureProcessor::RetokenizeAndFindClick( const UnicodeText& context_unicode, CodepointSpan input_span, bool only_use_line_with_click, std::vector* tokens, int* click_pos) const { TC3_CHECK(tokens != nullptr); if (options_->split_tokens_on_selection_boundaries()) { internal::SplitTokensOnSelectionBoundaries(input_span, tokens); } if (only_use_line_with_click) { StripTokensFromOtherLines(context_unicode, input_span, tokens); } int local_click_pos; if (click_pos == nullptr) { click_pos = &local_click_pos; } *click_pos = FindCenterToken(input_span, *tokens); if (*click_pos == kInvalidIndex) { // If the default click method failed, let's try to do sub-token matching // before we fail. *click_pos = internal::CenterTokenFromClick(input_span, *tokens); } } namespace internal { void StripOrPadTokens(TokenSpan relative_click_span, int context_size, std::vector* tokens, int* click_pos) { int right_context_needed = relative_click_span.second + context_size; if (*click_pos + right_context_needed + 1 >= tokens->size()) { // Pad max the context size. const int num_pad_tokens = std::min( context_size, static_cast(*click_pos + right_context_needed + 1 - tokens->size())); std::vector pad_tokens(num_pad_tokens); tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end()); } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) { // Strip unused tokens. auto it = tokens->begin(); std::advance(it, *click_pos + right_context_needed + 1); tokens->erase(it, tokens->end()); } int left_context_needed = relative_click_span.first + context_size; if (*click_pos < left_context_needed) { // Pad max the context size. const int num_pad_tokens = std::min(context_size, left_context_needed - *click_pos); std::vector pad_tokens(num_pad_tokens); tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end()); *click_pos += num_pad_tokens; } else if (*click_pos > left_context_needed) { // Strip unused tokens. auto it = tokens->begin(); std::advance(it, *click_pos - left_context_needed); *click_pos -= it - tokens->begin(); tokens->erase(tokens->begin(), it); } } } // namespace internal bool FeatureProcessor::HasEnoughSupportedCodepoints( const std::vector& tokens, TokenSpan token_span) const { if (options_->min_supported_codepoint_ratio() > 0) { const float supported_codepoint_ratio = SupportedCodepointsRatio(token_span, tokens); if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) { TC3_VLOG(1) << "Not enough supported codepoints in the context: " << supported_codepoint_ratio; return false; } } return true; } bool FeatureProcessor::ExtractFeatures( const std::vector& tokens, TokenSpan token_span, CodepointSpan selection_span_for_feature, const EmbeddingExecutor* embedding_executor, EmbeddingCache* embedding_cache, int feature_vector_size, std::unique_ptr* cached_features) const { std::unique_ptr> features(new std::vector()); features->reserve(feature_vector_size * TokenSpanSize(token_span)); for (int i = token_span.first; i < token_span.second; ++i) { if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature, embedding_executor, embedding_cache, features.get())) { TC3_LOG(ERROR) << "Could not get token features."; return false; } } std::unique_ptr> padding_features( new std::vector()); padding_features->reserve(feature_vector_size); if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature, embedding_executor, embedding_cache, padding_features.get())) { TC3_LOG(ERROR) << "Count not get padding token features."; return false; } *cached_features = CachedFeatures::Create(token_span, std::move(features), std::move(padding_features), options_, feature_vector_size); if (!*cached_features) { TC3_LOG(ERROR) << "Cound not create cached features."; return false; } return true; } bool FeatureProcessor::AppendTokenFeaturesWithCache( const Token& token, CodepointSpan selection_span_for_feature, const EmbeddingExecutor* embedding_executor, EmbeddingCache* embedding_cache, std::vector* output_features) const { // Look for the embedded features for the token in the cache, if there is one. if (embedding_cache) { const auto it = embedding_cache->find({token.start, token.end}); if (it != embedding_cache->end()) { // The embedded features were found in the cache, extract only the dense // features. std::vector dense_features; if (!feature_extractor_.Extract( token, token.IsContainedInSpan(selection_span_for_feature), /*sparse_features=*/nullptr, &dense_features)) { TC3_LOG(ERROR) << "Could not extract token's dense features."; return false; } // Append both embedded and dense features to the output and return. output_features->insert(output_features->end(), it->second.begin(), it->second.end()); output_features->insert(output_features->end(), dense_features.begin(), dense_features.end()); return true; } } // Extract the sparse and dense features. std::vector sparse_features; std::vector dense_features; if (!feature_extractor_.Extract( token, token.IsContainedInSpan(selection_span_for_feature), &sparse_features, &dense_features)) { TC3_LOG(ERROR) << "Could not extract token's features."; return false; } // Embed the sparse features, appending them directly to the output. const int embedding_size = GetOptions()->embedding_size(); output_features->resize(output_features->size() + embedding_size); float* output_features_end = output_features->data() + output_features->size(); if (!embedding_executor->AddEmbedding( TensorView(sparse_features.data(), {static_cast(sparse_features.size())}), /*dest=*/output_features_end - embedding_size, /*dest_size=*/embedding_size)) { TC3_LOG(ERROR) << "Cound not embed token's sparse features."; return false; } // If there is a cache, the embedded features for the token were not in it, // so insert them. if (embedding_cache) { (*embedding_cache)[{token.start, token.end}] = std::vector( output_features_end - embedding_size, output_features_end); } // Append the dense features to the output. output_features->insert(output_features->end(), dense_features.begin(), dense_features.end()); return true; } } // namespace libtextclassifier3