• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
18 #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
19 
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 
24 #include "annotator/model_generated.h"
25 #include "utils/base/logging.h"
26 #include "utils/flatbuffers/flatbuffers_generated.h"
27 #include "utils/flatbuffers/reflection.h"
28 #include "utils/strings/stringpiece.h"
29 #include "utils/variant.h"
30 #include "flatbuffers/flatbuffers.h"
31 #include "flatbuffers/reflection.h"
32 #include "flatbuffers/reflection_generated.h"
33 
34 namespace libtextclassifier3 {
35 
36 class MutableFlatbuffer;
37 class RepeatedField;
38 
39 template <typename T>
IsStringType()40 constexpr bool IsStringType() {
41   return std::is_same<T, std::string>::value ||
42          std::is_same<T, StringPiece>::value ||
43          std::is_same<T, const char*>::value;
44 }
45 
46 // Checks whether a variant value type agrees with a field type.
47 template <typename T>
IsMatchingType(const reflection::BaseType type)48 bool IsMatchingType(const reflection::BaseType type) {
49   switch (type) {
50     case reflection::String:
51       return IsStringType<T>();
52     case reflection::Obj:
53       return std::is_same<T, MutableFlatbuffer>::value;
54     default:
55       return type == flatbuffers_base_type<T>::value;
56   }
57 }
58 
59 // A mutable flatbuffer that can be built using flatbuffer reflection data of
60 // the schema. Normally, field information is hard-coded in code generated from
61 // a flatbuffer schema. Here we lookup the necessary information for building a
62 // flatbuffer from the provided reflection meta data. When serializing a
63 // flatbuffer, the library requires that the sub messages are already
64 // serialized, therefore we explicitly keep the field values and serialize the
65 // message in (reverse) topological dependency order.
66 class MutableFlatbuffer {
67  public:
MutableFlatbuffer(const reflection::Schema * schema,const reflection::Object * type)68   MutableFlatbuffer(const reflection::Schema* schema,
69                     const reflection::Object* type)
70       : schema_(schema), type_(type) {}
71 
72   // Gets the field information for a field name, returns nullptr if the
73   // field was not defined.
74   const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
75   const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
76   const reflection::Field* GetFieldOrNull(const int field_offset) const;
77 
78   // Gets a nested field and the message it is defined on.
79   bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
80                           MutableFlatbuffer** parent,
81                           reflection::Field const** field);
82 
83   // Sets a field to a specific value.
84   // Returns true if successful, and false if the field was not found or the
85   // expected type doesn't match.
86   template <typename T>
87   bool Set(StringPiece field_name, T value);
88 
89   // Sets a field to a specific value.
90   // Returns true if successful, and false if the expected type doesn't match.
91   // Expects `field` to be non-null.
92   template <typename T>
93   bool Set(const reflection::Field* field, T value);
94 
95   // Sets a field to a specific value. Field is specified by path.
96   template <typename T>
97   bool Set(const FlatbufferFieldPath* path, T value);
98 
99   // Sets an enum field from an enum value name.
100   // Returns true if the value could be successfully parsed.
101   bool SetFromEnumValueName(StringPiece field_name, StringPiece value_name);
102 
103   // Sets an enum field from an enum value name.
104   // Returns true if the value could be successfully parsed.
105   bool SetFromEnumValueName(const reflection::Field* field,
106                             StringPiece value_name);
107 
108   // Sets an enum field from an enum value name. Field is specified by path.
109   // Returns true if the value could be successfully parsed.
110   bool SetFromEnumValueName(const FlatbufferFieldPath* path,
111                             StringPiece value_name);
112 
113   // Sets sub-message field (if not set yet), and returns a pointer to it.
114   // Returns nullptr if the field was not found, or the field type was not a
115   // table.
116   MutableFlatbuffer* Mutable(StringPiece field_name);
117   MutableFlatbuffer* Mutable(const reflection::Field* field);
118 
119   // Sets a sub-message field (if not set yet) specified by path, and returns a
120   // pointer to it. Returns nullptr if the field was not found, or the field
121   // type was not a table.
122   MutableFlatbuffer* Mutable(const FlatbufferFieldPath* path);
123 
124   // Parses the value (according to the type) and sets a primitive field to the
125   // parsed value.
126   bool ParseAndSet(const reflection::Field* field, const std::string& value);
127   bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
128 
129   // Adds a primitive value to the repeated field.
130   template <typename T>
131   bool Add(StringPiece field_name, T value);
132 
133   // Add a sub-message to the repeated field.
134   MutableFlatbuffer* Add(StringPiece field_name);
135 
136   template <typename T>
137   bool Add(const reflection::Field* field, T value);
138 
139   MutableFlatbuffer* Add(const reflection::Field* field);
140 
141   // Gets the reflective flatbuffer for a repeated field.
142   // Returns nullptr if the field was not found, or the field type was not a
143   // vector.
144   RepeatedField* Repeated(StringPiece field_name);
145   RepeatedField* Repeated(const reflection::Field* field);
146 
147   // Gets a repeated field specified by path.
148   // Returns nullptr if the field was not found, or the field
149   // type was not a repeated field.
150   RepeatedField* Repeated(const FlatbufferFieldPath* path);
151 
152   // Serializes the flatbuffer.
153   flatbuffers::uoffset_t Serialize(
154       flatbuffers::FlatBufferBuilder* builder) const;
155   std::string Serialize() const;
156 
157   // Merges the fields from the given flatbuffer table into this flatbuffer.
158   // Scalar fields will be overwritten, if present in `from`.
159   // Embedded messages will be merged.
160   bool MergeFrom(const flatbuffers::Table* from);
161   bool MergeFromSerializedFlatbuffer(StringPiece from);
162 
163   // Flattens the flatbuffer as a flat map.
164   // (Nested) fields names are joined by `key_separator`.
165   std::map<std::string, Variant> AsFlatMap(
166       const std::string& key_separator = ".") const {
167     std::map<std::string, Variant> result;
168     AsFlatMap(key_separator, /*key_prefix=*/"", &result);
169     return result;
170   }
171 
172   // Converts the flatbuffer's content to a human-readable textproto
173   // representation.
174   std::string ToTextProto() const;
175 
HasExplicitlySetFields()176   bool HasExplicitlySetFields() const {
177     return !fields_.empty() || !children_.empty() || !repeated_fields_.empty();
178   }
179 
type()180   const reflection::Object* type() const { return type_; }
181 
182  private:
183   // Helper function for merging given repeated field from given flatbuffer
184   // table. Appends the elements.
185   template <typename T>
186   bool AppendFromVector(const flatbuffers::Table* from,
187                         const reflection::Field* field);
188 
189   // Flattens the flatbuffer as a flat map.
190   // (Nested) fields names are joined by `key_separator` and prefixed by
191   // `key_prefix`.
192   void AsFlatMap(const std::string& key_separator,
193                  const std::string& key_prefix,
194                  std::map<std::string, Variant>* result) const;
195 
196   const reflection::Schema* const schema_;
197   const reflection::Object* const type_;
198 
199   // Cached primitive fields (scalars and strings).
200   std::unordered_map<const reflection::Field*, Variant> fields_;
201 
202   // Cached sub-messages.
203   std::unordered_map<const reflection::Field*,
204                      std::unique_ptr<MutableFlatbuffer>>
205       children_;
206 
207   // Cached repeated fields.
208   std::unordered_map<const reflection::Field*, std::unique_ptr<RepeatedField>>
209       repeated_fields_;
210 };
211 
212 // A helper class to build flatbuffers based on schema reflection data.
213 // Can be used to a `MutableFlatbuffer` for the root message of the
214 // schema, or any defined table via name.
215 class MutableFlatbufferBuilder {
216  public:
MutableFlatbufferBuilder(const reflection::Schema * schema)217   explicit MutableFlatbufferBuilder(const reflection::Schema* schema)
218       : schema_(schema), root_type_(schema->root_table()) {}
219   explicit MutableFlatbufferBuilder(const reflection::Schema* schema,
220                                     StringPiece root_type);
221 
222   // Starts a new root table message.
223   std::unique_ptr<MutableFlatbuffer> NewRoot() const;
224 
225   // Creates a new table message. Returns nullptr if no table with given name is
226   // found in the schema.
227   std::unique_ptr<MutableFlatbuffer> NewTable(
228       const StringPiece table_name) const;
229 
230   // Creates a new message for the given type id. Returns nullptr if the type is
231   // invalid.
232   std::unique_ptr<MutableFlatbuffer> NewTable(int type_id) const;
233 
234   // Creates a new message for the given type.
235   std::unique_ptr<MutableFlatbuffer> NewTable(
236       const reflection::Object* type) const;
237 
238  private:
239   const reflection::Schema* const schema_;
240   const reflection::Object* const root_type_;
241 };
242 
243 // Encapsulates a repeated field.
244 // Serves as a common base class for repeated fields.
245 class RepeatedField {
246  public:
RepeatedField(const reflection::Schema * const schema,const reflection::Field * field)247   RepeatedField(const reflection::Schema* const schema,
248                 const reflection::Field* field)
249       : schema_(schema),
250         field_(field),
251         is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
252 
253   template <typename T>
254   bool Add(const T value);
255 
256   MutableFlatbuffer* Add();
257 
258   template <typename T>
Get(int index)259   T Get(int index) const {
260     return items_.at(index).Value<T>();
261   }
262 
263   template <>
Get(int index)264   MutableFlatbuffer* Get(int index) const {
265     if (is_primitive_) {
266       TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
267                         "repeated field.";
268       return nullptr;
269     }
270     return object_items_.at(index).get();
271   }
272 
Size()273   int Size() const {
274     if (is_primitive_) {
275       return items_.size();
276     } else {
277       return object_items_.size();
278     }
279   }
280 
281   bool Extend(const flatbuffers::Table* from);
282 
283   flatbuffers::uoffset_t Serialize(
284       flatbuffers::FlatBufferBuilder* builder) const;
285 
286   std::string ToTextProto() const;
287 
288  private:
289   template <typename T>
290   bool AppendFromVector(const flatbuffers::Table* from);
291 
292   flatbuffers::uoffset_t SerializeString(
293       flatbuffers::FlatBufferBuilder* builder) const;
294   flatbuffers::uoffset_t SerializeObject(
295       flatbuffers::FlatBufferBuilder* builder) const;
296 
297   const reflection::Schema* const schema_;
298   const reflection::Field* field_;
299   bool is_primitive_;
300 
301   std::vector<Variant> items_;
302   std::vector<std::unique_ptr<MutableFlatbuffer>> object_items_;
303 };
304 
305 template <typename T>
Set(StringPiece field_name,T value)306 bool MutableFlatbuffer::Set(StringPiece field_name, T value) {
307   if (const reflection::Field* field = GetFieldOrNull(field_name)) {
308     if (field->type()->base_type() == reflection::BaseType::Vector ||
309         field->type()->base_type() == reflection::BaseType::Obj) {
310       TC3_LOG(ERROR)
311           << "Trying to set a primitive value on a non-scalar field.";
312       return false;
313     }
314     return Set<T>(field, value);
315   }
316   TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
317   return false;
318 }
319 
320 template <typename T>
Set(const reflection::Field * field,T value)321 bool MutableFlatbuffer::Set(const reflection::Field* field, T value) {
322   if (field == nullptr) {
323     TC3_LOG(ERROR) << "Expected non-null field.";
324     return false;
325   }
326   Variant variant_value(value);
327   if (!IsMatchingType<T>(field->type()->base_type())) {
328     TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
329                    << "`, expected: "
330                    << EnumNameBaseType(field->type()->base_type())
331                    << ", got: " << variant_value.GetType();
332     return false;
333   }
334   fields_[field] = variant_value;
335   return true;
336 }
337 
338 template <typename T>
Set(const FlatbufferFieldPath * path,T value)339 bool MutableFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
340   MutableFlatbuffer* parent;
341   const reflection::Field* field;
342   if (!GetFieldWithParent(path, &parent, &field)) {
343     return false;
344   }
345   return parent->Set<T>(field, value);
346 }
347 
348 template <typename T>
Add(StringPiece field_name,T value)349 bool MutableFlatbuffer::Add(StringPiece field_name, T value) {
350   const reflection::Field* field = GetFieldOrNull(field_name);
351   if (field == nullptr) {
352     return false;
353   }
354 
355   if (field->type()->base_type() != reflection::BaseType::Vector) {
356     return false;
357   }
358 
359   return Add<T>(field, value);
360 }
361 
362 template <typename T>
Add(const reflection::Field * field,T value)363 bool MutableFlatbuffer::Add(const reflection::Field* field, T value) {
364   if (field == nullptr) {
365     return false;
366   }
367   Repeated(field)->Add(value);
368   return true;
369 }
370 
371 template <typename T>
Add(const T value)372 bool RepeatedField::Add(const T value) {
373   if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
374     TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
375     return false;
376   }
377   items_.push_back(Variant{value});
378   return true;
379 }
380 
381 template <typename T>
AppendFromVector(const flatbuffers::Table * from)382 bool RepeatedField::AppendFromVector(const flatbuffers::Table* from) {
383   const flatbuffers::Vector<T>* values =
384       from->GetPointer<const flatbuffers::Vector<T>*>(field_->offset());
385   if (values == nullptr) {
386     return false;
387   }
388   for (const T element : *values) {
389     Add(element);
390   }
391   return true;
392 }
393 
394 template <>
395 inline bool RepeatedField::AppendFromVector<std::string>(
396     const flatbuffers::Table* from) {
397   auto* values = from->GetPointer<
398       const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*>(
399       field_->offset());
400   if (values == nullptr) {
401     return false;
402   }
403   for (const flatbuffers::String* element : *values) {
404     Add(element->str());
405   }
406   return true;
407 }
408 
409 template <>
410 inline bool RepeatedField::AppendFromVector<MutableFlatbuffer>(
411     const flatbuffers::Table* from) {
412   auto* values = from->GetPointer<const flatbuffers::Vector<
413       flatbuffers::Offset<const flatbuffers::Table>>*>(field_->offset());
414   if (values == nullptr) {
415     return false;
416   }
417   for (const flatbuffers::Table* const from_element : *values) {
418     MutableFlatbuffer* to_element = Add();
419     if (to_element == nullptr) {
420       return false;
421     }
422     to_element->MergeFrom(from_element);
423   }
424   return true;
425 }
426 
427 }  // namespace libtextclassifier3
428 
429 #endif  // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_
430