1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
17
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/str_format.h"
23 #include "absl/strings/string_view.h"
24 #include "flatbuffers/flatbuffers.h"
25 #include "contrib/minizip/ioapi.h"
26 #include "contrib/minizip/unzip.h"
27 #include "tensorflow/lite/schema/schema_generated.h"
28 #include "tensorflow_lite_support/cc/common.h"
29 #include "tensorflow_lite_support/cc/port/status_macros.h"
30 #include "tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32
33 namespace tflite {
34 namespace metadata {
35
36 namespace {
37 constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
38
39 using ::absl::StatusCode;
40 using ::flatbuffers::Offset;
41 using ::flatbuffers::Vector;
42 using ::tflite::TensorMetadata;
43 using ::tflite::support::CreateStatusWithPayload;
44 using ::tflite::support::TfLiteSupportStatus;
45
46 // Util to get item from src_vector specified by index.
47 template <typename T>
GetItemFromVector(const flatbuffers::Vector<flatbuffers::Offset<T>> * src_vector,int index)48 const T* GetItemFromVector(
49 const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
50 if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
51 return nullptr;
52 }
53 return src_vector->Get(index);
54 }
55
56 // Wrapper function around calls to unzip to avoid repeating conversion logic
57 // from error code to Status.
UnzipErrorToStatus(int error)58 absl::Status UnzipErrorToStatus(int error) {
59 if (error != UNZ_OK) {
60 return CreateStatusWithPayload(
61 StatusCode::kUnknown, "Unable to read associated file in zip archive.",
62 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
63 }
64 return absl::OkStatus();
65 }
66
67 // Stores a file name, position in zip buffer and size.
68 struct ZipFileInfo {
69 std::string name;
70 ZPOS64_T position;
71 ZPOS64_T size;
72 };
73
74 // Returns the ZipFileInfo corresponding to the current file in the provided
75 // unzFile object.
GetCurrentZipFileInfo(const unzFile & zf)76 tflite::support::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
77 // Open file in raw mode, as data is expected to be uncompressed.
78 int method;
79 RETURN_IF_ERROR(UnzipErrorToStatus(
80 unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
81 if (method != Z_NO_COMPRESSION) {
82 return CreateStatusWithPayload(
83 StatusCode::kUnknown, "Expected uncompressed zip archive.",
84 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
85 }
86
87 // Get file info a first time to get filename size.
88 unz_file_info64 file_info;
89 RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
90 zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
91 /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
92 /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
93
94 // Second call to get file name.
95 auto file_name_size = file_info.size_filename;
96 char* c_file_name = (char*)malloc(file_name_size);
97 RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
98 zf, &file_info, c_file_name, file_name_size,
99 /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
100 /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
101 std::string file_name = std::string(c_file_name, file_name_size);
102 free(c_file_name);
103
104 // Get position in file.
105 auto position = unzGetCurrentFileZStreamPos64(zf);
106 if (position == 0) {
107 return CreateStatusWithPayload(
108 StatusCode::kUnknown, "Unable to read file in zip archive.",
109 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
110 }
111 ZipFileInfo result = {.name = file_name,
112 .position = position,
113 .size = file_info.uncompressed_size};
114
115 // Close file and return.
116 RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
117 return result;
118 }
119 } // namespace
120
121 /* static */
122 tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char * buffer_data,size_t buffer_size)123 ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
124 size_t buffer_size) {
125 // Use absl::WrapUnique() to call private constructor:
126 // https://abseil.io/tips/126.
127 std::unique_ptr<ModelMetadataExtractor> extractor =
128 absl::WrapUnique(new ModelMetadataExtractor());
129 RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
130 return extractor;
131 }
132
133 /* static */
134 tflite::support::StatusOr<const tflite::ProcessUnit*>
FindFirstProcessUnit(const tflite::TensorMetadata & tensor_metadata,tflite::ProcessUnitOptions type)135 ModelMetadataExtractor::FindFirstProcessUnit(
136 const tflite::TensorMetadata& tensor_metadata,
137 tflite::ProcessUnitOptions type) {
138 const tflite::ProcessUnit* result = nullptr;
139 if (tensor_metadata.process_units() == nullptr) {
140 return result;
141 }
142 for (const auto process_unit : *tensor_metadata.process_units()) {
143 if (process_unit->options_type() == type) {
144 if (result != nullptr) {
145 return CreateStatusWithPayload(
146 StatusCode::kInvalidArgument,
147 absl::StrCat("Found multiple ProcessUnits with type=",
148 tflite::EnumNameProcessUnitOptions(type),
149 ", expected at most one."),
150 TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
151 }
152 result = process_unit;
153 }
154 }
155 return result;
156 }
157
158 /* static */
FindFirstAssociatedFileName(const tflite::TensorMetadata & tensor_metadata,tflite::AssociatedFileType type,absl::string_view locale)159 std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
160 const tflite::TensorMetadata& tensor_metadata,
161 tflite::AssociatedFileType type, absl::string_view locale) {
162 if (tensor_metadata.associated_files() == nullptr) {
163 return std::string();
164 }
165 for (const auto associated_file : *tensor_metadata.associated_files()) {
166 if (associated_file->type() != type || associated_file->name() == nullptr) {
167 continue;
168 }
169 if (locale.empty() || (associated_file->locale() != nullptr &&
170 locale == associated_file->locale()->str())) {
171 return associated_file->name()->str();
172 }
173 }
174 return std::string();
175 }
176
InitFromModelBuffer(const char * buffer_data,size_t buffer_size)177 absl::Status ModelMetadataExtractor::InitFromModelBuffer(
178 const char* buffer_data, size_t buffer_size) {
179 // Rely on the simplest, base flatbuffers verifier. Here is not the place to
180 // e.g. use an OpResolver: we just want to make sure the buffer is valid to
181 // access the metadata.
182 flatbuffers::Verifier verifier = flatbuffers::Verifier(
183 reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
184 if (!tflite::VerifyModelBuffer(verifier)) {
185 return CreateStatusWithPayload(
186 StatusCode::kInvalidArgument,
187 "The model is not a valid FlatBuffer buffer.",
188 TfLiteSupportStatus::kInvalidFlatBufferError);
189 }
190 model_ = tflite::GetModel(buffer_data);
191 if (model_->metadata() == nullptr) {
192 // Not all models have metadata, which is OK. `GetModelMetadata()` then
193 // returns nullptr.
194 return absl::OkStatus();
195 }
196 // Look for the "TFLITE_METADATA" field, if any.
197 for (int i = 0; i < model_->metadata()->size(); ++i) {
198 const auto metadata = model_->metadata()->Get(i);
199 if (!metadata->name()) {
200 continue;
201 }
202 if (metadata->name()->str() != kMetadataBufferName) {
203 continue;
204 }
205 const auto buffer_index = metadata->buffer();
206 const auto metadata_buffer =
207 model_->buffers()->Get(buffer_index)->data()->data();
208 if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
209 return CreateStatusWithPayload(
210 StatusCode::kInvalidArgument,
211 absl::StrFormat(
212 "Invalid metadata schema version: expected %s, got %s",
213 absl::string_view(tflite::ModelMetadataIdentifier())
214 .substr(
215 0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
216 // Returned identifier is not null terminated; has to be
217 // truncated.
218 absl::string_view(
219 flatbuffers::GetBufferIdentifier(metadata_buffer))
220 .substr(
221 0,
222 flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
223 TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
224 }
225 model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
226 if (model_metadata_ == nullptr) {
227 return CreateStatusWithPayload(StatusCode::kInternal,
228 "Expected Model Metadata not to be null.");
229 }
230 return ExtractAssociatedFiles(buffer_data, buffer_size);
231 break;
232 }
233 return absl::OkStatus();
234 }
235
ExtractAssociatedFiles(const char * buffer_data,size_t buffer_size)236 absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
237 const char* buffer_data, size_t buffer_size) {
238 // Create in-memory read-only zip file.
239 ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
240 // Open zip.
241 unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
242 if (zf == nullptr) {
243 // It's OK if it fails: this means there are no associated files with this
244 // model.
245 return absl::OkStatus();
246 }
247 // Get number of files.
248 unz_global_info global_info;
249 if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
250 return CreateStatusWithPayload(
251 StatusCode::kUnknown, "Unable to get zip archive info.",
252 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
253 }
254
255 // Browse through files in archive.
256 if (global_info.number_entry > 0) {
257 int error = unzGoToFirstFile(zf);
258 while (error == UNZ_OK) {
259 ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
260 // Store result in map.
261 associated_files_[zip_file_info.name] = absl::string_view(
262 buffer_data + zip_file_info.position, zip_file_info.size);
263 error = unzGoToNextFile(zf);
264 }
265 if (error != UNZ_END_OF_LIST_OF_FILE) {
266 return CreateStatusWithPayload(
267 StatusCode::kUnknown,
268 "Unable to read associated file in zip archive.",
269 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
270 }
271 }
272 // Close zip.
273 if (unzClose(zf) != UNZ_OK) {
274 return CreateStatusWithPayload(
275 StatusCode::kUnknown, "Unable to close zip archive.",
276 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
277 }
278 return absl::OkStatus();
279 }
280
281 tflite::support::StatusOr<absl::string_view>
GetAssociatedFile(const std::string & filename) const282 ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
283 auto it = associated_files_.find(filename);
284 if (it == associated_files_.end()) {
285 return CreateStatusWithPayload(
286 StatusCode::kNotFound,
287 absl::StrFormat("No associated file with name: %s", filename),
288 TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
289 }
290 return it->second;
291 }
292
293 tflite::support::StatusOr<std::string>
GetModelVersion() const294 ModelMetadataExtractor::GetModelVersion() const {
295 if (model_metadata_ == nullptr) {
296 return CreateStatusWithPayload(
297 StatusCode::kFailedPrecondition,
298 "No model metadata",
299 TfLiteSupportStatus::kMetadataNotFoundError);
300 }
301 if (model_metadata_->version() == nullptr) {
302 return CreateStatusWithPayload(
303 StatusCode::kNotFound,
304 "No version in model metadata",
305 TfLiteSupportStatus::kMetadataNotFoundError);
306 }
307 return model_metadata_->version()->str();
308 }
309
310 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const311 ModelMetadataExtractor::GetInputTensorMetadata() const {
312 if (model_metadata_ == nullptr ||
313 model_metadata_->subgraph_metadata() == nullptr) {
314 return nullptr;
315 }
316 return model_metadata_->subgraph_metadata()
317 ->Get(kDefaultSubgraphIndex)
318 ->input_tensor_metadata();
319 }
320
GetInputTensorMetadata(int index) const321 const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
322 int index) const {
323 return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
324 index);
325 }
326
GetInputTensorCount() const327 int ModelMetadataExtractor::GetInputTensorCount() const {
328 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
329 input_tensor_metadata = GetInputTensorMetadata();
330 return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
331 }
332
333 const Vector<Offset<TensorMetadata>>*
GetOutputTensorMetadata() const334 ModelMetadataExtractor::GetOutputTensorMetadata() const {
335 if (model_metadata_ == nullptr ||
336 model_metadata_->subgraph_metadata() == nullptr) {
337 return nullptr;
338 }
339 return model_metadata_->subgraph_metadata()
340 ->Get(kDefaultSubgraphIndex)
341 ->output_tensor_metadata();
342 }
343
GetOutputTensorMetadata(int index) const344 const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
345 int index) const {
346 return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
347 index);
348 }
349
GetOutputTensorCount() const350 int ModelMetadataExtractor::GetOutputTensorCount() const {
351 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
352 output_tensor_metadata = GetOutputTensorMetadata();
353 return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
354 }
355
356 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const357 ModelMetadataExtractor::GetInputProcessUnits() const {
358 if (model_metadata_ == nullptr ||
359 model_metadata_->subgraph_metadata() == nullptr) {
360 return nullptr;
361 }
362 return model_metadata_->subgraph_metadata()
363 ->Get(kDefaultSubgraphIndex)
364 ->input_process_units();
365 }
366
GetInputProcessUnit(int index) const367 const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
368 int index) const {
369 return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
370 }
371
GetInputProcessUnitsCount() const372 int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
373 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
374 GetInputProcessUnits();
375 return input_process_units == nullptr ? 0 : input_process_units->size();
376 }
377
378 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const379 ModelMetadataExtractor::GetOutputProcessUnits() const {
380 if (model_metadata_ == nullptr ||
381 model_metadata_->subgraph_metadata() == nullptr) {
382 return nullptr;
383 }
384 return model_metadata_->subgraph_metadata()
385 ->Get(kDefaultSubgraphIndex)
386 ->output_process_units();
387 }
388
GetOutputProcessUnit(int index) const389 const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
390 int index) const {
391 return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
392 }
393
GetOutputProcessUnitsCount() const394 int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
395 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
396 GetOutputProcessUnits();
397 return output_process_units == nullptr ? 0 : output_process_units->size();
398 }
399
400 } // namespace metadata
401 } // namespace tflite