1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
18 #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
19
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23
24 #include "annotator/model_generated.h"
25 #include "utils/base/logging.h"
26 #include "utils/flatbuffers/flatbuffers_generated.h"
27 #include "utils/flatbuffers/reflection.h"
28 #include "utils/strings/stringpiece.h"
29 #include "utils/variant.h"
30 #include "flatbuffers/flatbuffers.h"
31 #include "flatbuffers/reflection.h"
32 #include "flatbuffers/reflection_generated.h"
33
34 namespace libtextclassifier3 {
35
36 class MutableFlatbuffer;
37 class RepeatedField;
38
39 template <typename T>
IsStringType()40 constexpr bool IsStringType() {
41 return std::is_same<T, std::string>::value ||
42 std::is_same<T, StringPiece>::value ||
43 std::is_same<T, const char*>::value;
44 }
45
46 // Checks whether a variant value type agrees with a field type.
47 template <typename T>
IsMatchingType(const reflection::BaseType type)48 bool IsMatchingType(const reflection::BaseType type) {
49 switch (type) {
50 case reflection::String:
51 return IsStringType<T>();
52 case reflection::Obj:
53 return std::is_same<T, MutableFlatbuffer>::value;
54 default:
55 return type == flatbuffers_base_type<T>::value;
56 }
57 }
58
59 // A mutable flatbuffer that can be built using flatbuffer reflection data of
60 // the schema. Normally, field information is hard-coded in code generated from
61 // a flatbuffer schema. Here we lookup the necessary information for building a
62 // flatbuffer from the provided reflection meta data. When serializing a
63 // flatbuffer, the library requires that the sub messages are already
64 // serialized, therefore we explicitly keep the field values and serialize the
65 // message in (reverse) topological dependency order.
66 class MutableFlatbuffer {
67 public:
MutableFlatbuffer(const reflection::Schema * schema,const reflection::Object * type)68 MutableFlatbuffer(const reflection::Schema* schema,
69 const reflection::Object* type)
70 : schema_(schema), type_(type) {}
71
72 // Gets the field information for a field name, returns nullptr if the
73 // field was not defined.
74 const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
75 const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
76 const reflection::Field* GetFieldOrNull(const int field_offset) const;
77
78 // Gets a nested field and the message it is defined on.
79 bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
80 MutableFlatbuffer** parent,
81 reflection::Field const** field);
82
83 // Sets a field to a specific value.
84 // Returns true if successful, and false if the field was not found or the
85 // expected type doesn't match.
86 template <typename T>
87 bool Set(StringPiece field_name, T value);
88
89 // Sets a field to a specific value.
90 // Returns true if successful, and false if the expected type doesn't match.
91 // Expects `field` to be non-null.
92 template <typename T>
93 bool Set(const reflection::Field* field, T value);
94
95 // Sets a field to a specific value. Field is specified by path.
96 template <typename T>
97 bool Set(const FlatbufferFieldPath* path, T value);
98
99 // Sets an enum field from an enum value name.
100 // Returns true if the value could be successfully parsed.
101 bool SetFromEnumValueName(StringPiece field_name, StringPiece value_name);
102
103 // Sets an enum field from an enum value name.
104 // Returns true if the value could be successfully parsed.
105 bool SetFromEnumValueName(const reflection::Field* field,
106 StringPiece value_name);
107
108 // Sets an enum field from an enum value name. Field is specified by path.
109 // Returns true if the value could be successfully parsed.
110 bool SetFromEnumValueName(const FlatbufferFieldPath* path,
111 StringPiece value_name);
112
113 // Sets sub-message field (if not set yet), and returns a pointer to it.
114 // Returns nullptr if the field was not found, or the field type was not a
115 // table.
116 MutableFlatbuffer* Mutable(StringPiece field_name);
117 MutableFlatbuffer* Mutable(const reflection::Field* field);
118
119 // Sets a sub-message field (if not set yet) specified by path, and returns a
120 // pointer to it. Returns nullptr if the field was not found, or the field
121 // type was not a table.
122 MutableFlatbuffer* Mutable(const FlatbufferFieldPath* path);
123
124 // Parses the value (according to the type) and sets a primitive field to the
125 // parsed value.
126 bool ParseAndSet(const reflection::Field* field, const std::string& value);
127 bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
128
129 // Adds a primitive value to the repeated field.
130 template <typename T>
131 bool Add(StringPiece field_name, T value);
132
133 // Add a sub-message to the repeated field.
134 MutableFlatbuffer* Add(StringPiece field_name);
135
136 template <typename T>
137 bool Add(const reflection::Field* field, T value);
138
139 MutableFlatbuffer* Add(const reflection::Field* field);
140
141 // Gets the reflective flatbuffer for a repeated field.
142 // Returns nullptr if the field was not found, or the field type was not a
143 // vector.
144 RepeatedField* Repeated(StringPiece field_name);
145 RepeatedField* Repeated(const reflection::Field* field);
146
147 // Gets a repeated field specified by path.
148 // Returns nullptr if the field was not found, or the field
149 // type was not a repeated field.
150 RepeatedField* Repeated(const FlatbufferFieldPath* path);
151
152 // Serializes the flatbuffer.
153 flatbuffers::uoffset_t Serialize(
154 flatbuffers::FlatBufferBuilder* builder) const;
155 std::string Serialize() const;
156
157 // Merges the fields from the given flatbuffer table into this flatbuffer.
158 // Scalar fields will be overwritten, if present in `from`.
159 // Embedded messages will be merged.
160 bool MergeFrom(const flatbuffers::Table* from);
161 bool MergeFromSerializedFlatbuffer(StringPiece from);
162
163 // Flattens the flatbuffer as a flat map.
164 // (Nested) fields names are joined by `key_separator`.
165 std::map<std::string, Variant> AsFlatMap(
166 const std::string& key_separator = ".") const {
167 std::map<std::string, Variant> result;
168 AsFlatMap(key_separator, /*key_prefix=*/"", &result);
169 return result;
170 }
171
172 // Converts the flatbuffer's content to a human-readable textproto
173 // representation.
174 std::string ToTextProto() const;
175
HasExplicitlySetFields()176 bool HasExplicitlySetFields() const {
177 return !fields_.empty() || !children_.empty() || !repeated_fields_.empty();
178 }
179
type()180 const reflection::Object* type() const { return type_; }
181
182 private:
183 // Helper function for merging given repeated field from given flatbuffer
184 // table. Appends the elements.
185 template <typename T>
186 bool AppendFromVector(const flatbuffers::Table* from,
187 const reflection::Field* field);
188
189 // Flattens the flatbuffer as a flat map.
190 // (Nested) fields names are joined by `key_separator` and prefixed by
191 // `key_prefix`.
192 void AsFlatMap(const std::string& key_separator,
193 const std::string& key_prefix,
194 std::map<std::string, Variant>* result) const;
195
196 const reflection::Schema* const schema_;
197 const reflection::Object* const type_;
198
199 // Cached primitive fields (scalars and strings).
200 std::unordered_map<const reflection::Field*, Variant> fields_;
201
202 // Cached sub-messages.
203 std::unordered_map<const reflection::Field*,
204 std::unique_ptr<MutableFlatbuffer>>
205 children_;
206
207 // Cached repeated fields.
208 std::unordered_map<const reflection::Field*, std::unique_ptr<RepeatedField>>
209 repeated_fields_;
210 };
211
212 // A helper class to build flatbuffers based on schema reflection data.
213 // Can be used to a `MutableFlatbuffer` for the root message of the
214 // schema, or any defined table via name.
215 class MutableFlatbufferBuilder {
216 public:
MutableFlatbufferBuilder(const reflection::Schema * schema)217 explicit MutableFlatbufferBuilder(const reflection::Schema* schema)
218 : schema_(schema), root_type_(schema->root_table()) {}
219 explicit MutableFlatbufferBuilder(const reflection::Schema* schema,
220 StringPiece root_type);
221
222 // Starts a new root table message.
223 std::unique_ptr<MutableFlatbuffer> NewRoot() const;
224
225 // Creates a new table message. Returns nullptr if no table with given name is
226 // found in the schema.
227 std::unique_ptr<MutableFlatbuffer> NewTable(
228 const StringPiece table_name) const;
229
230 // Creates a new message for the given type id. Returns nullptr if the type is
231 // invalid.
232 std::unique_ptr<MutableFlatbuffer> NewTable(int type_id) const;
233
234 // Creates a new message for the given type.
235 std::unique_ptr<MutableFlatbuffer> NewTable(
236 const reflection::Object* type) const;
237
238 private:
239 const reflection::Schema* const schema_;
240 const reflection::Object* const root_type_;
241 };
242
243 // Encapsulates a repeated field.
244 // Serves as a common base class for repeated fields.
245 class RepeatedField {
246 public:
RepeatedField(const reflection::Schema * const schema,const reflection::Field * field)247 RepeatedField(const reflection::Schema* const schema,
248 const reflection::Field* field)
249 : schema_(schema),
250 field_(field),
251 is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
252
253 template <typename T>
254 bool Add(const T value);
255
256 MutableFlatbuffer* Add();
257
258 template <typename T>
Get(int index)259 T Get(int index) const {
260 return items_.at(index).Value<T>();
261 }
262
263 template <>
Get(int index)264 MutableFlatbuffer* Get(int index) const {
265 if (is_primitive_) {
266 TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
267 "repeated field.";
268 return nullptr;
269 }
270 return object_items_.at(index).get();
271 }
272
Size()273 int Size() const {
274 if (is_primitive_) {
275 return items_.size();
276 } else {
277 return object_items_.size();
278 }
279 }
280
281 bool Extend(const flatbuffers::Table* from);
282
283 flatbuffers::uoffset_t Serialize(
284 flatbuffers::FlatBufferBuilder* builder) const;
285
286 std::string ToTextProto() const;
287
288 private:
289 template <typename T>
290 bool AppendFromVector(const flatbuffers::Table* from);
291
292 flatbuffers::uoffset_t SerializeString(
293 flatbuffers::FlatBufferBuilder* builder) const;
294 flatbuffers::uoffset_t SerializeObject(
295 flatbuffers::FlatBufferBuilder* builder) const;
296
297 const reflection::Schema* const schema_;
298 const reflection::Field* field_;
299 bool is_primitive_;
300
301 std::vector<Variant> items_;
302 std::vector<std::unique_ptr<MutableFlatbuffer>> object_items_;
303 };
304
305 template <typename T>
Set(StringPiece field_name,T value)306 bool MutableFlatbuffer::Set(StringPiece field_name, T value) {
307 if (const reflection::Field* field = GetFieldOrNull(field_name)) {
308 if (field->type()->base_type() == reflection::BaseType::Vector ||
309 field->type()->base_type() == reflection::BaseType::Obj) {
310 TC3_LOG(ERROR)
311 << "Trying to set a primitive value on a non-scalar field.";
312 return false;
313 }
314 return Set<T>(field, value);
315 }
316 TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
317 return false;
318 }
319
320 template <typename T>
Set(const reflection::Field * field,T value)321 bool MutableFlatbuffer::Set(const reflection::Field* field, T value) {
322 if (field == nullptr) {
323 TC3_LOG(ERROR) << "Expected non-null field.";
324 return false;
325 }
326 Variant variant_value(value);
327 if (!IsMatchingType<T>(field->type()->base_type())) {
328 TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
329 << "`, expected: "
330 << EnumNameBaseType(field->type()->base_type())
331 << ", got: " << variant_value.GetType();
332 return false;
333 }
334 fields_[field] = variant_value;
335 return true;
336 }
337
338 template <typename T>
Set(const FlatbufferFieldPath * path,T value)339 bool MutableFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
340 MutableFlatbuffer* parent;
341 const reflection::Field* field;
342 if (!GetFieldWithParent(path, &parent, &field)) {
343 return false;
344 }
345 return parent->Set<T>(field, value);
346 }
347
348 template <typename T>
Add(StringPiece field_name,T value)349 bool MutableFlatbuffer::Add(StringPiece field_name, T value) {
350 const reflection::Field* field = GetFieldOrNull(field_name);
351 if (field == nullptr) {
352 return false;
353 }
354
355 if (field->type()->base_type() != reflection::BaseType::Vector) {
356 return false;
357 }
358
359 return Add<T>(field, value);
360 }
361
362 template <typename T>
Add(const reflection::Field * field,T value)363 bool MutableFlatbuffer::Add(const reflection::Field* field, T value) {
364 if (field == nullptr) {
365 return false;
366 }
367 Repeated(field)->Add(value);
368 return true;
369 }
370
371 template <typename T>
Add(const T value)372 bool RepeatedField::Add(const T value) {
373 if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
374 TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
375 return false;
376 }
377 items_.push_back(Variant{value});
378 return true;
379 }
380
381 template <typename T>
AppendFromVector(const flatbuffers::Table * from)382 bool RepeatedField::AppendFromVector(const flatbuffers::Table* from) {
383 const flatbuffers::Vector<T>* values =
384 from->GetPointer<const flatbuffers::Vector<T>*>(field_->offset());
385 if (values == nullptr) {
386 return false;
387 }
388 for (const T element : *values) {
389 Add(element);
390 }
391 return true;
392 }
393
394 template <>
395 inline bool RepeatedField::AppendFromVector<std::string>(
396 const flatbuffers::Table* from) {
397 auto* values = from->GetPointer<
398 const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
399 field_->offset());
400 if (values == nullptr) {
401 return false;
402 }
403 for (const flatbuffers::String* element : *values) {
404 Add(element->str());
405 }
406 return true;
407 }
408
409 template <>
410 inline bool RepeatedField::AppendFromVector<MutableFlatbuffer>(
411 const flatbuffers::Table* from) {
412 auto* values = from->GetPointer<const flatbuffers::Vector<
413 flatbuffers::Offset<const flatbuffers::Table>>*>(field_->offset());
414 if (values == nullptr) {
415 return false;
416 }
417 for (const flatbuffers::Table* const from_element : *values) {
418 MutableFlatbuffer* to_element = Add();
419 if (to_element == nullptr) {
420 return false;
421 }
422 to_element->MergeFrom(from_element);
423 }
424 return true;
425 }
426
427 } // namespace libtextclassifier3
428
429 #endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
430