1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ 18 #define LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ 19 20 #include <memory> 21 #include <vector> 22 23 #include "model-executor.h" 24 #include "model_generated.h" 25 #include "types.h" 26 27 namespace libtextclassifier2 { 28 29 // Holds state for extracting features across multiple calls and reusing them. 30 // Assumes that features for each Token are independent. 31 class CachedFeatures { 32 public: 33 static std::unique_ptr<CachedFeatures> Create( 34 const TokenSpan& extraction_span, 35 std::unique_ptr<std::vector<float>> features, 36 std::unique_ptr<std::vector<float>> padding_features, 37 const FeatureProcessorOptions* options, int feature_vector_size); 38 39 // Appends the click context features for the given click position to 40 // 'output_features'. 41 void AppendClickContextFeaturesForClick( 42 int click_pos, std::vector<float>* output_features) const; 43 44 // Appends the bounds-sensitive features for the given token span to 45 // 'output_features'. 46 void AppendBoundsSensitiveFeaturesForSpan( 47 TokenSpan selected_span, std::vector<float>* output_features) const; 48 49 // Returns number of features that 'AppendFeaturesForSpan' appends. OutputFeaturesSize()50 int OutputFeaturesSize() const { return output_features_size_; } 51 52 private: CachedFeatures()53 CachedFeatures() {} 54 55 // Appends token features to the output. The intended_span specifies which 56 // tokens' features should be used in principle. The read_mask_span restricts 57 // which tokens are actually read. For tokens outside of the read_mask_span, 58 // padding tokens are used instead. 59 void AppendFeaturesInternal(const TokenSpan& intended_span, 60 const TokenSpan& read_mask_span, 61 std::vector<float>* output_features) const; 62 63 // Appends features of one padding token to the output. 64 void AppendPaddingFeatures(std::vector<float>* output_features) const; 65 66 // Appends the features of tokens from the given span to the output. The 67 // features are averaged so that the appended features have the size 68 // corresponding to one token. 69 void AppendBagFeatures(const TokenSpan& bag_span, 70 std::vector<float>* output_features) const; 71 72 int NumFeaturesPerToken() const; 73 74 TokenSpan extraction_span_; 75 const FeatureProcessorOptions* options_; 76 int output_features_size_; 77 std::unique_ptr<std::vector<float>> features_; 78 std::unique_ptr<std::vector<float>> padding_features_; 79 }; 80 81 } // namespace libtextclassifier2 82 83 #endif // LIBTEXTCLASSIFIER_CACHED_FEATURES_H_ 84