• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains classes that can execute different models/parts of a model.
18 
19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
21 
22 #include <cstdint>
23 #include <memory>
24 
25 #include "utils/base/logging.h"
26 #include "utils/tensor-view.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/register.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/op_resolver.h"
32 #include "tensorflow/lite/string_util.h"
33 
34 namespace libtextclassifier3 {
35 
36 std::unique_ptr<tflite::OpResolver> BuildOpResolver();
37 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
38     const tflite::Model*);
39 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
40     const flatbuffers::Vector<uint8_t>*);
41 
42 // Executor for the text selection prediction and classification models.
43 class TfLiteModelExecutor {
44  public:
FromModelSpec(const tflite::Model * model_spec)45   static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
46       const tflite::Model* model_spec) {
47     auto model = TfLiteModelFromModelSpec(model_spec);
48     if (!model) {
49       return nullptr;
50     }
51     return std::unique_ptr<TfLiteModelExecutor>(
52         new TfLiteModelExecutor(std::move(model)));
53   }
54 
FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)55   static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
56       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
57     auto model = TfLiteModelFromBuffer(model_spec_buffer);
58     if (!model) {
59       return nullptr;
60     }
61     return std::unique_ptr<TfLiteModelExecutor>(
62         new TfLiteModelExecutor(std::move(model)));
63   }
64 
65   // Creates an Interpreter for the model that serves as a scratch-pad for the
66   // inference. The Interpreter is NOT thread-safe.
67   std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
68 
69   template <typename T>
SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)70   void SetInput(const int input_index, const TensorView<T>& input_data,
71                 tflite::Interpreter* interpreter) const {
72     input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
73                        input_data.size());
74   }
75 
76   template <typename T>
SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)77   void SetInput(const int input_index, const std::vector<T>& input_data,
78                 tflite::Interpreter* interpreter) const {
79     std::copy(input_data.begin(), input_data.end(),
80               interpreter->typed_input_tensor<T>(input_index));
81   }
82 
83   template <typename T>
SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)84   void SetInput(const int input_index, const T input_value,
85                 tflite::Interpreter* interpreter) const {
86     TfLiteTensor* input_tensor =
87         interpreter->tensor(interpreter->inputs()[input_index]);
88     switch (input_tensor->type) {
89       case kTfLiteFloat32:
90         *tflite::GetTensorData<float>(input_tensor) = input_value;
91         break;
92       case kTfLiteInt32:
93         *tflite::GetTensorData<int32_t>(input_tensor) = input_value;
94         break;
95       case kTfLiteUInt8:
96         *tflite::GetTensorData<uint8_t>(input_tensor) = input_value;
97         break;
98       case kTfLiteInt64:
99         *tflite::GetTensorData<int64_t>(input_tensor) = input_value;
100         break;
101       case kTfLiteBool:
102         *tflite::GetTensorData<bool>(input_tensor) = input_value;
103         break;
104       case kTfLiteInt16:
105         *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
106         break;
107       case kTfLiteInt8:
108         *tflite::GetTensorData<int8_t>(input_tensor) = input_value;
109         break;
110       default:
111         break;
112     }
113   }
114 
115   template <typename T>
OutputView(const int output_index,const tflite::Interpreter * interpreter)116   TensorView<T> OutputView(const int output_index,
117                            const tflite::Interpreter* interpreter) const {
118     const TfLiteTensor* output_tensor =
119         interpreter->tensor(interpreter->outputs()[output_index]);
120     return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
121                          std::vector<int>(output_tensor->dims->data,
122                                           output_tensor->dims->data +
123                                               output_tensor->dims->size));
124   }
125 
126   template <typename T>
Output(const int output_index,const tflite::Interpreter * interpreter)127   std::vector<T> Output(const int output_index,
128                         const tflite::Interpreter* interpreter) const {
129     TensorView<T> output_view = OutputView<T>(output_index, interpreter);
130     return std::vector<T>(output_view.data(),
131                           output_view.data() + output_view.size());
132   }
133 
134  protected:
135   explicit TfLiteModelExecutor(
136       std::unique_ptr<const tflite::FlatBufferModel> model);
137   TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,
138                       std::unique_ptr<tflite::OpResolver> resolver);
139 
140   std::unique_ptr<const tflite::FlatBufferModel> model_;
141   std::unique_ptr<tflite::OpResolver> resolver_;
142 };
143 
144 template <>
145 void TfLiteModelExecutor::SetInput(const int input_index,
146                                    const std::vector<std::string>& input_data,
147                                    tflite::Interpreter* interpreter) const;
148 
149 template <>
150 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
151     const int output_index, const tflite::Interpreter* interpreter) const;
152 
153 template <>
154 std::vector<std::string> TfLiteModelExecutor::Output(
155     const int output_index, const tflite::Interpreter* interpreter) const;
156 
157 }  // namespace libtextclassifier3
158 
159 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
160