• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/metadata/cc/metadata_version.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <array>
21 #include <ostream>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/tools/logging.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32 
33 namespace tflite {
34 namespace metadata {
35 namespace {
36 
37 // Members that are added to the metadata schema after the initial version
38 // of 1.0.0.
39 enum class SchemaMembers {
40   kAssociatedFileTypeVocabulary = 0,
41   kSubGraphMetadataInputProcessUnits = 1,
42   kSubGraphMetadataOutputProcessUnits = 2,
43   kProcessUnitOptionsBertTokenizerOptions = 3,
44   kProcessUnitOptionsSentencePieceTokenizerOptions = 4,
45   kSubGraphMetadataInputTensorGroups = 5,
46   kSubGraphMetadataOutputTensorGroups = 6,
47   kProcessUnitOptionsRegexTokenizerOptions = 7,
48   kContentPropertiesAudioProperties = 8,
49   kAssociatedFileTypeScannIndexFile = 9,
50   kAssociatedFileVersion = 10,
51 };
52 
53 // Helper class to compare semantic versions in terms of three integers, major,
54 // minor, and patch.
55 class Version {
56  public:
Version(int major,int minor=0,int patch=0)57   explicit Version(int major, int minor = 0, int patch = 0)
58       : version_({major, minor, patch}) {}
59 
Version(const std::string & version)60   explicit Version(const std::string& version) {
61     const std::vector<std::string> vec = absl::StrSplit(version, '.');
62     // The version string should always be less than four numbers.
63     TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
64     version_[0] = std::stoi(vec[0]);
65     version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
66     version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
67   }
68 
69   // Compares two semantic version numbers.
70   //
71   // Example results when comparing two versions strings:
72   //   "1.9" precedes "1.14";
73   //   "1.14" precedes "1.14.1";
74   //   "1.14" and "1.14.0" are equal.
75   //
76   // Returns the value 0 if the two versions are equal; a value less than 0 if
77   // *this precedes v; a value greater than 0 if v precedes *this.
Compare(const Version & v)78   int Compare(const Version& v) {
79     for (int i = 0; i < kElementNumber; ++i) {
80       if (version_[i] != v.version_[i]) {
81         return version_[i] < v.version_[i] ? -1 : 1;
82       }
83     }
84     return 0;
85   }
86 
87   // Converts version_ into a version string.
ToString()88   std::string ToString() { return absl::StrJoin(version_, "."); }
89 
90  private:
91   static constexpr int kElementNumber = 3;
92   std::array<int, kElementNumber> version_;
93 };
94 
GetMemberVersion(SchemaMembers member)95 Version GetMemberVersion(SchemaMembers member) {
96   switch (member) {
97     case SchemaMembers::kAssociatedFileTypeVocabulary:
98       return Version(1, 0, 1);
99     case SchemaMembers::kSubGraphMetadataInputProcessUnits:
100       return Version(1, 1, 0);
101     case SchemaMembers::kSubGraphMetadataOutputProcessUnits:
102       return Version(1, 1, 0);
103     case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions:
104       return Version(1, 1, 0);
105     case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions:
106       return Version(1, 1, 0);
107     case SchemaMembers::kSubGraphMetadataInputTensorGroups:
108       return Version(1, 2, 0);
109     case SchemaMembers::kSubGraphMetadataOutputTensorGroups:
110       return Version(1, 2, 0);
111     case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions:
112       return Version(1, 2, 1);
113     case SchemaMembers::kContentPropertiesAudioProperties:
114       return Version(1, 3, 0);
115     case SchemaMembers::kAssociatedFileTypeScannIndexFile:
116       return Version(1, 4, 0);
117     case SchemaMembers::kAssociatedFileVersion:
118       return Version(1, 4, 1);
119     default:
120       // Should never happen.
121       TFLITE_LOG(FATAL) << "Unsupported schema member: "
122                         << static_cast<int>(member);
123   }
124   // Should never happen.
125   return Version(0, 0, 0);
126 }
127 
128 // Updates min_version if it precedes the new_version.
UpdateMinimumVersion(const Version & new_version,Version * min_version)129 inline void UpdateMinimumVersion(const Version& new_version,
130                                  Version* min_version) {
131   if (min_version->Compare(new_version) < 0) {
132     *min_version = new_version;
133   }
134 }
135 
136 template <typename T>
137 void UpdateMinimumVersionForTable(const T* table, Version* min_version);
138 
139 template <typename T>
UpdateMinimumVersionForArray(const flatbuffers::Vector<flatbuffers::Offset<T>> * array,Version * min_version)140 void UpdateMinimumVersionForArray(
141     const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
142     Version* min_version) {
143   if (array == nullptr) return;
144 
145   for (int i = 0; i < array->size(); ++i) {
146     UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
147   }
148 }
149 
150 template <>
UpdateMinimumVersionForTable(const tflite::AssociatedFile * table,Version * min_version)151 void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
152     const tflite::AssociatedFile* table, Version* min_version) {
153   if (table == nullptr) return;
154 
155   if (table->type() == AssociatedFileType_VOCABULARY) {
156     UpdateMinimumVersion(
157         GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
158         min_version);
159   }
160 
161   if (table->type() == AssociatedFileType_SCANN_INDEX_FILE) {
162     UpdateMinimumVersion(
163         GetMemberVersion(SchemaMembers::kAssociatedFileTypeScannIndexFile),
164         min_version);
165   }
166 
167   if (table->version() != nullptr) {
168     UpdateMinimumVersion(
169         GetMemberVersion(SchemaMembers::kAssociatedFileVersion),
170         min_version);
171   }
172 }
173 
174 template <>
UpdateMinimumVersionForTable(const tflite::ProcessUnit * table,Version * min_version)175 void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
176     const tflite::ProcessUnit* table, Version* min_version) {
177   if (table == nullptr) return;
178 
179   tflite::ProcessUnitOptions process_unit_type = table->options_type();
180   if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
181     UpdateMinimumVersion(
182         GetMemberVersion(
183             SchemaMembers::kProcessUnitOptionsBertTokenizerOptions),
184         min_version);
185   }
186   if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) {
187     UpdateMinimumVersion(
188         GetMemberVersion(
189             SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions),
190         min_version);
191   }
192   if (process_unit_type == ProcessUnitOptions_RegexTokenizerOptions) {
193     UpdateMinimumVersion(
194         GetMemberVersion(
195             SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions),
196         min_version);
197   }
198 }
199 
200 template <>
UpdateMinimumVersionForTable(const tflite::Content * table,Version * min_version)201 void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
202                                                    Version* min_version) {
203   if (table == nullptr) return;
204 
205   // Checks the ContenProperties field.
206   if (table->content_properties_type() == ContentProperties_AudioProperties) {
207     UpdateMinimumVersion(
208         GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),
209         min_version);
210   }
211 }
212 
213 template <>
UpdateMinimumVersionForTable(const tflite::TensorMetadata * table,Version * min_version)214 void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
215     const tflite::TensorMetadata* table, Version* min_version) {
216   if (table == nullptr) return;
217 
218   // Checks the associated_files field.
219   UpdateMinimumVersionForArray<tflite::AssociatedFile>(
220       table->associated_files(), min_version);
221 
222   // Checks the process_units field.
223   UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(),
224                                                     min_version);
225 
226   // Check the content field.
227   UpdateMinimumVersionForTable<tflite::Content>(table->content(), min_version);
228 }
229 
230 template <>
UpdateMinimumVersionForTable(const tflite::SubGraphMetadata * table,Version * min_version)231 void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
232     const tflite::SubGraphMetadata* table, Version* min_version) {
233   if (table == nullptr) return;
234 
235   // Checks in the input/output metadata arrays.
236   UpdateMinimumVersionForArray<tflite::TensorMetadata>(
237       table->input_tensor_metadata(), min_version);
238   UpdateMinimumVersionForArray<tflite::TensorMetadata>(
239       table->output_tensor_metadata(), min_version);
240 
241   // Checks the associated_files field.
242   UpdateMinimumVersionForArray<tflite::AssociatedFile>(
243       table->associated_files(), min_version);
244 
245   // Checks for the input_process_units field.
246   if (table->input_process_units() != nullptr) {
247     UpdateMinimumVersion(
248         GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits),
249         min_version);
250     UpdateMinimumVersionForArray<tflite::ProcessUnit>(
251         table->input_process_units(), min_version);
252   }
253 
254   // Checks for the output_process_units field.
255   if (table->output_process_units() != nullptr) {
256     UpdateMinimumVersion(
257         GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits),
258         min_version);
259     UpdateMinimumVersionForArray<tflite::ProcessUnit>(
260         table->output_process_units(), min_version);
261   }
262 
263   // Checks for the input_tensor_groups field.
264   if (table->input_tensor_groups() != nullptr) {
265     UpdateMinimumVersion(
266         GetMemberVersion(SchemaMembers::kSubGraphMetadataInputTensorGroups),
267         min_version);
268   }
269 
270   // Checks for the output_tensor_groups field.
271   if (table->output_tensor_groups() != nullptr) {
272     UpdateMinimumVersion(
273         GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
274         min_version);
275   }
276 }
277 
278 template <>
UpdateMinimumVersionForTable(const tflite::ModelMetadata * table,Version * min_version)279 void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
280     const tflite::ModelMetadata* table, Version* min_version) {
281   if (table == nullptr) {
282     // Should never happen, because VerifyModelMetadataBuffer has verified it.
283     TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
284     return;
285   }
286 
287   // Checks the subgraph_metadata field.
288   if (table->subgraph_metadata() != nullptr) {
289     for (int i = 0; i < table->subgraph_metadata()->size(); ++i) {
290       UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
291           table->subgraph_metadata()->Get(i), min_version);
292     }
293   }
294 
295   // Checks the associated_files field.
296   UpdateMinimumVersionForArray<tflite::AssociatedFile>(
297       table->associated_files(), min_version);
298 }
299 
300 }  // namespace
301 
GetMinimumMetadataParserVersion(const uint8_t * buffer_data,size_t buffer_size,std::string * min_version_str)302 TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
303                                              size_t buffer_size,
304                                              std::string* min_version_str) {
305   flatbuffers::Verifier verifier =
306       flatbuffers::Verifier(buffer_data, buffer_size);
307   if (!tflite::VerifyModelMetadataBuffer(verifier)) {
308     TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
309     return kTfLiteError;
310   }
311 
312   static constexpr char kDefaultVersion[] = "1.0.0";
313   Version min_version = Version(kDefaultVersion);
314 
315   // Checks if any member declared after 1.0.0 (such as those in
316   // SchemaMembers) exists, and updates min_version accordingly. The minimum
317   // metadata parser version will be the largest version number of all fields
318   // that has been added to a metadata flatbuffer
319   const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
320 
321   // All tables in the metadata schema should have their dedicated
322   // UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually
323   // add these methods when new fields show up in later schema versions.
324   //
325   // UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The
326   // pointer can be a nullptr if Foo is not populated into the corresponding
327   // table of the Flatbuffer object. In this case,
328   // UpdateMinimumVersionFor<Foo>() will be skipped. An exception is
329   // UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root
330   // table, and it won't be null.
331   UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata,
332                                                       &min_version);
333 
334   *min_version_str = min_version.ToString();
335   return kTfLiteOk;
336 }
337 
338 }  // namespace metadata
339 }  // namespace tflite
340