• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
17 
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/string_view.h"
24 #include "flatbuffers/flatbuffers.h"
25 #include "contrib/minizip/ioapi.h"
26 #include "contrib/minizip/unzip.h"
27 #include "tensorflow/lite/schema/schema_generated.h"
28 #include "tensorflow_lite_support/cc/common.h"
29 #include "tensorflow_lite_support/cc/port/status_macros.h"
30 #include "tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32 
33 namespace tflite {
34 namespace metadata {
35 
36 namespace {
37 constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
38 
39 using ::absl::StatusCode;
40 using ::flatbuffers::Offset;
41 using ::flatbuffers::Vector;
42 using ::tflite::TensorMetadata;
43 using ::tflite::support::CreateStatusWithPayload;
44 using ::tflite::support::TfLiteSupportStatus;
45 
46 // Util to get item from src_vector specified by index.
47 template <typename T>
GetItemFromVector(const flatbuffers::Vector<flatbuffers::Offset<T>> * src_vector,int index)48 const T* GetItemFromVector(
49     const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
50   if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
51     return nullptr;
52   }
53   return src_vector->Get(index);
54 }
55 
56 // Wrapper function around calls to unzip to avoid repeating conversion logic
57 // from error code to Status.
UnzipErrorToStatus(int error)58 absl::Status UnzipErrorToStatus(int error) {
59   if (error != UNZ_OK) {
60     return CreateStatusWithPayload(
61         StatusCode::kUnknown, "Unable to read associated file in zip archive.",
62         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
63   }
64   return absl::OkStatus();
65 }
66 
67 // Stores a file name, position in zip buffer and size.
68 struct ZipFileInfo {
69   std::string name;
70   ZPOS64_T position;
71   ZPOS64_T size;
72 };
73 
74 // Returns the ZipFileInfo corresponding to the current file in the provided
75 // unzFile object.
GetCurrentZipFileInfo(const unzFile & zf)76 tflite::support::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
77   // Open file in raw mode, as data is expected to be uncompressed.
78   int method;
79   RETURN_IF_ERROR(UnzipErrorToStatus(
80       unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
81   if (method != Z_NO_COMPRESSION) {
82     return CreateStatusWithPayload(
83         StatusCode::kUnknown, "Expected uncompressed zip archive.",
84         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
85   }
86 
87   // Get file info a first time to get filename size.
88   unz_file_info64 file_info;
89   RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
90       zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
91       /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
92       /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
93 
94   // Second call to get file name.
95   auto file_name_size = file_info.size_filename;
96   char* c_file_name = (char*)malloc(file_name_size);
97   RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
98       zf, &file_info, c_file_name, file_name_size,
99       /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
100       /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
101   std::string file_name = std::string(c_file_name, file_name_size);
102   free(c_file_name);
103 
104   // Get position in file.
105   auto position = unzGetCurrentFileZStreamPos64(zf);
106   if (position == 0) {
107     return CreateStatusWithPayload(
108         StatusCode::kUnknown, "Unable to read file in zip archive.",
109         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
110   }
111   ZipFileInfo result = {.name = file_name,
112                         .position = position,
113                         .size = file_info.uncompressed_size};
114 
115   // Close file and return.
116   RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
117   return result;
118 }
119 }  // namespace
120 
121 /* static */
122 tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char * buffer_data,size_t buffer_size)123 ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
124                                               size_t buffer_size) {
125   // Use absl::WrapUnique() to call private constructor:
126   // https://abseil.io/tips/126.
127   std::unique_ptr<ModelMetadataExtractor> extractor =
128       absl::WrapUnique(new ModelMetadataExtractor());
129   RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
130   return extractor;
131 }
132 
133 /* static */
134 tflite::support::StatusOr<const tflite::ProcessUnit*>
FindFirstProcessUnit(const tflite::TensorMetadata & tensor_metadata,tflite::ProcessUnitOptions type)135 ModelMetadataExtractor::FindFirstProcessUnit(
136     const tflite::TensorMetadata& tensor_metadata,
137     tflite::ProcessUnitOptions type) {
138   const tflite::ProcessUnit* result = nullptr;
139   if (tensor_metadata.process_units() == nullptr) {
140     return result;
141   }
142   for (const auto process_unit : *tensor_metadata.process_units()) {
143     if (process_unit->options_type() == type) {
144       if (result != nullptr) {
145         return CreateStatusWithPayload(
146             StatusCode::kInvalidArgument,
147             absl::StrCat("Found multiple ProcessUnits with type=",
148                          tflite::EnumNameProcessUnitOptions(type),
149                          ", expected at most one."),
150             TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
151       }
152       result = process_unit;
153     }
154   }
155   return result;
156 }
157 
158 /* static */
FindFirstAssociatedFileName(const tflite::TensorMetadata & tensor_metadata,tflite::AssociatedFileType type,absl::string_view locale)159 std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
160     const tflite::TensorMetadata& tensor_metadata,
161     tflite::AssociatedFileType type, absl::string_view locale) {
162   if (tensor_metadata.associated_files() == nullptr) {
163     return std::string();
164   }
165   for (const auto associated_file : *tensor_metadata.associated_files()) {
166     if (associated_file->type() != type || associated_file->name() == nullptr) {
167       continue;
168     }
169     if (locale.empty() || (associated_file->locale() != nullptr &&
170                            locale == associated_file->locale()->str())) {
171       return associated_file->name()->str();
172     }
173   }
174   return std::string();
175 }
176 
InitFromModelBuffer(const char * buffer_data,size_t buffer_size)177 absl::Status ModelMetadataExtractor::InitFromModelBuffer(
178     const char* buffer_data, size_t buffer_size) {
179   // Rely on the simplest, base flatbuffers verifier. Here is not the place to
180   // e.g. use an OpResolver: we just want to make sure the buffer is valid to
181   // access the metadata.
182   flatbuffers::Verifier verifier = flatbuffers::Verifier(
183       reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
184   if (!tflite::VerifyModelBuffer(verifier)) {
185     return CreateStatusWithPayload(
186         StatusCode::kInvalidArgument,
187         "The model is not a valid FlatBuffer buffer.",
188         TfLiteSupportStatus::kInvalidFlatBufferError);
189   }
190   model_ = tflite::GetModel(buffer_data);
191   if (model_->metadata() == nullptr) {
192     // Not all models have metadata, which is OK. `GetModelMetadata()` then
193     // returns nullptr.
194     return absl::OkStatus();
195   }
196   // Look for the "TFLITE_METADATA" field, if any.
197   for (int i = 0; i < model_->metadata()->size(); ++i) {
198     const auto metadata = model_->metadata()->Get(i);
199     if (!metadata->name()) {
200       continue;
201     }
202     if (metadata->name()->str() != kMetadataBufferName) {
203       continue;
204     }
205     const auto buffer_index = metadata->buffer();
206     const auto metadata_buffer =
207         model_->buffers()->Get(buffer_index)->data()->data();
208     if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
209       return CreateStatusWithPayload(
210           StatusCode::kInvalidArgument,
211           absl::StrFormat(
212               "Invalid metadata schema version: expected %s, got %s",
213               absl::string_view(tflite::ModelMetadataIdentifier())
214                   .substr(
215                       0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
216               // Returned identifier is not null terminated; has to be
217               // truncated.
218               absl::string_view(
219                   flatbuffers::GetBufferIdentifier(metadata_buffer))
220                   .substr(
221                       0,
222                       flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
223           TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
224     }
225     model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
226     if (model_metadata_ == nullptr) {
227       return CreateStatusWithPayload(StatusCode::kInternal,
228                                      "Expected Model Metadata not to be null.");
229     }
230     return ExtractAssociatedFiles(buffer_data, buffer_size);
231     break;
232   }
233   return absl::OkStatus();
234 }
235 
ExtractAssociatedFiles(const char * buffer_data,size_t buffer_size)236 absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
237     const char* buffer_data, size_t buffer_size) {
238   // Create in-memory read-only zip file.
239   ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
240   // Open zip.
241   unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
242   if (zf == nullptr) {
243     // It's OK if it fails: this means there are no associated files with this
244     // model.
245     return absl::OkStatus();
246   }
247   // Get number of files.
248   unz_global_info global_info;
249   if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
250     return CreateStatusWithPayload(
251         StatusCode::kUnknown, "Unable to get zip archive info.",
252         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
253   }
254 
255   // Browse through files in archive.
256   if (global_info.number_entry > 0) {
257     int error = unzGoToFirstFile(zf);
258     while (error == UNZ_OK) {
259       ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
260       // Store result in map.
261       associated_files_[zip_file_info.name] = absl::string_view(
262           buffer_data + zip_file_info.position, zip_file_info.size);
263       error = unzGoToNextFile(zf);
264     }
265     if (error != UNZ_END_OF_LIST_OF_FILE) {
266       return CreateStatusWithPayload(
267           StatusCode::kUnknown,
268           "Unable to read associated file in zip archive.",
269           TfLiteSupportStatus::kMetadataAssociatedFileZipError);
270     }
271   }
272   // Close zip.
273   if (unzClose(zf) != UNZ_OK) {
274     return CreateStatusWithPayload(
275         StatusCode::kUnknown, "Unable to close zip archive.",
276         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
277   }
278   return absl::OkStatus();
279 }
280 
281 tflite::support::StatusOr<absl::string_view>
GetAssociatedFile(const std::string & filename) const282 ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
283   auto it = associated_files_.find(filename);
284   if (it == associated_files_.end()) {
285     return CreateStatusWithPayload(
286         StatusCode::kNotFound,
287         absl::StrFormat("No associated file with name: %s", filename),
288         TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
289   }
290   return it->second;
291 }
292 
293 tflite::support::StatusOr<std::string>
GetModelVersion() const294 ModelMetadataExtractor::GetModelVersion() const {
295   if (model_metadata_ == nullptr) {
296     return CreateStatusWithPayload(
297       StatusCode::kFailedPrecondition,
298       "No model metadata",
299       TfLiteSupportStatus::kMetadataNotFoundError);
300   }
301   if (model_metadata_->version() == nullptr) {
302     return CreateStatusWithPayload(
303       StatusCode::kNotFound,
304       "No version in model metadata",
305       TfLiteSupportStatus::kMetadataNotFoundError);
306   }
307   return model_metadata_->version()->str();
308 }
309 
310 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const311 ModelMetadataExtractor::GetInputTensorMetadata() const {
312   if (model_metadata_ == nullptr ||
313       model_metadata_->subgraph_metadata() == nullptr) {
314     return nullptr;
315   }
316   return model_metadata_->subgraph_metadata()
317       ->Get(kDefaultSubgraphIndex)
318       ->input_tensor_metadata();
319 }
320 
GetInputTensorMetadata(int index) const321 const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
322     int index) const {
323   return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
324                                                    index);
325 }
326 
GetInputTensorCount() const327 int ModelMetadataExtractor::GetInputTensorCount() const {
328   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
329       input_tensor_metadata = GetInputTensorMetadata();
330   return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
331 }
332 
333 const Vector<Offset<TensorMetadata>>*
GetOutputTensorMetadata() const334 ModelMetadataExtractor::GetOutputTensorMetadata() const {
335   if (model_metadata_ == nullptr ||
336       model_metadata_->subgraph_metadata() == nullptr) {
337     return nullptr;
338   }
339   return model_metadata_->subgraph_metadata()
340       ->Get(kDefaultSubgraphIndex)
341       ->output_tensor_metadata();
342 }
343 
GetOutputTensorMetadata(int index) const344 const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
345     int index) const {
346   return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
347                                                    index);
348 }
349 
GetOutputTensorCount() const350 int ModelMetadataExtractor::GetOutputTensorCount() const {
351   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
352       output_tensor_metadata = GetOutputTensorMetadata();
353   return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
354 }
355 
356 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const357 ModelMetadataExtractor::GetInputProcessUnits() const {
358   if (model_metadata_ == nullptr ||
359       model_metadata_->subgraph_metadata() == nullptr) {
360     return nullptr;
361   }
362   return model_metadata_->subgraph_metadata()
363       ->Get(kDefaultSubgraphIndex)
364       ->input_process_units();
365 }
366 
GetInputProcessUnit(int index) const367 const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
368     int index) const {
369   return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
370 }
371 
GetInputProcessUnitsCount() const372 int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
373   const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
374       GetInputProcessUnits();
375   return input_process_units == nullptr ? 0 : input_process_units->size();
376 }
377 
378 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const379 ModelMetadataExtractor::GetOutputProcessUnits() const {
380   if (model_metadata_ == nullptr ||
381       model_metadata_->subgraph_metadata() == nullptr) {
382     return nullptr;
383   }
384   return model_metadata_->subgraph_metadata()
385       ->Get(kDefaultSubgraphIndex)
386       ->output_process_units();
387 }
388 
GetOutputProcessUnit(int index) const389 const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
390     int index) const {
391   return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
392 }
393 
GetOutputProcessUnitsCount() const394 int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
395   const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
396       GetOutputProcessUnits();
397   return output_process_units == nullptr ? 0 : output_process_units->size();
398 }
399 
400 }  // namespace metadata
401 }  // namespace tflite