• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Feature processing for FFModel (feed-forward SmartSelection model).
18 
19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
20 #define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
21 
22 #include <map>
23 #include <memory>
24 #include <set>
25 #include <string>
26 #include <vector>
27 
28 #include "annotator/cached-features.h"
29 #include "annotator/model_generated.h"
30 #include "annotator/types.h"
31 #include "utils/base/integral_types.h"
32 #include "utils/base/logging.h"
33 #include "utils/token-feature-extractor.h"
34 #include "utils/tokenizer.h"
35 #include "utils/utf8/unicodetext.h"
36 #include "utils/utf8/unilib.h"
37 
38 namespace libtextclassifier3 {
39 
40 constexpr int kInvalidLabel = -1;
41 
42 namespace internal {
43 
44 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
45                          const UniLib* unilib);
46 
47 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
48     const FeatureProcessorOptions* options);
49 
50 // Splits tokens that contain the selection boundary inside them.
51 // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
52 void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
53                                       std::vector<Token>* tokens);
54 
55 // Returns the index of token that corresponds to the codepoint span.
56 int CenterTokenFromClick(const CodepointSpan& span,
57                          const std::vector<Token>& tokens);
58 
59 // Returns the index of token that corresponds to the middle of the  codepoint
60 // span.
61 int CenterTokenFromMiddleOfSelection(
62     const CodepointSpan& span, const std::vector<Token>& selectable_tokens);
63 
64 // Strips the tokens from the tokens vector that are not used for feature
65 // extraction because they are out of scope, or pads them so that there is
66 // enough tokens in the required context_size for all inferences with a click
67 // in relative_click_span.
68 void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
69                       std::vector<Token>* tokens, int* click_pos);
70 
71 }  // namespace internal
72 
73 // Converts a codepoint span to a token span in the given list of tokens.
74 // If snap_boundaries_to_containing_tokens is set to true, it is enough for a
75 // token to overlap with the codepoint range to be considered part of it.
76 // Otherwise it must be fully included in the range.
77 TokenSpan CodepointSpanToTokenSpan(
78     const std::vector<Token>& selectable_tokens,
79     const CodepointSpan& codepoint_span,
80     bool snap_boundaries_to_containing_tokens = false);
81 
82 // Converts a token span to a codepoint span in the given list of tokens.
83 CodepointSpan TokenSpanToCodepointSpan(
84     const std::vector<Token>& selectable_tokens, const TokenSpan& token_span);
85 
86 // Converts a codepoint span to a unicode text range, within the given unicode
87 // text.
88 // For an invalid span (with a negative index), returns (begin, begin). This
89 // means that it is safe to call this function before checking the validity of
90 // the span.
91 // The indices must fit within the unicode text.
92 // Note that the execution time is linear with respect to the codepoint indices.
93 // Calling this function repeatedly for spans on the same text might lead to
94 // inefficient code.
95 UnicodeTextRange CodepointSpanToUnicodeTextRange(
96     const UnicodeText& unicode_text, const CodepointSpan& span);
97 
98 // Takes care of preparing features for the span prediction model.
99 class FeatureProcessor {
100  public:
101   // A cache mapping codepoint spans to embedded tokens features. An instance
102   // can be provided to multiple calls to ExtractFeatures() operating on the
103   // same context (the same codepoint spans corresponding to the same tokens),
104   // as an optimization. Note that the tokenizations do not have to be
105   // identical.
106   typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
107 
FeatureProcessor(const FeatureProcessorOptions * options,const UniLib * unilib)108   explicit FeatureProcessor(const FeatureProcessorOptions* options,
109                             const UniLib* unilib)
110       : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
111                            unilib),
112         options_(options),
113         tokenizer_(internal::BuildTokenizer(options, unilib)) {
114     MakeLabelMaps();
115     if (options->supported_codepoint_ranges() != nullptr) {
116       SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
117                            options->supported_codepoint_ranges()->end()},
118                           &supported_codepoint_ranges_);
119     }
120     PrepareIgnoredSpanBoundaryCodepoints();
121   }
122 
123   // Tokenizes the input string using the selected tokenization method.
124   std::vector<Token> Tokenize(const std::string& text) const;
125 
126   // Same as above but takes UnicodeText.
127   std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
128 
129   // Converts a label into a token span.
130   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
131 
132   // Gets the total number of selection labels.
GetSelectionLabelCount()133   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
134 
135   // Gets the string value for given collection label.
136   std::string LabelToCollection(int label) const;
137 
138   // Gets the total number of collections of the model.
NumCollections()139   int NumCollections() const { return collection_to_label_.size(); }
140 
141   // Gets the name of the default collection.
142   std::string GetDefaultCollection() const;
143 
GetOptions()144   const FeatureProcessorOptions* GetOptions() const { return options_; }
145 
146   // Retokenizes the context and input span, and finds the click position.
147   // Depending on the options, might modify tokens (split them or remove them).
148   void RetokenizeAndFindClick(const std::string& context,
149                               const CodepointSpan& input_span,
150                               bool only_use_line_with_click,
151                               std::vector<Token>* tokens, int* click_pos) const;
152 
153   // Same as above, but takes UnicodeText and iterators within it corresponding
154   // to input_span.
155   void RetokenizeAndFindClick(const UnicodeText& context_unicode,
156                               const UnicodeText::const_iterator& span_begin,
157                               const UnicodeText::const_iterator& span_end,
158                               const CodepointSpan& input_span,
159                               bool only_use_line_with_click,
160                               std::vector<Token>* tokens, int* click_pos) const;
161 
162   // Returns true if the token span has enough supported codepoints (as defined
163   // in the model config) or not and model should not run.
164   bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
165                                     const TokenSpan& token_span) const;
166 
167   // Extracts features as a CachedFeatures object that can be used for repeated
168   // inference over token spans in the given context.
169   bool ExtractFeatures(const std::vector<Token>& tokens,
170                        const TokenSpan& token_span,
171                        const CodepointSpan& selection_span_for_feature,
172                        const EmbeddingExecutor* embedding_executor,
173                        EmbeddingCache* embedding_cache, int feature_vector_size,
174                        std::unique_ptr<CachedFeatures>* cached_features) const;
175 
176   // Fills selection_label_spans with CodepointSpans that correspond to the
177   // selection labels. The CodepointSpans are based on the codepoint ranges of
178   // given tokens.
179   bool SelectionLabelSpans(
180       VectorSpan<Token> tokens,
181       std::vector<CodepointSpan>* selection_label_spans) const;
182 
183   // Fills selection_label_relative_token_spans with number of tokens left and
184   // right from the click.
185   bool SelectionLabelRelativeTokenSpans(
186       std::vector<TokenSpan>* selection_label_relative_token_spans) const;
187 
DenseFeaturesCount()188   int DenseFeaturesCount() const {
189     return feature_extractor_.DenseFeaturesCount();
190   }
191 
EmbeddingSize()192   int EmbeddingSize() const { return options_->embedding_size(); }
193 
194   // Splits context to several segments.
195   std::vector<UnicodeTextRange> SplitContext(
196       const UnicodeText& context_unicode,
197       const bool use_pipe_character_for_newline) const;
198 
199   // Strips boundary codepoints from the span in context and returns the new
200   // start and end indices. If the span comprises entirely of boundary
201   // codepoints, the first index of span is returned for both indices.
202   CodepointSpan StripBoundaryCodepoints(const std::string& context,
203                                         const CodepointSpan& span) const;
204 
205   // Same as above but takes UnicodeText.
206   CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
207                                         const CodepointSpan& span) const;
208 
209   // Same as above but takes a pair of iterators for the span, for efficiency.
210   CodepointSpan StripBoundaryCodepoints(
211       const UnicodeText::const_iterator& span_begin,
212       const UnicodeText::const_iterator& span_end,
213       const CodepointSpan& span) const;
214 
215   // Same as above, but takes an optional buffer for saving the modified value.
216   // As an optimization, returns pointer to 'value' if nothing was stripped, or
217   // pointer to 'buffer' if something was stripped.
218   const std::string& StripBoundaryCodepoints(const std::string& value,
219                                              std::string* buffer) const;
220 
221  protected:
222   // Returns the class id corresponding to the given string collection
223   // identifier. There is a catch-all class id that the function returns for
224   // unknown collections.
225   int CollectionToLabel(const std::string& collection) const;
226 
227   // Prepares mapping from collection names to labels.
228   void MakeLabelMaps();
229 
230   // Gets the number of spannable tokens for the model.
231   //
232   // Spannable tokens are those tokens of context, which the model predicts
233   // selection spans over (i.e., there is 1:1 correspondence between the output
234   // classes of the model and each of the spannable tokens).
GetNumContextTokens()235   int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
236 
237   // Converts a label into a span of codepoint indices corresponding to it
238   // given output_tokens.
239   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
240                    CodepointSpan* span) const;
241 
242   // Converts a span to the corresponding label given output_tokens.
243   bool SpanToLabel(const CodepointSpan& span,
244                    const std::vector<Token>& output_tokens, int* label) const;
245 
246   // Converts a token span to the corresponding label.
247   int TokenSpanToLabel(const TokenSpan& token_span) const;
248 
249   // Returns the ratio of supported codepoints to total number of codepoints in
250   // the given token span.
251   float SupportedCodepointsRatio(const TokenSpan& token_span,
252                                  const std::vector<Token>& tokens) const;
253 
254   void PrepareIgnoredSpanBoundaryCodepoints();
255 
256   // Counts the number of span boundary codepoints. If count_from_beginning is
257   // True, the counting will start at the span_start iterator (inclusive) and at
258   // maximum end at span_end (exclusive). If count_from_beginning is True, the
259   // counting will start from span_end (exclusive) and end at span_start
260   // (inclusive).
261   int CountIgnoredSpanBoundaryCodepoints(
262       const UnicodeText::const_iterator& span_start,
263       const UnicodeText::const_iterator& span_end,
264       bool count_from_beginning) const;
265 
266   // Finds the center token index in tokens vector, using the method defined
267   // in options_.
268   int FindCenterToken(const CodepointSpan& span,
269                       const std::vector<Token>& tokens) const;
270 
271   // Removes all tokens from tokens that are not on a line (defined by calling
272   // SplitContext on the context) to which span points.
273   void StripTokensFromOtherLines(const std::string& context,
274                                  const CodepointSpan& span,
275                                  std::vector<Token>* tokens) const;
276 
277   // Same as above but takes UnicodeText.
278   void StripTokensFromOtherLines(const UnicodeText& context_unicode,
279                                  const UnicodeText::const_iterator& span_begin,
280                                  const UnicodeText::const_iterator& span_end,
281                                  const CodepointSpan& span,
282                                  std::vector<Token>* tokens) const;
283 
284   // Extracts the features of a token and appends them to the output vector.
285   // Uses the embedding cache to to avoid re-extracting the re-embedding the
286   // sparse features for the same token.
287   bool AppendTokenFeaturesWithCache(
288       const Token& token, const CodepointSpan& selection_span_for_feature,
289       const EmbeddingExecutor* embedding_executor,
290       EmbeddingCache* embedding_cache,
291       std::vector<float>* output_features) const;
292 
293  protected:
294   const TokenFeatureExtractor feature_extractor_;
295 
296   // Codepoint ranges that define what codepoints are supported by the model.
297   // NOTE: Must be sorted.
298   std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
299 
300  private:
301   // Set of codepoints that will be stripped from beginning and end of
302   // predicted spans.
303   std::unordered_set<int32> ignored_span_boundary_codepoints_;
304 
305   const FeatureProcessorOptions* const options_;
306 
307   // Mapping between token selection spans and labels ids.
308   std::map<TokenSpan, int> selection_to_label_;
309   std::vector<TokenSpan> label_to_selection_;
310 
311   // Mapping between collections and labels.
312   std::map<std::string, int> collection_to_label_;
313 
314   Tokenizer tokenizer_;
315 };
316 
317 }  // namespace libtextclassifier3
318 
319 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
320