1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Feature processing for FFModel (feed-forward SmartSelection model). 18 19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ 20 #define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ 21 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "smartselect/cached-features.h" 27 #include "smartselect/text-classification-model.pb.h" 28 #include "smartselect/token-feature-extractor.h" 29 #include "smartselect/tokenizer.h" 30 #include "smartselect/types.h" 31 #include "util/base/logging.h" 32 #include "util/utf8/unicodetext.h" 33 34 namespace libtextclassifier { 35 36 constexpr int kInvalidLabel = -1; 37 38 // Maps a vector of sparse features and a vector of dense features to a vector 39 // of features that combines both. 40 // The output is written to the memory location pointed to by the last float* 41 // argument. 42 // Returns true on success false on failure. 43 using FeatureVectorFn = std::function<bool(const std::vector<int>&, 44 const std::vector<float>&, float*)>; 45 46 namespace internal { 47 48 // Parses the serialized protocol buffer. 49 FeatureProcessorOptions ParseSerializedOptions( 50 const std::string& serialized_options); 51 52 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( 53 const FeatureProcessorOptions& options); 54 55 // Removes tokens that are not part of a line of the context which contains 56 // given span. 57 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, 58 std::vector<Token>* tokens); 59 60 // Splits tokens that contain the selection boundary inside them. 61 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" 62 void SplitTokensOnSelectionBoundaries(CodepointSpan selection, 63 std::vector<Token>* tokens); 64 65 // Returns the index of token that corresponds to the codepoint span. 66 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens); 67 68 // Returns the index of token that corresponds to the middle of the codepoint 69 // span. 70 int CenterTokenFromMiddleOfSelection( 71 CodepointSpan span, const std::vector<Token>& selectable_tokens); 72 73 // Strips the tokens from the tokens vector that are not used for feature 74 // extraction because they are out of scope, or pads them so that there is 75 // enough tokens in the required context_size for all inferences with a click 76 // in relative_click_span. 77 void StripOrPadTokens(TokenSpan relative_click_span, int context_size, 78 std::vector<Token>* tokens, int* click_pos); 79 80 } // namespace internal 81 82 // Converts a codepoint span to a token span in the given list of tokens. 83 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens, 84 CodepointSpan codepoint_span); 85 86 // Converts a token span to a codepoint span in the given list of tokens. 87 CodepointSpan TokenSpanToCodepointSpan( 88 const std::vector<Token>& selectable_tokens, TokenSpan token_span); 89 90 // Takes care of preparing features for the span prediction model. 91 class FeatureProcessor { 92 public: FeatureProcessor(const FeatureProcessorOptions & options)93 explicit FeatureProcessor(const FeatureProcessorOptions& options) 94 : feature_extractor_( 95 internal::BuildTokenFeatureExtractorOptions(options)), 96 options_(options), 97 tokenizer_({options.tokenization_codepoint_config().begin(), 98 options.tokenization_codepoint_config().end()}) { 99 MakeLabelMaps(); 100 PrepareCodepointRanges({options.supported_codepoint_ranges().begin(), 101 options.supported_codepoint_ranges().end()}, 102 &supported_codepoint_ranges_); 103 PrepareCodepointRanges( 104 {options.internal_tokenizer_codepoint_ranges().begin(), 105 options.internal_tokenizer_codepoint_ranges().end()}, 106 &internal_tokenizer_codepoint_ranges_); 107 } 108 FeatureProcessor(const std::string & serialized_options)109 explicit FeatureProcessor(const std::string& serialized_options) 110 : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) { 111 } 112 113 // Tokenizes the input string using the selected tokenization method. 114 std::vector<Token> Tokenize(const std::string& utf8_text) const; 115 116 // Converts a label into a token span. 117 bool LabelToTokenSpan(int label, TokenSpan* token_span) const; 118 119 // Gets the total number of selection labels. GetSelectionLabelCount()120 int GetSelectionLabelCount() const { return label_to_selection_.size(); } 121 122 // Gets the string value for given collection label. 123 std::string LabelToCollection(int label) const; 124 125 // Gets the total number of collections of the model. NumCollections()126 int NumCollections() const { return collection_to_label_.size(); } 127 128 // Gets the name of the default collection. 129 std::string GetDefaultCollection() const; 130 GetOptions()131 const FeatureProcessorOptions& GetOptions() const { return options_; } 132 133 // Tokenizes the context and input span, and finds the click position. 134 void TokenizeAndFindClick(const std::string& context, 135 CodepointSpan input_span, 136 std::vector<Token>* tokens, int* click_pos) const; 137 138 // Extracts features as a CachedFeatures object that can be used for repeated 139 // inference over token spans in the given context. 140 bool ExtractFeatures(const std::string& context, CodepointSpan input_span, 141 TokenSpan relative_click_span, 142 const FeatureVectorFn& feature_vector_fn, 143 int feature_vector_size, std::vector<Token>* tokens, 144 int* click_pos, 145 std::unique_ptr<CachedFeatures>* cached_features) const; 146 147 // Fills selection_label_spans with CodepointSpans that correspond to the 148 // selection labels. The CodepointSpans are based on the codepoint ranges of 149 // given tokens. 150 bool SelectionLabelSpans( 151 VectorSpan<Token> tokens, 152 std::vector<CodepointSpan>* selection_label_spans) const; 153 DenseFeaturesCount()154 int DenseFeaturesCount() const { 155 return feature_extractor_.DenseFeaturesCount(); 156 } 157 158 protected: 159 // Represents a codepoint range [start, end). 160 struct CodepointRange { 161 int32 start; 162 int32 end; 163 CodepointRangeCodepointRange164 CodepointRange(int32 arg_start, int32 arg_end) 165 : start(arg_start), end(arg_end) {} 166 }; 167 168 // Returns the class id corresponding to the given string collection 169 // identifier. There is a catch-all class id that the function returns for 170 // unknown collections. 171 int CollectionToLabel(const std::string& collection) const; 172 173 // Prepares mapping from collection names to labels. 174 void MakeLabelMaps(); 175 176 // Gets the number of spannable tokens for the model. 177 // 178 // Spannable tokens are those tokens of context, which the model predicts 179 // selection spans over (i.e., there is 1:1 correspondence between the output 180 // classes of the model and each of the spannable tokens). GetNumContextTokens()181 int GetNumContextTokens() const { return options_.context_size() * 2 + 1; } 182 183 // Converts a label into a span of codepoint indices corresponding to it 184 // given output_tokens. 185 bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens, 186 CodepointSpan* span) const; 187 188 // Converts a span to the corresponding label given output_tokens. 189 bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span, 190 const std::vector<Token>& output_tokens, int* label) const; 191 192 // Converts a token span to the corresponding label. 193 int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const; 194 195 void PrepareCodepointRanges( 196 const std::vector<FeatureProcessorOptions::CodepointRange>& 197 codepoint_ranges, 198 std::vector<CodepointRange>* prepared_codepoint_ranges); 199 200 // Returns the ratio of supported codepoints to total number of codepoints in 201 // the input context around given click position. 202 float SupportedCodepointsRatio(int click_pos, 203 const std::vector<Token>& tokens) const; 204 205 // Returns true if given codepoint is covered by the given sorted vector of 206 // codepoint ranges. 207 bool IsCodepointInRanges( 208 int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const; 209 210 // Finds the center token index in tokens vector, using the method defined 211 // in options_. 212 int FindCenterToken(CodepointSpan span, 213 const std::vector<Token>& tokens) const; 214 215 // Tokenizes the input text using ICU tokenizer. 216 bool ICUTokenize(const std::string& context, 217 std::vector<Token>* result) const; 218 219 // Takes the result of ICU tokenization and retokenizes stretches of tokens 220 // made of a specific subset of characters using the internal tokenizer. 221 void InternalRetokenize(const std::string& context, 222 std::vector<Token>* tokens) const; 223 224 // Tokenizes a substring of the unicode string, appending the resulting tokens 225 // to the output vector. The resulting tokens have bounds relative to the full 226 // string. Does nothing if the start of the span is negative. 227 void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span, 228 std::vector<Token>* result) const; 229 230 const TokenFeatureExtractor feature_extractor_; 231 232 // Codepoint ranges that define what codepoints are supported by the model. 233 // NOTE: Must be sorted. 234 std::vector<CodepointRange> supported_codepoint_ranges_; 235 236 // Codepoint ranges that define which tokens (consisting of which codepoints) 237 // should be re-tokenized with the internal tokenizer in the mixed 238 // tokenization mode. 239 // NOTE: Must be sorted. 240 std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_; 241 242 private: 243 const FeatureProcessorOptions options_; 244 245 // Mapping between token selection spans and labels ids. 246 std::map<TokenSpan, int> selection_to_label_; 247 std::vector<TokenSpan> label_to_selection_; 248 249 // Mapping between collections and labels. 250 std::map<std::string, int> collection_to_label_; 251 252 Tokenizer tokenizer_; 253 }; 254 255 } // namespace libtextclassifier 256 257 #endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ 258