1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ 18 #define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ 19 20 #include <string> 21 #include <vector> 22 23 #include "common/embedding-feature-extractor.h" 24 #include "common/feature-extractor.h" 25 #include "common/task-context.h" 26 #include "common/workspace.h" 27 #include "lang_id/light-sentence-features.h" 28 #include "lang_id/light-sentence.h" 29 #include "util/base/macros.h" 30 31 namespace libtextclassifier { 32 namespace nlp_core { 33 namespace lang_id { 34 35 // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence. 36 class LangIdEmbeddingFeatureExtractor 37 : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> { 38 public: LangIdEmbeddingFeatureExtractor()39 LangIdEmbeddingFeatureExtractor() {} ArgPrefix()40 const std::string ArgPrefix() const override { return "language_identifier"; } 41 42 TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor); 43 }; 44 45 // Handles sentence -> numeric_features and numeric_prediction -> language 46 // conversions. 47 class LangIdBrainInterface { 48 public: LangIdBrainInterface()49 LangIdBrainInterface() {} 50 51 // Initializes resources and parameters. Init(TaskContext * context)52 bool Init(TaskContext *context) { 53 if (!feature_extractor_.Init(context)) { 54 return false; 55 } 56 feature_extractor_.RequestWorkspaces(&workspace_registry_); 57 return true; 58 } 59 60 // Extract features from sentence. On return, FeatureVector features[i] 61 // contains the features for the embedding space #i. GetFeatures(LightSentence * sentence,std::vector<FeatureVector> * features)62 void GetFeatures(LightSentence *sentence, 63 std::vector<FeatureVector> *features) const { 64 WorkspaceSet workspace; 65 workspace.Reset(workspace_registry_); 66 feature_extractor_.Preprocess(&workspace, sentence); 67 return feature_extractor_.ExtractFeatures(workspace, *sentence, features); 68 } 69 NumEmbeddings()70 int NumEmbeddings() const { 71 return feature_extractor_.NumEmbeddings(); 72 } 73 74 private: 75 // Typed feature extractor for embeddings. 76 LangIdEmbeddingFeatureExtractor feature_extractor_; 77 78 // The registry of shared workspaces in the feature extractor. 79 WorkspaceRegistry workspace_registry_; 80 81 TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface); 82 }; 83 84 } // namespace lang_id 85 } // namespace nlp_core 86 } // namespace libtextclassifier 87 88 #endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ 89