1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/metadata/cc/metadata_version.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19
20 #include <array>
21 #include <ostream>
22 #include <string>
23 #include <vector>
24
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/tools/logging.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32
33 namespace tflite {
34 namespace metadata {
35 namespace {
36
37 // Members that are added to the metadata schema after the initial version
38 // of 1.0.0.
39 enum class SchemaMembers {
40 kAssociatedFileTypeVocabulary = 0,
41 kSubGraphMetadataInputProcessUnits = 1,
42 kSubGraphMetadataOutputProcessUnits = 2,
43 kProcessUnitOptionsBertTokenizerOptions = 3,
44 kProcessUnitOptionsSentencePieceTokenizerOptions = 4,
45 kSubGraphMetadataInputTensorGroups = 5,
46 kSubGraphMetadataOutputTensorGroups = 6,
47 kProcessUnitOptionsRegexTokenizerOptions = 7,
48 kContentPropertiesAudioProperties = 8,
49 kAssociatedFileTypeScannIndexFile = 9,
50 kAssociatedFileVersion = 10,
51 };
52
53 // Helper class to compare semantic versions in terms of three integers, major,
54 // minor, and patch.
55 class Version {
56 public:
Version(int major,int minor=0,int patch=0)57 explicit Version(int major, int minor = 0, int patch = 0)
58 : version_({major, minor, patch}) {}
59
Version(const std::string & version)60 explicit Version(const std::string& version) {
61 const std::vector<std::string> vec = absl::StrSplit(version, '.');
62 // The version string should always be less than four numbers.
63 TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
64 version_[0] = std::stoi(vec[0]);
65 version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
66 version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
67 }
68
69 // Compares two semantic version numbers.
70 //
71 // Example results when comparing two versions strings:
72 // "1.9" precedes "1.14";
73 // "1.14" precedes "1.14.1";
74 // "1.14" and "1.14.0" are equal.
75 //
76 // Returns the value 0 if the two versions are equal; a value less than 0 if
77 // *this precedes v; a value greater than 0 if v precedes *this.
Compare(const Version & v)78 int Compare(const Version& v) {
79 for (int i = 0; i < kElementNumber; ++i) {
80 if (version_[i] != v.version_[i]) {
81 return version_[i] < v.version_[i] ? -1 : 1;
82 }
83 }
84 return 0;
85 }
86
87 // Converts version_ into a version string.
ToString()88 std::string ToString() { return absl::StrJoin(version_, "."); }
89
90 private:
91 static constexpr int kElementNumber = 3;
92 std::array<int, kElementNumber> version_;
93 };
94
GetMemberVersion(SchemaMembers member)95 Version GetMemberVersion(SchemaMembers member) {
96 switch (member) {
97 case SchemaMembers::kAssociatedFileTypeVocabulary:
98 return Version(1, 0, 1);
99 case SchemaMembers::kSubGraphMetadataInputProcessUnits:
100 return Version(1, 1, 0);
101 case SchemaMembers::kSubGraphMetadataOutputProcessUnits:
102 return Version(1, 1, 0);
103 case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions:
104 return Version(1, 1, 0);
105 case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions:
106 return Version(1, 1, 0);
107 case SchemaMembers::kSubGraphMetadataInputTensorGroups:
108 return Version(1, 2, 0);
109 case SchemaMembers::kSubGraphMetadataOutputTensorGroups:
110 return Version(1, 2, 0);
111 case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions:
112 return Version(1, 2, 1);
113 case SchemaMembers::kContentPropertiesAudioProperties:
114 return Version(1, 3, 0);
115 case SchemaMembers::kAssociatedFileTypeScannIndexFile:
116 return Version(1, 4, 0);
117 case SchemaMembers::kAssociatedFileVersion:
118 return Version(1, 4, 1);
119 default:
120 // Should never happen.
121 TFLITE_LOG(FATAL) << "Unsupported schema member: "
122 << static_cast<int>(member);
123 }
124 // Should never happen.
125 return Version(0, 0, 0);
126 }
127
128 // Updates min_version if it precedes the new_version.
UpdateMinimumVersion(const Version & new_version,Version * min_version)129 inline void UpdateMinimumVersion(const Version& new_version,
130 Version* min_version) {
131 if (min_version->Compare(new_version) < 0) {
132 *min_version = new_version;
133 }
134 }
135
136 template <typename T>
137 void UpdateMinimumVersionForTable(const T* table, Version* min_version);
138
139 template <typename T>
UpdateMinimumVersionForArray(const flatbuffers::Vector<flatbuffers::Offset<T>> * array,Version * min_version)140 void UpdateMinimumVersionForArray(
141 const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
142 Version* min_version) {
143 if (array == nullptr) return;
144
145 for (int i = 0; i < array->size(); ++i) {
146 UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
147 }
148 }
149
150 template <>
UpdateMinimumVersionForTable(const tflite::AssociatedFile * table,Version * min_version)151 void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
152 const tflite::AssociatedFile* table, Version* min_version) {
153 if (table == nullptr) return;
154
155 if (table->type() == AssociatedFileType_VOCABULARY) {
156 UpdateMinimumVersion(
157 GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
158 min_version);
159 }
160
161 if (table->type() == AssociatedFileType_SCANN_INDEX_FILE) {
162 UpdateMinimumVersion(
163 GetMemberVersion(SchemaMembers::kAssociatedFileTypeScannIndexFile),
164 min_version);
165 }
166
167 if (table->version() != nullptr) {
168 UpdateMinimumVersion(
169 GetMemberVersion(SchemaMembers::kAssociatedFileVersion),
170 min_version);
171 }
172 }
173
174 template <>
UpdateMinimumVersionForTable(const tflite::ProcessUnit * table,Version * min_version)175 void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
176 const tflite::ProcessUnit* table, Version* min_version) {
177 if (table == nullptr) return;
178
179 tflite::ProcessUnitOptions process_unit_type = table->options_type();
180 if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
181 UpdateMinimumVersion(
182 GetMemberVersion(
183 SchemaMembers::kProcessUnitOptionsBertTokenizerOptions),
184 min_version);
185 }
186 if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) {
187 UpdateMinimumVersion(
188 GetMemberVersion(
189 SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions),
190 min_version);
191 }
192 if (process_unit_type == ProcessUnitOptions_RegexTokenizerOptions) {
193 UpdateMinimumVersion(
194 GetMemberVersion(
195 SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions),
196 min_version);
197 }
198 }
199
200 template <>
UpdateMinimumVersionForTable(const tflite::Content * table,Version * min_version)201 void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
202 Version* min_version) {
203 if (table == nullptr) return;
204
205 // Checks the ContenProperties field.
206 if (table->content_properties_type() == ContentProperties_AudioProperties) {
207 UpdateMinimumVersion(
208 GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),
209 min_version);
210 }
211 }
212
213 template <>
UpdateMinimumVersionForTable(const tflite::TensorMetadata * table,Version * min_version)214 void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
215 const tflite::TensorMetadata* table, Version* min_version) {
216 if (table == nullptr) return;
217
218 // Checks the associated_files field.
219 UpdateMinimumVersionForArray<tflite::AssociatedFile>(
220 table->associated_files(), min_version);
221
222 // Checks the process_units field.
223 UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(),
224 min_version);
225
226 // Check the content field.
227 UpdateMinimumVersionForTable<tflite::Content>(table->content(), min_version);
228 }
229
230 template <>
UpdateMinimumVersionForTable(const tflite::SubGraphMetadata * table,Version * min_version)231 void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
232 const tflite::SubGraphMetadata* table, Version* min_version) {
233 if (table == nullptr) return;
234
235 // Checks in the input/output metadata arrays.
236 UpdateMinimumVersionForArray<tflite::TensorMetadata>(
237 table->input_tensor_metadata(), min_version);
238 UpdateMinimumVersionForArray<tflite::TensorMetadata>(
239 table->output_tensor_metadata(), min_version);
240
241 // Checks the associated_files field.
242 UpdateMinimumVersionForArray<tflite::AssociatedFile>(
243 table->associated_files(), min_version);
244
245 // Checks for the input_process_units field.
246 if (table->input_process_units() != nullptr) {
247 UpdateMinimumVersion(
248 GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits),
249 min_version);
250 UpdateMinimumVersionForArray<tflite::ProcessUnit>(
251 table->input_process_units(), min_version);
252 }
253
254 // Checks for the output_process_units field.
255 if (table->output_process_units() != nullptr) {
256 UpdateMinimumVersion(
257 GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits),
258 min_version);
259 UpdateMinimumVersionForArray<tflite::ProcessUnit>(
260 table->output_process_units(), min_version);
261 }
262
263 // Checks for the input_tensor_groups field.
264 if (table->input_tensor_groups() != nullptr) {
265 UpdateMinimumVersion(
266 GetMemberVersion(SchemaMembers::kSubGraphMetadataInputTensorGroups),
267 min_version);
268 }
269
270 // Checks for the output_tensor_groups field.
271 if (table->output_tensor_groups() != nullptr) {
272 UpdateMinimumVersion(
273 GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
274 min_version);
275 }
276 }
277
278 template <>
UpdateMinimumVersionForTable(const tflite::ModelMetadata * table,Version * min_version)279 void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
280 const tflite::ModelMetadata* table, Version* min_version) {
281 if (table == nullptr) {
282 // Should never happen, because VerifyModelMetadataBuffer has verified it.
283 TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
284 return;
285 }
286
287 // Checks the subgraph_metadata field.
288 if (table->subgraph_metadata() != nullptr) {
289 for (int i = 0; i < table->subgraph_metadata()->size(); ++i) {
290 UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
291 table->subgraph_metadata()->Get(i), min_version);
292 }
293 }
294
295 // Checks the associated_files field.
296 UpdateMinimumVersionForArray<tflite::AssociatedFile>(
297 table->associated_files(), min_version);
298 }
299
300 } // namespace
301
GetMinimumMetadataParserVersion(const uint8_t * buffer_data,size_t buffer_size,std::string * min_version_str)302 TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
303 size_t buffer_size,
304 std::string* min_version_str) {
305 flatbuffers::Verifier verifier =
306 flatbuffers::Verifier(buffer_data, buffer_size);
307 if (!tflite::VerifyModelMetadataBuffer(verifier)) {
308 TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
309 return kTfLiteError;
310 }
311
312 static constexpr char kDefaultVersion[] = "1.0.0";
313 Version min_version = Version(kDefaultVersion);
314
315 // Checks if any member declared after 1.0.0 (such as those in
316 // SchemaMembers) exists, and updates min_version accordingly. The minimum
317 // metadata parser version will be the largest version number of all fields
318 // that has been added to a metadata flatbuffer
319 const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
320
321 // All tables in the metadata schema should have their dedicated
322 // UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually
323 // add these methods when new fields show up in later schema versions.
324 //
325 // UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The
326 // pointer can be a nullptr if Foo is not populated into the corresponding
327 // table of the Flatbuffer object. In this case,
328 // UpdateMinimumVersionFor<Foo>() will be skipped. An exception is
329 // UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root
330 // table, and it won't be null.
331 UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata,
332 &min_version);
333
334 *min_version_str = min_version.ToString();
335 return kTfLiteOk;
336 }
337
338 } // namespace metadata
339 } // namespace tflite
340