• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains classes that can execute different models/parts of a model.
18 
19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
21 
22 #include <cstdint>
23 #include <memory>
24 
25 #include "utils/base/logging.h"
26 #include "utils/tensor-view.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/register.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/mutable_op_resolver.h"
32 #include "tensorflow/lite/op_resolver.h"
33 #include "tensorflow/lite/string_util.h"
34 
35 namespace libtextclassifier3 {
36 
37 // Creates a TF.Lite Op resolver in default configuration, with ops for
38 // Annotator and Actions models.
39 std::unique_ptr<tflite::OpResolver> BuildOpResolver();
40 
41 // Like above, but allows passage of a function that can register additional
42 // ops.
43 std::unique_ptr<tflite::OpResolver> BuildOpResolver(
44     const std::function<void(tflite::MutableOpResolver*)>& customize_fn);
45 
46 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
47     const tflite::Model*);
48 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
49     const flatbuffers::Vector<uint8_t>*);
50 
51 // Executor for the text selection prediction and classification models.
52 class TfLiteModelExecutor {
53  public:
FromModelSpec(const tflite::Model * model_spec)54   static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
55       const tflite::Model* model_spec) {
56     auto model = TfLiteModelFromModelSpec(model_spec);
57     if (!model) {
58       return nullptr;
59     }
60     return std::unique_ptr<TfLiteModelExecutor>(
61         new TfLiteModelExecutor(std::move(model)));
62   }
63 
FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)64   static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
65       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
66     auto model = TfLiteModelFromBuffer(model_spec_buffer);
67     if (!model) {
68       return nullptr;
69     }
70     return std::unique_ptr<TfLiteModelExecutor>(
71         new TfLiteModelExecutor(std::move(model)));
72   }
73 
74   // Creates an Interpreter for the model that serves as a scratch-pad for the
75   // inference. The Interpreter is NOT thread-safe.
76   std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
77 
78   template <typename T>
SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)79   void SetInput(const int input_index, const TensorView<T>& input_data,
80                 tflite::Interpreter* interpreter) const {
81     input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
82                        input_data.size());
83   }
84 
85   template <typename T>
SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)86   void SetInput(const int input_index, const std::vector<T>& input_data,
87                 tflite::Interpreter* interpreter) const {
88     std::copy(input_data.begin(), input_data.end(),
89               interpreter->typed_input_tensor<T>(input_index));
90   }
91 
92   template <typename T>
SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)93   void SetInput(const int input_index, const T input_value,
94                 tflite::Interpreter* interpreter) const {
95     TfLiteTensor* input_tensor =
96         interpreter->tensor(interpreter->inputs()[input_index]);
97     switch (input_tensor->type) {
98       case kTfLiteFloat32:
99         *tflite::GetTensorData<float>(input_tensor) = input_value;
100         break;
101       case kTfLiteInt32:
102         *tflite::GetTensorData<int32_t>(input_tensor) = input_value;
103         break;
104       case kTfLiteUInt8:
105         *tflite::GetTensorData<uint8_t>(input_tensor) = input_value;
106         break;
107       case kTfLiteInt64:
108         *tflite::GetTensorData<int64_t>(input_tensor) = input_value;
109         break;
110       case kTfLiteBool:
111         *tflite::GetTensorData<bool>(input_tensor) = input_value;
112         break;
113       case kTfLiteInt16:
114         *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
115         break;
116       case kTfLiteInt8:
117         *tflite::GetTensorData<int8_t>(input_tensor) = input_value;
118         break;
119       default:
120         break;
121     }
122   }
123 
124   template <typename T>
OutputView(const int output_index,const tflite::Interpreter * interpreter)125   TensorView<T> OutputView(const int output_index,
126                            const tflite::Interpreter* interpreter) const {
127     const TfLiteTensor* output_tensor =
128         interpreter->tensor(interpreter->outputs()[output_index]);
129     return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
130                          std::vector<int>(output_tensor->dims->data,
131                                           output_tensor->dims->data +
132                                               output_tensor->dims->size));
133   }
134 
135   template <typename T>
Output(const int output_index,const tflite::Interpreter * interpreter)136   std::vector<T> Output(const int output_index,
137                         const tflite::Interpreter* interpreter) const {
138     TensorView<T> output_view = OutputView<T>(output_index, interpreter);
139     return std::vector<T>(output_view.data(),
140                           output_view.data() + output_view.size());
141   }
142 
143  protected:
144   explicit TfLiteModelExecutor(
145       std::unique_ptr<const tflite::FlatBufferModel> model);
146   TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,
147                       std::unique_ptr<tflite::OpResolver> resolver);
148 
149   std::unique_ptr<const tflite::FlatBufferModel> model_;
150   std::unique_ptr<tflite::OpResolver> resolver_;
151 };
152 
153 template <>
154 void TfLiteModelExecutor::SetInput(const int input_index,
155                                    const std::vector<std::string>& input_data,
156                                    tflite::Interpreter* interpreter) const;
157 
158 template <>
159 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
160     const int output_index, const tflite::Interpreter* interpreter) const;
161 
162 template <>
163 std::vector<std::string> TfLiteModelExecutor::Output(
164     const int output_index, const tflite::Interpreter* interpreter) const;
165 
166 }  // namespace libtextclassifier3
167 
168 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
169