• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/model_builder.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
25 #include "tensorflow/lite/allocation.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/core/api/verifier.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow/lite/stderr_reporter.h"
30 #include "tensorflow/lite/string_type.h"
31 
32 namespace tflite {
33 
34 namespace {
35 
36 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)37 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
38   return e ? e : DefaultErrorReporter();
39 }
40 
41 }  // namespace
42 
43 #ifndef TFLITE_MCU
44 // Loads a model from `filename`. If `mmap_file` is true then use mmap,
45 // otherwise make a copy of the model in a buffer.
GetAllocationFromFile(const char * filename,bool mmap_file,ErrorReporter * error_reporter,bool use_nnapi)46 std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
47                                                   bool mmap_file,
48                                                   ErrorReporter* error_reporter,
49                                                   bool use_nnapi) {
50   std::unique_ptr<Allocation> allocation;
51   if (mmap_file && MMAPAllocation::IsSupported()) {
52     allocation.reset(new MMAPAllocation(filename, error_reporter));
53   } else {
54     allocation.reset(new FileCopyAllocation(filename, error_reporter));
55   }
56   return allocation;
57 }
58 
BuildFromFile(const char * filename,ErrorReporter * error_reporter)59 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
60     const char* filename, ErrorReporter* error_reporter) {
61   error_reporter = ValidateErrorReporter(error_reporter);
62 
63   std::unique_ptr<FlatBufferModel> model;
64   auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
65                                           error_reporter, /*use_nnapi=*/true);
66   model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
67   if (!model->initialized()) model.reset();
68   return model;
69 }
70 
VerifyAndBuildFromFile(const char * filename,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)71 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
72     const char* filename, TfLiteVerifier* extra_verifier,
73     ErrorReporter* error_reporter) {
74   error_reporter = ValidateErrorReporter(error_reporter);
75 
76   std::unique_ptr<FlatBufferModel> model;
77   auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
78                                           error_reporter, /*use_nnapi=*/true);
79 
80   flatbuffers::Verifier base_verifier(
81       reinterpret_cast<const uint8_t*>(allocation->base()),
82       allocation->bytes());
83   if (!VerifyModelBuffer(base_verifier)) {
84     TF_LITE_REPORT_ERROR(error_reporter,
85                          "The model is not a valid Flatbuffer file");
86     return nullptr;
87   }
88 
89   if (extra_verifier &&
90       !extra_verifier->Verify(static_cast<const char*>(allocation->base()),
91                               allocation->bytes(), error_reporter)) {
92     return model;
93   }
94   model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
95   if (!model->initialized()) model.reset();
96   return model;
97 }
98 #endif
99 
BuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,ErrorReporter * error_reporter)100 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
101     const char* caller_owned_buffer, size_t buffer_size,
102     ErrorReporter* error_reporter) {
103   error_reporter = ValidateErrorReporter(error_reporter);
104 
105   std::unique_ptr<FlatBufferModel> model;
106   std::unique_ptr<Allocation> allocation(
107       new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
108   model.reset(new FlatBufferModel(std::move(allocation), error_reporter));
109   if (!model->initialized()) model.reset();
110   return model;
111 }
112 
VerifyAndBuildFromBuffer(const char * caller_owned_buffer,size_t buffer_size,TfLiteVerifier * extra_verifier,ErrorReporter * error_reporter)113 std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromBuffer(
114     const char* caller_owned_buffer, size_t buffer_size,
115     TfLiteVerifier* extra_verifier, ErrorReporter* error_reporter) {
116   error_reporter = ValidateErrorReporter(error_reporter);
117 
118   flatbuffers::Verifier base_verifier(
119       reinterpret_cast<const uint8_t*>(caller_owned_buffer), buffer_size);
120   if (!VerifyModelBuffer(base_verifier)) {
121     TF_LITE_REPORT_ERROR(error_reporter,
122                          "The model is not a valid Flatbuffer buffer");
123     return nullptr;
124   }
125 
126   if (extra_verifier && !extra_verifier->Verify(caller_owned_buffer,
127                                                 buffer_size, error_reporter)) {
128     return nullptr;
129   }
130 
131   return BuildFromBuffer(caller_owned_buffer, buffer_size, error_reporter);
132 }
133 
BuildFromModel(const tflite::Model * caller_owned_model_spec,ErrorReporter * error_reporter)134 std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
135     const tflite::Model* caller_owned_model_spec,
136     ErrorReporter* error_reporter) {
137   error_reporter = ValidateErrorReporter(error_reporter);
138 
139   std::unique_ptr<FlatBufferModel> model;
140   model.reset(new FlatBufferModel(caller_owned_model_spec, error_reporter));
141   if (!model->initialized()) model.reset();
142   return model;
143 }
144 
GetMinimumRuntime() const145 string FlatBufferModel::GetMinimumRuntime() const {
146   if (!model_ || !model_->metadata()) return "";
147 
148   for (int i = 0; i < model_->metadata()->size(); ++i) {
149     auto metadata = model_->metadata()->Get(i);
150     if (metadata->name()->str() == "min_runtime_version") {
151       auto buf = metadata->buffer();
152       auto* buffer = (*model_->buffers())[buf];
153       auto* array = buffer->data();
154       // Get the real length of the runtime string, since there might be
155       // trailing
156       // '\0's in the buffer.
157       for (int len = 0; len < array->size(); ++len) {
158         if (array->data()[len] == '\0') {
159           return string(reinterpret_cast<const char*>(array->data()), len);
160         }
161       }
162       // If there is no '\0' in the buffer, this indicates that the flatbuffer
163       // is malformed.
164       TF_LITE_REPORT_ERROR(
165           error_reporter_,
166           "Min_runtime_version in model metadata is malformed");
167       break;
168     }
169   }
170   return "";
171 }
172 
CheckModelIdentifier() const173 bool FlatBufferModel::CheckModelIdentifier() const {
174   if (!tflite::ModelBufferHasIdentifier(allocation_->base())) {
175     const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base());
176     error_reporter_->Report(
177         "Model provided has model identifier '%c%c%c%c', should be '%s'\n",
178         ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier());
179     return false;
180   }
181   return true;
182 }
183 
FlatBufferModel(const Model * model,ErrorReporter * error_reporter)184 FlatBufferModel::FlatBufferModel(const Model* model,
185                                  ErrorReporter* error_reporter)
186     : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {}
187 
FlatBufferModel(std::unique_ptr<Allocation> allocation,ErrorReporter * error_reporter)188 FlatBufferModel::FlatBufferModel(std::unique_ptr<Allocation> allocation,
189                                  ErrorReporter* error_reporter)
190     : error_reporter_(ValidateErrorReporter(error_reporter)),
191       allocation_(std::move(allocation)) {
192   if (!allocation_->valid() || !CheckModelIdentifier()) return;
193 
194   model_ = ::tflite::GetModel(allocation_->base());
195 }
196 
~FlatBufferModel()197 FlatBufferModel::~FlatBufferModel() {}
198 
199 }  // namespace tflite
200