• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Utility functions for working with FlatBuffers.
18 
19 #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
20 #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
21 
22 #include <map>
23 #include <memory>
24 #include <string>
25 
26 #include "annotator/model_generated.h"
27 #include "utils/strings/stringpiece.h"
28 #include "utils/variant.h"
29 #include "flatbuffers/flatbuffers.h"
30 #include "flatbuffers/reflection.h"
31 
32 namespace libtextclassifier3 {
33 
34 // Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
35 // integrity.
36 template <typename FlatbufferMessage>
LoadAndVerifyFlatbuffer(const void * buffer,int size)37 const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
38   const FlatbufferMessage* message =
39       flatbuffers::GetRoot<FlatbufferMessage>(buffer);
40   if (message == nullptr) {
41     return nullptr;
42   }
43   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
44                                  size);
45   if (message->Verify(verifier)) {
46     return message;
47   } else {
48     return nullptr;
49   }
50 }
51 
52 // Same as above but takes string.
53 template <typename FlatbufferMessage>
LoadAndVerifyFlatbuffer(const std::string & buffer)54 const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
55   return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
56                                                     buffer.size());
57 }
58 
59 // Loads and interprets the buffer as 'FlatbufferMessage', verifies its
60 // integrity and returns its mutable version.
61 template <typename FlatbufferMessage>
62 std::unique_ptr<typename FlatbufferMessage::NativeTableType>
LoadAndVerifyMutableFlatbuffer(const void * buffer,int size)63 LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
64   const FlatbufferMessage* message =
65       LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
66   if (message == nullptr) {
67     return nullptr;
68   }
69   return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
70       message->UnPack());
71 }
72 
73 // Same as above but takes string.
74 template <typename FlatbufferMessage>
75 std::unique_ptr<typename FlatbufferMessage::NativeTableType>
LoadAndVerifyMutableFlatbuffer(const std::string & buffer)76 LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
77   return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
78                                                            buffer.size());
79 }
80 
81 template <typename FlatbufferMessage>
FlatbufferFileIdentifier()82 const char* FlatbufferFileIdentifier() {
83   return nullptr;
84 }
85 
86 template <>
87 const char* FlatbufferFileIdentifier<Model>();
88 
89 // Packs the mutable flatbuffer message to string.
90 template <typename FlatbufferMessage>
PackFlatbuffer(const typename FlatbufferMessage::NativeTableType * mutable_message)91 std::string PackFlatbuffer(
92     const typename FlatbufferMessage::NativeTableType* mutable_message) {
93   flatbuffers::FlatBufferBuilder builder;
94   builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
95                  FlatbufferFileIdentifier<FlatbufferMessage>());
96   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
97                      builder.GetSize());
98 }
99 
100 // A flatbuffer that can be built using flatbuffer reflection data of the
101 // schema.
102 // Normally, field information is hard-coded in code generated from a flatbuffer
103 // schema. Here we lookup the necessary information for building a flatbuffer
104 // from the provided reflection meta data.
105 // When serializing a flatbuffer, the library requires that the sub messages
106 // are already serialized, therefore we explicitly keep the field values and
107 // serialize the message in (reverse) topological dependency order.
108 class ReflectiveFlatbuffer {
109  public:
ReflectiveFlatbuffer(const reflection::Schema * schema,const reflection::Object * type)110   ReflectiveFlatbuffer(const reflection::Schema* schema,
111                        const reflection::Object* type)
112       : schema_(schema), type_(type) {}
113 
114   // Encapsulates a repeated field.
115   // Serves as a common base class for repeated fields.
116   class RepeatedField {
117    public:
~RepeatedField()118     virtual ~RepeatedField() {}
119 
120     virtual flatbuffers::uoffset_t Serialize(
121         flatbuffers::FlatBufferBuilder* builder) const = 0;
122   };
123 
124   // Represents a repeated field of particular type.
125   template <typename T>
126   class TypedRepeatedField : public RepeatedField {
127    public:
Add(const T value)128     void Add(const T value) { items_.push_back(value); }
129 
Serialize(flatbuffers::FlatBufferBuilder * builder)130     flatbuffers::uoffset_t Serialize(
131         flatbuffers::FlatBufferBuilder* builder) const override {
132       return builder->CreateVector(items_).o;
133     }
134 
135    private:
136     std::vector<T> items_;
137   };
138 
139   // Specialization for strings.
140   template <>
141   class TypedRepeatedField<std::string> : public RepeatedField {
142    public:
Add(const std::string & value)143     void Add(const std::string& value) { items_.push_back(value); }
144 
Serialize(flatbuffers::FlatBufferBuilder * builder)145     flatbuffers::uoffset_t Serialize(
146         flatbuffers::FlatBufferBuilder* builder) const override {
147       std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
148           items_.size());
149       for (int i = 0; i < items_.size(); i++) {
150         offsets[i] = builder->CreateString(items_[i]);
151       }
152       return builder->CreateVector(offsets).o;
153     }
154 
155    private:
156     std::vector<std::string> items_;
157   };
158 
159   // Specialization for repeated sub-messages.
160   template <>
161   class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
162    public:
163     TypedRepeatedField<ReflectiveFlatbuffer>(
164         const reflection::Schema* const schema,
165         const reflection::Type* const type)
schema_(schema)166         : schema_(schema), type_(type) {}
167 
Add()168     ReflectiveFlatbuffer* Add() {
169       items_.emplace_back(new ReflectiveFlatbuffer(
170           schema_, schema_->objects()->Get(type_->index())));
171       return items_.back().get();
172     }
173 
Serialize(flatbuffers::FlatBufferBuilder * builder)174     flatbuffers::uoffset_t Serialize(
175         flatbuffers::FlatBufferBuilder* builder) const override {
176       std::vector<flatbuffers::Offset<void>> offsets(items_.size());
177       for (int i = 0; i < items_.size(); i++) {
178         offsets[i] = items_[i]->Serialize(builder);
179       }
180       return builder->CreateVector(offsets).o;
181     }
182 
183    private:
184     const reflection::Schema* const schema_;
185     const reflection::Type* const type_;
186     std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
187   };
188 
189   // Gets the field information for a field name, returns nullptr if the
190   // field was not defined.
191   const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
192   const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
193   const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const;
194 
195   // Gets a nested field and the message it is defined on.
196   bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
197                           ReflectiveFlatbuffer** parent,
198                           reflection::Field const** field);
199 
200   // Checks whether a variant value type agrees with a field type.
201   bool IsMatchingType(const reflection::Field* field,
202                       const Variant& value) const;
203 
204   // Sets a (primitive) field to a specific value.
205   // Returns true if successful, and false if the field was not found or the
206   // expected type doesn't match.
207   template <typename T>
Set(StringPiece field_name,T value)208   bool Set(StringPiece field_name, T value) {
209     if (const reflection::Field* field = GetFieldOrNull(field_name)) {
210       return Set<T>(field, value);
211     }
212     return false;
213   }
214 
215   // Sets a (primitive) field to a specific value.
216   // Returns true if successful, and false if the expected type doesn't match.
217   // Expects `field` to be non-null.
218   template <typename T>
Set(const reflection::Field * field,T value)219   bool Set(const reflection::Field* field, T value) {
220     if (field == nullptr) {
221       TC3_LOG(ERROR) << "Expected non-null field.";
222       return false;
223     }
224     Variant variant_value(value);
225     if (!IsMatchingType(field, variant_value)) {
226       TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
227                      << "`, expected: " << field->type()->base_type()
228                      << ", got: " << variant_value.GetType();
229       return false;
230     }
231     fields_[field] = variant_value;
232     return true;
233   }
234 
235   template <typename T>
Set(const FlatbufferFieldPath * path,T value)236   bool Set(const FlatbufferFieldPath* path, T value) {
237     ReflectiveFlatbuffer* parent;
238     const reflection::Field* field;
239     if (!GetFieldWithParent(path, &parent, &field)) {
240       return false;
241     }
242     return parent->Set<T>(field, value);
243   }
244 
245   // Sets a (primitive) field to a specific value.
246   // Parses the string value according to the field type.
247   bool ParseAndSet(const reflection::Field* field, const std::string& value);
248   bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
249 
250   // Gets the reflective flatbuffer for a table field.
251   // Returns nullptr if the field was not found, or the field type was not a
252   // table.
253   ReflectiveFlatbuffer* Mutable(StringPiece field_name);
254   ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
255 
256   // Gets the reflective flatbuffer for a repeated field.
257   // Returns nullptr if the field was not found, or the field type was not a
258   // vector.
259   RepeatedField* Repeated(StringPiece field_name);
260   RepeatedField* Repeated(const reflection::Field* field);
261 
262   template <typename T>
Repeated(const reflection::Field * field)263   TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
264     return static_cast<TypedRepeatedField<T>*>(Repeated(field));
265   }
266 
267   template <typename T>
Repeated(StringPiece field_name)268   TypedRepeatedField<T>* Repeated(StringPiece field_name) {
269     return static_cast<TypedRepeatedField<T>*>(Repeated(field_name));
270   }
271 
272   // Serializes the flatbuffer.
273   flatbuffers::uoffset_t Serialize(
274       flatbuffers::FlatBufferBuilder* builder) const;
275   std::string Serialize() const;
276 
277   // Merges the fields from the given flatbuffer table into this flatbuffer.
278   // Scalar fields will be overwritten, if present in `from`.
279   // Embedded messages will be merged.
280   bool MergeFrom(const flatbuffers::Table* from);
281   bool MergeFromSerializedFlatbuffer(StringPiece from);
282 
283   // Flattens the flatbuffer as a flat map.
284   // (Nested) fields names are joined by `key_separator`.
285   std::map<std::string, Variant> AsFlatMap(
286       const std::string& key_separator = ".") const {
287     std::map<std::string, Variant> result;
288     AsFlatMap(key_separator, /*key_prefix=*/"", &result);
289     return result;
290   }
291 
292  private:
293   const reflection::Schema* const schema_;
294   const reflection::Object* const type_;
295 
296   // Cached primitive fields (scalars and strings).
297   std::map<const reflection::Field*, Variant> fields_;
298 
299   // Cached sub-messages.
300   std::map<const reflection::Field*, std::unique_ptr<ReflectiveFlatbuffer>>
301       children_;
302 
303   // Cached repeated fields.
304   std::map<const reflection::Field*, std::unique_ptr<RepeatedField>>
305       repeated_fields_;
306 
307   // Flattens the flatbuffer as a flat map.
308   // (Nested) fields names are joined by `key_separator` and prefixed by
309   // `key_prefix`.
310   void AsFlatMap(const std::string& key_separator,
311                  const std::string& key_prefix,
312                  std::map<std::string, Variant>* result) const;
313 };
314 
315 // A helper class to build flatbuffers based on schema reflection data.
316 // Can be used to a `ReflectiveFlatbuffer` for the root message of the
317 // schema, or any defined table via name.
318 class ReflectiveFlatbufferBuilder {
319  public:
ReflectiveFlatbufferBuilder(const reflection::Schema * schema)320   explicit ReflectiveFlatbufferBuilder(const reflection::Schema* schema)
321       : schema_(schema) {}
322 
323   // Starts a new root table message.
324   std::unique_ptr<ReflectiveFlatbuffer> NewRoot() const;
325 
326   // Starts a new table message. Returns nullptr if no table with given name is
327   // found in the schema.
328   std::unique_ptr<ReflectiveFlatbuffer> NewTable(
329       const StringPiece table_name) const;
330 
331  private:
332   const reflection::Schema* const schema_;
333 };
334 
335 }  // namespace libtextclassifier3
336 
337 #endif  // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
338