• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Feature processing for FFModel (feed-forward SmartSelection model).
18 
19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
20 #define LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "smartselect/cached-features.h"
27 #include "smartselect/text-classification-model.pb.h"
28 #include "smartselect/token-feature-extractor.h"
29 #include "smartselect/tokenizer.h"
30 #include "smartselect/types.h"
31 #include "util/base/logging.h"
32 #include "util/utf8/unicodetext.h"
33 
34 namespace libtextclassifier {
35 
36 constexpr int kInvalidLabel = -1;
37 
38 // Maps a vector of sparse features and a vector of dense features to a vector
39 // of features that combines both.
40 // The output is written to the memory location pointed to  by the last float*
41 // argument.
42 // Returns true on success false on failure.
43 using FeatureVectorFn = std::function<bool(const std::vector<int>&,
44                                            const std::vector<float>&, float*)>;
45 
46 namespace internal {
47 
48 // Parses the serialized protocol buffer.
49 FeatureProcessorOptions ParseSerializedOptions(
50     const std::string& serialized_options);
51 
52 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
53     const FeatureProcessorOptions& options);
54 
55 // Removes tokens that are not part of a line of the context which contains
56 // given span.
57 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
58                                std::vector<Token>* tokens);
59 
60 // Splits tokens that contain the selection boundary inside them.
61 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
62 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
63                                       std::vector<Token>* tokens);
64 
65 // Returns the index of token that corresponds to the codepoint span.
66 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
67 
68 // Returns the index of token that corresponds to the middle of the  codepoint
69 // span.
70 int CenterTokenFromMiddleOfSelection(
71     CodepointSpan span, const std::vector<Token>& selectable_tokens);
72 
73 // Strips the tokens from the tokens vector that are not used for feature
74 // extraction because they are out of scope, or pads them so that there is
75 // enough tokens in the required context_size for all inferences with a click
76 // in relative_click_span.
77 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
78                       std::vector<Token>* tokens, int* click_pos);
79 
80 }  // namespace internal
81 
82 // Converts a codepoint span to a token span in the given list of tokens.
83 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
84                                    CodepointSpan codepoint_span);
85 
86 // Converts a token span to a codepoint span in the given list of tokens.
87 CodepointSpan TokenSpanToCodepointSpan(
88     const std::vector<Token>& selectable_tokens, TokenSpan token_span);
89 
90 // Takes care of preparing features for the span prediction model.
91 class FeatureProcessor {
92  public:
FeatureProcessor(const FeatureProcessorOptions & options)93   explicit FeatureProcessor(const FeatureProcessorOptions& options)
94       : feature_extractor_(
95             internal::BuildTokenFeatureExtractorOptions(options)),
96         options_(options),
97         tokenizer_({options.tokenization_codepoint_config().begin(),
98                     options.tokenization_codepoint_config().end()}) {
99     MakeLabelMaps();
100     PrepareCodepointRanges({options.supported_codepoint_ranges().begin(),
101                             options.supported_codepoint_ranges().end()},
102                            &supported_codepoint_ranges_);
103     PrepareCodepointRanges(
104         {options.internal_tokenizer_codepoint_ranges().begin(),
105          options.internal_tokenizer_codepoint_ranges().end()},
106         &internal_tokenizer_codepoint_ranges_);
107   }
108 
FeatureProcessor(const std::string & serialized_options)109   explicit FeatureProcessor(const std::string& serialized_options)
110       : FeatureProcessor(internal::ParseSerializedOptions(serialized_options)) {
111   }
112 
113   // Tokenizes the input string using the selected tokenization method.
114   std::vector<Token> Tokenize(const std::string& utf8_text) const;
115 
116   // Converts a label into a token span.
117   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
118 
119   // Gets the total number of selection labels.
GetSelectionLabelCount()120   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
121 
122   // Gets the string value for given collection label.
123   std::string LabelToCollection(int label) const;
124 
125   // Gets the total number of collections of the model.
NumCollections()126   int NumCollections() const { return collection_to_label_.size(); }
127 
128   // Gets the name of the default collection.
129   std::string GetDefaultCollection() const;
130 
GetOptions()131   const FeatureProcessorOptions& GetOptions() const { return options_; }
132 
133   // Tokenizes the context and input span, and finds the click position.
134   void TokenizeAndFindClick(const std::string& context,
135                             CodepointSpan input_span,
136                             std::vector<Token>* tokens, int* click_pos) const;
137 
138   // Extracts features as a CachedFeatures object that can be used for repeated
139   // inference over token spans in the given context.
140   bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
141                        TokenSpan relative_click_span,
142                        const FeatureVectorFn& feature_vector_fn,
143                        int feature_vector_size, std::vector<Token>* tokens,
144                        int* click_pos,
145                        std::unique_ptr<CachedFeatures>* cached_features) const;
146 
147   // Fills selection_label_spans with CodepointSpans that correspond to the
148   // selection labels. The CodepointSpans are based on the codepoint ranges of
149   // given tokens.
150   bool SelectionLabelSpans(
151       VectorSpan<Token> tokens,
152       std::vector<CodepointSpan>* selection_label_spans) const;
153 
DenseFeaturesCount()154   int DenseFeaturesCount() const {
155     return feature_extractor_.DenseFeaturesCount();
156   }
157 
158  protected:
159   // Represents a codepoint range [start, end).
160   struct CodepointRange {
161     int32 start;
162     int32 end;
163 
CodepointRangeCodepointRange164     CodepointRange(int32 arg_start, int32 arg_end)
165         : start(arg_start), end(arg_end) {}
166   };
167 
168   // Returns the class id corresponding to the given string collection
169   // identifier. There is a catch-all class id that the function returns for
170   // unknown collections.
171   int CollectionToLabel(const std::string& collection) const;
172 
173   // Prepares mapping from collection names to labels.
174   void MakeLabelMaps();
175 
176   // Gets the number of spannable tokens for the model.
177   //
178   // Spannable tokens are those tokens of context, which the model predicts
179   // selection spans over (i.e., there is 1:1 correspondence between the output
180   // classes of the model and each of the spannable tokens).
GetNumContextTokens()181   int GetNumContextTokens() const { return options_.context_size() * 2 + 1; }
182 
183   // Converts a label into a span of codepoint indices corresponding to it
184   // given output_tokens.
185   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
186                    CodepointSpan* span) const;
187 
188   // Converts a span to the corresponding label given output_tokens.
189   bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
190                    const std::vector<Token>& output_tokens, int* label) const;
191 
192   // Converts a token span to the corresponding label.
193   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
194 
195   void PrepareCodepointRanges(
196       const std::vector<FeatureProcessorOptions::CodepointRange>&
197           codepoint_ranges,
198       std::vector<CodepointRange>* prepared_codepoint_ranges);
199 
200   // Returns the ratio of supported codepoints to total number of codepoints in
201   // the input context around given click position.
202   float SupportedCodepointsRatio(int click_pos,
203                                  const std::vector<Token>& tokens) const;
204 
205   // Returns true if given codepoint is covered by the given sorted vector of
206   // codepoint ranges.
207   bool IsCodepointInRanges(
208       int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
209 
210   // Finds the center token index in tokens vector, using the method defined
211   // in options_.
212   int FindCenterToken(CodepointSpan span,
213                       const std::vector<Token>& tokens) const;
214 
215   // Tokenizes the input text using ICU tokenizer.
216   bool ICUTokenize(const std::string& context,
217                    std::vector<Token>* result) const;
218 
219   // Takes the result of ICU tokenization and retokenizes stretches of tokens
220   // made of a specific subset of characters using the internal tokenizer.
221   void InternalRetokenize(const std::string& context,
222                           std::vector<Token>* tokens) const;
223 
224   // Tokenizes a substring of the unicode string, appending the resulting tokens
225   // to the output vector. The resulting tokens have bounds relative to the full
226   // string. Does nothing if the start of the span is negative.
227   void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
228                          std::vector<Token>* result) const;
229 
230   const TokenFeatureExtractor feature_extractor_;
231 
232   // Codepoint ranges that define what codepoints are supported by the model.
233   // NOTE: Must be sorted.
234   std::vector<CodepointRange> supported_codepoint_ranges_;
235 
236   // Codepoint ranges that define which tokens (consisting of which codepoints)
237   // should be re-tokenized with the internal tokenizer in the mixed
238   // tokenization mode.
239   // NOTE: Must be sorted.
240   std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
241 
242  private:
243   const FeatureProcessorOptions options_;
244 
245   // Mapping between token selection spans and labels ids.
246   std::map<TokenSpan, int> selection_to_label_;
247   std::vector<TokenSpan> label_to_selection_;
248 
249   // Mapping between collections and labels.
250   std::map<std::string, int> collection_to_label_;
251 
252   Tokenizer tokenizer_;
253 };
254 
255 }  // namespace libtextclassifier
256 
257 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_FEATURE_PROCESSOR_H_
258