• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite-model-executor.h"
18 
19 #include "utils/base/logging.h"
20 #include "tensorflow/lite/kernels/register.h"
21 
22 // Forward declaration of custom TensorFlow Lite ops for registration.
23 namespace tflite {
24 namespace ops {
25 namespace builtin {
26 TfLiteRegistration* Register_ADD();
27 TfLiteRegistration* Register_CONCATENATION();
28 TfLiteRegistration* Register_CONV_2D();
29 TfLiteRegistration* Register_FULLY_CONNECTED();
30 TfLiteRegistration* Register_L2_NORMALIZATION();
31 TfLiteRegistration* Register_MUL();
32 TfLiteRegistration* Register_RESHAPE();
33 TfLiteRegistration* Register_SOFTMAX();
34 TfLiteRegistration* Register_GATHER();
35 TfLiteRegistration* Register_TRANSPOSE();
36 TfLiteRegistration* Register_SUB();
37 TfLiteRegistration* Register_DIV();
38 TfLiteRegistration* Register_STRIDED_SLICE();
39 TfLiteRegistration* Register_EXP();
40 TfLiteRegistration* Register_TOPK_V2();
41 TfLiteRegistration* Register_SPLIT();
42 TfLiteRegistration* Register_CAST();
43 TfLiteRegistration* Register_MAXIMUM();
44 TfLiteRegistration* Register_MINIMUM();
45 TfLiteRegistration* Register_NEG();
46 TfLiteRegistration* Register_SLICE();
47 TfLiteRegistration* Register_LOG();
48 TfLiteRegistration* Register_SUM();
49 TfLiteRegistration* Register_PACK();
50 TfLiteRegistration* Register_DEQUANTIZE();
51 TfLiteRegistration* Register_MEAN();
52 }  // namespace builtin
53 }  // namespace ops
54 }  // namespace tflite
55 
56 #ifdef TC3_WITH_ACTIONS_OPS
57 #include "utils/tflite/dist_diversification.h"
58 #include "utils/tflite/text_encoder.h"
59 #include "utils/tflite/token_encoder.h"
60 
RegisterSelectedOps(tflite::MutableOpResolver * resolver)61 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
62   resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
63                        tflite::ops::builtin::Register_ADD(),
64                        /*min_version=*/1,
65                        /*max_version=*/2);
66   resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
67                        tflite::ops::builtin::Register_CONCATENATION(),
68                        /*min_version=*/1,
69                        /*max_version=*/2);
70   resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
71                        tflite::ops::builtin::Register_CONV_2D(),
72                        /*min_version=*/1,
73                        /*max_version=*/3);
74   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
75                        tflite::ops::builtin::Register_FULLY_CONNECTED(),
76                        /*min_version=*/1,
77                        /*max_version=*/4);
78   resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
79                        tflite::ops::builtin::Register_L2_NORMALIZATION(),
80                        /*min_version=*/1,
81                        /*max_version=*/2);
82   resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
83                        tflite::ops::builtin::Register_MUL());
84   resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
85                        tflite::ops::builtin::Register_RESHAPE());
86   resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
87                        tflite::ops::builtin::Register_SOFTMAX(),
88                        /*min_version=*/1,
89                        /*max_version=*/2);
90   resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
91                        tflite::ops::builtin::Register_GATHER(),
92                        /*min_version=*/1,
93                        /*max_version=*/2);
94   resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
95                        tflite::ops::builtin::Register_TRANSPOSE(),
96                        /*min_version=*/1,
97                        /*max_version=*/2);
98   resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
99                        tflite::ops::builtin::Register_SUB(),
100                        /*min_version=*/1,
101                        /*max_version=*/2);
102   resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
103                        tflite::ops::builtin::Register_DIV());
104   resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
105                        tflite::ops::builtin::Register_STRIDED_SLICE(),
106                        /*min_version=*/1,
107                        /*max_version=*/2);
108   resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
109                        tflite::ops::builtin::Register_EXP());
110   resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
111                        tflite::ops::builtin::Register_TOPK_V2(),
112                        /*min_version=*/1,
113                        /*max_version=*/2);
114   resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
115                        tflite::ops::builtin::Register_SPLIT(),
116                        /*min_version=*/1,
117                        /*max_version=*/3);
118   resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
119                        tflite::ops::builtin::Register_CAST());
120   resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
121                        tflite::ops::builtin::Register_MAXIMUM(),
122                        /*min_version=*/1,
123                        /*max_version=*/2);
124   resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
125                        tflite::ops::builtin::Register_MINIMUM(),
126                        /*min_version=*/1,
127                        /*max_version=*/2);
128   resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
129                        tflite::ops::builtin::Register_NEG());
130   resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
131                        tflite::ops::builtin::Register_SLICE(),
132                        /*min_version=*/1,
133                        /*max_version=*/2);
134   resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
135                        tflite::ops::builtin::Register_LOG());
136   resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
137                        tflite::ops::builtin::Register_SUM());
138   resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
139                        tflite::ops::builtin::Register_PACK(),
140                        /*min_version=*/1,
141                        /*max_version=*/2);
142   resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
143                        tflite::ops::builtin::Register_DEQUANTIZE(),
144                        /*min_version=*/1,
145                        /*max_version=*/2);
146   resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
147                        tflite::ops::builtin::Register_MEAN());
148 }
149 #else
RegisterSelectedOps(tflite::MutableOpResolver * resolver)150 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
151   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
152                        tflite::ops::builtin::Register_FULLY_CONNECTED());
153 }
154 #endif  // TC3_WITH_ACTIONS_OPS
155 
156 namespace libtextclassifier3 {
157 
BuildOpResolver()158 inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
159 #ifdef TC3_USE_SELECTIVE_REGISTRATION
160   std::unique_ptr<tflite::MutableOpResolver> resolver(
161       new tflite::MutableOpResolver);
162   RegisterSelectedOps(resolver.get());
163 #else
164   std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
165       new tflite::ops::builtin::BuiltinOpResolver);
166 #endif
167 #ifdef TC3_WITH_ACTIONS_OPS
168   resolver->AddCustom("DistanceDiversification",
169                       tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
170   resolver->AddCustom("TextEncoder",
171                       tflite::ops::custom::Register_TEXT_ENCODER());
172   resolver->AddCustom("TokenEncoder",
173                       tflite::ops::custom::Register_TOKEN_ENCODER());
174 #endif  // TC3_WITH_ACTIONS_OPS
175   return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
176 }
177 
TfLiteModelFromModelSpec(const tflite::Model * model_spec)178 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
179     const tflite::Model* model_spec) {
180   std::unique_ptr<const tflite::FlatBufferModel> model(
181       tflite::FlatBufferModel::BuildFromModel(model_spec));
182   if (!model || !model->initialized()) {
183     TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
184     return nullptr;
185   }
186   return model;
187 }
188 
TfLiteModelFromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)189 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
190     const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
191   const tflite::Model* model =
192       flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
193   flatbuffers::Verifier verifier(model_spec_buffer->data(),
194                                  model_spec_buffer->Length());
195   if (!model->Verify(verifier)) {
196     return nullptr;
197   }
198   return TfLiteModelFromModelSpec(model);
199 }
200 
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)201 TfLiteModelExecutor::TfLiteModelExecutor(
202     std::unique_ptr<const tflite::FlatBufferModel> model)
203     : model_(std::move(model)), resolver_(BuildOpResolver()) {}
204 
CreateInterpreter() const205 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
206     const {
207   std::unique_ptr<tflite::Interpreter> interpreter;
208   tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
209   return interpreter;
210 }
211 
212 template <>
SetInput(const int input_index,const std::vector<std::string> & input_data,tflite::Interpreter * interpreter) const213 void TfLiteModelExecutor::SetInput(const int input_index,
214                                    const std::vector<std::string>& input_data,
215                                    tflite::Interpreter* interpreter) const {
216   tflite::DynamicBuffer buf;
217   for (const std::string& s : input_data) {
218     buf.AddString(s.data(), s.length());
219   }
220   buf.WriteToTensorAsVector(
221       interpreter->tensor(interpreter->inputs()[input_index]));
222 }
223 
224 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const225 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
226     const int output_index, const tflite::Interpreter* interpreter) const {
227   const TfLiteTensor* output_tensor =
228       interpreter->tensor(interpreter->outputs()[output_index]);
229   const int num_strings = tflite::GetStringCount(output_tensor);
230   std::vector<tflite::StringRef> output(num_strings);
231   for (int i = 0; i < num_strings; i++) {
232     output[i] = tflite::GetString(output_tensor, i);
233   }
234   return output;
235 }
236 
237 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const238 std::vector<std::string> TfLiteModelExecutor::Output(
239     const int output_index, const tflite::Interpreter* interpreter) const {
240   std::vector<std::string> output;
241   for (const tflite::StringRef& s :
242        Output<tflite::StringRef>(output_index, interpreter)) {
243     output.push_back(std::string(s.str, s.len));
244   }
245   return output;
246 }
247 
248 }  // namespace libtextclassifier3
249