1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ 16 #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ 17 18 #include <string> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "absl/status/status.h" 22 #include "absl/strings/string_view.h" 23 #include "tensorflow/lite/schema/schema_generated.h" 24 #include "tensorflow_lite_support/cc/port/statusor.h" 25 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" 26 27 namespace tflite { 28 namespace metadata { 29 30 // Extracts and provides easy access to the TFLite ModelMetadata [1] and 31 // corresponding associated files packed into a TFLite FlatBuffer, if any. 32 // 33 // [1]: https://www.tensorflow.org/lite/convert/metadata 34 class ModelMetadataExtractor { 35 public: 36 // Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer 37 // and returns a pointer to the new object. Ownership is transferred to the 38 // caller. Returns an error if the creation failed, which may happen e.g. if 39 // the provided buffer is not a valid TFLite FlatBuffer. 40 // 41 // Warning: Does not take ownership of the provided buffer, which must outlive 42 // this object. 43 // 44 // It is recommended to obtain and manage the buffer through an 45 // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having 46 // to load the entire buffer in memory when provided by path or file 47 // descriptor. 48 // 49 // [1]: 50 // tensorflow_lite_support/c/task/core/external_file_handler.h 51 static tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>> 52 CreateFromModelBuffer(const char* buffer_data, size_t buffer_size); 53 54 // Returns the pointer to the *first* ProcessUnit with the provided type, or 55 // nullptr if none can be found. An error is returned if multiple 56 // ProcessUnit-s with the provided type are found. 57 static tflite::support::StatusOr<const tflite::ProcessUnit*> 58 FindFirstProcessUnit(const tflite::TensorMetadata& tensor_metadata, 59 tflite::ProcessUnitOptions type); 60 61 // Returns the name of the *first* associated file with the provided type and 62 // (optional) locale in the provided TensorMetadata, or an empty string if no 63 // such associated file can be found (which is not necessarily an error: some 64 // models have no associated files at all) or its `name` field is unspecified. 65 // Note: see `GetAssociatedFile` to read the actual file contents. 66 static std::string FindFirstAssociatedFileName( 67 const tflite::TensorMetadata& tensor_metadata, 68 tflite::AssociatedFileType type, 69 absl::string_view locale = absl::string_view()); 70 71 // Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no 72 // metadata was present in the Model FlatBuffer provided at creation time. GetModelMetadata()73 const tflite::ModelMetadata* GetModelMetadata() const { 74 return model_metadata_; 75 } 76 77 // Gets the contents of the associated file with the provided name packed into 78 // the model metadata. An error is returned if there is no such associated 79 // file. 80 tflite::support::StatusOr<absl::string_view> GetAssociatedFile( 81 const std::string& filename) const; 82 83 // Gets the model version from the model metadata. An error is returned if 84 // either the metadata does not exist or no model version is present in it. 85 tflite::support::StatusOr<std::string> GetModelVersion() const; 86 87 // Note: all methods below retrieves metadata of the *first* subgraph as 88 // default. 89 90 // Gets the metadata for input tensors. 91 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* 92 GetInputTensorMetadata() const; 93 94 // Gets the metadata for the input tensor specified by the given index, or 95 // nullptr in case there is no metadata or the index is out of range. 96 const tflite::TensorMetadata* GetInputTensorMetadata(int index) const; 97 98 // Gets the count of input tensors with metadata in the metadata FlatBuffer. 99 // In particular, 0 is returned when there is no metadata. 100 int GetInputTensorCount() const; 101 102 // Gets the metadata for output tensors. 103 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* 104 GetOutputTensorMetadata() const; 105 106 // Gets the metadata for the output tensor specified by the given index, or 107 // nullptr in case there is no metadata or the index is out of range. 108 const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const; 109 110 // Gets the count of output tensors with metadata in the metadata FlatBuffer. 111 // In particular, 0 is returned when there is no metadata. 112 int GetOutputTensorCount() const; 113 114 // Gets the input process units from SubgraphMetadata.input_process_units, 115 // could be nullptr. 116 const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>* 117 GetInputProcessUnits() const; 118 119 // Gets the input process unit specified by the given index, or nullptr in 120 // case there is no input process unit or the index is out of range. 121 const tflite::ProcessUnit* GetInputProcessUnit(int index) const; 122 123 // Gets the count of input process units. In particular, 0 is returned when 124 // there is no input process units. 125 int GetInputProcessUnitsCount() const; 126 127 // Gets the output process units from SubgraphMetadata.output_process_units, 128 // could be nullptr. 129 const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>* 130 GetOutputProcessUnits() const; 131 132 // Gets the output process unit specified by the given index, or nullptr in 133 // case there is no output process unit or the index is out of range. 134 const tflite::ProcessUnit* GetOutputProcessUnit(int index) const; 135 136 // Gets the count of output process units. In particular, 0 is returned when 137 // there is no output process units. 138 int GetOutputProcessUnitsCount() const; 139 140 private: 141 static constexpr int kDefaultSubgraphIndex = 0; 142 // Private default constructor, called from CreateFromModel(). 143 ModelMetadataExtractor() = default; 144 // Initializes the ModelMetadataExtractor from the provided Model FlatBuffer. 145 absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size); 146 // Extracts and stores in associated_files_ the associated files (if present) 147 // packed into the model FlatBuffer data. 148 absl::Status ExtractAssociatedFiles(const char* buffer_data, 149 size_t buffer_size); 150 // Pointer to the TFLite Model object from which to read the ModelMetadata. 151 const tflite::Model* model_{nullptr}; 152 // Pointer to the extracted ModelMetadata, if any. 153 const tflite::ModelMetadata* model_metadata_{nullptr}; 154 // The files associated with the ModelMetadata, as a map with the filename 155 // (corresponding to a basename, e.g. "labels.txt") as key and a pointer to 156 // the file contents as value. 157 absl::flat_hash_map<std::string, absl::string_view> associated_files_; 158 }; 159 160 } // namespace metadata 161 } // namespace tflite 162 163 #endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ 164