1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ 16 #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/status/status.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/lite/schema/schema_generated.h" 22 #include "tensorflow_lite_support/cc/port/statusor.h" 23 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" 24 25 namespace tflite { 26 namespace metadata { 27 28 // Extracts and provides easy access to the TFLite ModelMetadata [1] and 29 // corresponding associated files packed into a TFLite FlatBuffer, if any. 30 // 31 // [1]: https://www.tensorflow.org/lite/convert/metadata 32 class ModelMetadataExtractor { 33 public: 34 // Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer 35 // and returns a pointer to the new object. Ownership is transferred to the 36 // caller. Returns an error if the creation failed, which may happen e.g. if 37 // the provided buffer is not a valid TFLite FlatBuffer. 38 // 39 // Warning: Does not take ownership of the provided buffer, which must outlive 40 // this object. 41 // 42 // It is recommended to obtain and manage the buffer through an 43 // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having 44 // to load the entire buffer in memory when provided by path or file 45 // descriptor. 46 // 47 // [1]: 48 // tensorflow_lite_support/c/task/core/external_file_handler.h 49 static tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>> 50 CreateFromModelBuffer(const char* buffer_data, size_t buffer_size); 51 52 // Returns the pointer to the *first* ProcessUnit with the provided type, or 53 // nullptr if none can be found. An error is returned if multiple 54 // ProcessUnit-s with the provided type are found. 55 static tflite::support::StatusOr<const tflite::ProcessUnit*> 56 FindFirstProcessUnit(const tflite::TensorMetadata& tensor_metadata, 57 tflite::ProcessUnitOptions type); 58 59 // Returns the name of the *first* associated file with the provided type and 60 // (optional) locale in the provided TensorMetadata, or an empty string if no 61 // such associated file can be found (which is not necessarily an error: some 62 // models have no associated files at all) or its `name` field is unspecified. 63 // Note: see `GetAssociatedFile` to read the actual file contents. 64 static std::string FindFirstAssociatedFileName( 65 const tflite::TensorMetadata& tensor_metadata, 66 tflite::AssociatedFileType type, 67 absl::string_view locale = absl::string_view()); 68 69 // Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no 70 // metadata was present in the Model FlatBuffer provided at creation time. GetModelMetadata()71 const tflite::ModelMetadata* GetModelMetadata() const { 72 return model_metadata_; 73 } 74 75 // Gets the contents of the associated file with the provided name packed into 76 // the model metadata. An error is returned if there is no such associated 77 // file. 78 tflite::support::StatusOr<absl::string_view> GetAssociatedFile( 79 const std::string& filename) const; 80 81 // Note: all methods below retrieves metadata of the *first* subgraph as 82 // default. 83 84 // Gets the metadata for input tensors. 85 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* 86 GetInputTensorMetadata() const; 87 88 // Gets the metadata for the input tensor specified by the given index, or 89 // nullptr in case there is no metadata or the index is out of range. 90 const tflite::TensorMetadata* GetInputTensorMetadata(int index) const; 91 92 // Gets the count of input tensors with metadata in the metadata FlatBuffer. 93 // In particular, 0 is returned when there is no metadata. 94 int GetInputTensorCount() const; 95 96 // Gets the metadata for output tensors. 97 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* 98 GetOutputTensorMetadata() const; 99 100 // Gets the metadata for the output tensor specified by the given index, or 101 // nullptr in case there is no metadata or the index is out of range. 102 const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const; 103 104 // Gets the count of output tensors with metadata in the metadata FlatBuffer. 105 // In particular, 0 is returned when there is no metadata. 106 int GetOutputTensorCount() const; 107 108 // Gets the input process units from SubgraphMetadata.input_process_units, 109 // could be nullptr. 110 const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>* 111 GetInputProcessUnits() const; 112 113 // Gets the input process unit specified by the given index, or nullptr in 114 // case there is no input process unit or the index is out of range. 115 const tflite::ProcessUnit* GetInputProcessUnit(int index) const; 116 117 // Gets the count of input process units. In particular, 0 is returned when 118 // there is no input process units. 119 int GetInputProcessUnitsCount() const; 120 121 // Gets the output process units from SubgraphMetadata.output_process_units, 122 // could be nullptr. 123 const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>* 124 GetOutputProcessUnits() const; 125 126 // Gets the output process unit specified by the given index, or nullptr in 127 // case there is no output process unit or the index is out of range. 128 const tflite::ProcessUnit* GetOutputProcessUnit(int index) const; 129 130 // Gets the count of output process units. In particular, 0 is returned when 131 // there is no output process units. 132 int GetOutputProcessUnitsCount() const; 133 134 private: 135 static constexpr int kDefaultSubgraphIndex = 0; 136 // Private default constructor, called from CreateFromModel(). 137 ModelMetadataExtractor() = default; 138 // Initializes the ModelMetadataExtractor from the provided Model FlatBuffer. 139 absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size); 140 // Extracts and stores in associated_files_ the associated files (if present) 141 // packed into the model FlatBuffer data. 142 absl::Status ExtractAssociatedFiles(const char* buffer_data, 143 size_t buffer_size); 144 // Pointer to the TFLite Model object from which to read the ModelMetadata. 145 const tflite::Model* model_{nullptr}; 146 // Pointer to the extracted ModelMetadata, if any. 147 const tflite::ModelMetadata* model_metadata_{nullptr}; 148 // The files associated with the ModelMetadata, as a map with the filename 149 // (corresponding to a basename, e.g. "labels.txt") as key and the file 150 // contents as value. 151 absl::flat_hash_map<std::string, std::string> associated_files_; 152 }; 153 154 } // namespace metadata 155 } // namespace tflite 156 157 #endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ 158