• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains classes that can execute different models/parts of a model.
18 
19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
21 
22 #include <memory>
23 
24 #include "utils/base/logging.h"
25 #include "utils/tensor-view.h"
26 #include "tensorflow/lite/interpreter.h"
27 #include "tensorflow/lite/kernels/register.h"
28 #include "tensorflow/lite/model.h"
29 #include "tensorflow/lite/op_resolver.h"
30 #include "tensorflow/lite/string_util.h"
31 
32 namespace libtextclassifier3 {
33 
34 std::unique_ptr<tflite::OpResolver> BuildOpResolver();
35 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
36     const tflite::Model*);
37 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
38     const flatbuffers::Vector<uint8_t>*);
39 
40 // Executor for the text selection prediction and classification models.
41 class TfLiteModelExecutor {
42  public:
FromModelSpec(const tflite::Model * model_spec)43   static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
44       const tflite::Model* model_spec) {
45     auto model = TfLiteModelFromModelSpec(model_spec);
46     if (!model) {
47       return nullptr;
48     }
49     return std::unique_ptr<TfLiteModelExecutor>(
50         new TfLiteModelExecutor(std::move(model)));
51   }
52 
FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)53   static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
54       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
55     auto model = TfLiteModelFromBuffer(model_spec_buffer);
56     if (!model) {
57       return nullptr;
58     }
59     return std::unique_ptr<TfLiteModelExecutor>(
60         new TfLiteModelExecutor(std::move(model)));
61   }
62 
63   // Creates an Interpreter for the model that serves as a scratch-pad for the
64   // inference. The Interpreter is NOT thread-safe.
65   std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
66 
67   template <typename T>
SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)68   void SetInput(const int input_index, const TensorView<T>& input_data,
69                 tflite::Interpreter* interpreter) const {
70     input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
71                        input_data.size());
72   }
73 
74   template <typename T>
SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)75   void SetInput(const int input_index, const std::vector<T>& input_data,
76                 tflite::Interpreter* interpreter) const {
77     std::copy(input_data.begin(), input_data.end(),
78               interpreter->typed_input_tensor<T>(input_index));
79   }
80 
81   template <typename T>
SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)82   void SetInput(const int input_index, const T input_value,
83                 tflite::Interpreter* interpreter) const {
84     TfLiteTensor* input_tensor =
85         interpreter->tensor(interpreter->inputs()[input_index]);
86     switch (input_tensor->type) {
87       case kTfLiteFloat32:
88         *(input_tensor->data.f) = input_value;
89         break;
90       case kTfLiteInt32:
91         *(input_tensor->data.i32) = input_value;
92         break;
93       case kTfLiteUInt8:
94         *(input_tensor->data.uint8) = input_value;
95         break;
96       case kTfLiteInt64:
97         *(input_tensor->data.i64) = input_value;
98         break;
99       case kTfLiteBool:
100         *(input_tensor->data.b) = input_value;
101         break;
102       case kTfLiteInt16:
103         *(input_tensor->data.i16) = input_value;
104         break;
105       case kTfLiteInt8:
106         *(input_tensor->data.int8) = input_value;
107         break;
108       default:
109         break;
110     }
111   }
112 
113   template <typename T>
OutputView(const int output_index,const tflite::Interpreter * interpreter)114   TensorView<T> OutputView(const int output_index,
115                            const tflite::Interpreter* interpreter) const {
116     const TfLiteTensor* output_tensor =
117         interpreter->tensor(interpreter->outputs()[output_index]);
118     return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
119                          std::vector<int>(output_tensor->dims->data,
120                                           output_tensor->dims->data +
121                                               output_tensor->dims->size));
122   }
123 
124   template <typename T>
Output(const int output_index,const tflite::Interpreter * interpreter)125   std::vector<T> Output(const int output_index,
126                         const tflite::Interpreter* interpreter) const {
127     TensorView<T> output_view = OutputView<T>(output_index, interpreter);
128     return std::vector<T>(output_view.data(),
129                           output_view.data() + output_view.size());
130   }
131 
132  protected:
133   explicit TfLiteModelExecutor(
134       std::unique_ptr<const tflite::FlatBufferModel> model);
135 
136   std::unique_ptr<const tflite::FlatBufferModel> model_;
137   std::unique_ptr<tflite::OpResolver> resolver_;
138 };
139 
140 template <>
141 void TfLiteModelExecutor::SetInput(const int input_index,
142                                    const std::vector<std::string>& input_data,
143                                    tflite::Interpreter* interpreter) const;
144 
145 template <>
146 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
147     const int output_index, const tflite::Interpreter* interpreter) const;
148 
149 template <>
150 std::vector<std::string> TfLiteModelExecutor::Output(
151     const int output_index, const tflite::Interpreter* interpreter) const;
152 
153 }  // namespace libtextclassifier3
154 
155 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
156