• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
16 #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 #include "tensorflow_lite_support/cc/port/statusor.h"
25 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
26 
27 namespace tflite {
28 namespace metadata {
29 
30 // Extracts and provides easy access to the TFLite ModelMetadata [1] and
31 // corresponding associated files packed into a TFLite FlatBuffer, if any.
32 //
33 // [1]: https://www.tensorflow.org/lite/convert/metadata
34 class ModelMetadataExtractor {
35  public:
36   // Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer
37   // and returns a pointer to the new object. Ownership is transferred to the
38   // caller. Returns an error if the creation failed, which may happen e.g. if
39   // the provided buffer is not a valid TFLite FlatBuffer.
40   //
41   // Warning: Does not take ownership of the provided buffer, which must outlive
42   // this object.
43   //
44   // It is recommended to obtain and manage the buffer through an
45   // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having
46   // to load the entire buffer in memory when provided by path or file
47   // descriptor.
48   //
49   // [1]:
50   // tensorflow_lite_support/c/task/core/external_file_handler.h
51   static tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
52   CreateFromModelBuffer(const char* buffer_data, size_t buffer_size);
53 
54   // Returns the pointer to the *first* ProcessUnit with the provided type, or
55   // nullptr if none can be found. An error is returned if multiple
56   // ProcessUnit-s with the provided type are found.
57   static tflite::support::StatusOr<const tflite::ProcessUnit*>
58   FindFirstProcessUnit(const tflite::TensorMetadata& tensor_metadata,
59                        tflite::ProcessUnitOptions type);
60 
61   // Returns the name of the *first* associated file with the provided type and
62   // (optional) locale in the provided TensorMetadata, or an empty string if no
63   // such associated file can be found (which is not necessarily an error: some
64   // models have no associated files at all) or its `name` field is unspecified.
65   // Note: see `GetAssociatedFile` to read the actual file contents.
66   static std::string FindFirstAssociatedFileName(
67       const tflite::TensorMetadata& tensor_metadata,
68       tflite::AssociatedFileType type,
69       absl::string_view locale = absl::string_view());
70 
71   // Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no
72   // metadata was present in the Model FlatBuffer provided at creation time.
GetModelMetadata()73   const tflite::ModelMetadata* GetModelMetadata() const {
74     return model_metadata_;
75   }
76 
77   // Gets the contents of the associated file with the provided name packed into
78   // the model metadata. An error is returned if there is no such associated
79   // file.
80   tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
81       const std::string& filename) const;
82 
83   // Gets the model version from the model metadata.  An error is returned if
84   // either the metadata does not exist or no model version is present in it.
85   tflite::support::StatusOr<std::string> GetModelVersion() const;
86 
87   // Note: all methods below retrieves metadata of the *first* subgraph as
88   // default.
89 
90   // Gets the metadata for input tensors.
91   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
92   GetInputTensorMetadata() const;
93 
94   // Gets the metadata for the input tensor specified by the given index, or
95   // nullptr in case there is no metadata or the index is out of range.
96   const tflite::TensorMetadata* GetInputTensorMetadata(int index) const;
97 
98   // Gets the count of input tensors with metadata in the metadata FlatBuffer.
99   // In particular, 0 is returned when there is no metadata.
100   int GetInputTensorCount() const;
101 
102   // Gets the metadata for output tensors.
103   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
104   GetOutputTensorMetadata() const;
105 
106   // Gets the metadata for the output tensor specified by the given index, or
107   // nullptr in case there is no metadata or the index is out of range.
108   const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const;
109 
110   // Gets the count of output tensors with metadata in the metadata FlatBuffer.
111   // In particular, 0 is returned when there is no metadata.
112   int GetOutputTensorCount() const;
113 
114   // Gets the input process units from SubgraphMetadata.input_process_units,
115   // could be nullptr.
116   const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
117   GetInputProcessUnits() const;
118 
119   // Gets the input process unit specified by the given index, or nullptr in
120   // case there is no input process unit or the index is out of range.
121   const tflite::ProcessUnit* GetInputProcessUnit(int index) const;
122 
123   // Gets the count of input process units. In particular, 0 is returned when
124   // there is no input process units.
125   int GetInputProcessUnitsCount() const;
126 
127   // Gets the output process units from SubgraphMetadata.output_process_units,
128   // could be nullptr.
129   const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
130   GetOutputProcessUnits() const;
131 
132   // Gets the output process unit specified by the given index, or nullptr in
133   // case there is no output process unit or the index is out of range.
134   const tflite::ProcessUnit* GetOutputProcessUnit(int index) const;
135 
136   // Gets the count of output process units. In particular, 0 is returned when
137   // there is no output process units.
138   int GetOutputProcessUnitsCount() const;
139 
140  private:
141   static constexpr int kDefaultSubgraphIndex = 0;
142   // Private default constructor, called from CreateFromModel().
143   ModelMetadataExtractor() = default;
144   // Initializes the ModelMetadataExtractor from the provided Model FlatBuffer.
145   absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size);
146   // Extracts and stores in associated_files_ the associated files (if present)
147   // packed into the model FlatBuffer data.
148   absl::Status ExtractAssociatedFiles(const char* buffer_data,
149                                       size_t buffer_size);
150   // Pointer to the TFLite Model object from which to read the ModelMetadata.
151   const tflite::Model* model_{nullptr};
152   // Pointer to the extracted ModelMetadata, if any.
153   const tflite::ModelMetadata* model_metadata_{nullptr};
154   // The files associated with the ModelMetadata, as a map with the filename
155   // (corresponding to a basename, e.g. "labels.txt") as key and a pointer to
156   // the file contents as value.
157   absl::flat_hash_map<std::string, absl::string_view> associated_files_;
158 };
159 
160 }  // namespace metadata
161 }  // namespace tflite
162 
163 #endif  // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
164