• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/model_builder.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "tensorflow/lite/allocation.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/core/api/verifier.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/stderr_reporter.h"
30 #include "tensorflow/lite/string_type.h"
31 
32 namespace tflite {
33 
34 namespace {
35 
36 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)37 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
38   return e ? e : DefaultErrorReporter();
39 }
40 
41 }  // namespace
42 
43 #ifndef TFLITE_MCU
44 // Loads a model from `filename`. If `mmap_file` is true then use mmap,
45 // otherwise make a copy of the model in a buffer.
GetAllocationFromFile(const char * filename,ErrorReporter * error_reporter)46 std::unique_ptr<Allocation> GetAllocationFromFile(
47     const char* filename, ErrorReporter* error_reporter) {
48   std::unique_ptr<Allocation> allocation;
49   if (MMAPAllocation::IsSupported()) {
50     allocation = std::make_unique<MMAPAllocation>(filename, error_reporter);
51   } else {
52     allocation = std::make_unique<FileCopyAllocation>(filename, error_reporter);
53   }
54   return allocation;
55 }
56 
BuildFromFile(const char * filename,ErrorReporter * error_reporter)57 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
58     const char* filename, ErrorReporter* error_reporter) {
59   error_reporter = ValidateErrorReporter(error_reporter);
60   return BuildFromAllocation(GetAllocationFromFile(filename, error_reporter),
61                              error_reporter);
62 }
63 
VerifyAndBuildFromFile(const char * filename,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)64 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
65     const char* filename, TfLiteVerifier* extra_verifier,
66     ErrorReporter* error_reporter) {
67   error_reporter = ValidateErrorReporter(error_reporter);
68   return VerifyAndBuildFromAllocation(
69       GetAllocationFromFile(filename, error_reporter), extra_verifier,
70       error_reporter);
71 }
72 #endif
73 
BuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,ErrorReporter * error_reporter)74 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
75     const char* caller_owned_buffer, size_t buffer_size,
76     ErrorReporter* error_reporter) {
77   error_reporter = ValidateErrorReporter(error_reporter);
78   std::unique_ptr<Allocation> allocation(
79       new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
80   return BuildFromAllocation(std::move(allocation), error_reporter);
81 }
82 
VerifyAndBuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)83 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
84     const char* caller_owned_buffer, size_t buffer_size,
85     TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
86   error_reporter = ValidateErrorReporter(error_reporter);
87   std::unique_ptr<Allocation> allocation(
88       new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
89   return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier,
90                                       error_reporter);
91 }
92 
BuildFromAllocation(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)93 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromAllocation(
94     std::unique_ptr<Allocation> allocation, ErrorReporter* error_reporter) {
95   std::unique_ptr<FlatBufferModel> model(new FlatBufferModel(
96       std::move(allocation), ValidateErrorReporter(error_reporter)));
97   if (!model->initialized()) {
98     model.reset();
99   }
100   return model;
101 }
102 
VerifyAndBuildFromAllocation(std::unique_ptr<Allocation> allocation,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)103 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromAllocation(
104     std::unique_ptr<Allocation> allocation, TfLiteVerifier* extra_verifier,
105     ErrorReporter* error_reporter) {
106   error_reporter = ValidateErrorReporter(error_reporter);
107   if (!allocation || !allocation->valid()) {
108     TF_LITE_REPORT_ERROR(error_reporter, "The model allocation is null/empty");
109     return nullptr;
110   }
111 
112   flatbuffers::Verifier base_verifier(
113       reinterpret_cast<const uint8_t*>(allocation->base()),
114       allocation->bytes());
115   if (!VerifyModelBuffer(base_verifier)) {
116     TF_LITE_REPORT_ERROR(error_reporter,
117                          "The model is not a valid Flatbuffer buffer");
118     return nullptr;
119   }
120 
121   if (extra_verifier &&
122       !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
123                               allocation->bytes(), error_reporter)) {
124     // The verifier will have already logged an appropriate error message.
125     return nullptr;
126   }
127 
128   return BuildFromAllocation(std::move(allocation), error_reporter);
129 }
130 
BuildFromModel(const tflite::Model * caller_owned_model_spec,ErrorReporter * error_reporter)131 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
132     const tflite::Model* caller_owned_model_spec,
133     ErrorReporter* error_reporter) {
134   error_reporter = ValidateErrorReporter(error_reporter);
135 
136   std::unique_ptr<FlatBufferModel> model(
137       new FlatBufferModel(caller_owned_model_spec, error_reporter));
138   if (!model->initialized()) {
139     model.reset();
140   }
141   return model;
142 }
143 
GetMinimumRuntime() const144 string FlatBufferModel::GetMinimumRuntime() const {
145   if (!model_ || !model_->metadata()) return "";
146 
147   for (int i = 0; i < model_->metadata()->size(); ++i) {
148     auto metadata = model_->metadata()->Get(i);
149     if (metadata->name()->str() == "min_runtime_version") {
150       auto buf = metadata->buffer();
151       auto* buffer = (*model_->buffers())[buf];
152       auto* array = buffer->data();
153       // Get the real length of the runtime string, since there might be
154       // trailing
155       // '\0's in the buffer.
156       for (int len = 0; len < array->size(); ++len) {
157         if (array->data()[len] == '\0') {
158           return string(reinterpret_cast<const char*>(array->data()), len);
159         }
160       }
161       // If there is no '\0' in the buffer, this indicates that the flatbuffer
162       // is malformed.
163       TF_LITE_REPORT_ERROR(
164           error_reporter_,
165           "Min_runtime_version in model metadata is malformed");
166       break;
167     }
168   }
169   return "";
170 }
171 
ReadAllMetadata() const172 std::map<std::string, std::string> FlatBufferModel::ReadAllMetadata() const {
173   std::map<std::string, std::string> keys_values;
174   if (!model_ || !model_->metadata() || !model_->buffers()) return keys_values;
175 
176   for (int i = 0; i < model_->metadata()->size(); ++i) {
177     auto metadata = model_->metadata()->Get(i);
178     auto buf = metadata->buffer();
179     const tflite::Buffer* buffer = (*model_->buffers())[buf];
180     if (!buffer || !buffer->data()) continue;
181     const flatbuffers::Vector<uint8_t>* array = buffer->data();
182     if (!array) continue;
183     std::string val =
184         string(reinterpret_cast<const char*>(array->data()), array->size());
185     // Skip if key or value of metadata is empty.
186     if (!metadata->name() || val.empty()) continue;
187     keys_values[metadata->name()->str()] = val;
188   }
189   return keys_values;
190 }
191 
CheckModelIdentifier() const192 bool FlatBufferModel::CheckModelIdentifier() const {
193   if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
194     const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
195     error_reporter_->Report(
196         "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
197         ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
198     return false;
199   }
200   return true;
201 }
202 
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)203 FlatBufferModel::FlatBufferModel(const Model* model,
204                                  ErrorReporter* error_reporter)
205     : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
206 
FlatBufferModel(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)207 FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
208                                  ErrorReporter* error_reporter)
209     : error_reporter_(ValidateErrorReporter(error_reporter)),
210       allocation_(std::move(allocation)) {
211   if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) {
212     return;
213   }
214 
215   model_ = ::tflite::GetModel(allocation_->base());
216 }
217 
~FlatBufferModel()218 FlatBufferModel::~FlatBufferModel() {}
219 
220 }  // namespace tflite
221