1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <fcntl.h>
16 #include <stdint.h>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <sys/stat.h>
20 #include <sys/types.h>
21
22 #include "tensorflow/lite/allocation.h"
23 #include "tensorflow/lite/c/builtin_op_data.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/core/api/error_reporter.h"
26 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
27 #include "tensorflow/lite/model.h"
28 #ifndef TFLITE_MCU
29 #include "tensorflow/lite/nnapi_delegate.h"
30 #endif
31 #include "tensorflow/lite/version.h"
32
33 namespace tflite {
34
35 namespace {
36 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)37 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
38 return e ? e : DefaultErrorReporter();
39 }
40 } // namespace
41
42 const char* kEmptyTensorName = "";
43
44 // Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
45 // we avoid the absl dependency for binary size reasons.
46 #ifdef __has_attribute
47 #define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
48 #else
49 #define TFLITE_HAS_ATTRIBUTE(x) 0
50 #endif
51
52 #if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
53 // Using weak symbols for the flex delegate allows automatic injection of the
54 // delegate simply by adding it as a dependency. See also the strong override in
55 // lite/delegates/flex/delegate.cc.
AcquireFlexDelegate()56 __attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
57 return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
58 }
59 #else
60 Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
61 #endif
62
63 #ifndef TFLITE_MCU
64 // Loads a model from `filename`. If `mmap_file` is true then use mmap,
65 // otherwise make a copy of the model in a buffer.
GetAllocationFromFile(const char * filename,bool mmap_file,ErrorReporter * error_reporter,bool use_nnapi)66 std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
67 bool mmap_file,
68 ErrorReporter* error_reporter,
69 bool use_nnapi) {
70 std::unique_ptr<Allocation> allocation;
71 if (mmap_file && MMAPAllocation::IsSupported()) {
72 if (use_nnapi && NNAPIDelegate::IsSupported())
73 allocation.reset(new NNAPIAllocation(filename, error_reporter));
74 else
75 allocation.reset(new MMAPAllocation(filename, error_reporter));
76 } else {
77 allocation.reset(new FileCopyAllocation(filename, error_reporter));
78 }
79 return allocation;
80 }
81
BuildFromFile(const char * filename,ErrorReporter * error_reporter)82 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
83 const char* filename, ErrorReporter* error_reporter) {
84 error_reporter = ValidateErrorReporter(error_reporter);
85
86 std::unique_ptr<FlatBufferModel> model;
87 auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
88 error_reporter, /*use_nnapi=*/true);
89 model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
90 if (!model->initialized()) model.reset();
91 return model;
92 }
93
VerifyAndBuildFromFile(const char * filename,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)94 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
95 const char* filename, TfLiteVerifier* extra_verifier,
96 ErrorReporter* error_reporter) {
97 error_reporter = ValidateErrorReporter(error_reporter);
98
99 std::unique_ptr<FlatBufferModel> model;
100 auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
101 error_reporter, /*use_nnapi=*/true);
102
103 flatbuffers::Verifier base_verifier(
104 reinterpret_cast<const uint8_t*>(allocation->base()),
105 allocation->bytes());
106 if (!VerifyModelBuffer(base_verifier)) {
107 error_reporter->Report("The model is not a valid Flatbuffer file");
108 return nullptr;
109 }
110
111 if (extra_verifier &&
112 !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
113 allocation->bytes(), error_reporter)) {
114 return model;
115 }
116 model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
117 if (!model->initialized()) model.reset();
118 return model;
119 }
120 #endif
121
BuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,ErrorReporter * error_reporter)122 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
123 const char* caller_owned_buffer, size_t buffer_size,
124 ErrorReporter* error_reporter) {
125 error_reporter = ValidateErrorReporter(error_reporter);
126
127 std::unique_ptr<FlatBufferModel> model;
128 std::unique_ptr<Allocation> allocation(
129 new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
130 model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
131 if (!model->initialized()) model.reset();
132 return model;
133 }
134
VerifyAndBuildFromBuffer(const char * buffer,size_t buffer_size,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)135 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
136 const char* buffer, size_t buffer_size, TfLiteVerifier* extra_verifier,
137 ErrorReporter* error_reporter) {
138 error_reporter = ValidateErrorReporter(error_reporter);
139
140 flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t*>(buffer),
141 buffer_size);
142 if (!VerifyModelBuffer(base_verifier)) {
143 error_reporter->Report("The model is not a valid Flatbuffer buffer");
144 return nullptr;
145 }
146
147 if (extra_verifier &&
148 !extra_verifier->Verify(buffer, buffer_size, error_reporter)) {
149 return nullptr;
150 }
151
152 return BuildFromBuffer(buffer, buffer_size, error_reporter);
153 }
154
BuildFromModel(const tflite::Model * caller_owned_model_spec,ErrorReporter * error_reporter)155 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
156 const tflite::Model* caller_owned_model_spec,
157 ErrorReporter* error_reporter) {
158 error_reporter = ValidateErrorReporter(error_reporter);
159
160 std::unique_ptr<FlatBufferModel> model;
161 model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
162 if (!model->initialized()) model.reset();
163 return model;
164 }
165
CheckModelIdentifier() const166 bool FlatBufferModel::CheckModelIdentifier() const {
167 if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
168 const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
169 error_reporter_->Report(
170 "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
171 ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
172 return false;
173 }
174 return true;
175 }
176
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)177 FlatBufferModel::FlatBufferModel(const Model* model,
178 ErrorReporter* error_reporter)
179 : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
180
FlatBufferModel(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)181 FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
182 ErrorReporter* error_reporter)
183 : error_reporter_(ValidateErrorReporter(error_reporter)),
184 allocation_(std::move(allocation)) {
185 if (!allocation_->valid() || !CheckModelIdentifier()) return;
186
187 model_ = ::tflite::GetModel(allocation_->base());
188 }
189
~FlatBufferModel()190 FlatBufferModel::~FlatBufferModel() {}
191
InterpreterBuilder(const FlatBufferModel & model,const OpResolver & op_resolver)192 InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
193 const OpResolver& op_resolver)
194 : model_(model.GetModel()),
195 op_resolver_(op_resolver),
196 error_reporter_(ValidateErrorReporter(model.error_reporter())),
197 allocation_(model.allocation()) {}
198
InterpreterBuilder(const::tflite::Model * model,const OpResolver & op_resolver,ErrorReporter * error_reporter)199 InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
200 const OpResolver& op_resolver,
201 ErrorReporter* error_reporter)
202 : model_(model),
203 op_resolver_(op_resolver),
204 error_reporter_(ValidateErrorReporter(error_reporter)) {}
205
~InterpreterBuilder()206 InterpreterBuilder::~InterpreterBuilder() {}
207
BuildLocalIndexToRegistrationMapping()208 TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
209 TfLiteStatus status = kTfLiteOk;
210 auto opcodes = model_->operator_codes();
211 for (const OperatorCode* opcode : *opcodes) {
212 const TfLiteRegistration* registration = nullptr;
213 status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
214 ®istration);
215 if (status != kTfLiteOk) {
216 return status;
217 }
218 flatbuffer_op_index_to_registration_.push_back(registration);
219 }
220 return status;
221 }
222
223 namespace {
224 template <class T>
FlatBufferIntArrayToVector(T * flat_array)225 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
226 // Initialize shape of tensors with null shape. Empty vectors are converted
227 // to nullptr for models that are constructed via flatbuffers::Pack.
228 if (flat_array == nullptr) {
229 return {};
230 }
231 std::vector<int> ret(flat_array->Length());
232 for (int i = 0; i < flat_array->Length(); i++) {
233 ret[i] = flat_array->Get(i);
234 }
235 return ret;
236 }
237
238 // Used to determine how the op data parsing function creates its working space.
239 class MallocDataAllocator : public BuiltinDataAllocator {
240 public:
Allocate(size_t size)241 void* Allocate(size_t size) override { return malloc(size); }
Deallocate(void * data)242 void Deallocate(void* data) override { free(data); }
243 };
244
245 } // namespace
246
ParseNodes(const flatbuffers::Vector<flatbuffers::Offset<Operator>> * operators,Subgraph * subgraph)247 TfLiteStatus InterpreterBuilder::ParseNodes(
248 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
249 Subgraph* subgraph) {
250 TfLiteStatus status = kTfLiteOk;
251
252 // Reduce the number of redundant allocations
253 subgraph->ReserveNodes(operators->Length());
254
255 for (int i = 0; i < operators->Length(); ++i) {
256 const auto* op = operators->Get(i);
257 int index = op->opcode_index();
258 if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
259 error_reporter_->Report("Missing registration for opcode_index %d\n",
260 index);
261 status = kTfLiteError;
262 continue;
263 }
264
265 const TfLiteRegistration* registration =
266 flatbuffer_op_index_to_registration_[index];
267 if (registration == nullptr) {
268 error_reporter_->Report("Skipping op for opcode_index %d\n", index);
269 status = kTfLiteError;
270 continue;
271 }
272
273 BuiltinOperator op_type =
274 static_cast<BuiltinOperator>(registration->builtin_code);
275
276 if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
277 error_reporter_->Report(
278 "Found builtin operator %s with custom options.\n",
279 EnumNameBuiltinOperator(op_type));
280 }
281
282 if (op->custom_options()) {
283 subgraph->AddNodeWithParameters(
284 FlatBufferIntArrayToVector(op->inputs()),
285 FlatBufferIntArrayToVector(op->outputs()),
286 reinterpret_cast<const char*>(op->custom_options()->data()),
287 op->custom_options()->size(), nullptr, registration);
288 } else {
289 void* builtin_data = nullptr;
290 MallocDataAllocator malloc_allocator;
291 TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
292 &malloc_allocator, &builtin_data));
293 subgraph->AddNodeWithParameters(FlatBufferIntArrayToVector(op->inputs()),
294 FlatBufferIntArrayToVector(op->outputs()),
295 nullptr, 0, builtin_data, registration);
296 }
297 }
298
299 return status;
300 }
301
ParseQuantization(const QuantizationParameters * src_quantization,TfLiteQuantization * quantization)302 TfLiteStatus InterpreterBuilder::ParseQuantization(
303 const QuantizationParameters* src_quantization,
304 TfLiteQuantization* quantization) {
305 quantization->type = kTfLiteNoQuantization;
306 if (!src_quantization || !src_quantization->scale() ||
307 src_quantization->scale()->size() == 0) {
308 return kTfLiteOk;
309 }
310 if (!src_quantization->zero_point()) {
311 error_reporter_->Report(
312 "Quantization parameters has non-null scale but null zero_point.");
313 return kTfLiteError;
314 }
315
316 // Ensure that the number of scales matches the number of zero_points.
317 if (src_quantization->scale()->size() !=
318 src_quantization->zero_point()->size()) {
319 error_reporter_->Report(
320 "QuantizationParam has %d zero_point values and %d scale values. Must "
321 "have same number.",
322 src_quantization->zero_point()->size(),
323 src_quantization->scale()->size());
324 return kTfLiteError;
325 }
326
327 // Affine-quantization.
328 quantization->type = kTfLiteAffineQuantization;
329 auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
330 malloc(sizeof(TfLiteAffineQuantization)));
331 const size_t num_scales = src_quantization->scale()->size();
332 affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
333 affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales);
334 for (size_t i = 0; i < num_scales; ++i) {
335 affine_quantization->scale->data[i] = src_quantization->scale()->Get(i);
336 affine_quantization->zero_point->data[i] =
337 src_quantization->zero_point()->Get(i);
338 }
339 if (src_quantization->quantized_dimension() < 0 ||
340 src_quantization->quantized_dimension() >= num_scales) {
341 error_reporter_->Report(
342 "quantized_dimension must be in range [0, %d). Was %d.", num_scales,
343 src_quantization->quantized_dimension());
344 return kTfLiteError;
345 }
346 affine_quantization->quantized_dimension =
347 src_quantization->quantized_dimension();
348 quantization->params = reinterpret_cast<void*>(affine_quantization);
349 return kTfLiteOk;
350 }
351
ParseTensors(const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * buffers,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,Subgraph * subgraph)352 TfLiteStatus InterpreterBuilder::ParseTensors(
353 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
354 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
355 Subgraph* subgraph) {
356 TfLiteStatus status = kTfLiteOk;
357
358 // A little helper to get the names of inputs and outputs. Note that they
359 // must outlive the subgraph.
360 auto get_name = [](const tflite::Tensor* t) -> const char* {
361 auto name = t->name();
362 if (name) return name->c_str();
363 return kEmptyTensorName;
364 };
365
366 for (int i = 0; i < tensors->Length(); ++i) {
367 const auto* tensor = tensors->Get(i);
368 std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
369
370 const auto* src_quantization = tensor->quantization();
371 TfLiteQuantization quantization;
372 if (ParseQuantization(src_quantization, &quantization) != kTfLiteOk) {
373 status = kTfLiteError;
374 continue;
375 }
376
377 TfLiteType type;
378 if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
379 kTfLiteOk) {
380 status = kTfLiteError;
381 continue;
382 }
383 auto get_readonly_data = [&](const char** buffer_data,
384 size_t* buffer_size) {
385 // TODO(aselle): Check what happens if we have an unspecified size
386 // constant.
387 *buffer_data = nullptr;
388 if (tensor->buffer() == 0) return kTfLiteOk;
389 if (tensor->buffer() >= buffers->size()) {
390 error_reporter_->Report(
391 "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
392 i, tensor->buffer(), buffers->size());
393 return kTfLiteError;
394 }
395 if (auto* buffer = (*buffers)[tensor->buffer()]) {
396 if (auto* array = buffer->data()) {
397 if (size_t size = array->size()) {
398 *buffer_size = size;
399 *buffer_data = reinterpret_cast<const char*>(array->data());
400 return kTfLiteOk;
401 }
402 }
403 }
404 return kTfLiteOk;
405 };
406 size_t buffer_size = 0;
407 const char* buffer_ptr;
408 TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
409
410 bool is_variable = tensor->is_variable();
411 if (buffer_ptr) {
412 if (is_variable) {
413 error_reporter_->Report(
414 "Tensor %d is a variable tensor with buffer. "
415 "It's not supported now.\n",
416 i);
417 status = kTfLiteError;
418 }
419
420 if (subgraph->SetTensorParametersReadOnly(
421 i, type, get_name(tensor), dims, quantization, buffer_ptr,
422 buffer_size, allocation_) != kTfLiteOk) {
423 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
424 i);
425 status = kTfLiteError;
426 }
427 } else {
428 if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
429 dims, quantization,
430 is_variable) != kTfLiteOk) {
431 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
432 i);
433 status = kTfLiteError;
434 }
435 }
436 }
437
438 return status;
439 }
440
ApplyDelegates(Interpreter * interpreter)441 TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) {
442 // TODO(b/117561550): Move flex delegate application to the OpResolver.
443 if (AcquireFlexDelegate == nullptr) {
444 return kTfLiteOk;
445 }
446
447 bool has_flex_op = false;
448 for (const auto* registration : flatbuffer_op_index_to_registration_) {
449 if ((registration->builtin_code == BuiltinOperator_CUSTOM) &&
450 IsFlexOp(registration->custom_name)) {
451 has_flex_op = true;
452 break;
453 }
454 }
455
456 if (!has_flex_op) {
457 return kTfLiteOk;
458 }
459
460 if (auto flex_delegate = AcquireFlexDelegate()) {
461 return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate));
462 }
463
464 return kTfLiteOk;
465 }
466
operator ()(std::unique_ptr<Interpreter> * interpreter)467 TfLiteStatus InterpreterBuilder::operator()(
468 std::unique_ptr<Interpreter>* interpreter) {
469 return operator()(interpreter, /*num_threads=*/-1);
470 }
471
operator ()(std::unique_ptr<Interpreter> * interpreter,int num_threads)472 TfLiteStatus InterpreterBuilder::operator()(
473 std::unique_ptr<Interpreter>* interpreter, int num_threads) {
474 if (!interpreter) {
475 error_reporter_->Report(
476 "Null output pointer passed to InterpreterBuilder.");
477 return kTfLiteError;
478 }
479
480 // Safe exit by deleting partially created interpreter, to reduce verbosity
481 // on error conditions. Use by return cleanup_on_error();
482 auto cleanup_and_error = [&interpreter]() {
483 interpreter->reset();
484 return kTfLiteError;
485 };
486
487 if (!model_) {
488 error_reporter_->Report("Null pointer passed in as model.");
489 return cleanup_and_error();
490 }
491
492 if (model_->version() != TFLITE_SCHEMA_VERSION) {
493 error_reporter_->Report(
494 "Model provided is schema version %d not equal "
495 "to supported version %d.\n",
496 model_->version(), TFLITE_SCHEMA_VERSION);
497 return cleanup_and_error();
498 }
499
500 if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
501 error_reporter_->Report("Registration failed.\n");
502 return cleanup_and_error();
503 }
504
505 // Flatbuffer model schemas define a list of opcodes independent of the graph.
506 // We first map those to registrations. This reduces string lookups for custom
507 // ops since we only do it once per custom op rather than once per custom op
508 // invocation in the model graph.
509 // Construct interpreter with correct number of tensors and operators.
510 auto* subgraphs = model_->subgraphs();
511 auto* buffers = model_->buffers();
512
513 if (subgraphs->size() == 0) {
514 error_reporter_->Report("No subgraph in the model.\n");
515 return cleanup_and_error();
516 }
517
518 interpreter->reset(new Interpreter(error_reporter_));
519 (*interpreter)->SetNumThreads(num_threads);
520 if (subgraphs->Length() > 1) {
521 (*interpreter)->AddSubgraphs(subgraphs->Length() - 1);
522 }
523
524 for (int subgraph_index = 0; subgraph_index < subgraphs->Length();
525 ++subgraph_index) {
526 const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
527 tflite::Subgraph* modified_subgraph =
528 (*interpreter)->subgraph(subgraph_index);
529 auto operators = subgraph->operators();
530 auto tensors = subgraph->tensors();
531 if (!operators || !tensors || !buffers) {
532 error_reporter_->Report(
533 "Did not get operators, tensors, or buffers in subgraph %d.\n",
534 subgraph_index);
535 return cleanup_and_error();
536 }
537 if (modified_subgraph->AddTensors(tensors->Length()) != kTfLiteOk) {
538 return cleanup_and_error();
539 }
540 // Set num threads
541 // Parse inputs/outputs
542 modified_subgraph->SetInputs(
543 FlatBufferIntArrayToVector(subgraph->inputs()));
544 modified_subgraph->SetOutputs(
545 FlatBufferIntArrayToVector(subgraph->outputs()));
546
547 // Finally setup nodes and tensors
548 if (ParseNodes(operators, modified_subgraph) != kTfLiteOk)
549 return cleanup_and_error();
550 if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
551 return cleanup_and_error();
552
553 std::vector<int> variables;
554 for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
555 auto* tensor = modified_subgraph->tensor(i);
556 if (tensor->is_variable) {
557 variables.push_back(i);
558 }
559 }
560 modified_subgraph->SetVariables(std::move(variables));
561 }
562
563 if (ApplyDelegates(interpreter->get()) != kTfLiteOk)
564 return cleanup_and_error();
565
566 return kTfLiteOk;
567 }
568
569 } // namespace tflite
570