1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/metadata/cc/metadata_version.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <array>
21 #include <ostream>
22 #include <string>
23 #include <vector>
24
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/tools/logging.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32
33 namespace tflite {
34 namespace metadata {
35 namespace {
36
37 // Members that are added to the metadata schema after the initial version
38 // of 1.0.0.
39 enum class SchemaMembers {
40 kAssociatedFileTypeVocabulary = 0,
41 kSubGraphMetadataInputProcessUnits = 1,
42 kSubGraphMetadataOutputProcessUnits = 2,
43 kProcessUnitOptionsBertTokenizerOptions = 3,
44 kProcessUnitOptionsSentencePieceTokenizerOptions = 4,
45 kSubGraphMetadataInputTensorGroups = 5,
46 kSubGraphMetadataOutputTensorGroups = 6,
47 kProcessUnitOptionsRegexTokenizerOptions = 7,
48 };
49
50 // Helper class to compare semantic versions in terms of three integers, major,
51 // minor, and patch.
52 class Version {
53 public:
Version(int major,int minor=0,int patch=0)54 explicit Version(int major, int minor = 0, int patch = 0)
55 : version_({major, minor, patch}) {}
56
Version(const std::string & version)57 explicit Version(const std::string& version) {
58 const std::vector<std::string> vec = absl::StrSplit(version, '.');
59 // The version string should always be less than four numbers.
60 TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
61 version_[0] = std::stoi(vec[0]);
62 version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
63 version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
64 }
65
66 // Compares two semantic version numbers.
67 //
68 // Example results when comparing two versions strings:
69 // "1.9" precedes "1.14";
70 // "1.14" precedes "1.14.1";
71 // "1.14" and "1.14.0" are equal.
72 //
73 // Returns the value 0 if the two versions are equal; a value less than 0 if
74 // *this precedes v; a value greater than 0 if v precedes *this.
Compare(const Version & v)75 int Compare(const Version& v) {
76 for (int i = 0; i < kElementNumber; ++i) {
77 if (version_[i] != v.version_[i]) {
78 return version_[i] < v.version_[i] ? -1 : 1;
79 }
80 }
81 return 0;
82 }
83
84 // Converts version_ into a version string.
ToString()85 std::string ToString() { return absl::StrJoin(version_, "."); }
86
87 private:
88 static constexpr int kElementNumber = 3;
89 std::array<int, kElementNumber> version_;
90 };
91
GetMemberVersion(SchemaMembers member)92 Version GetMemberVersion(SchemaMembers member) {
93 switch (member) {
94 case SchemaMembers::kAssociatedFileTypeVocabulary:
95 return Version(1, 0, 1);
96 case SchemaMembers::kSubGraphMetadataInputProcessUnits:
97 return Version(1, 1, 0);
98 case SchemaMembers::kSubGraphMetadataOutputProcessUnits:
99 return Version(1, 1, 0);
100 case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions:
101 return Version(1, 1, 0);
102 case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions:
103 return Version(1, 1, 0);
104 case SchemaMembers::kSubGraphMetadataInputTensorGroups:
105 return Version(1, 2, 0);
106 case SchemaMembers::kSubGraphMetadataOutputTensorGroups:
107 return Version(1, 2, 0);
108 case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions:
109 return Version(1, 2, 1);
110 default:
111 // Should never happen.
112 TFLITE_LOG(FATAL) << "Unsupported schema member: "
113 << static_cast<int>(member);
114 }
115 // Should never happen.
116 return Version(0, 0, 0);
117 }
118
119 // Updates min_version if it precedes the new_version.
UpdateMinimumVersion(const Version & new_version,Version * min_version)120 inline void UpdateMinimumVersion(const Version& new_version,
121 Version* min_version) {
122 if (min_version->Compare(new_version) < 0) {
123 *min_version = new_version;
124 }
125 }
126
127 template <typename T>
128 void UpdateMinimumVersionForTable(const T* table, Version* min_version);
129
130 template <typename T>
UpdateMinimumVersionForArray(const flatbuffers::Vector<flatbuffers::Offset<T>> * array,Version * min_version)131 void UpdateMinimumVersionForArray(
132 const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
133 Version* min_version) {
134 if (array == nullptr) return;
135
136 for (int i = 0; i < array->size(); ++i) {
137 UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
138 }
139 }
140
141 template <>
UpdateMinimumVersionForTable(const tflite::AssociatedFile * table,Version * min_version)142 void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
143 const tflite::AssociatedFile* table, Version* min_version) {
144 if (table == nullptr) return;
145
146 if (table->type() == AssociatedFileType_VOCABULARY) {
147 UpdateMinimumVersion(
148 GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
149 min_version);
150 }
151 }
152
153 template <>
UpdateMinimumVersionForTable(const tflite::ProcessUnit * table,Version * min_version)154 void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
155 const tflite::ProcessUnit* table, Version* min_version) {
156 if (table == nullptr) return;
157
158 tflite::ProcessUnitOptions process_unit_type = table->options_type();
159 if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
160 UpdateMinimumVersion(
161 GetMemberVersion(
162 SchemaMembers::kProcessUnitOptionsBertTokenizerOptions),
163 min_version);
164 }
165 if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) {
166 UpdateMinimumVersion(
167 GetMemberVersion(
168 SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions),
169 min_version);
170 }
171 if (process_unit_type == ProcessUnitOptions_RegexTokenizerOptions) {
172 UpdateMinimumVersion(
173 GetMemberVersion(
174 SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions),
175 min_version);
176 }
177 }
178
179 template <>
UpdateMinimumVersionForTable(const tflite::TensorMetadata * table,Version * min_version)180 void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
181 const tflite::TensorMetadata* table, Version* min_version) {
182 if (table == nullptr) return;
183
184 // Checks the associated_files field.
185 UpdateMinimumVersionForArray<tflite::AssociatedFile>(
186 table->associated_files(), min_version);
187
188 // Checks the process_units field.
189 UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(),
190 min_version);
191 }
192
193 template <>
UpdateMinimumVersionForTable(const tflite::SubGraphMetadata * table,Version * min_version)194 void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
195 const tflite::SubGraphMetadata* table, Version* min_version) {
196 if (table == nullptr) return;
197
198 // Checks in the input/output metadata arrays.
199 UpdateMinimumVersionForArray<tflite::TensorMetadata>(
200 table->input_tensor_metadata(), min_version);
201 UpdateMinimumVersionForArray<tflite::TensorMetadata>(
202 table->output_tensor_metadata(), min_version);
203
204 // Checks the associated_files field.
205 UpdateMinimumVersionForArray<tflite::AssociatedFile>(
206 table->associated_files(), min_version);
207
208 // Checks for the input_process_units field.
209 if (table->input_process_units() != nullptr) {
210 UpdateMinimumVersion(
211 GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits),
212 min_version);
213 UpdateMinimumVersionForArray<tflite::ProcessUnit>(
214 table->input_process_units(), min_version);
215 }
216
217 // Checks for the output_process_units field.
218 if (table->output_process_units() != nullptr) {
219 UpdateMinimumVersion(
220 GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits),
221 min_version);
222 UpdateMinimumVersionForArray<tflite::ProcessUnit>(
223 table->output_process_units(), min_version);
224 }
225
226 // Checks for the input_tensor_groups field.
227 if (table->input_tensor_groups() != nullptr) {
228 UpdateMinimumVersion(
229 GetMemberVersion(SchemaMembers::kSubGraphMetadataInputTensorGroups),
230 min_version);
231 }
232
233 // Checks for the output_tensor_groups field.
234 if (table->output_tensor_groups() != nullptr) {
235 UpdateMinimumVersion(
236 GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
237 min_version);
238 }
239 }
240
241 template <>
UpdateMinimumVersionForTable(const tflite::ModelMetadata * table,Version * min_version)242 void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
243 const tflite::ModelMetadata* table, Version* min_version) {
244 if (table == nullptr) {
245 // Should never happen, because VerifyModelMetadataBuffer has verified it.
246 TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
247 return;
248 }
249
250 // Checks the subgraph_metadata field.
251 if (table->subgraph_metadata() != nullptr) {
252 for (int i = 0; i < table->subgraph_metadata()->size(); ++i) {
253 UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
254 table->subgraph_metadata()->Get(i), min_version);
255 }
256 }
257
258 // Checks the associated_files field.
259 UpdateMinimumVersionForArray<tflite::AssociatedFile>(
260 table->associated_files(), min_version);
261 }
262
263 } // namespace
264
GetMinimumMetadataParserVersion(const uint8_t * buffer_data,size_t buffer_size,std::string * min_version_str)265 TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
266 size_t buffer_size,
267 std::string* min_version_str) {
268 flatbuffers::Verifier verifier =
269 flatbuffers::Verifier(buffer_data, buffer_size);
270 if (!tflite::VerifyModelMetadataBuffer(verifier)) {
271 TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
272 return kTfLiteError;
273 }
274
275 static constexpr char kDefaultVersion[] = "1.0.0";
276 Version min_version = Version(kDefaultVersion);
277
278 // Checks if any member declared after 1.0.0 (such as those in
279 // SchemaMembers) exists, and updates min_version accordingly. The minimum
280 // metadata parser version will be the largest version number of all fields
281 // that has been added to a metadata flatbuffer
282 const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
283
284 // All tables in the metadata schema should have their dedicated
285 // UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually
286 // add these methods when new fields show up in later schema versions.
287 //
288 // UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The
289 // pointer can be a nullptr if Foo is not populated into the corresponding
290 // table of the Flatbuffer object. In this case,
291 // UpdateMinimumVersionFor<Foo>() will be skipped. An exception is
292 // UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root
293 // table, and it won't be null.
294 UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata,
295 &min_version);
296
297 *min_version_str = min_version.ToString();
298 return kTfLiteOk;
299 }
300
301 } // namespace metadata
302 } // namespace tflite
303