/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_ #include #include #include #include "common/feature-extractor.h" #include "common/task-context.h" #include "common/workspace.h" #include "util/base/logging.h" #include "util/base/macros.h" namespace libtextclassifier { namespace nlp_core { // An EmbeddingFeatureExtractor manages the extraction of features for // embedding-based models. It wraps a sequence of underlying classes of feature // extractors, along with associated predicate maps. Each class of feature // extractors is associated with a name, e.g., "words", "labels", "tags". // // The class is split between a generic abstract version, // GenericEmbeddingFeatureExtractor (that can be initialized without knowing the // signature of the ExtractFeatures method) and a typed version. // // The predicate maps must be initialized before use: they can be loaded using // Read() or updated via UpdateMapsForExample. class GenericEmbeddingFeatureExtractor { public: GenericEmbeddingFeatureExtractor() {} virtual ~GenericEmbeddingFeatureExtractor() {} // Get the prefix std::string to put in front of all arguments, so they don't // conflict with other embedding models. virtual const std::string ArgPrefix() const = 0; // Initializes predicate maps and embedding space names that are common for // all embedding-based feature extractors. virtual bool Init(TaskContext *context); // Requests workspace for the underlying feature extractors. This is // implemented in the typed class. virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0; // Returns number of embedding spaces. int NumEmbeddings() const { return embedding_dims_.size(); } // Number of predicates for the embedding at a given index (vocabulary size). // Returns -1 if index is out of bounds. int EmbeddingSize(int index) const { const GenericFeatureExtractor *extractor = generic_feature_extractor(index); return (extractor == nullptr) ? -1 : extractor->GetDomainSize(); } // Returns the dimensionality of the embedding space. int EmbeddingDims(int index) const { return embedding_dims_[index]; } // Accessor for embedding dims (dimensions of the embedding spaces). const std::vector &embedding_dims() const { return embedding_dims_; } const std::vector &embedding_fml() const { return embedding_fml_; } // Get parameter name by concatenating the prefix and the original name. std::string GetParamName(const std::string ¶m_name) const { std::string full_name = ArgPrefix(); full_name.push_back('_'); full_name.append(param_name); return full_name; } protected: // Provides the generic class with access to the templated extractors. This is // used to get the type information out of the feature extractor without // knowing the specific calling arguments of the extractor itself. // Returns nullptr for an out-of-bounds idx. virtual const GenericFeatureExtractor *generic_feature_extractor( int idx) const = 0; private: // Embedding space names for parameter sharing. std::vector embedding_names_; // FML strings for each feature extractor. std::vector embedding_fml_; // Size of each of the embedding spaces (maximum predicate id). std::vector embedding_sizes_; // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.) std::vector embedding_dims_; TC_DISALLOW_COPY_AND_ASSIGN(GenericEmbeddingFeatureExtractor); }; // Templated, object-specific implementation of the // EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor class that has the appropriate FeatureTraits() to ensure that // locator type features work. // // Note: for backwards compatibility purposes, this always reads the FML spec // from "_features". template class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor { public: // Initializes all predicate maps, feature extractors, etc. bool Init(TaskContext *context) override { if (!GenericEmbeddingFeatureExtractor::Init(context)) { return false; } feature_extractors_.resize(embedding_fml().size()); for (int i = 0; i < embedding_fml().size(); ++i) { feature_extractors_[i].reset(new EXTRACTOR()); if (!feature_extractors_[i]->Parse(embedding_fml()[i])) { return false; } if (!feature_extractors_[i]->Setup(context)) { return false; } } for (auto &feature_extractor : feature_extractors_) { if (!feature_extractor->Init(context)) { return false; } } return true; } // Requests workspaces from the registry. Must be called after Init(), and // before Preprocess(). void RequestWorkspaces(WorkspaceRegistry *registry) override { for (auto &feature_extractor : feature_extractors_) { feature_extractor->RequestWorkspaces(registry); } } // Must be called on the object one state for each sentence, before any // feature extraction (e.g., UpdateMapsForExample, ExtractFeatures). void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const { for (auto &feature_extractor : feature_extractors_) { feature_extractor->Preprocess(workspaces, obj); } } // Extracts features using the extractors. Note that features must already // be initialized to the correct number of feature extractors. No predicate // mapping is applied. void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj, ARGS... args, std::vector *features) const { TC_DCHECK(features != nullptr); TC_DCHECK_EQ(features->size(), feature_extractors_.size()); for (int i = 0; i < feature_extractors_.size(); ++i) { (*features)[i].clear(); feature_extractors_[i]->ExtractFeatures(workspaces, obj, args..., &(*features)[i]); } } protected: // Provides generic access to the feature extractors. const GenericFeatureExtractor *generic_feature_extractor( int idx) const override { if ((idx < 0) || (idx >= feature_extractors_.size())) { TC_LOG(ERROR) << "Out of bounds index " << idx; TC_DCHECK(false); // Crash in debug mode. return nullptr; } return feature_extractors_[idx].get(); } private: // Templated feature extractor class. std::vector> feature_extractors_; }; } // namespace nlp_core } // namespace libtextclassifier #endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_FEATURE_EXTRACTOR_H_