1 // Copyright 2016 Google Inc. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SRC_FIELD_INSTANCE_H_ 16 #define SRC_FIELD_INSTANCE_H_ 17 18 #include <memory> 19 #include <string> 20 21 #include "port/protobuf.h" 22 23 namespace protobuf_mutator { 24 25 // Helper class for common protobuf fields operations. 26 class ConstFieldInstance { 27 public: 28 static const size_t kInvalidIndex = -1; 29 30 struct Enum { 31 size_t index; 32 size_t count; 33 }; 34 ConstFieldInstance()35 ConstFieldInstance() 36 : message_(nullptr), descriptor_(nullptr), index_(kInvalidIndex) {} 37 ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)38 ConstFieldInstance(const protobuf::Message* message, 39 const protobuf::FieldDescriptor* field, size_t index) 40 : message_(message), descriptor_(field), index_(index) { 41 assert(message_); 42 assert(descriptor_); 43 assert(index_ != kInvalidIndex); 44 assert(descriptor_->is_repeated()); 45 } 46 ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field)47 ConstFieldInstance(const protobuf::Message* message, 48 const protobuf::FieldDescriptor* field) 49 : message_(message), descriptor_(field), index_(kInvalidIndex) { 50 assert(message_); 51 assert(descriptor_); 52 assert(!descriptor_->is_repeated()); 53 } 54 GetDefault(int32_t * out)55 void GetDefault(int32_t* out) const { 56 *out = descriptor_->default_value_int32(); 57 } 58 GetDefault(int64_t * out)59 void GetDefault(int64_t* out) const { 60 *out = descriptor_->default_value_int64(); 61 } 62 GetDefault(uint32_t * out)63 void GetDefault(uint32_t* out) const { 64 *out = descriptor_->default_value_uint32(); 65 } 66 GetDefault(uint64_t * out)67 void GetDefault(uint64_t* out) const { 68 *out = descriptor_->default_value_uint64(); 69 } 70 GetDefault(double * out)71 void GetDefault(double* out) const { 72 *out = descriptor_->default_value_double(); 73 } 74 GetDefault(float * out)75 void GetDefault(float* out) const { 76 *out = descriptor_->default_value_float(); 77 } 78 GetDefault(bool * out)79 void GetDefault(bool* out) const { *out = descriptor_->default_value_bool(); } 80 GetDefault(Enum * out)81 void GetDefault(Enum* out) const { 82 const protobuf::EnumValueDescriptor* value = 83 descriptor_->default_value_enum(); 84 const protobuf::EnumDescriptor* type = value->type(); 85 *out = {static_cast<size_t>(value->index()), 86 static_cast<size_t>(type->value_count())}; 87 } 88 GetDefault(std::string * out)89 void GetDefault(std::string* out) const { 90 *out = descriptor_->default_value_string(); 91 } 92 GetDefault(std::unique_ptr<protobuf::Message> * out)93 void GetDefault(std::unique_ptr<protobuf::Message>* out) const { 94 out->reset(reflection() 95 .GetMessageFactory() 96 ->GetPrototype(descriptor_->message_type()) 97 ->New()); 98 } 99 Load(int32_t * value)100 void Load(int32_t* value) const { 101 *value = is_repeated() 102 ? reflection().GetRepeatedInt32(*message_, descriptor_, index_) 103 : reflection().GetInt32(*message_, descriptor_); 104 } 105 Load(int64_t * value)106 void Load(int64_t* value) const { 107 *value = is_repeated() 108 ? reflection().GetRepeatedInt64(*message_, descriptor_, index_) 109 : reflection().GetInt64(*message_, descriptor_); 110 } 111 Load(uint32_t * value)112 void Load(uint32_t* value) const { 113 *value = is_repeated() ? reflection().GetRepeatedUInt32(*message_, 114 descriptor_, index_) 115 : reflection().GetUInt32(*message_, descriptor_); 116 } 117 Load(uint64_t * value)118 void Load(uint64_t* value) const { 119 *value = is_repeated() ? reflection().GetRepeatedUInt64(*message_, 120 descriptor_, index_) 121 : reflection().GetUInt64(*message_, descriptor_); 122 } 123 Load(double * value)124 void Load(double* value) const { 125 *value = is_repeated() ? reflection().GetRepeatedDouble(*message_, 126 descriptor_, index_) 127 : reflection().GetDouble(*message_, descriptor_); 128 } 129 Load(float * value)130 void Load(float* value) const { 131 *value = is_repeated() 132 ? reflection().GetRepeatedFloat(*message_, descriptor_, index_) 133 : reflection().GetFloat(*message_, descriptor_); 134 } 135 Load(bool * value)136 void Load(bool* value) const { 137 *value = is_repeated() 138 ? reflection().GetRepeatedBool(*message_, descriptor_, index_) 139 : reflection().GetBool(*message_, descriptor_); 140 } 141 Load(Enum * value)142 void Load(Enum* value) const { 143 const protobuf::EnumValueDescriptor* value_descriptor = 144 is_repeated() 145 ? reflection().GetRepeatedEnum(*message_, descriptor_, index_) 146 : reflection().GetEnum(*message_, descriptor_); 147 *value = {static_cast<size_t>(value_descriptor->index()), 148 static_cast<size_t>(value_descriptor->type()->value_count())}; 149 if (value->index >= value->count) GetDefault(value); 150 } 151 Load(std::string * value)152 void Load(std::string* value) const { 153 *value = is_repeated() ? reflection().GetRepeatedString(*message_, 154 descriptor_, index_) 155 : reflection().GetString(*message_, descriptor_); 156 } 157 Load(std::unique_ptr<protobuf::Message> * value)158 void Load(std::unique_ptr<protobuf::Message>* value) const { 159 const protobuf::Message& source = 160 is_repeated() 161 ? reflection().GetRepeatedMessage(*message_, descriptor_, index_) 162 : reflection().GetMessage(*message_, descriptor_); 163 value->reset(source.New()); 164 (*value)->CopyFrom(source); 165 } 166 167 template <class T> CanStore(const T & value)168 bool CanStore(const T& value) const { 169 return true; 170 } 171 CanStore(const std::string & value)172 bool CanStore(const std::string& value) const { 173 if (!EnforceUtf8()) return true; 174 using protobuf::internal::WireFormatLite; 175 return WireFormatLite::VerifyUtf8String(value.data(), value.length(), 176 WireFormatLite::PARSE, ""); 177 } 178 name()179 std::string name() const { return descriptor_->name(); } 180 cpp_type()181 protobuf::FieldDescriptor::CppType cpp_type() const { 182 return descriptor_->cpp_type(); 183 } 184 enum_type()185 const protobuf::EnumDescriptor* enum_type() const { 186 return descriptor_->enum_type(); 187 } 188 message_type()189 const protobuf::Descriptor* message_type() const { 190 return descriptor_->message_type(); 191 } 192 EnforceUtf8()193 bool EnforceUtf8() const { 194 return descriptor_->type() == protobuf::FieldDescriptor::TYPE_STRING && 195 descriptor()->file()->syntax() == 196 protobuf::FileDescriptor::SYNTAX_PROTO3; 197 } 198 descriptor()199 const protobuf::FieldDescriptor* descriptor() const { return descriptor_; } 200 DebugString()201 std::string DebugString() const { 202 std::string s = descriptor_->DebugString(); 203 if (is_repeated()) s += "[" + std::to_string(index_) + "]"; 204 return s + " of\n" + message_->DebugString(); 205 } 206 207 protected: is_repeated()208 bool is_repeated() const { return descriptor_->is_repeated(); } 209 reflection()210 const protobuf::Reflection& reflection() const { 211 return *message_->GetReflection(); 212 } 213 index()214 size_t index() const { return index_; } 215 216 private: 217 template <class Fn, class T> 218 friend struct FieldFunction; 219 220 const protobuf::Message* message_; 221 const protobuf::FieldDescriptor* descriptor_; 222 size_t index_; 223 }; 224 225 class FieldInstance : public ConstFieldInstance { 226 public: 227 static const size_t kInvalidIndex = -1; 228 FieldInstance()229 FieldInstance() : ConstFieldInstance(), message_(nullptr) {} 230 FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)231 FieldInstance(protobuf::Message* message, 232 const protobuf::FieldDescriptor* field, size_t index) 233 : ConstFieldInstance(message, field, index), message_(message) {} 234 FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field)235 FieldInstance(protobuf::Message* message, 236 const protobuf::FieldDescriptor* field) 237 : ConstFieldInstance(message, field), message_(message) {} 238 Delete()239 void Delete() const { 240 if (!is_repeated()) return reflection().ClearField(message_, descriptor()); 241 int field_size = reflection().FieldSize(*message_, descriptor()); 242 // API has only method to delete the last message, so we move method from 243 // the 244 // middle to the end. 245 for (int i = index() + 1; i < field_size; ++i) 246 reflection().SwapElements(message_, descriptor(), i, i - 1); 247 reflection().RemoveLast(message_, descriptor()); 248 } 249 250 template <class T> Create(const T & value)251 void Create(const T& value) const { 252 if (!is_repeated()) return Store(value); 253 InsertRepeated(value); 254 } 255 Store(int32_t value)256 void Store(int32_t value) const { 257 if (is_repeated()) 258 reflection().SetRepeatedInt32(message_, descriptor(), index(), value); 259 else 260 reflection().SetInt32(message_, descriptor(), value); 261 } 262 Store(int64_t value)263 void Store(int64_t value) const { 264 if (is_repeated()) 265 reflection().SetRepeatedInt64(message_, descriptor(), index(), value); 266 else 267 reflection().SetInt64(message_, descriptor(), value); 268 } 269 Store(uint32_t value)270 void Store(uint32_t value) const { 271 if (is_repeated()) 272 reflection().SetRepeatedUInt32(message_, descriptor(), index(), value); 273 else 274 reflection().SetUInt32(message_, descriptor(), value); 275 } 276 Store(uint64_t value)277 void Store(uint64_t value) const { 278 if (is_repeated()) 279 reflection().SetRepeatedUInt64(message_, descriptor(), index(), value); 280 else 281 reflection().SetUInt64(message_, descriptor(), value); 282 } 283 Store(double value)284 void Store(double value) const { 285 if (is_repeated()) 286 reflection().SetRepeatedDouble(message_, descriptor(), index(), value); 287 else 288 reflection().SetDouble(message_, descriptor(), value); 289 } 290 Store(float value)291 void Store(float value) const { 292 if (is_repeated()) 293 reflection().SetRepeatedFloat(message_, descriptor(), index(), value); 294 else 295 reflection().SetFloat(message_, descriptor(), value); 296 } 297 Store(bool value)298 void Store(bool value) const { 299 if (is_repeated()) 300 reflection().SetRepeatedBool(message_, descriptor(), index(), value); 301 else 302 reflection().SetBool(message_, descriptor(), value); 303 } 304 Store(const Enum & value)305 void Store(const Enum& value) const { 306 assert(value.index < value.count); 307 const protobuf::EnumValueDescriptor* enum_value = 308 descriptor()->enum_type()->value(value.index); 309 if (is_repeated()) 310 reflection().SetRepeatedEnum(message_, descriptor(), index(), enum_value); 311 else 312 reflection().SetEnum(message_, descriptor(), enum_value); 313 } 314 Store(const std::string & value)315 void Store(const std::string& value) const { 316 if (is_repeated()) 317 reflection().SetRepeatedString(message_, descriptor(), index(), value); 318 else 319 reflection().SetString(message_, descriptor(), value); 320 } 321 Store(const std::unique_ptr<protobuf::Message> & value)322 void Store(const std::unique_ptr<protobuf::Message>& value) const { 323 protobuf::Message* mutable_message = 324 is_repeated() ? reflection().MutableRepeatedMessage( 325 message_, descriptor(), index()) 326 : reflection().MutableMessage(message_, descriptor()); 327 mutable_message->Clear(); 328 if (value) mutable_message->CopyFrom(*value); 329 } 330 331 private: 332 template <class T> InsertRepeated(const T & value)333 void InsertRepeated(const T& value) const { 334 PushBackRepeated(value); 335 size_t field_size = reflection().FieldSize(*message_, descriptor()); 336 if (field_size == 1) return; 337 // API has only method to add field to the end of the list. So we add 338 // descriptor() 339 // and move it into the middle. 340 for (size_t i = field_size - 1; i > index(); --i) 341 reflection().SwapElements(message_, descriptor(), i, i - 1); 342 } 343 PushBackRepeated(int32_t value)344 void PushBackRepeated(int32_t value) const { 345 assert(is_repeated()); 346 reflection().AddInt32(message_, descriptor(), value); 347 } 348 PushBackRepeated(int64_t value)349 void PushBackRepeated(int64_t value) const { 350 assert(is_repeated()); 351 reflection().AddInt64(message_, descriptor(), value); 352 } 353 PushBackRepeated(uint32_t value)354 void PushBackRepeated(uint32_t value) const { 355 assert(is_repeated()); 356 reflection().AddUInt32(message_, descriptor(), value); 357 } 358 PushBackRepeated(uint64_t value)359 void PushBackRepeated(uint64_t value) const { 360 assert(is_repeated()); 361 reflection().AddUInt64(message_, descriptor(), value); 362 } 363 PushBackRepeated(double value)364 void PushBackRepeated(double value) const { 365 assert(is_repeated()); 366 reflection().AddDouble(message_, descriptor(), value); 367 } 368 PushBackRepeated(float value)369 void PushBackRepeated(float value) const { 370 assert(is_repeated()); 371 reflection().AddFloat(message_, descriptor(), value); 372 } 373 PushBackRepeated(bool value)374 void PushBackRepeated(bool value) const { 375 assert(is_repeated()); 376 reflection().AddBool(message_, descriptor(), value); 377 } 378 PushBackRepeated(const Enum & value)379 void PushBackRepeated(const Enum& value) const { 380 assert(value.index < value.count); 381 const protobuf::EnumValueDescriptor* enum_value = 382 descriptor()->enum_type()->value(value.index); 383 assert(is_repeated()); 384 reflection().AddEnum(message_, descriptor(), enum_value); 385 } 386 PushBackRepeated(const std::string & value)387 void PushBackRepeated(const std::string& value) const { 388 assert(is_repeated()); 389 reflection().AddString(message_, descriptor(), value); 390 } 391 PushBackRepeated(const std::unique_ptr<protobuf::Message> & value)392 void PushBackRepeated(const std::unique_ptr<protobuf::Message>& value) const { 393 assert(is_repeated()); 394 protobuf::Message* mutable_message = 395 reflection().AddMessage(message_, descriptor()); 396 mutable_message->Clear(); 397 if (value) mutable_message->CopyFrom(*value); 398 } 399 400 protobuf::Message* message_; 401 }; 402 403 template <class Fn, class R = void> 404 struct FieldFunction { 405 template <class Field, class... Args> operatorFieldFunction406 R operator()(const Field& field, const Args&... args) const { 407 assert(field.descriptor()); 408 using protobuf::FieldDescriptor; 409 switch (field.cpp_type()) { 410 case FieldDescriptor::CPPTYPE_INT32: 411 return static_cast<const Fn*>(this)->template ForType<int32_t>(field, 412 args...); 413 case FieldDescriptor::CPPTYPE_INT64: 414 return static_cast<const Fn*>(this)->template ForType<int64_t>(field, 415 args...); 416 case FieldDescriptor::CPPTYPE_UINT32: 417 return static_cast<const Fn*>(this)->template ForType<uint32_t>( 418 field, args...); 419 case FieldDescriptor::CPPTYPE_UINT64: 420 return static_cast<const Fn*>(this)->template ForType<uint64_t>( 421 field, args...); 422 case FieldDescriptor::CPPTYPE_DOUBLE: 423 return static_cast<const Fn*>(this)->template ForType<double>(field, 424 args...); 425 case FieldDescriptor::CPPTYPE_FLOAT: 426 return static_cast<const Fn*>(this)->template ForType<float>(field, 427 args...); 428 case FieldDescriptor::CPPTYPE_BOOL: 429 return static_cast<const Fn*>(this)->template ForType<bool>(field, 430 args...); 431 case FieldDescriptor::CPPTYPE_ENUM: 432 return static_cast<const Fn*>(this) 433 ->template ForType<ConstFieldInstance::Enum>(field, args...); 434 case FieldDescriptor::CPPTYPE_STRING: 435 return static_cast<const Fn*>(this)->template ForType<std::string>( 436 field, args...); 437 case FieldDescriptor::CPPTYPE_MESSAGE: 438 return static_cast<const Fn*>(this) 439 ->template ForType<std::unique_ptr<protobuf::Message>>(field, 440 args...); 441 } 442 assert(false && "Unknown type"); 443 abort(); 444 } 445 }; 446 447 } // namespace protobuf_mutator 448 449 #endif // SRC_FIELD_INSTANCE_H_ 450