• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Inference code for the text classification model.
18 
19 #ifndef LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
20 #define LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
21 
22 #include <memory>
23 #include <set>
24 #include <string>
25 #include <vector>
26 
27 #include "datetime/parser.h"
28 #include "feature-processor.h"
29 #include "model-executor.h"
30 #include "model_generated.h"
31 #include "strip-unpaired-brackets.h"
32 #include "types.h"
33 #include "util/memory/mmap.h"
34 #include "util/utf8/unilib.h"
35 #include "zlib-utils.h"
36 
37 namespace libtextclassifier2 {
38 
39 struct SelectionOptions {
40   // Comma-separated list of locale specification for the input text (BCP 47
41   // tags).
42   std::string locales;
43 
DefaultSelectionOptions44   static SelectionOptions Default() { return SelectionOptions(); }
45 };
46 
47 struct ClassificationOptions {
48   // For parsing relative datetimes, the reference now time against which the
49   // relative datetimes get resolved.
50   // UTC milliseconds since epoch.
51   int64 reference_time_ms_utc = 0;
52 
53   // Timezone in which the input text was written (format as accepted by ICU).
54   std::string reference_timezone;
55 
56   // Comma-separated list of locale specification for the input text (BCP 47
57   // tags).
58   std::string locales;
59 
DefaultClassificationOptions60   static ClassificationOptions Default() { return ClassificationOptions(); }
61 };
62 
63 struct AnnotationOptions {
64   // For parsing relative datetimes, the reference now time against which the
65   // relative datetimes get resolved.
66   // UTC milliseconds since epoch.
67   int64 reference_time_ms_utc = 0;
68 
69   // Timezone in which the input text was written (format as accepted by ICU).
70   std::string reference_timezone;
71 
72   // Comma-separated list of locale specification for the input text (BCP 47
73   // tags).
74   std::string locales;
75 
DefaultAnnotationOptions76   static AnnotationOptions Default() { return AnnotationOptions(); }
77 };
78 
79 // Holds TFLite interpreters for selection and classification models.
80 // NOTE: his class is not thread-safe, thus should NOT be re-used across
81 // threads.
82 class InterpreterManager {
83  public:
84   // The constructor can be called with nullptr for any of the executors, and is
85   // a defined behavior, as long as the corresponding *Interpreter() method is
86   // not called when the executor is null.
InterpreterManager(const ModelExecutor * selection_executor,const ModelExecutor * classification_executor)87   InterpreterManager(const ModelExecutor* selection_executor,
88                      const ModelExecutor* classification_executor)
89       : selection_executor_(selection_executor),
90         classification_executor_(classification_executor) {}
91 
92   // Gets or creates and caches an interpreter for the selection model.
93   tflite::Interpreter* SelectionInterpreter();
94 
95   // Gets or creates and caches an interpreter for the classification model.
96   tflite::Interpreter* ClassificationInterpreter();
97 
98  private:
99   const ModelExecutor* selection_executor_;
100   const ModelExecutor* classification_executor_;
101 
102   std::unique_ptr<tflite::Interpreter> selection_interpreter_;
103   std::unique_ptr<tflite::Interpreter> classification_interpreter_;
104 };
105 
106 // A text processing model that provides text classification, annotation,
107 // selection suggestion for various types.
108 // NOTE: This class is not thread-safe.
109 class TextClassifier {
110  public:
111   static std::unique_ptr<TextClassifier> FromUnownedBuffer(
112       const char* buffer, int size, const UniLib* unilib = nullptr);
113   // Takes ownership of the mmap.
114   static std::unique_ptr<TextClassifier> FromScopedMmap(
115       std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr);
116   static std::unique_ptr<TextClassifier> FromFileDescriptor(
117       int fd, int offset, int size, const UniLib* unilib = nullptr);
118   static std::unique_ptr<TextClassifier> FromFileDescriptor(
119       int fd, const UniLib* unilib = nullptr);
120   static std::unique_ptr<TextClassifier> FromPath(
121       const std::string& path, const UniLib* unilib = nullptr);
122 
123   // Returns true if the model is ready for use.
IsInitialized()124   bool IsInitialized() { return initialized_; }
125 
126   // Runs inference for given a context and current selection (i.e. index
127   // of the first and one past last selected characters (utf8 codepoint
128   // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
129   // beginning character and one past selection end character.
130   // Returns the original click_indices if an error occurs.
131   // NOTE: The selection indices are passed in and returned in terms of
132   // UTF8 codepoints (not bytes).
133   // Requires that the model is a smart selection model.
134   CodepointSpan SuggestSelection(
135       const std::string& context, CodepointSpan click_indices,
136       const SelectionOptions& options = SelectionOptions::Default()) const;
137 
138   // Classifies the selected text given the context string.
139   // Returns an empty result if an error occurs.
140   std::vector<ClassificationResult> ClassifyText(
141       const std::string& context, CodepointSpan selection_indices,
142       const ClassificationOptions& options =
143           ClassificationOptions::Default()) const;
144 
145   // Annotates given input text. The annotations are sorted by their position
146   // in the context string and exclude spans classified as 'other'.
147   std::vector<AnnotatedSpan> Annotate(
148       const std::string& context,
149       const AnnotationOptions& options = AnnotationOptions::Default()) const;
150 
151   // Exposes the feature processor for tests and evaluations.
152   const FeatureProcessor* SelectionFeatureProcessorForTests() const;
153   const FeatureProcessor* ClassificationFeatureProcessorForTests() const;
154 
155   // Exposes the date time parser for tests and evaluations.
156   const DatetimeParser* DatetimeParserForTests() const;
157 
158   // String collection names for various classes.
159   static const std::string& kOtherCollection;
160   static const std::string& kPhoneCollection;
161   static const std::string& kAddressCollection;
162   static const std::string& kDateCollection;
163 
164  protected:
165   struct ScoredChunk {
166     TokenSpan token_span;
167     float score;
168   };
169 
170   // Constructs and initializes text classifier from given model.
171   // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'.
TextClassifier(std::unique_ptr<ScopedMmap> * mmap,const Model * model,const UniLib * unilib)172   TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
173                  const UniLib* unilib)
174       : model_(model),
175         mmap_(std::move(*mmap)),
176         owned_unilib_(nullptr),
177         unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) {
178     ValidateAndInitialize();
179   }
180 
181   // Constructs, validates and initializes text classifier from given model.
182   // Does not own the buffer that backs 'model'.
TextClassifier(const Model * model,const UniLib * unilib)183   explicit TextClassifier(const Model* model, const UniLib* unilib)
184       : model_(model),
185         owned_unilib_(nullptr),
186         unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) {
187     ValidateAndInitialize();
188   }
189 
190   // Checks that model contains all required fields, and initializes internal
191   // datastructures.
192   void ValidateAndInitialize();
193 
194   // Initializes regular expressions for the regex model.
195   bool InitializeRegexModel(ZlibDecompressor* decompressor);
196 
197   // Resolves conflicts in the list of candidates by removing some overlapping
198   // ones. Returns indices of the surviving ones.
199   // NOTE: Assumes that the candidates are sorted according to their position in
200   // the span.
201   bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates,
202                         const std::string& context,
203                         const std::vector<Token>& cached_tokens,
204                         InterpreterManager* interpreter_manager,
205                         std::vector<int>* result) const;
206 
207   // Resolves one conflict between candidates on indices 'start_index'
208   // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate
209   // indices to 'chosen_indices'. Returns false if a problem arises.
210   bool ResolveConflict(const std::string& context,
211                        const std::vector<Token>& cached_tokens,
212                        const std::vector<AnnotatedSpan>& candidates,
213                        int start_index, int end_index,
214                        InterpreterManager* interpreter_manager,
215                        std::vector<int>* chosen_indices) const;
216 
217   // Gets selection candidates from the ML model.
218   // Provides the tokens produced during tokenization of the context string for
219   // reuse.
220   bool ModelSuggestSelection(const UnicodeText& context_unicode,
221                              CodepointSpan click_indices,
222                              InterpreterManager* interpreter_manager,
223                              std::vector<Token>* tokens,
224                              std::vector<AnnotatedSpan>* result) const;
225 
226   // Classifies the selected text given the context string with the
227   // classification model.
228   // Returns true if no error occurred.
229   bool ModelClassifyText(
230       const std::string& context, const std::vector<Token>& cached_tokens,
231       CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
232       FeatureProcessor::EmbeddingCache* embedding_cache,
233       std::vector<ClassificationResult>* classification_results) const;
234 
235   bool ModelClassifyText(
236       const std::string& context, CodepointSpan selection_indices,
237       InterpreterManager* interpreter_manager,
238       FeatureProcessor::EmbeddingCache* embedding_cache,
239       std::vector<ClassificationResult>* classification_results) const;
240 
241   // Returns a relative token span that represents how many tokens on the left
242   // from the selection and right from the selection are needed for the
243   // classifier input.
244   TokenSpan ClassifyTextUpperBoundNeededTokens() const;
245 
246   // Classifies the selected text with the regular expressions models.
247   // Returns true if any regular expression matched and the result was set.
248   bool RegexClassifyText(const std::string& context,
249                          CodepointSpan selection_indices,
250                          ClassificationResult* classification_result) const;
251 
252   // Classifies the selected text with the date time model.
253   // Returns true if there was a match and the result was set.
254   bool DatetimeClassifyText(const std::string& context,
255                             CodepointSpan selection_indices,
256                             const ClassificationOptions& options,
257                             ClassificationResult* classification_result) const;
258 
259   // Chunks given input text with the selection model and classifies the spans
260   // with the classification model.
261   // The annotations are sorted by their position in the context string and
262   // exclude spans classified as 'other'.
263   // Provides the tokens produced during tokenization of the context string for
264   // reuse.
265   bool ModelAnnotate(const std::string& context,
266                      InterpreterManager* interpreter_manager,
267                      std::vector<Token>* tokens,
268                      std::vector<AnnotatedSpan>* result) const;
269 
270   // Groups the tokens into chunks. A chunk is a token span that should be the
271   // suggested selection when any of its contained tokens is clicked. The chunks
272   // are non-overlapping and are sorted by their position in the context string.
273   // "num_tokens" is the total number of tokens available (as this method does
274   // not need the actual vector of tokens).
275   // "span_of_interest" is a span of all the tokens that could be clicked.
276   // The resulting chunks all have to overlap with it and they cover this span
277   // completely. The first and last chunk might extend beyond it.
278   // The chunks vector is cleared before filling.
279   bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
280                   tflite::Interpreter* selection_interpreter,
281                   const CachedFeatures& cached_features,
282                   std::vector<TokenSpan>* chunks) const;
283 
284   // A helper method for ModelChunk(). It generates scored chunk candidates for
285   // a click context model.
286   // NOTE: The returned chunks can (and most likely do) overlap.
287   bool ModelClickContextScoreChunks(
288       int num_tokens, const TokenSpan& span_of_interest,
289       const CachedFeatures& cached_features,
290       tflite::Interpreter* selection_interpreter,
291       std::vector<ScoredChunk>* scored_chunks) const;
292 
293   // A helper method for ModelChunk(). It generates scored chunk candidates for
294   // a bounds-sensitive model.
295   // NOTE: The returned chunks can (and most likely do) overlap.
296   bool ModelBoundsSensitiveScoreChunks(
297       int num_tokens, const TokenSpan& span_of_interest,
298       const TokenSpan& inference_span, const CachedFeatures& cached_features,
299       tflite::Interpreter* selection_interpreter,
300       std::vector<ScoredChunk>* scored_chunks) const;
301 
302   // Produces chunks isolated by a set of regular expressions.
303   bool RegexChunk(const UnicodeText& context_unicode,
304                   const std::vector<int>& rules,
305                   std::vector<AnnotatedSpan>* result) const;
306 
307   // Produces chunks from the datetime parser.
308   bool DatetimeChunk(const UnicodeText& context_unicode,
309                      int64 reference_time_ms_utc,
310                      const std::string& reference_timezone,
311                      const std::string& locales, ModeFlag mode,
312                      std::vector<AnnotatedSpan>* result) const;
313 
314   // Returns whether a classification should be filtered.
315   bool FilteredForAnnotation(const AnnotatedSpan& span) const;
316   bool FilteredForClassification(
317       const ClassificationResult& classification) const;
318   bool FilteredForSelection(const AnnotatedSpan& span) const;
319 
320   const Model* model_;
321 
322   std::unique_ptr<const ModelExecutor> selection_executor_;
323   std::unique_ptr<const ModelExecutor> classification_executor_;
324   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
325 
326   std::unique_ptr<const FeatureProcessor> selection_feature_processor_;
327   std::unique_ptr<const FeatureProcessor> classification_feature_processor_;
328 
329   std::unique_ptr<const DatetimeParser> datetime_parser_;
330 
331  private:
332   struct CompiledRegexPattern {
333     std::string collection_name;
334     float target_classification_score;
335     float priority_score;
336     std::unique_ptr<UniLib::RegexPattern> pattern;
337   };
338 
339   std::unique_ptr<ScopedMmap> mmap_;
340   bool initialized_ = false;
341   bool enabled_for_annotation_ = false;
342   bool enabled_for_classification_ = false;
343   bool enabled_for_selection_ = false;
344   std::unordered_set<std::string> filtered_collections_annotation_;
345   std::unordered_set<std::string> filtered_collections_classification_;
346   std::unordered_set<std::string> filtered_collections_selection_;
347 
348   std::vector<CompiledRegexPattern> regex_patterns_;
349   std::unordered_set<int> regex_approximate_match_pattern_ids_;
350 
351   // Indices into regex_patterns_ for the different modes.
352   std::vector<int> annotation_regex_patterns_, classification_regex_patterns_,
353       selection_regex_patterns_;
354 
355   std::unique_ptr<UniLib> owned_unilib_;
356   const UniLib* unilib_;
357 };
358 
359 namespace internal {
360 
361 // Helper function, which if the initial 'span' contains only white-spaces,
362 // moves the selection to a single-codepoint selection on the left side
363 // of this block of white-space.
364 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
365                                             const UnicodeText& context_unicode,
366                                             const UniLib& unilib);
367 
368 // Copies tokens from 'cached_tokens' that are
369 // 'tokens_around_selection_to_copy' (on the left, and right) tokens distant
370 // from the tokens that correspond to 'selection_indices'.
371 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
372                                     CodepointSpan selection_indices,
373                                     TokenSpan tokens_around_selection_to_copy);
374 }  // namespace internal
375 
376 // Interprets the buffer as a Model flatbuffer and returns it for reading.
377 const Model* ViewModel(const void* buffer, int size);
378 
379 }  // namespace libtextclassifier2
380 
381 #endif  // LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_
382