// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef SRC_FIELD_INSTANCE_H_ #define SRC_FIELD_INSTANCE_H_ #include #include #include "port/protobuf.h" namespace protobuf_mutator { // Helper class for common protobuf fields operations. class ConstFieldInstance { public: static const size_t kInvalidIndex = -1; struct Enum { size_t index; size_t count; }; ConstFieldInstance() : message_(nullptr), descriptor_(nullptr), index_(kInvalidIndex) {} ConstFieldInstance(const protobuf::Message* message, const protobuf::FieldDescriptor* field, size_t index) : message_(message), descriptor_(field), index_(index) { assert(message_); assert(descriptor_); assert(index_ != kInvalidIndex); assert(descriptor_->is_repeated()); } ConstFieldInstance(const protobuf::Message* message, const protobuf::FieldDescriptor* field) : message_(message), descriptor_(field), index_(kInvalidIndex) { assert(message_); assert(descriptor_); assert(!descriptor_->is_repeated()); } void GetDefault(int32_t* out) const { *out = descriptor_->default_value_int32(); } void GetDefault(int64_t* out) const { *out = descriptor_->default_value_int64(); } void GetDefault(uint32_t* out) const { *out = descriptor_->default_value_uint32(); } void GetDefault(uint64_t* out) const { *out = descriptor_->default_value_uint64(); } void GetDefault(double* out) const { *out = descriptor_->default_value_double(); } void GetDefault(float* out) const { *out = descriptor_->default_value_float(); } void GetDefault(bool* out) const { *out = descriptor_->default_value_bool(); } void GetDefault(Enum* out) const { const protobuf::EnumValueDescriptor* value = descriptor_->default_value_enum(); const protobuf::EnumDescriptor* type = value->type(); *out = {static_cast(value->index()), static_cast(type->value_count())}; } void GetDefault(std::string* out) const { *out = descriptor_->default_value_string(); } void GetDefault(std::unique_ptr* out) const { out->reset(reflection() .GetMessageFactory() ->GetPrototype(descriptor_->message_type()) ->New()); } void Load(int32_t* value) const { *value = is_repeated() ? reflection().GetRepeatedInt32(*message_, descriptor_, index_) : reflection().GetInt32(*message_, descriptor_); } void Load(int64_t* value) const { *value = is_repeated() ? reflection().GetRepeatedInt64(*message_, descriptor_, index_) : reflection().GetInt64(*message_, descriptor_); } void Load(uint32_t* value) const { *value = is_repeated() ? reflection().GetRepeatedUInt32(*message_, descriptor_, index_) : reflection().GetUInt32(*message_, descriptor_); } void Load(uint64_t* value) const { *value = is_repeated() ? reflection().GetRepeatedUInt64(*message_, descriptor_, index_) : reflection().GetUInt64(*message_, descriptor_); } void Load(double* value) const { *value = is_repeated() ? reflection().GetRepeatedDouble(*message_, descriptor_, index_) : reflection().GetDouble(*message_, descriptor_); } void Load(float* value) const { *value = is_repeated() ? reflection().GetRepeatedFloat(*message_, descriptor_, index_) : reflection().GetFloat(*message_, descriptor_); } void Load(bool* value) const { *value = is_repeated() ? reflection().GetRepeatedBool(*message_, descriptor_, index_) : reflection().GetBool(*message_, descriptor_); } void Load(Enum* value) const { const protobuf::EnumValueDescriptor* value_descriptor = is_repeated() ? reflection().GetRepeatedEnum(*message_, descriptor_, index_) : reflection().GetEnum(*message_, descriptor_); *value = {static_cast(value_descriptor->index()), static_cast(value_descriptor->type()->value_count())}; } void Load(std::string* value) const { *value = is_repeated() ? reflection().GetRepeatedString(*message_, descriptor_, index_) : reflection().GetString(*message_, descriptor_); } void Load(std::unique_ptr* value) const { const protobuf::Message& source = is_repeated() ? reflection().GetRepeatedMessage(*message_, descriptor_, index_) : reflection().GetMessage(*message_, descriptor_); value->reset(source.New()); (*value)->CopyFrom(source); } std::string name() const { return descriptor_->name(); } protobuf::FieldDescriptor::CppType cpp_type() const { return descriptor_->cpp_type(); } const protobuf::EnumDescriptor* enum_type() const { return descriptor_->enum_type(); } const protobuf::Descriptor* message_type() const { return descriptor_->message_type(); } bool EnforceUtf8() const { return descriptor_->type() == protobuf::FieldDescriptor::TYPE_STRING && descriptor()->file()->syntax() == protobuf::FileDescriptor::SYNTAX_PROTO3; } protected: bool is_repeated() const { return descriptor_->is_repeated(); } const protobuf::Reflection& reflection() const { return *message_->GetReflection(); } const protobuf::FieldDescriptor* descriptor() const { return descriptor_; } size_t index() const { return index_; } private: template friend struct FieldFunction; const protobuf::Message* message_; const protobuf::FieldDescriptor* descriptor_; size_t index_; }; class FieldInstance : public ConstFieldInstance { public: static const size_t kInvalidIndex = -1; FieldInstance() : ConstFieldInstance(), message_(nullptr) {} FieldInstance(protobuf::Message* message, const protobuf::FieldDescriptor* field, size_t index) : ConstFieldInstance(message, field, index), message_(message) {} FieldInstance(protobuf::Message* message, const protobuf::FieldDescriptor* field) : ConstFieldInstance(message, field), message_(message) {} void Delete() const { if (!is_repeated()) return reflection().ClearField(message_, descriptor()); int field_size = reflection().FieldSize(*message_, descriptor()); // API has only method to delete the last message, so we move method from // the // middle to the end. for (int i = index() + 1; i < field_size; ++i) reflection().SwapElements(message_, descriptor(), i, i - 1); reflection().RemoveLast(message_, descriptor()); } template void Create(const T& value) const { if (!is_repeated()) return Store(value); InsertRepeated(value); } void Store(int32_t value) const { if (is_repeated()) reflection().SetRepeatedInt32(message_, descriptor(), index(), value); else reflection().SetInt32(message_, descriptor(), value); } void Store(int64_t value) const { if (is_repeated()) reflection().SetRepeatedInt64(message_, descriptor(), index(), value); else reflection().SetInt64(message_, descriptor(), value); } void Store(uint32_t value) const { if (is_repeated()) reflection().SetRepeatedUInt32(message_, descriptor(), index(), value); else reflection().SetUInt32(message_, descriptor(), value); } void Store(uint64_t value) const { if (is_repeated()) reflection().SetRepeatedUInt64(message_, descriptor(), index(), value); else reflection().SetUInt64(message_, descriptor(), value); } void Store(double value) const { if (is_repeated()) reflection().SetRepeatedDouble(message_, descriptor(), index(), value); else reflection().SetDouble(message_, descriptor(), value); } void Store(float value) const { if (is_repeated()) reflection().SetRepeatedFloat(message_, descriptor(), index(), value); else reflection().SetFloat(message_, descriptor(), value); } void Store(bool value) const { if (is_repeated()) reflection().SetRepeatedBool(message_, descriptor(), index(), value); else reflection().SetBool(message_, descriptor(), value); } void Store(const Enum& value) const { assert(value.index < value.count); const protobuf::EnumValueDescriptor* enum_value = descriptor()->enum_type()->value(value.index); if (is_repeated()) reflection().SetRepeatedEnum(message_, descriptor(), index(), enum_value); else reflection().SetEnum(message_, descriptor(), enum_value); } void Store(const std::string& value) const { if (is_repeated()) reflection().SetRepeatedString(message_, descriptor(), index(), value); else reflection().SetString(message_, descriptor(), value); } void Store(const std::unique_ptr& value) const { protobuf::Message* mutable_message = is_repeated() ? reflection().MutableRepeatedMessage( message_, descriptor(), index()) : reflection().MutableMessage(message_, descriptor()); mutable_message->Clear(); if (value) mutable_message->CopyFrom(*value); } private: template void InsertRepeated(const T& value) const { PushBackRepeated(value); size_t field_size = reflection().FieldSize(*message_, descriptor()); if (field_size == 1) return; // API has only method to add field to the end of the list. So we add // descriptor() // and move it into the middle. for (size_t i = field_size - 1; i > index(); --i) reflection().SwapElements(message_, descriptor(), i, i - 1); } void PushBackRepeated(int32_t value) const { assert(is_repeated()); reflection().AddInt32(message_, descriptor(), value); } void PushBackRepeated(int64_t value) const { assert(is_repeated()); reflection().AddInt64(message_, descriptor(), value); } void PushBackRepeated(uint32_t value) const { assert(is_repeated()); reflection().AddUInt32(message_, descriptor(), value); } void PushBackRepeated(uint64_t value) const { assert(is_repeated()); reflection().AddUInt64(message_, descriptor(), value); } void PushBackRepeated(double value) const { assert(is_repeated()); reflection().AddDouble(message_, descriptor(), value); } void PushBackRepeated(float value) const { assert(is_repeated()); reflection().AddFloat(message_, descriptor(), value); } void PushBackRepeated(bool value) const { assert(is_repeated()); reflection().AddBool(message_, descriptor(), value); } void PushBackRepeated(const Enum& value) const { assert(value.index < value.count); const protobuf::EnumValueDescriptor* enum_value = descriptor()->enum_type()->value(value.index); assert(is_repeated()); reflection().AddEnum(message_, descriptor(), enum_value); } void PushBackRepeated(const std::string& value) const { assert(is_repeated()); reflection().AddString(message_, descriptor(), value); } void PushBackRepeated(const std::unique_ptr& value) const { assert(is_repeated()); protobuf::Message* mutable_message = reflection().AddMessage(message_, descriptor()); mutable_message->Clear(); if (value) mutable_message->CopyFrom(*value); } protobuf::Message* message_; }; template struct FieldFunction { template R operator()(const Field& field, const Args&... args) const { assert(field.descriptor()); using protobuf::FieldDescriptor; switch (field.cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: return static_cast(this)->template ForType(field, args...); case FieldDescriptor::CPPTYPE_INT64: return static_cast(this)->template ForType(field, args...); case FieldDescriptor::CPPTYPE_UINT32: return static_cast(this)->template ForType( field, args...); case FieldDescriptor::CPPTYPE_UINT64: return static_cast(this)->template ForType( field, args...); case FieldDescriptor::CPPTYPE_DOUBLE: return static_cast(this)->template ForType(field, args...); case FieldDescriptor::CPPTYPE_FLOAT: return static_cast(this)->template ForType(field, args...); case FieldDescriptor::CPPTYPE_BOOL: return static_cast(this)->template ForType(field, args...); case FieldDescriptor::CPPTYPE_ENUM: return static_cast(this) ->template ForType(field, args...); case FieldDescriptor::CPPTYPE_STRING: return static_cast(this)->template ForType( field, args...); case FieldDescriptor::CPPTYPE_MESSAGE: return static_cast(this) ->template ForType>(field, args...); } assert(false && "Unknown type"); abort(); } }; } // namespace protobuf_mutator #endif // SRC_FIELD_INSTANCE_H_