1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <fcntl.h>
16 #include <stdint.h>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <sys/mman.h>
20 #include <sys/stat.h>
21 #include <sys/types.h>
22 #include <unistd.h>
23
24 #include "tensorflow/contrib/lite/allocation.h"
25 #include "tensorflow/contrib/lite/builtin_op_data.h"
26 #include "tensorflow/contrib/lite/error_reporter.h"
27 #include "tensorflow/contrib/lite/model.h"
28 #include "tensorflow/contrib/lite/nnapi_delegate.h"
29 #include "tensorflow/contrib/lite/version.h"
30
31 namespace tflite {
32
33 const char* kEmptyTensorName = "";
34
BuildFromFile(const char * filename,ErrorReporter * error_reporter)35 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
36 const char* filename, ErrorReporter* error_reporter) {
37 std::unique_ptr<FlatBufferModel> model;
38 model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter,
39 /*use_nnapi=*/true));
40 if (!model->initialized()) model.reset();
41 return model;
42 }
43
BuildFromBuffer(const char * buffer,size_t buffer_size,ErrorReporter * error_reporter)44 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
45 const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
46 std::unique_ptr<FlatBufferModel> model;
47 model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter));
48 if (!model->initialized()) model.reset();
49 return model;
50 }
51
BuildFromModel(const tflite::Model * model_spec,ErrorReporter * error_reporter)52 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
53 const tflite::Model* model_spec, ErrorReporter* error_reporter) {
54 std::unique_ptr<FlatBufferModel> model;
55 model.reset(new FlatBufferModel(model_spec, error_reporter));
56 if (!model->initialized()) model.reset();
57 return model;
58 }
59
FlatBufferModel(const char * filename,bool mmap_file,ErrorReporter * error_reporter,bool use_nnapi)60 FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
61 ErrorReporter* error_reporter, bool use_nnapi)
62 : error_reporter_(error_reporter ? error_reporter
63 : DefaultErrorReporter()) {
64 if (mmap_file) {
65 if (use_nnapi && NNAPIExists())
66 allocation_ = new NNAPIAllocation(filename, error_reporter);
67 else
68 allocation_ = new MMAPAllocation(filename, error_reporter);
69 } else {
70 allocation_ = new FileCopyAllocation(filename, error_reporter);
71 }
72 if (!allocation_->valid() || !CheckModelIdentifier()) return;
73
74 model_ = ::tflite::GetModel(allocation_->base());
75 }
76
CheckModelIdentifier() const77 bool FlatBufferModel::CheckModelIdentifier() const {
78 if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
79 const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
80 error_reporter_->Report(
81 "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
82 ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
83 return false;
84 }
85 return true;
86 }
87
FlatBufferModel(const char * ptr,size_t num_bytes,ErrorReporter * error_reporter)88 FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
89 ErrorReporter* error_reporter)
90 : error_reporter_(error_reporter ? error_reporter
91 : DefaultErrorReporter()) {
92 allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
93 if (!allocation_->valid()) return;
94
95 model_ = ::tflite::GetModel(allocation_->base());
96 }
97
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)98 FlatBufferModel::FlatBufferModel(const Model* model,
99 ErrorReporter* error_reporter)
100 : error_reporter_(error_reporter ? error_reporter
101 : DefaultErrorReporter()) {
102 model_ = model;
103 }
104
~FlatBufferModel()105 FlatBufferModel::~FlatBufferModel() { delete allocation_; }
106
InterpreterBuilder(const FlatBufferModel & model,const OpResolver & op_resolver)107 InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
108 const OpResolver& op_resolver)
109 : model_(model.GetModel()),
110 op_resolver_(op_resolver),
111 error_reporter_(model.error_reporter()),
112 allocation_(model.allocation()) {}
113
InterpreterBuilder(const::tflite::Model * model,const OpResolver & op_resolver,ErrorReporter * error_reporter)114 InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
115 const OpResolver& op_resolver,
116 ErrorReporter* error_reporter)
117 : model_(model),
118 op_resolver_(op_resolver),
119 error_reporter_(error_reporter ? error_reporter
120 : DefaultErrorReporter()) {}
121
BuildLocalIndexToRegistrationMapping()122 TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
123 TfLiteStatus status = kTfLiteOk;
124 auto opcodes = model_->operator_codes();
125 for (const OperatorCode* opcode : *opcodes) {
126 TfLiteRegistration* registration = nullptr;
127
128 if (opcode->builtin_code() != BuiltinOperator_CUSTOM) {
129 auto x = opcode->builtin_code();
130 flatbuffer_op_index_to_registration_types_.push_back(x);
131 registration = op_resolver_.FindOp(x);
132 if (registration == nullptr) {
133 error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
134 EnumNameBuiltinOperator(x));
135 status = kTfLiteError;
136 }
137 } else if (!opcode->custom_code()) {
138 error_reporter_->Report(
139 "Operator with CUSTOM builtin_code has no custom_code.\n");
140 status = kTfLiteError;
141 } else {
142 const char* name = opcode->custom_code()->c_str();
143 registration = op_resolver_.FindOp(name);
144 flatbuffer_op_index_to_registration_types_.push_back(
145 BuiltinOperator_CUSTOM);
146 if (registration == nullptr) {
147 error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
148 status = kTfLiteError;
149 }
150 }
151 flatbuffer_op_index_to_registration_.push_back(registration);
152 }
153 return status;
154 }
155
156 namespace {
157 template <class T>
FlatBufferIntArrayToVector(T * flat_array)158 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
159 std::vector<int> ret(flat_array->Length());
160 for (int i = 0; i < flat_array->Length(); i++) {
161 ret[i] = flat_array->Get(i);
162 }
163 return ret;
164 }
165
166 // Copies the contents from the flatbuffer int vector `flatbuffer` into the
167 // int array `buffer`. `flat_vector` and `buffer` represent the same
168 // configuration operation for a given operation.
FlatBufferIntVectorToArray(int max_size_of_buffer,const flatbuffers::Vector<int32_t> * flat_vector,int * buffer,ErrorReporter * error_reporter)169 void FlatBufferIntVectorToArray(int max_size_of_buffer,
170 const flatbuffers::Vector<int32_t>* flat_vector,
171 int* buffer, ErrorReporter* error_reporter) {
172 if (!flat_vector) {
173 error_reporter->Report("Input array not provided for operation.\n");
174 } else {
175 int num_dimensions = flat_vector->Length();
176 if (num_dimensions > max_size_of_buffer / sizeof(int)) {
177 error_reporter->Report(
178 "Found too many dimensions in the operation's input array.\n");
179 } else {
180 for (int i = 0; i < num_dimensions; ++i) {
181 buffer[i] = flat_vector->Get(i);
182 }
183 }
184 }
185 }
186
187 // Allocate a structure using C malloc, but make sure the structure is a
188 // POD structure that doesn't require constructors to run. The reason we do
189 // this, is that Interpreter's C extension part will take ownership and wants
190 // to use malloc() and free().
191 template <class T>
MallocPOD()192 T* MallocPOD() {
193 static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
194 return static_cast<T*>(malloc(sizeof(T)));
195 }
196
197 // Parse the appropriate data out of the op.
198 //
199 // This handles builtin data explicitly as there are flatbuffer schemas.
200 //
201 // Returns memory that must be feed.
202 //
203 // TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program
204 // crashes if error reporter is called.
ParseOpData(const Operator * op,BuiltinOperator op_type,ErrorReporter * error_reporter)205 void* ParseOpData(const Operator* op, BuiltinOperator op_type,
206 ErrorReporter* error_reporter) {
207 auto parse_padding = [](Padding padding) {
208 switch (padding) {
209 case Padding_SAME:
210 return kTfLitePaddingSame;
211 case Padding_VALID:
212 return kTfLitePaddingValid;
213 }
214 return kTfLitePaddingUnknown;
215 };
216 auto parse_activation = [](ActivationFunctionType activation) {
217 switch (activation) {
218 case ActivationFunctionType_NONE:
219 return kTfLiteActNone;
220 case ActivationFunctionType_RELU:
221 return kTfLiteActRelu;
222 case ActivationFunctionType_RELU_N1_TO_1:
223 return kTfLiteActRelu1;
224 case ActivationFunctionType_RELU6:
225 return kTfLiteActRelu6;
226 case ActivationFunctionType_TANH:
227 return kTfLiteActTanh;
228 case ActivationFunctionType_SIGN_BIT:
229 return kTfLiteActSignBit;
230 }
231 return kTfLiteActNone;
232 };
233 auto parseLSHProjectionType = [](LSHProjectionType type) {
234 switch (type) {
235 case LSHProjectionType_SPARSE:
236 return kTfLiteLshProjectionSparse;
237 case LSHProjectionType_DENSE:
238 return kTfLiteLshProjectionDense;
239 default:
240 return kTfLiteLshProjectionUnknown;
241 }
242 };
243 auto parseCombinerType = [](CombinerType type) {
244 switch (type) {
245 case CombinerType_MEAN:
246 return kTfLiteCombinerTypeMean;
247 case CombinerType_SQRTN:
248 return kTfLiteCombinerTypeSqrtn;
249 case CombinerType_SUM:
250 default:
251 return kTfLiteCombinerTypeSum;
252 }
253 };
254
255 void* builtin_data = nullptr;
256 switch (op_type) {
257 case BuiltinOperator_CALL:
258 // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
259 // ok for now, since there is no call implementation either.
260 break;
261 case BuiltinOperator_CUSTOM:
262 break;
263 case BuiltinOperator_CONV_2D: {
264 TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
265 if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
266 params->padding = parse_padding(conv_params->padding());
267 params->stride_width = conv_params->stride_w();
268 params->stride_height = conv_params->stride_h();
269 params->activation =
270 parse_activation(conv_params->fused_activation_function());
271 }
272 builtin_data = reinterpret_cast<void*>(params);
273 break;
274 }
275 case BuiltinOperator_TANH:
276 case BuiltinOperator_LOGISTIC:
277 case BuiltinOperator_RELU:
278 case BuiltinOperator_RELU_N1_TO_1:
279 case BuiltinOperator_RELU6:
280 case BuiltinOperator_CONCAT_EMBEDDINGS:
281 case BuiltinOperator_EXP:
282 case BuiltinOperator_TOPK_V2:
283 break;
284 case BuiltinOperator_LSH_PROJECTION: {
285 TfLiteLSHProjectionParams* params =
286 MallocPOD<TfLiteLSHProjectionParams>();
287 if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
288 params->type = parseLSHProjectionType(lshParams->type());
289 }
290 builtin_data = reinterpret_cast<void*>(params);
291 break;
292 }
293 case BuiltinOperator_AVERAGE_POOL_2D:
294 case BuiltinOperator_MAX_POOL_2D:
295 case BuiltinOperator_L2_POOL_2D: {
296 TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
297 if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
298 params->padding = parse_padding(pool_params->padding());
299 params->stride_width = pool_params->stride_w();
300 params->stride_height = pool_params->stride_h();
301 params->filter_width = pool_params->filter_width();
302 params->filter_height = pool_params->filter_height();
303 params->activation =
304 parse_activation(pool_params->fused_activation_function());
305 }
306 builtin_data = reinterpret_cast<void*>(params);
307 break;
308 }
309 case BuiltinOperator_DEPTHWISE_CONV_2D: {
310 TfLiteDepthwiseConvParams* params =
311 MallocPOD<TfLiteDepthwiseConvParams>();
312 if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
313 params->padding = parse_padding(conv_params->padding());
314 params->stride_width = conv_params->stride_w();
315 params->stride_height = conv_params->stride_h();
316 params->depth_multiplier = conv_params->depth_multiplier();
317 params->activation =
318 parse_activation(conv_params->fused_activation_function());
319 }
320 builtin_data = reinterpret_cast<void*>(params);
321 break;
322 }
323 case BuiltinOperator_SVDF: {
324 TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
325 if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
326 params->rank = svdf_params->rank();
327 params->activation =
328 parse_activation(svdf_params->fused_activation_function());
329 }
330 builtin_data = reinterpret_cast<void*>(params);
331 break;
332 }
333 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
334 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
335 TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
336 if (auto* sequence_rnn_params =
337 op->builtin_options_as_SequenceRNNOptions()) {
338 params->activation =
339 parse_activation(sequence_rnn_params->fused_activation_function());
340 params->time_major = sequence_rnn_params->time_major();
341 }
342 builtin_data = reinterpret_cast<void*>(params);
343 break;
344 }
345 case BuiltinOperator_RNN: {
346 TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
347 if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
348 params->activation =
349 parse_activation(rnn_params->fused_activation_function());
350 }
351 builtin_data = reinterpret_cast<void*>(params);
352 break;
353 }
354 case BuiltinOperator_EMBEDDING_LOOKUP:
355 // no-op.
356 break;
357 case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
358 TfLiteEmbeddingLookupSparseParams* params =
359 MallocPOD<TfLiteEmbeddingLookupSparseParams>();
360 if (auto* embedding_params =
361 op->builtin_options_as_EmbeddingLookupSparseOptions()) {
362 params->combiner = parseCombinerType(embedding_params->combiner());
363 }
364 builtin_data = reinterpret_cast<void*>(params);
365 break;
366 }
367 case BuiltinOperator_FULLY_CONNECTED: {
368 TfLiteFullyConnectedParams* params =
369 MallocPOD<TfLiteFullyConnectedParams>();
370 if (auto* fully_connected_params =
371 op->builtin_options_as_FullyConnectedOptions()) {
372 params->activation = parse_activation(
373 fully_connected_params->fused_activation_function());
374 }
375 builtin_data = reinterpret_cast<void*>(params);
376 break;
377 }
378 case BuiltinOperator_HASHTABLE_LOOKUP:
379 // no-op.
380 break;
381 case BuiltinOperator_SOFTMAX: {
382 TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
383 if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
384 params->beta = softmax_params->beta();
385 }
386 builtin_data = reinterpret_cast<void*>(params);
387 break;
388 }
389 case BuiltinOperator_CONCATENATION: {
390 TfLiteConcatenationParams* params =
391 MallocPOD<TfLiteConcatenationParams>();
392 if (auto* concatenation_params =
393 op->builtin_options_as_ConcatenationOptions()) {
394 params->activation =
395 parse_activation(concatenation_params->fused_activation_function());
396 params->axis = concatenation_params->axis();
397 }
398 builtin_data = reinterpret_cast<void*>(params);
399 break;
400 }
401 case BuiltinOperator_MUL: {
402 auto* params = MallocPOD<TfLiteMulParams>();
403 if (auto* schema_params = op->builtin_options_as_MulOptions()) {
404 params->activation =
405 parse_activation(schema_params->fused_activation_function());
406 }
407 builtin_data = reinterpret_cast<void*>(params);
408 break;
409 }
410 case BuiltinOperator_ADD: {
411 auto* params = MallocPOD<TfLiteAddParams>();
412 if (auto* schema_params = op->builtin_options_as_AddOptions()) {
413 params->activation =
414 parse_activation(schema_params->fused_activation_function());
415 }
416 builtin_data = reinterpret_cast<void*>(params);
417 break;
418 }
419 case BuiltinOperator_DIV: {
420 auto* params = MallocPOD<TfLiteDivParams>();
421 if (auto* schema_params = op->builtin_options_as_DivOptions()) {
422 params->activation =
423 parse_activation(schema_params->fused_activation_function());
424 }
425 builtin_data = reinterpret_cast<void*>(params);
426 break;
427 }
428 case BuiltinOperator_SUB: {
429 auto* params = MallocPOD<TfLiteSubParams>();
430 if (auto* schema_params = op->builtin_options_as_SubOptions()) {
431 params->activation =
432 parse_activation(schema_params->fused_activation_function());
433 }
434 builtin_data = reinterpret_cast<void*>(params);
435 break;
436 }
437 case BuiltinOperator_L2_NORMALIZATION: {
438 auto* params = MallocPOD<TfLiteL2NormParams>();
439 if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
440 params->activation =
441 parse_activation(schema_params->fused_activation_function());
442 }
443 builtin_data = reinterpret_cast<void*>(params);
444 break;
445 }
446 case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
447 auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
448 if (auto* schema_params =
449 op->builtin_options_as_LocalResponseNormalizationOptions()) {
450 params->radius = schema_params->radius();
451 params->bias = schema_params->bias();
452 params->alpha = schema_params->alpha();
453 params->beta = schema_params->beta();
454 }
455 builtin_data = reinterpret_cast<void*>(params);
456 break;
457 }
458 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
459 case BuiltinOperator_LSTM: {
460 TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
461 if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
462 params->activation =
463 parse_activation(lstm_params->fused_activation_function());
464 params->cell_clip = lstm_params->cell_clip();
465 params->proj_clip = lstm_params->proj_clip();
466 }
467 builtin_data = reinterpret_cast<void*>(params);
468 break;
469 }
470 case BuiltinOperator_RESIZE_BILINEAR: {
471 auto* params = MallocPOD<TfLiteResizeBilinearParams>();
472 if (auto* schema_params =
473 op->builtin_options_as_ResizeBilinearOptions()) {
474 params->align_corners = schema_params->align_corners();
475 }
476 builtin_data = reinterpret_cast<void*>(params);
477 break;
478 }
479 case BuiltinOperator_PAD: {
480 break;
481 }
482 case BuiltinOperator_RESHAPE: {
483 auto* params = MallocPOD<TfLiteReshapeParams>();
484 if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
485 auto* new_shape = schema_params->new_shape();
486 FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
487 params->shape, error_reporter);
488 params->num_dimensions = new_shape->Length();
489 }
490 builtin_data = reinterpret_cast<void*>(params);
491 break;
492 }
493 case BuiltinOperator_SKIP_GRAM: {
494 TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
495 if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
496 params->ngram_size = skip_gram_params->ngram_size();
497 params->max_skip_size = skip_gram_params->max_skip_size();
498 params->include_all_ngrams = skip_gram_params->include_all_ngrams();
499 }
500 builtin_data = reinterpret_cast<void*>(params);
501 break;
502 }
503 case BuiltinOperator_SPACE_TO_DEPTH: {
504 auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
505 if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
506 params->block_size = schema_params->block_size();
507 }
508 builtin_data = reinterpret_cast<void*>(params);
509 break;
510 }
511 case BuiltinOperator_GATHER: {
512 TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
513 params->axis = 0;
514 if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
515 params->axis = gather_params->axis();
516 }
517
518 builtin_data = reinterpret_cast<void*>(params);
519 break;
520 }
521 case BuiltinOperator_SPACE_TO_BATCH_ND: {
522 break;
523 }
524 case BuiltinOperator_BATCH_TO_SPACE_ND: {
525 break;
526 }
527 case BuiltinOperator_TRANSPOSE: {
528 break;
529 }
530 case BuiltinOperator_MEAN: {
531 auto* params = MallocPOD<TfLiteMeanParams>();
532 if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
533 params->keep_dims = schema_params->keep_dims();
534 }
535 builtin_data = reinterpret_cast<void*>(params);
536 break;
537 }
538 case BuiltinOperator_SPLIT: {
539 auto* params = MallocPOD<TfLiteSplitParams>();
540 if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
541 params->num_splits = schema_params->num_splits();
542 }
543 builtin_data = reinterpret_cast<void*>(params);
544 break;
545 }
546 case BuiltinOperator_SQUEEZE: {
547 auto* params = MallocPOD<TfLiteSqueezeParams>();
548 if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
549 const auto& squeeze_dims = schema_params->squeeze_dims();
550 FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
551 params->squeeze_dims, error_reporter);
552 params->num_squeeze_dims = squeeze_dims->Length();
553 }
554 builtin_data = reinterpret_cast<void*>(params);
555 break;
556 }
557 case BuiltinOperator_STRIDED_SLICE: {
558 auto* params = MallocPOD<TfLiteStridedSliceParams>();
559 if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
560 params->begin_mask = schema_params->begin_mask();
561 params->end_mask = schema_params->end_mask();
562 params->ellipsis_mask = schema_params->ellipsis_mask();
563 params->new_axis_mask = schema_params->new_axis_mask();
564 params->shrink_axis_mask = schema_params->shrink_axis_mask();
565 }
566 builtin_data = reinterpret_cast<void*>(params);
567 break;
568 }
569 }
570 return builtin_data;
571 }
572
573 } // namespace
574
ParseNodes(const flatbuffers::Vector<flatbuffers::Offset<Operator>> * operators,Interpreter * interpreter)575 TfLiteStatus InterpreterBuilder::ParseNodes(
576 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
577 Interpreter* interpreter) {
578 TfLiteStatus status = kTfLiteOk;
579 for (int i = 0; i < operators->Length(); ++i) {
580 const auto* op = operators->Get(i);
581 int index = op->opcode_index();
582 if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
583 error_reporter_->Report("Missing registration for opcode_index %d\n",
584 index);
585 status = kTfLiteError;
586 continue;
587 }
588 const TfLiteRegistration* reg =
589 flatbuffer_op_index_to_registration_[op->opcode_index()];
590 if (reg == nullptr) {
591 error_reporter_->Report("Skipping op for opcode_index %d\n", index);
592 status = kTfLiteError;
593 continue;
594 }
595
596 auto op_type =
597 flatbuffer_op_index_to_registration_types_[op->opcode_index()];
598 if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
599 error_reporter_->Report(
600 "Found builtin operator %s with custom options.\n",
601 EnumNameBuiltinOperator(op_type));
602 }
603 if (op->custom_options()) {
604 interpreter->AddNodeWithParameters(
605 FlatBufferIntArrayToVector(op->inputs()),
606 FlatBufferIntArrayToVector(op->outputs()),
607 reinterpret_cast<const char*>(op->custom_options()->data()),
608 op->custom_options()->size(), nullptr, reg);
609 } else {
610 interpreter->AddNodeWithParameters(
611 FlatBufferIntArrayToVector(op->inputs()),
612 FlatBufferIntArrayToVector(op->outputs()), nullptr, 0,
613 ParseOpData(op, op_type, error_reporter_), reg);
614 }
615 }
616
617 return status;
618 }
619
ParseTensors(const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * buffers,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,Interpreter * interpreter)620 TfLiteStatus InterpreterBuilder::ParseTensors(
621 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
622 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
623 Interpreter* interpreter) {
624 TfLiteStatus status = kTfLiteOk;
625
626 // A little helper to get the names of inputs and outputs. Note that they
627 // must outlive the interpreter.
628 auto get_name = [](const tflite::Tensor* t) -> const char* {
629 auto name = t->name();
630 if (name) return name->c_str();
631 return kEmptyTensorName;
632 };
633
634 for (int i = 0; i < tensors->Length(); ++i) {
635 const auto* tensor = tensors->Get(i);
636 std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
637
638 TfLiteQuantizationParams quantization;
639 quantization.scale = 0;
640 quantization.zero_point = 0;
641 auto* q_params = tensor->quantization();
642 if (q_params) {
643 // Note that the schema could hold per-channel quantization parameters
644 // but we really only support one value for the whole tensor.
645 // TODO(aselle): This breaks as well if these are nullptr's.
646 // TODO(aselle): This assumes non per-channel quantization.
647 if (q_params->scale()) quantization.scale = q_params->scale()->Get(0);
648 if (q_params->zero_point())
649 quantization.zero_point = q_params->zero_point()->Get(0);
650 }
651
652 TfLiteType type;
653 switch (tensor->type()) {
654 case TensorType_FLOAT32:
655 type = kTfLiteFloat32;
656 break;
657 case TensorType_INT32:
658 type = kTfLiteInt32;
659 break;
660 case TensorType_UINT8:
661 type = kTfLiteUInt8;
662 break;
663 case TensorType_INT64:
664 type = kTfLiteInt64;
665 break;
666 case TensorType_STRING:
667 type = kTfLiteString;
668 break;
669 default:
670 // tensorType = ArrayType::NONE;
671 error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n",
672 EnumNameTensorType(tensor->type()),
673 tensor->type());
674 status = kTfLiteError;
675 continue;
676 }
677 auto get_readonly_data = [&](const char** buffer_data,
678 size_t* buffer_size) {
679 // TODO(aselle): Check what happens if we have an unspecified size
680 // constant.
681 *buffer_data = nullptr;
682 if (tensor->buffer() == 0) return kTfLiteOk;
683 if (tensor->buffer() >= buffers->size()) {
684 error_reporter_->Report(
685 "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
686 i, tensor->buffer(), buffers->size());
687 return kTfLiteError;
688 }
689 if (auto* buffer = (*buffers)[tensor->buffer()]) {
690 if (auto* array = buffer->data()) {
691 if (size_t size = array->size()) {
692 *buffer_size = size;
693 *buffer_data = reinterpret_cast<const char*>(array->data());
694 return kTfLiteOk;
695 }
696 }
697 }
698 return kTfLiteOk;
699 };
700 size_t buffer_size = 0;
701 const char* buffer_ptr;
702 TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
703
704 if (buffer_ptr) {
705 if (interpreter->SetTensorParametersReadOnly(
706 i, type, get_name(tensor), dims, quantization, buffer_ptr,
707 buffer_size, allocation_) != kTfLiteOk) {
708 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
709 i);
710 status = kTfLiteError;
711 }
712 } else {
713 if (interpreter->SetTensorParametersReadWrite(
714 i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
715 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
716 i);
717 status = kTfLiteError;
718 }
719 }
720 }
721
722 return status;
723 }
724
operator ()(std::unique_ptr<Interpreter> * interpreter)725 TfLiteStatus InterpreterBuilder::operator()(
726 std::unique_ptr<Interpreter>* interpreter) {
727 if (!interpreter) {
728 error_reporter_->Report(
729 "Null output pointer passed to InterpreterBuilder.");
730 return kTfLiteError;
731 }
732
733 // Safe exit by deleting partially created interpreter, to reduce verbosity
734 // on error conditions. Use by return cleanup_on_error();
735 auto cleanup_and_error = [&interpreter]() {
736 interpreter->reset();
737 return kTfLiteError;
738 };
739
740 if (!model_) {
741 error_reporter_->Report("Null pointer passed in as model.");
742 return cleanup_and_error();
743 }
744
745 if (model_->version() != TFLITE_SCHEMA_VERSION) {
746 error_reporter_->Report(
747 "Model provided is schema version %d not equal "
748 "to supported version %d.\n",
749 model_->version(), TFLITE_SCHEMA_VERSION);
750 return cleanup_and_error();
751 }
752
753 if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
754 error_reporter_->Report("Registration failed.\n");
755 return cleanup_and_error();
756 }
757
758 // Flatbuffer model schemas define a list of opcodes independent of the graph.
759 // We first map those to registrations. This reduces string lookups for custom
760 // ops since we only do it once per custom op rather than once per custom op
761 // invocation in the model graph.
762 // Construct interpreter with correct number of tensors and operators.
763 auto* subgraphs = model_->subgraphs();
764 auto* buffers = model_->buffers();
765 if (subgraphs->size() != 1) {
766 error_reporter_->Report("Only 1 subgraph is currently supported.\n");
767 return cleanup_and_error();
768 }
769 const tflite::SubGraph* subgraph = (*subgraphs)[0];
770 auto operators = subgraph->operators();
771 auto tensors = subgraph->tensors();
772 if (!operators || !tensors || !buffers) {
773 error_reporter_->Report(
774 "Did not get operators, tensors, or buffers in input flat buffer.\n");
775 return cleanup_and_error();
776 }
777 interpreter->reset(new Interpreter(error_reporter_));
778 if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
779 return cleanup_and_error();
780 }
781
782 // Parse inputs/outputs
783 (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
784 (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
785
786 // Finally setup nodes and tensors
787 if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
788 return cleanup_and_error();
789 if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
790 return cleanup_and_error();
791
792 return kTfLiteOk;
793 }
794
795 } // namespace tflite
796