• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "common/feature-extractor.h"
25 #include "common/task-context.h"
26 #include "common/workspace.h"
27 #include "util/base/logging.h"
28 #include "util/base/macros.h"
29 
30 namespace libtextclassifier {
31 namespace nlp_core {
32 
33 // An EmbeddingFeatureExtractor manages the extraction of features for
34 // embedding-based models. It wraps a sequence of underlying classes of feature
35 // extractors, along with associated predicate maps. Each class of feature
36 // extractors is associated with a name, e.g., "words", "labels", "tags".
37 //
38 // The class is split between a generic abstract version,
39 // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
40 // signature of the ExtractFeatures method) and a typed version.
41 //
42 // The predicate maps must be initialized before use: they can be loaded using
43 // Read() or updated via UpdateMapsForExample.
44 class GenericEmbeddingFeatureExtractor {
45  public:
GenericEmbeddingFeatureExtractor()46   GenericEmbeddingFeatureExtractor() {}
~GenericEmbeddingFeatureExtractor()47   virtual ~GenericEmbeddingFeatureExtractor() {}
48 
49   // Get the prefix std::string to put in front of all arguments, so they don't
50   // conflict with other embedding models.
51   virtual const std::string ArgPrefix() const = 0;
52 
53   // Initializes predicate maps and embedding space names that are common for
54   // all embedding-based feature extractors.
55   virtual bool Init(TaskContext *context);
56 
57   // Requests workspace for the underlying feature extractors. This is
58   // implemented in the typed class.
59   virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
60 
61   // Returns number of embedding spaces.
NumEmbeddings()62   int NumEmbeddings() const { return embedding_dims_.size(); }
63 
64   // Number of predicates for the embedding at a given index (vocabulary size).
65   // Returns -1 if index is out of bounds.
EmbeddingSize(int index)66   int EmbeddingSize(int index) const {
67     const GenericFeatureExtractor *extractor = generic_feature_extractor(index);
68     return (extractor == nullptr) ? -1 : extractor->GetDomainSize();
69   }
70 
71   // Returns the dimensionality of the embedding space.
EmbeddingDims(int index)72   int EmbeddingDims(int index) const { return embedding_dims_[index]; }
73 
74   // Accessor for embedding dims (dimensions of the embedding spaces).
embedding_dims()75   const std::vector<int> &embedding_dims() const { return embedding_dims_; }
76 
embedding_fml()77   const std::vector<std::string> &embedding_fml() const {
78     return embedding_fml_;
79   }
80 
81   // Get parameter name by concatenating the prefix and the original name.
GetParamName(const std::string & param_name)82   std::string GetParamName(const std::string &param_name) const {
83     std::string full_name = ArgPrefix();
84     full_name.push_back('_');
85     full_name.append(param_name);
86     return full_name;
87   }
88 
89  protected:
90   // Provides the generic class with access to the templated extractors. This is
91   // used to get the type information out of the feature extractor without
92   // knowing the specific calling arguments of the extractor itself.
93   // Returns nullptr for an out-of-bounds idx.
94   virtual const GenericFeatureExtractor *generic_feature_extractor(
95       int idx) const = 0;
96 
97  private:
98   // Embedding space names for parameter sharing.
99   std::vector<std::string> embedding_names_;
100 
101   // FML strings for each feature extractor.
102   std::vector<std::string> embedding_fml_;
103 
104   // Size of each of the embedding spaces (maximum predicate id).
105   std::vector<int> embedding_sizes_;
106 
107   // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
108   std::vector<int> embedding_dims_;
109 
110   TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor);
111 };
112 
113 // Templated, object-specific implementation of the
114 // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
115 // ARGS...> class that has the appropriate FeatureTraits() to ensure that
116 // locator type features work.
117 //
118 // Note: for backwards compatibility purposes, this always reads the FML spec
119 // from "<prefix>_features".
120 template <class EXTRACTOR, class OBJ, class... ARGS>
121 class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
122  public:
123   // Initializes all predicate maps, feature extractors, etc.
Init(TaskContext * context)124   bool Init(TaskContext *context) override {
125     if (!GenericEmbeddingFeatureExtractor::Init(context)) {
126       return false;
127     }
128     feature_extractors_.resize(embedding_fml().size());
129     for (int i = 0; i < embedding_fml().size(); ++i) {
130       feature_extractors_[i].reset(new EXTRACTOR());
131       if (!feature_extractors_[i]->Parse(embedding_fml()[i])) {
132         return false;
133       }
134       if (!feature_extractors_[i]->Setup(context)) {
135         return false;
136       }
137     }
138     for (auto &feature_extractor : feature_extractors_) {
139       if (!feature_extractor->Init(context)) {
140         return false;
141       }
142     }
143     return true;
144   }
145 
146   // Requests workspaces from the registry. Must be called after Init(), and
147   // before Preprocess().
RequestWorkspaces(WorkspaceRegistry * registry)148   void RequestWorkspaces(WorkspaceRegistry *registry) override {
149     for (auto &feature_extractor : feature_extractors_) {
150       feature_extractor->RequestWorkspaces(registry);
151     }
152   }
153 
154   // Must be called on the object one state for each sentence, before any
155   // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
Preprocess(WorkspaceSet * workspaces,OBJ * obj)156   void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
157     for (auto &feature_extractor : feature_extractors_) {
158       feature_extractor->Preprocess(workspaces, obj);
159     }
160   }
161 
162   // Extracts features using the extractors. Note that features must already
163   // be initialized to the correct number of feature extractors. No predicate
164   // mapping is applied.
ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & obj,ARGS...args,std::vector<FeatureVector> * features)165   void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
166                        ARGS... args,
167                        std::vector<FeatureVector> *features) const {
168     TC_DCHECK(features != nullptr);
169     TC_DCHECK_EQ(features->size(), feature_extractors_.size());
170     for (int i = 0; i < feature_extractors_.size(); ++i) {
171       (*features)[i].clear();
172       feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
173                                               &(*features)[i]);
174     }
175   }
176 
177  protected:
178   // Provides generic access to the feature extractors.
generic_feature_extractor(int idx)179   const GenericFeatureExtractor *generic_feature_extractor(
180       int idx) const override {
181     if ((idx < 0) || (idx >= feature_extractors_.size())) {
182       TC_LOG(ERROR) << "Out of bounds index " << idx;
183       TC_DCHECK(false);  // Crash in debug mode.
184       return nullptr;
185     }
186     return feature_extractors_[idx].get();
187   }
188 
189  private:
190   // Templated feature extractor class.
191   std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
192 };
193 
194 }  // namespace nlp_core
195 }  // namespace libtextclassifier
196 
197 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_
198