• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/mutator.h"
16 
17 #include <algorithm>
18 #include <bitset>
19 #include <map>
20 #include <memory>
21 #include <random>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "src/field_instance.h"
27 #include "src/utf8_fix.h"
28 #include "src/weighted_reservoir_sampler.h"
29 
30 namespace protobuf_mutator {
31 
32 using google::protobuf::Any;
33 using protobuf::Descriptor;
34 using protobuf::FieldDescriptor;
35 using protobuf::FileDescriptor;
36 using protobuf::Message;
37 using protobuf::OneofDescriptor;
38 using protobuf::Reflection;
39 using protobuf::util::MessageDifferencer;
40 using std::placeholders::_1;
41 
42 namespace {
43 
44 const int kMaxInitializeDepth = 200;
45 const uint64_t kDefaultMutateWeight = 1000000;
46 
47 enum class Mutation : uint8_t {
48   None,
49   Add,     // Adds new field with default value.
50   Mutate,  // Mutates field contents.
51   Delete,  // Deletes field.
52   Copy,    // Copy values copied from another field.
53   Clone,   // Create new field with value copied from another.
54 
55   Last = Clone,
56 };
57 
58 using MutationBitset = std::bitset<static_cast<size_t>(Mutation::Last)>;
59 
60 using Messages = std::vector<Message*>;
61 using ConstMessages = std::vector<const Message*>;
62 
63 // Return random integer from [0, count)
GetRandomIndex(RandomEngine * random,size_t count)64 size_t GetRandomIndex(RandomEngine* random, size_t count) {
65   assert(count > 0);
66   if (count == 1) return 0;
67   return std::uniform_int_distribution<size_t>(0, count - 1)(*random);
68 }
69 
70 // Flips random bit in the buffer.
FlipBit(size_t size,uint8_t * bytes,RandomEngine * random)71 void FlipBit(size_t size, uint8_t* bytes, RandomEngine* random) {
72   size_t bit = GetRandomIndex(random, size * 8);
73   bytes[bit / 8] ^= (1u << (bit % 8));
74 }
75 
76 // Flips random bit in the value.
77 template <class T>
FlipBit(T value,RandomEngine * random)78 T FlipBit(T value, RandomEngine* random) {
79   FlipBit(sizeof(value), reinterpret_cast<uint8_t*>(&value), random);
80   return value;
81 }
82 
83 // Return true with probability about 1-of-n.
GetRandomBool(RandomEngine * random,size_t n=2)84 bool GetRandomBool(RandomEngine* random, size_t n = 2) {
85   return GetRandomIndex(random, n) == 0;
86 }
87 
IsProto3SimpleField(const FieldDescriptor & field)88 bool IsProto3SimpleField(const FieldDescriptor& field) {
89   assert(field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 ||
90          field.file()->syntax() == FileDescriptor::SYNTAX_PROTO2);
91   return field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 &&
92          field.cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE &&
93          !field.containing_oneof() && !field.is_repeated();
94 }
95 
96 struct CreateDefaultField : public FieldFunction<CreateDefaultField> {
97   template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::CreateDefaultField98   void ForType(const FieldInstance& field) const {
99     T value;
100     field.GetDefault(&value);
101     field.Create(value);
102   }
103 };
104 
105 struct DeleteField : public FieldFunction<DeleteField> {
106   template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::DeleteField107   void ForType(const FieldInstance& field) const {
108     field.Delete();
109   }
110 };
111 
112 struct CopyField : public FieldFunction<CopyField> {
113   template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::CopyField114   void ForType(const ConstFieldInstance& source,
115                const FieldInstance& field) const {
116     T value;
117     source.Load(&value);
118     field.Store(value);
119   }
120 };
121 
122 struct AppendField : public FieldFunction<AppendField> {
123   template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::AppendField124   void ForType(const ConstFieldInstance& source,
125                const FieldInstance& field) const {
126     T value;
127     source.Load(&value);
128     field.Create(value);
129   }
130 };
131 
132 class CanCopyAndDifferentField
133     : public FieldFunction<CanCopyAndDifferentField, bool> {
134  public:
135   template <class T>
ForType(const ConstFieldInstance & src,const ConstFieldInstance & dst,int size_increase_hint) const136   bool ForType(const ConstFieldInstance& src, const ConstFieldInstance& dst,
137                int size_increase_hint) const {
138     T s;
139     src.Load(&s);
140     if (!dst.CanStore(s)) return false;
141     T d;
142     dst.Load(&d);
143     return SizeDiff(s, d) <= size_increase_hint && !IsEqual(s, d);
144   }
145 
146  private:
IsEqual(const ConstFieldInstance::Enum & a,const ConstFieldInstance::Enum & b) const147   bool IsEqual(const ConstFieldInstance::Enum& a,
148                const ConstFieldInstance::Enum& b) const {
149     assert(a.count == b.count);
150     return a.index == b.index;
151   }
152 
IsEqual(const std::unique_ptr<Message> & a,const std::unique_ptr<Message> & b) const153   bool IsEqual(const std::unique_ptr<Message>& a,
154                const std::unique_ptr<Message>& b) const {
155     return MessageDifferencer::Equals(*a, *b);
156   }
157 
158   template <class T>
IsEqual(const T & a,const T & b) const159   bool IsEqual(const T& a, const T& b) const {
160     return a == b;
161   }
162 
SizeDiff(const std::unique_ptr<Message> & src,const std::unique_ptr<Message> & dst) const163   int64_t SizeDiff(const std::unique_ptr<Message>& src,
164                    const std::unique_ptr<Message>& dst) const {
165     return src->ByteSizeLong() - dst->ByteSizeLong();
166   }
167 
SizeDiff(const std::string & src,const std::string & dst) const168   int64_t SizeDiff(const std::string& src, const std::string& dst) const {
169     return src.size() - dst.size();
170   }
171 
172   template <class T>
SizeDiff(const T &,const T &) const173   int64_t SizeDiff(const T&, const T&) const {
174     return 0;
175   }
176 };
177 
178 // Selects random field and mutation from the given proto message.
179 class MutationSampler {
180  public:
MutationSampler(bool keep_initialized,MutationBitset allowed_mutations,RandomEngine * random)181   MutationSampler(bool keep_initialized, MutationBitset allowed_mutations,
182                   RandomEngine* random)
183       : keep_initialized_(keep_initialized),
184         allowed_mutations_(allowed_mutations),
185         random_(random),
186         sampler_(random) {}
187 
188   // Returns selected field.
field() const189   const FieldInstance& field() const { return sampler_.selected().field; }
190 
191   // Returns selected mutation.
mutation() const192   Mutation mutation() const { return sampler_.selected().mutation; }
193 
Sample(Message * message)194   void Sample(Message* message) {
195     SampleImpl(message);
196     assert(mutation() != Mutation::None ||
197            !allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
198            message->GetDescriptor()->field_count() == 0);
199   }
200 
201  private:
SampleImpl(Message * message)202   void SampleImpl(Message* message) {
203     const Descriptor* descriptor = message->GetDescriptor();
204     const Reflection* reflection = message->GetReflection();
205 
206     int field_count = descriptor->field_count();
207     for (int i = 0; i < field_count; ++i) {
208       const FieldDescriptor* field = descriptor->field(i);
209       if (const OneofDescriptor* oneof = field->containing_oneof()) {
210         // Handle entire oneof group on the first field.
211         if (field->index_in_oneof() == 0) {
212           assert(oneof->field_count());
213           const FieldDescriptor* current_field =
214               reflection->GetOneofFieldDescriptor(*message, oneof);
215           for (;;) {
216             const FieldDescriptor* add_field =
217                 oneof->field(GetRandomIndex(random_, oneof->field_count()));
218             if (add_field != current_field) {
219               Try({message, add_field}, Mutation::Add);
220               Try({message, add_field}, Mutation::Clone);
221               break;
222             }
223             if (oneof->field_count() < 2) break;
224           }
225           if (current_field) {
226             if (current_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
227               Try({message, current_field}, Mutation::Mutate);
228             Try({message, current_field}, Mutation::Delete);
229             Try({message, current_field}, Mutation::Copy);
230           }
231         }
232       } else {
233         if (field->is_repeated()) {
234           int field_size = reflection->FieldSize(*message, field);
235           size_t random_index = GetRandomIndex(random_, field_size + 1);
236           Try({message, field, random_index}, Mutation::Add);
237           Try({message, field, random_index}, Mutation::Clone);
238 
239           if (field_size) {
240             size_t random_index = GetRandomIndex(random_, field_size);
241             if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
242               Try({message, field, random_index}, Mutation::Mutate);
243             Try({message, field, random_index}, Mutation::Delete);
244             Try({message, field, random_index}, Mutation::Copy);
245           }
246         } else {
247           if (reflection->HasField(*message, field) ||
248               IsProto3SimpleField(*field)) {
249             if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
250               Try({message, field}, Mutation::Mutate);
251             if (!IsProto3SimpleField(*field) &&
252                 (!field->is_required() || !keep_initialized_)) {
253               Try({message, field}, Mutation::Delete);
254             }
255             Try({message, field}, Mutation::Copy);
256           } else {
257             Try({message, field}, Mutation::Add);
258             Try({message, field}, Mutation::Clone);
259           }
260         }
261       }
262 
263       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
264         if (field->is_repeated()) {
265           const int field_size = reflection->FieldSize(*message, field);
266           for (int j = 0; j < field_size; ++j)
267             SampleImpl(reflection->MutableRepeatedMessage(message, field, j));
268         } else if (reflection->HasField(*message, field)) {
269           SampleImpl(reflection->MutableMessage(message, field));
270         }
271       }
272     }
273   }
274 
Try(const FieldInstance & field,Mutation mutation)275   void Try(const FieldInstance& field, Mutation mutation) {
276     assert(mutation != Mutation::None);
277     if (!allowed_mutations_[static_cast<size_t>(mutation)]) return;
278     sampler_.Try(kDefaultMutateWeight, {field, mutation});
279   }
280 
281   bool keep_initialized_ = false;
282   MutationBitset allowed_mutations_;
283 
284   RandomEngine* random_;
285 
286   struct Result {
287     Result() = default;
Resultprotobuf_mutator::__anon4db2e1e30111::MutationSampler::Result288     Result(const FieldInstance& f, Mutation m) : field(f), mutation(m) {}
289 
290     FieldInstance field;
291     Mutation mutation = Mutation::None;
292   };
293   WeightedReservoirSampler<Result, RandomEngine> sampler_;
294 };
295 
296 // Selects random field of compatible type to use for clone mutations.
297 class DataSourceSampler {
298  public:
DataSourceSampler(const ConstFieldInstance & match,RandomEngine * random,int size_increase_hint)299   DataSourceSampler(const ConstFieldInstance& match, RandomEngine* random,
300                     int size_increase_hint)
301       : match_(match),
302         random_(random),
303         size_increase_hint_(size_increase_hint),
304         sampler_(random) {}
305 
Sample(const Message & message)306   void Sample(const Message& message) { SampleImpl(message); }
307 
308   // Returns selected field.
field() const309   const ConstFieldInstance& field() const {
310     assert(!IsEmpty());
311     return sampler_.selected();
312   }
313 
IsEmpty() const314   bool IsEmpty() const { return sampler_.IsEmpty(); }
315 
316  private:
SampleImpl(const Message & message)317   void SampleImpl(const Message& message) {
318     const Descriptor* descriptor = message.GetDescriptor();
319     const Reflection* reflection = message.GetReflection();
320 
321     int field_count = descriptor->field_count();
322     for (int i = 0; i < field_count; ++i) {
323       const FieldDescriptor* field = descriptor->field(i);
324       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
325         if (field->is_repeated()) {
326           const int field_size = reflection->FieldSize(message, field);
327           for (int j = 0; j < field_size; ++j) {
328             SampleImpl(reflection->GetRepeatedMessage(message, field, j));
329           }
330         } else if (reflection->HasField(message, field)) {
331           SampleImpl(reflection->GetMessage(message, field));
332         }
333       }
334 
335       if (field->cpp_type() != match_.cpp_type()) continue;
336       if (match_.cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
337         if (field->enum_type() != match_.enum_type()) continue;
338       } else if (match_.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
339         if (field->message_type() != match_.message_type()) continue;
340       }
341 
342       if (field->is_repeated()) {
343         if (int field_size = reflection->FieldSize(message, field)) {
344           ConstFieldInstance source(&message, field,
345                                     GetRandomIndex(random_, field_size));
346           if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
347             sampler_.Try(field_size, source);
348         }
349       } else {
350         if (reflection->HasField(message, field)) {
351           ConstFieldInstance source(&message, field);
352           if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
353             sampler_.Try(1, source);
354         }
355       }
356     }
357   }
358 
359   ConstFieldInstance match_;
360   RandomEngine* random_;
361   int size_increase_hint_;
362 
363   WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
364 };
365 
366 using UnpackedAny =
367     std::unordered_map<const Message*, std::unique_ptr<Message>>;
368 
GetAnyTypeDescriptor(const Any & any)369 const Descriptor* GetAnyTypeDescriptor(const Any& any) {
370   std::string type_name;
371   if (!Any::ParseAnyTypeUrl(std::string(any.type_url()), &type_name))
372     return nullptr;
373   return any.descriptor()->file()->pool()->FindMessageTypeByName(type_name);
374 }
375 
UnpackAny(const Any & any)376 std::unique_ptr<Message> UnpackAny(const Any& any) {
377   const Descriptor* desc = GetAnyTypeDescriptor(any);
378   if (!desc) return {};
379   std::unique_ptr<Message> message(
380       any.GetReflection()->GetMessageFactory()->GetPrototype(desc)->New());
381   message->ParsePartialFromString(std::string(any.value()));
382   return message;
383 }
384 
CastToAny(const Message * message)385 const Any* CastToAny(const Message* message) {
386   return Any::GetDescriptor() == message->GetDescriptor()
387              ? static_cast<const Any*>(message)
388              : nullptr;
389 }
390 
CastToAny(Message * message)391 Any* CastToAny(Message* message) {
392   return Any::GetDescriptor() == message->GetDescriptor()
393              ? static_cast<Any*>(message)
394              : nullptr;
395 }
396 
UnpackIfAny(const Message & message)397 std::unique_ptr<Message> UnpackIfAny(const Message& message) {
398   if (const Any* any = CastToAny(&message)) return UnpackAny(*any);
399   return {};
400 }
401 
UnpackAny(const Message & message,UnpackedAny * result)402 void UnpackAny(const Message& message, UnpackedAny* result) {
403   if (std::unique_ptr<Message> any = UnpackIfAny(message)) {
404     UnpackAny(*any, result);
405     result->emplace(&message, std::move(any));
406     return;
407   }
408 
409   const Descriptor* descriptor = message.GetDescriptor();
410   const Reflection* reflection = message.GetReflection();
411 
412   for (int i = 0; i < descriptor->field_count(); ++i) {
413     const FieldDescriptor* field = descriptor->field(i);
414     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
415       if (field->is_repeated()) {
416         const int field_size = reflection->FieldSize(message, field);
417         for (int j = 0; j < field_size; ++j) {
418           UnpackAny(reflection->GetRepeatedMessage(message, field, j), result);
419         }
420       } else if (reflection->HasField(message, field)) {
421         UnpackAny(reflection->GetMessage(message, field), result);
422       }
423     }
424   }
425 }
426 
427 class PostProcessing {
428  public:
429   using PostProcessors =
430       std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;
431 
PostProcessing(bool keep_initialized,const PostProcessors & post_processors,const UnpackedAny & any,RandomEngine * random)432   PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
433                  const UnpackedAny& any, RandomEngine* random)
434       : keep_initialized_(keep_initialized),
435         post_processors_(post_processors),
436         any_(any),
437         random_(random) {}
438 
Run(Message * message,int max_depth)439   void Run(Message* message, int max_depth) {
440     --max_depth;
441     const Descriptor* descriptor = message->GetDescriptor();
442 
443     // Apply custom mutators in nested messages before packing any.
444     const Reflection* reflection = message->GetReflection();
445     for (int i = 0; i < descriptor->field_count(); i++) {
446       const FieldDescriptor* field = descriptor->field(i);
447       if (keep_initialized_ &&
448           (field->is_required() || descriptor->options().map_entry()) &&
449           !reflection->HasField(*message, field)) {
450         CreateDefaultField()(FieldInstance(message, field));
451       }
452 
453       if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
454 
455       if (max_depth < 0 && !field->is_required()) {
456         // Clear deep optional fields to avoid stack overflow.
457         reflection->ClearField(message, field);
458         if (field->is_repeated())
459           assert(!reflection->FieldSize(*message, field));
460         else
461           assert(!reflection->HasField(*message, field));
462         continue;
463       }
464 
465       if (field->is_repeated()) {
466         const int field_size = reflection->FieldSize(*message, field);
467         for (int j = 0; j < field_size; ++j) {
468           Message* nested_message =
469               reflection->MutableRepeatedMessage(message, field, j);
470           Run(nested_message, max_depth);
471         }
472       } else if (reflection->HasField(*message, field)) {
473         Message* nested_message = reflection->MutableMessage(message, field);
474         Run(nested_message, max_depth);
475       }
476     }
477 
478     if (Any* any = CastToAny(message)) {
479       if (max_depth < 0) {
480         // Clear deep Any fields to avoid stack overflow.
481         any->Clear();
482       } else {
483         auto It = any_.find(message);
484         if (It != any_.end()) {
485           Run(It->second.get(), max_depth);
486           std::string value;
487           It->second->SerializePartialToString(&value);
488           *any->mutable_value() = value;
489         }
490       }
491     }
492 
493     // Call user callback after message trimmed, initialized and packed.
494     auto range = post_processors_.equal_range(descriptor);
495     for (auto it = range.first; it != range.second; ++it)
496       it->second(message, (*random_)());
497   }
498 
499  private:
500   bool keep_initialized_;
501   const PostProcessors& post_processors_;
502   const UnpackedAny& any_;
503   RandomEngine* random_;
504 };
505 
506 }  // namespace
507 
508 class FieldMutator {
509  public:
FieldMutator(int size_increase_hint,bool enforce_changes,bool enforce_utf8_strings,const ConstMessages & sources,Mutator * mutator)510   FieldMutator(int size_increase_hint, bool enforce_changes,
511                bool enforce_utf8_strings, const ConstMessages& sources,
512                Mutator* mutator)
513       : size_increase_hint_(size_increase_hint),
514         enforce_changes_(enforce_changes),
515         enforce_utf8_strings_(enforce_utf8_strings),
516         sources_(sources),
517         mutator_(mutator) {}
518 
Mutate(int32_t * value) const519   void Mutate(int32_t* value) const {
520     RepeatMutate(value, std::bind(&Mutator::MutateInt32, mutator_, _1));
521   }
522 
Mutate(int64_t * value) const523   void Mutate(int64_t* value) const {
524     RepeatMutate(value, std::bind(&Mutator::MutateInt64, mutator_, _1));
525   }
526 
Mutate(uint32_t * value) const527   void Mutate(uint32_t* value) const {
528     RepeatMutate(value, std::bind(&Mutator::MutateUInt32, mutator_, _1));
529   }
530 
Mutate(uint64_t * value) const531   void Mutate(uint64_t* value) const {
532     RepeatMutate(value, std::bind(&Mutator::MutateUInt64, mutator_, _1));
533   }
534 
Mutate(float * value) const535   void Mutate(float* value) const {
536     RepeatMutate(value, std::bind(&Mutator::MutateFloat, mutator_, _1));
537   }
538 
Mutate(double * value) const539   void Mutate(double* value) const {
540     RepeatMutate(value, std::bind(&Mutator::MutateDouble, mutator_, _1));
541   }
542 
Mutate(bool * value) const543   void Mutate(bool* value) const {
544     RepeatMutate(value, std::bind(&Mutator::MutateBool, mutator_, _1));
545   }
546 
Mutate(FieldInstance::Enum * value) const547   void Mutate(FieldInstance::Enum* value) const {
548     RepeatMutate(&value->index,
549                  std::bind(&Mutator::MutateEnum, mutator_, _1, value->count));
550     assert(value->index < value->count);
551   }
552 
Mutate(std::string * value) const553   void Mutate(std::string* value) const {
554     if (enforce_utf8_strings_) {
555       RepeatMutate(value, std::bind(&Mutator::MutateUtf8String, mutator_, _1,
556                                     size_increase_hint_));
557     } else {
558       RepeatMutate(value, std::bind(&Mutator::MutateString, mutator_, _1,
559                                     size_increase_hint_));
560     }
561   }
562 
Mutate(std::unique_ptr<Message> * message) const563   void Mutate(std::unique_ptr<Message>* message) const {
564     assert(!enforce_changes_);
565     assert(*message);
566     if (GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_))
567       return;
568     mutator_->MutateImpl(sources_, {message->get()}, false,
569                          size_increase_hint_);
570   }
571 
572  private:
573   template <class T, class F>
RepeatMutate(T * value,F mutate) const574   void RepeatMutate(T* value, F mutate) const {
575     if (!enforce_changes_ &&
576         GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_)) {
577       return;
578     }
579     T tmp = *value;
580     for (int i = 0; i < 10; ++i) {
581       *value = mutate(*value);
582       if (!enforce_changes_ || *value != tmp) return;
583     }
584   }
585 
586   int size_increase_hint_;
587   size_t enforce_changes_;
588   bool enforce_utf8_strings_;
589   const ConstMessages& sources_;
590   Mutator* mutator_;
591 };
592 
593 namespace {
594 
595 struct MutateField : public FieldFunction<MutateField> {
596   template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30211::MutateField597   void ForType(const FieldInstance& field, int size_increase_hint,
598                const ConstMessages& sources, Mutator* mutator) const {
599     T value;
600     field.Load(&value);
601     FieldMutator(size_increase_hint, true, field.EnforceUtf8(), sources,
602                  mutator)
603         .Mutate(&value);
604     field.Store(value);
605   }
606 };
607 
608 struct CreateField : public FieldFunction<CreateField> {
609  public:
610   template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30211::CreateField611   void ForType(const FieldInstance& field, int size_increase_hint,
612                const ConstMessages& sources, Mutator* mutator) const {
613     T value;
614     field.GetDefault(&value);
615     FieldMutator field_mutator(size_increase_hint,
616                                false /* defaults could be useful */,
617                                field.EnforceUtf8(), sources, mutator);
618     field_mutator.Mutate(&value);
619     field.Create(value);
620   }
621 };
622 
623 }  // namespace
624 
Seed(uint32_t value)625 void Mutator::Seed(uint32_t value) { random_.seed(value); }
626 
Mutate(Message * message,size_t max_size_hint)627 void Mutator::Mutate(Message* message, size_t max_size_hint) {
628   UnpackedAny any;
629   UnpackAny(*message, &any);
630 
631   Messages messages;
632   messages.reserve(any.size() + 1);
633   messages.push_back(message);
634   for (const auto& kv : any) messages.push_back(kv.second.get());
635 
636   ConstMessages sources(messages.begin(), messages.end());
637   MutateImpl(sources, messages, false,
638              static_cast<int>(max_size_hint) -
639                  static_cast<int>(message->ByteSizeLong()));
640 
641   PostProcessing(keep_initialized_, post_processors_, any, &random_)
642       .Run(message, kMaxInitializeDepth);
643   assert(IsInitialized(*message));
644 }
645 
CrossOver(const Message & message1,Message * message2,size_t max_size_hint)646 void Mutator::CrossOver(const Message& message1, Message* message2,
647                         size_t max_size_hint) {
648   UnpackedAny any;
649   UnpackAny(*message2, &any);
650 
651   Messages messages;
652   messages.reserve(any.size() + 1);
653   messages.push_back(message2);
654   for (auto& kv : any) messages.push_back(kv.second.get());
655 
656   UnpackAny(message1, &any);
657 
658   ConstMessages sources;
659   sources.reserve(any.size() + 2);
660   sources.push_back(&message1);
661   sources.push_back(message2);
662   for (const auto& kv : any) sources.push_back(kv.second.get());
663 
664   MutateImpl(sources, messages, true,
665              static_cast<int>(max_size_hint) -
666                  static_cast<int>(message2->ByteSizeLong()));
667 
668   PostProcessing(keep_initialized_, post_processors_, any, &random_)
669       .Run(message2, kMaxInitializeDepth);
670   assert(IsInitialized(*message2));
671 }
672 
RegisterPostProcessor(const Descriptor * desc,PostProcess callback)673 void Mutator::RegisterPostProcessor(const Descriptor* desc,
674                                     PostProcess callback) {
675   post_processors_.emplace(desc, callback);
676 }
677 
MutateImpl(const ConstMessages & sources,const Messages & messages,bool copy_clone_only,int size_increase_hint)678 bool Mutator::MutateImpl(const ConstMessages& sources, const Messages& messages,
679                          bool copy_clone_only, int size_increase_hint) {
680   MutationBitset mutations;
681   if (copy_clone_only) {
682     mutations[static_cast<size_t>(Mutation::Copy)] = true;
683     mutations[static_cast<size_t>(Mutation::Clone)] = true;
684   } else if (size_increase_hint <= 16) {
685     mutations[static_cast<size_t>(Mutation::Delete)] = true;
686   } else {
687     mutations.set();
688     mutations[static_cast<size_t>(Mutation::Copy)] = false;
689     mutations[static_cast<size_t>(Mutation::Clone)] = false;
690   }
691   while (mutations.any()) {
692     MutationSampler mutation(keep_initialized_, mutations, &random_);
693     for (Message* message : messages) mutation.Sample(message);
694 
695     switch (mutation.mutation()) {
696       case Mutation::None:
697         return true;
698       case Mutation::Add:
699         CreateField()(mutation.field(), size_increase_hint, sources, this);
700         return true;
701       case Mutation::Mutate:
702         MutateField()(mutation.field(), size_increase_hint, sources, this);
703         return true;
704       case Mutation::Delete:
705         DeleteField()(mutation.field());
706         return true;
707       case Mutation::Clone: {
708         CreateDefaultField()(mutation.field());
709         DataSourceSampler source_sampler(mutation.field(), &random_,
710                                          size_increase_hint);
711         for (const Message* source : sources) source_sampler.Sample(*source);
712         if (source_sampler.IsEmpty()) {
713           if (!IsProto3SimpleField(*mutation.field().descriptor()))
714             return true;  // CreateField is enough for proto2.
715           break;
716         }
717         CopyField()(source_sampler.field(), mutation.field());
718         return true;
719       }
720       case Mutation::Copy: {
721         DataSourceSampler source_sampler(mutation.field(), &random_,
722                                          size_increase_hint);
723         for (const Message* source : sources) source_sampler.Sample(*source);
724         if (source_sampler.IsEmpty()) break;
725         CopyField()(source_sampler.field(), mutation.field());
726         return true;
727       }
728       default:
729         assert(false && "unexpected mutation");
730         return false;
731     }
732 
733     // Don't try same mutation next time.
734     mutations[static_cast<size_t>(mutation.mutation())] = false;
735   }
736   return false;
737 }
738 
MutateInt32(int32_t value)739 int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }
740 
MutateInt64(int64_t value)741 int64_t Mutator::MutateInt64(int64_t value) { return FlipBit(value, &random_); }
742 
MutateUInt32(uint32_t value)743 uint32_t Mutator::MutateUInt32(uint32_t value) {
744   return FlipBit(value, &random_);
745 }
746 
MutateUInt64(uint64_t value)747 uint64_t Mutator::MutateUInt64(uint64_t value) {
748   return FlipBit(value, &random_);
749 }
750 
MutateFloat(float value)751 float Mutator::MutateFloat(float value) { return FlipBit(value, &random_); }
752 
MutateDouble(double value)753 double Mutator::MutateDouble(double value) { return FlipBit(value, &random_); }
754 
MutateBool(bool value)755 bool Mutator::MutateBool(bool value) { return !value; }
756 
MutateEnum(size_t index,size_t item_count)757 size_t Mutator::MutateEnum(size_t index, size_t item_count) {
758   if (item_count <= 1) return 0;
759   return (index + 1 + GetRandomIndex(&random_, item_count - 1)) % item_count;
760 }
761 
MutateString(const std::string & value,int size_increase_hint)762 std::string Mutator::MutateString(const std::string& value,
763                                   int size_increase_hint) {
764   std::string result = value;
765 
766   while (!result.empty() && GetRandomBool(&random_)) {
767     result.erase(GetRandomIndex(&random_, result.size()), 1);
768   }
769 
770   while (size_increase_hint > 0 &&
771          result.size() < static_cast<size_t>(size_increase_hint) &&
772          GetRandomBool(&random_)) {
773     size_t index = GetRandomIndex(&random_, result.size() + 1);
774     result.insert(result.begin() + index, GetRandomIndex(&random_, 1 << 8));
775   }
776 
777   if (result != value) return result;
778 
779   if (result.empty()) {
780     result.push_back(GetRandomIndex(&random_, 1 << 8));
781     return result;
782   }
783 
784   if (!result.empty())
785     FlipBit(result.size(), reinterpret_cast<uint8_t*>(&result[0]), &random_);
786   return result;
787 }
788 
MutateUtf8String(const std::string & value,int size_increase_hint)789 std::string Mutator::MutateUtf8String(const std::string& value,
790                                       int size_increase_hint) {
791   std::string str = MutateString(value, size_increase_hint);
792   FixUtf8String(&str, &random_);
793   return str;
794 }
795 
IsInitialized(const Message & message) const796 bool Mutator::IsInitialized(const Message& message) const {
797   if (!keep_initialized_ || message.IsInitialized()) return true;
798   std::cerr << "Uninitialized: " << message.DebugString() << "\n";
799   return false;
800 }
801 
802 }  // namespace protobuf_mutator
803