• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SRC_FIELD_INSTANCE_H_
16 #define SRC_FIELD_INSTANCE_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "port/protobuf.h"
22 
23 namespace protobuf_mutator {
24 
25 // Helper class for common protobuf fields operations.
26 class ConstFieldInstance {
27  public:
28   static const size_t kInvalidIndex = -1;
29 
30   struct Enum {
31     size_t index;
32     size_t count;
33   };
34 
ConstFieldInstance()35   ConstFieldInstance()
36       : message_(nullptr), descriptor_(nullptr), index_(kInvalidIndex) {}
37 
ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)38   ConstFieldInstance(const protobuf::Message* message,
39                      const protobuf::FieldDescriptor* field, size_t index)
40       : message_(message), descriptor_(field), index_(index) {
41     assert(message_);
42     assert(descriptor_);
43     assert(index_ != kInvalidIndex);
44     assert(descriptor_->is_repeated());
45   }
46 
ConstFieldInstance(const protobuf::Message * message,const protobuf::FieldDescriptor * field)47   ConstFieldInstance(const protobuf::Message* message,
48                      const protobuf::FieldDescriptor* field)
49       : message_(message), descriptor_(field), index_(kInvalidIndex) {
50     assert(message_);
51     assert(descriptor_);
52     assert(!descriptor_->is_repeated());
53   }
54 
GetDefault(int32_t * out)55   void GetDefault(int32_t* out) const {
56     *out = descriptor_->default_value_int32();
57   }
58 
GetDefault(int64_t * out)59   void GetDefault(int64_t* out) const {
60     *out = descriptor_->default_value_int64();
61   }
62 
GetDefault(uint32_t * out)63   void GetDefault(uint32_t* out) const {
64     *out = descriptor_->default_value_uint32();
65   }
66 
GetDefault(uint64_t * out)67   void GetDefault(uint64_t* out) const {
68     *out = descriptor_->default_value_uint64();
69   }
70 
GetDefault(double * out)71   void GetDefault(double* out) const {
72     *out = descriptor_->default_value_double();
73   }
74 
GetDefault(float * out)75   void GetDefault(float* out) const {
76     *out = descriptor_->default_value_float();
77   }
78 
GetDefault(bool * out)79   void GetDefault(bool* out) const { *out = descriptor_->default_value_bool(); }
80 
GetDefault(Enum * out)81   void GetDefault(Enum* out) const {
82     const protobuf::EnumValueDescriptor* value =
83         descriptor_->default_value_enum();
84     const protobuf::EnumDescriptor* type = value->type();
85     *out = {static_cast<size_t>(value->index()),
86             static_cast<size_t>(type->value_count())};
87   }
88 
GetDefault(std::string * out)89   void GetDefault(std::string* out) const {
90     *out = descriptor_->default_value_string();
91   }
92 
GetDefault(std::unique_ptr<protobuf::Message> * out)93   void GetDefault(std::unique_ptr<protobuf::Message>* out) const {
94     out->reset(reflection()
95                    .GetMessageFactory()
96                    ->GetPrototype(descriptor_->message_type())
97                    ->New());
98   }
99 
Load(int32_t * value)100   void Load(int32_t* value) const {
101     *value = is_repeated()
102                  ? reflection().GetRepeatedInt32(*message_, descriptor_, index_)
103                  : reflection().GetInt32(*message_, descriptor_);
104   }
105 
Load(int64_t * value)106   void Load(int64_t* value) const {
107     *value = is_repeated()
108                  ? reflection().GetRepeatedInt64(*message_, descriptor_, index_)
109                  : reflection().GetInt64(*message_, descriptor_);
110   }
111 
Load(uint32_t * value)112   void Load(uint32_t* value) const {
113     *value = is_repeated() ? reflection().GetRepeatedUInt32(*message_,
114                                                             descriptor_, index_)
115                            : reflection().GetUInt32(*message_, descriptor_);
116   }
117 
Load(uint64_t * value)118   void Load(uint64_t* value) const {
119     *value = is_repeated() ? reflection().GetRepeatedUInt64(*message_,
120                                                             descriptor_, index_)
121                            : reflection().GetUInt64(*message_, descriptor_);
122   }
123 
Load(double * value)124   void Load(double* value) const {
125     *value = is_repeated() ? reflection().GetRepeatedDouble(*message_,
126                                                             descriptor_, index_)
127                            : reflection().GetDouble(*message_, descriptor_);
128   }
129 
Load(float * value)130   void Load(float* value) const {
131     *value = is_repeated()
132                  ? reflection().GetRepeatedFloat(*message_, descriptor_, index_)
133                  : reflection().GetFloat(*message_, descriptor_);
134   }
135 
Load(bool * value)136   void Load(bool* value) const {
137     *value = is_repeated()
138                  ? reflection().GetRepeatedBool(*message_, descriptor_, index_)
139                  : reflection().GetBool(*message_, descriptor_);
140   }
141 
Load(Enum * value)142   void Load(Enum* value) const {
143     const protobuf::EnumValueDescriptor* value_descriptor =
144         is_repeated()
145             ? reflection().GetRepeatedEnum(*message_, descriptor_, index_)
146             : reflection().GetEnum(*message_, descriptor_);
147     *value = {static_cast<size_t>(value_descriptor->index()),
148               static_cast<size_t>(value_descriptor->type()->value_count())};
149     if (value->index >= value->count) GetDefault(value);
150   }
151 
Load(std::string * value)152   void Load(std::string* value) const {
153     *value = is_repeated() ? reflection().GetRepeatedString(*message_,
154                                                             descriptor_, index_)
155                            : reflection().GetString(*message_, descriptor_);
156   }
157 
Load(std::unique_ptr<protobuf::Message> * value)158   void Load(std::unique_ptr<protobuf::Message>* value) const {
159     const protobuf::Message& source =
160         is_repeated()
161             ? reflection().GetRepeatedMessage(*message_, descriptor_, index_)
162             : reflection().GetMessage(*message_, descriptor_);
163     value->reset(source.New());
164     (*value)->CopyFrom(source);
165   }
166 
167   template <class T>
CanStore(const T & value)168   bool CanStore(const T& value) const {
169     return true;
170   }
171 
CanStore(const std::string & value)172   bool CanStore(const std::string& value) const {
173     if (!EnforceUtf8()) return true;
174     using protobuf::internal::WireFormatLite;
175     return WireFormatLite::VerifyUtf8String(value.data(), value.length(),
176                                             WireFormatLite::PARSE, "");
177   }
178 
name()179   std::string name() const { return descriptor_->name(); }
180 
cpp_type()181   protobuf::FieldDescriptor::CppType cpp_type() const {
182     return descriptor_->cpp_type();
183   }
184 
enum_type()185   const protobuf::EnumDescriptor* enum_type() const {
186     return descriptor_->enum_type();
187   }
188 
message_type()189   const protobuf::Descriptor* message_type() const {
190     return descriptor_->message_type();
191   }
192 
EnforceUtf8()193   bool EnforceUtf8() const {
194     return descriptor_->type() == protobuf::FieldDescriptor::TYPE_STRING &&
195            descriptor()->file()->syntax() ==
196                protobuf::FileDescriptor::SYNTAX_PROTO3;
197   }
198 
descriptor()199   const protobuf::FieldDescriptor* descriptor() const { return descriptor_; }
200 
DebugString()201   std::string DebugString() const {
202     std::string s = descriptor_->DebugString();
203     if (is_repeated()) s += "[" + std::to_string(index_) + "]";
204     return s + " of\n" + message_->DebugString();
205   }
206 
207  protected:
is_repeated()208   bool is_repeated() const { return descriptor_->is_repeated(); }
209 
reflection()210   const protobuf::Reflection& reflection() const {
211     return *message_->GetReflection();
212   }
213 
index()214   size_t index() const { return index_; }
215 
216  private:
217   template <class Fn, class T>
218   friend struct FieldFunction;
219 
220   const protobuf::Message* message_;
221   const protobuf::FieldDescriptor* descriptor_;
222   size_t index_;
223 };
224 
225 class FieldInstance : public ConstFieldInstance {
226  public:
227   static const size_t kInvalidIndex = -1;
228 
FieldInstance()229   FieldInstance() : ConstFieldInstance(), message_(nullptr) {}
230 
FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field,size_t index)231   FieldInstance(protobuf::Message* message,
232                 const protobuf::FieldDescriptor* field, size_t index)
233       : ConstFieldInstance(message, field, index), message_(message) {}
234 
FieldInstance(protobuf::Message * message,const protobuf::FieldDescriptor * field)235   FieldInstance(protobuf::Message* message,
236                 const protobuf::FieldDescriptor* field)
237       : ConstFieldInstance(message, field), message_(message) {}
238 
Delete()239   void Delete() const {
240     if (!is_repeated()) return reflection().ClearField(message_, descriptor());
241     int field_size = reflection().FieldSize(*message_, descriptor());
242     // API has only method to delete the last message, so we move method from
243     // the
244     // middle to the end.
245     for (int i = index() + 1; i < field_size; ++i)
246       reflection().SwapElements(message_, descriptor(), i, i - 1);
247     reflection().RemoveLast(message_, descriptor());
248   }
249 
250   template <class T>
Create(const T & value)251   void Create(const T& value) const {
252     if (!is_repeated()) return Store(value);
253     InsertRepeated(value);
254   }
255 
Store(int32_t value)256   void Store(int32_t value) const {
257     if (is_repeated())
258       reflection().SetRepeatedInt32(message_, descriptor(), index(), value);
259     else
260       reflection().SetInt32(message_, descriptor(), value);
261   }
262 
Store(int64_t value)263   void Store(int64_t value) const {
264     if (is_repeated())
265       reflection().SetRepeatedInt64(message_, descriptor(), index(), value);
266     else
267       reflection().SetInt64(message_, descriptor(), value);
268   }
269 
Store(uint32_t value)270   void Store(uint32_t value) const {
271     if (is_repeated())
272       reflection().SetRepeatedUInt32(message_, descriptor(), index(), value);
273     else
274       reflection().SetUInt32(message_, descriptor(), value);
275   }
276 
Store(uint64_t value)277   void Store(uint64_t value) const {
278     if (is_repeated())
279       reflection().SetRepeatedUInt64(message_, descriptor(), index(), value);
280     else
281       reflection().SetUInt64(message_, descriptor(), value);
282   }
283 
Store(double value)284   void Store(double value) const {
285     if (is_repeated())
286       reflection().SetRepeatedDouble(message_, descriptor(), index(), value);
287     else
288       reflection().SetDouble(message_, descriptor(), value);
289   }
290 
Store(float value)291   void Store(float value) const {
292     if (is_repeated())
293       reflection().SetRepeatedFloat(message_, descriptor(), index(), value);
294     else
295       reflection().SetFloat(message_, descriptor(), value);
296   }
297 
Store(bool value)298   void Store(bool value) const {
299     if (is_repeated())
300       reflection().SetRepeatedBool(message_, descriptor(), index(), value);
301     else
302       reflection().SetBool(message_, descriptor(), value);
303   }
304 
Store(const Enum & value)305   void Store(const Enum& value) const {
306     assert(value.index < value.count);
307     const protobuf::EnumValueDescriptor* enum_value =
308         descriptor()->enum_type()->value(value.index);
309     if (is_repeated())
310       reflection().SetRepeatedEnum(message_, descriptor(), index(), enum_value);
311     else
312       reflection().SetEnum(message_, descriptor(), enum_value);
313   }
314 
Store(const std::string & value)315   void Store(const std::string& value) const {
316     if (is_repeated())
317       reflection().SetRepeatedString(message_, descriptor(), index(), value);
318     else
319       reflection().SetString(message_, descriptor(), value);
320   }
321 
Store(const std::unique_ptr<protobuf::Message> & value)322   void Store(const std::unique_ptr<protobuf::Message>& value) const {
323     protobuf::Message* mutable_message =
324         is_repeated() ? reflection().MutableRepeatedMessage(
325                             message_, descriptor(), index())
326                       : reflection().MutableMessage(message_, descriptor());
327     mutable_message->Clear();
328     if (value) mutable_message->CopyFrom(*value);
329   }
330 
331  private:
332   template <class T>
InsertRepeated(const T & value)333   void InsertRepeated(const T& value) const {
334     PushBackRepeated(value);
335     size_t field_size = reflection().FieldSize(*message_, descriptor());
336     if (field_size == 1) return;
337     // API has only method to add field to the end of the list. So we add
338     // descriptor()
339     // and move it into the middle.
340     for (size_t i = field_size - 1; i > index(); --i)
341       reflection().SwapElements(message_, descriptor(), i, i - 1);
342   }
343 
PushBackRepeated(int32_t value)344   void PushBackRepeated(int32_t value) const {
345     assert(is_repeated());
346     reflection().AddInt32(message_, descriptor(), value);
347   }
348 
PushBackRepeated(int64_t value)349   void PushBackRepeated(int64_t value) const {
350     assert(is_repeated());
351     reflection().AddInt64(message_, descriptor(), value);
352   }
353 
PushBackRepeated(uint32_t value)354   void PushBackRepeated(uint32_t value) const {
355     assert(is_repeated());
356     reflection().AddUInt32(message_, descriptor(), value);
357   }
358 
PushBackRepeated(uint64_t value)359   void PushBackRepeated(uint64_t value) const {
360     assert(is_repeated());
361     reflection().AddUInt64(message_, descriptor(), value);
362   }
363 
PushBackRepeated(double value)364   void PushBackRepeated(double value) const {
365     assert(is_repeated());
366     reflection().AddDouble(message_, descriptor(), value);
367   }
368 
PushBackRepeated(float value)369   void PushBackRepeated(float value) const {
370     assert(is_repeated());
371     reflection().AddFloat(message_, descriptor(), value);
372   }
373 
PushBackRepeated(bool value)374   void PushBackRepeated(bool value) const {
375     assert(is_repeated());
376     reflection().AddBool(message_, descriptor(), value);
377   }
378 
PushBackRepeated(const Enum & value)379   void PushBackRepeated(const Enum& value) const {
380     assert(value.index < value.count);
381     const protobuf::EnumValueDescriptor* enum_value =
382         descriptor()->enum_type()->value(value.index);
383     assert(is_repeated());
384     reflection().AddEnum(message_, descriptor(), enum_value);
385   }
386 
PushBackRepeated(const std::string & value)387   void PushBackRepeated(const std::string& value) const {
388     assert(is_repeated());
389     reflection().AddString(message_, descriptor(), value);
390   }
391 
PushBackRepeated(const std::unique_ptr<protobuf::Message> & value)392   void PushBackRepeated(const std::unique_ptr<protobuf::Message>& value) const {
393     assert(is_repeated());
394     protobuf::Message* mutable_message =
395         reflection().AddMessage(message_, descriptor());
396     mutable_message->Clear();
397     if (value) mutable_message->CopyFrom(*value);
398   }
399 
400   protobuf::Message* message_;
401 };
402 
403 template <class Fn, class R = void>
404 struct FieldFunction {
405   template <class Field, class... Args>
operatorFieldFunction406   R operator()(const Field& field, const Args&... args) const {
407     assert(field.descriptor());
408     using protobuf::FieldDescriptor;
409     switch (field.cpp_type()) {
410       case FieldDescriptor::CPPTYPE_INT32:
411         return static_cast<const Fn*>(this)->template ForType<int32_t>(field,
412                                                                        args...);
413       case FieldDescriptor::CPPTYPE_INT64:
414         return static_cast<const Fn*>(this)->template ForType<int64_t>(field,
415                                                                        args...);
416       case FieldDescriptor::CPPTYPE_UINT32:
417         return static_cast<const Fn*>(this)->template ForType<uint32_t>(
418             field, args...);
419       case FieldDescriptor::CPPTYPE_UINT64:
420         return static_cast<const Fn*>(this)->template ForType<uint64_t>(
421             field, args...);
422       case FieldDescriptor::CPPTYPE_DOUBLE:
423         return static_cast<const Fn*>(this)->template ForType<double>(field,
424                                                                       args...);
425       case FieldDescriptor::CPPTYPE_FLOAT:
426         return static_cast<const Fn*>(this)->template ForType<float>(field,
427                                                                      args...);
428       case FieldDescriptor::CPPTYPE_BOOL:
429         return static_cast<const Fn*>(this)->template ForType<bool>(field,
430                                                                     args...);
431       case FieldDescriptor::CPPTYPE_ENUM:
432         return static_cast<const Fn*>(this)
433             ->template ForType<ConstFieldInstance::Enum>(field, args...);
434       case FieldDescriptor::CPPTYPE_STRING:
435         return static_cast<const Fn*>(this)->template ForType<std::string>(
436             field, args...);
437       case FieldDescriptor::CPPTYPE_MESSAGE:
438         return static_cast<const Fn*>(this)
439             ->template ForType<std::unique_ptr<protobuf::Message>>(field,
440                                                                    args...);
441     }
442     assert(false && "Unknown type");
443     abort();
444   }
445 };
446 
447 }  // namespace protobuf_mutator
448 
449 #endif  // SRC_FIELD_INSTANCE_H_
450