/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Inference code for the feed-forward text classification models. #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ #define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_ #include #include #include #include "base.h" #include "common/embedding-network.h" #include "common/feature-extractor.h" #include "common/memory_image/embedding-network-params-from-image.h" #include "common/mmap.h" #include "smartselect/feature-processor.h" #include "smartselect/model-params.h" #include "smartselect/text-classification-model.pb.h" #include "smartselect/types.h" namespace libtextclassifier { // SmartSelection/Sharing feed-forward model. class TextClassificationModel { public: // Loads TextClassificationModel from given file given by an int // file descriptor. explicit TextClassificationModel(int fd); // Bit flags for the input selection. enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 }; // Runs inference for given a context and current selection (i.e. index // of the first and one past last selected characters (utf8 codepoint // offsets)). Returns the indices (utf8 codepoint offsets) of the selection // beginning character and one past selection end character. // Returns the original click_indices if an error occurs. // NOTE: The selection indices are passed in and returned in terms of // UTF8 codepoints (not bytes). // Requires that the model is a smart selection model. CodepointSpan SuggestSelection(const std::string& context, CodepointSpan click_indices) const; // Classifies the selected text given the context string. // Requires that the model is a smart sharing model. // Returns an empty result if an error occurs. std::vector> ClassifyText( const std::string& context, CodepointSpan click_indices, int input_flags = 0) const; protected: // Removes punctuation from the beginning and end of the selection and returns // the new selection span. CodepointSpan StripPunctuation(CodepointSpan selection, const std::string& context) const; // During evaluation we need access to the feature processor. FeatureProcessor* SelectionFeatureProcessor() const { return selection_feature_processor_.get(); } // Collection name when url hint is accepted. const std::string kUrlHintCollection = "url"; // Collection name when email hint is accepted. const std::string kEmailHintCollection = "email"; // Collection name for other. const std::string kOtherCollection = "other"; // Collection name for phone. const std::string kPhoneCollection = "phone"; SelectionModelOptions selection_options_; SharingModelOptions sharing_options_; private: bool LoadModels(const nlp_core::MmapHandle& mmap_handle); nlp_core::EmbeddingNetwork::Vector InferInternal( const std::string& context, CodepointSpan span, const FeatureProcessor& feature_processor, const nlp_core::EmbeddingNetwork& network, const FeatureVectorFn& feature_vector_fn, std::vector* selection_label_spans) const; // Returns a selection suggestion with a score. std::pair SuggestSelectionInternal( const std::string& context, CodepointSpan click_indices) const; // Returns a selection suggestion and makes sure it's symmetric. Internally // runs several times SuggestSelectionInternal. CodepointSpan SuggestSelectionSymmetrical(const std::string& context, CodepointSpan click_indices) const; bool initialized_; nlp_core::ScopedMmap mmap_; std::unique_ptr selection_params_; std::unique_ptr selection_feature_processor_; std::unique_ptr selection_network_; FeatureVectorFn selection_feature_fn_; std::unique_ptr sharing_feature_processor_; std::unique_ptr sharing_params_; std::unique_ptr sharing_network_; FeatureVectorFn sharing_feature_fn_; std::set punctuation_to_strip_; }; // Parses the merged image given as a file descriptor, and reads // the ModelOptions proto from the selection model. bool ReadSelectionModelOptions(int fd, ModelOptions* model_options); } // namespace libtextclassifier #endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_