1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite-model-executor.h"
18
19 #include "utils/base/logging.h"
20 #include "tensorflow/lite/kernels/register.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22
23 // Forward declaration of custom TensorFlow Lite ops for registration.
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 TfLiteRegistration* Register_ADD();
28 TfLiteRegistration* Register_CONCATENATION();
29 TfLiteRegistration* Register_CONV_2D();
30 TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
31 TfLiteRegistration* Register_AVERAGE_POOL_2D();
32 TfLiteRegistration* Register_EQUAL();
33 TfLiteRegistration* Register_FULLY_CONNECTED();
34 TfLiteRegistration* Register_GREATER_EQUAL();
35 TfLiteRegistration* Register_L2_NORMALIZATION();
36 TfLiteRegistration* Register_MUL();
37 TfLiteRegistration* Register_RESHAPE();
38 TfLiteRegistration* Register_REDUCE_MAX();
39 TfLiteRegistration* Register_REDUCE_MIN();
40 TfLiteRegistration* Register_REDUCE_ANY();
41 TfLiteRegistration* Register_SOFTMAX();
42 TfLiteRegistration* Register_GATHER();
43 TfLiteRegistration* Register_GATHER_ND();
44 TfLiteRegistration* Register_IF();
45 TfLiteRegistration* Register_ROUND();
46 TfLiteRegistration* Register_ZEROS_LIKE();
47 TfLiteRegistration* Register_TRANSPOSE();
48 TfLiteRegistration* Register_SUB();
49 TfLiteRegistration* Register_DIV();
50 TfLiteRegistration* Register_STRIDED_SLICE();
51 TfLiteRegistration* Register_EXP();
52 TfLiteRegistration* Register_TOPK_V2();
53 TfLiteRegistration* Register_SLICE();
54 TfLiteRegistration* Register_SPLIT();
55 TfLiteRegistration* Register_CAST();
56 TfLiteRegistration* Register_MAXIMUM();
57 TfLiteRegistration* Register_MINIMUM();
58 TfLiteRegistration* Register_NEG();
59 TfLiteRegistration* Register_SLICE();
60 TfLiteRegistration* Register_LOG();
61 TfLiteRegistration* Register_LOGISTIC();
62 TfLiteRegistration* Register_SUM();
63 TfLiteRegistration* Register_PACK();
64 TfLiteRegistration* Register_DEQUANTIZE();
65 TfLiteRegistration* Register_MEAN();
66 TfLiteRegistration* Register_LESS();
67 TfLiteRegistration* Register_TILE();
68 TfLiteRegistration* Register_SQUARED_DIFFERENCE();
69 TfLiteRegistration* Register_RSQRT();
70 TfLiteRegistration* Register_LOG_SOFTMAX();
71 TfLiteRegistration* Register_WHERE();
72 TfLiteRegistration* Register_ONE_HOT();
73 TfLiteRegistration* Register_POW();
74 TfLiteRegistration* Register_TANH();
75 TfLiteRegistration* Register_UNIQUE();
76 TfLiteRegistration* Register_REDUCE_PROD();
77 TfLiteRegistration* Register_SHAPE();
78 TfLiteRegistration* Register_NOT_EQUAL();
79 TfLiteRegistration* Register_CUMSUM();
80 TfLiteRegistration* Register_EXPAND_DIMS();
81 TfLiteRegistration* Register_FILL();
82 TfLiteRegistration* Register_PADV2();
83 TfLiteRegistration* Register_EMBEDDING_LOOKUP();
84 TfLiteRegistration* Register_GREATER();
85 } // namespace builtin
86 } // namespace ops
87 } // namespace tflite
88
89 #ifdef TC3_WITH_ACTIONS_OPS
90 #include "utils/tflite/blacklist.h"
91 #include "utils/tflite/dist_diversification.h"
92 #include "utils/tflite/string_projection.h"
93 #include "utils/tflite/text_encoder.h"
94 #include "utils/tflite/text_encoder3s.h"
95 #include "utils/tflite/token_encoder.h"
96
97 namespace tflite {
98 namespace ops {
99 namespace custom {
100 TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
101 TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
102 TfLiteRegistration* Register_RAGGED_RANGE();
103 TfLiteRegistration* Register_RANDOM_UNIFORM();
104 } // namespace custom
105 } // namespace ops
106 } // namespace tflite
107
RegisterSelectedOps(tflite::MutableOpResolver * resolver)108 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
109 resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
110 tflite::ops::builtin::Register_ADD(),
111 /*min_version=*/1,
112 /*max_version=*/2);
113 resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
114 tflite::ops::builtin::Register_CONCATENATION(),
115 /*min_version=*/1,
116 /*max_version=*/2);
117 resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
118 tflite::ops::builtin::Register_CONV_2D(),
119 /*min_version=*/1,
120 /*max_version=*/5);
121 resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
122 tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(),
123 /*min_version=*/1,
124 /*max_version=*/6);
125 resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
126 tflite::ops::builtin::Register_AVERAGE_POOL_2D(),
127 /*min_version=*/1,
128 /*max_version=*/1);
129 resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
130 ::tflite::ops::builtin::Register_EQUAL());
131
132 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
133 tflite::ops::builtin::Register_FULLY_CONNECTED(),
134 /*min_version=*/1,
135 /*max_version=*/9);
136 resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER_EQUAL,
137 ::tflite::ops::builtin::Register_GREATER_EQUAL());
138 resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
139 tflite::ops::builtin::Register_L2_NORMALIZATION(),
140 /*min_version=*/1,
141 /*max_version=*/2);
142 resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
143 tflite::ops::builtin::Register_MUL());
144 resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
145 tflite::ops::builtin::Register_RESHAPE());
146 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MAX,
147 ::tflite::ops::builtin::Register_REDUCE_MAX());
148 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_MIN,
149 ::tflite::ops::builtin::Register_REDUCE_MIN());
150 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_ANY,
151 ::tflite::ops::builtin::Register_REDUCE_ANY());
152 resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
153 tflite::ops::builtin::Register_SOFTMAX(),
154 /*min_version=*/1,
155 /*max_version=*/2);
156 resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
157 tflite::ops::builtin::Register_GATHER(),
158 /*min_version=*/1,
159 /*max_version=*/2);
160 resolver->AddBuiltin(::tflite::BuiltinOperator_GATHER_ND,
161 ::tflite::ops::builtin::Register_GATHER_ND(),
162 /*version=*/2);
163 resolver->AddBuiltin(::tflite::BuiltinOperator_IF,
164 ::tflite::ops::builtin::Register_IF()),
165 resolver->AddBuiltin(::tflite::BuiltinOperator_ROUND,
166 ::tflite::ops::builtin::Register_ROUND());
167 resolver->AddBuiltin(::tflite::BuiltinOperator_ZEROS_LIKE,
168 ::tflite::ops::builtin::Register_ZEROS_LIKE());
169 resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
170 tflite::ops::builtin::Register_TRANSPOSE(),
171 /*min_version=*/1,
172 /*max_version=*/2);
173 resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
174 tflite::ops::builtin::Register_SUB(),
175 /*min_version=*/1,
176 /*max_version=*/2);
177 resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
178 tflite::ops::builtin::Register_DIV());
179 resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
180 tflite::ops::builtin::Register_STRIDED_SLICE(),
181 /*min_version=*/1,
182 /*max_version=*/2);
183 resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
184 tflite::ops::builtin::Register_EXP());
185 resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
186 tflite::ops::builtin::Register_TOPK_V2(),
187 /*min_version=*/1,
188 /*max_version=*/2);
189 resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
190 tflite::ops::builtin::Register_SLICE(),
191 /*min_version=*/1,
192 /*max_version=*/3);
193 resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
194 tflite::ops::builtin::Register_SPLIT(),
195 /*min_version=*/1,
196 /*max_version=*/3);
197 resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
198 tflite::ops::builtin::Register_CAST());
199 resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
200 tflite::ops::builtin::Register_MAXIMUM(),
201 /*min_version=*/1,
202 /*max_version=*/2);
203 resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
204 tflite::ops::builtin::Register_MINIMUM(),
205 /*min_version=*/1,
206 /*max_version=*/2);
207 resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
208 tflite::ops::builtin::Register_NEG());
209 resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
210 tflite::ops::builtin::Register_SLICE(),
211 /*min_version=*/1,
212 /*max_version=*/2);
213 resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
214 tflite::ops::builtin::Register_LOG());
215 resolver->AddBuiltin(tflite::BuiltinOperator_LOGISTIC,
216 tflite::ops::builtin::Register_LOGISTIC());
217 resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
218 tflite::ops::builtin::Register_SUM());
219 resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
220 tflite::ops::builtin::Register_PACK(),
221 /*min_version=*/1,
222 /*max_version=*/2);
223 resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
224 tflite::ops::builtin::Register_DEQUANTIZE(),
225 /*min_version=*/1,
226 /*max_version=*/2);
227 resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
228 tflite::ops::builtin::Register_MEAN());
229 resolver->AddBuiltin(tflite::BuiltinOperator_LESS,
230 tflite::ops::builtin::Register_LESS());
231 resolver->AddBuiltin(tflite::BuiltinOperator_TILE,
232 tflite::ops::builtin::Register_TILE());
233 resolver->AddBuiltin(tflite::BuiltinOperator_SQUARED_DIFFERENCE,
234 tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
235 resolver->AddBuiltin(tflite::BuiltinOperator_RSQRT,
236 tflite::ops::builtin::Register_RSQRT());
237 resolver->AddBuiltin(tflite::BuiltinOperator_LOG_SOFTMAX,
238 tflite::ops::builtin::Register_LOG_SOFTMAX());
239 resolver->AddBuiltin(::tflite::BuiltinOperator_WHERE,
240 ::tflite::ops::builtin::Register_WHERE());
241 resolver->AddBuiltin(tflite::BuiltinOperator_ONE_HOT,
242 tflite::ops::builtin::Register_ONE_HOT(),
243 /*min_version=*/1,
244 /*max_version=*/1);
245 resolver->AddBuiltin(tflite::BuiltinOperator_POW,
246 tflite::ops::builtin::Register_POW(),
247 /*min_version=*/1,
248 /*max_version=*/1);
249 resolver->AddBuiltin(tflite::BuiltinOperator_TANH,
250 tflite::ops::builtin::Register_TANH(),
251 /*min_version=*/1,
252 /*max_version=*/1);
253 resolver->AddBuiltin(::tflite::BuiltinOperator_UNIQUE,
254 ::tflite::ops::builtin::Register_UNIQUE());
255 resolver->AddBuiltin(::tflite::BuiltinOperator_REDUCE_PROD,
256 ::tflite::ops::builtin::Register_REDUCE_PROD());
257 resolver->AddBuiltin(::tflite::BuiltinOperator_SHAPE,
258 ::tflite::ops::builtin::Register_SHAPE());
259 resolver->AddBuiltin(::tflite::BuiltinOperator_NOT_EQUAL,
260 ::tflite::ops::builtin::Register_NOT_EQUAL());
261 resolver->AddBuiltin(::tflite::BuiltinOperator_CUMSUM,
262 ::tflite::ops::builtin::Register_CUMSUM());
263 resolver->AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
264 ::tflite::ops::builtin::Register_EXPAND_DIMS());
265 resolver->AddBuiltin(::tflite::BuiltinOperator_FILL,
266 ::tflite::ops::builtin::Register_FILL());
267 resolver->AddBuiltin(::tflite::BuiltinOperator_PADV2,
268 ::tflite::ops::builtin::Register_PADV2());
269 resolver->AddBuiltin(::tflite::BuiltinOperator_EMBEDDING_LOOKUP,
270 ::tflite::ops::builtin::Register_EMBEDDING_LOOKUP(),
271 /* min_version=*/1,
272 /*max_version=*/3);
273 resolver->AddBuiltin(::tflite::BuiltinOperator_GREATER,
274 ::tflite::ops::builtin::Register_GREATER());
275 }
276 #else
RegisterSelectedOps(tflite::MutableOpResolver * resolver)277 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
278 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
279 tflite::ops::builtin::Register_FULLY_CONNECTED());
280 }
281 #endif // TC3_WITH_ACTIONS_OPS
282
283 namespace libtextclassifier3 {
284
BuildOpResolver()285 std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
286 return BuildOpResolver([](tflite::MutableOpResolver* mutable_resolver) {});
287 }
288
BuildOpResolver(const std::function<void (tflite::MutableOpResolver *)> & customize_fn)289 std::unique_ptr<tflite::OpResolver> BuildOpResolver(
290 const std::function<void(tflite::MutableOpResolver*)>& customize_fn) {
291 #ifdef TC3_USE_SELECTIVE_REGISTRATION
292 std::unique_ptr<tflite::MutableOpResolver> resolver(
293 new tflite::MutableOpResolver);
294 RegisterSelectedOps(resolver.get());
295 #else
296 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
297 new tflite::ops::builtin::BuiltinOpResolver);
298 #endif
299 #ifdef TC3_WITH_ACTIONS_OPS
300 resolver->AddCustom("DistanceDiversification",
301 tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
302 resolver->AddCustom("TextEncoder",
303 tflite::ops::custom::Register_TEXT_ENCODER());
304 resolver->AddCustom("TextEncoder3S",
305 tflite::ops::custom::Register_TEXT_ENCODER3S());
306 resolver->AddCustom("TokenEncoder",
307 tflite::ops::custom::Register_TOKEN_ENCODER());
308 resolver->AddCustom(
309 "TFSentencepieceTokenizeOp",
310 ::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
311 resolver->AddCustom("RaggedRange",
312 ::tflite::ops::custom::Register_RAGGED_RANGE());
313 resolver->AddCustom(
314 "RaggedTensorToTensor",
315 ::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
316 resolver->AddCustom(
317 "STRING_PROJECTION",
318 ::tflite::ops::custom::libtextclassifier3::Register_STRING_PROJECTION());
319 resolver->AddCustom(
320 "BLACKLIST",
321 ::tflite::ops::custom::libtextclassifier3::Register_BLACKLIST());
322 resolver->AddCustom("RandomUniform",
323 ::tflite::ops::custom::Register_RANDOM_UNIFORM());
324 #endif // TC3_WITH_ACTIONS_OPS
325 customize_fn(resolver.get());
326 return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
327 }
328
TfLiteModelFromModelSpec(const tflite::Model * model_spec)329 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
330 const tflite::Model* model_spec) {
331 std::unique_ptr<const tflite::FlatBufferModel> model(
332 tflite::FlatBufferModel::BuildFromModel(model_spec));
333 if (!model || !model->initialized()) {
334 TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
335 return nullptr;
336 }
337 return model;
338 }
339
TfLiteModelFromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)340 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
341 const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
342 const tflite::Model* model =
343 flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
344 flatbuffers::Verifier verifier(model_spec_buffer->data(),
345 model_spec_buffer->size());
346 if (!model->Verify(verifier)) {
347 return nullptr;
348 }
349 return TfLiteModelFromModelSpec(model);
350 }
351
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)352 TfLiteModelExecutor::TfLiteModelExecutor(
353 std::unique_ptr<const tflite::FlatBufferModel> model)
354 : model_(std::move(model)), resolver_(BuildOpResolver()) {}
TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model,std::unique_ptr<tflite::OpResolver> resolver)355 TfLiteModelExecutor::TfLiteModelExecutor(
356 std::unique_ptr<const tflite::FlatBufferModel> model,
357 std::unique_ptr<tflite::OpResolver> resolver)
358 : model_(std::move(model)), resolver_(std::move(resolver)) {}
359
CreateInterpreter() const360 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
361 const {
362 std::unique_ptr<tflite::Interpreter> interpreter;
363 tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
364 return interpreter;
365 }
366
367 template <>
SetInput(const int input_index,const std::vector<std::string> & input_data,tflite::Interpreter * interpreter) const368 void TfLiteModelExecutor::SetInput(const int input_index,
369 const std::vector<std::string>& input_data,
370 tflite::Interpreter* interpreter) const {
371 tflite::DynamicBuffer buf;
372 for (const std::string& s : input_data) {
373 buf.AddString(s.data(), s.length());
374 }
375 buf.WriteToTensorAsVector(
376 interpreter->tensor(interpreter->inputs()[input_index]));
377 }
378
379 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const380 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
381 const int output_index, const tflite::Interpreter* interpreter) const {
382 const TfLiteTensor* output_tensor =
383 interpreter->tensor(interpreter->outputs()[output_index]);
384 const int num_strings = tflite::GetStringCount(output_tensor);
385 std::vector<tflite::StringRef> output(num_strings);
386 for (int i = 0; i < num_strings; i++) {
387 output[i] = tflite::GetString(output_tensor, i);
388 }
389 return output;
390 }
391
392 template <>
Output(const int output_index,const tflite::Interpreter * interpreter) const393 std::vector<std::string> TfLiteModelExecutor::Output(
394 const int output_index, const tflite::Interpreter* interpreter) const {
395 std::vector<std::string> output;
396 for (const tflite::StringRef& s :
397 Output<tflite::StringRef>(output_index, interpreter)) {
398 output.push_back(std::string(s.str, s.len));
399 }
400 return output;
401 }
402
403 } // namespace libtextclassifier3
404