• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Feature processing for FFModel (feed-forward SmartSelection model).
18 
19 #ifndef LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_
20 #define LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_
21 
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <vector>
27 
28 #include "cached-features.h"
29 #include "model_generated.h"
30 #include "token-feature-extractor.h"
31 #include "tokenizer.h"
32 #include "types.h"
33 #include "util/base/integral_types.h"
34 #include "util/base/logging.h"
35 #include "util/utf8/unicodetext.h"
36 #include "util/utf8/unilib.h"
37 
38 namespace libtextclassifier2 {
39 
40 constexpr int kInvalidLabel = -1;
41 
42 namespace internal {
43 
44 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
45     const FeatureProcessorOptions* options);
46 
47 // Splits tokens that contain the selection boundary inside them.
48 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
49 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
50                                       std::vector<Token>* tokens);
51 
52 // Returns the index of token that corresponds to the codepoint span.
53 int CenterTokenFromClick(CodepointSpan span, const std::vector<Token>& tokens);
54 
55 // Returns the index of token that corresponds to the middle of the  codepoint
56 // span.
57 int CenterTokenFromMiddleOfSelection(
58     CodepointSpan span, const std::vector<Token>& selectable_tokens);
59 
60 // Strips the tokens from the tokens vector that are not used for feature
61 // extraction because they are out of scope, or pads them so that there is
62 // enough tokens in the required context_size for all inferences with a click
63 // in relative_click_span.
64 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
65                       std::vector<Token>* tokens, int* click_pos);
66 
67 // If unilib is not nullptr, just returns unilib. Otherwise, if unilib is
68 // nullptr, will create UniLib, assign ownership to owned_unilib, and return it.
69 const UniLib* MaybeCreateUnilib(const UniLib* unilib,
70                                 std::unique_ptr<UniLib>* owned_unilib);
71 
72 }  // namespace internal
73 
74 // Converts a codepoint span to a token span in the given list of tokens.
75 // If snap_boundaries_to_containing_tokens is set to true, it is enough for a
76 // token to overlap with the codepoint range to be considered part of it.
77 // Otherwise it must be fully included in the range.
78 TokenSpan CodepointSpanToTokenSpan(
79     const std::vector<Token>& selectable_tokens, CodepointSpan codepoint_span,
80     bool snap_boundaries_to_containing_tokens = false);
81 
82 // Converts a token span to a codepoint span in the given list of tokens.
83 CodepointSpan TokenSpanToCodepointSpan(
84     const std::vector<Token>& selectable_tokens, TokenSpan token_span);
85 
86 // Takes care of preparing features for the span prediction model.
87 class FeatureProcessor {
88  public:
89   // A cache mapping codepoint spans to embedded tokens features. An instance
90   // can be provided to multiple calls to ExtractFeatures() operating on the
91   // same context (the same codepoint spans corresponding to the same tokens),
92   // as an optimization. Note that the tokenizations do not have to be
93   // identical.
94   typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
95 
96   // If unilib is nullptr, will create and own an instance of a UniLib,
97   // otherwise will use what's passed in.
98   explicit FeatureProcessor(const FeatureProcessorOptions* options,
99                             const UniLib* unilib = nullptr)
owned_unilib_(nullptr)100       : owned_unilib_(nullptr),
101         unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)),
102         feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
103                            *unilib_),
104         options_(options),
105         tokenizer_(
106             options->tokenization_codepoint_config() != nullptr
107                 ? Tokenizer({options->tokenization_codepoint_config()->begin(),
108                              options->tokenization_codepoint_config()->end()},
109                             options->tokenize_on_script_change())
110                 : Tokenizer({}, /*split_on_script_change=*/false)) {
111     MakeLabelMaps();
112     if (options->supported_codepoint_ranges() != nullptr) {
113       PrepareCodepointRanges({options->supported_codepoint_ranges()->begin(),
114                               options->supported_codepoint_ranges()->end()},
115                              &supported_codepoint_ranges_);
116     }
117     if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
118       PrepareCodepointRanges(
119           {options->internal_tokenizer_codepoint_ranges()->begin(),
120            options->internal_tokenizer_codepoint_ranges()->end()},
121           &internal_tokenizer_codepoint_ranges_);
122     }
123     PrepareIgnoredSpanBoundaryCodepoints();
124   }
125 
126   // Tokenizes the input string using the selected tokenization method.
127   std::vector<Token> Tokenize(const std::string& text) const;
128 
129   // Same as above but takes UnicodeText.
130   std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
131 
132   // Converts a label into a token span.
133   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
134 
135   // Gets the total number of selection labels.
GetSelectionLabelCount()136   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
137 
138   // Gets the string value for given collection label.
139   std::string LabelToCollection(int label) const;
140 
141   // Gets the total number of collections of the model.
NumCollections()142   int NumCollections() const { return collection_to_label_.size(); }
143 
144   // Gets the name of the default collection.
145   std::string GetDefaultCollection() const;
146 
GetOptions()147   const FeatureProcessorOptions* GetOptions() const { return options_; }
148 
149   // Retokenizes the context and input span, and finds the click position.
150   // Depending on the options, might modify tokens (split them or remove them).
151   void RetokenizeAndFindClick(const std::string& context,
152                               CodepointSpan input_span,
153                               bool only_use_line_with_click,
154                               std::vector<Token>* tokens, int* click_pos) const;
155 
156   // Same as above but takes UnicodeText.
157   void RetokenizeAndFindClick(const UnicodeText& context_unicode,
158                               CodepointSpan input_span,
159                               bool only_use_line_with_click,
160                               std::vector<Token>* tokens, int* click_pos) const;
161 
162   // Returns true if the token span has enough supported codepoints (as defined
163   // in the model config) or not and model should not run.
164   bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
165                                     TokenSpan token_span) const;
166 
167   // Extracts features as a CachedFeatures object that can be used for repeated
168   // inference over token spans in the given context.
169   bool ExtractFeatures(const std::vector<Token>& tokens, TokenSpan token_span,
170                        CodepointSpan selection_span_for_feature,
171                        const EmbeddingExecutor* embedding_executor,
172                        EmbeddingCache* embedding_cache, int feature_vector_size,
173                        std::unique_ptr<CachedFeatures>* cached_features) const;
174 
175   // Fills selection_label_spans with CodepointSpans that correspond to the
176   // selection labels. The CodepointSpans are based on the codepoint ranges of
177   // given tokens.
178   bool SelectionLabelSpans(
179       VectorSpan<Token> tokens,
180       std::vector<CodepointSpan>* selection_label_spans) const;
181 
DenseFeaturesCount()182   int DenseFeaturesCount() const {
183     return feature_extractor_.DenseFeaturesCount();
184   }
185 
EmbeddingSize()186   int EmbeddingSize() const { return options_->embedding_size(); }
187 
188   // Splits context to several segments.
189   std::vector<UnicodeTextRange> SplitContext(
190       const UnicodeText& context_unicode) const;
191 
192   // Strips boundary codepoints from the span in context and returns the new
193   // start and end indices. If the span comprises entirely of boundary
194   // codepoints, the first index of span is returned for both indices.
195   CodepointSpan StripBoundaryCodepoints(const std::string& context,
196                                         CodepointSpan span) const;
197 
198   // Same as above but takes UnicodeText.
199   CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
200                                         CodepointSpan span) const;
201 
202  protected:
203   // Represents a codepoint range [start, end).
204   struct CodepointRange {
205     int32 start;
206     int32 end;
207 
CodepointRangeCodepointRange208     CodepointRange(int32 arg_start, int32 arg_end)
209         : start(arg_start), end(arg_end) {}
210   };
211 
212   // Returns the class id corresponding to the given string collection
213   // identifier. There is a catch-all class id that the function returns for
214   // unknown collections.
215   int CollectionToLabel(const std::string& collection) const;
216 
217   // Prepares mapping from collection names to labels.
218   void MakeLabelMaps();
219 
220   // Gets the number of spannable tokens for the model.
221   //
222   // Spannable tokens are those tokens of context, which the model predicts
223   // selection spans over (i.e., there is 1:1 correspondence between the output
224   // classes of the model and each of the spannable tokens).
GetNumContextTokens()225   int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
226 
227   // Converts a label into a span of codepoint indices corresponding to it
228   // given output_tokens.
229   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
230                    CodepointSpan* span) const;
231 
232   // Converts a span to the corresponding label given output_tokens.
233   bool SpanToLabel(const std::pair<CodepointIndex, CodepointIndex>& span,
234                    const std::vector<Token>& output_tokens, int* label) const;
235 
236   // Converts a token span to the corresponding label.
237   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
238 
239   void PrepareCodepointRanges(
240       const std::vector<const FeatureProcessorOptions_::CodepointRange*>&
241           codepoint_ranges,
242       std::vector<CodepointRange>* prepared_codepoint_ranges);
243 
244   // Returns the ratio of supported codepoints to total number of codepoints in
245   // the given token span.
246   float SupportedCodepointsRatio(const TokenSpan& token_span,
247                                  const std::vector<Token>& tokens) const;
248 
249   // Returns true if given codepoint is covered by the given sorted vector of
250   // codepoint ranges.
251   bool IsCodepointInRanges(
252       int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const;
253 
254   void PrepareIgnoredSpanBoundaryCodepoints();
255 
256   // Counts the number of span boundary codepoints. If count_from_beginning is
257   // True, the counting will start at the span_start iterator (inclusive) and at
258   // maximum end at span_end (exclusive). If count_from_beginning is True, the
259   // counting will start from span_end (exclusive) and end at span_start
260   // (inclusive).
261   int CountIgnoredSpanBoundaryCodepoints(
262       const UnicodeText::const_iterator& span_start,
263       const UnicodeText::const_iterator& span_end,
264       bool count_from_beginning) const;
265 
266   // Finds the center token index in tokens vector, using the method defined
267   // in options_.
268   int FindCenterToken(CodepointSpan span,
269                       const std::vector<Token>& tokens) const;
270 
271   // Tokenizes the input text using ICU tokenizer.
272   bool ICUTokenize(const UnicodeText& context_unicode,
273                    std::vector<Token>* result) const;
274 
275   // Takes the result of ICU tokenization and retokenizes stretches of tokens
276   // made of a specific subset of characters using the internal tokenizer.
277   void InternalRetokenize(const UnicodeText& unicode_text,
278                           std::vector<Token>* tokens) const;
279 
280   // Tokenizes a substring of the unicode string, appending the resulting tokens
281   // to the output vector. The resulting tokens have bounds relative to the full
282   // string. Does nothing if the start of the span is negative.
283   void TokenizeSubstring(const UnicodeText& unicode_text, CodepointSpan span,
284                          std::vector<Token>* result) const;
285 
286   // Removes all tokens from tokens that are not on a line (defined by calling
287   // SplitContext on the context) to which span points.
288   void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
289                                  std::vector<Token>* tokens) const;
290 
291   // Same as above but takes UnicodeText.
292   void StripTokensFromOtherLines(const UnicodeText& context_unicode,
293                                  CodepointSpan span,
294                                  std::vector<Token>* tokens) const;
295 
296   // Extracts the features of a token and appends them to the output vector.
297   // Uses the embedding cache to to avoid re-extracting the re-embedding the
298   // sparse features for the same token.
299   bool AppendTokenFeaturesWithCache(const Token& token,
300                                     CodepointSpan selection_span_for_feature,
301                                     const EmbeddingExecutor* embedding_executor,
302                                     EmbeddingCache* embedding_cache,
303                                     std::vector<float>* output_features) const;
304 
305  private:
306   std::unique_ptr<UniLib> owned_unilib_;
307   const UniLib* unilib_;
308 
309  protected:
310   const TokenFeatureExtractor feature_extractor_;
311 
312   // Codepoint ranges that define what codepoints are supported by the model.
313   // NOTE: Must be sorted.
314   std::vector<CodepointRange> supported_codepoint_ranges_;
315 
316   // Codepoint ranges that define which tokens (consisting of which codepoints)
317   // should be re-tokenized with the internal tokenizer in the mixed
318   // tokenization mode.
319   // NOTE: Must be sorted.
320   std::vector<CodepointRange> internal_tokenizer_codepoint_ranges_;
321 
322  private:
323   // Set of codepoints that will be stripped from beginning and end of
324   // predicted spans.
325   std::set<int32> ignored_span_boundary_codepoints_;
326 
327   const FeatureProcessorOptions* const options_;
328 
329   // Mapping between token selection spans and labels ids.
330   std::map<TokenSpan, int> selection_to_label_;
331   std::vector<TokenSpan> label_to_selection_;
332 
333   // Mapping between collections and labels.
334   std::map<std::string, int> collection_to_label_;
335 
336   Tokenizer tokenizer_;
337 };
338 
339 }  // namespace libtextclassifier2
340 
341 #endif  // LIBTEXTCLASSIFIER_FEATURE_PROCESSOR_H_
342