• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
16 #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/status/status.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22 #include "tensorflow_lite_support/cc/port/statusor.h"
23 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
24 
25 namespace tflite {
26 namespace metadata {
27 
28 // Extracts and provides easy access to the TFLite ModelMetadata [1] and
29 // corresponding associated files packed into a TFLite FlatBuffer, if any.
30 //
31 // [1]: https://www.tensorflow.org/lite/convert/metadata
32 class ModelMetadataExtractor {
33  public:
34   // Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer
35   // and returns a pointer to the new object. Ownership is transferred to the
36   // caller. Returns an error if the creation failed, which may happen e.g. if
37   // the provided buffer is not a valid TFLite FlatBuffer.
38   //
39   // Warning: Does not take ownership of the provided buffer, which must outlive
40   // this object.
41   //
42   // It is recommended to obtain and manage the buffer through an
43   // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having
44   // to load the entire buffer in memory when provided by path or file
45   // descriptor.
46   //
47   // [1]:
48   // tensorflow_lite_support/c/task/core/external_file_handler.h
49   static tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
50   CreateFromModelBuffer(const char* buffer_data, size_t buffer_size);
51 
52   // Returns the pointer to the *first* ProcessUnit with the provided type, or
53   // nullptr if none can be found. An error is returned if multiple
54   // ProcessUnit-s with the provided type are found.
55   static tflite::support::StatusOr<const tflite::ProcessUnit*>
56   FindFirstProcessUnit(const tflite::TensorMetadata& tensor_metadata,
57                        tflite::ProcessUnitOptions type);
58 
59   // Returns the name of the *first* associated file with the provided type and
60   // (optional) locale in the provided TensorMetadata, or an empty string if no
61   // such associated file can be found (which is not necessarily an error: some
62   // models have no associated files at all) or its `name` field is unspecified.
63   // Note: see `GetAssociatedFile` to read the actual file contents.
64   static std::string FindFirstAssociatedFileName(
65       const tflite::TensorMetadata& tensor_metadata,
66       tflite::AssociatedFileType type,
67       absl::string_view locale = absl::string_view());
68 
69   // Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no
70   // metadata was present in the Model FlatBuffer provided at creation time.
GetModelMetadata()71   const tflite::ModelMetadata* GetModelMetadata() const {
72     return model_metadata_;
73   }
74 
75   // Gets the contents of the associated file with the provided name packed into
76   // the model metadata. An error is returned if there is no such associated
77   // file.
78   tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
79       const std::string& filename) const;
80 
81   // Note: all methods below retrieves metadata of the *first* subgraph as
82   // default.
83 
84   // Gets the metadata for input tensors.
85   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
86   GetInputTensorMetadata() const;
87 
88   // Gets the metadata for the input tensor specified by the given index, or
89   // nullptr in case there is no metadata or the index is out of range.
90   const tflite::TensorMetadata* GetInputTensorMetadata(int index) const;
91 
92   // Gets the count of input tensors with metadata in the metadata FlatBuffer.
93   // In particular, 0 is returned when there is no metadata.
94   int GetInputTensorCount() const;
95 
96   // Gets the metadata for output tensors.
97   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
98   GetOutputTensorMetadata() const;
99 
100   // Gets the metadata for the output tensor specified by the given index, or
101   // nullptr in case there is no metadata or the index is out of range.
102   const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const;
103 
104   // Gets the count of output tensors with metadata in the metadata FlatBuffer.
105   // In particular, 0 is returned when there is no metadata.
106   int GetOutputTensorCount() const;
107 
108   // Gets the input process units from SubgraphMetadata.input_process_units,
109   // could be nullptr.
110   const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
111   GetInputProcessUnits() const;
112 
113   // Gets the input process unit specified by the given index, or nullptr in
114   // case there is no input process unit or the index is out of range.
115   const tflite::ProcessUnit* GetInputProcessUnit(int index) const;
116 
117   // Gets the count of input process units. In particular, 0 is returned when
118   // there is no input process units.
119   int GetInputProcessUnitsCount() const;
120 
121   // Gets the output process units from SubgraphMetadata.output_process_units,
122   // could be nullptr.
123   const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
124   GetOutputProcessUnits() const;
125 
126   // Gets the output process unit specified by the given index, or nullptr in
127   // case there is no output process unit or the index is out of range.
128   const tflite::ProcessUnit* GetOutputProcessUnit(int index) const;
129 
130   // Gets the count of output process units. In particular, 0 is returned when
131   // there is no output process units.
132   int GetOutputProcessUnitsCount() const;
133 
134  private:
135   static constexpr int kDefaultSubgraphIndex = 0;
136   // Private default constructor, called from CreateFromModel().
137   ModelMetadataExtractor() = default;
138   // Initializes the ModelMetadataExtractor from the provided Model FlatBuffer.
139   absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size);
140   // Extracts and stores in associated_files_ the associated files (if present)
141   // packed into the model FlatBuffer data.
142   absl::Status ExtractAssociatedFiles(const char* buffer_data,
143                                       size_t buffer_size);
144   // Pointer to the TFLite Model object from which to read the ModelMetadata.
145   const tflite::Model* model_{nullptr};
146   // Pointer to the extracted ModelMetadata, if any.
147   const tflite::ModelMetadata* model_metadata_{nullptr};
148   // The files associated with the ModelMetadata, as a map with the filename
149   // (corresponding to a basename, e.g. "labels.txt") as key and the file
150   // contents as value.
151   absl::flat_hash_map<std::string, std::string> associated_files_;
152 };
153 
154 }  // namespace metadata
155 }  // namespace tflite
156 
157 #endif  // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
158