/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Contains classes that can execute different models/parts of a model. #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ #define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ #include #include "annotator/types.h" #include "utils/base/logging.h" #include "utils/tensor-view.h" #include "utils/tflite-model-executor.h" namespace libtextclassifier3 { // Executor for the text selection prediction and classification models. class ModelExecutor : public TfLiteModelExecutor { public: static std::unique_ptr FromModelSpec( const tflite::Model* model_spec) { auto model = TfLiteModelFromModelSpec(model_spec); if (!model) { return nullptr; } return std::unique_ptr(new ModelExecutor(std::move(model))); } static std::unique_ptr FromBuffer( const flatbuffers::Vector* model_spec_buffer) { auto model = TfLiteModelFromBuffer(model_spec_buffer); if (!model) { return nullptr; } return std::unique_ptr(new ModelExecutor(std::move(model))); } TensorView ComputeLogits(const TensorView& features, tflite::Interpreter* interpreter) const; protected: explicit ModelExecutor(std::unique_ptr model) : TfLiteModelExecutor(std::move(model)) {} static const int kInputIndexFeatures = 0; static const int kOutputIndexLogits = 0; }; // Executor for embedding sparse features into a dense vector. class EmbeddingExecutor { public: virtual ~EmbeddingExecutor() {} // Embeds the sparse_features into a dense embedding and adds (+) it // element-wise to the dest vector. virtual bool AddEmbedding(const TensorView& sparse_features, float* dest, int dest_size) const = 0; // Returns true when the model is ready to be used, false otherwise. virtual bool IsReady() const { return true; } }; class TFLiteEmbeddingExecutor : public EmbeddingExecutor { public: static std::unique_ptr FromBuffer( const flatbuffers::Vector* model_spec_buffer, int embedding_size, int quantization_bits, const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr); // Embeds the sparse_features into a dense embedding and adds (+) it // element-wise to the dest vector. bool AddEmbedding(const TensorView& sparse_features, float* dest, int dest_size) const; // Auxiliary function for computing prefixes used in implementation of // efficient mask indexing data structure. void ComputePrefixCounts(); // Function implementing mask indexing based on efficient data structure int PruneBucketId(int bucket_id) const; protected: explicit TFLiteEmbeddingExecutor( std::unique_ptr executor, int quantization_bits, int num_buckets, int bytes_per_embedding, int output_embedding_size, const TfLiteTensor* scales, const TfLiteTensor* embeddings, std::unique_ptr interpreter, const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr); std::unique_ptr executor_; int quantization_bits_; int num_buckets_ = -1; int bytes_per_embedding_ = -1; int output_embedding_size_ = -1; const TfLiteTensor* scales_ = nullptr; const TfLiteTensor* embeddings_ = nullptr; // NOTE: This interpreter is used in a read-only way (as a storage for the // model params), thus is still thread-safe. std::unique_ptr interpreter_; std::vector pruning_mask_; std::vector prefix_counts_; int full_num_buckets_ = -1; // Index of row of embedding table corresponding to all pruned buckets. int pruned_row_bucket_id_ = -1; }; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_