/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/flatbuffers/mutable.h" #include #include "utils/flatbuffers/reflection.h" #include "utils/strings/numbers.h" #include "utils/variant.h" #include "flatbuffers/reflection_generated.h" namespace libtextclassifier3 { namespace { bool Parse(const std::string& str_value, float* value) { double double_value; if (!ParseDouble(str_value.data(), &double_value)) { return false; } *value = static_cast(double_value); return true; } bool Parse(const std::string& str_value, double* value) { return ParseDouble(str_value.data(), value); } bool Parse(const std::string& str_value, int64* value) { return ParseInt64(str_value.data(), value); } bool Parse(const std::string& str_value, int32* value) { return ParseInt32(str_value.data(), value); } bool Parse(const std::string& str_value, std::string* value) { *value = str_value; return true; } template bool ParseAndSetField(const reflection::Field* field, const std::string& str_value, MutableFlatbuffer* buffer) { T value; if (!Parse(str_value, &value)) { TC3_LOG(ERROR) << "Could not parse '" << str_value << "'"; return false; } if (field->type()->base_type() == reflection::Vector) { buffer->Repeated(field)->Add(value); return true; } else { return buffer->Set(field, value); } } } // namespace MutableFlatbufferBuilder::MutableFlatbufferBuilder( const reflection::Schema* schema, StringPiece root_type) : schema_(schema), root_type_(TypeForName(schema, root_type)) {} std::unique_ptr MutableFlatbufferBuilder::NewRoot() const { return NewTable(root_type_); } std::unique_ptr MutableFlatbufferBuilder::NewTable( StringPiece table_name) const { return NewTable(TypeForName(schema_, table_name)); } std::unique_ptr MutableFlatbufferBuilder::NewTable( const int type_id) const { if (type_id < 0 || type_id >= schema_->objects()->size()) { TC3_LOG(ERROR) << "Invalid type id: " << type_id; return nullptr; } return NewTable(schema_->objects()->Get(type_id)); } std::unique_ptr MutableFlatbufferBuilder::NewTable( const reflection::Object* type) const { if (type == nullptr) { return nullptr; } return std::make_unique(schema_, type); } const reflection::Field* MutableFlatbuffer::GetFieldOrNull( const StringPiece field_name) const { return libtextclassifier3::GetFieldOrNull(type_, field_name); } const reflection::Field* MutableFlatbuffer::GetFieldOrNull( const FlatbufferField* field) const { return libtextclassifier3::GetFieldOrNull(type_, field); } bool MutableFlatbuffer::GetFieldWithParent( const FlatbufferFieldPath* field_path, MutableFlatbuffer** parent, reflection::Field const** field) { const auto* path = field_path->field(); if (path == nullptr || path->size() == 0) { return false; } for (int i = 0; i < path->size(); i++) { *parent = (i == 0 ? this : (*parent)->Mutable(*field)); if (*parent == nullptr) { return false; } *field = (*parent)->GetFieldOrNull(path->Get(i)); if (*field == nullptr) { return false; } } return true; } const reflection::Field* MutableFlatbuffer::GetFieldOrNull( const int field_offset) const { return libtextclassifier3::GetFieldOrNull(type_, field_offset); } bool MutableFlatbuffer::SetFromEnumValueName(const reflection::Field* field, StringPiece value_name) { if (!IsEnum(field->type())) { return false; } Variant variant_value = ParseEnumValue(schema_, field->type(), value_name); if (!variant_value.HasValue()) { return false; } fields_[field] = variant_value; return true; } bool MutableFlatbuffer::SetFromEnumValueName(StringPiece field_name, StringPiece value_name) { if (const reflection::Field* field = GetFieldOrNull(field_name)) { return SetFromEnumValueName(field, value_name); } return false; } bool MutableFlatbuffer::SetFromEnumValueName(const FlatbufferFieldPath* path, StringPiece value_name) { MutableFlatbuffer* parent; const reflection::Field* field; if (!GetFieldWithParent(path, &parent, &field)) { return false; } return parent->SetFromEnumValueName(field, value_name); } bool MutableFlatbuffer::ParseAndSet(const reflection::Field* field, const std::string& value) { // Try parsing as an enum value. if (IsEnum(field->type()) && SetFromEnumValueName(field, value)) { return true; } switch (field->type()->base_type() == reflection::Vector ? field->type()->element() : field->type()->base_type()) { case reflection::String: return ParseAndSetField(field, value, this); case reflection::Int: return ParseAndSetField(field, value, this); case reflection::Long: return ParseAndSetField(field, value, this); case reflection::Float: return ParseAndSetField(field, value, this); case reflection::Double: return ParseAndSetField(field, value, this); default: TC3_LOG(ERROR) << "Unhandled field type: " << field->type()->base_type(); return false; } } bool MutableFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path, const std::string& value) { MutableFlatbuffer* parent; const reflection::Field* field; if (!GetFieldWithParent(path, &parent, &field)) { return false; } return parent->ParseAndSet(field, value); } MutableFlatbuffer* MutableFlatbuffer::Add(StringPiece field_name) { const reflection::Field* field = GetFieldOrNull(field_name); if (field == nullptr) { return nullptr; } if (field->type()->base_type() != reflection::BaseType::Vector) { return nullptr; } return Add(field); } MutableFlatbuffer* MutableFlatbuffer::Add(const reflection::Field* field) { if (field == nullptr) { return nullptr; } return Repeated(field)->Add(); } MutableFlatbuffer* MutableFlatbuffer::Mutable(const StringPiece field_name) { if (const reflection::Field* field = GetFieldOrNull(field_name)) { return Mutable(field); } TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString(); return nullptr; } MutableFlatbuffer* MutableFlatbuffer::Mutable(const reflection::Field* field) { if (field->type()->base_type() != reflection::Obj) { return nullptr; } const auto entry = children_.find(field); if (entry != children_.end()) { return entry->second.get(); } const auto it = children_.insert( /*hint=*/entry, std::make_pair( field, std::unique_ptr(new MutableFlatbuffer( schema_, schema_->objects()->Get(field->type()->index()))))); return it->second.get(); } MutableFlatbuffer* MutableFlatbuffer::Mutable(const FlatbufferFieldPath* path) { const auto* field_path = path->field(); if (field_path == nullptr || field_path->size() == 0) { return this; } MutableFlatbuffer* object = this; for (int i = 0; i < field_path->size(); i++) { const reflection::Field* field = object->GetFieldOrNull(field_path->Get(i)); if (field == nullptr) { return nullptr; } object = object->Mutable(field); if (object == nullptr) { return nullptr; } } return object; } RepeatedField* MutableFlatbuffer::Repeated(StringPiece field_name) { if (const reflection::Field* field = GetFieldOrNull(field_name)) { return Repeated(field); } TC3_LOG(ERROR) << "Unknown field: " << field_name.ToString(); return nullptr; } RepeatedField* MutableFlatbuffer::Repeated(const reflection::Field* field) { if (field->type()->base_type() != reflection::Vector) { TC3_LOG(ERROR) << "Field is not of type Vector."; return nullptr; } // If the repeated field was already set, return its instance. const auto entry = repeated_fields_.find(field); if (entry != repeated_fields_.end()) { return entry->second.get(); } // Otherwise, create a new instance and store it. std::unique_ptr repeated_field( new RepeatedField(schema_, field)); const auto it = repeated_fields_.insert( /*hint=*/entry, std::make_pair(field, std::move(repeated_field))); return it->second.get(); } RepeatedField* MutableFlatbuffer::Repeated(const FlatbufferFieldPath* path) { MutableFlatbuffer* parent; const reflection::Field* field; if (!GetFieldWithParent(path, &parent, &field)) { return nullptr; } return parent->Repeated(field); } flatbuffers::uoffset_t MutableFlatbuffer::Serialize( flatbuffers::FlatBufferBuilder* builder) const { // Build all children before we can start with this table. std::vector< std::pair> offsets; offsets.reserve(children_.size() + repeated_fields_.size()); for (const auto& it : children_) { offsets.push_back({it.first->offset(), it.second->Serialize(builder)}); } // Create strings. for (const auto& it : fields_) { if (it.second.Has()) { offsets.push_back( {it.first->offset(), builder->CreateString(it.second.ConstRefValue()).o}); } } // Build the repeated fields. for (const auto& it : repeated_fields_) { offsets.push_back({it.first->offset(), it.second->Serialize(builder)}); } // Build the table now. const flatbuffers::uoffset_t table_start = builder->StartTable(); // Add scalar fields. for (const auto& it : fields_) { switch (it.second.GetType()) { case Variant::TYPE_BOOL_VALUE: builder->AddElement( it.first->offset(), static_cast(it.second.Value()), static_cast(it.first->default_integer())); continue; case Variant::TYPE_INT8_VALUE: builder->AddElement( it.first->offset(), static_cast(it.second.Value()), static_cast(it.first->default_integer())); continue; case Variant::TYPE_UINT8_VALUE: builder->AddElement( it.first->offset(), static_cast(it.second.Value()), static_cast(it.first->default_integer())); continue; case Variant::TYPE_INT_VALUE: builder->AddElement( it.first->offset(), it.second.Value(), static_cast(it.first->default_integer())); continue; case Variant::TYPE_UINT_VALUE: builder->AddElement( it.first->offset(), it.second.Value(), static_cast(it.first->default_integer())); continue; case Variant::TYPE_INT64_VALUE: builder->AddElement(it.first->offset(), it.second.Value(), it.first->default_integer()); continue; case Variant::TYPE_UINT64_VALUE: builder->AddElement(it.first->offset(), it.second.Value(), it.first->default_integer()); continue; case Variant::TYPE_FLOAT_VALUE: builder->AddElement( it.first->offset(), it.second.Value(), static_cast(it.first->default_real())); continue; case Variant::TYPE_DOUBLE_VALUE: builder->AddElement(it.first->offset(), it.second.Value(), it.first->default_real()); continue; default: continue; } } // Add strings, subtables and repeated fields. for (const auto& it : offsets) { builder->AddOffset(it.first, flatbuffers::Offset(it.second)); } return builder->EndTable(table_start); } std::string MutableFlatbuffer::Serialize() const { flatbuffers::FlatBufferBuilder builder; builder.Finish(flatbuffers::Offset(Serialize(&builder))); return std::string(reinterpret_cast(builder.GetBufferPointer()), builder.GetSize()); } bool MutableFlatbuffer::MergeFrom(const flatbuffers::Table* from) { // No fields to set. if (type_->fields() == nullptr) { return true; } for (const reflection::Field* field : *type_->fields()) { // Skip fields that are not explicitly set. if (!from->CheckField(field->offset())) { continue; } const reflection::BaseType type = field->type()->base_type(); switch (type) { case reflection::Bool: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::Byte: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::UByte: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::Int: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::UInt: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::Long: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::ULong: Set(field, from->GetField(field->offset(), field->default_integer())); break; case reflection::Float: Set(field, from->GetField(field->offset(), field->default_real())); break; case reflection::Double: Set(field, from->GetField(field->offset(), field->default_real())); break; case reflection::String: Set( field, from->GetPointer(field->offset()) ->str()); break; case reflection::Obj: if (MutableFlatbuffer* nested_field = Mutable(field); nested_field == nullptr || !nested_field->MergeFrom( from->GetPointer( field->offset()))) { return false; } break; case reflection::Vector: { if (RepeatedField* repeated_field = Repeated(field); repeated_field == nullptr || !repeated_field->Extend(from)) { return false; } break; } default: TC3_LOG(ERROR) << "Unsupported type: " << type << " for field: " << field->name()->str(); return false; } } return true; } bool MutableFlatbuffer::MergeFromSerializedFlatbuffer(StringPiece from) { return MergeFrom(flatbuffers::GetAnyRoot( reinterpret_cast(from.data()))); } void MutableFlatbuffer::AsFlatMap( const std::string& key_separator, const std::string& key_prefix, std::map* result) const { // Add direct fields. for (const auto& it : fields_) { (*result)[key_prefix + it.first->name()->str()] = it.second; } // Add nested messages. for (const auto& it : children_) { it.second->AsFlatMap(key_separator, key_prefix + it.first->name()->str() + key_separator, result); } } std::string RepeatedField::ToTextProto() const { std::string result = " ["; std::string current_field_separator; for (int index = 0; index < Size(); index++) { if (is_primitive_) { result.append(current_field_separator + items_.at(index).ToString()); } else { result.append(current_field_separator + "{" + Get(index)->ToTextProto() + "}"); } current_field_separator = ", "; } result.append("] "); return result; } std::string MutableFlatbuffer::ToTextProto() const { std::string result; std::string current_field_separator; // Add direct fields. for (const auto& field_value_pair : fields_) { const std::string field_name = field_value_pair.first->name()->str(); const Variant& value = field_value_pair.second; std::string quotes; if (value.GetType() == Variant::TYPE_STRING_VALUE) { quotes = "'"; } result.append(current_field_separator + field_name + ": " + quotes + value.ToString() + quotes); current_field_separator = ", "; } // Add repeated message for (const auto& repeated_fb_pair : repeated_fields_) { result.append(current_field_separator + repeated_fb_pair.first->name()->c_str() + ": " + repeated_fb_pair.second->ToTextProto()); current_field_separator = ", "; } // Add nested messages. for (const auto& field_flatbuffer_pair : children_) { const std::string field_name = field_flatbuffer_pair.first->name()->str(); result.append(current_field_separator + field_name + " {" + field_flatbuffer_pair.second->ToTextProto() + "}"); current_field_separator = ", "; } return result; } // // Repeated field methods. // MutableFlatbuffer* RepeatedField::Add() { if (is_primitive_) { TC3_LOG(ERROR) << "Trying to add sub-message on a primitive-typed field."; return nullptr; } object_items_.emplace_back(new MutableFlatbuffer( schema_, schema_->objects()->Get(field_->type()->index()))); return object_items_.back().get(); } namespace { template flatbuffers::uoffset_t TypedSerialize(const std::vector& values, flatbuffers::FlatBufferBuilder* builder) { std::vector typed_values; typed_values.reserve(values.size()); for (const Variant& item : values) { typed_values.push_back(item.Value()); } return builder->CreateVector(typed_values).o; } } // namespace bool RepeatedField::Extend(const flatbuffers::Table* from) { switch (field_->type()->element()) { case reflection::Int: AppendFromVector(from); return true; case reflection::UInt: AppendFromVector(from); return true; case reflection::Long: AppendFromVector(from); return true; case reflection::ULong: AppendFromVector(from); return true; case reflection::Byte: AppendFromVector(from); return true; case reflection::UByte: AppendFromVector(from); return true; case reflection::String: AppendFromVector(from); return true; case reflection::Obj: AppendFromVector(from); return true; case reflection::Double: AppendFromVector(from); return true; case reflection::Float: AppendFromVector(from); return true; default: TC3_LOG(ERROR) << "Repeated unsupported type: " << field_->type()->element() << " for field: " << field_->name()->str(); return false; } } flatbuffers::uoffset_t RepeatedField::Serialize( flatbuffers::FlatBufferBuilder* builder) const { switch (field_->type()->element()) { case reflection::String: return SerializeString(builder); break; case reflection::Obj: return SerializeObject(builder); break; case reflection::Bool: return TypedSerialize(items_, builder); break; case reflection::Byte: return TypedSerialize(items_, builder); break; case reflection::UByte: return TypedSerialize(items_, builder); break; case reflection::Int: return TypedSerialize(items_, builder); break; case reflection::UInt: return TypedSerialize(items_, builder); break; case reflection::Long: return TypedSerialize(items_, builder); break; case reflection::ULong: return TypedSerialize(items_, builder); break; case reflection::Float: return TypedSerialize(items_, builder); break; case reflection::Double: return TypedSerialize(items_, builder); break; default: TC3_LOG(FATAL) << "Unsupported type: " << field_->type()->element(); break; } TC3_LOG(FATAL) << "Invalid state."; return 0; } flatbuffers::uoffset_t RepeatedField::SerializeString( flatbuffers::FlatBufferBuilder* builder) const { std::vector> offsets(items_.size()); for (int i = 0; i < items_.size(); i++) { offsets[i] = builder->CreateString(items_[i].ConstRefValue()); } return builder->CreateVector(offsets).o; } flatbuffers::uoffset_t RepeatedField::SerializeObject( flatbuffers::FlatBufferBuilder* builder) const { std::vector> offsets(object_items_.size()); for (int i = 0; i < object_items_.size(); i++) { offsets[i] = object_items_[i]->Serialize(builder); } return builder->CreateVector(offsets).o; } } // namespace libtextclassifier3