• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "lang_id/common/fel/feature-extractor.h"
25 #include "lang_id/common/fel/task-context.h"
26 #include "lang_id/common/fel/workspace.h"
27 #include "lang_id/common/lite_base/attributes.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 
31 namespace libtextclassifier3 {
32 namespace mobile {
33 
34 // An EmbeddingFeatureExtractor manages the extraction of features for
35 // embedding-based models. It wraps a sequence of underlying classes of feature
36 // extractors, along with associated predicate maps. Each class of feature
37 // extractors is associated with a name, e.g., "words", "labels", "tags".
38 //
39 // The class is split between a generic abstract version,
40 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
41 // signature of the ExtractFeatures method) and a typed version.
42 //
43 // The predicate maps must be initialized before use: they can be loaded using
44 // Read() or updated via UpdateMapsForExample.
45 class GenericEmbeddingFeatureExtractor {
46  public:
47   // Constructs this GenericEmbeddingFeatureExtractor.
48   //
49   // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
50   // avoid name clashes.  See GetParamName().
GenericEmbeddingFeatureExtractor(absl::string_view arg_prefix)51   explicit GenericEmbeddingFeatureExtractor(absl::string_view arg_prefix)
52       : arg_prefix_(arg_prefix) {}
53 
~GenericEmbeddingFeatureExtractor()54   virtual ~GenericEmbeddingFeatureExtractor() {}
55 
56   // Sets/inits up predicate maps and embedding space names that are common for
57   // all embedding based feature extractors.
58   //
59   // Returns true on success, false otherwise.
60   SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
61   SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
62 
63   // Requests workspace for the underlying feature extractors. This is
64   // implemented in the typed class.
65   virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
66 
67   // Returns number of embedding spaces.
NumEmbeddings()68   int NumEmbeddings() const { return embedding_dims_.size(); }
69 
embedding_fml()70   const std::vector<std::string> &embedding_fml() const {
71     return embedding_fml_;
72   }
73 
74   // Get parameter name by concatenating the prefix and the original name.
GetParamName(absl::string_view param_name)75   std::string GetParamName(absl::string_view param_name) const {
76     return absl::StrCat(arg_prefix_, "_", param_name);
77   }
78 
79  private:
80   // Prefix for TaskContext parameters.
81   const std::string arg_prefix_;
82 
83   // Embedding space names for parameter sharing.
84   std::vector<std::string> embedding_names_;
85 
86   // FML strings for each feature extractor.
87   std::vector<std::string> embedding_fml_;
88 
89   // Size of each of the embedding spaces (maximum predicate id).
90   std::vector<int> embedding_sizes_;
91 
92   // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
93   std::vector<int> embedding_dims_;
94 };
95 
96 // Templated, object-specific implementation of the
97 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
98 // ARGS...> class that has the appropriate FeatureTraits() to ensure that
99 // locator type features work.
100 //
101 // Note: for backwards compatibility purposes, this always reads the FML spec
102 // from "<prefix>_features".
103 template <class EXTRACTOR, class OBJ, class... ARGS>
104 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
105  public:
106   // Constructs this EmbeddingFeatureExtractor.
107   //
108   // |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
109   // avoid name clashes.  See GetParamName().
EmbeddingFeatureExtractor(absl::string_view arg_prefix)110   explicit EmbeddingFeatureExtractor(absl::string_view arg_prefix)
111       : GenericEmbeddingFeatureExtractor(arg_prefix) {}
112 
113   // Sets up all predicate maps, feature extractors, and flags.
Setup(TaskContext * context)114   SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
115     if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
116       return false;
117     }
118     feature_extractors_.resize(embedding_fml().size());
119     for (size_t i = 0; i < embedding_fml().size(); ++i) {
120       feature_extractors_[i].reset(new EXTRACTOR());
121       if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
122       if (!feature_extractors_[i]->Setup(context)) return false;
123     }
124     return true;
125   }
126 
127   // Initializes resources needed by the feature extractors.
Init(TaskContext * context)128   SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
129     if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
130     for (auto &feature_extractor : feature_extractors_) {
131       if (!feature_extractor->Init(context)) return false;
132     }
133     return true;
134   }
135 
136   // Requests workspaces from the registry. Must be called after Init(), and
137   // before Preprocess().
RequestWorkspaces(WorkspaceRegistry * registry)138   void RequestWorkspaces(WorkspaceRegistry *registry) override {
139     for (auto &feature_extractor : feature_extractors_) {
140       feature_extractor->RequestWorkspaces(registry);
141     }
142   }
143 
144   // Must be called on the object one state for each sentence, before any
145   // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
Preprocess(WorkspaceSet * workspaces,OBJ * obj)146   void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
147     for (auto &feature_extractor : feature_extractors_) {
148       feature_extractor->Preprocess(workspaces, obj);
149     }
150   }
151 
152   // Extracts features using the extractors. Note that features must already
153   // be initialized to the correct number of feature extractors. No predicate
154   // mapping is applied.
ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & obj,ARGS...args,std::vector<FeatureVector> * features)155   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
156                        ARGS... args,
157                        std::vector<FeatureVector> *features) const {
158     // DCHECK(features != nullptr);
159     // DCHECK_EQ(features->size(), feature_extractors_.size());
160     for (size_t i = 0; i < feature_extractors_.size(); ++i) {
161       (*features)[i].clear();
162       feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
163                                               &(*features)[i]);
164     }
165   }
166 
167  private:
168   // Templated feature extractor class.
169   std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
170 };
171 
172 }  // namespace mobile
173 }  // namespace nlp_saft
174 
175 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
176