1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/model_builder.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
25 #include "tensorflow/lite/allocation.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/core/api/verifier.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/stderr_reporter.h"
30 #include "tensorflow/lite/string_type.h"
31
32 namespace tflite {
33
34 namespace {
35
36 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)37 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
38 return e ? e : DefaultErrorReporter();
39 }
40
41 } // namespace
42
43 #ifndef TFLITE_MCU
44 // Loads a model from `filename`. If `mmap_file` is true then use mmap,
45 // otherwise make a copy of the model in a buffer.
GetAllocationFromFile(const char * filename,bool mmap_file,ErrorReporter * error_reporter,bool use_nnapi)46 std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
47 bool mmap_file,
48 ErrorReporter* error_reporter,
49 bool use_nnapi) {
50 std::unique_ptr<Allocation> allocation;
51 if (mmap_file && MMAPAllocation::IsSupported()) {
52 allocation.reset(new MMAPAllocation(filename, error_reporter));
53 } else {
54 allocation.reset(new FileCopyAllocation(filename, error_reporter));
55 }
56 return allocation;
57 }
58
BuildFromFile(const char * filename,ErrorReporter * error_reporter)59 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
60 const char* filename, ErrorReporter* error_reporter) {
61 error_reporter = ValidateErrorReporter(error_reporter);
62
63 std::unique_ptr<FlatBufferModel> model;
64 auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
65 error_reporter, /*use_nnapi=*/true);
66 model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
67 if (!model->initialized()) model.reset();
68 return model;
69 }
70
VerifyAndBuildFromFile(const char * filename,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)71 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
72 const char* filename, TfLiteVerifier* extra_verifier,
73 ErrorReporter* error_reporter) {
74 error_reporter = ValidateErrorReporter(error_reporter);
75
76 std::unique_ptr<FlatBufferModel> model;
77 auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
78 error_reporter, /*use_nnapi=*/true);
79
80 flatbuffers::Verifier base_verifier(
81 reinterpret_cast<const uint8_t*>(allocation->base()),
82 allocation->bytes());
83 if (!VerifyModelBuffer(base_verifier)) {
84 TF_LITE_REPORT_ERROR(error_reporter,
85 "The model is not a valid Flatbuffer file");
86 return nullptr;
87 }
88
89 if (extra_verifier &&
90 !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
91 allocation->bytes(), error_reporter)) {
92 return model;
93 }
94 model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
95 if (!model->initialized()) model.reset();
96 return model;
97 }
98 #endif
99
BuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,ErrorReporter * error_reporter)100 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
101 const char* caller_owned_buffer, size_t buffer_size,
102 ErrorReporter* error_reporter) {
103 error_reporter = ValidateErrorReporter(error_reporter);
104
105 std::unique_ptr<FlatBufferModel> model;
106 std::unique_ptr<Allocation> allocation(
107 new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
108 model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
109 if (!model->initialized()) model.reset();
110 return model;
111 }
112
VerifyAndBuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)113 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
114 const char* caller_owned_buffer, size_t buffer_size,
115 TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
116 error_reporter = ValidateErrorReporter(error_reporter);
117
118 flatbuffers::Verifier base_verifier(
119 reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
120 if (!VerifyModelBuffer(base_verifier)) {
121 TF_LITE_REPORT_ERROR(error_reporter,
122 "The model is not a valid Flatbuffer buffer");
123 return nullptr;
124 }
125
126 if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
127 buffer_size, error_reporter)) {
128 return nullptr;
129 }
130
131 return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
132 }
133
BuildFromModel(const tflite::Model * caller_owned_model_spec,ErrorReporter * error_reporter)134 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
135 const tflite::Model* caller_owned_model_spec,
136 ErrorReporter* error_reporter) {
137 error_reporter = ValidateErrorReporter(error_reporter);
138
139 std::unique_ptr<FlatBufferModel> model;
140 model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
141 if (!model->initialized()) model.reset();
142 return model;
143 }
144
GetMinimumRuntime() const145 string FlatBufferModel::GetMinimumRuntime() const {
146 if (!model_ || !model_->metadata()) return "";
147
148 for (int i = 0; i < model_->metadata()->size(); ++i) {
149 auto metadata = model_->metadata()->Get(i);
150 if (metadata->name()->str() == "min_runtime_version") {
151 auto buf = metadata->buffer();
152 auto* buffer = (*model_->buffers())[buf];
153 auto* array = buffer->data();
154 // Get the real length of the runtime string, since there might be
155 // trailing
156 // '\0's in the buffer.
157 for (int len = 0; len < array->size(); ++len) {
158 if (array->data()[len] == '\0') {
159 return string(reinterpret_cast<const char*>(array->data()), len);
160 }
161 }
162 // If there is no '\0' in the buffer, this indicates that the flatbuffer
163 // is malformed.
164 TF_LITE_REPORT_ERROR(
165 error_reporter_,
166 "Min_runtime_version in model metadata is malformed");
167 break;
168 }
169 }
170 return "";
171 }
172
CheckModelIdentifier() const173 bool FlatBufferModel::CheckModelIdentifier() const {
174 if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
175 const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
176 error_reporter_->Report(
177 "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
178 ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
179 return false;
180 }
181 return true;
182 }
183
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)184 FlatBufferModel::FlatBufferModel(const Model* model,
185 ErrorReporter* error_reporter)
186 : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
187
FlatBufferModel(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)188 FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
189 ErrorReporter* error_reporter)
190 : error_reporter_(ValidateErrorReporter(error_reporter)),
191 allocation_(std::move(allocation)) {
192 if (!allocation_->valid() || !CheckModelIdentifier()) return;
193
194 model_ = ::tflite::GetModel(allocation_->base());
195 }
196
~FlatBufferModel()197 FlatBufferModel::~FlatBufferModel() {}
198
199 } // namespace tflite
200