1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_ 18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "lang_id/common/fel/feature-extractor.h" 25 #include "lang_id/common/fel/task-context.h" 26 #include "lang_id/common/fel/workspace.h" 27 #include "lang_id/common/lite_base/attributes.h" 28 #include "absl/strings/str_cat.h" 29 #include "absl/strings/string_view.h" 30 31 namespace libtextclassifier3 { 32 namespace mobile { 33 34 // An EmbeddingFeatureExtractor manages the extraction of features for 35 // embedding-based models. It wraps a sequence of underlying classes of feature 36 // extractors, along with associated predicate maps. Each class of feature 37 // extractors is associated with a name, e.g., "words", "labels", "tags". 38 // 39 // The class is split between a generic abstract version, 40 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the 41 // signature of the ExtractFeatures method) and a typed version. 42 // 43 // The predicate maps must be initialized before use: they can be loaded using 44 // Read() or updated via UpdateMapsForExample. 45 class GenericEmbeddingFeatureExtractor { 46 public: 47 // Constructs this GenericEmbeddingFeatureExtractor. 48 // 49 // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to 50 // avoid name clashes. See GetParamName(). GenericEmbeddingFeatureExtractor(absl::string_view arg_prefix)51 explicit GenericEmbeddingFeatureExtractor(absl::string_view arg_prefix) 52 : arg_prefix_(arg_prefix) {} 53 ~GenericEmbeddingFeatureExtractor()54 virtual ~GenericEmbeddingFeatureExtractor() {} 55 56 // Sets/inits up predicate maps and embedding space names that are common for 57 // all embedding based feature extractors. 58 // 59 // Returns true on success, false otherwise. 60 SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context); 61 SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context); 62 63 // Requests workspace for the underlying feature extractors. This is 64 // implemented in the typed class. 65 virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0; 66 67 // Returns number of embedding spaces. NumEmbeddings()68 int NumEmbeddings() const { return embedding_dims_.size(); } 69 embedding_fml()70 const std::vector<std::string> &embedding_fml() const { 71 return embedding_fml_; 72 } 73 74 // Get parameter name by concatenating the prefix and the original name. GetParamName(absl::string_view param_name)75 std::string GetParamName(absl::string_view param_name) const { 76 return absl::StrCat(arg_prefix_, "_", param_name); 77 } 78 79 private: 80 // Prefix for TaskContext parameters. 81 const std::string arg_prefix_; 82 83 // Embedding space names for parameter sharing. 84 std::vector<std::string> embedding_names_; 85 86 // FML strings for each feature extractor. 87 std::vector<std::string> embedding_fml_; 88 89 // Size of each of the embedding spaces (maximum predicate id). 90 std::vector<int> embedding_sizes_; 91 92 // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.) 93 std::vector<int> embedding_dims_; 94 }; 95 96 // Templated, object-specific implementation of the 97 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ, 98 // ARGS...> class that has the appropriate FeatureTraits() to ensure that 99 // locator type features work. 100 // 101 // Note: for backwards compatibility purposes, this always reads the FML spec 102 // from "<prefix>_features". 103 template <class EXTRACTOR, class OBJ, class... ARGS> 104 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { 105 public: 106 // Constructs this EmbeddingFeatureExtractor. 107 // 108 // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to 109 // avoid name clashes. See GetParamName(). EmbeddingFeatureExtractor(absl::string_view arg_prefix)110 explicit EmbeddingFeatureExtractor(absl::string_view arg_prefix) 111 : GenericEmbeddingFeatureExtractor(arg_prefix) {} 112 113 // Sets up all predicate maps, feature extractors, and flags. Setup(TaskContext * context)114 SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override { 115 if (!GenericEmbeddingFeatureExtractor::Setup(context)) { 116 return false; 117 } 118 feature_extractors_.resize(embedding_fml().size()); 119 for (size_t i = 0; i < embedding_fml().size(); ++i) { 120 feature_extractors_[i].reset(new EXTRACTOR()); 121 if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false; 122 if (!feature_extractors_[i]->Setup(context)) return false; 123 } 124 return true; 125 } 126 127 // Initializes resources needed by the feature extractors. Init(TaskContext * context)128 SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override { 129 if (!GenericEmbeddingFeatureExtractor::Init(context)) return false; 130 for (auto &feature_extractor : feature_extractors_) { 131 if (!feature_extractor->Init(context)) return false; 132 } 133 return true; 134 } 135 136 // Requests workspaces from the registry. Must be called after Init(), and 137 // before Preprocess(). RequestWorkspaces(WorkspaceRegistry * registry)138 void RequestWorkspaces(WorkspaceRegistry *registry) override { 139 for (auto &feature_extractor : feature_extractors_) { 140 feature_extractor->RequestWorkspaces(registry); 141 } 142 } 143 144 // Must be called on the object one state for each sentence, before any 145 // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures). Preprocess(WorkspaceSet * workspaces,OBJ * obj)146 void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const { 147 for (auto &feature_extractor : feature_extractors_) { 148 feature_extractor->Preprocess(workspaces, obj); 149 } 150 } 151 152 // Extracts features using the extractors. Note that features must already 153 // be initialized to the correct number of feature extractors. No predicate 154 // mapping is applied. ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & obj,ARGS...args,std::vector<FeatureVector> * features)155 void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj, 156 ARGS... args, 157 std::vector<FeatureVector> *features) const { 158 // DCHECK(features != nullptr); 159 // DCHECK_EQ(features->size(), feature_extractors_.size()); 160 for (size_t i = 0; i < feature_extractors_.size(); ++i) { 161 (*features)[i].clear(); 162 feature_extractors_[i]->ExtractFeatures(workspaces, obj, args..., 163 &(*features)[i]); 164 } 165 } 166 167 private: 168 // Templated feature extractor class. 169 std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_; 170 }; 171 172 } // namespace mobile 173 } // namespace nlp_saft 174 175 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_ 176