1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains classes that can execute different models/parts of a model. 18 19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ 20 #define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ 21 22 #include <memory> 23 24 #include "annotator/types.h" 25 #include "utils/base/logging.h" 26 #include "utils/tensor-view.h" 27 #include "utils/tflite-model-executor.h" 28 29 namespace libtextclassifier3 { 30 31 // Executor for the text selection prediction and classification models. 32 class ModelExecutor : public TfLiteModelExecutor { 33 public: FromModelSpec(const tflite::Model * model_spec)34 static std::unique_ptr<ModelExecutor> FromModelSpec( 35 const tflite::Model* model_spec) { 36 auto model = TfLiteModelFromModelSpec(model_spec); 37 if (!model) { 38 return nullptr; 39 } 40 return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); 41 } 42 FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)43 static std::unique_ptr<ModelExecutor> FromBuffer( 44 const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 45 auto model = TfLiteModelFromBuffer(model_spec_buffer); 46 if (!model) { 47 return nullptr; 48 } 49 return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); 50 } 51 52 TensorView<float> ComputeLogits(const TensorView<float>& features, 53 tflite::Interpreter* interpreter) const; 54 55 protected: ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)56 explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) 57 : TfLiteModelExecutor(std::move(model)) {} 58 59 static const int kInputIndexFeatures = 0; 60 static const int kOutputIndexLogits = 0; 61 }; 62 63 // Executor for embedding sparse features into a dense vector. 64 class EmbeddingExecutor { 65 public: ~EmbeddingExecutor()66 virtual ~EmbeddingExecutor() {} 67 68 // Embeds the sparse_features into a dense embedding and adds (+) it 69 // element-wise to the dest vector. 70 virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 71 int dest_size) const = 0; 72 73 // Returns true when the model is ready to be used, false otherwise. IsReady()74 virtual bool IsReady() const { return true; } 75 }; 76 77 class TFLiteEmbeddingExecutor : public EmbeddingExecutor { 78 public: 79 static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer( 80 const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, 81 int quantization_bits, 82 const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr); 83 84 // Embeds the sparse_features into a dense embedding and adds (+) it 85 // element-wise to the dest vector. 86 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 87 int dest_size) const; 88 89 // Auxiliary function for computing prefixes used in implementation of 90 // efficient mask indexing data structure. 91 void ComputePrefixCounts(); 92 93 // Function implementing mask indexing based on efficient data structure 94 int PruneBucketId(int bucket_id) const; 95 96 protected: 97 explicit TFLiteEmbeddingExecutor( 98 std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, 99 int num_buckets, int bytes_per_embedding, int output_embedding_size, 100 const TfLiteTensor* scales, const TfLiteTensor* embeddings, 101 std::unique_ptr<tflite::Interpreter> interpreter, 102 const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr); 103 104 std::unique_ptr<TfLiteModelExecutor> executor_; 105 106 int quantization_bits_; 107 int num_buckets_ = -1; 108 int bytes_per_embedding_ = -1; 109 int output_embedding_size_ = -1; 110 const TfLiteTensor* scales_ = nullptr; 111 const TfLiteTensor* embeddings_ = nullptr; 112 113 // NOTE: This interpreter is used in a read-only way (as a storage for the 114 // model params), thus is still thread-safe. 115 std::unique_ptr<tflite::Interpreter> interpreter_; 116 117 std::vector<uint64> pruning_mask_; 118 std::vector<uint16> prefix_counts_; 119 int full_num_buckets_ = -1; 120 121 // Index of row of embedding table corresponding to all pruned buckets. 122 int pruned_row_bucket_id_ = -1; 123 }; 124 125 } // namespace libtextclassifier3 126 127 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ 128