• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Inference code for the feed-forward text classification models.
18 
19 #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
20 #define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
21 
22 #include <memory>
23 #include <set>
24 #include <string>
25 
26 #include "base.h"
27 #include "common/embedding-network.h"
28 #include "common/feature-extractor.h"
29 #include "common/memory_image/embedding-network-params-from-image.h"
30 #include "common/mmap.h"
31 #include "smartselect/feature-processor.h"
32 #include "smartselect/model-params.h"
33 #include "smartselect/text-classification-model.pb.h"
34 #include "smartselect/types.h"
35 
36 namespace libtextclassifier {
37 
38 // SmartSelection/Sharing feed-forward model.
39 class TextClassificationModel {
40  public:
41   // Loads TextClassificationModel from given file given by an int
42   // file descriptor.
43   explicit TextClassificationModel(int fd);
44 
45   // Bit flags for the input selection.
46   enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
47 
48   // Runs inference for given a context and current selection (i.e. index
49   // of the first and one past last selected characters (utf8 codepoint
50   // offsets)). Returns the indices (utf8 codepoint offsets) of the selection
51   // beginning character and one past selection end character.
52   // Returns the original click_indices if an error occurs.
53   // NOTE: The selection indices are passed in and returned in terms of
54   // UTF8 codepoints (not bytes).
55   // Requires that the model is a smart selection model.
56   CodepointSpan SuggestSelection(const std::string& context,
57                                  CodepointSpan click_indices) const;
58 
59   // Classifies the selected text given the context string.
60   // Requires that the model is a smart sharing model.
61   // Returns an empty result if an error occurs.
62   std::vector<std::pair<std::string, float>> ClassifyText(
63       const std::string& context, CodepointSpan click_indices,
64       int input_flags = 0) const;
65 
66  protected:
67   // Removes punctuation from the beginning and end of the selection and returns
68   // the new selection span.
69   CodepointSpan StripPunctuation(CodepointSpan selection,
70                                  const std::string& context) const;
71 
72   // During evaluation we need access to the feature processor.
SelectionFeatureProcessor()73   FeatureProcessor* SelectionFeatureProcessor() const {
74     return selection_feature_processor_.get();
75   }
76 
77   // Collection name when url hint is accepted.
78   const std::string kUrlHintCollection = "url";
79 
80   // Collection name when email hint is accepted.
81   const std::string kEmailHintCollection = "email";
82 
83   // Collection name for other.
84   const std::string kOtherCollection = "other";
85 
86   // Collection name for phone.
87   const std::string kPhoneCollection = "phone";
88 
89   SelectionModelOptions selection_options_;
90   SharingModelOptions sharing_options_;
91 
92  private:
93   bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
94 
95   nlp_core::EmbeddingNetwork::Vector InferInternal(
96       const std::string& context, CodepointSpan span,
97       const FeatureProcessor& feature_processor,
98       const nlp_core::EmbeddingNetwork& network,
99       const FeatureVectorFn& feature_vector_fn,
100       std::vector<CodepointSpan>* selection_label_spans) const;
101 
102   // Returns a selection suggestion with a score.
103   std::pair<CodepointSpan, float> SuggestSelectionInternal(
104       const std::string& context, CodepointSpan click_indices) const;
105 
106   // Returns a selection suggestion and makes sure it's symmetric. Internally
107   // runs several times SuggestSelectionInternal.
108   CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
109                                             CodepointSpan click_indices) const;
110 
111   bool initialized_;
112   nlp_core::ScopedMmap mmap_;
113   std::unique_ptr<ModelParams> selection_params_;
114   std::unique_ptr<FeatureProcessor> selection_feature_processor_;
115   std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
116   FeatureVectorFn selection_feature_fn_;
117   std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
118   std::unique_ptr<ModelParams> sharing_params_;
119   std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
120   FeatureVectorFn sharing_feature_fn_;
121 
122   std::set<int> punctuation_to_strip_;
123 };
124 
125 // Parses the merged image given as a file descriptor, and reads
126 // the ModelOptions proto from the selection model.
127 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
128 
129 }  // namespace libtextclassifier
130 
131 #endif  // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
132