1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Feature processing for FFModel (feed-forward SmartSelection model). 18 19 #ifndef LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ 20 #define LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ 21 22 #include <map> 23 #include <memory> 24 #include <set> 25 #include <string> 26 #include <vector> 27 28 #include "cached-features.h" 29 #include "model_generated.h" 30 #include "token-feature-extractor.h" 31 #include "tokenizer.h" 32 #include "types.h" 33 #include "util/base/integral_types.h" 34 #include "util/base/logging.h" 35 #include "util/utf8/unicodetext.h" 36 #include "util/utf8/unilib.h" 37 38 namespace libtextclassifier2 { 39 40 constexpr int kInvalidLabel = -1; 41 42 namespace internal { 43 44 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( 45 const FeatureProcessorOptions* options); 46 47 // Splits tokens that contain the selection boundary inside them. 48 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" 49 void SplitTokensOnSelectionBoundaries(CodepointSpan selection, 50 std::vector<Token>* tokens); 51 52 // Returns the index of token that corresponds to the codepoint span. 53 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens); 54 55 // Returns the index of token that corresponds to the middle of the codepoint 56 // span. 57 int CenterTokenFromMiddleOfSelection( 58 CodepointSpan span, const std::vector<Token>& selectable_tokens); 59 60 // Strips the tokens from the tokens vector that are not used for feature 61 // extraction because they are out of scope, or pads them so that there is 62 // enough tokens in the required context_size for all inferences with a click 63 // in relative_click_span. 64 void StripOrPadTokens(TokenSpan relative_click_span, int context_size, 65 std::vector<Token>* tokens, int* click_pos); 66 67 // If unilib is not nullptr, just returns unilib. Otherwise, if unilib is 68 // nullptr, will create UniLib, assign ownership to owned_unilib, and return it. 69 const UniLib* MaybeCreateUnilib(const UniLib* unilib, 70 std::unique_ptr<UniLib>* owned_unilib); 71 72 } // namespace internal 73 74 // Converts a codepoint span to a token span in the given list of tokens. 75 // If snap_boundaries_to_containing_tokens is set to true, it is enough for a 76 // token to overlap with the codepoint range to be considered part of it. 77 // Otherwise it must be fully included in the range. 78 TokenSpan CodepointSpanToTokenSpan( 79 const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span, 80 bool snap_boundaries_to_containing_tokens = false); 81 82 // Converts a token span to a codepoint span in the given list of tokens. 83 CodepointSpan TokenSpanToCodepointSpan( 84 const std::vector<Token>& selectable_tokens, TokenSpan token_span); 85 86 // Takes care of preparing features for the span prediction model. 87 class FeatureProcessor { 88 public: 89 // A cache mapping codepoint spans to embedded tokens features. An instance 90 // can be provided to multiple calls to ExtractFeatures() operating on the 91 // same context (the same codepoint spans corresponding to the same tokens), 92 // as an optimization. Note that the tokenizations do not have to be 93 // identical. 94 typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache; 95 96 // If unilib is nullptr, will create and own an instance of a UniLib, 97 // otherwise will use what's passed in. 98 explicit FeatureProcessor(const FeatureProcessorOptions* options, 99 const UniLib* unilib = nullptr) owned_unilib_(nullptr)100 : owned_unilib_(nullptr), 101 unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)), 102 feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options), 103 *unilib_), 104 options_(options), 105 tokenizer_( 106 options->tokenization_codepoint_config() != nullptr 107 ? Tokenizer({options->tokenization_codepoint_config()->begin(), 108 options->tokenization_codepoint_config()->end()}, 109 options->tokenize_on_script_change()) 110 : Tokenizer({}, /*split_on_script_change=*/false)) { 111 MakeLabelMaps(); 112 if (options->supported_codepoint_ranges() != nullptr) { 113 PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(), 114 options->supported_codepoint_ranges()->end()}, 115 &supported_codepoint_ranges_); 116 } 117 if (options->internal_tokenizer_codepoint_ranges() != nullptr) { 118 PrepareCodepointRanges( 119 {options->internal_tokenizer_codepoint_ranges()->begin(), 120 options->internal_tokenizer_codepoint_ranges()->end()}, 121 &internal_tokenizer_codepoint_ranges_); 122 } 123 PrepareIgnoredSpanBoundaryCodepoints(); 124 } 125 126 // Tokenizes the input string using the selected tokenization method. 127 std::vector<Token> Tokenize(const std::string& text) const; 128 129 // Same as above but takes UnicodeText. 130 std::vector<Token> Tokenize(const UnicodeText& text_unicode) const; 131 132 // Converts a label into a token span. 133 bool LabelToTokenSpan(int label, TokenSpan* token_span) const; 134 135 // Gets the total number of selection labels. GetSelectionLabelCount()136 int GetSelectionLabelCount() const { return label_to_selection_.size(); } 137 138 // Gets the string value for given collection label. 139 std::string LabelToCollection(int label) const; 140 141 // Gets the total number of collections of the model. NumCollections()142 int NumCollections() const { return collection_to_label_.size(); } 143 144 // Gets the name of the default collection. 145 std::string GetDefaultCollection() const; 146 GetOptions()147 const FeatureProcessorOptions* GetOptions() const { return options_; } 148 149 // Retokenizes the context and input span, and finds the click position. 150 // Depending on the options, might modify tokens (split them or remove them). 151 void RetokenizeAndFindClick(const std::string& context, 152 CodepointSpan input_span, 153 bool only_use_line_with_click, 154 std::vector<Token>* tokens, int* click_pos) const; 155 156 // Same as above but takes UnicodeText. 157 void RetokenizeAndFindClick(const UnicodeText& context_unicode, 158 CodepointSpan input_span, 159 bool only_use_line_with_click, 160 std::vector<Token>* tokens, int* click_pos) const; 161 162 // Returns true if the token span has enough supported codepoints (as defined 163 // in the model config) or not and model should not run. 164 bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens, 165 TokenSpan token_span) const; 166 167 // Extracts features as a CachedFeatures object that can be used for repeated 168 // inference over token spans in the given context. 169 bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span, 170 CodepointSpan selection_span_for_feature, 171 const EmbeddingExecutor* embedding_executor, 172 EmbeddingCache* embedding_cache, int feature_vector_size, 173 std::unique_ptr<CachedFeatures>* cached_features) const; 174 175 // Fills selection_label_spans with CodepointSpans that correspond to the 176 // selection labels. The CodepointSpans are based on the codepoint ranges of 177 // given tokens. 178 bool SelectionLabelSpans( 179 VectorSpan<Token> tokens, 180 std::vector<CodepointSpan>* selection_label_spans) const; 181 DenseFeaturesCount()182 int DenseFeaturesCount() const { 183 return feature_extractor_.DenseFeaturesCount(); 184 } 185 EmbeddingSize()186 int EmbeddingSize() const { return options_->embedding_size(); } 187 188 // Splits context to several segments. 189 std::vector<UnicodeTextRange> SplitContext( 190 const UnicodeText& context_unicode) const; 191 192 // Strips boundary codepoints from the span in context and returns the new 193 // start and end indices. If the span comprises entirely of boundary 194 // codepoints, the first index of span is returned for both indices. 195 CodepointSpan StripBoundaryCodepoints(const std::string& context, 196 CodepointSpan span) const; 197 198 // Same as above but takes UnicodeText. 199 CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode, 200 CodepointSpan span) const; 201 202 protected: 203 // Represents a codepoint range [start, end). 204 struct CodepointRange { 205 int32 start; 206 int32 end; 207 CodepointRangeCodepointRange208 CodepointRange(int32 arg_start, int32 arg_end) 209 : start(arg_start), end(arg_end) {} 210 }; 211 212 // Returns the class id corresponding to the given string collection 213 // identifier. There is a catch-all class id that the function returns for 214 // unknown collections. 215 int CollectionToLabel(const std::string& collection) const; 216 217 // Prepares mapping from collection names to labels. 218 void MakeLabelMaps(); 219 220 // Gets the number of spannable tokens for the model. 221 // 222 // Spannable tokens are those tokens of context, which the model predicts 223 // selection spans over (i.e., there is 1:1 correspondence between the output 224 // classes of the model and each of the spannable tokens). GetNumContextTokens()225 int GetNumContextTokens() const { return options_->context_size() * 2 + 1; } 226 227 // Converts a label into a span of codepoint indices corresponding to it 228 // given output_tokens. 229 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens, 230 CodepointSpan* span) const; 231 232 // Converts a span to the corresponding label given output_tokens. 233 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span, 234 const std::vector<Token>& output_tokens, int* label) const; 235 236 // Converts a token span to the corresponding label. 237 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const; 238 239 void PrepareCodepointRanges( 240 const std::vector<const FeatureProcessorOptions_::CodepointRange*>& 241 codepoint_ranges, 242 std::vector<CodepointRange>* prepared_codepoint_ranges); 243 244 // Returns the ratio of supported codepoints to total number of codepoints in 245 // the given token span. 246 float SupportedCodepointsRatio(const TokenSpan& token_span, 247 const std::vector<Token>& tokens) const; 248 249 // Returns true if given codepoint is covered by the given sorted vector of 250 // codepoint ranges. 251 bool IsCodepointInRanges( 252 int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const; 253 254 void PrepareIgnoredSpanBoundaryCodepoints(); 255 256 // Counts the number of span boundary codepoints. If count_from_beginning is 257 // True, the counting will start at the span_start iterator (inclusive) and at 258 // maximum end at span_end (exclusive). If count_from_beginning is True, the 259 // counting will start from span_end (exclusive) and end at span_start 260 // (inclusive). 261 int CountIgnoredSpanBoundaryCodepoints( 262 const UnicodeText::const_iterator& span_start, 263 const UnicodeText::const_iterator& span_end, 264 bool count_from_beginning) const; 265 266 // Finds the center token index in tokens vector, using the method defined 267 // in options_. 268 int FindCenterToken(CodepointSpan span, 269 const std::vector<Token>& tokens) const; 270 271 // Tokenizes the input text using ICU tokenizer. 272 bool ICUTokenize(const UnicodeText& context_unicode, 273 std::vector<Token>* result) const; 274 275 // Takes the result of ICU tokenization and retokenizes stretches of tokens 276 // made of a specific subset of characters using the internal tokenizer. 277 void InternalRetokenize(const UnicodeText& unicode_text, 278 std::vector<Token>* tokens) const; 279 280 // Tokenizes a substring of the unicode string, appending the resulting tokens 281 // to the output vector. The resulting tokens have bounds relative to the full 282 // string. Does nothing if the start of the span is negative. 283 void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span, 284 std::vector<Token>* result) const; 285 286 // Removes all tokens from tokens that are not on a line (defined by calling 287 // SplitContext on the context) to which span points. 288 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, 289 std::vector<Token>* tokens) const; 290 291 // Same as above but takes UnicodeText. 292 void StripTokensFromOtherLines(const UnicodeText& context_unicode, 293 CodepointSpan span, 294 std::vector<Token>* tokens) const; 295 296 // Extracts the features of a token and appends them to the output vector. 297 // Uses the embedding cache to to avoid re-extracting the re-embedding the 298 // sparse features for the same token. 299 bool AppendTokenFeaturesWithCache(const Token& token, 300 CodepointSpan selection_span_for_feature, 301 const EmbeddingExecutor* embedding_executor, 302 EmbeddingCache* embedding_cache, 303 std::vector<float>* output_features) const; 304 305 private: 306 std::unique_ptr<UniLib> owned_unilib_; 307 const UniLib* unilib_; 308 309 protected: 310 const TokenFeatureExtractor feature_extractor_; 311 312 // Codepoint ranges that define what codepoints are supported by the model. 313 // NOTE: Must be sorted. 314 std::vector<CodepointRange> supported_codepoint_ranges_; 315 316 // Codepoint ranges that define which tokens (consisting of which codepoints) 317 // should be re-tokenized with the internal tokenizer in the mixed 318 // tokenization mode. 319 // NOTE: Must be sorted. 320 std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_; 321 322 private: 323 // Set of codepoints that will be stripped from beginning and end of 324 // predicted spans. 325 std::set<int32> ignored_span_boundary_codepoints_; 326 327 const FeatureProcessorOptions* const options_; 328 329 // Mapping between token selection spans and labels ids. 330 std::map<TokenSpan, int> selection_to_label_; 331 std::vector<TokenSpan> label_to_selection_; 332 333 // Mapping between collections and labels. 334 std::map<std::string, int> collection_to_label_; 335 336 Tokenizer tokenizer_; 337 }; 338 339 } // namespace libtextclassifier2 340 341 #endif // LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_ 342