1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/model_builder.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
25 #include "tensorflow/lite/allocation.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/core/api/verifier.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/stderr_reporter.h"
30 #include "tensorflow/lite/string_type.h"
31
32 namespace tflite {
33
34 namespace {
35
36 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)37 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
38 return e ? e : DefaultErrorReporter();
39 }
40
41 } // namespace
42
43 #ifndef TFLITE_MCU
44 // Loads a model from `filename`. If `mmap_file` is true then use mmap,
45 // otherwise make a copy of the model in a buffer.
GetAllocationFromFile(const char * filename,ErrorReporter * error_reporter)46 std::unique_ptr<Allocation> GetAllocationFromFile(
47 const char* filename, ErrorReporter* error_reporter) {
48 std::unique_ptr<Allocation> allocation;
49 if (MMAPAllocation::IsSupported()) {
50 allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
51 } else {
52 allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
53 }
54 return allocation;
55 }
56
BuildFromFile(const char * filename,ErrorReporter * error_reporter)57 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
58 const char* filename, ErrorReporter* error_reporter) {
59 error_reporter = ValidateErrorReporter(error_reporter);
60 return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
61 error_reporter);
62 }
63
VerifyAndBuildFromFile(const char * filename,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)64 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
65 const char* filename, TfLiteVerifier* extra_verifier,
66 ErrorReporter* error_reporter) {
67 error_reporter = ValidateErrorReporter(error_reporter);
68 return VerifyAndBuildFromAllocation(
69 GetAllocationFromFile(filename, error_reporter), extra_verifier,
70 error_reporter);
71 }
72 #endif
73
BuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,ErrorReporter * error_reporter)74 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
75 const char* caller_owned_buffer, size_t buffer_size,
76 ErrorReporter* error_reporter) {
77 error_reporter = ValidateErrorReporter(error_reporter);
78 std::unique_ptr<Allocation> allocation(
79 new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
80 return BuildFromAllocation(std::move(allocation), error_reporter);
81 }
82
VerifyAndBuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)83 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
84 const char* caller_owned_buffer, size_t buffer_size,
85 TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
86 error_reporter = ValidateErrorReporter(error_reporter);
87 std::unique_ptr<Allocation> allocation(
88 new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
89 return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
90 error_reporter);
91 }
92
BuildFromAllocation(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)93 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
94 std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
95 std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
96 std::move(allocation), ValidateErrorReporter(error_reporter)));
97 if (!model->initialized()) {
98 model.reset();
99 }
100 return model;
101 }
102
VerifyAndBuildFromAllocation(std::unique_ptr<Allocation> allocation,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)103 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
104 std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
105 ErrorReporter* error_reporter) {
106 error_reporter = ValidateErrorReporter(error_reporter);
107 if (!allocation || !allocation->valid()) {
108 TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
109 return nullptr;
110 }
111
112 flatbuffers::Verifier base_verifier(
113 reinterpret_cast<const uint8_t*>(allocation->base()),
114 allocation->bytes());
115 if (!VerifyModelBuffer(base_verifier)) {
116 TF_LITE_REPORT_ERROR(error_reporter,
117 "The model is not a valid Flatbuffer buffer");
118 return nullptr;
119 }
120
121 if (extra_verifier &&
122 !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
123 allocation->bytes(), error_reporter)) {
124 // The verifier will have already logged an appropriate error message.
125 return nullptr;
126 }
127
128 return BuildFromAllocation(std::move(allocation), error_reporter);
129 }
130
BuildFromModel(const tflite::Model * caller_owned_model_spec,ErrorReporter * error_reporter)131 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
132 const tflite::Model* caller_owned_model_spec,
133 ErrorReporter* error_reporter) {
134 error_reporter = ValidateErrorReporter(error_reporter);
135
136 std::unique_ptr<FlatBufferModel> model(
137 new FlatBufferModel(caller_owned_model_spec, error_reporter));
138 if (!model->initialized()) {
139 model.reset();
140 }
141 return model;
142 }
143
GetMinimumRuntime() const144 string FlatBufferModel::GetMinimumRuntime() const {
145 if (!model_ || !model_->metadata()) return "";
146
147 for (int i = 0; i < model_->metadata()->size(); ++i) {
148 auto metadata = model_->metadata()->Get(i);
149 if (metadata->name()->str() == "min_runtime_version") {
150 auto buf = metadata->buffer();
151 auto* buffer = (*model_->buffers())[buf];
152 auto* array = buffer->data();
153 // Get the real length of the runtime string, since there might be
154 // trailing
155 // '\0's in the buffer.
156 for (int len = 0; len < array->size(); ++len) {
157 if (array->data()[len] == '\0') {
158 return string(reinterpret_cast<const char*>(array->data()), len);
159 }
160 }
161 // If there is no '\0' in the buffer, this indicates that the flatbuffer
162 // is malformed.
163 TF_LITE_REPORT_ERROR(
164 error_reporter_,
165 "Min_runtime_version in model metadata is malformed");
166 break;
167 }
168 }
169 return "";
170 }
171
ReadAllMetadata() const172 std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata() const {
173 std::map<std::string, std::string> keys_values;
174 if (!model_ || !model_->metadata() || !model_->buffers()) return keys_values;
175
176 for (int i = 0; i < model_->metadata()->size(); ++i) {
177 auto metadata = model_->metadata()->Get(i);
178 auto buf = metadata->buffer();
179 const tflite::Buffer* buffer = (*model_->buffers())[buf];
180 if (!buffer || !buffer->data()) continue;
181 const flatbuffers::Vector<uint8_t>* array = buffer->data();
182 if (!array) continue;
183 std::string val =
184 string(reinterpret_cast<const char*>(array->data()), array->size());
185 // Skip if key or value of metadata is empty.
186 if (!metadata->name() || val.empty()) continue;
187 keys_values[metadata->name()->str()] = val;
188 }
189 return keys_values;
190 }
191
CheckModelIdentifier() const192 bool FlatBufferModel::CheckModelIdentifier() const {
193 if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
194 const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
195 error_reporter_->Report(
196 "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
197 ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
198 return false;
199 }
200 return true;
201 }
202
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)203 FlatBufferModel::FlatBufferModel(const Model* model,
204 ErrorReporter* error_reporter)
205 : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
206
FlatBufferModel(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)207 FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
208 ErrorReporter* error_reporter)
209 : error_reporter_(ValidateErrorReporter(error_reporter)),
210 allocation_(std::move(allocation)) {
211 if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
212 return;
213 }
214
215 model_ = ::tflite::GetModel(allocation_->base());
216 }
217
~FlatBufferModel()218 FlatBufferModel::~FlatBufferModel() {}
219
220 } // namespace tflite
221