1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/mutator.h"
16
17 #include <algorithm>
18 #include <bitset>
19 #include <iostream>
20 #include <map>
21 #include <memory>
22 #include <random>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "src/field_instance.h"
28 #include "src/utf8_fix.h"
29 #include "src/weighted_reservoir_sampler.h"
30
31 namespace protobuf_mutator {
32
33 using google::protobuf::Any;
34 using protobuf::Descriptor;
35 using protobuf::FieldDescriptor;
36 using protobuf::FileDescriptor;
37 using protobuf::Message;
38 using protobuf::OneofDescriptor;
39 using protobuf::Reflection;
40 using protobuf::util::MessageDifferencer;
41 using std::placeholders::_1;
42
43 namespace {
44
45 const int kMaxInitializeDepth = 200;
46 const uint64_t kDefaultMutateWeight = 1000000;
47
48 enum class Mutation : uint8_t {
49 None,
50 Add, // Adds new field with default value.
51 Mutate, // Mutates field contents.
52 Delete, // Deletes field.
53 Copy, // Copy values copied from another field.
54 Clone, // Create new field with value copied from another.
55
56 Last = Clone,
57 };
58
59 using MutationBitset = std::bitset<static_cast<size_t>(Mutation::Last) + 1>;
60
61 using Messages = std::vector<Message*>;
62 using ConstMessages = std::vector<const Message*>;
63
64 // Return random integer from [0, count)
GetRandomIndex(RandomEngine * random,size_t count)65 size_t GetRandomIndex(RandomEngine* random, size_t count) {
66 assert(count > 0);
67 if (count == 1) return 0;
68 return std::uniform_int_distribution<size_t>(0, count - 1)(*random);
69 }
70
71 // Flips random bit in the buffer.
FlipBit(size_t size,uint8_t * bytes,RandomEngine * random)72 void FlipBit(size_t size, uint8_t* bytes, RandomEngine* random) {
73 size_t bit = GetRandomIndex(random, size * 8);
74 bytes[bit / 8] ^= (1u << (bit % 8));
75 }
76
77 // Flips random bit in the value.
78 template <class T>
FlipBit(T value,RandomEngine * random)79 T FlipBit(T value, RandomEngine* random) {
80 FlipBit(sizeof(value), reinterpret_cast<uint8_t*>(&value), random);
81 return value;
82 }
83
84 // Return true with probability about 1-of-n.
GetRandomBool(RandomEngine * random,size_t n=2)85 bool GetRandomBool(RandomEngine* random, size_t n = 2) {
86 return GetRandomIndex(random, n) == 0;
87 }
88
IsProto3SimpleField(const FieldDescriptor & field)89 bool IsProto3SimpleField(const FieldDescriptor& field) {
90 assert(field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 ||
91 field.file()->syntax() == FileDescriptor::SYNTAX_PROTO2);
92 return field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 &&
93 field.cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE &&
94 !field.containing_oneof() && !field.is_repeated();
95 }
96
97 struct CreateDefaultField : public FieldFunction<CreateDefaultField> {
98 template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::CreateDefaultField99 void ForType(const FieldInstance& field) const {
100 T value;
101 field.GetDefault(&value);
102 field.Create(value);
103 }
104 };
105
106 struct DeleteField : public FieldFunction<DeleteField> {
107 template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::DeleteField108 void ForType(const FieldInstance& field) const {
109 field.Delete();
110 }
111 };
112
113 struct CopyField : public FieldFunction<CopyField> {
114 template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::CopyField115 void ForType(const ConstFieldInstance& source,
116 const FieldInstance& field) const {
117 T value;
118 source.Load(&value);
119 field.Store(value);
120 }
121 };
122
123 struct AppendField : public FieldFunction<AppendField> {
124 template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0111::AppendField125 void ForType(const ConstFieldInstance& source,
126 const FieldInstance& field) const {
127 T value;
128 source.Load(&value);
129 field.Create(value);
130 }
131 };
132
133 class CanCopyAndDifferentField
134 : public FieldFunction<CanCopyAndDifferentField, bool> {
135 public:
136 template <class T>
ForType(const ConstFieldInstance & src,const ConstFieldInstance & dst,int size_increase_hint) const137 bool ForType(const ConstFieldInstance& src, const ConstFieldInstance& dst,
138 int size_increase_hint) const {
139 T s;
140 src.Load(&s);
141 if (!dst.CanStore(s)) return false;
142 T d;
143 dst.Load(&d);
144 return SizeDiff(s, d) <= size_increase_hint && !IsEqual(s, d);
145 }
146
147 private:
IsEqual(const ConstFieldInstance::Enum & a,const ConstFieldInstance::Enum & b) const148 bool IsEqual(const ConstFieldInstance::Enum& a,
149 const ConstFieldInstance::Enum& b) const {
150 assert(a.count == b.count);
151 return a.index == b.index;
152 }
153
IsEqual(const std::unique_ptr<Message> & a,const std::unique_ptr<Message> & b) const154 bool IsEqual(const std::unique_ptr<Message>& a,
155 const std::unique_ptr<Message>& b) const {
156 return MessageDifferencer::Equals(*a, *b);
157 }
158
159 template <class T>
IsEqual(const T & a,const T & b) const160 bool IsEqual(const T& a, const T& b) const {
161 return a == b;
162 }
163
SizeDiff(const std::unique_ptr<Message> & src,const std::unique_ptr<Message> & dst) const164 int64_t SizeDiff(const std::unique_ptr<Message>& src,
165 const std::unique_ptr<Message>& dst) const {
166 return src->ByteSizeLong() - dst->ByteSizeLong();
167 }
168
SizeDiff(const std::string & src,const std::string & dst) const169 int64_t SizeDiff(const std::string& src, const std::string& dst) const {
170 return src.size() - dst.size();
171 }
172
173 template <class T>
SizeDiff(const T &,const T &) const174 int64_t SizeDiff(const T&, const T&) const {
175 return 0;
176 }
177 };
178
179 // Selects random field and mutation from the given proto message.
180 class MutationSampler {
181 public:
MutationSampler(bool keep_initialized,MutationBitset allowed_mutations,RandomEngine * random)182 MutationSampler(bool keep_initialized, MutationBitset allowed_mutations,
183 RandomEngine* random)
184 : keep_initialized_(keep_initialized),
185 allowed_mutations_(allowed_mutations),
186 random_(random),
187 sampler_(random) {}
188
189 // Returns selected field.
field() const190 const FieldInstance& field() const { return sampler_.selected().field; }
191
192 // Returns selected mutation.
mutation() const193 Mutation mutation() const { return sampler_.selected().mutation; }
194
Sample(Message * message)195 void Sample(Message* message) {
196 SampleImpl(message);
197 assert(mutation() != Mutation::None ||
198 !allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
199 message->GetDescriptor()->field_count() == 0);
200 }
201
202 private:
SampleImpl(Message * message)203 void SampleImpl(Message* message) {
204 const Descriptor* descriptor = message->GetDescriptor();
205 const Reflection* reflection = message->GetReflection();
206
207 int field_count = descriptor->field_count();
208 for (int i = 0; i < field_count; ++i) {
209 const FieldDescriptor* field = descriptor->field(i);
210 if (const OneofDescriptor* oneof = field->containing_oneof()) {
211 // Handle entire oneof group on the first field.
212 if (field->index_in_oneof() == 0) {
213 assert(oneof->field_count());
214 const FieldDescriptor* current_field =
215 reflection->GetOneofFieldDescriptor(*message, oneof);
216 for (;;) {
217 const FieldDescriptor* add_field =
218 oneof->field(GetRandomIndex(random_, oneof->field_count()));
219 if (add_field != current_field) {
220 Try({message, add_field}, Mutation::Add);
221 Try({message, add_field}, Mutation::Clone);
222 break;
223 }
224 if (oneof->field_count() < 2) break;
225 }
226 if (current_field) {
227 if (current_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
228 Try({message, current_field}, Mutation::Mutate);
229 Try({message, current_field}, Mutation::Delete);
230 Try({message, current_field}, Mutation::Copy);
231 }
232 }
233 } else {
234 if (field->is_repeated()) {
235 int field_size = reflection->FieldSize(*message, field);
236 size_t random_index = GetRandomIndex(random_, field_size + 1);
237 Try({message, field, random_index}, Mutation::Add);
238 Try({message, field, random_index}, Mutation::Clone);
239
240 if (field_size) {
241 size_t random_index = GetRandomIndex(random_, field_size);
242 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
243 Try({message, field, random_index}, Mutation::Mutate);
244 Try({message, field, random_index}, Mutation::Delete);
245 Try({message, field, random_index}, Mutation::Copy);
246 }
247 } else {
248 if (reflection->HasField(*message, field) ||
249 IsProto3SimpleField(*field)) {
250 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
251 Try({message, field}, Mutation::Mutate);
252 if (!IsProto3SimpleField(*field) &&
253 (!field->is_required() || !keep_initialized_)) {
254 Try({message, field}, Mutation::Delete);
255 }
256 Try({message, field}, Mutation::Copy);
257 } else {
258 Try({message, field}, Mutation::Add);
259 Try({message, field}, Mutation::Clone);
260 }
261 }
262 }
263
264 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
265 if (field->is_repeated()) {
266 const int field_size = reflection->FieldSize(*message, field);
267 for (int j = 0; j < field_size; ++j)
268 SampleImpl(reflection->MutableRepeatedMessage(message, field, j));
269 } else if (reflection->HasField(*message, field)) {
270 SampleImpl(reflection->MutableMessage(message, field));
271 }
272 }
273 }
274 }
275
Try(const FieldInstance & field,Mutation mutation)276 void Try(const FieldInstance& field, Mutation mutation) {
277 assert(mutation != Mutation::None);
278 if (!allowed_mutations_[static_cast<size_t>(mutation)]) return;
279 sampler_.Try(kDefaultMutateWeight, {field, mutation});
280 }
281
282 bool keep_initialized_ = false;
283 MutationBitset allowed_mutations_;
284
285 RandomEngine* random_;
286
287 struct Result {
288 Result() = default;
Resultprotobuf_mutator::__anonbdd7ecbf0111::MutationSampler::Result289 Result(const FieldInstance& f, Mutation m) : field(f), mutation(m) {}
290
291 FieldInstance field;
292 Mutation mutation = Mutation::None;
293 };
294 WeightedReservoirSampler<Result, RandomEngine> sampler_;
295 };
296
297 // Selects random field of compatible type to use for clone mutations.
298 class DataSourceSampler {
299 public:
DataSourceSampler(const ConstFieldInstance & match,RandomEngine * random,int size_increase_hint)300 DataSourceSampler(const ConstFieldInstance& match, RandomEngine* random,
301 int size_increase_hint)
302 : match_(match),
303 random_(random),
304 size_increase_hint_(size_increase_hint),
305 sampler_(random) {}
306
Sample(const Message & message)307 void Sample(const Message& message) { SampleImpl(message); }
308
309 // Returns selected field.
field() const310 const ConstFieldInstance& field() const {
311 assert(!IsEmpty());
312 return sampler_.selected();
313 }
314
IsEmpty() const315 bool IsEmpty() const { return sampler_.IsEmpty(); }
316
317 private:
SampleImpl(const Message & message)318 void SampleImpl(const Message& message) {
319 const Descriptor* descriptor = message.GetDescriptor();
320 const Reflection* reflection = message.GetReflection();
321
322 int field_count = descriptor->field_count();
323 for (int i = 0; i < field_count; ++i) {
324 const FieldDescriptor* field = descriptor->field(i);
325 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
326 if (field->is_repeated()) {
327 const int field_size = reflection->FieldSize(message, field);
328 for (int j = 0; j < field_size; ++j) {
329 SampleImpl(reflection->GetRepeatedMessage(message, field, j));
330 }
331 } else if (reflection->HasField(message, field)) {
332 SampleImpl(reflection->GetMessage(message, field));
333 }
334 }
335
336 if (field->cpp_type() != match_.cpp_type()) continue;
337 if (match_.cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
338 if (field->enum_type() != match_.enum_type()) continue;
339 } else if (match_.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
340 if (field->message_type() != match_.message_type()) continue;
341 }
342
343 if (field->is_repeated()) {
344 if (int field_size = reflection->FieldSize(message, field)) {
345 ConstFieldInstance source(&message, field,
346 GetRandomIndex(random_, field_size));
347 if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
348 sampler_.Try(field_size, source);
349 }
350 } else {
351 if (reflection->HasField(message, field)) {
352 ConstFieldInstance source(&message, field);
353 if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
354 sampler_.Try(1, source);
355 }
356 }
357 }
358 }
359
360 ConstFieldInstance match_;
361 RandomEngine* random_;
362 int size_increase_hint_;
363
364 WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
365 };
366
367 using UnpackedAny =
368 std::unordered_map<const Message*, std::unique_ptr<Message>>;
369
GetAnyTypeDescriptor(const Any & any)370 const Descriptor* GetAnyTypeDescriptor(const Any& any) {
371 std::string type_name;
372 if (!Any::ParseAnyTypeUrl(std::string(any.type_url()), &type_name))
373 return nullptr;
374 return any.descriptor()->file()->pool()->FindMessageTypeByName(type_name);
375 }
376
UnpackAny(const Any & any)377 std::unique_ptr<Message> UnpackAny(const Any& any) {
378 const Descriptor* desc = GetAnyTypeDescriptor(any);
379 if (!desc) return {};
380 std::unique_ptr<Message> message(
381 any.GetReflection()->GetMessageFactory()->GetPrototype(desc)->New());
382 message->ParsePartialFromString(std::string(any.value()));
383 return message;
384 }
385
CastToAny(const Message * message)386 const Any* CastToAny(const Message* message) {
387 return Any::GetDescriptor() == message->GetDescriptor()
388 ? static_cast<const Any*>(message)
389 : nullptr;
390 }
391
CastToAny(Message * message)392 Any* CastToAny(Message* message) {
393 return Any::GetDescriptor() == message->GetDescriptor()
394 ? static_cast<Any*>(message)
395 : nullptr;
396 }
397
UnpackIfAny(const Message & message)398 std::unique_ptr<Message> UnpackIfAny(const Message& message) {
399 if (const Any* any = CastToAny(&message)) return UnpackAny(*any);
400 return {};
401 }
402
UnpackAny(const Message & message,UnpackedAny * result)403 void UnpackAny(const Message& message, UnpackedAny* result) {
404 if (std::unique_ptr<Message> any = UnpackIfAny(message)) {
405 UnpackAny(*any, result);
406 result->emplace(&message, std::move(any));
407 return;
408 }
409
410 const Descriptor* descriptor = message.GetDescriptor();
411 const Reflection* reflection = message.GetReflection();
412
413 for (int i = 0; i < descriptor->field_count(); ++i) {
414 const FieldDescriptor* field = descriptor->field(i);
415 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
416 if (field->is_repeated()) {
417 const int field_size = reflection->FieldSize(message, field);
418 for (int j = 0; j < field_size; ++j) {
419 UnpackAny(reflection->GetRepeatedMessage(message, field, j), result);
420 }
421 } else if (reflection->HasField(message, field)) {
422 UnpackAny(reflection->GetMessage(message, field), result);
423 }
424 }
425 }
426 }
427
428 class PostProcessing {
429 public:
430 using PostProcessors =
431 std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;
432
PostProcessing(bool keep_initialized,const PostProcessors & post_processors,const UnpackedAny & any,RandomEngine * random)433 PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
434 const UnpackedAny& any, RandomEngine* random)
435 : keep_initialized_(keep_initialized),
436 post_processors_(post_processors),
437 any_(any),
438 random_(random) {}
439
Run(Message * message,int max_depth)440 void Run(Message* message, int max_depth) {
441 --max_depth;
442 const Descriptor* descriptor = message->GetDescriptor();
443
444 // Apply custom mutators in nested messages before packing any.
445 const Reflection* reflection = message->GetReflection();
446 for (int i = 0; i < descriptor->field_count(); i++) {
447 const FieldDescriptor* field = descriptor->field(i);
448 if (keep_initialized_ &&
449 (field->is_required() || descriptor->options().map_entry()) &&
450 !reflection->HasField(*message, field)) {
451 CreateDefaultField()(FieldInstance(message, field));
452 }
453
454 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
455
456 if (max_depth < 0 && !field->is_required()) {
457 // Clear deep optional fields to avoid stack overflow.
458 reflection->ClearField(message, field);
459 if (field->is_repeated())
460 assert(!reflection->FieldSize(*message, field));
461 else
462 assert(!reflection->HasField(*message, field));
463 continue;
464 }
465
466 if (field->is_repeated()) {
467 const int field_size = reflection->FieldSize(*message, field);
468 for (int j = 0; j < field_size; ++j) {
469 Message* nested_message =
470 reflection->MutableRepeatedMessage(message, field, j);
471 Run(nested_message, max_depth);
472 }
473 } else if (reflection->HasField(*message, field)) {
474 Message* nested_message = reflection->MutableMessage(message, field);
475 Run(nested_message, max_depth);
476 }
477 }
478
479 if (Any* any = CastToAny(message)) {
480 if (max_depth < 0) {
481 // Clear deep Any fields to avoid stack overflow.
482 any->Clear();
483 } else {
484 auto It = any_.find(message);
485 if (It != any_.end()) {
486 Run(It->second.get(), max_depth);
487 std::string value;
488 It->second->SerializePartialToString(&value);
489 *any->mutable_value() = value;
490 }
491 }
492 }
493
494 // Call user callback after message trimmed, initialized and packed.
495 auto range = post_processors_.equal_range(descriptor);
496 for (auto it = range.first; it != range.second; ++it)
497 it->second(message, (*random_)());
498 }
499
500 private:
501 bool keep_initialized_;
502 const PostProcessors& post_processors_;
503 const UnpackedAny& any_;
504 RandomEngine* random_;
505 };
506
507 } // namespace
508
509 class FieldMutator {
510 public:
FieldMutator(int size_increase_hint,bool enforce_changes,bool enforce_utf8_strings,const ConstMessages & sources,Mutator * mutator)511 FieldMutator(int size_increase_hint, bool enforce_changes,
512 bool enforce_utf8_strings, const ConstMessages& sources,
513 Mutator* mutator)
514 : size_increase_hint_(size_increase_hint),
515 enforce_changes_(enforce_changes),
516 enforce_utf8_strings_(enforce_utf8_strings),
517 sources_(sources),
518 mutator_(mutator) {}
519
Mutate(int32_t * value) const520 void Mutate(int32_t* value) const {
521 RepeatMutate(value, std::bind(&Mutator::MutateInt32, mutator_, _1));
522 }
523
Mutate(int64_t * value) const524 void Mutate(int64_t* value) const {
525 RepeatMutate(value, std::bind(&Mutator::MutateInt64, mutator_, _1));
526 }
527
Mutate(uint32_t * value) const528 void Mutate(uint32_t* value) const {
529 RepeatMutate(value, std::bind(&Mutator::MutateUInt32, mutator_, _1));
530 }
531
Mutate(uint64_t * value) const532 void Mutate(uint64_t* value) const {
533 RepeatMutate(value, std::bind(&Mutator::MutateUInt64, mutator_, _1));
534 }
535
Mutate(float * value) const536 void Mutate(float* value) const {
537 RepeatMutate(value, std::bind(&Mutator::MutateFloat, mutator_, _1));
538 }
539
Mutate(double * value) const540 void Mutate(double* value) const {
541 RepeatMutate(value, std::bind(&Mutator::MutateDouble, mutator_, _1));
542 }
543
Mutate(bool * value) const544 void Mutate(bool* value) const {
545 RepeatMutate(value, std::bind(&Mutator::MutateBool, mutator_, _1));
546 }
547
Mutate(FieldInstance::Enum * value) const548 void Mutate(FieldInstance::Enum* value) const {
549 RepeatMutate(&value->index,
550 std::bind(&Mutator::MutateEnum, mutator_, _1, value->count));
551 assert(value->index < value->count);
552 }
553
Mutate(std::string * value) const554 void Mutate(std::string* value) const {
555 if (enforce_utf8_strings_) {
556 RepeatMutate(value, std::bind(&Mutator::MutateUtf8String, mutator_, _1,
557 size_increase_hint_));
558 } else {
559 RepeatMutate(value, std::bind(&Mutator::MutateString, mutator_, _1,
560 size_increase_hint_));
561 }
562 }
563
Mutate(std::unique_ptr<Message> * message) const564 void Mutate(std::unique_ptr<Message>* message) const {
565 assert(!enforce_changes_);
566 assert(*message);
567 if (GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_))
568 return;
569 mutator_->MutateImpl(sources_, {message->get()}, false,
570 size_increase_hint_);
571 }
572
573 private:
574 template <class T, class F>
RepeatMutate(T * value,F mutate) const575 void RepeatMutate(T* value, F mutate) const {
576 if (!enforce_changes_ &&
577 GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_)) {
578 return;
579 }
580 T tmp = *value;
581 for (int i = 0; i < 10; ++i) {
582 *value = mutate(*value);
583 if (!enforce_changes_ || *value != tmp) return;
584 }
585 }
586
587 int size_increase_hint_;
588 size_t enforce_changes_;
589 bool enforce_utf8_strings_;
590 const ConstMessages& sources_;
591 Mutator* mutator_;
592 };
593
594 namespace {
595
596 struct MutateField : public FieldFunction<MutateField> {
597 template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0211::MutateField598 void ForType(const FieldInstance& field, int size_increase_hint,
599 const ConstMessages& sources, Mutator* mutator) const {
600 T value;
601 field.Load(&value);
602 FieldMutator(size_increase_hint, true, field.EnforceUtf8(), sources,
603 mutator)
604 .Mutate(&value);
605 field.Store(value);
606 }
607 };
608
609 struct CreateField : public FieldFunction<CreateField> {
610 public:
611 template <class T>
ForTypeprotobuf_mutator::__anonbdd7ecbf0211::CreateField612 void ForType(const FieldInstance& field, int size_increase_hint,
613 const ConstMessages& sources, Mutator* mutator) const {
614 T value;
615 field.GetDefault(&value);
616 FieldMutator field_mutator(size_increase_hint,
617 false /* defaults could be useful */,
618 field.EnforceUtf8(), sources, mutator);
619 field_mutator.Mutate(&value);
620 field.Create(value);
621 }
622 };
623
624 } // namespace
625
Seed(uint32_t value)626 void Mutator::Seed(uint32_t value) { random_.seed(value); }
627
Fix(Message * message)628 void Mutator::Fix(Message* message) {
629 UnpackedAny any;
630 UnpackAny(*message, &any);
631
632 PostProcessing(keep_initialized_, post_processors_, any, &random_)
633 .Run(message, kMaxInitializeDepth);
634 assert(IsInitialized(*message));
635 }
636
Mutate(Message * message,size_t max_size_hint)637 void Mutator::Mutate(Message* message, size_t max_size_hint) {
638 UnpackedAny any;
639 UnpackAny(*message, &any);
640
641 Messages messages;
642 messages.reserve(any.size() + 1);
643 messages.push_back(message);
644 for (const auto& kv : any) messages.push_back(kv.second.get());
645
646 ConstMessages sources(messages.begin(), messages.end());
647 MutateImpl(sources, messages, false,
648 static_cast<int>(max_size_hint) -
649 static_cast<int>(message->ByteSizeLong()));
650
651 PostProcessing(keep_initialized_, post_processors_, any, &random_)
652 .Run(message, kMaxInitializeDepth);
653 assert(IsInitialized(*message));
654 }
655
CrossOver(const Message & message1,Message * message2,size_t max_size_hint)656 void Mutator::CrossOver(const Message& message1, Message* message2,
657 size_t max_size_hint) {
658 UnpackedAny any;
659 UnpackAny(*message2, &any);
660
661 Messages messages;
662 messages.reserve(any.size() + 1);
663 messages.push_back(message2);
664 for (auto& kv : any) messages.push_back(kv.second.get());
665
666 UnpackAny(message1, &any);
667
668 ConstMessages sources;
669 sources.reserve(any.size() + 2);
670 sources.push_back(&message1);
671 sources.push_back(message2);
672 for (const auto& kv : any) sources.push_back(kv.second.get());
673
674 MutateImpl(sources, messages, true,
675 static_cast<int>(max_size_hint) -
676 static_cast<int>(message2->ByteSizeLong()));
677
678 PostProcessing(keep_initialized_, post_processors_, any, &random_)
679 .Run(message2, kMaxInitializeDepth);
680 assert(IsInitialized(*message2));
681 }
682
RegisterPostProcessor(const Descriptor * desc,PostProcess callback)683 void Mutator::RegisterPostProcessor(const Descriptor* desc,
684 PostProcess callback) {
685 post_processors_.emplace(desc, callback);
686 }
687
MutateImpl(const ConstMessages & sources,const Messages & messages,bool copy_clone_only,int size_increase_hint)688 bool Mutator::MutateImpl(const ConstMessages& sources, const Messages& messages,
689 bool copy_clone_only, int size_increase_hint) {
690 MutationBitset mutations;
691 if (copy_clone_only) {
692 mutations[static_cast<size_t>(Mutation::Copy)] = true;
693 mutations[static_cast<size_t>(Mutation::Clone)] = true;
694 } else if (size_increase_hint <= 16) {
695 mutations[static_cast<size_t>(Mutation::Delete)] = true;
696 } else {
697 mutations.set();
698 mutations[static_cast<size_t>(Mutation::Copy)] = false;
699 mutations[static_cast<size_t>(Mutation::Clone)] = false;
700 }
701 while (mutations.any()) {
702 MutationSampler mutation(keep_initialized_, mutations, &random_);
703 for (Message* message : messages) mutation.Sample(message);
704
705 switch (mutation.mutation()) {
706 case Mutation::None:
707 return true;
708 case Mutation::Add:
709 CreateField()(mutation.field(), size_increase_hint, sources, this);
710 return true;
711 case Mutation::Mutate:
712 MutateField()(mutation.field(), size_increase_hint, sources, this);
713 return true;
714 case Mutation::Delete:
715 DeleteField()(mutation.field());
716 return true;
717 case Mutation::Clone: {
718 CreateDefaultField()(mutation.field());
719 DataSourceSampler source_sampler(mutation.field(), &random_,
720 size_increase_hint);
721 for (const Message* source : sources) source_sampler.Sample(*source);
722 if (source_sampler.IsEmpty()) {
723 if (!IsProto3SimpleField(*mutation.field().descriptor()))
724 return true; // CreateField is enough for proto2.
725 break;
726 }
727 CopyField()(source_sampler.field(), mutation.field());
728 return true;
729 }
730 case Mutation::Copy: {
731 DataSourceSampler source_sampler(mutation.field(), &random_,
732 size_increase_hint);
733 for (const Message* source : sources) source_sampler.Sample(*source);
734 if (source_sampler.IsEmpty()) break;
735 CopyField()(source_sampler.field(), mutation.field());
736 return true;
737 }
738 default:
739 assert(false && "unexpected mutation");
740 return false;
741 }
742
743 // Don't try same mutation next time.
744 mutations[static_cast<size_t>(mutation.mutation())] = false;
745 }
746 return false;
747 }
748
MutateInt32(int32_t value)749 int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }
750
MutateInt64(int64_t value)751 int64_t Mutator::MutateInt64(int64_t value) { return FlipBit(value, &random_); }
752
MutateUInt32(uint32_t value)753 uint32_t Mutator::MutateUInt32(uint32_t value) {
754 return FlipBit(value, &random_);
755 }
756
MutateUInt64(uint64_t value)757 uint64_t Mutator::MutateUInt64(uint64_t value) {
758 return FlipBit(value, &random_);
759 }
760
MutateFloat(float value)761 float Mutator::MutateFloat(float value) { return FlipBit(value, &random_); }
762
MutateDouble(double value)763 double Mutator::MutateDouble(double value) { return FlipBit(value, &random_); }
764
MutateBool(bool value)765 bool Mutator::MutateBool(bool value) { return !value; }
766
MutateEnum(size_t index,size_t item_count)767 size_t Mutator::MutateEnum(size_t index, size_t item_count) {
768 if (item_count <= 1) return 0;
769 return (index + 1 + GetRandomIndex(&random_, item_count - 1)) % item_count;
770 }
771
MutateString(const std::string & value,int size_increase_hint)772 std::string Mutator::MutateString(const std::string& value,
773 int size_increase_hint) {
774 std::string result = value;
775
776 while (!result.empty() && GetRandomBool(&random_)) {
777 result.erase(GetRandomIndex(&random_, result.size()), 1);
778 }
779
780 while (size_increase_hint > 0 &&
781 result.size() < static_cast<size_t>(size_increase_hint) &&
782 GetRandomBool(&random_)) {
783 size_t index = GetRandomIndex(&random_, result.size() + 1);
784 result.insert(result.begin() + index, GetRandomIndex(&random_, 1 << 8));
785 }
786
787 if (result != value) return result;
788
789 if (result.empty()) {
790 result.push_back(GetRandomIndex(&random_, 1 << 8));
791 return result;
792 }
793
794 if (!result.empty())
795 FlipBit(result.size(), reinterpret_cast<uint8_t*>(&result[0]), &random_);
796 return result;
797 }
798
MutateUtf8String(const std::string & value,int size_increase_hint)799 std::string Mutator::MutateUtf8String(const std::string& value,
800 int size_increase_hint) {
801 std::string str = MutateString(value, size_increase_hint);
802 FixUtf8String(&str, &random_);
803 return str;
804 }
805
IsInitialized(const Message & message) const806 bool Mutator::IsInitialized(const Message& message) const {
807 if (!keep_initialized_ || message.IsInitialized()) return true;
808 std::cerr << "Uninitialized: " << message.DebugString() << "\n";
809 return false;
810 }
811
812 } // namespace protobuf_mutator
813