• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/mutator.h"
16 
17 #include <algorithm>
18 #include <bitset>
19 #include <iostream>
20 #include <map>
21 #include <memory>
22 #include <random>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "src/field_instance.h"
28 #include "src/utf8_fix.h"
29 #include "src/weighted_reservoir_sampler.h"
30 
31 namespace protobuf_mutator {
32 
33 using google::protobuf::Any;
34 using protobuf::Descriptor;
35 using protobuf::FieldDescriptor;
36 using protobuf::FileDescriptor;
37 using protobuf::Message;
38 using protobuf::OneofDescriptor;
39 using protobuf::Reflection;
40 using protobuf::util::MessageDifferencer;
41 using std::placeholders::_1;
42 
43 namespace {
44 
45 const int kMaxInitializeDepth = 200;
46 const uint64_t kDefaultMutateWeight = 1000000;
47 
48 enum class Mutation : uint8_t {
49   None,
50   Add,     // Adds new field with default value.
51   Mutate,  // Mutates field contents.
52   Delete,  // Deletes field.
53   Copy,    // Copy values copied from another field.
54   Clone,   // Create new field with value copied from another.
55 
56   Last = Clone,
57 };
58 
59 using MutationBitset = std::bitset<static_cast<size_t>(Mutation::Last) + 1>;
60 
61 using Messages = std::vector<Message*>;
62 using ConstMessages = std::vector<const Message*>;
63 
64 // Return random integer from [0, count)
GetRandomIndex(RandomEngine * random,size_t count)65 size_t GetRandomIndex(RandomEngine* random, size_t count) {
66   assert(count > 0);
67   if (count == 1) return 0;
68   return std::uniform_int_distribution<size_t>(0, count - 1)(*random);
69 }
70 
71 // Flips random bit in the buffer.
FlipBit(size_t size,uint8_t * bytes,RandomEngine * random)72 void FlipBit(size_t size, uint8_t* bytes, RandomEngine* random) {
73   size_t bit = GetRandomIndex(random, size * 8);
74   bytes[bit / 8] ^= (1u << (bit % 8));
75 }
76 
77 // Flips random bit in the value.
78 template <class T>
FlipBit(T value,RandomEngine * random)79 T FlipBit(T value, RandomEngine* random) {
80   FlipBit(sizeof(value), reinterpret_cast<uint8_t*>(&value), random);
81   return value;
82 }
83 
84 // Return true with probability about 1-of-n.
GetRandomBool(RandomEngine * random,size_t n=2)85 bool GetRandomBool(RandomEngine* random, size_t n = 2) {
86   return GetRandomIndex(random, n) == 0;
87 }
88 
IsProto3SimpleField(const FieldDescriptor & field)89 bool IsProto3SimpleField(const FieldDescriptor& field) {
90   assert(field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 ||
91          field.file()->syntax() == FileDescriptor::SYNTAX_PROTO2);
92   return field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 &&
93          field.cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE &&
94          !field.containing_oneof() && !field.is_repeated();
95 }
96 
97 struct CreateDefaultField : public FieldFunction<CreateDefaultField> {
98   template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::CreateDefaultField99   void ForType(const FieldInstance& field) const {
100     T value;
101     field.GetDefault(&value);
102     field.Create(value);
103   }
104 };
105 
106 struct DeleteField : public FieldFunction<DeleteField> {
107   template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::DeleteField108   void ForType(const FieldInstance& field) const {
109     field.Delete();
110   }
111 };
112 
113 struct CopyField : public FieldFunction<CopyField> {
114   template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::CopyField115   void ForType(const ConstFieldInstance& source,
116                const FieldInstance& field) const {
117     T value;
118     source.Load(&value);
119     field.Store(value);
120   }
121 };
122 
123 struct AppendField : public FieldFunction<AppendField> {
124   template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::AppendField125   void ForType(const ConstFieldInstance& source,
126                const FieldInstance& field) const {
127     T value;
128     source.Load(&value);
129     field.Create(value);
130   }
131 };
132 
133 class CanCopyAndDifferentField
134     : public FieldFunction<CanCopyAndDifferentField, bool> {
135  public:
136   template <class T>
ForType(const ConstFieldInstance & src,const ConstFieldInstance & dst,int size_increase_hint) const137   bool ForType(const ConstFieldInstance& src, const ConstFieldInstance& dst,
138                int size_increase_hint) const {
139     T s;
140     src.Load(&s);
141     if (!dst.CanStore(s)) return false;
142     T d;
143     dst.Load(&d);
144     return SizeDiff(s, d) <= size_increase_hint && !IsEqual(s, d);
145   }
146 
147  private:
IsEqual(const ConstFieldInstance::Enum & a,const ConstFieldInstance::Enum & b) const148   bool IsEqual(const ConstFieldInstance::Enum& a,
149                const ConstFieldInstance::Enum& b) const {
150     assert(a.count == b.count);
151     return a.index == b.index;
152   }
153 
IsEqual(const std::unique_ptr<Message> & a,const std::unique_ptr<Message> & b) const154   bool IsEqual(const std::unique_ptr<Message>& a,
155                const std::unique_ptr<Message>& b) const {
156     return MessageDifferencer::Equals(*a, *b);
157   }
158 
159   template <class T>
IsEqual(const T & a,const T & b) const160   bool IsEqual(const T& a, const T& b) const {
161     return a == b;
162   }
163 
SizeDiff(const std::unique_ptr<Message> & src,const std::unique_ptr<Message> & dst) const164   int64_t SizeDiff(const std::unique_ptr<Message>& src,
165                    const std::unique_ptr<Message>& dst) const {
166     return src->ByteSizeLong() - dst->ByteSizeLong();
167   }
168 
SizeDiff(const std::string & src,const std::string & dst) const169   int64_t SizeDiff(const std::string& src, const std::string& dst) const {
170     return src.size() - dst.size();
171   }
172 
173   template <class T>
SizeDiff(const T &,const T &) const174   int64_t SizeDiff(const T&, const T&) const {
175     return 0;
176   }
177 };
178 
179 // Selects random field and mutation from the given proto message.
180 class MutationSampler {
181  public:
MutationSampler(bool keep_initialized,MutationBitset allowed_mutations,RandomEngine * random)182   MutationSampler(bool keep_initialized, MutationBitset allowed_mutations,
183                   RandomEngine* random)
184       : keep_initialized_(keep_initialized),
185         allowed_mutations_(allowed_mutations),
186         random_(random),
187         sampler_(random) {}
188 
189   // Returns selected field.
field() const190   const FieldInstance& field() const { return sampler_.selected().field; }
191 
192   // Returns selected mutation.
mutation() const193   Mutation mutation() const { return sampler_.selected().mutation; }
194 
Sample(Message * message)195   void Sample(Message* message) {
196     SampleImpl(message);
197     assert(mutation() != Mutation::None ||
198            !allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
199            message->GetDescriptor()->field_count() == 0);
200   }
201 
202  private:
SampleImpl(Message * message)203   void SampleImpl(Message* message) {
204     const Descriptor* descriptor = message->GetDescriptor();
205     const Reflection* reflection = message->GetReflection();
206 
207     int field_count = descriptor->field_count();
208     for (int i = 0; i < field_count; ++i) {
209       const FieldDescriptor* field = descriptor->field(i);
210       if (const OneofDescriptor* oneof = field->containing_oneof()) {
211         // Handle entire oneof group on the first field.
212         if (field->index_in_oneof() == 0) {
213           assert(oneof->field_count());
214           const FieldDescriptor* current_field =
215               reflection->GetOneofFieldDescriptor(*message, oneof);
216           for (;;) {
217             const FieldDescriptor* add_field =
218                 oneof->field(GetRandomIndex(random_, oneof->field_count()));
219             if (add_field != current_field) {
220               Try({message, add_field}, Mutation::Add);
221               Try({message, add_field}, Mutation::Clone);
222               break;
223             }
224             if (oneof->field_count() < 2) break;
225           }
226           if (current_field) {
227             if (current_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
228               Try({message, current_field}, Mutation::Mutate);
229             Try({message, current_field}, Mutation::Delete);
230             Try({message, current_field}, Mutation::Copy);
231           }
232         }
233       } else {
234         if (field->is_repeated()) {
235           int field_size = reflection->FieldSize(*message, field);
236           size_t random_index = GetRandomIndex(random_, field_size + 1);
237           Try({message, field, random_index}, Mutation::Add);
238           Try({message, field, random_index}, Mutation::Clone);
239 
240           if (field_size) {
241             size_t random_index = GetRandomIndex(random_, field_size);
242             if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
243               Try({message, field, random_index}, Mutation::Mutate);
244             Try({message, field, random_index}, Mutation::Delete);
245             Try({message, field, random_index}, Mutation::Copy);
246           }
247         } else {
248           if (reflection->HasField(*message, field) ||
249               IsProto3SimpleField(*field)) {
250             if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
251               Try({message, field}, Mutation::Mutate);
252             if (!IsProto3SimpleField(*field) &&
253                 (!field->is_required() || !keep_initialized_)) {
254               Try({message, field}, Mutation::Delete);
255             }
256             Try({message, field}, Mutation::Copy);
257           } else {
258             Try({message, field}, Mutation::Add);
259             Try({message, field}, Mutation::Clone);
260           }
261         }
262       }
263 
264       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
265         if (field->is_repeated()) {
266           const int field_size = reflection->FieldSize(*message, field);
267           for (int j = 0; j < field_size; ++j)
268             SampleImpl(reflection->MutableRepeatedMessage(message, field, j));
269         } else if (reflection->HasField(*message, field)) {
270           SampleImpl(reflection->MutableMessage(message, field));
271         }
272       }
273     }
274   }
275 
Try(const FieldInstance & field,Mutation mutation)276   void Try(const FieldInstance& field, Mutation mutation) {
277     assert(mutation != Mutation::None);
278     if (!allowed_mutations_[static_cast<size_t>(mutation)]) return;
279     sampler_.Try(kDefaultMutateWeight, {field, mutation});
280   }
281 
282   bool keep_initialized_ = false;
283   MutationBitset allowed_mutations_;
284 
285   RandomEngine* random_;
286 
287   struct Result {
288     Result() = default;
Resultprotobuf_mutator::__anonbdd7ecbf0111::MutationSampler::Result289     Result(const FieldInstance& f, Mutation m) : field(f), mutation(m) {}
290 
291     FieldInstance field;
292     Mutation mutation = Mutation::None;
293   };
294   WeightedReservoirSampler<Result, RandomEngine> sampler_;
295 };
296 
297 // Selects random field of compatible type to use for clone mutations.
298 class DataSourceSampler {
299  public:
DataSourceSampler(const ConstFieldInstance & match,RandomEngine * random,int size_increase_hint)300   DataSourceSampler(const ConstFieldInstance& match, RandomEngine* random,
301                     int size_increase_hint)
302       : match_(match),
303         random_(random),
304         size_increase_hint_(size_increase_hint),
305         sampler_(random) {}
306 
Sample(const Message & message)307   void Sample(const Message& message) { SampleImpl(message); }
308 
309   // Returns selected field.
field() const310   const ConstFieldInstance& field() const {
311     assert(!IsEmpty());
312     return sampler_.selected();
313   }
314 
IsEmpty() const315   bool IsEmpty() const { return sampler_.IsEmpty(); }
316 
317  private:
SampleImpl(const Message & message)318   void SampleImpl(const Message& message) {
319     const Descriptor* descriptor = message.GetDescriptor();
320     const Reflection* reflection = message.GetReflection();
321 
322     int field_count = descriptor->field_count();
323     for (int i = 0; i < field_count; ++i) {
324       const FieldDescriptor* field = descriptor->field(i);
325       if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
326         if (field->is_repeated()) {
327           const int field_size = reflection->FieldSize(message, field);
328           for (int j = 0; j < field_size; ++j) {
329             SampleImpl(reflection->GetRepeatedMessage(message, field, j));
330           }
331         } else if (reflection->HasField(message, field)) {
332           SampleImpl(reflection->GetMessage(message, field));
333         }
334       }
335 
336       if (field->cpp_type() != match_.cpp_type()) continue;
337       if (match_.cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
338         if (field->enum_type() != match_.enum_type()) continue;
339       } else if (match_.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
340         if (field->message_type() != match_.message_type()) continue;
341       }
342 
343       if (field->is_repeated()) {
344         if (int field_size = reflection->FieldSize(message, field)) {
345           ConstFieldInstance source(&message, field,
346                                     GetRandomIndex(random_, field_size));
347           if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
348             sampler_.Try(field_size, source);
349         }
350       } else {
351         if (reflection->HasField(message, field)) {
352           ConstFieldInstance source(&message, field);
353           if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
354             sampler_.Try(1, source);
355         }
356       }
357     }
358   }
359 
360   ConstFieldInstance match_;
361   RandomEngine* random_;
362   int size_increase_hint_;
363 
364   WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
365 };
366 
367 using UnpackedAny =
368     std::unordered_map<const Message*, std::unique_ptr<Message>>;
369 
GetAnyTypeDescriptor(const Any & any)370 const Descriptor* GetAnyTypeDescriptor(const Any& any) {
371   std::string type_name;
372   if (!Any::ParseAnyTypeUrl(std::string(any.type_url()), &type_name))
373     return nullptr;
374   return any.descriptor()->file()->pool()->FindMessageTypeByName(type_name);
375 }
376 
UnpackAny(const Any & any)377 std::unique_ptr<Message> UnpackAny(const Any& any) {
378   const Descriptor* desc = GetAnyTypeDescriptor(any);
379   if (!desc) return {};
380   std::unique_ptr<Message> message(
381       any.GetReflection()->GetMessageFactory()->GetPrototype(desc)->New());
382   message->ParsePartialFromString(std::string(any.value()));
383   return message;
384 }
385 
CastToAny(const Message * message)386 const Any* CastToAny(const Message* message) {
387   return Any::GetDescriptor() == message->GetDescriptor()
388              ? static_cast<const Any*>(message)
389              : nullptr;
390 }
391 
CastToAny(Message * message)392 Any* CastToAny(Message* message) {
393   return Any::GetDescriptor() == message->GetDescriptor()
394              ? static_cast<Any*>(message)
395              : nullptr;
396 }
397 
UnpackIfAny(const Message & message)398 std::unique_ptr<Message> UnpackIfAny(const Message& message) {
399   if (const Any* any = CastToAny(&message)) return UnpackAny(*any);
400   return {};
401 }
402 
UnpackAny(const Message & message,UnpackedAny * result)403 void UnpackAny(const Message& message, UnpackedAny* result) {
404   if (std::unique_ptr<Message> any = UnpackIfAny(message)) {
405     UnpackAny(*any, result);
406     result->emplace(&message, std::move(any));
407     return;
408   }
409 
410   const Descriptor* descriptor = message.GetDescriptor();
411   const Reflection* reflection = message.GetReflection();
412 
413   for (int i = 0; i < descriptor->field_count(); ++i) {
414     const FieldDescriptor* field = descriptor->field(i);
415     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
416       if (field->is_repeated()) {
417         const int field_size = reflection->FieldSize(message, field);
418         for (int j = 0; j < field_size; ++j) {
419           UnpackAny(reflection->GetRepeatedMessage(message, field, j), result);
420         }
421       } else if (reflection->HasField(message, field)) {
422         UnpackAny(reflection->GetMessage(message, field), result);
423       }
424     }
425   }
426 }
427 
428 class PostProcessing {
429  public:
430   using PostProcessors =
431       std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;
432 
PostProcessing(bool keep_initialized,const PostProcessors & post_processors,const UnpackedAny & any,RandomEngine * random)433   PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
434                  const UnpackedAny& any, RandomEngine* random)
435       : keep_initialized_(keep_initialized),
436         post_processors_(post_processors),
437         any_(any),
438         random_(random) {}
439 
Run(Message * message,int max_depth)440   void Run(Message* message, int max_depth) {
441     --max_depth;
442     const Descriptor* descriptor = message->GetDescriptor();
443 
444     // Apply custom mutators in nested messages before packing any.
445     const Reflection* reflection = message->GetReflection();
446     for (int i = 0; i < descriptor->field_count(); i++) {
447       const FieldDescriptor* field = descriptor->field(i);
448       if (keep_initialized_ &&
449           (field->is_required() || descriptor->options().map_entry()) &&
450           !reflection->HasField(*message, field)) {
451         CreateDefaultField()(FieldInstance(message, field));
452       }
453 
454       if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
455 
456       if (max_depth < 0 && !field->is_required()) {
457         // Clear deep optional fields to avoid stack overflow.
458         reflection->ClearField(message, field);
459         if (field->is_repeated())
460           assert(!reflection->FieldSize(*message, field));
461         else
462           assert(!reflection->HasField(*message, field));
463         continue;
464       }
465 
466       if (field->is_repeated()) {
467         const int field_size = reflection->FieldSize(*message, field);
468         for (int j = 0; j < field_size; ++j) {
469           Message* nested_message =
470               reflection->MutableRepeatedMessage(message, field, j);
471           Run(nested_message, max_depth);
472         }
473       } else if (reflection->HasField(*message, field)) {
474         Message* nested_message = reflection->MutableMessage(message, field);
475         Run(nested_message, max_depth);
476       }
477     }
478 
479     if (Any* any = CastToAny(message)) {
480       if (max_depth < 0) {
481         // Clear deep Any fields to avoid stack overflow.
482         any->Clear();
483       } else {
484         auto It = any_.find(message);
485         if (It != any_.end()) {
486           Run(It->second.get(), max_depth);
487           std::string value;
488           It->second->SerializePartialToString(&value);
489           *any->mutable_value() = value;
490         }
491       }
492     }
493 
494     // Call user callback after message trimmed, initialized and packed.
495     auto range = post_processors_.equal_range(descriptor);
496     for (auto it = range.first; it != range.second; ++it)
497       it->second(message, (*random_)());
498   }
499 
500  private:
501   bool keep_initialized_;
502   const PostProcessors& post_processors_;
503   const UnpackedAny& any_;
504   RandomEngine* random_;
505 };
506 
507 }  // namespace
508 
509 class FieldMutator {
510  public:
FieldMutator(int size_increase_hint,bool enforce_changes,bool enforce_utf8_strings,const ConstMessages & sources,Mutator * mutator)511   FieldMutator(int size_increase_hint, bool enforce_changes,
512                bool enforce_utf8_strings, const ConstMessages& sources,
513                Mutator* mutator)
514       : size_increase_hint_(size_increase_hint),
515         enforce_changes_(enforce_changes),
516         enforce_utf8_strings_(enforce_utf8_strings),
517         sources_(sources),
518         mutator_(mutator) {}
519 
Mutate(int32_t * value) const520   void Mutate(int32_t* value) const {
521     RepeatMutate(value, std::bind(&Mutator::MutateInt32, mutator_, _1));
522   }
523 
Mutate(int64_t * value) const524   void Mutate(int64_t* value) const {
525     RepeatMutate(value, std::bind(&Mutator::MutateInt64, mutator_, _1));
526   }
527 
Mutate(uint32_t * value) const528   void Mutate(uint32_t* value) const {
529     RepeatMutate(value, std::bind(&Mutator::MutateUInt32, mutator_, _1));
530   }
531 
Mutate(uint64_t * value) const532   void Mutate(uint64_t* value) const {
533     RepeatMutate(value, std::bind(&Mutator::MutateUInt64, mutator_, _1));
534   }
535 
Mutate(float * value) const536   void Mutate(float* value) const {
537     RepeatMutate(value, std::bind(&Mutator::MutateFloat, mutator_, _1));
538   }
539 
Mutate(double * value) const540   void Mutate(double* value) const {
541     RepeatMutate(value, std::bind(&Mutator::MutateDouble, mutator_, _1));
542   }
543 
Mutate(bool * value) const544   void Mutate(bool* value) const {
545     RepeatMutate(value, std::bind(&Mutator::MutateBool, mutator_, _1));
546   }
547 
Mutate(FieldInstance::Enum * value) const548   void Mutate(FieldInstance::Enum* value) const {
549     RepeatMutate(&value->index,
550                  std::bind(&Mutator::MutateEnum, mutator_, _1, value->count));
551     assert(value->index < value->count);
552   }
553 
Mutate(std::string * value) const554   void Mutate(std::string* value) const {
555     if (enforce_utf8_strings_) {
556       RepeatMutate(value, std::bind(&Mutator::MutateUtf8String, mutator_, _1,
557                                     size_increase_hint_));
558     } else {
559       RepeatMutate(value, std::bind(&Mutator::MutateString, mutator_, _1,
560                                     size_increase_hint_));
561     }
562   }
563 
Mutate(std::unique_ptr<Message> * message) const564   void Mutate(std::unique_ptr<Message>* message) const {
565     assert(!enforce_changes_);
566     assert(*message);
567     if (GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_))
568       return;
569     mutator_->MutateImpl(sources_, {message->get()}, false,
570                          size_increase_hint_);
571   }
572 
573  private:
574   template <class T, class F>
RepeatMutate(T * value,F mutate) const575   void RepeatMutate(T* value, F mutate) const {
576     if (!enforce_changes_ &&
577         GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_)) {
578       return;
579     }
580     T tmp = *value;
581     for (int i = 0; i < 10; ++i) {
582       *value = mutate(*value);
583       if (!enforce_changes_ || *value != tmp) return;
584     }
585   }
586 
587   int size_increase_hint_;
588   size_t enforce_changes_;
589   bool enforce_utf8_strings_;
590   const ConstMessages& sources_;
591   Mutator* mutator_;
592 };
593 
594 namespace {
595 
596 struct MutateField : public FieldFunction<MutateField> {
597   template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0211::MutateField598   void ForType(const FieldInstance& field, int size_increase_hint,
599                const ConstMessages& sources, Mutator* mutator) const {
600     T value;
601     field.Load(&value);
602     FieldMutator(size_increase_hint, true, field.EnforceUtf8(), sources,
603                  mutator)
604         .Mutate(&value);
605     field.Store(value);
606   }
607 };
608 
609 struct CreateField : public FieldFunction<CreateField> {
610  public:
611   template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0211::CreateField612   void ForType(const FieldInstance& field, int size_increase_hint,
613                const ConstMessages& sources, Mutator* mutator) const {
614     T value;
615     field.GetDefault(&value);
616     FieldMutator field_mutator(size_increase_hint,
617                                false /* defaults could be useful */,
618                                field.EnforceUtf8(), sources, mutator);
619     field_mutator.Mutate(&value);
620     field.Create(value);
621   }
622 };
623 
624 }  // namespace
625 
Seed(uint32_t value)626 void Mutator::Seed(uint32_t value) { random_.seed(value); }
627 
Fix(Message * message)628 void Mutator::Fix(Message* message) {
629   UnpackedAny any;
630   UnpackAny(*message, &any);
631 
632   PostProcessing(keep_initialized_, post_processors_, any, &random_)
633       .Run(message, kMaxInitializeDepth);
634   assert(IsInitialized(*message));
635 }
636 
Mutate(Message * message,size_t max_size_hint)637 void Mutator::Mutate(Message* message, size_t max_size_hint) {
638   UnpackedAny any;
639   UnpackAny(*message, &any);
640 
641   Messages messages;
642   messages.reserve(any.size() + 1);
643   messages.push_back(message);
644   for (const auto& kv : any) messages.push_back(kv.second.get());
645 
646   ConstMessages sources(messages.begin(), messages.end());
647   MutateImpl(sources, messages, false,
648              static_cast<int>(max_size_hint) -
649                  static_cast<int>(message->ByteSizeLong()));
650 
651   PostProcessing(keep_initialized_, post_processors_, any, &random_)
652       .Run(message, kMaxInitializeDepth);
653   assert(IsInitialized(*message));
654 }
655 
CrossOver(const Message & message1,Message * message2,size_t max_size_hint)656 void Mutator::CrossOver(const Message& message1, Message* message2,
657                         size_t max_size_hint) {
658   UnpackedAny any;
659   UnpackAny(*message2, &any);
660 
661   Messages messages;
662   messages.reserve(any.size() + 1);
663   messages.push_back(message2);
664   for (auto& kv : any) messages.push_back(kv.second.get());
665 
666   UnpackAny(message1, &any);
667 
668   ConstMessages sources;
669   sources.reserve(any.size() + 2);
670   sources.push_back(&message1);
671   sources.push_back(message2);
672   for (const auto& kv : any) sources.push_back(kv.second.get());
673 
674   MutateImpl(sources, messages, true,
675              static_cast<int>(max_size_hint) -
676                  static_cast<int>(message2->ByteSizeLong()));
677 
678   PostProcessing(keep_initialized_, post_processors_, any, &random_)
679       .Run(message2, kMaxInitializeDepth);
680   assert(IsInitialized(*message2));
681 }
682 
RegisterPostProcessor(const Descriptor * desc,PostProcess callback)683 void Mutator::RegisterPostProcessor(const Descriptor* desc,
684                                     PostProcess callback) {
685   post_processors_.emplace(desc, callback);
686 }
687 
MutateImpl(const ConstMessages & sources,const Messages & messages,bool copy_clone_only,int size_increase_hint)688 bool Mutator::MutateImpl(const ConstMessages& sources, const Messages& messages,
689                          bool copy_clone_only, int size_increase_hint) {
690   MutationBitset mutations;
691   if (copy_clone_only) {
692     mutations[static_cast<size_t>(Mutation::Copy)] = true;
693     mutations[static_cast<size_t>(Mutation::Clone)] = true;
694   } else if (size_increase_hint <= 16) {
695     mutations[static_cast<size_t>(Mutation::Delete)] = true;
696   } else {
697     mutations.set();
698     mutations[static_cast<size_t>(Mutation::Copy)] = false;
699     mutations[static_cast<size_t>(Mutation::Clone)] = false;
700   }
701   while (mutations.any()) {
702     MutationSampler mutation(keep_initialized_, mutations, &random_);
703     for (Message* message : messages) mutation.Sample(message);
704 
705     switch (mutation.mutation()) {
706       case Mutation::None:
707         return true;
708       case Mutation::Add:
709         CreateField()(mutation.field(), size_increase_hint, sources, this);
710         return true;
711       case Mutation::Mutate:
712         MutateField()(mutation.field(), size_increase_hint, sources, this);
713         return true;
714       case Mutation::Delete:
715         DeleteField()(mutation.field());
716         return true;
717       case Mutation::Clone: {
718         CreateDefaultField()(mutation.field());
719         DataSourceSampler source_sampler(mutation.field(), &random_,
720                                          size_increase_hint);
721         for (const Message* source : sources) source_sampler.Sample(*source);
722         if (source_sampler.IsEmpty()) {
723           if (!IsProto3SimpleField(*mutation.field().descriptor()))
724             return true;  // CreateField is enough for proto2.
725           break;
726         }
727         CopyField()(source_sampler.field(), mutation.field());
728         return true;
729       }
730       case Mutation::Copy: {
731         DataSourceSampler source_sampler(mutation.field(), &random_,
732                                          size_increase_hint);
733         for (const Message* source : sources) source_sampler.Sample(*source);
734         if (source_sampler.IsEmpty()) break;
735         CopyField()(source_sampler.field(), mutation.field());
736         return true;
737       }
738       default:
739         assert(false && "unexpected mutation");
740         return false;
741     }
742 
743     // Don't try same mutation next time.
744     mutations[static_cast<size_t>(mutation.mutation())] = false;
745   }
746   return false;
747 }
748 
MutateInt32(int32_t value)749 int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }
750 
MutateInt64(int64_t value)751 int64_t Mutator::MutateInt64(int64_t value) { return FlipBit(value, &random_); }
752 
MutateUInt32(uint32_t value)753 uint32_t Mutator::MutateUInt32(uint32_t value) {
754   return FlipBit(value, &random_);
755 }
756 
MutateUInt64(uint64_t value)757 uint64_t Mutator::MutateUInt64(uint64_t value) {
758   return FlipBit(value, &random_);
759 }
760 
MutateFloat(float value)761 float Mutator::MutateFloat(float value) { return FlipBit(value, &random_); }
762 
MutateDouble(double value)763 double Mutator::MutateDouble(double value) { return FlipBit(value, &random_); }
764 
MutateBool(bool value)765 bool Mutator::MutateBool(bool value) { return !value; }
766 
MutateEnum(size_t index,size_t item_count)767 size_t Mutator::MutateEnum(size_t index, size_t item_count) {
768   if (item_count <= 1) return 0;
769   return (index + 1 + GetRandomIndex(&random_, item_count - 1)) % item_count;
770 }
771 
MutateString(const std::string & value,int size_increase_hint)772 std::string Mutator::MutateString(const std::string& value,
773                                   int size_increase_hint) {
774   std::string result = value;
775 
776   while (!result.empty() && GetRandomBool(&random_)) {
777     result.erase(GetRandomIndex(&random_, result.size()), 1);
778   }
779 
780   while (size_increase_hint > 0 &&
781          result.size() < static_cast<size_t>(size_increase_hint) &&
782          GetRandomBool(&random_)) {
783     size_t index = GetRandomIndex(&random_, result.size() + 1);
784     result.insert(result.begin() + index, GetRandomIndex(&random_, 1 << 8));
785   }
786 
787   if (result != value) return result;
788 
789   if (result.empty()) {
790     result.push_back(GetRandomIndex(&random_, 1 << 8));
791     return result;
792   }
793 
794   if (!result.empty())
795     FlipBit(result.size(), reinterpret_cast<uint8_t*>(&result[0]), &random_);
796   return result;
797 }
798 
MutateUtf8String(const std::string & value,int size_increase_hint)799 std::string Mutator::MutateUtf8String(const std::string& value,
800                                       int size_increase_hint) {
801   std::string str = MutateString(value, size_increase_hint);
802   FixUtf8String(&str, &random_);
803   return str;
804 }
805 
IsInitialized(const Message & message) const806 bool Mutator::IsInitialized(const Message& message) const {
807   if (!keep_initialized_ || message.IsInitialized()) return true;
808   std::cerr << "Uninitialized: " << message.DebugString() << "\n";
809   return false;
810 }
811 
812 }  // namespace protobuf_mutator
813