1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Inference code for the feed-forward text classification models. 18 19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ 20 #define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ 21 22 #include <memory> 23 #include <set> 24 #include <string> 25 26 #include "base.h" 27 #include "common/embedding-network.h" 28 #include "common/feature-extractor.h" 29 #include "common/memory_image/embedding-network-params-from-image.h" 30 #include "common/mmap.h" 31 #include "smartselect/feature-processor.h" 32 #include "smartselect/model-params.h" 33 #include "smartselect/text-classification-model.pb.h" 34 #include "smartselect/types.h" 35 36 namespace libtextclassifier { 37 38 // SmartSelection/Sharing feed-forward model. 39 class TextClassificationModel { 40 public: 41 // Loads TextClassificationModel from given file given by an int 42 // file descriptor. 43 explicit TextClassificationModel(int fd); 44 45 // Bit flags for the input selection. 46 enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 }; 47 48 // Runs inference for given a context and current selection (i.e. index 49 // of the first and one past last selected characters (utf8 codepoint 50 // offsets)). Returns the indices (utf8 codepoint offsets) of the selection 51 // beginning character and one past selection end character. 52 // Returns the original click_indices if an error occurs. 53 // NOTE: The selection indices are passed in and returned in terms of 54 // UTF8 codepoints (not bytes). 55 // Requires that the model is a smart selection model. 56 CodepointSpan SuggestSelection(const std::string& context, 57 CodepointSpan click_indices) const; 58 59 // Classifies the selected text given the context string. 60 // Requires that the model is a smart sharing model. 61 // Returns an empty result if an error occurs. 62 std::vector<std::pair<std::string, float>> ClassifyText( 63 const std::string& context, CodepointSpan click_indices, 64 int input_flags = 0) const; 65 66 protected: 67 // Removes punctuation from the beginning and end of the selection and returns 68 // the new selection span. 69 CodepointSpan StripPunctuation(CodepointSpan selection, 70 const std::string& context) const; 71 72 // During evaluation we need access to the feature processor. SelectionFeatureProcessor()73 FeatureProcessor* SelectionFeatureProcessor() const { 74 return selection_feature_processor_.get(); 75 } 76 77 // Collection name when url hint is accepted. 78 const std::string kUrlHintCollection = "url"; 79 80 // Collection name when email hint is accepted. 81 const std::string kEmailHintCollection = "email"; 82 83 // Collection name for other. 84 const std::string kOtherCollection = "other"; 85 86 // Collection name for phone. 87 const std::string kPhoneCollection = "phone"; 88 89 SelectionModelOptions selection_options_; 90 SharingModelOptions sharing_options_; 91 92 private: 93 bool LoadModels(const nlp_core::MmapHandle& mmap_handle); 94 95 nlp_core::EmbeddingNetwork::Vector InferInternal( 96 const std::string& context, CodepointSpan span, 97 const FeatureProcessor& feature_processor, 98 const nlp_core::EmbeddingNetwork& network, 99 const FeatureVectorFn& feature_vector_fn, 100 std::vector<CodepointSpan>* selection_label_spans) const; 101 102 // Returns a selection suggestion with a score. 103 std::pair<CodepointSpan, float> SuggestSelectionInternal( 104 const std::string& context, CodepointSpan click_indices) const; 105 106 // Returns a selection suggestion and makes sure it's symmetric. Internally 107 // runs several times SuggestSelectionInternal. 108 CodepointSpan SuggestSelectionSymmetrical(const std::string& context, 109 CodepointSpan click_indices) const; 110 111 bool initialized_; 112 nlp_core::ScopedMmap mmap_; 113 std::unique_ptr<ModelParams> selection_params_; 114 std::unique_ptr<FeatureProcessor> selection_feature_processor_; 115 std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_; 116 FeatureVectorFn selection_feature_fn_; 117 std::unique_ptr<FeatureProcessor> sharing_feature_processor_; 118 std::unique_ptr<ModelParams> sharing_params_; 119 std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_; 120 FeatureVectorFn sharing_feature_fn_; 121 122 std::set<int> punctuation_to_strip_; 123 }; 124 125 // Parses the merged image given as a file descriptor, and reads 126 // the ModelOptions proto from the selection model. 127 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options); 128 129 } // namespace libtextclassifier 130 131 #endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ 132