/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_ #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_ #include #include #include #include "annotator/model_generated.h" #include "utils/base/logging.h" #include "utils/flatbuffers/flatbuffers_generated.h" #include "utils/flatbuffers/reflection.h" #include "utils/strings/stringpiece.h" #include "utils/variant.h" #include "flatbuffers/flatbuffers.h" #include "flatbuffers/reflection.h" #include "flatbuffers/reflection_generated.h" namespace libtextclassifier3 { class MutableFlatbuffer; class RepeatedField; template constexpr bool IsStringType() { return std::is_same::value || std::is_same::value || std::is_same::value; } // Checks whether a variant value type agrees with a field type. template bool IsMatchingType(const reflection::BaseType type) { switch (type) { case reflection::String: return IsStringType(); case reflection::Obj: return std::is_same::value; default: return type == flatbuffers_base_type::value; } } // A mutable flatbuffer that can be built using flatbuffer reflection data of // the schema. Normally, field information is hard-coded in code generated from // a flatbuffer schema. Here we lookup the necessary information for building a // flatbuffer from the provided reflection meta data. When serializing a // flatbuffer, the library requires that the sub messages are already // serialized, therefore we explicitly keep the field values and serialize the // message in (reverse) topological dependency order. class MutableFlatbuffer { public: MutableFlatbuffer(const reflection::Schema* schema, const reflection::Object* type) : schema_(schema), type_(type) {} // Gets the field information for a field name, returns nullptr if the // field was not defined. const reflection::Field* GetFieldOrNull(const StringPiece field_name) const; const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const; const reflection::Field* GetFieldOrNull(const int field_offset) const; // Gets a nested field and the message it is defined on. bool GetFieldWithParent(const FlatbufferFieldPath* field_path, MutableFlatbuffer** parent, reflection::Field const** field); // Sets a field to a specific value. // Returns true if successful, and false if the field was not found or the // expected type doesn't match. template bool Set(StringPiece field_name, T value); // Sets a field to a specific value. // Returns true if successful, and false if the expected type doesn't match. // Expects `field` to be non-null. template bool Set(const reflection::Field* field, T value); // Sets a field to a specific value. Field is specified by path. template bool Set(const FlatbufferFieldPath* path, T value); // Sets an enum field from an enum value name. // Returns true if the value could be successfully parsed. bool SetFromEnumValueName(StringPiece field_name, StringPiece value_name); // Sets an enum field from an enum value name. // Returns true if the value could be successfully parsed. bool SetFromEnumValueName(const reflection::Field* field, StringPiece value_name); // Sets an enum field from an enum value name. Field is specified by path. // Returns true if the value could be successfully parsed. bool SetFromEnumValueName(const FlatbufferFieldPath* path, StringPiece value_name); // Sets sub-message field (if not set yet), and returns a pointer to it. // Returns nullptr if the field was not found, or the field type was not a // table. MutableFlatbuffer* Mutable(StringPiece field_name); MutableFlatbuffer* Mutable(const reflection::Field* field); // Sets a sub-message field (if not set yet) specified by path, and returns a // pointer to it. Returns nullptr if the field was not found, or the field // type was not a table. MutableFlatbuffer* Mutable(const FlatbufferFieldPath* path); // Parses the value (according to the type) and sets a primitive field to the // parsed value. bool ParseAndSet(const reflection::Field* field, const std::string& value); bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value); // Adds a primitive value to the repeated field. template bool Add(StringPiece field_name, T value); // Add a sub-message to the repeated field. MutableFlatbuffer* Add(StringPiece field_name); template bool Add(const reflection::Field* field, T value); MutableFlatbuffer* Add(const reflection::Field* field); // Gets the reflective flatbuffer for a repeated field. // Returns nullptr if the field was not found, or the field type was not a // vector. RepeatedField* Repeated(StringPiece field_name); RepeatedField* Repeated(const reflection::Field* field); // Gets a repeated field specified by path. // Returns nullptr if the field was not found, or the field // type was not a repeated field. RepeatedField* Repeated(const FlatbufferFieldPath* path); // Serializes the flatbuffer. flatbuffers::uoffset_t Serialize( flatbuffers::FlatBufferBuilder* builder) const; std::string Serialize() const; // Merges the fields from the given flatbuffer table into this flatbuffer. // Scalar fields will be overwritten, if present in `from`. // Embedded messages will be merged. bool MergeFrom(const flatbuffers::Table* from); bool MergeFromSerializedFlatbuffer(StringPiece from); // Flattens the flatbuffer as a flat map. // (Nested) fields names are joined by `key_separator`. std::map AsFlatMap( const std::string& key_separator = ".") const { std::map result; AsFlatMap(key_separator, /*key_prefix=*/"", &result); return result; } // Converts the flatbuffer's content to a human-readable textproto // representation. std::string ToTextProto() const; bool HasExplicitlySetFields() const { return !fields_.empty() || !children_.empty() || !repeated_fields_.empty(); } const reflection::Object* type() const { return type_; } private: // Helper function for merging given repeated field from given flatbuffer // table. Appends the elements. template bool AppendFromVector(const flatbuffers::Table* from, const reflection::Field* field); // Flattens the flatbuffer as a flat map. // (Nested) fields names are joined by `key_separator` and prefixed by // `key_prefix`. void AsFlatMap(const std::string& key_separator, const std::string& key_prefix, std::map* result) const; const reflection::Schema* const schema_; const reflection::Object* const type_; // Cached primitive fields (scalars and strings). std::unordered_map fields_; // Cached sub-messages. std::unordered_map> children_; // Cached repeated fields. std::unordered_map> repeated_fields_; }; // A helper class to build flatbuffers based on schema reflection data. // Can be used to a `MutableFlatbuffer` for the root message of the // schema, or any defined table via name. class MutableFlatbufferBuilder { public: explicit MutableFlatbufferBuilder(const reflection::Schema* schema) : schema_(schema), root_type_(schema->root_table()) {} explicit MutableFlatbufferBuilder(const reflection::Schema* schema, StringPiece root_type); // Starts a new root table message. std::unique_ptr NewRoot() const; // Creates a new table message. Returns nullptr if no table with given name is // found in the schema. std::unique_ptr NewTable( const StringPiece table_name) const; // Creates a new message for the given type id. Returns nullptr if the type is // invalid. std::unique_ptr NewTable(int type_id) const; // Creates a new message for the given type. std::unique_ptr NewTable( const reflection::Object* type) const; private: const reflection::Schema* const schema_; const reflection::Object* const root_type_; }; // Encapsulates a repeated field. // Serves as a common base class for repeated fields. class RepeatedField { public: RepeatedField(const reflection::Schema* const schema, const reflection::Field* field) : schema_(schema), field_(field), is_primitive_(field->type()->element() != reflection::BaseType::Obj) {} template bool Add(const T value); MutableFlatbuffer* Add(); template T Get(int index) const { return items_.at(index).Value(); } template <> MutableFlatbuffer* Get(int index) const { if (is_primitive_) { TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive " "repeated field."; return nullptr; } return object_items_.at(index).get(); } int Size() const { if (is_primitive_) { return items_.size(); } else { return object_items_.size(); } } bool Extend(const flatbuffers::Table* from); flatbuffers::uoffset_t Serialize( flatbuffers::FlatBufferBuilder* builder) const; std::string ToTextProto() const; private: template bool AppendFromVector(const flatbuffers::Table* from); flatbuffers::uoffset_t SerializeString( flatbuffers::FlatBufferBuilder* builder) const; flatbuffers::uoffset_t SerializeObject( flatbuffers::FlatBufferBuilder* builder) const; const reflection::Schema* const schema_; const reflection::Field* field_; bool is_primitive_; std::vector items_; std::vector> object_items_; }; template bool MutableFlatbuffer::Set(StringPiece field_name, T value) { if (const reflection::Field* field = GetFieldOrNull(field_name)) { if (field->type()->base_type() == reflection::BaseType::Vector || field->type()->base_type() == reflection::BaseType::Obj) { TC3_LOG(ERROR) << "Trying to set a primitive value on a non-scalar field."; return false; } return Set(field, value); } TC3_LOG(ERROR) << "Couldn't find a field: " << field_name; return false; } template bool MutableFlatbuffer::Set(const reflection::Field* field, T value) { if (field == nullptr) { TC3_LOG(ERROR) << "Expected non-null field."; return false; } Variant variant_value(value); if (!IsMatchingType(field->type()->base_type())) { TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str() << "`, expected: " << EnumNameBaseType(field->type()->base_type()) << ", got: " << variant_value.GetType(); return false; } fields_[field] = variant_value; return true; } template bool MutableFlatbuffer::Set(const FlatbufferFieldPath* path, T value) { MutableFlatbuffer* parent; const reflection::Field* field; if (!GetFieldWithParent(path, &parent, &field)) { return false; } return parent->Set(field, value); } template bool MutableFlatbuffer::Add(StringPiece field_name, T value) { const reflection::Field* field = GetFieldOrNull(field_name); if (field == nullptr) { return false; } if (field->type()->base_type() != reflection::BaseType::Vector) { return false; } return Add(field, value); } template bool MutableFlatbuffer::Add(const reflection::Field* field, T value) { if (field == nullptr) { return false; } Repeated(field)->Add(value); return true; } template bool RepeatedField::Add(const T value) { if (!is_primitive_ || !IsMatchingType(field_->type()->element())) { TC3_LOG(ERROR) << "Trying to add value of unmatching type."; return false; } items_.push_back(Variant{value}); return true; } template bool RepeatedField::AppendFromVector(const flatbuffers::Table* from) { const flatbuffers::Vector* values = from->GetPointer*>(field_->offset()); if (values == nullptr) { return false; } for (const T element : *values) { Add(element); } return true; } template <> inline bool RepeatedField::AppendFromVector( const flatbuffers::Table* from) { auto* values = from->GetPointer< const flatbuffers::Vector>*>( field_->offset()); if (values == nullptr) { return false; } for (const flatbuffers::String* element : *values) { Add(element->str()); } return true; } template <> inline bool RepeatedField::AppendFromVector( const flatbuffers::Table* from) { auto* values = from->GetPointer>*>(field_->offset()); if (values == nullptr) { return false; } for (const flatbuffers::Table* const from_element : *values) { MutableFlatbuffer* to_element = Add(); if (to_element == nullptr) { return false; } to_element->MergeFrom(from_element); } return true; } } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_MUTABLE_H_