1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Feature processing for FFModel (feed-forward SmartSelection model). 18 19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ 20 #define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ 21 22 #include <map> 23 #include <memory> 24 #include <set> 25 #include <string> 26 #include <vector> 27 28 #include "annotator/cached-features.h" 29 #include "annotator/model_generated.h" 30 #include "annotator/types.h" 31 #include "utils/base/integral_types.h" 32 #include "utils/base/logging.h" 33 #include "utils/token-feature-extractor.h" 34 #include "utils/tokenizer.h" 35 #include "utils/utf8/unicodetext.h" 36 #include "utils/utf8/unilib.h" 37 38 namespace libtextclassifier3 { 39 40 constexpr int kInvalidLabel = -1; 41 42 namespace internal { 43 44 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options, 45 const UniLib* unilib); 46 47 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( 48 const FeatureProcessorOptions* options); 49 50 // Splits tokens that contain the selection boundary inside them. 51 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" 52 void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection, 53 std::vector<Token>* tokens); 54 55 // Returns the index of token that corresponds to the codepoint span. 56 int CenterTokenFromClick(const CodepointSpan& span, 57 const std::vector<Token>& tokens); 58 59 // Returns the index of token that corresponds to the middle of the codepoint 60 // span. 61 int CenterTokenFromMiddleOfSelection( 62 const CodepointSpan& span, const std::vector<Token>& selectable_tokens); 63 64 // Strips the tokens from the tokens vector that are not used for feature 65 // extraction because they are out of scope, or pads them so that there is 66 // enough tokens in the required context_size for all inferences with a click 67 // in relative_click_span. 68 void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size, 69 std::vector<Token>* tokens, int* click_pos); 70 71 } // namespace internal 72 73 // Converts a codepoint span to a token span in the given list of tokens. 74 // If snap_boundaries_to_containing_tokens is set to true, it is enough for a 75 // token to overlap with the codepoint range to be considered part of it. 76 // Otherwise it must be fully included in the range. 77 TokenSpan CodepointSpanToTokenSpan( 78 const std::vector<Token>& selectable_tokens, 79 const CodepointSpan& codepoint_span, 80 bool snap_boundaries_to_containing_tokens = false); 81 82 // Converts a token span to a codepoint span in the given list of tokens. 83 CodepointSpan TokenSpanToCodepointSpan( 84 const std::vector<Token>& selectable_tokens, const TokenSpan& token_span); 85 86 // Converts a codepoint span to a unicode text range, within the given unicode 87 // text. 88 // For an invalid span (with a negative index), returns (begin, begin). This 89 // means that it is safe to call this function before checking the validity of 90 // the span. 91 // The indices must fit within the unicode text. 92 // Note that the execution time is linear with respect to the codepoint indices. 93 // Calling this function repeatedly for spans on the same text might lead to 94 // inefficient code. 95 UnicodeTextRange CodepointSpanToUnicodeTextRange( 96 const UnicodeText& unicode_text, const CodepointSpan& span); 97 98 // Takes care of preparing features for the span prediction model. 99 class FeatureProcessor { 100 public: 101 // A cache mapping codepoint spans to embedded tokens features. An instance 102 // can be provided to multiple calls to ExtractFeatures() operating on the 103 // same context (the same codepoint spans corresponding to the same tokens), 104 // as an optimization. Note that the tokenizations do not have to be 105 // identical. 106 typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache; 107 FeatureProcessor(const FeatureProcessorOptions * options,const UniLib * unilib)108 explicit FeatureProcessor(const FeatureProcessorOptions* options, 109 const UniLib* unilib) 110 : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options), 111 unilib), 112 options_(options), 113 tokenizer_(internal::BuildTokenizer(options, unilib)) { 114 MakeLabelMaps(); 115 if (options->supported_codepoint_ranges() != nullptr) { 116 SortCodepointRanges({options->supported_codepoint_ranges()->begin(), 117 options->supported_codepoint_ranges()->end()}, 118 &supported_codepoint_ranges_); 119 } 120 PrepareIgnoredSpanBoundaryCodepoints(); 121 } 122 123 // Tokenizes the input string using the selected tokenization method. 124 std::vector<Token> Tokenize(const std::string& text) const; 125 126 // Same as above but takes UnicodeText. 127 std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; 128 129 // Converts a label into a token span. 130 bool LabelToTokenSpan(int label, TokenSpan* token_span) const; 131 132 // Gets the total number of selection labels. GetSelectionLabelCount()133 int GetSelectionLabelCount() const { return label_to_selection_.size(); } 134 135 // Gets the string value for given collection label. 136 std::string LabelToCollection(int label) const; 137 138 // Gets the total number of collections of the model. NumCollections()139 int NumCollections() const { return collection_to_label_.size(); } 140 141 // Gets the name of the default collection. 142 std::string GetDefaultCollection() const; 143 GetOptions()144 const FeatureProcessorOptions* GetOptions() const { return options_; } 145 146 // Retokenizes the context and input span, and finds the click position. 147 // Depending on the options, might modify tokens (split them or remove them). 148 void RetokenizeAndFindClick(const std::string& context, 149 const CodepointSpan& input_span, 150 bool only_use_line_with_click, 151 std::vector<Token>* tokens, int* click_pos) const; 152 153 // Same as above, but takes UnicodeText and iterators within it corresponding 154 // to input_span. 155 void RetokenizeAndFindClick(const UnicodeText& context_unicode, 156 const UnicodeText::const_iterator& span_begin, 157 const UnicodeText::const_iterator& span_end, 158 const CodepointSpan& input_span, 159 bool only_use_line_with_click, 160 std::vector<Token>* tokens, int* click_pos) const; 161 162 // Returns true if the token span has enough supported codepoints (as defined 163 // in the model config) or not and model should not run. 164 bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens, 165 const TokenSpan& token_span) const; 166 167 // Extracts features as a CachedFeatures object that can be used for repeated 168 // inference over token spans in the given context. 169 bool ExtractFeatures(const std::vector<Token>& tokens, 170 const TokenSpan& token_span, 171 const CodepointSpan& selection_span_for_feature, 172 const EmbeddingExecutor* embedding_executor, 173 EmbeddingCache* embedding_cache, int feature_vector_size, 174 std::unique_ptr<CachedFeatures>* cached_features) const; 175 176 // Fills selection_label_spans with CodepointSpans that correspond to the 177 // selection labels. The CodepointSpans are based on the codepoint ranges of 178 // given tokens. 179 bool SelectionLabelSpans( 180 VectorSpan<Token> tokens, 181 std::vector<CodepointSpan>* selection_label_spans) const; 182 183 // Fills selection_label_relative_token_spans with number of tokens left and 184 // right from the click. 185 bool SelectionLabelRelativeTokenSpans( 186 std::vector<TokenSpan>* selection_label_relative_token_spans) const; 187 DenseFeaturesCount()188 int DenseFeaturesCount() const { 189 return feature_extractor_.DenseFeaturesCount(); 190 } 191 EmbeddingSize()192 int EmbeddingSize() const { return options_->embedding_size(); } 193 194 // Splits context to several segments. 195 std::vector<UnicodeTextRange> SplitContext( 196 const UnicodeText& context_unicode, 197 const bool use_pipe_character_for_newline) const; 198 199 // Strips boundary codepoints from the span in context and returns the new 200 // start and end indices. If the span comprises entirely of boundary 201 // codepoints, the first index of span is returned for both indices. 202 CodepointSpan StripBoundaryCodepoints(const std::string& context, 203 const CodepointSpan& span) const; 204 205 // Same as above but takes UnicodeText. 206 CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode, 207 const CodepointSpan& span) const; 208 209 // Same as above but takes a pair of iterators for the span, for efficiency. 210 CodepointSpan StripBoundaryCodepoints( 211 const UnicodeText::const_iterator& span_begin, 212 const UnicodeText::const_iterator& span_end, 213 const CodepointSpan& span) const; 214 215 // Same as above, but takes an optional buffer for saving the modified value. 216 // As an optimization, returns pointer to 'value' if nothing was stripped, or 217 // pointer to 'buffer' if something was stripped. 218 const std::string& StripBoundaryCodepoints(const std::string& value, 219 std::string* buffer) const; 220 221 protected: 222 // Returns the class id corresponding to the given string collection 223 // identifier. There is a catch-all class id that the function returns for 224 // unknown collections. 225 int CollectionToLabel(const std::string& collection) const; 226 227 // Prepares mapping from collection names to labels. 228 void MakeLabelMaps(); 229 230 // Gets the number of spannable tokens for the model. 231 // 232 // Spannable tokens are those tokens of context, which the model predicts 233 // selection spans over (i.e., there is 1:1 correspondence between the output 234 // classes of the model and each of the spannable tokens). GetNumContextTokens()235 int GetNumContextTokens() const { return options_->context_size() * 2 + 1; } 236 237 // Converts a label into a span of codepoint indices corresponding to it 238 // given output_tokens. 239 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens, 240 CodepointSpan* span) const; 241 242 // Converts a span to the corresponding label given output_tokens. 243 bool SpanToLabel(const CodepointSpan& span, 244 const std::vector<Token>& output_tokens, int* label) const; 245 246 // Converts a token span to the corresponding label. 247 int TokenSpanToLabel(const TokenSpan& token_span) const; 248 249 // Returns the ratio of supported codepoints to total number of codepoints in 250 // the given token span. 251 float SupportedCodepointsRatio(const TokenSpan& token_span, 252 const std::vector<Token>& tokens) const; 253 254 void PrepareIgnoredSpanBoundaryCodepoints(); 255 256 // Counts the number of span boundary codepoints. If count_from_beginning is 257 // True, the counting will start at the span_start iterator (inclusive) and at 258 // maximum end at span_end (exclusive). If count_from_beginning is True, the 259 // counting will start from span_end (exclusive) and end at span_start 260 // (inclusive). 261 int CountIgnoredSpanBoundaryCodepoints( 262 const UnicodeText::const_iterator& span_start, 263 const UnicodeText::const_iterator& span_end, 264 bool count_from_beginning) const; 265 266 // Finds the center token index in tokens vector, using the method defined 267 // in options_. 268 int FindCenterToken(const CodepointSpan& span, 269 const std::vector<Token>& tokens) const; 270 271 // Removes all tokens from tokens that are not on a line (defined by calling 272 // SplitContext on the context) to which span points. 273 void StripTokensFromOtherLines(const std::string& context, 274 const CodepointSpan& span, 275 std::vector<Token>* tokens) const; 276 277 // Same as above but takes UnicodeText. 278 void StripTokensFromOtherLines(const UnicodeText& context_unicode, 279 const UnicodeText::const_iterator& span_begin, 280 const UnicodeText::const_iterator& span_end, 281 const CodepointSpan& span, 282 std::vector<Token>* tokens) const; 283 284 // Extracts the features of a token and appends them to the output vector. 285 // Uses the embedding cache to to avoid re-extracting the re-embedding the 286 // sparse features for the same token. 287 bool AppendTokenFeaturesWithCache( 288 const Token& token, const CodepointSpan& selection_span_for_feature, 289 const EmbeddingExecutor* embedding_executor, 290 EmbeddingCache* embedding_cache, 291 std::vector<float>* output_features) const; 292 293 protected: 294 const TokenFeatureExtractor feature_extractor_; 295 296 // Codepoint ranges that define what codepoints are supported by the model. 297 // NOTE: Must be sorted. 298 std::vector<CodepointRangeStruct> supported_codepoint_ranges_; 299 300 private: 301 // Set of codepoints that will be stripped from beginning and end of 302 // predicted spans. 303 std::unordered_set<int32> ignored_span_boundary_codepoints_; 304 305 const FeatureProcessorOptions* const options_; 306 307 // Mapping between token selection spans and labels ids. 308 std::map<TokenSpan, int> selection_to_label_; 309 std::vector<TokenSpan> label_to_selection_; 310 311 // Mapping between collections and labels. 312 std::map<std::string, int> collection_to_label_; 313 314 Tokenizer tokenizer_; 315 }; 316 317 } // namespace libtextclassifier3 318 319 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_ 320