• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow_lite_support/metadata/cc/metadata_version.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <array>
21 #include <ostream>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/str_split.h"
27 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/tools/logging.h"
31 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
32 
33 namespace tflite {
34 namespace metadata {
35 namespace {
36 
37 // Members that are added to the metadata schema after the initial version
38 // of 1.0.0.
39 enum class SchemaMembers {
40   kAssociatedFileTypeVocabulary = 0,
41   kSubGraphMetadataInputProcessUnits = 1,
42   kSubGraphMetadataOutputProcessUnits = 2,
43   kProcessUnitOptionsBertTokenizerOptions = 3,
44   kProcessUnitOptionsSentencePieceTokenizerOptions = 4,
45   kSubGraphMetadataInputTensorGroups = 5,
46   kSubGraphMetadataOutputTensorGroups = 6,
47   kProcessUnitOptionsRegexTokenizerOptions = 7,
48 };
49 
50 // Helper class to compare semantic versions in terms of three integers, major,
51 // minor, and patch.
52 class Version {
53  public:
Version(int major,int minor=0,int patch=0)54   explicit Version(int major, int minor = 0, int patch = 0)
55       : version_({major, minor, patch}) {}
56 
Version(const std::string & version)57   explicit Version(const std::string& version) {
58     const std::vector<std::string> vec = absl::StrSplit(version, '.');
59     // The version string should always be less than four numbers.
60     TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
61     version_[0] = std::stoi(vec[0]);
62     version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
63     version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
64   }
65 
66   // Compares two semantic version numbers.
67   //
68   // Example results when comparing two versions strings:
69   //   "1.9" precedes "1.14";
70   //   "1.14" precedes "1.14.1";
71   //   "1.14" and "1.14.0" are equal.
72   //
73   // Returns the value 0 if the two versions are equal; a value less than 0 if
74   // *this precedes v; a value greater than 0 if v precedes *this.
Compare(const Version & v)75   int Compare(const Version& v) {
76     for (int i = 0; i < kElementNumber; ++i) {
77       if (version_[i] != v.version_[i]) {
78         return version_[i] < v.version_[i] ? -1 : 1;
79       }
80     }
81     return 0;
82   }
83 
84   // Converts version_ into a version string.
ToString()85   std::string ToString() { return absl::StrJoin(version_, "."); }
86 
87  private:
88   static constexpr int kElementNumber = 3;
89   std::array<int, kElementNumber> version_;
90 };
91 
GetMemberVersion(SchemaMembers member)92 Version GetMemberVersion(SchemaMembers member) {
93   switch (member) {
94     case SchemaMembers::kAssociatedFileTypeVocabulary:
95       return Version(1, 0, 1);
96     case SchemaMembers::kSubGraphMetadataInputProcessUnits:
97       return Version(1, 1, 0);
98     case SchemaMembers::kSubGraphMetadataOutputProcessUnits:
99       return Version(1, 1, 0);
100     case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions:
101       return Version(1, 1, 0);
102     case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions:
103       return Version(1, 1, 0);
104     case SchemaMembers::kSubGraphMetadataInputTensorGroups:
105       return Version(1, 2, 0);
106     case SchemaMembers::kSubGraphMetadataOutputTensorGroups:
107       return Version(1, 2, 0);
108     case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions:
109       return Version(1, 2, 1);
110     default:
111       // Should never happen.
112       TFLITE_LOG(FATAL) << "Unsupported schema member: "
113                         << static_cast<int>(member);
114   }
115   // Should never happen.
116   return Version(0, 0, 0);
117 }
118 
119 // Updates min_version if it precedes the new_version.
UpdateMinimumVersion(const Version & new_version,Version * min_version)120 inline void UpdateMinimumVersion(const Version& new_version,
121                                  Version* min_version) {
122   if (min_version->Compare(new_version) < 0) {
123     *min_version = new_version;
124   }
125 }
126 
127 template <typename T>
128 void UpdateMinimumVersionForTable(const T* table, Version* min_version);
129 
130 template <typename T>
UpdateMinimumVersionForArray(const flatbuffers::Vector<flatbuffers::Offset<T>> * array,Version * min_version)131 void UpdateMinimumVersionForArray(
132     const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
133     Version* min_version) {
134   if (array == nullptr) return;
135 
136   for (int i = 0; i < array->size(); ++i) {
137     UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
138   }
139 }
140 
141 template <>
UpdateMinimumVersionForTable(const tflite::AssociatedFile * table,Version * min_version)142 void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
143     const tflite::AssociatedFile* table, Version* min_version) {
144   if (table == nullptr) return;
145 
146   if (table->type() == AssociatedFileType_VOCABULARY) {
147     UpdateMinimumVersion(
148         GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
149         min_version);
150   }
151 }
152 
153 template <>
UpdateMinimumVersionForTable(const tflite::ProcessUnit * table,Version * min_version)154 void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
155     const tflite::ProcessUnit* table, Version* min_version) {
156   if (table == nullptr) return;
157 
158   tflite::ProcessUnitOptions process_unit_type = table->options_type();
159   if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
160     UpdateMinimumVersion(
161         GetMemberVersion(
162             SchemaMembers::kProcessUnitOptionsBertTokenizerOptions),
163         min_version);
164   }
165   if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) {
166     UpdateMinimumVersion(
167         GetMemberVersion(
168             SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions),
169         min_version);
170   }
171   if (process_unit_type == ProcessUnitOptions_RegexTokenizerOptions) {
172     UpdateMinimumVersion(
173         GetMemberVersion(
174             SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions),
175         min_version);
176   }
177 }
178 
179 template <>
UpdateMinimumVersionForTable(const tflite::TensorMetadata * table,Version * min_version)180 void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
181     const tflite::TensorMetadata* table, Version* min_version) {
182   if (table == nullptr) return;
183 
184   // Checks the associated_files field.
185   UpdateMinimumVersionForArray<tflite::AssociatedFile>(
186       table->associated_files(), min_version);
187 
188   // Checks the process_units field.
189   UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(),
190                                                     min_version);
191 }
192 
193 template <>
UpdateMinimumVersionForTable(const tflite::SubGraphMetadata * table,Version * min_version)194 void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
195     const tflite::SubGraphMetadata* table, Version* min_version) {
196   if (table == nullptr) return;
197 
198   // Checks in the input/output metadata arrays.
199   UpdateMinimumVersionForArray<tflite::TensorMetadata>(
200       table->input_tensor_metadata(), min_version);
201   UpdateMinimumVersionForArray<tflite::TensorMetadata>(
202       table->output_tensor_metadata(), min_version);
203 
204   // Checks the associated_files field.
205   UpdateMinimumVersionForArray<tflite::AssociatedFile>(
206       table->associated_files(), min_version);
207 
208   // Checks for the input_process_units field.
209   if (table->input_process_units() != nullptr) {
210     UpdateMinimumVersion(
211         GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits),
212         min_version);
213     UpdateMinimumVersionForArray<tflite::ProcessUnit>(
214         table->input_process_units(), min_version);
215   }
216 
217   // Checks for the output_process_units field.
218   if (table->output_process_units() != nullptr) {
219     UpdateMinimumVersion(
220         GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits),
221         min_version);
222     UpdateMinimumVersionForArray<tflite::ProcessUnit>(
223         table->output_process_units(), min_version);
224   }
225 
226   // Checks for the input_tensor_groups field.
227   if (table->input_tensor_groups() != nullptr) {
228     UpdateMinimumVersion(
229         GetMemberVersion(SchemaMembers::kSubGraphMetadataInputTensorGroups),
230         min_version);
231   }
232 
233   // Checks for the output_tensor_groups field.
234   if (table->output_tensor_groups() != nullptr) {
235     UpdateMinimumVersion(
236         GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
237         min_version);
238   }
239 }
240 
241 template <>
UpdateMinimumVersionForTable(const tflite::ModelMetadata * table,Version * min_version)242 void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
243     const tflite::ModelMetadata* table, Version* min_version) {
244   if (table == nullptr) {
245     // Should never happen, because VerifyModelMetadataBuffer has verified it.
246     TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
247     return;
248   }
249 
250   // Checks the subgraph_metadata field.
251   if (table->subgraph_metadata() != nullptr) {
252     for (int i = 0; i < table->subgraph_metadata()->size(); ++i) {
253       UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
254           table->subgraph_metadata()->Get(i), min_version);
255     }
256   }
257 
258   // Checks the associated_files field.
259   UpdateMinimumVersionForArray<tflite::AssociatedFile>(
260       table->associated_files(), min_version);
261 }
262 
263 }  // namespace
264 
GetMinimumMetadataParserVersion(const uint8_t * buffer_data,size_t buffer_size,std::string * min_version_str)265 TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
266                                              size_t buffer_size,
267                                              std::string* min_version_str) {
268   flatbuffers::Verifier verifier =
269       flatbuffers::Verifier(buffer_data, buffer_size);
270   if (!tflite::VerifyModelMetadataBuffer(verifier)) {
271     TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
272     return kTfLiteError;
273   }
274 
275   static constexpr char kDefaultVersion[] = "1.0.0";
276   Version min_version = Version(kDefaultVersion);
277 
278   // Checks if any member declared after 1.0.0 (such as those in
279   // SchemaMembers) exists, and updates min_version accordingly. The minimum
280   // metadata parser version will be the largest version number of all fields
281   // that has been added to a metadata flatbuffer
282   const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
283 
284   // All tables in the metadata schema should have their dedicated
285   // UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually
286   // add these methods when new fields show up in later schema versions.
287   //
288   // UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The
289   // pointer can be a nullptr if Foo is not populated into the corresponding
290   // table of the Flatbuffer object. In this case,
291   // UpdateMinimumVersionFor<Foo>() will be skipped. An exception is
292   // UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root
293   // table, and it won't be null.
294   UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata,
295                                                       &min_version);
296 
297   *min_version_str = min_version.ToString();
298   return kTfLiteOk;
299 }
300 
301 }  // namespace metadata
302 }  // namespace tflite
303