1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Inference code for the text classification model. 18 19 #ifndef LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ 20 #define LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ 21 22 #include <memory> 23 #include <set> 24 #include <string> 25 #include <vector> 26 27 #include "datetime/parser.h" 28 #include "feature-processor.h" 29 #include "model-executor.h" 30 #include "model_generated.h" 31 #include "strip-unpaired-brackets.h" 32 #include "types.h" 33 #include "util/memory/mmap.h" 34 #include "util/utf8/unilib.h" 35 #include "zlib-utils.h" 36 37 namespace libtextclassifier2 { 38 39 struct SelectionOptions { 40 // Comma-separated list of locale specification for the input text (BCP 47 41 // tags). 42 std::string locales; 43 DefaultSelectionOptions44 static SelectionOptions Default() { return SelectionOptions(); } 45 }; 46 47 struct ClassificationOptions { 48 // For parsing relative datetimes, the reference now time against which the 49 // relative datetimes get resolved. 50 // UTC milliseconds since epoch. 51 int64 reference_time_ms_utc = 0; 52 53 // Timezone in which the input text was written (format as accepted by ICU). 54 std::string reference_timezone; 55 56 // Comma-separated list of locale specification for the input text (BCP 47 57 // tags). 58 std::string locales; 59 DefaultClassificationOptions60 static ClassificationOptions Default() { return ClassificationOptions(); } 61 }; 62 63 struct AnnotationOptions { 64 // For parsing relative datetimes, the reference now time against which the 65 // relative datetimes get resolved. 66 // UTC milliseconds since epoch. 67 int64 reference_time_ms_utc = 0; 68 69 // Timezone in which the input text was written (format as accepted by ICU). 70 std::string reference_timezone; 71 72 // Comma-separated list of locale specification for the input text (BCP 47 73 // tags). 74 std::string locales; 75 DefaultAnnotationOptions76 static AnnotationOptions Default() { return AnnotationOptions(); } 77 }; 78 79 // Holds TFLite interpreters for selection and classification models. 80 // NOTE: his class is not thread-safe, thus should NOT be re-used across 81 // threads. 82 class InterpreterManager { 83 public: 84 // The constructor can be called with nullptr for any of the executors, and is 85 // a defined behavior, as long as the corresponding *Interpreter() method is 86 // not called when the executor is null. InterpreterManager(const ModelExecutor * selection_executor,const ModelExecutor * classification_executor)87 InterpreterManager(const ModelExecutor* selection_executor, 88 const ModelExecutor* classification_executor) 89 : selection_executor_(selection_executor), 90 classification_executor_(classification_executor) {} 91 92 // Gets or creates and caches an interpreter for the selection model. 93 tflite::Interpreter* SelectionInterpreter(); 94 95 // Gets or creates and caches an interpreter for the classification model. 96 tflite::Interpreter* ClassificationInterpreter(); 97 98 private: 99 const ModelExecutor* selection_executor_; 100 const ModelExecutor* classification_executor_; 101 102 std::unique_ptr<tflite::Interpreter> selection_interpreter_; 103 std::unique_ptr<tflite::Interpreter> classification_interpreter_; 104 }; 105 106 // A text processing model that provides text classification, annotation, 107 // selection suggestion for various types. 108 // NOTE: This class is not thread-safe. 109 class TextClassifier { 110 public: 111 static std::unique_ptr<TextClassifier> FromUnownedBuffer( 112 const char* buffer, int size, const UniLib* unilib = nullptr); 113 // Takes ownership of the mmap. 114 static std::unique_ptr<TextClassifier> FromScopedMmap( 115 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib = nullptr); 116 static std::unique_ptr<TextClassifier> FromFileDescriptor( 117 int fd, int offset, int size, const UniLib* unilib = nullptr); 118 static std::unique_ptr<TextClassifier> FromFileDescriptor( 119 int fd, const UniLib* unilib = nullptr); 120 static std::unique_ptr<TextClassifier> FromPath( 121 const std::string& path, const UniLib* unilib = nullptr); 122 123 // Returns true if the model is ready for use. IsInitialized()124 bool IsInitialized() { return initialized_; } 125 126 // Runs inference for given a context and current selection (i.e. index 127 // of the first and one past last selected characters (utf8 codepoint 128 // offsets)). Returns the indices (utf8 codepoint offsets) of the selection 129 // beginning character and one past selection end character. 130 // Returns the original click_indices if an error occurs. 131 // NOTE: The selection indices are passed in and returned in terms of 132 // UTF8 codepoints (not bytes). 133 // Requires that the model is a smart selection model. 134 CodepointSpan SuggestSelection( 135 const std::string& context, CodepointSpan click_indices, 136 const SelectionOptions& options = SelectionOptions::Default()) const; 137 138 // Classifies the selected text given the context string. 139 // Returns an empty result if an error occurs. 140 std::vector<ClassificationResult> ClassifyText( 141 const std::string& context, CodepointSpan selection_indices, 142 const ClassificationOptions& options = 143 ClassificationOptions::Default()) const; 144 145 // Annotates given input text. The annotations are sorted by their position 146 // in the context string and exclude spans classified as 'other'. 147 std::vector<AnnotatedSpan> Annotate( 148 const std::string& context, 149 const AnnotationOptions& options = AnnotationOptions::Default()) const; 150 151 // Exposes the feature processor for tests and evaluations. 152 const FeatureProcessor* SelectionFeatureProcessorForTests() const; 153 const FeatureProcessor* ClassificationFeatureProcessorForTests() const; 154 155 // Exposes the date time parser for tests and evaluations. 156 const DatetimeParser* DatetimeParserForTests() const; 157 158 // String collection names for various classes. 159 static const std::string& kOtherCollection; 160 static const std::string& kPhoneCollection; 161 static const std::string& kAddressCollection; 162 static const std::string& kDateCollection; 163 164 protected: 165 struct ScoredChunk { 166 TokenSpan token_span; 167 float score; 168 }; 169 170 // Constructs and initializes text classifier from given model. 171 // Takes ownership of 'mmap', and thus owns the buffer that backs 'model'. TextClassifier(std::unique_ptr<ScopedMmap> * mmap,const Model * model,const UniLib * unilib)172 TextClassifier(std::unique_ptr<ScopedMmap>* mmap, const Model* model, 173 const UniLib* unilib) 174 : model_(model), 175 mmap_(std::move(*mmap)), 176 owned_unilib_(nullptr), 177 unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) { 178 ValidateAndInitialize(); 179 } 180 181 // Constructs, validates and initializes text classifier from given model. 182 // Does not own the buffer that backs 'model'. TextClassifier(const Model * model,const UniLib * unilib)183 explicit TextClassifier(const Model* model, const UniLib* unilib) 184 : model_(model), 185 owned_unilib_(nullptr), 186 unilib_(internal::MaybeCreateUnilib(unilib, &owned_unilib_)) { 187 ValidateAndInitialize(); 188 } 189 190 // Checks that model contains all required fields, and initializes internal 191 // datastructures. 192 void ValidateAndInitialize(); 193 194 // Initializes regular expressions for the regex model. 195 bool InitializeRegexModel(ZlibDecompressor* decompressor); 196 197 // Resolves conflicts in the list of candidates by removing some overlapping 198 // ones. Returns indices of the surviving ones. 199 // NOTE: Assumes that the candidates are sorted according to their position in 200 // the span. 201 bool ResolveConflicts(const std::vector<AnnotatedSpan>& candidates, 202 const std::string& context, 203 const std::vector<Token>& cached_tokens, 204 InterpreterManager* interpreter_manager, 205 std::vector<int>* result) const; 206 207 // Resolves one conflict between candidates on indices 'start_index' 208 // (inclusive) and 'end_index' (exclusive). Assigns the winning candidate 209 // indices to 'chosen_indices'. Returns false if a problem arises. 210 bool ResolveConflict(const std::string& context, 211 const std::vector<Token>& cached_tokens, 212 const std::vector<AnnotatedSpan>& candidates, 213 int start_index, int end_index, 214 InterpreterManager* interpreter_manager, 215 std::vector<int>* chosen_indices) const; 216 217 // Gets selection candidates from the ML model. 218 // Provides the tokens produced during tokenization of the context string for 219 // reuse. 220 bool ModelSuggestSelection(const UnicodeText& context_unicode, 221 CodepointSpan click_indices, 222 InterpreterManager* interpreter_manager, 223 std::vector<Token>* tokens, 224 std::vector<AnnotatedSpan>* result) const; 225 226 // Classifies the selected text given the context string with the 227 // classification model. 228 // Returns true if no error occurred. 229 bool ModelClassifyText( 230 const std::string& context, const std::vector<Token>& cached_tokens, 231 CodepointSpan selection_indices, InterpreterManager* interpreter_manager, 232 FeatureProcessor::EmbeddingCache* embedding_cache, 233 std::vector<ClassificationResult>* classification_results) const; 234 235 bool ModelClassifyText( 236 const std::string& context, CodepointSpan selection_indices, 237 InterpreterManager* interpreter_manager, 238 FeatureProcessor::EmbeddingCache* embedding_cache, 239 std::vector<ClassificationResult>* classification_results) const; 240 241 // Returns a relative token span that represents how many tokens on the left 242 // from the selection and right from the selection are needed for the 243 // classifier input. 244 TokenSpan ClassifyTextUpperBoundNeededTokens() const; 245 246 // Classifies the selected text with the regular expressions models. 247 // Returns true if any regular expression matched and the result was set. 248 bool RegexClassifyText(const std::string& context, 249 CodepointSpan selection_indices, 250 ClassificationResult* classification_result) const; 251 252 // Classifies the selected text with the date time model. 253 // Returns true if there was a match and the result was set. 254 bool DatetimeClassifyText(const std::string& context, 255 CodepointSpan selection_indices, 256 const ClassificationOptions& options, 257 ClassificationResult* classification_result) const; 258 259 // Chunks given input text with the selection model and classifies the spans 260 // with the classification model. 261 // The annotations are sorted by their position in the context string and 262 // exclude spans classified as 'other'. 263 // Provides the tokens produced during tokenization of the context string for 264 // reuse. 265 bool ModelAnnotate(const std::string& context, 266 InterpreterManager* interpreter_manager, 267 std::vector<Token>* tokens, 268 std::vector<AnnotatedSpan>* result) const; 269 270 // Groups the tokens into chunks. A chunk is a token span that should be the 271 // suggested selection when any of its contained tokens is clicked. The chunks 272 // are non-overlapping and are sorted by their position in the context string. 273 // "num_tokens" is the total number of tokens available (as this method does 274 // not need the actual vector of tokens). 275 // "span_of_interest" is a span of all the tokens that could be clicked. 276 // The resulting chunks all have to overlap with it and they cover this span 277 // completely. The first and last chunk might extend beyond it. 278 // The chunks vector is cleared before filling. 279 bool ModelChunk(int num_tokens, const TokenSpan& span_of_interest, 280 tflite::Interpreter* selection_interpreter, 281 const CachedFeatures& cached_features, 282 std::vector<TokenSpan>* chunks) const; 283 284 // A helper method for ModelChunk(). It generates scored chunk candidates for 285 // a click context model. 286 // NOTE: The returned chunks can (and most likely do) overlap. 287 bool ModelClickContextScoreChunks( 288 int num_tokens, const TokenSpan& span_of_interest, 289 const CachedFeatures& cached_features, 290 tflite::Interpreter* selection_interpreter, 291 std::vector<ScoredChunk>* scored_chunks) const; 292 293 // A helper method for ModelChunk(). It generates scored chunk candidates for 294 // a bounds-sensitive model. 295 // NOTE: The returned chunks can (and most likely do) overlap. 296 bool ModelBoundsSensitiveScoreChunks( 297 int num_tokens, const TokenSpan& span_of_interest, 298 const TokenSpan& inference_span, const CachedFeatures& cached_features, 299 tflite::Interpreter* selection_interpreter, 300 std::vector<ScoredChunk>* scored_chunks) const; 301 302 // Produces chunks isolated by a set of regular expressions. 303 bool RegexChunk(const UnicodeText& context_unicode, 304 const std::vector<int>& rules, 305 std::vector<AnnotatedSpan>* result) const; 306 307 // Produces chunks from the datetime parser. 308 bool DatetimeChunk(const UnicodeText& context_unicode, 309 int64 reference_time_ms_utc, 310 const std::string& reference_timezone, 311 const std::string& locales, ModeFlag mode, 312 std::vector<AnnotatedSpan>* result) const; 313 314 // Returns whether a classification should be filtered. 315 bool FilteredForAnnotation(const AnnotatedSpan& span) const; 316 bool FilteredForClassification( 317 const ClassificationResult& classification) const; 318 bool FilteredForSelection(const AnnotatedSpan& span) const; 319 320 const Model* model_; 321 322 std::unique_ptr<const ModelExecutor> selection_executor_; 323 std::unique_ptr<const ModelExecutor> classification_executor_; 324 std::unique_ptr<const EmbeddingExecutor> embedding_executor_; 325 326 std::unique_ptr<const FeatureProcessor> selection_feature_processor_; 327 std::unique_ptr<const FeatureProcessor> classification_feature_processor_; 328 329 std::unique_ptr<const DatetimeParser> datetime_parser_; 330 331 private: 332 struct CompiledRegexPattern { 333 std::string collection_name; 334 float target_classification_score; 335 float priority_score; 336 std::unique_ptr<UniLib::RegexPattern> pattern; 337 }; 338 339 std::unique_ptr<ScopedMmap> mmap_; 340 bool initialized_ = false; 341 bool enabled_for_annotation_ = false; 342 bool enabled_for_classification_ = false; 343 bool enabled_for_selection_ = false; 344 std::unordered_set<std::string> filtered_collections_annotation_; 345 std::unordered_set<std::string> filtered_collections_classification_; 346 std::unordered_set<std::string> filtered_collections_selection_; 347 348 std::vector<CompiledRegexPattern> regex_patterns_; 349 std::unordered_set<int> regex_approximate_match_pattern_ids_; 350 351 // Indices into regex_patterns_ for the different modes. 352 std::vector<int> annotation_regex_patterns_, classification_regex_patterns_, 353 selection_regex_patterns_; 354 355 std::unique_ptr<UniLib> owned_unilib_; 356 const UniLib* unilib_; 357 }; 358 359 namespace internal { 360 361 // Helper function, which if the initial 'span' contains only white-spaces, 362 // moves the selection to a single-codepoint selection on the left side 363 // of this block of white-space. 364 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, 365 const UnicodeText& context_unicode, 366 const UniLib& unilib); 367 368 // Copies tokens from 'cached_tokens' that are 369 // 'tokens_around_selection_to_copy' (on the left, and right) tokens distant 370 // from the tokens that correspond to 'selection_indices'. 371 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, 372 CodepointSpan selection_indices, 373 TokenSpan tokens_around_selection_to_copy); 374 } // namespace internal 375 376 // Interprets the buffer as a Model flatbuffer and returns it for reading. 377 const Model* ViewModel(const void* buffer, int size); 378 379 } // namespace libtextclassifier2 380 381 #endif // LIBTEXTCLASSIFIER_TEXT_CLASSIFIER_H_ 382