1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite-model-executor.h"
18
19 #include "utils/base/logging.h"
20 #include "tensorflow/lite/kernels/register.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22
23 // Forward declaration of custom TensorFlow Lite ops for registration.
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 TfLiteRegistration* Register_ADD();
28 TfLiteRegistration* Register_CONCATENATION();
29 TfLiteRegistration* Register_CONV_2D();
30 TfLiteRegistration* Register_EQUAL();
31 TfLiteRegistration* Register_FULLY_CONNECTED();
32 TfLiteRegistration* Register_GREATER_EQUAL();
33 TfLiteRegistration* Register_L2_NORMALIZATION();
34 TfLiteRegistration* Register_MUL();
35 TfLiteRegistration* Register_RESHAPE();
36 TfLiteRegistration* Register_REDUCE_MAX();
37 TfLiteRegistration* Register_REDUCE_MIN();
38 TfLiteRegistration* Register_REDUCE_ANY();
39 TfLiteRegistration* Register_SOFTMAX();
40 TfLiteRegistration* Register_GATHER();
41 TfLiteRegistration* Register_GATHER_ND();
42 TfLiteRegistration* Register_IF();
43 TfLiteRegistration* Register_ROUND();
44 TfLiteRegistration* Register_ZEROS_LIKE();
45 TfLiteRegistration* Register_TRANSPOSE();
46 TfLiteRegistration* Register_SUB();
47 TfLiteRegistration* Register_DIV();
48 TfLiteRegistration* Register_STRIDED_SLICE();
49 TfLiteRegistration* Register_EXP();
50 TfLiteRegistration* Register_TOPK_V2();
51 TfLiteRegistration* Register_SLICE();
52 TfLiteRegistration* Register_SPLIT();
53 TfLiteRegistration* Register_CAST();
54 TfLiteRegistration* Register_MAXIMUM();
55 TfLiteRegistration* Register_MINIMUM();
56 TfLiteRegistration* Register_NEG();
57 TfLiteRegistration* Register_SLICE();
58 TfLiteRegistration* Register_LOG();
59 TfLiteRegistration* Register_LOGISTIC();
60 TfLiteRegistration* Register_SUM();
61 TfLiteRegistration* Register_PACK();
62 TfLiteRegistration* Register_DEQUANTIZE();
63 TfLiteRegistration* Register_MEAN();
64 TfLiteRegistration* Register_LESS();
65 TfLiteRegistration* Register_TILE();
66 TfLiteRegistration* Register_SQUARED_DIFFERENCE();
67 TfLiteRegistration* Register_RSQRT();
68 TfLiteRegistration* Register_LOG_SOFTMAX();
69 TfLiteRegistration* Register_WHERE();
70 TfLiteRegistration* Register_ONE_HOT();
71 TfLiteRegistration* Register_POW();
72 TfLiteRegistration* Register_TANH();
73 TfLiteRegistration* Register_UNIQUE();
74 TfLiteRegistration* Register_REDUCE_PROD();
75 TfLiteRegistration* Register_SHAPE();
76 TfLiteRegistration* Register_NOT_EQUAL();
77 TfLiteRegistration* Register_CUMSUM();
78 TfLiteRegistration* Register_EXPAND_DIMS();
79 TfLiteRegistration* Register_FILL();
80 TfLiteRegistration* Register_PADV2();
81 } // namespace builtin
82 } // namespace ops
83 } // namespace tflite
84
85 #ifdef TC3_WITH_ACTIONS_OPS
86 #include "utils/tflite/blacklist.h"
87 #include "utils/tflite/dist_diversification.h"
88 #include "utils/tflite/string_projection.h"
89 #include "utils/tflite/text_encoder.h"
90 #include "utils/tflite/token_encoder.h"
91 namespace tflite {
92 namespace ops {
93 namespace custom {
94 TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
95 TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
96 TfLiteRegistration* Register_RAGGED_RANGE();
97 TfLiteRegistration* Register_RANDOM_UNIFORM();
98 } // namespace custom
99 } // namespace ops
100 } // namespace tflite
101
RegisterSelectedOps(tflite::MutableOpResolver * resolver)102 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
103 resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
104 tflite::ops::builtin::Register_ADD(),
105 /*min_version=*/1,
106 /*max_version=*/2);
107 resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
108 tflite::ops::builtin::Register_CONCATENATION(),
109 /*min_version=*/1,
110 /*max_version=*/2);
111 resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
112 tflite::ops::builtin::Register_CONV_2D(),
113 /*min_version=*/1,
114 /*max_version=*/5);
115 resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
116 ::tflite::ops::builtin::Register_EQUAL());
117
118 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
119 tflite::ops::builtin::Register_FULLY_CONNECTED(),
120 /*min_version=*/1,
121 /*max_version=*/9);
122 resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
123 ::tflite::ops::builtin::Register_GREATER_EQUAL());
124 resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
125 tflite::ops::builtin::Register_L2_NORMALIZATION(),
126 /*min_version=*/1,
127 /*max_version=*/2);
128 resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
129 tflite::ops::builtin::Register_MUL());
130 resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
131 tflite::ops::builtin::Register_RESHAPE());
132 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MAX,
133 ::tflite::ops::builtin::Register_REDUCE_MAX());
134 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MIN,
135 ::tflite::ops::builtin::Register_REDUCE_MIN());
136 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_ANY,
137 ::tflite::ops::builtin::Register_REDUCE_ANY());
138 resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
139 tflite::ops::builtin::Register_SOFTMAX(),
140 /*min_version=*/1,
141 /*max_version=*/2);
142 resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
143 tflite::ops::builtin::Register_GATHER(),
144 /*min_version=*/1,
145 /*max_version=*/2);
146 resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND,
147 ::tflite::ops::builtin::Register_GATHER_ND(),
148 /*version=*/2);
149 resolver->AddBuiltin(::tflite::BuiltinOperator_IF,
150 ::tflite::ops::builtin::Register_IF()),
151 resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
152 ::tflite::ops::builtin::Register_ROUND());
153 resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE,
154 ::tflite::ops::builtin::Register_ZEROS_LIKE());
155 resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
156 tflite::ops::builtin::Register_TRANSPOSE(),
157 /*min_version=*/1,
158 /*max_version=*/2);
159 resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
160 tflite::ops::builtin::Register_SUB(),
161 /*min_version=*/1,
162 /*max_version=*/2);
163 resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
164 tflite::ops::builtin::Register_DIV());
165 resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
166 tflite::ops::builtin::Register_STRIDED_SLICE(),
167 /*min_version=*/1,
168 /*max_version=*/2);
169 resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
170 tflite::ops::builtin::Register_EXP());
171 resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
172 tflite::ops::builtin::Register_TOPK_V2(),
173 /*min_version=*/1,
174 /*max_version=*/2);
175 resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
176 tflite::ops::builtin::Register_SLICE(),
177 /*min_version=*/1,
178 /*max_version=*/3);
179 resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
180 tflite::ops::builtin::Register_SPLIT(),
181 /*min_version=*/1,
182 /*max_version=*/3);
183 resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
184 tflite::ops::builtin::Register_CAST());
185 resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
186 tflite::ops::builtin::Register_MAXIMUM(),
187 /*min_version=*/1,
188 /*max_version=*/2);
189 resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
190 tflite::ops::builtin::Register_MINIMUM(),
191 /*min_version=*/1,
192 /*max_version=*/2);
193 resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
194 tflite::ops::builtin::Register_NEG());
195 resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
196 tflite::ops::builtin::Register_SLICE(),
197 /*min_version=*/1,
198 /*max_version=*/2);
199 resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
200 tflite::ops::builtin::Register_LOG());
201 resolver->AddBuiltin(tflite::BuiltinOperator_LOGISTIC,
202 tflite::ops::builtin::Register_LOGISTIC());
203 resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
204 tflite::ops::builtin::Register_SUM());
205 resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
206 tflite::ops::builtin::Register_PACK(),
207 /*min_version=*/1,
208 /*max_version=*/2);
209 resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
210 tflite::ops::builtin::Register_DEQUANTIZE(),
211 /*min_version=*/1,
212 /*max_version=*/2);
213 resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
214 tflite::ops::builtin::Register_MEAN());
215 resolver->AddBuiltin(tflite::BuiltinOperator_LESS,
216 tflite::ops::builtin::Register_LESS());
217 resolver->AddBuiltin(tflite::BuiltinOperator_TILE,
218 tflite::ops::builtin::Register_TILE());
219 resolver->AddBuiltin(tflite::BuiltinOperator_SQUARED_DIFFERENCE,
220 tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
221 resolver->AddBuiltin(tflite::BuiltinOperator_RSQRT,
222 tflite::ops::builtin::Register_RSQRT());
223 resolver->AddBuiltin(tflite::BuiltinOperator_LOG_SOFTMAX,
224 tflite::ops::builtin::Register_LOG_SOFTMAX());
225 resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
226 ::tflite::ops::builtin::Register_WHERE());
227 resolver->AddBuiltin(tflite::BuiltinOperator_ONE_HOT,
228 tflite::ops::builtin::Register_ONE_HOT(),
229 /*min_version=*/1,
230 /*max_version=*/1);
231 resolver->AddBuiltin(tflite::BuiltinOperator_POW,
232 tflite::ops::builtin::Register_POW(),
233 /*min_version=*/1,
234 /*max_version=*/1);
235 resolver->AddBuiltin(tflite::BuiltinOperator_TANH,
236 tflite::ops::builtin::Register_TANH(),
237 /*min_version=*/1,
238 /*max_version=*/1);
239 resolver->AddBuiltin(::tflite::BuiltinOperator_UNIQUE,
240 ::tflite::ops::builtin::Register_UNIQUE());
241 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD,
242 ::tflite::ops::builtin::Register_REDUCE_PROD());
243 resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
244 ::tflite::ops::builtin::Register_SHAPE());
245 resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL,
246 ::tflite::ops::builtin::Register_NOT_EQUAL());
247 resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM,
248 ::tflite::ops::builtin::Register_CUMSUM());
249 resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
250 ::tflite::ops::builtin::Register_EXPAND_DIMS());
251 resolver->AddBuiltin(::tflite::BuiltinOperator_FILL,
252 ::tflite::ops::builtin::Register_FILL());
253 resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2,
254 ::tflite::ops::builtin::Register_PADV2());
255 }
256 #else
RegisterSelectedOps(tflite::MutableOpResolver * resolver)257 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
258 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
259 tflite::ops::builtin::Register_FULLY_CONNECTED());
260 }
261 #endif // TC3_WITH_ACTIONS_OPS
262
263 namespace libtextclassifier3 {
264
BuildOpResolver()265 std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
266 return BuildOpResolver([](tflite::MutableOpResolver* mutable_resolver) {});
267 }
268
BuildOpResolver(const std::function<void (tflite::MutableOpResolver *)> & customize_fn)269 std::unique_ptr<tflite::OpResolver> BuildOpResolver(
270 const std::function<void(tflite::MutableOpResolver*)>& customize_fn) {
271 #ifdef TC3_USE_SELECTIVE_REGISTRATION
272 std::unique_ptr<tflite::MutableOpResolver> resolver(
273 new tflite::MutableOpResolver);
274 RegisterSelectedOps(resolver.get());
275 #else
276 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
277 new tflite::ops::builtin::BuiltinOpResolver);
278 #endif
279 #ifdef TC3_WITH_ACTIONS_OPS
280 resolver->AddCustom("DistanceDiversification",
281 tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
282 resolver->AddCustom("TextEncoder",
283 tflite::ops::custom::Register_TEXT_ENCODER());
284 resolver->AddCustom("TokenEncoder",
285 tflite::ops::custom::Register_TOKEN_ENCODER());
286 resolver->AddCustom(
287 "TFSentencepieceTokenizeOp",
288 ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
289 resolver->AddCustom("RaggedRange",
290 ::tflite::ops::custom::Register_RAGGED_RANGE());
291 resolver->AddCustom(
292 "RaggedTensorToTensor",
293 ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
294 resolver->AddCustom(
295 "STRING_PROJECTION",
296 ::tflite::ops::custom::libtextclassifier3::Register_STRING_PROJECTION());
297 resolver->AddCustom(
298 "BLACKLIST",
299 ::tflite::ops::custom::libtextclassifier3::Register_BLACKLIST());
300 resolver->AddCustom("RandomUniform",
301 ::tflite::ops::custom::Register_RANDOM_UNIFORM());
302 #endif // TC3_WITH_ACTIONS_OPS
303 customize_fn(resolver.get());
304 return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
305 }
306
TfLiteModelFromModelSpec(const tflite::Model * model_spec)307 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
308 const tflite::Model* model_spec) {
309 std::unique_ptr<const tflite::FlatBufferModel> model(
310 tflite::FlatBufferModel::BuildFromModel(model_spec));
311 if (!model || !model->initialized()) {
312 TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
313 return nullptr;
314 }
315 return model;
316 }
317
TfLiteModelFromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)318 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
319 const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
320 const tflite::Model* model =
321 flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
322 flatbuffers::Verifier verifier(model_spec_buffer->data(),
323 model_spec_buffer->size());
324 if (!model->Verify(verifier)) {
325 return nullptr;
326 }
327 return TfLiteModelFromModelSpec(model);
328 }
329
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)330 TfLiteModelExecutor::TfLiteModelExecutor(
331 std::unique_ptr<const tflite::FlatBufferModel> model)
332 : model_(std::move(model)), resolver_(BuildOpResolver()) {}
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,std::unique_ptr<tflite::OpResolver> resolver)333 TfLiteModelExecutor::TfLiteModelExecutor(
334 std::unique_ptr<const tflite::FlatBufferModel> model,
335 std::unique_ptr<tflite::OpResolver> resolver)
336 : model_(std::move(model)), resolver_(std::move(resolver)) {}
337
CreateInterpreter() const338 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
339 const {
340 std::unique_ptr<tflite::Interpreter> interpreter;
341 tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
342 return interpreter;
343 }
344
345 template <>
SetInput(const int input_index,const std::vector<std::string> & input_data,tflite::Interpreter * interpreter) const346 void TfLiteModelExecutor::SetInput(const int input_index,
347 const std::vector<std::string>& input_data,
348 tflite::Interpreter* interpreter) const {
349 tflite::DynamicBuffer buf;
350 for (const std::string& s : input_data) {
351 buf.AddString(s.data(), s.length());
352 }
353 buf.WriteToTensorAsVector(
354 interpreter->tensor(interpreter->inputs()[input_index]));
355 }
356
357 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const358 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
359 const int output_index, const tflite::Interpreter* interpreter) const {
360 const TfLiteTensor* output_tensor =
361 interpreter->tensor(interpreter->outputs()[output_index]);
362 const int num_strings = tflite::GetStringCount(output_tensor);
363 std::vector<tflite::StringRef> output(num_strings);
364 for (int i = 0; i < num_strings; i++) {
365 output[i] = tflite::GetString(output_tensor, i);
366 }
367 return output;
368 }
369
370 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const371 std::vector<std::string> TfLiteModelExecutor::Output(
372 const int output_index, const tflite::Interpreter* interpreter) const {
373 std::vector<std::string> output;
374 for (const tflite::StringRef& s :
375 Output<tflite::StringRef>(output_index, interpreter)) {
376 output.push_back(std::string(s.str, s.len));
377 }
378 return output;
379 }
380
381 } // namespace libtextclassifier3
382