1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_ 18 #define LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_ 19 20 #include <memory> 21 22 #include "actions/actions_model_generated.h" 23 #include "annotator/model-executor.h" 24 #include "annotator/types.h" 25 #include "utils/token-feature-extractor.h" 26 #include "utils/tokenizer.h" 27 #include "utils/utf8/unicodetext.h" 28 #include "utils/utf8/unilib.h" 29 30 namespace libtextclassifier3 { 31 32 // Create tokenizer from options. 33 std::unique_ptr<Tokenizer> CreateTokenizer( 34 const ActionsTokenizerOptions* options, const UniLib* unilib); 35 36 // Feature processor for the actions suggestions model. 37 class ActionsFeatureProcessor { 38 public: 39 explicit ActionsFeatureProcessor( 40 const ActionsTokenFeatureProcessorOptions* options, const UniLib* unilib); 41 42 // Embeds and appends features to the output vector. 43 bool AppendFeatures(const std::vector<int>& sparse_features, 44 const std::vector<float>& dense_features, 45 const EmbeddingExecutor* embedding_executor, 46 std::vector<float>* output_features) const; 47 48 // Extracts the features of a token and appends them to the output vector. 49 bool AppendTokenFeatures(const Token& token, 50 const EmbeddingExecutor* embedding_executor, 51 std::vector<float>* output_features) const; 52 53 // Extracts the features of a vector of tokens and appends each to the output 54 // vector. 55 bool AppendTokenFeatures(const std::vector<Token>& tokens, 56 const EmbeddingExecutor* embedding_executor, 57 std::vector<float>* output_features) const; 58 59 int GetTokenEmbeddingSize() const; 60 tokenizer()61 const Tokenizer* tokenizer() const { return tokenizer_.get(); } 62 63 private: 64 const ActionsTokenFeatureProcessorOptions* options_; 65 const std::unique_ptr<Tokenizer> tokenizer_; 66 const TokenFeatureExtractor token_feature_extractor_; 67 }; 68 69 } // namespace libtextclassifier3 70 71 #endif // LIBTEXTCLASSIFIER_ACTIONS_FEATURE_PROCESSOR_H_ 72