• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
17 
18 #include <functional>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/str_format.h"
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "lib/zip.h"  // from @org_libzip
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow_lite_support/cc/common.h"
27 #include "tensorflow_lite_support/cc/port/status_macros.h"
28 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
29 
30 #if TFLITE_USE_C_API
31 #include "tensorflow/lite/c/c_api.h"
32 #else
33 #include "tensorflow/lite/model_builder.h"
34 #endif
35 
36 namespace tflite {
37 namespace metadata {
38 
39 namespace {
40 constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
41 
42 using ::absl::StatusCode;
43 using ::flatbuffers::Offset;
44 using ::flatbuffers::Vector;
45 using ::tflite::TensorMetadata;
46 using ::tflite::support::CreateStatusWithPayload;
47 using ::tflite::support::TfLiteSupportStatus;
48 
49 // Helper class that takes a callback function, and invokes it in its
50 // destructor.
51 class SimpleCleanUp {
52  public:
SimpleCleanUp(std::function<void ()> callback)53   explicit SimpleCleanUp(std::function<void()> callback)
54       : callback_(std::move(callback)) {}
55 
~SimpleCleanUp()56   ~SimpleCleanUp() {
57     if (callback_ != nullptr) callback_();
58   }
59 
60   // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever
61   // executing at all. Once a SimpleCleanUp object has been `std::move(...)`-ed,
62   // it may not be read from again.
Cancel()63   void Cancel() && { callback_ = nullptr; }
64 
65  private:
66   std::function<void()> callback_;
67 };
68 
69 // Util to get item from src_vector specified by index.
70 template <typename T>
GetItemFromVector(const flatbuffers::Vector<flatbuffers::Offset<T>> * src_vector,int index)71 const T* GetItemFromVector(
72     const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
73   if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
74     return nullptr;
75   }
76   return src_vector->Get(index);
77 }
78 }  // namespace
79 
80 /* static */
81 tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char * buffer_data,size_t buffer_size)82 ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
83                                               size_t buffer_size) {
84   // Use absl::WrapUnique() to call private constructor:
85   // https://abseil.io/tips/126.
86   std::unique_ptr<ModelMetadataExtractor> extractor =
87       absl::WrapUnique(new ModelMetadataExtractor());
88   RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
89   return extractor;
90 }
91 
92 /* static */
93 tflite::support::StatusOr<const tflite::ProcessUnit*>
FindFirstProcessUnit(const tflite::TensorMetadata & tensor_metadata,tflite::ProcessUnitOptions type)94 ModelMetadataExtractor::FindFirstProcessUnit(
95     const tflite::TensorMetadata& tensor_metadata,
96     tflite::ProcessUnitOptions type) {
97   const tflite::ProcessUnit* result = nullptr;
98   if (tensor_metadata.process_units() == nullptr) {
99     return result;
100   }
101   for (const auto process_unit : *tensor_metadata.process_units()) {
102     if (process_unit->options_type() == type) {
103       if (result != nullptr) {
104         return CreateStatusWithPayload(
105             StatusCode::kInvalidArgument,
106             absl::StrCat("Found multiple ProcessUnits with type=",
107                          tflite::EnumNameProcessUnitOptions(type),
108                          ", expected at most one."),
109             TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
110       }
111       result = process_unit;
112     }
113   }
114   return result;
115 }
116 
117 /* static */
FindFirstAssociatedFileName(const tflite::TensorMetadata & tensor_metadata,tflite::AssociatedFileType type,absl::string_view locale)118 std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
119     const tflite::TensorMetadata& tensor_metadata,
120     tflite::AssociatedFileType type, absl::string_view locale) {
121   if (tensor_metadata.associated_files() == nullptr) {
122     return std::string();
123   }
124   for (const auto associated_file : *tensor_metadata.associated_files()) {
125     if (associated_file->type() != type || associated_file->name() == nullptr) {
126       continue;
127     }
128     if (locale.empty() || (associated_file->locale() != nullptr &&
129                            locale == associated_file->locale()->str())) {
130       return associated_file->name()->str();
131     }
132   }
133   return std::string();
134 }
135 
InitFromModelBuffer(const char * buffer_data,size_t buffer_size)136 absl::Status ModelMetadataExtractor::InitFromModelBuffer(
137     const char* buffer_data, size_t buffer_size) {
138   // Rely on the simplest, base flatbuffers verifier. Here is not the place to
139   // e.g. use an OpResolver: we just want to make sure the buffer is valid to
140   // access the metadata.
141   flatbuffers::Verifier verifier = flatbuffers::Verifier(
142       reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
143   if (!tflite::VerifyModelBuffer(verifier)) {
144     return CreateStatusWithPayload(
145         StatusCode::kInvalidArgument,
146         "The model is not a valid FlatBuffer buffer.",
147         TfLiteSupportStatus::kInvalidFlatBufferError);
148   }
149   model_ = tflite::GetModel(buffer_data);
150   if (model_->metadata() == nullptr) {
151     // Not all models have metadata, which is OK. `GetModelMetadata()` then
152     // returns nullptr.
153     return absl::OkStatus();
154   }
155   // Look for the "TFLITE_METADATA" field, if any.
156   for (int i = 0; i < model_->metadata()->size(); ++i) {
157     const auto metadata = model_->metadata()->Get(i);
158     if (metadata->name()->str() != kMetadataBufferName) {
159       continue;
160     }
161     const auto buffer_index = metadata->buffer();
162     const auto metadata_buffer =
163         model_->buffers()->Get(buffer_index)->data()->data();
164     if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
165       return CreateStatusWithPayload(
166           StatusCode::kInvalidArgument,
167           absl::StrFormat(
168               "Invalid metadata schema version: expected %s, got %s",
169               absl::string_view(tflite::ModelMetadataIdentifier())
170                   .substr(
171                       0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
172               // Returned identifier is not null terminated; has to be
173               // truncated.
174               absl::string_view(
175                   flatbuffers::GetBufferIdentifier(metadata_buffer))
176                   .substr(
177                       0,
178                       flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
179           TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
180     }
181     model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
182     if (model_metadata_ == nullptr) {
183       return CreateStatusWithPayload(StatusCode::kInternal,
184                                      "Expected Model Metadata not to be null.");
185     }
186     return ExtractAssociatedFiles(buffer_data, buffer_size);
187     break;
188   }
189   return absl::OkStatus();
190 }
191 
ExtractAssociatedFiles(const char * buffer_data,size_t buffer_size)192 absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
193     const char* buffer_data, size_t buffer_size) {
194   // Setup libzip error reporting.
195   zip_error_t error;
196   zip_error_init(&error);
197   auto zip_error_cleanup = SimpleCleanUp([&error] { zip_error_fini(&error); });
198 
199   // Initialize zip source.
200   zip_source_t* src =
201       zip_source_buffer_create(buffer_data, buffer_size, /*freep=*/0, &error);
202   if (src == nullptr) {
203     return CreateStatusWithPayload(
204         StatusCode::kUnknown,
205         absl::StrFormat("Can't create zip source from model buffer: %s",
206                         zip_error_strerror(&error)),
207         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
208   }
209   auto zip_source_cleanup = SimpleCleanUp([src] { zip_source_free(src); });
210 
211   // Try opening zip source.
212   zip* zip_archive = zip_open_from_source(src, /*flags=*/0, &error);
213   if (zip_archive == nullptr) {
214     // It's OK if it fails: this means there are no associated files with this
215     // model.
216     return absl::OkStatus();
217   }
218   auto zip_archive_cleanup =
219       SimpleCleanUp([zip_archive] { zip_close(zip_archive); });
220   // As per the documentation [1] for zip_source_free, it should not be called
221   // after a successful call to zip_open_from_source.
222   //
223   // [1]: https://libzip.org/documentation/zip_source_free.html
224   std::move(zip_source_cleanup).Cancel();
225 
226   const int num_files = zip_get_num_entries(zip_archive, /*flags=*/0);
227   for (int index = 0; index < num_files; ++index) {
228     // Get file stats.
229     struct zip_stat zip_file_stat;
230     zip_stat_init(&zip_file_stat);
231     zip_stat_index(zip_archive, index, /*flags=*/0, &zip_file_stat);
232     absl::string_view filename = zip_file_stat.name;
233     const auto unzip_filesize = zip_file_stat.size;
234 
235     // Open file.
236     zip_file* zip_file = zip_fopen_index(zip_archive, index, /*flags=*/0);
237     if (zip_file == nullptr) {
238       return CreateStatusWithPayload(
239           StatusCode::kUnknown,
240           absl::StrFormat("Unable to open associated file with name: %s",
241                           zip_file_stat.name),
242           TfLiteSupportStatus::kMetadataAssociatedFileZipError);
243     }
244     auto zip_file_cleanup = SimpleCleanUp([zip_file] { zip_fclose(zip_file); });
245 
246     // Unzip file.
247     char* unzip_buffer = new char[unzip_filesize];
248     auto unzip_buffer_cleanup =
249         SimpleCleanUp([unzip_buffer] { delete[] unzip_buffer; });
250     if (zip_fread(zip_file, unzip_buffer, unzip_filesize) != unzip_filesize) {
251       return CreateStatusWithPayload(
252           StatusCode::kUnknown,
253           absl::StrFormat("Unzipping failed for file: %s.", filename),
254           TfLiteSupportStatus::kMetadataAssociatedFileZipError);
255     }
256 
257     // Copy file contents in map.
258     associated_files_[filename] = std::string(unzip_buffer, unzip_filesize);
259   }
260   return absl::OkStatus();
261 }
262 
263 tflite::support::StatusOr<absl::string_view>
GetAssociatedFile(const std::string & filename) const264 ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
265   auto it = associated_files_.find(filename);
266   if (it == associated_files_.end()) {
267     return CreateStatusWithPayload(
268         StatusCode::kNotFound,
269         absl::StrFormat("No associated file with name: %s", filename),
270         TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
271   }
272   return it->second;
273 }
274 
275 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const276 ModelMetadataExtractor::GetInputTensorMetadata() const {
277   if (model_metadata_ == nullptr ||
278       model_metadata_->subgraph_metadata() == nullptr) {
279     return nullptr;
280   }
281   return model_metadata_->subgraph_metadata()
282       ->Get(kDefaultSubgraphIndex)
283       ->input_tensor_metadata();
284 }
285 
GetInputTensorMetadata(int index) const286 const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
287     int index) const {
288   return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
289                                                    index);
290 }
291 
GetInputTensorCount() const292 int ModelMetadataExtractor::GetInputTensorCount() const {
293   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
294       input_tensor_metadata = GetInputTensorMetadata();
295   return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
296 }
297 
298 const Vector<Offset<TensorMetadata>>*
GetOutputTensorMetadata() const299 ModelMetadataExtractor::GetOutputTensorMetadata() const {
300   if (model_metadata_ == nullptr ||
301       model_metadata_->subgraph_metadata() == nullptr) {
302     return nullptr;
303   }
304   return model_metadata_->subgraph_metadata()
305       ->Get(kDefaultSubgraphIndex)
306       ->output_tensor_metadata();
307 }
308 
GetOutputTensorMetadata(int index) const309 const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
310     int index) const {
311   return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
312                                                    index);
313 }
314 
GetOutputTensorCount() const315 int ModelMetadataExtractor::GetOutputTensorCount() const {
316   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
317       output_tensor_metadata = GetOutputTensorMetadata();
318   return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
319 }
320 
321 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const322 ModelMetadataExtractor::GetInputProcessUnits() const {
323   if (model_metadata_ == nullptr ||
324       model_metadata_->subgraph_metadata() == nullptr) {
325     return nullptr;
326   }
327   return model_metadata_->subgraph_metadata()
328       ->Get(kDefaultSubgraphIndex)
329       ->input_process_units();
330 }
331 
GetInputProcessUnit(int index) const332 const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
333     int index) const {
334   return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
335 }
336 
GetInputProcessUnitsCount() const337 int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
338   const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
339       GetInputProcessUnits();
340   return input_process_units == nullptr ? 0 : input_process_units->size();
341 }
342 
343 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const344 ModelMetadataExtractor::GetOutputProcessUnits() const {
345   if (model_metadata_ == nullptr ||
346       model_metadata_->subgraph_metadata() == nullptr) {
347     return nullptr;
348   }
349   return model_metadata_->subgraph_metadata()
350       ->Get(kDefaultSubgraphIndex)
351       ->output_process_units();
352 }
353 
GetOutputProcessUnit(int index) const354 const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
355     int index) const {
356   return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
357 }
358 
GetOutputProcessUnitsCount() const359 int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
360   const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
361       GetOutputProcessUnits();
362   return output_process_units == nullptr ? 0 : output_process_units->size();
363 }
364 
365 }  // namespace metadata
366 }  // namespace tflite
367