1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/mutator.h"
16
17 #include <algorithm>
18 #include <bitset>
19 #include <map>
20 #include <memory>
21 #include <random>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "src/field_instance.h"
27 #include "src/utf8_fix.h"
28 #include "src/weighted_reservoir_sampler.h"
29
30 namespace protobuf_mutator {
31
32 using google::protobuf::Any;
33 using protobuf::Descriptor;
34 using protobuf::FieldDescriptor;
35 using protobuf::FileDescriptor;
36 using protobuf::Message;
37 using protobuf::OneofDescriptor;
38 using protobuf::Reflection;
39 using protobuf::util::MessageDifferencer;
40 using std::placeholders::_1;
41
42 namespace {
43
44 const int kMaxInitializeDepth = 200;
45 const uint64_t kDefaultMutateWeight = 1000000;
46
47 enum class Mutation : uint8_t {
48 None,
49 Add, // Adds new field with default value.
50 Mutate, // Mutates field contents.
51 Delete, // Deletes field.
52 Copy, // Copy values copied from another field.
53 Clone, // Create new field with value copied from another.
54
55 Last = Clone,
56 };
57
58 using MutationBitset = std::bitset<static_cast<size_t>(Mutation::Last)>;
59
60 using Messages = std::vector<Message*>;
61 using ConstMessages = std::vector<const Message*>;
62
63 // Return random integer from [0, count)
GetRandomIndex(RandomEngine * random,size_t count)64 size_t GetRandomIndex(RandomEngine* random, size_t count) {
65 assert(count > 0);
66 if (count == 1) return 0;
67 return std::uniform_int_distribution<size_t>(0, count - 1)(*random);
68 }
69
70 // Flips random bit in the buffer.
FlipBit(size_t size,uint8_t * bytes,RandomEngine * random)71 void FlipBit(size_t size, uint8_t* bytes, RandomEngine* random) {
72 size_t bit = GetRandomIndex(random, size * 8);
73 bytes[bit / 8] ^= (1u << (bit % 8));
74 }
75
76 // Flips random bit in the value.
77 template <class T>
FlipBit(T value,RandomEngine * random)78 T FlipBit(T value, RandomEngine* random) {
79 FlipBit(sizeof(value), reinterpret_cast<uint8_t*>(&value), random);
80 return value;
81 }
82
83 // Return true with probability about 1-of-n.
GetRandomBool(RandomEngine * random,size_t n=2)84 bool GetRandomBool(RandomEngine* random, size_t n = 2) {
85 return GetRandomIndex(random, n) == 0;
86 }
87
IsProto3SimpleField(const FieldDescriptor & field)88 bool IsProto3SimpleField(const FieldDescriptor& field) {
89 assert(field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 ||
90 field.file()->syntax() == FileDescriptor::SYNTAX_PROTO2);
91 return field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 &&
92 field.cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE &&
93 !field.containing_oneof() && !field.is_repeated();
94 }
95
96 struct CreateDefaultField : public FieldFunction<CreateDefaultField> {
97 template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::CreateDefaultField98 void ForType(const FieldInstance& field) const {
99 T value;
100 field.GetDefault(&value);
101 field.Create(value);
102 }
103 };
104
105 struct DeleteField : public FieldFunction<DeleteField> {
106 template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::DeleteField107 void ForType(const FieldInstance& field) const {
108 field.Delete();
109 }
110 };
111
112 struct CopyField : public FieldFunction<CopyField> {
113 template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::CopyField114 void ForType(const ConstFieldInstance& source,
115 const FieldInstance& field) const {
116 T value;
117 source.Load(&value);
118 field.Store(value);
119 }
120 };
121
122 struct AppendField : public FieldFunction<AppendField> {
123 template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30111::AppendField124 void ForType(const ConstFieldInstance& source,
125 const FieldInstance& field) const {
126 T value;
127 source.Load(&value);
128 field.Create(value);
129 }
130 };
131
132 class CanCopyAndDifferentField
133 : public FieldFunction<CanCopyAndDifferentField, bool> {
134 public:
135 template <class T>
ForType(const ConstFieldInstance & src,const ConstFieldInstance & dst,int size_increase_hint) const136 bool ForType(const ConstFieldInstance& src, const ConstFieldInstance& dst,
137 int size_increase_hint) const {
138 T s;
139 src.Load(&s);
140 if (!dst.CanStore(s)) return false;
141 T d;
142 dst.Load(&d);
143 return SizeDiff(s, d) <= size_increase_hint && !IsEqual(s, d);
144 }
145
146 private:
IsEqual(const ConstFieldInstance::Enum & a,const ConstFieldInstance::Enum & b) const147 bool IsEqual(const ConstFieldInstance::Enum& a,
148 const ConstFieldInstance::Enum& b) const {
149 assert(a.count == b.count);
150 return a.index == b.index;
151 }
152
IsEqual(const std::unique_ptr<Message> & a,const std::unique_ptr<Message> & b) const153 bool IsEqual(const std::unique_ptr<Message>& a,
154 const std::unique_ptr<Message>& b) const {
155 return MessageDifferencer::Equals(*a, *b);
156 }
157
158 template <class T>
IsEqual(const T & a,const T & b) const159 bool IsEqual(const T& a, const T& b) const {
160 return a == b;
161 }
162
SizeDiff(const std::unique_ptr<Message> & src,const std::unique_ptr<Message> & dst) const163 int64_t SizeDiff(const std::unique_ptr<Message>& src,
164 const std::unique_ptr<Message>& dst) const {
165 return src->ByteSizeLong() - dst->ByteSizeLong();
166 }
167
SizeDiff(const std::string & src,const std::string & dst) const168 int64_t SizeDiff(const std::string& src, const std::string& dst) const {
169 return src.size() - dst.size();
170 }
171
172 template <class T>
SizeDiff(const T &,const T &) const173 int64_t SizeDiff(const T&, const T&) const {
174 return 0;
175 }
176 };
177
178 // Selects random field and mutation from the given proto message.
179 class MutationSampler {
180 public:
MutationSampler(bool keep_initialized,MutationBitset allowed_mutations,RandomEngine * random)181 MutationSampler(bool keep_initialized, MutationBitset allowed_mutations,
182 RandomEngine* random)
183 : keep_initialized_(keep_initialized),
184 allowed_mutations_(allowed_mutations),
185 random_(random),
186 sampler_(random) {}
187
188 // Returns selected field.
field() const189 const FieldInstance& field() const { return sampler_.selected().field; }
190
191 // Returns selected mutation.
mutation() const192 Mutation mutation() const { return sampler_.selected().mutation; }
193
Sample(Message * message)194 void Sample(Message* message) {
195 SampleImpl(message);
196 assert(mutation() != Mutation::None ||
197 !allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
198 message->GetDescriptor()->field_count() == 0);
199 }
200
201 private:
SampleImpl(Message * message)202 void SampleImpl(Message* message) {
203 const Descriptor* descriptor = message->GetDescriptor();
204 const Reflection* reflection = message->GetReflection();
205
206 int field_count = descriptor->field_count();
207 for (int i = 0; i < field_count; ++i) {
208 const FieldDescriptor* field = descriptor->field(i);
209 if (const OneofDescriptor* oneof = field->containing_oneof()) {
210 // Handle entire oneof group on the first field.
211 if (field->index_in_oneof() == 0) {
212 assert(oneof->field_count());
213 const FieldDescriptor* current_field =
214 reflection->GetOneofFieldDescriptor(*message, oneof);
215 for (;;) {
216 const FieldDescriptor* add_field =
217 oneof->field(GetRandomIndex(random_, oneof->field_count()));
218 if (add_field != current_field) {
219 Try({message, add_field}, Mutation::Add);
220 Try({message, add_field}, Mutation::Clone);
221 break;
222 }
223 if (oneof->field_count() < 2) break;
224 }
225 if (current_field) {
226 if (current_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
227 Try({message, current_field}, Mutation::Mutate);
228 Try({message, current_field}, Mutation::Delete);
229 Try({message, current_field}, Mutation::Copy);
230 }
231 }
232 } else {
233 if (field->is_repeated()) {
234 int field_size = reflection->FieldSize(*message, field);
235 size_t random_index = GetRandomIndex(random_, field_size + 1);
236 Try({message, field, random_index}, Mutation::Add);
237 Try({message, field, random_index}, Mutation::Clone);
238
239 if (field_size) {
240 size_t random_index = GetRandomIndex(random_, field_size);
241 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
242 Try({message, field, random_index}, Mutation::Mutate);
243 Try({message, field, random_index}, Mutation::Delete);
244 Try({message, field, random_index}, Mutation::Copy);
245 }
246 } else {
247 if (reflection->HasField(*message, field) ||
248 IsProto3SimpleField(*field)) {
249 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
250 Try({message, field}, Mutation::Mutate);
251 if (!IsProto3SimpleField(*field) &&
252 (!field->is_required() || !keep_initialized_)) {
253 Try({message, field}, Mutation::Delete);
254 }
255 Try({message, field}, Mutation::Copy);
256 } else {
257 Try({message, field}, Mutation::Add);
258 Try({message, field}, Mutation::Clone);
259 }
260 }
261 }
262
263 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
264 if (field->is_repeated()) {
265 const int field_size = reflection->FieldSize(*message, field);
266 for (int j = 0; j < field_size; ++j)
267 SampleImpl(reflection->MutableRepeatedMessage(message, field, j));
268 } else if (reflection->HasField(*message, field)) {
269 SampleImpl(reflection->MutableMessage(message, field));
270 }
271 }
272 }
273 }
274
Try(const FieldInstance & field,Mutation mutation)275 void Try(const FieldInstance& field, Mutation mutation) {
276 assert(mutation != Mutation::None);
277 if (!allowed_mutations_[static_cast<size_t>(mutation)]) return;
278 sampler_.Try(kDefaultMutateWeight, {field, mutation});
279 }
280
281 bool keep_initialized_ = false;
282 MutationBitset allowed_mutations_;
283
284 RandomEngine* random_;
285
286 struct Result {
287 Result() = default;
Resultprotobuf_mutator::__anon4db2e1e30111::MutationSampler::Result288 Result(const FieldInstance& f, Mutation m) : field(f), mutation(m) {}
289
290 FieldInstance field;
291 Mutation mutation = Mutation::None;
292 };
293 WeightedReservoirSampler<Result, RandomEngine> sampler_;
294 };
295
296 // Selects random field of compatible type to use for clone mutations.
297 class DataSourceSampler {
298 public:
DataSourceSampler(const ConstFieldInstance & match,RandomEngine * random,int size_increase_hint)299 DataSourceSampler(const ConstFieldInstance& match, RandomEngine* random,
300 int size_increase_hint)
301 : match_(match),
302 random_(random),
303 size_increase_hint_(size_increase_hint),
304 sampler_(random) {}
305
Sample(const Message & message)306 void Sample(const Message& message) { SampleImpl(message); }
307
308 // Returns selected field.
field() const309 const ConstFieldInstance& field() const {
310 assert(!IsEmpty());
311 return sampler_.selected();
312 }
313
IsEmpty() const314 bool IsEmpty() const { return sampler_.IsEmpty(); }
315
316 private:
SampleImpl(const Message & message)317 void SampleImpl(const Message& message) {
318 const Descriptor* descriptor = message.GetDescriptor();
319 const Reflection* reflection = message.GetReflection();
320
321 int field_count = descriptor->field_count();
322 for (int i = 0; i < field_count; ++i) {
323 const FieldDescriptor* field = descriptor->field(i);
324 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
325 if (field->is_repeated()) {
326 const int field_size = reflection->FieldSize(message, field);
327 for (int j = 0; j < field_size; ++j) {
328 SampleImpl(reflection->GetRepeatedMessage(message, field, j));
329 }
330 } else if (reflection->HasField(message, field)) {
331 SampleImpl(reflection->GetMessage(message, field));
332 }
333 }
334
335 if (field->cpp_type() != match_.cpp_type()) continue;
336 if (match_.cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
337 if (field->enum_type() != match_.enum_type()) continue;
338 } else if (match_.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
339 if (field->message_type() != match_.message_type()) continue;
340 }
341
342 if (field->is_repeated()) {
343 if (int field_size = reflection->FieldSize(message, field)) {
344 ConstFieldInstance source(&message, field,
345 GetRandomIndex(random_, field_size));
346 if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
347 sampler_.Try(field_size, source);
348 }
349 } else {
350 if (reflection->HasField(message, field)) {
351 ConstFieldInstance source(&message, field);
352 if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
353 sampler_.Try(1, source);
354 }
355 }
356 }
357 }
358
359 ConstFieldInstance match_;
360 RandomEngine* random_;
361 int size_increase_hint_;
362
363 WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
364 };
365
366 using UnpackedAny =
367 std::unordered_map<const Message*, std::unique_ptr<Message>>;
368
GetAnyTypeDescriptor(const Any & any)369 const Descriptor* GetAnyTypeDescriptor(const Any& any) {
370 std::string type_name;
371 if (!Any::ParseAnyTypeUrl(std::string(any.type_url()), &type_name))
372 return nullptr;
373 return any.descriptor()->file()->pool()->FindMessageTypeByName(type_name);
374 }
375
UnpackAny(const Any & any)376 std::unique_ptr<Message> UnpackAny(const Any& any) {
377 const Descriptor* desc = GetAnyTypeDescriptor(any);
378 if (!desc) return {};
379 std::unique_ptr<Message> message(
380 any.GetReflection()->GetMessageFactory()->GetPrototype(desc)->New());
381 message->ParsePartialFromString(std::string(any.value()));
382 return message;
383 }
384
CastToAny(const Message * message)385 const Any* CastToAny(const Message* message) {
386 return Any::GetDescriptor() == message->GetDescriptor()
387 ? static_cast<const Any*>(message)
388 : nullptr;
389 }
390
CastToAny(Message * message)391 Any* CastToAny(Message* message) {
392 return Any::GetDescriptor() == message->GetDescriptor()
393 ? static_cast<Any*>(message)
394 : nullptr;
395 }
396
UnpackIfAny(const Message & message)397 std::unique_ptr<Message> UnpackIfAny(const Message& message) {
398 if (const Any* any = CastToAny(&message)) return UnpackAny(*any);
399 return {};
400 }
401
UnpackAny(const Message & message,UnpackedAny * result)402 void UnpackAny(const Message& message, UnpackedAny* result) {
403 if (std::unique_ptr<Message> any = UnpackIfAny(message)) {
404 UnpackAny(*any, result);
405 result->emplace(&message, std::move(any));
406 return;
407 }
408
409 const Descriptor* descriptor = message.GetDescriptor();
410 const Reflection* reflection = message.GetReflection();
411
412 for (int i = 0; i < descriptor->field_count(); ++i) {
413 const FieldDescriptor* field = descriptor->field(i);
414 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
415 if (field->is_repeated()) {
416 const int field_size = reflection->FieldSize(message, field);
417 for (int j = 0; j < field_size; ++j) {
418 UnpackAny(reflection->GetRepeatedMessage(message, field, j), result);
419 }
420 } else if (reflection->HasField(message, field)) {
421 UnpackAny(reflection->GetMessage(message, field), result);
422 }
423 }
424 }
425 }
426
427 class PostProcessing {
428 public:
429 using PostProcessors =
430 std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;
431
PostProcessing(bool keep_initialized,const PostProcessors & post_processors,const UnpackedAny & any,RandomEngine * random)432 PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
433 const UnpackedAny& any, RandomEngine* random)
434 : keep_initialized_(keep_initialized),
435 post_processors_(post_processors),
436 any_(any),
437 random_(random) {}
438
Run(Message * message,int max_depth)439 void Run(Message* message, int max_depth) {
440 --max_depth;
441 const Descriptor* descriptor = message->GetDescriptor();
442
443 // Apply custom mutators in nested messages before packing any.
444 const Reflection* reflection = message->GetReflection();
445 for (int i = 0; i < descriptor->field_count(); i++) {
446 const FieldDescriptor* field = descriptor->field(i);
447 if (keep_initialized_ &&
448 (field->is_required() || descriptor->options().map_entry()) &&
449 !reflection->HasField(*message, field)) {
450 CreateDefaultField()(FieldInstance(message, field));
451 }
452
453 if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
454
455 if (max_depth < 0 && !field->is_required()) {
456 // Clear deep optional fields to avoid stack overflow.
457 reflection->ClearField(message, field);
458 if (field->is_repeated())
459 assert(!reflection->FieldSize(*message, field));
460 else
461 assert(!reflection->HasField(*message, field));
462 continue;
463 }
464
465 if (field->is_repeated()) {
466 const int field_size = reflection->FieldSize(*message, field);
467 for (int j = 0; j < field_size; ++j) {
468 Message* nested_message =
469 reflection->MutableRepeatedMessage(message, field, j);
470 Run(nested_message, max_depth);
471 }
472 } else if (reflection->HasField(*message, field)) {
473 Message* nested_message = reflection->MutableMessage(message, field);
474 Run(nested_message, max_depth);
475 }
476 }
477
478 if (Any* any = CastToAny(message)) {
479 if (max_depth < 0) {
480 // Clear deep Any fields to avoid stack overflow.
481 any->Clear();
482 } else {
483 auto It = any_.find(message);
484 if (It != any_.end()) {
485 Run(It->second.get(), max_depth);
486 std::string value;
487 It->second->SerializePartialToString(&value);
488 *any->mutable_value() = value;
489 }
490 }
491 }
492
493 // Call user callback after message trimmed, initialized and packed.
494 auto range = post_processors_.equal_range(descriptor);
495 for (auto it = range.first; it != range.second; ++it)
496 it->second(message, (*random_)());
497 }
498
499 private:
500 bool keep_initialized_;
501 const PostProcessors& post_processors_;
502 const UnpackedAny& any_;
503 RandomEngine* random_;
504 };
505
506 } // namespace
507
508 class FieldMutator {
509 public:
FieldMutator(int size_increase_hint,bool enforce_changes,bool enforce_utf8_strings,const ConstMessages & sources,Mutator * mutator)510 FieldMutator(int size_increase_hint, bool enforce_changes,
511 bool enforce_utf8_strings, const ConstMessages& sources,
512 Mutator* mutator)
513 : size_increase_hint_(size_increase_hint),
514 enforce_changes_(enforce_changes),
515 enforce_utf8_strings_(enforce_utf8_strings),
516 sources_(sources),
517 mutator_(mutator) {}
518
Mutate(int32_t * value) const519 void Mutate(int32_t* value) const {
520 RepeatMutate(value, std::bind(&Mutator::MutateInt32, mutator_, _1));
521 }
522
Mutate(int64_t * value) const523 void Mutate(int64_t* value) const {
524 RepeatMutate(value, std::bind(&Mutator::MutateInt64, mutator_, _1));
525 }
526
Mutate(uint32_t * value) const527 void Mutate(uint32_t* value) const {
528 RepeatMutate(value, std::bind(&Mutator::MutateUInt32, mutator_, _1));
529 }
530
Mutate(uint64_t * value) const531 void Mutate(uint64_t* value) const {
532 RepeatMutate(value, std::bind(&Mutator::MutateUInt64, mutator_, _1));
533 }
534
Mutate(float * value) const535 void Mutate(float* value) const {
536 RepeatMutate(value, std::bind(&Mutator::MutateFloat, mutator_, _1));
537 }
538
Mutate(double * value) const539 void Mutate(double* value) const {
540 RepeatMutate(value, std::bind(&Mutator::MutateDouble, mutator_, _1));
541 }
542
Mutate(bool * value) const543 void Mutate(bool* value) const {
544 RepeatMutate(value, std::bind(&Mutator::MutateBool, mutator_, _1));
545 }
546
Mutate(FieldInstance::Enum * value) const547 void Mutate(FieldInstance::Enum* value) const {
548 RepeatMutate(&value->index,
549 std::bind(&Mutator::MutateEnum, mutator_, _1, value->count));
550 assert(value->index < value->count);
551 }
552
Mutate(std::string * value) const553 void Mutate(std::string* value) const {
554 if (enforce_utf8_strings_) {
555 RepeatMutate(value, std::bind(&Mutator::MutateUtf8String, mutator_, _1,
556 size_increase_hint_));
557 } else {
558 RepeatMutate(value, std::bind(&Mutator::MutateString, mutator_, _1,
559 size_increase_hint_));
560 }
561 }
562
Mutate(std::unique_ptr<Message> * message) const563 void Mutate(std::unique_ptr<Message>* message) const {
564 assert(!enforce_changes_);
565 assert(*message);
566 if (GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_))
567 return;
568 mutator_->MutateImpl(sources_, {message->get()}, false,
569 size_increase_hint_);
570 }
571
572 private:
573 template <class T, class F>
RepeatMutate(T * value,F mutate) const574 void RepeatMutate(T* value, F mutate) const {
575 if (!enforce_changes_ &&
576 GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_)) {
577 return;
578 }
579 T tmp = *value;
580 for (int i = 0; i < 10; ++i) {
581 *value = mutate(*value);
582 if (!enforce_changes_ || *value != tmp) return;
583 }
584 }
585
586 int size_increase_hint_;
587 size_t enforce_changes_;
588 bool enforce_utf8_strings_;
589 const ConstMessages& sources_;
590 Mutator* mutator_;
591 };
592
593 namespace {
594
595 struct MutateField : public FieldFunction<MutateField> {
596 template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30211::MutateField597 void ForType(const FieldInstance& field, int size_increase_hint,
598 const ConstMessages& sources, Mutator* mutator) const {
599 T value;
600 field.Load(&value);
601 FieldMutator(size_increase_hint, true, field.EnforceUtf8(), sources,
602 mutator)
603 .Mutate(&value);
604 field.Store(value);
605 }
606 };
607
608 struct CreateField : public FieldFunction<CreateField> {
609 public:
610 template <class T>
ForTypeprotobuf_mutator::__anon4db2e1e30211::CreateField611 void ForType(const FieldInstance& field, int size_increase_hint,
612 const ConstMessages& sources, Mutator* mutator) const {
613 T value;
614 field.GetDefault(&value);
615 FieldMutator field_mutator(size_increase_hint,
616 false /* defaults could be useful */,
617 field.EnforceUtf8(), sources, mutator);
618 field_mutator.Mutate(&value);
619 field.Create(value);
620 }
621 };
622
623 } // namespace
624
Seed(uint32_t value)625 void Mutator::Seed(uint32_t value) { random_.seed(value); }
626
Mutate(Message * message,size_t max_size_hint)627 void Mutator::Mutate(Message* message, size_t max_size_hint) {
628 UnpackedAny any;
629 UnpackAny(*message, &any);
630
631 Messages messages;
632 messages.reserve(any.size() + 1);
633 messages.push_back(message);
634 for (const auto& kv : any) messages.push_back(kv.second.get());
635
636 ConstMessages sources(messages.begin(), messages.end());
637 MutateImpl(sources, messages, false,
638 static_cast<int>(max_size_hint) -
639 static_cast<int>(message->ByteSizeLong()));
640
641 PostProcessing(keep_initialized_, post_processors_, any, &random_)
642 .Run(message, kMaxInitializeDepth);
643 assert(IsInitialized(*message));
644 }
645
CrossOver(const Message & message1,Message * message2,size_t max_size_hint)646 void Mutator::CrossOver(const Message& message1, Message* message2,
647 size_t max_size_hint) {
648 UnpackedAny any;
649 UnpackAny(*message2, &any);
650
651 Messages messages;
652 messages.reserve(any.size() + 1);
653 messages.push_back(message2);
654 for (auto& kv : any) messages.push_back(kv.second.get());
655
656 UnpackAny(message1, &any);
657
658 ConstMessages sources;
659 sources.reserve(any.size() + 2);
660 sources.push_back(&message1);
661 sources.push_back(message2);
662 for (const auto& kv : any) sources.push_back(kv.second.get());
663
664 MutateImpl(sources, messages, true,
665 static_cast<int>(max_size_hint) -
666 static_cast<int>(message2->ByteSizeLong()));
667
668 PostProcessing(keep_initialized_, post_processors_, any, &random_)
669 .Run(message2, kMaxInitializeDepth);
670 assert(IsInitialized(*message2));
671 }
672
RegisterPostProcessor(const Descriptor * desc,PostProcess callback)673 void Mutator::RegisterPostProcessor(const Descriptor* desc,
674 PostProcess callback) {
675 post_processors_.emplace(desc, callback);
676 }
677
MutateImpl(const ConstMessages & sources,const Messages & messages,bool copy_clone_only,int size_increase_hint)678 bool Mutator::MutateImpl(const ConstMessages& sources, const Messages& messages,
679 bool copy_clone_only, int size_increase_hint) {
680 MutationBitset mutations;
681 if (copy_clone_only) {
682 mutations[static_cast<size_t>(Mutation::Copy)] = true;
683 mutations[static_cast<size_t>(Mutation::Clone)] = true;
684 } else if (size_increase_hint <= 16) {
685 mutations[static_cast<size_t>(Mutation::Delete)] = true;
686 } else {
687 mutations.set();
688 mutations[static_cast<size_t>(Mutation::Copy)] = false;
689 mutations[static_cast<size_t>(Mutation::Clone)] = false;
690 }
691 while (mutations.any()) {
692 MutationSampler mutation(keep_initialized_, mutations, &random_);
693 for (Message* message : messages) mutation.Sample(message);
694
695 switch (mutation.mutation()) {
696 case Mutation::None:
697 return true;
698 case Mutation::Add:
699 CreateField()(mutation.field(), size_increase_hint, sources, this);
700 return true;
701 case Mutation::Mutate:
702 MutateField()(mutation.field(), size_increase_hint, sources, this);
703 return true;
704 case Mutation::Delete:
705 DeleteField()(mutation.field());
706 return true;
707 case Mutation::Clone: {
708 CreateDefaultField()(mutation.field());
709 DataSourceSampler source_sampler(mutation.field(), &random_,
710 size_increase_hint);
711 for (const Message* source : sources) source_sampler.Sample(*source);
712 if (source_sampler.IsEmpty()) {
713 if (!IsProto3SimpleField(*mutation.field().descriptor()))
714 return true; // CreateField is enough for proto2.
715 break;
716 }
717 CopyField()(source_sampler.field(), mutation.field());
718 return true;
719 }
720 case Mutation::Copy: {
721 DataSourceSampler source_sampler(mutation.field(), &random_,
722 size_increase_hint);
723 for (const Message* source : sources) source_sampler.Sample(*source);
724 if (source_sampler.IsEmpty()) break;
725 CopyField()(source_sampler.field(), mutation.field());
726 return true;
727 }
728 default:
729 assert(false && "unexpected mutation");
730 return false;
731 }
732
733 // Don't try same mutation next time.
734 mutations[static_cast<size_t>(mutation.mutation())] = false;
735 }
736 return false;
737 }
738
MutateInt32(int32_t value)739 int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }
740
MutateInt64(int64_t value)741 int64_t Mutator::MutateInt64(int64_t value) { return FlipBit(value, &random_); }
742
MutateUInt32(uint32_t value)743 uint32_t Mutator::MutateUInt32(uint32_t value) {
744 return FlipBit(value, &random_);
745 }
746
MutateUInt64(uint64_t value)747 uint64_t Mutator::MutateUInt64(uint64_t value) {
748 return FlipBit(value, &random_);
749 }
750
MutateFloat(float value)751 float Mutator::MutateFloat(float value) { return FlipBit(value, &random_); }
752
MutateDouble(double value)753 double Mutator::MutateDouble(double value) { return FlipBit(value, &random_); }
754
MutateBool(bool value)755 bool Mutator::MutateBool(bool value) { return !value; }
756
MutateEnum(size_t index,size_t item_count)757 size_t Mutator::MutateEnum(size_t index, size_t item_count) {
758 if (item_count <= 1) return 0;
759 return (index + 1 + GetRandomIndex(&random_, item_count - 1)) % item_count;
760 }
761
MutateString(const std::string & value,int size_increase_hint)762 std::string Mutator::MutateString(const std::string& value,
763 int size_increase_hint) {
764 std::string result = value;
765
766 while (!result.empty() && GetRandomBool(&random_)) {
767 result.erase(GetRandomIndex(&random_, result.size()), 1);
768 }
769
770 while (size_increase_hint > 0 &&
771 result.size() < static_cast<size_t>(size_increase_hint) &&
772 GetRandomBool(&random_)) {
773 size_t index = GetRandomIndex(&random_, result.size() + 1);
774 result.insert(result.begin() + index, GetRandomIndex(&random_, 1 << 8));
775 }
776
777 if (result != value) return result;
778
779 if (result.empty()) {
780 result.push_back(GetRandomIndex(&random_, 1 << 8));
781 return result;
782 }
783
784 if (!result.empty())
785 FlipBit(result.size(), reinterpret_cast<uint8_t*>(&result[0]), &random_);
786 return result;
787 }
788
MutateUtf8String(const std::string & value,int size_increase_hint)789 std::string Mutator::MutateUtf8String(const std::string& value,
790 int size_increase_hint) {
791 std::string str = MutateString(value, size_increase_hint);
792 FixUtf8String(&str, &random_);
793 return str;
794 }
795
IsInitialized(const Message & message) const796 bool Mutator::IsInitialized(const Message& message) const {
797 if (!keep_initialized_ || message.IsInitialized()) return true;
798 std::cerr << "Uninitialized: " << message.DebugString() << "\n";
799 return false;
800 }
801
802 } // namespace protobuf_mutator
803