/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Feature processing for FFModel (feed-forward SmartSelection model). #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ #define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_ #include #include #include #include "smartselect/cached-features.h" #include "smartselect/text-classification-model.pb.h" #include "smartselect/token-feature-extractor.h" #include "smartselect/tokenizer.h" #include "smartselect/types.h" #include "util/base/logging.h" #include "util/utf8/unicodetext.h" namespace libtextclassifier { constexpr int kInvalidLabel = -1; // Maps a vector of sparse features and a vector of dense features to a vector // of features that combines both. // The output is written to the memory location pointed to by the last float* // argument. // Returns true on success false on failure. using FeatureVectorFn = std::function&, const std::vector&, float*)>; namespace internal { // Parses the serialized protocol buffer. FeatureProcessorOptions ParseSerializedOptions( const std::string& serialized_options); TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( const FeatureProcessorOptions& options); // Removes tokens that are not part of a line of the context which contains // given span. void StripTokensFromOtherLines(const std::string& context, CodepointSpan span, std::vector* tokens); // Splits tokens that contain the selection boundary inside them. // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com" void SplitTokensOnSelectionBoundaries(CodepointSpan selection, std::vector* tokens); // Returns the index of token that corresponds to the codepoint span. int CenterTokenFromClick(CodepointSpan span, const std::vector& tokens); // Returns the index of token that corresponds to the middle of the codepoint // span. int CenterTokenFromMiddleOfSelection( CodepointSpan span, const std::vector& selectable_tokens); // Strips the tokens from the tokens vector that are not used for feature // extraction because they are out of scope, or pads them so that there is // enough tokens in the required context_size for all inferences with a click // in relative_click_span. void StripOrPadTokens(TokenSpan relative_click_span, int context_size, std::vector* tokens, int* click_pos); } // namespace internal // Converts a codepoint span to a token span in the given list of tokens. TokenSpan CodepointSpanToTokenSpan(const std::vector& selectable_tokens, CodepointSpan codepoint_span); // Converts a token span to a codepoint span in the given list of tokens. CodepointSpan TokenSpanToCodepointSpan( const std::vector& selectable_tokens, TokenSpan token_span); // Takes care of preparing features for the span prediction model. class FeatureProcessor { public: explicit FeatureProcessor(const FeatureProcessorOptions& options) : feature_extractor_( internal::BuildTokenFeatureExtractorOptions(options)), options_(options), tokenizer_({options.tokenization_codepoint_config().begin(), options.tokenization_codepoint_config().end()}) { MakeLabelMaps(); PrepareCodepointRanges({options.supported_codepoint_ranges().begin(), options.supported_codepoint_ranges().end()}, &supported_codepoint_ranges_); PrepareCodepointRanges( {options.internal_tokenizer_codepoint_ranges().begin(), options.internal_tokenizer_codepoint_ranges().end()}, &internal_tokenizer_codepoint_ranges_); } explicit FeatureProcessor(const std::string& serialized_options) : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) { } // Tokenizes the input string using the selected tokenization method. std::vector Tokenize(const std::string& utf8_text) const; // Converts a label into a token span. bool LabelToTokenSpan(int label, TokenSpan* token_span) const; // Gets the total number of selection labels. int GetSelectionLabelCount() const { return label_to_selection_.size(); } // Gets the string value for given collection label. std::string LabelToCollection(int label) const; // Gets the total number of collections of the model. int NumCollections() const { return collection_to_label_.size(); } // Gets the name of the default collection. std::string GetDefaultCollection() const; const FeatureProcessorOptions& GetOptions() const { return options_; } // Tokenizes the context and input span, and finds the click position. void TokenizeAndFindClick(const std::string& context, CodepointSpan input_span, std::vector* tokens, int* click_pos) const; // Extracts features as a CachedFeatures object that can be used for repeated // inference over token spans in the given context. bool ExtractFeatures(const std::string& context, CodepointSpan input_span, TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn, int feature_vector_size, std::vector* tokens, int* click_pos, std::unique_ptr* cached_features) const; // Fills selection_label_spans with CodepointSpans that correspond to the // selection labels. The CodepointSpans are based on the codepoint ranges of // given tokens. bool SelectionLabelSpans( VectorSpan tokens, std::vector* selection_label_spans) const; int DenseFeaturesCount() const { return feature_extractor_.DenseFeaturesCount(); } protected: // Represents a codepoint range [start, end). struct CodepointRange { int32 start; int32 end; CodepointRange(int32 arg_start, int32 arg_end) : start(arg_start), end(arg_end) {} }; // Returns the class id corresponding to the given string collection // identifier. There is a catch-all class id that the function returns for // unknown collections. int CollectionToLabel(const std::string& collection) const; // Prepares mapping from collection names to labels. void MakeLabelMaps(); // Gets the number of spannable tokens for the model. // // Spannable tokens are those tokens of context, which the model predicts // selection spans over (i.e., there is 1:1 correspondence between the output // classes of the model and each of the spannable tokens). int GetNumContextTokens() const { return options_.context_size() * 2 + 1; } // Converts a label into a span of codepoint indices corresponding to it // given output_tokens. bool LabelToSpan(int label, const VectorSpan& output_tokens, CodepointSpan* span) const; // Converts a span to the corresponding label given output_tokens. bool SpanToLabel(const std::pair& span, const std::vector& output_tokens, int* label) const; // Converts a token span to the corresponding label. int TokenSpanToLabel(const std::pair& span) const; void PrepareCodepointRanges( const std::vector& codepoint_ranges, std::vector* prepared_codepoint_ranges); // Returns the ratio of supported codepoints to total number of codepoints in // the input context around given click position. float SupportedCodepointsRatio(int click_pos, const std::vector& tokens) const; // Returns true if given codepoint is covered by the given sorted vector of // codepoint ranges. bool IsCodepointInRanges( int codepoint, const std::vector& codepoint_ranges) const; // Finds the center token index in tokens vector, using the method defined // in options_. int FindCenterToken(CodepointSpan span, const std::vector& tokens) const; // Tokenizes the input text using ICU tokenizer. bool ICUTokenize(const std::string& context, std::vector* result) const; // Takes the result of ICU tokenization and retokenizes stretches of tokens // made of a specific subset of characters using the internal tokenizer. void InternalRetokenize(const std::string& context, std::vector* tokens) const; // Tokenizes a substring of the unicode string, appending the resulting tokens // to the output vector. The resulting tokens have bounds relative to the full // string. Does nothing if the start of the span is negative. void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span, std::vector* result) const; const TokenFeatureExtractor feature_extractor_; // Codepoint ranges that define what codepoints are supported by the model. // NOTE: Must be sorted. std::vector supported_codepoint_ranges_; // Codepoint ranges that define which tokens (consisting of which codepoints) // should be re-tokenized with the internal tokenizer in the mixed // tokenization mode. // NOTE: Must be sorted. std::vector internal_tokenizer_codepoint_ranges_; private: const FeatureProcessorOptions options_; // Mapping between token selection spans and labels ids. std::map selection_to_label_; std::vector label_to_selection_; // Mapping between collections and labels. std::map collection_to_label_; Tokenizer tokenizer_; }; } // namespace libtextclassifier #endif // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_