1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
17
18 #include <functional>
19
20 #include "absl/memory/memory.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/str_format.h"
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "lib/zip.h" // from @org_libzip
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow_lite_support/cc/common.h"
27 #include "tensorflow_lite_support/cc/port/status_macros.h"
28 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
29
30 #if TFLITE_USE_C_API
31 #include "tensorflow/lite/c/c_api.h"
32 #else
33 #include "tensorflow/lite/model_builder.h"
34 #endif
35
36 namespace tflite {
37 namespace metadata {
38
39 namespace {
40 constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
41
42 using ::absl::StatusCode;
43 using ::flatbuffers::Offset;
44 using ::flatbuffers::Vector;
45 using ::tflite::TensorMetadata;
46 using ::tflite::support::CreateStatusWithPayload;
47 using ::tflite::support::TfLiteSupportStatus;
48
49 // Helper class that takes a callback function, and invokes it in its
50 // destructor.
51 class SimpleCleanUp {
52 public:
SimpleCleanUp(std::function<void ()> callback)53 explicit SimpleCleanUp(std::function<void()> callback)
54 : callback_(std::move(callback)) {}
55
~SimpleCleanUp()56 ~SimpleCleanUp() {
57 if (callback_ != nullptr) callback_();
58 }
59
60 // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever
61 // executing at all. Once a SimpleCleanUp object has been `std::move(...)`-ed,
62 // it may not be read from again.
Cancel()63 void Cancel() && { callback_ = nullptr; }
64
65 private:
66 std::function<void()> callback_;
67 };
68
69 // Util to get item from src_vector specified by index.
70 template <typename T>
GetItemFromVector(const flatbuffers::Vector<flatbuffers::Offset<T>> * src_vector,int index)71 const T* GetItemFromVector(
72 const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
73 if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
74 return nullptr;
75 }
76 return src_vector->Get(index);
77 }
78 } // namespace
79
80 /* static */
81 tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char * buffer_data,size_t buffer_size)82 ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
83 size_t buffer_size) {
84 // Use absl::WrapUnique() to call private constructor:
85 // https://abseil.io/tips/126.
86 std::unique_ptr<ModelMetadataExtractor> extractor =
87 absl::WrapUnique(new ModelMetadataExtractor());
88 RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
89 return extractor;
90 }
91
92 /* static */
93 tflite::support::StatusOr<const tflite::ProcessUnit*>
FindFirstProcessUnit(const tflite::TensorMetadata & tensor_metadata,tflite::ProcessUnitOptions type)94 ModelMetadataExtractor::FindFirstProcessUnit(
95 const tflite::TensorMetadata& tensor_metadata,
96 tflite::ProcessUnitOptions type) {
97 const tflite::ProcessUnit* result = nullptr;
98 if (tensor_metadata.process_units() == nullptr) {
99 return result;
100 }
101 for (const auto process_unit : *tensor_metadata.process_units()) {
102 if (process_unit->options_type() == type) {
103 if (result != nullptr) {
104 return CreateStatusWithPayload(
105 StatusCode::kInvalidArgument,
106 absl::StrCat("Found multiple ProcessUnits with type=",
107 tflite::EnumNameProcessUnitOptions(type),
108 ", expected at most one."),
109 TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
110 }
111 result = process_unit;
112 }
113 }
114 return result;
115 }
116
117 /* static */
FindFirstAssociatedFileName(const tflite::TensorMetadata & tensor_metadata,tflite::AssociatedFileType type,absl::string_view locale)118 std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
119 const tflite::TensorMetadata& tensor_metadata,
120 tflite::AssociatedFileType type, absl::string_view locale) {
121 if (tensor_metadata.associated_files() == nullptr) {
122 return std::string();
123 }
124 for (const auto associated_file : *tensor_metadata.associated_files()) {
125 if (associated_file->type() != type || associated_file->name() == nullptr) {
126 continue;
127 }
128 if (locale.empty() || (associated_file->locale() != nullptr &&
129 locale == associated_file->locale()->str())) {
130 return associated_file->name()->str();
131 }
132 }
133 return std::string();
134 }
135
InitFromModelBuffer(const char * buffer_data,size_t buffer_size)136 absl::Status ModelMetadataExtractor::InitFromModelBuffer(
137 const char* buffer_data, size_t buffer_size) {
138 // Rely on the simplest, base flatbuffers verifier. Here is not the place to
139 // e.g. use an OpResolver: we just want to make sure the buffer is valid to
140 // access the metadata.
141 flatbuffers::Verifier verifier = flatbuffers::Verifier(
142 reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
143 if (!tflite::VerifyModelBuffer(verifier)) {
144 return CreateStatusWithPayload(
145 StatusCode::kInvalidArgument,
146 "The model is not a valid FlatBuffer buffer.",
147 TfLiteSupportStatus::kInvalidFlatBufferError);
148 }
149 model_ = tflite::GetModel(buffer_data);
150 if (model_->metadata() == nullptr) {
151 // Not all models have metadata, which is OK. `GetModelMetadata()` then
152 // returns nullptr.
153 return absl::OkStatus();
154 }
155 // Look for the "TFLITE_METADATA" field, if any.
156 for (int i = 0; i < model_->metadata()->size(); ++i) {
157 const auto metadata = model_->metadata()->Get(i);
158 if (metadata->name()->str() != kMetadataBufferName) {
159 continue;
160 }
161 const auto buffer_index = metadata->buffer();
162 const auto metadata_buffer =
163 model_->buffers()->Get(buffer_index)->data()->data();
164 if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
165 return CreateStatusWithPayload(
166 StatusCode::kInvalidArgument,
167 absl::StrFormat(
168 "Invalid metadata schema version: expected %s, got %s",
169 absl::string_view(tflite::ModelMetadataIdentifier())
170 .substr(
171 0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
172 // Returned identifier is not null terminated; has to be
173 // truncated.
174 absl::string_view(
175 flatbuffers::GetBufferIdentifier(metadata_buffer))
176 .substr(
177 0,
178 flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
179 TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
180 }
181 model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
182 if (model_metadata_ == nullptr) {
183 return CreateStatusWithPayload(StatusCode::kInternal,
184 "Expected Model Metadata not to be null.");
185 }
186 return ExtractAssociatedFiles(buffer_data, buffer_size);
187 break;
188 }
189 return absl::OkStatus();
190 }
191
ExtractAssociatedFiles(const char * buffer_data,size_t buffer_size)192 absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
193 const char* buffer_data, size_t buffer_size) {
194 // Setup libzip error reporting.
195 zip_error_t error;
196 zip_error_init(&error);
197 auto zip_error_cleanup = SimpleCleanUp([&error] { zip_error_fini(&error); });
198
199 // Initialize zip source.
200 zip_source_t* src =
201 zip_source_buffer_create(buffer_data, buffer_size, /*freep=*/0, &error);
202 if (src == nullptr) {
203 return CreateStatusWithPayload(
204 StatusCode::kUnknown,
205 absl::StrFormat("Can't create zip source from model buffer: %s",
206 zip_error_strerror(&error)),
207 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
208 }
209 auto zip_source_cleanup = SimpleCleanUp([src] { zip_source_free(src); });
210
211 // Try opening zip source.
212 zip* zip_archive = zip_open_from_source(src, /*flags=*/0, &error);
213 if (zip_archive == nullptr) {
214 // It's OK if it fails: this means there are no associated files with this
215 // model.
216 return absl::OkStatus();
217 }
218 auto zip_archive_cleanup =
219 SimpleCleanUp([zip_archive] { zip_close(zip_archive); });
220 // As per the documentation [1] for zip_source_free, it should not be called
221 // after a successful call to zip_open_from_source.
222 //
223 // [1]: https://libzip.org/documentation/zip_source_free.html
224 std::move(zip_source_cleanup).Cancel();
225
226 const int num_files = zip_get_num_entries(zip_archive, /*flags=*/0);
227 for (int index = 0; index < num_files; ++index) {
228 // Get file stats.
229 struct zip_stat zip_file_stat;
230 zip_stat_init(&zip_file_stat);
231 zip_stat_index(zip_archive, index, /*flags=*/0, &zip_file_stat);
232 absl::string_view filename = zip_file_stat.name;
233 const auto unzip_filesize = zip_file_stat.size;
234
235 // Open file.
236 zip_file* zip_file = zip_fopen_index(zip_archive, index, /*flags=*/0);
237 if (zip_file == nullptr) {
238 return CreateStatusWithPayload(
239 StatusCode::kUnknown,
240 absl::StrFormat("Unable to open associated file with name: %s",
241 zip_file_stat.name),
242 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
243 }
244 auto zip_file_cleanup = SimpleCleanUp([zip_file] { zip_fclose(zip_file); });
245
246 // Unzip file.
247 char* unzip_buffer = new char[unzip_filesize];
248 auto unzip_buffer_cleanup =
249 SimpleCleanUp([unzip_buffer] { delete[] unzip_buffer; });
250 if (zip_fread(zip_file, unzip_buffer, unzip_filesize) != unzip_filesize) {
251 return CreateStatusWithPayload(
252 StatusCode::kUnknown,
253 absl::StrFormat("Unzipping failed for file: %s.", filename),
254 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
255 }
256
257 // Copy file contents in map.
258 associated_files_[filename] = std::string(unzip_buffer, unzip_filesize);
259 }
260 return absl::OkStatus();
261 }
262
263 tflite::support::StatusOr<absl::string_view>
GetAssociatedFile(const std::string & filename) const264 ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
265 auto it = associated_files_.find(filename);
266 if (it == associated_files_.end()) {
267 return CreateStatusWithPayload(
268 StatusCode::kNotFound,
269 absl::StrFormat("No associated file with name: %s", filename),
270 TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
271 }
272 return it->second;
273 }
274
275 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const276 ModelMetadataExtractor::GetInputTensorMetadata() const {
277 if (model_metadata_ == nullptr ||
278 model_metadata_->subgraph_metadata() == nullptr) {
279 return nullptr;
280 }
281 return model_metadata_->subgraph_metadata()
282 ->Get(kDefaultSubgraphIndex)
283 ->input_tensor_metadata();
284 }
285
GetInputTensorMetadata(int index) const286 const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
287 int index) const {
288 return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
289 index);
290 }
291
GetInputTensorCount() const292 int ModelMetadataExtractor::GetInputTensorCount() const {
293 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
294 input_tensor_metadata = GetInputTensorMetadata();
295 return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
296 }
297
298 const Vector<Offset<TensorMetadata>>*
GetOutputTensorMetadata() const299 ModelMetadataExtractor::GetOutputTensorMetadata() const {
300 if (model_metadata_ == nullptr ||
301 model_metadata_->subgraph_metadata() == nullptr) {
302 return nullptr;
303 }
304 return model_metadata_->subgraph_metadata()
305 ->Get(kDefaultSubgraphIndex)
306 ->output_tensor_metadata();
307 }
308
GetOutputTensorMetadata(int index) const309 const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
310 int index) const {
311 return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
312 index);
313 }
314
GetOutputTensorCount() const315 int ModelMetadataExtractor::GetOutputTensorCount() const {
316 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
317 output_tensor_metadata = GetOutputTensorMetadata();
318 return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
319 }
320
321 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const322 ModelMetadataExtractor::GetInputProcessUnits() const {
323 if (model_metadata_ == nullptr ||
324 model_metadata_->subgraph_metadata() == nullptr) {
325 return nullptr;
326 }
327 return model_metadata_->subgraph_metadata()
328 ->Get(kDefaultSubgraphIndex)
329 ->input_process_units();
330 }
331
GetInputProcessUnit(int index) const332 const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
333 int index) const {
334 return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
335 }
336
GetInputProcessUnitsCount() const337 int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
338 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
339 GetInputProcessUnits();
340 return input_process_units == nullptr ? 0 : input_process_units->size();
341 }
342
343 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const344 ModelMetadataExtractor::GetOutputProcessUnits() const {
345 if (model_metadata_ == nullptr ||
346 model_metadata_->subgraph_metadata() == nullptr) {
347 return nullptr;
348 }
349 return model_metadata_->subgraph_metadata()
350 ->Get(kDefaultSubgraphIndex)
351 ->output_process_units();
352 }
353
GetOutputProcessUnit(int index) const354 const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
355 int index) const {
356 return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
357 }
358
GetOutputProcessUnitsCount() const359 int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
360 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
361 GetOutputProcessUnits();
362 return output_process_units == nullptr ? 0 : output_process_units->size();
363 }
364
365 } // namespace metadata
366 } // namespace tflite
367