• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains classes that can execute different models/parts of a model.
18 
19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
20 #define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
21 
22 #include <memory>
23 
24 #include "annotator/types.h"
25 #include "utils/base/logging.h"
26 #include "utils/tensor-view.h"
27 #include "utils/tflite-model-executor.h"
28 
29 namespace libtextclassifier3 {
30 
31 // Executor for the text selection prediction and classification models.
32 class ModelExecutor : public TfLiteModelExecutor {
33  public:
FromModelSpec(const tflite::Model * model_spec)34   static std::unique_ptr<ModelExecutor> FromModelSpec(
35       const tflite::Model* model_spec) {
36     auto model = TfLiteModelFromModelSpec(model_spec);
37     if (!model) {
38       return nullptr;
39     }
40     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
41   }
42 
FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)43   static std::unique_ptr<ModelExecutor> FromBuffer(
44       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
45     auto model = TfLiteModelFromBuffer(model_spec_buffer);
46     if (!model) {
47       return nullptr;
48     }
49     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
50   }
51 
52   TensorView<float> ComputeLogits(const TensorView<float>& features,
53                                   tflite::Interpreter* interpreter) const;
54 
55  protected:
ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)56   explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
57       : TfLiteModelExecutor(std::move(model)) {}
58 
59   static const int kInputIndexFeatures = 0;
60   static const int kOutputIndexLogits = 0;
61 };
62 
63 // Executor for embedding sparse features into a dense vector.
64 class EmbeddingExecutor {
65  public:
~EmbeddingExecutor()66   virtual ~EmbeddingExecutor() {}
67 
68   // Embeds the sparse_features into a dense embedding and adds (+) it
69   // element-wise to the dest vector.
70   virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
71                             int dest_size) const = 0;
72 
73   // Returns true when the model is ready to be used, false otherwise.
IsReady()74   virtual bool IsReady() const { return true; }
75 };
76 
77 class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
78  public:
79   static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
80       const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
81       int quantization_bits,
82       const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
83 
84   // Embeds the sparse_features into a dense embedding and adds (+) it
85   // element-wise to the dest vector.
86   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
87                     int dest_size) const;
88 
89   // Auxiliary function for computing prefixes used in implementation of
90   // efficient mask indexing data structure.
91   void ComputePrefixCounts();
92 
93   // Function implementing mask indexing based on efficient data structure
94   int PruneBucketId(int bucket_id) const;
95 
96  protected:
97   explicit TFLiteEmbeddingExecutor(
98       std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
99       int num_buckets, int bytes_per_embedding, int output_embedding_size,
100       const TfLiteTensor* scales, const TfLiteTensor* embeddings,
101       std::unique_ptr<tflite::Interpreter> interpreter,
102       const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
103 
104   std::unique_ptr<TfLiteModelExecutor> executor_;
105 
106   int quantization_bits_;
107   int num_buckets_ = -1;
108   int bytes_per_embedding_ = -1;
109   int output_embedding_size_ = -1;
110   const TfLiteTensor* scales_ = nullptr;
111   const TfLiteTensor* embeddings_ = nullptr;
112 
113   // NOTE: This interpreter is used in a read-only way (as a storage for the
114   // model params), thus is still thread-safe.
115   std::unique_ptr<tflite::Interpreter> interpreter_;
116 
117   std::vector<uint64> pruning_mask_;
118   std::vector<uint16> prefix_counts_;
119   int full_num_buckets_ = -1;
120 
121   // Index of row of embedding table corresponding to all pruned buckets.
122   int pruned_row_bucket_id_ = -1;
123 };
124 
125 }  // namespace libtextclassifier3
126 
127 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
128