1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Utility functions for working with FlatBuffers.
18
19 #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
20 #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
21
22 #include <map>
23 #include <memory>
24 #include <string>
25
26 #include "annotator/model_generated.h"
27 #include "utils/strings/stringpiece.h"
28 #include "utils/variant.h"
29 #include "flatbuffers/flatbuffers.h"
30 #include "flatbuffers/reflection.h"
31
32 namespace libtextclassifier3 {
33
34 // Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
35 // integrity.
36 template <typename FlatbufferMessage>
LoadAndVerifyFlatbuffer(const void * buffer,int size)37 const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
38 const FlatbufferMessage* message =
39 flatbuffers::GetRoot<FlatbufferMessage>(buffer);
40 if (message == nullptr) {
41 return nullptr;
42 }
43 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
44 size);
45 if (message->Verify(verifier)) {
46 return message;
47 } else {
48 return nullptr;
49 }
50 }
51
52 // Same as above but takes string.
53 template <typename FlatbufferMessage>
LoadAndVerifyFlatbuffer(const std::string & buffer)54 const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
55 return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
56 buffer.size());
57 }
58
59 // Loads and interprets the buffer as 'FlatbufferMessage', verifies its
60 // integrity and returns its mutable version.
61 template <typename FlatbufferMessage>
62 std::unique_ptr<typename FlatbufferMessage::NativeTableType>
LoadAndVerifyMutableFlatbuffer(const void * buffer,int size)63 LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
64 const FlatbufferMessage* message =
65 LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
66 if (message == nullptr) {
67 return nullptr;
68 }
69 return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
70 message->UnPack());
71 }
72
73 // Same as above but takes string.
74 template <typename FlatbufferMessage>
75 std::unique_ptr<typename FlatbufferMessage::NativeTableType>
LoadAndVerifyMutableFlatbuffer(const std::string & buffer)76 LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
77 return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
78 buffer.size());
79 }
80
81 template <typename FlatbufferMessage>
FlatbufferFileIdentifier()82 const char* FlatbufferFileIdentifier() {
83 return nullptr;
84 }
85
86 template <>
87 const char* FlatbufferFileIdentifier<Model>();
88
89 // Packs the mutable flatbuffer message to string.
90 template <typename FlatbufferMessage>
PackFlatbuffer(const typename FlatbufferMessage::NativeTableType * mutable_message)91 std::string PackFlatbuffer(
92 const typename FlatbufferMessage::NativeTableType* mutable_message) {
93 flatbuffers::FlatBufferBuilder builder;
94 builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
95 FlatbufferFileIdentifier<FlatbufferMessage>());
96 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
97 builder.GetSize());
98 }
99
100 // A flatbuffer that can be built using flatbuffer reflection data of the
101 // schema.
102 // Normally, field information is hard-coded in code generated from a flatbuffer
103 // schema. Here we lookup the necessary information for building a flatbuffer
104 // from the provided reflection meta data.
105 // When serializing a flatbuffer, the library requires that the sub messages
106 // are already serialized, therefore we explicitly keep the field values and
107 // serialize the message in (reverse) topological dependency order.
108 class ReflectiveFlatbuffer {
109 public:
ReflectiveFlatbuffer(const reflection::Schema * schema,const reflection::Object * type)110 ReflectiveFlatbuffer(const reflection::Schema* schema,
111 const reflection::Object* type)
112 : schema_(schema), type_(type) {}
113
114 // Encapsulates a repeated field.
115 // Serves as a common base class for repeated fields.
116 class RepeatedField {
117 public:
~RepeatedField()118 virtual ~RepeatedField() {}
119
120 virtual flatbuffers::uoffset_t Serialize(
121 flatbuffers::FlatBufferBuilder* builder) const = 0;
122 };
123
124 // Represents a repeated field of particular type.
125 template <typename T>
126 class TypedRepeatedField : public RepeatedField {
127 public:
Add(const T value)128 void Add(const T value) { items_.push_back(value); }
129
Serialize(flatbuffers::FlatBufferBuilder * builder)130 flatbuffers::uoffset_t Serialize(
131 flatbuffers::FlatBufferBuilder* builder) const override {
132 return builder->CreateVector(items_).o;
133 }
134
135 private:
136 std::vector<T> items_;
137 };
138
139 // Specialization for strings.
140 template <>
141 class TypedRepeatedField<std::string> : public RepeatedField {
142 public:
Add(const std::string & value)143 void Add(const std::string& value) { items_.push_back(value); }
144
Serialize(flatbuffers::FlatBufferBuilder * builder)145 flatbuffers::uoffset_t Serialize(
146 flatbuffers::FlatBufferBuilder* builder) const override {
147 std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
148 items_.size());
149 for (int i = 0; i < items_.size(); i++) {
150 offsets[i] = builder->CreateString(items_[i]);
151 }
152 return builder->CreateVector(offsets).o;
153 }
154
155 private:
156 std::vector<std::string> items_;
157 };
158
159 // Specialization for repeated sub-messages.
160 template <>
161 class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
162 public:
163 TypedRepeatedField<ReflectiveFlatbuffer>(
164 const reflection::Schema* const schema,
165 const reflection::Type* const type)
schema_(schema)166 : schema_(schema), type_(type) {}
167
Add()168 ReflectiveFlatbuffer* Add() {
169 items_.emplace_back(new ReflectiveFlatbuffer(
170 schema_, schema_->objects()->Get(type_->index())));
171 return items_.back().get();
172 }
173
Serialize(flatbuffers::FlatBufferBuilder * builder)174 flatbuffers::uoffset_t Serialize(
175 flatbuffers::FlatBufferBuilder* builder) const override {
176 std::vector<flatbuffers::Offset<void>> offsets(items_.size());
177 for (int i = 0; i < items_.size(); i++) {
178 offsets[i] = items_[i]->Serialize(builder);
179 }
180 return builder->CreateVector(offsets).o;
181 }
182
183 private:
184 const reflection::Schema* const schema_;
185 const reflection::Type* const type_;
186 std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
187 };
188
189 // Gets the field information for a field name, returns nullptr if the
190 // field was not defined.
191 const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
192 const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
193 const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const;
194
195 // Gets a nested field and the message it is defined on.
196 bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
197 ReflectiveFlatbuffer** parent,
198 reflection::Field const** field);
199
200 // Checks whether a variant value type agrees with a field type.
201 bool IsMatchingType(const reflection::Field* field,
202 const Variant& value) const;
203
204 // Sets a (primitive) field to a specific value.
205 // Returns true if successful, and false if the field was not found or the
206 // expected type doesn't match.
207 template <typename T>
Set(StringPiece field_name,T value)208 bool Set(StringPiece field_name, T value) {
209 if (const reflection::Field* field = GetFieldOrNull(field_name)) {
210 return Set<T>(field, value);
211 }
212 return false;
213 }
214
215 // Sets a (primitive) field to a specific value.
216 // Returns true if successful, and false if the expected type doesn't match.
217 // Expects `field` to be non-null.
218 template <typename T>
Set(const reflection::Field * field,T value)219 bool Set(const reflection::Field* field, T value) {
220 if (field == nullptr) {
221 TC3_LOG(ERROR) << "Expected non-null field.";
222 return false;
223 }
224 Variant variant_value(value);
225 if (!IsMatchingType(field, variant_value)) {
226 TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
227 << "`, expected: " << field->type()->base_type()
228 << ", got: " << variant_value.GetType();
229 return false;
230 }
231 fields_[field] = variant_value;
232 return true;
233 }
234
235 template <typename T>
Set(const FlatbufferFieldPath * path,T value)236 bool Set(const FlatbufferFieldPath* path, T value) {
237 ReflectiveFlatbuffer* parent;
238 const reflection::Field* field;
239 if (!GetFieldWithParent(path, &parent, &field)) {
240 return false;
241 }
242 return parent->Set<T>(field, value);
243 }
244
245 // Sets a (primitive) field to a specific value.
246 // Parses the string value according to the field type.
247 bool ParseAndSet(const reflection::Field* field, const std::string& value);
248 bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
249
250 // Gets the reflective flatbuffer for a table field.
251 // Returns nullptr if the field was not found, or the field type was not a
252 // table.
253 ReflectiveFlatbuffer* Mutable(StringPiece field_name);
254 ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
255
256 // Gets the reflective flatbuffer for a repeated field.
257 // Returns nullptr if the field was not found, or the field type was not a
258 // vector.
259 RepeatedField* Repeated(StringPiece field_name);
260 RepeatedField* Repeated(const reflection::Field* field);
261
262 template <typename T>
Repeated(const reflection::Field * field)263 TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
264 return static_cast<TypedRepeatedField<T>*>(Repeated(field));
265 }
266
267 template <typename T>
Repeated(StringPiece field_name)268 TypedRepeatedField<T>* Repeated(StringPiece field_name) {
269 return static_cast<TypedRepeatedField<T>*>(Repeated(field_name));
270 }
271
272 // Serializes the flatbuffer.
273 flatbuffers::uoffset_t Serialize(
274 flatbuffers::FlatBufferBuilder* builder) const;
275 std::string Serialize() const;
276
277 // Merges the fields from the given flatbuffer table into this flatbuffer.
278 // Scalar fields will be overwritten, if present in `from`.
279 // Embedded messages will be merged.
280 bool MergeFrom(const flatbuffers::Table* from);
281 bool MergeFromSerializedFlatbuffer(StringPiece from);
282
283 // Flattens the flatbuffer as a flat map.
284 // (Nested) fields names are joined by `key_separator`.
285 std::map<std::string, Variant> AsFlatMap(
286 const std::string& key_separator = ".") const {
287 std::map<std::string, Variant> result;
288 AsFlatMap(key_separator, /*key_prefix=*/"", &result);
289 return result;
290 }
291
292 private:
293 const reflection::Schema* const schema_;
294 const reflection::Object* const type_;
295
296 // Cached primitive fields (scalars and strings).
297 std::map<const reflection::Field*, Variant> fields_;
298
299 // Cached sub-messages.
300 std::map<const reflection::Field*, std::unique_ptr<ReflectiveFlatbuffer>>
301 children_;
302
303 // Cached repeated fields.
304 std::map<const reflection::Field*, std::unique_ptr<RepeatedField>>
305 repeated_fields_;
306
307 // Flattens the flatbuffer as a flat map.
308 // (Nested) fields names are joined by `key_separator` and prefixed by
309 // `key_prefix`.
310 void AsFlatMap(const std::string& key_separator,
311 const std::string& key_prefix,
312 std::map<std::string, Variant>* result) const;
313 };
314
315 // A helper class to build flatbuffers based on schema reflection data.
316 // Can be used to a `ReflectiveFlatbuffer` for the root message of the
317 // schema, or any defined table via name.
318 class ReflectiveFlatbufferBuilder {
319 public:
ReflectiveFlatbufferBuilder(const reflection::Schema * schema)320 explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
321 : schema_(schema) {}
322
323 // Starts a new root table message.
324 std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
325
326 // Starts a new table message. Returns nullptr if no table with given name is
327 // found in the schema.
328 std::unique_ptr<ReflectiveFlatbuffer> NewTable(
329 const StringPiece table_name) const;
330
331 private:
332 const reflection::Schema* const schema_;
333 };
334
335 } // namespace libtextclassifier3
336
337 #endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
338