1 /*
2 american fuzzy lop++ - LLVM CmpLog instrumentation
3 --------------------------------------------------
4
5 Written by Andrea Fioraldi <andreafioraldi@gmail.com>
6
7 Copyright 2015, 2016 Google Inc. All rights reserved.
8 Copyright 2019-2022 AFLplusplus Project. All rights reserved.
9
10 Licensed under the Apache License, Version 2.0 (the "License");
11 you may not use this file except in compliance with the License.
12 You may obtain a copy of the License at:
13
14 https://www.apache.org/licenses/LICENSE-2.0
15
16 */
17
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21
22 #include <iostream>
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27
28 #include "llvm/Config/llvm-config.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #if LLVM_MAJOR >= 11
35 #include "llvm/Passes/PassPlugin.h"
36 #include "llvm/Passes/PassBuilder.h"
37 #include "llvm/IR/PassManager.h"
38 #else
39 #include "llvm/IR/LegacyPassManager.h"
40 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
41 #endif
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Analysis/ValueTracking.h"
45 #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */
46 #include "llvm/Passes/OptimizationLevel.h"
47 #endif
48
49 #if LLVM_VERSION_MAJOR >= 4 || \
50 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
51 #include "llvm/IR/Verifier.h"
52 #include "llvm/IR/DebugInfo.h"
53 #include "llvm/Support/raw_ostream.h"
54 #else
55 #include "llvm/Analysis/Verifier.h"
56 #include "llvm/DebugInfo.h"
57 #define nullptr 0
58 #endif
59
60 #include <set>
61 #include "afl-llvm-common.h"
62
63 using namespace llvm;
64
65 namespace {
66
67 #if LLVM_MAJOR >= 11 /* use new pass manager */
68 class CmpLogInstructions : public PassInfoMixin<CmpLogInstructions> {
69
70 public:
CmpLogInstructions()71 CmpLogInstructions() {
72
73 initInstrumentList();
74
75 }
76
77 #else
78 class CmpLogInstructions : public ModulePass {
79
80 public:
81 static char ID;
82 CmpLogInstructions() : ModulePass(ID) {
83
84 initInstrumentList();
85
86 }
87
88 #endif
89
90 #if LLVM_MAJOR >= 11 /* use new pass manager */
91 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
92 #else
93 bool runOnModule(Module &M) override;
94
95 #if LLVM_VERSION_MAJOR >= 4
getPassName() const96 StringRef getPassName() const override {
97
98 #else
99 const char *getPassName() const override {
100
101 #endif
102 return "cmplog instructions";
103
104 }
105
106 #endif
107
108 private:
109 bool hookInstrs(Module &M);
110
111 };
112
113 } // namespace
114
115 #if LLVM_MAJOR >= 11
116 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo()117 llvmGetPassPluginInfo() {
118
119 return {LLVM_PLUGIN_API_VERSION, "cmploginstructions", "v0.1",
120 /* lambda to insert our pass into the pass pipeline. */
121 [](PassBuilder &PB) {
122
123 #if LLVM_VERSION_MAJOR <= 13
124 using OptimizationLevel = typename PassBuilder::OptimizationLevel;
125 #endif
126 PB.registerOptimizerLastEPCallback(
127 [](ModulePassManager &MPM, OptimizationLevel OL) {
128
129 MPM.addPass(CmpLogInstructions());
130
131 });
132
133 }};
134
135 }
136
137 #else
138 char CmpLogInstructions::ID = 0;
139 #endif
140
141 template <class Iterator>
Unique(Iterator first,Iterator last)142 Iterator Unique(Iterator first, Iterator last) {
143
144 while (first != last) {
145
146 Iterator next(first);
147 last = std::remove(++next, last, *first);
148 first = next;
149
150 }
151
152 return last;
153
154 }
155
hookInstrs(Module & M)156 bool CmpLogInstructions::hookInstrs(Module &M) {
157
158 std::vector<Instruction *> icomps;
159 LLVMContext & C = M.getContext();
160
161 Type * VoidTy = Type::getVoidTy(C);
162 IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
163 IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
164 IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
165 IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
166 IntegerType *Int128Ty = IntegerType::getInt128Ty(C);
167
168 #if LLVM_VERSION_MAJOR >= 9
169 FunctionCallee
170 #else
171 Constant *
172 #endif
173 c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
174 Int8Ty
175 #if LLVM_VERSION_MAJOR < 5
176 ,
177 NULL
178 #endif
179 );
180 #if LLVM_VERSION_MAJOR >= 9
181 FunctionCallee cmplogHookIns1 = c1;
182 #else
183 Function *cmplogHookIns1 = cast<Function>(c1);
184 #endif
185
186 #if LLVM_VERSION_MAJOR >= 9
187 FunctionCallee
188 #else
189 Constant *
190 #endif
191 c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
192 Int8Ty
193 #if LLVM_VERSION_MAJOR < 5
194 ,
195 NULL
196 #endif
197 );
198 #if LLVM_VERSION_MAJOR >= 9
199 FunctionCallee cmplogHookIns2 = c2;
200 #else
201 Function *cmplogHookIns2 = cast<Function>(c2);
202 #endif
203
204 #if LLVM_VERSION_MAJOR >= 9
205 FunctionCallee
206 #else
207 Constant *
208 #endif
209 c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
210 Int8Ty
211 #if LLVM_VERSION_MAJOR < 5
212 ,
213 NULL
214 #endif
215 );
216 #if LLVM_VERSION_MAJOR >= 9
217 FunctionCallee cmplogHookIns4 = c4;
218 #else
219 Function *cmplogHookIns4 = cast<Function>(c4);
220 #endif
221
222 #if LLVM_VERSION_MAJOR >= 9
223 FunctionCallee
224 #else
225 Constant *
226 #endif
227 c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
228 Int8Ty
229 #if LLVM_VERSION_MAJOR < 5
230 ,
231 NULL
232 #endif
233 );
234 #if LLVM_VERSION_MAJOR >= 9
235 FunctionCallee cmplogHookIns8 = c8;
236 #else
237 Function *cmplogHookIns8 = cast<Function>(c8);
238 #endif
239
240 #if LLVM_VERSION_MAJOR >= 9
241 FunctionCallee
242 #else
243 Constant *
244 #endif
245 c16 = M.getOrInsertFunction("__cmplog_ins_hook16", VoidTy, Int128Ty,
246 Int128Ty, Int8Ty
247 #if LLVM_VERSION_MAJOR < 5
248 ,
249 NULL
250 #endif
251 );
252 #if LLVM_VERSION_MAJOR < 9
253 Function *cmplogHookIns16 = cast<Function>(c16);
254 #else
255 FunctionCallee cmplogHookIns16 = c16;
256 #endif
257
258 #if LLVM_VERSION_MAJOR >= 9
259 FunctionCallee
260 #else
261 Constant *
262 #endif
263 cN = M.getOrInsertFunction("__cmplog_ins_hookN", VoidTy, Int128Ty,
264 Int128Ty, Int8Ty, Int8Ty
265 #if LLVM_VERSION_MAJOR < 5
266 ,
267 NULL
268 #endif
269 );
270 #if LLVM_VERSION_MAJOR >= 9
271 FunctionCallee cmplogHookInsN = cN;
272 #else
273 Function *cmplogHookInsN = cast<Function>(cN);
274 #endif
275
276 GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
277
278 if (!AFLCmplogPtr) {
279
280 AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
281 GlobalValue::ExternalWeakLinkage, 0,
282 "__afl_cmp_map");
283
284 }
285
286 Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
287
288 /* iterate over all functions, bbs and instruction and add suitable calls */
289 for (auto &F : M) {
290
291 if (!isInInstrumentList(&F, MNAME)) continue;
292
293 for (auto &BB : F) {
294
295 for (auto &IN : BB) {
296
297 CmpInst *selectcmpInst = nullptr;
298 if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
299
300 icomps.push_back(selectcmpInst);
301
302 }
303
304 }
305
306 }
307
308 }
309
310 if (icomps.size()) {
311
312 // if (!be_quiet) errs() << "Hooking " << icomps.size() <<
313 // " cmp instructions\n";
314
315 for (auto &selectcmpInst : icomps) {
316
317 IRBuilder<> IRB2(selectcmpInst->getParent());
318 IRB2.SetInsertPoint(selectcmpInst);
319 LoadInst *CmpPtr = IRB2.CreateLoad(
320 #if LLVM_VERSION_MAJOR >= 14
321 PointerType::get(Int8Ty, 0),
322 #endif
323 AFLCmplogPtr);
324 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
325 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
326 auto ThenTerm =
327 SplitBlockAndInsertIfThen(is_not_null, selectcmpInst, false);
328
329 IRBuilder<> IRB(ThenTerm);
330
331 Value *op0 = selectcmpInst->getOperand(0);
332 Value *op1 = selectcmpInst->getOperand(1);
333 Value *op0_saved = op0, *op1_saved = op1;
334 auto ty0 = op0->getType();
335 auto ty1 = op1->getType();
336
337 IntegerType *intTyOp0 = NULL;
338 IntegerType *intTyOp1 = NULL;
339 unsigned max_size = 0, cast_size = 0;
340 unsigned attr = 0, vector_cnt = 0, is_fp = 0;
341 CmpInst * cmpInst = dyn_cast<CmpInst>(selectcmpInst);
342
343 if (!cmpInst) { continue; }
344
345 switch (cmpInst->getPredicate()) {
346
347 case CmpInst::ICMP_NE:
348 case CmpInst::FCMP_UNE:
349 case CmpInst::FCMP_ONE:
350 break;
351 case CmpInst::ICMP_EQ:
352 case CmpInst::FCMP_UEQ:
353 case CmpInst::FCMP_OEQ:
354 attr += 1;
355 break;
356 case CmpInst::ICMP_UGT:
357 case CmpInst::ICMP_SGT:
358 case CmpInst::FCMP_OGT:
359 case CmpInst::FCMP_UGT:
360 attr += 2;
361 break;
362 case CmpInst::ICMP_UGE:
363 case CmpInst::ICMP_SGE:
364 case CmpInst::FCMP_OGE:
365 case CmpInst::FCMP_UGE:
366 attr += 3;
367 break;
368 case CmpInst::ICMP_ULT:
369 case CmpInst::ICMP_SLT:
370 case CmpInst::FCMP_OLT:
371 case CmpInst::FCMP_ULT:
372 attr += 4;
373 break;
374 case CmpInst::ICMP_ULE:
375 case CmpInst::ICMP_SLE:
376 case CmpInst::FCMP_OLE:
377 case CmpInst::FCMP_ULE:
378 attr += 5;
379 break;
380 default:
381 break;
382
383 }
384
385 if (selectcmpInst->getOpcode() == Instruction::FCmp) {
386
387 if (ty0->isVectorTy()) {
388
389 VectorType *tt = dyn_cast<VectorType>(ty0);
390 if (!tt) {
391
392 fprintf(stderr, "Warning: cmplog cmp vector is not a vector!\n");
393 continue;
394
395 }
396
397 #if (LLVM_VERSION_MAJOR >= 12)
398 vector_cnt = tt->getElementCount().getKnownMinValue();
399 ty0 = tt->getElementType();
400 #endif
401
402 }
403
404 if (ty0->isHalfTy()
405 #if LLVM_VERSION_MAJOR >= 11
406 || ty0->isBFloatTy()
407 #endif
408 )
409 max_size = 16;
410 else if (ty0->isFloatTy())
411 max_size = 32;
412 else if (ty0->isDoubleTy())
413 max_size = 64;
414 else if (ty0->isX86_FP80Ty())
415 max_size = 80;
416 else if (ty0->isFP128Ty() || ty0->isPPC_FP128Ty())
417 max_size = 128;
418 #if (LLVM_VERSION_MAJOR >= 12)
419 else if (ty0->getTypeID() != llvm::Type::PointerTyID && !be_quiet)
420 fprintf(stderr, "Warning: unsupported cmp type for cmplog: %u!\n",
421 ty0->getTypeID());
422 #endif
423
424 attr += 8;
425 is_fp = 1;
426 // fprintf(stderr, "HAVE FP %u!\n", vector_cnt);
427
428 } else {
429
430 if (ty0->isVectorTy()) {
431
432 #if (LLVM_VERSION_MAJOR >= 12)
433 VectorType *tt = dyn_cast<VectorType>(ty0);
434 if (!tt) {
435
436 fprintf(stderr, "Warning: cmplog cmp vector is not a vector!\n");
437 continue;
438
439 }
440
441 vector_cnt = tt->getElementCount().getKnownMinValue();
442 ty1 = ty0 = tt->getElementType();
443 #endif
444
445 }
446
447 intTyOp0 = dyn_cast<IntegerType>(ty0);
448 intTyOp1 = dyn_cast<IntegerType>(ty1);
449
450 if (intTyOp0 && intTyOp1) {
451
452 max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
453 ? intTyOp0->getBitWidth()
454 : intTyOp1->getBitWidth();
455
456 } else {
457
458 #if (LLVM_VERSION_MAJOR >= 12)
459 if (ty0->getTypeID() != llvm::Type::PointerTyID && !be_quiet) {
460
461 fprintf(stderr, "Warning: unsupported cmp type for cmplog: %u\n",
462 ty0->getTypeID());
463
464 }
465
466 #endif
467
468 }
469
470 }
471
472 if (!max_size || max_size < 16) {
473
474 // fprintf(stderr, "too small\n");
475 continue;
476
477 }
478
479 if (max_size % 8) { max_size = (((max_size / 8) + 1) * 8); }
480
481 if (max_size > 128) {
482
483 if (!be_quiet) {
484
485 fprintf(stderr,
486 "Cannot handle this compare bit size: %u (truncating)\n",
487 max_size);
488
489 }
490
491 max_size = 128;
492
493 }
494
495 // do we need to cast?
496 switch (max_size) {
497
498 case 8:
499 case 16:
500 case 32:
501 case 64:
502 case 128:
503 cast_size = max_size;
504 break;
505 default:
506 cast_size = 128;
507
508 }
509
510 // XXX FIXME BUG TODO
511 if (is_fp && vector_cnt) { continue; }
512
513 uint64_t cur = 0, last_val0 = 0, last_val1 = 0, cur_val;
514
515 while (1) {
516
517 std::vector<Value *> args;
518 bool skip = false;
519
520 if (vector_cnt) {
521
522 op0 = IRB.CreateExtractElement(op0_saved, cur);
523 op1 = IRB.CreateExtractElement(op1_saved, cur);
524 /*
525 std::string errMsg;
526 raw_string_ostream os(errMsg);
527 op0_saved->print(os);
528 fprintf(stderr, "X: %s\n", os.str().c_str());
529 */
530 if (is_fp) {
531
532 /*
533 ConstantFP *i0 = dyn_cast<ConstantFP>(op0);
534 ConstantFP *i1 = dyn_cast<ConstantFP>(op1);
535 // BUG FIXME TODO: this is null ... but why?
536 // fprintf(stderr, "%p %p\n", i0, i1);
537 if (i0) {
538
539 cur_val = (uint64_t)i0->getValue().convertToDouble();
540 if (last_val0 && last_val0 == cur_val) { skip = true;
541
542 } last_val0 = cur_val;
543
544 }
545
546 if (i1) {
547
548 cur_val = (uint64_t)i1->getValue().convertToDouble();
549 if (last_val1 && last_val1 == cur_val) { skip = true;
550
551 } last_val1 = cur_val;
552
553 }
554
555 */
556
557 } else {
558
559 ConstantInt *i0 = dyn_cast<ConstantInt>(op0);
560 ConstantInt *i1 = dyn_cast<ConstantInt>(op1);
561 if (i0 && i0->uge(0xffffffffffffffff) == false) {
562
563 cur_val = i0->getZExtValue();
564 if (last_val0 && last_val0 == cur_val) { skip = true; }
565 last_val0 = cur_val;
566
567 }
568
569 if (i1 && i1->uge(0xffffffffffffffff) == false) {
570
571 cur_val = i1->getZExtValue();
572 if (last_val1 && last_val1 == cur_val) { skip = true; }
573 last_val1 = cur_val;
574
575 }
576
577 }
578
579 }
580
581 if (!skip) {
582
583 // errs() << "[CMPLOG] cmp " << *cmpInst << "(in function " <<
584 // cmpInst->getFunction()->getName() << ")\n";
585
586 // first bitcast to integer type of the same bitsize as the original
587 // type (this is a nop, if already integer)
588 Value *op0_i = IRB.CreateBitCast(
589 op0, IntegerType::get(C, ty0->getPrimitiveSizeInBits()));
590 // then create a int cast, which does zext, trunc or bitcast. In our
591 // case usually zext to the next larger supported type (this is a nop
592 // if already the right type)
593 Value *V0 =
594 IRB.CreateIntCast(op0_i, IntegerType::get(C, cast_size), false);
595 args.push_back(V0);
596 Value *op1_i = IRB.CreateBitCast(
597 op1, IntegerType::get(C, ty1->getPrimitiveSizeInBits()));
598 Value *V1 =
599 IRB.CreateIntCast(op1_i, IntegerType::get(C, cast_size), false);
600 args.push_back(V1);
601
602 // errs() << "[CMPLOG] casted parameters:\n0: " << *V0 << "\n1: " <<
603 // *V1
604 // << "\n";
605
606 ConstantInt *attribute = ConstantInt::get(Int8Ty, attr);
607 args.push_back(attribute);
608
609 if (cast_size != max_size) {
610
611 ConstantInt *bitsize = ConstantInt::get(Int8Ty, (max_size / 8) - 1);
612 args.push_back(bitsize);
613
614 }
615
616 // fprintf(stderr, "_ExtInt(%u) castTo %u with attr %u didcast %u\n",
617 // max_size, cast_size, attr);
618
619 switch (cast_size) {
620
621 case 8:
622 IRB.CreateCall(cmplogHookIns1, args);
623 break;
624 case 16:
625 IRB.CreateCall(cmplogHookIns2, args);
626 break;
627 case 32:
628 IRB.CreateCall(cmplogHookIns4, args);
629 break;
630 case 64:
631 IRB.CreateCall(cmplogHookIns8, args);
632 break;
633 case 128:
634 if (max_size == 128) {
635
636 IRB.CreateCall(cmplogHookIns16, args);
637
638 } else {
639
640 IRB.CreateCall(cmplogHookInsN, args);
641
642 }
643
644 break;
645
646 }
647
648 }
649
650 /* else fprintf(stderr, "skipped\n"); */
651
652 ++cur;
653 if (cur >= vector_cnt) { break; }
654
655 }
656
657 }
658
659 }
660
661 if (icomps.size())
662 return true;
663 else
664 return false;
665
666 }
667
668 #if LLVM_MAJOR >= 11 /* use new pass manager */
run(Module & M,ModuleAnalysisManager & MAM)669 PreservedAnalyses CmpLogInstructions::run(Module & M,
670 ModuleAnalysisManager &MAM) {
671
672 #else
673 bool CmpLogInstructions::runOnModule(Module &M) {
674
675 #endif
676
677 if (getenv("AFL_QUIET") == NULL)
678 printf("Running cmplog-instructions-pass by andreafioraldi@gmail.com\n");
679 else
680 be_quiet = 1;
681 hookInstrs(M);
682 verifyModule(M);
683
684 #if LLVM_MAJOR >= 11 /* use new pass manager */
685 return PreservedAnalyses::all();
686 #else
687 return true;
688 #endif
689
690 }
691
692 #if LLVM_MAJOR < 11 /* use old pass manager */
693 static void registerCmpLogInstructionsPass(const PassManagerBuilder &,
694 legacy::PassManagerBase &PM) {
695
696 auto p = new CmpLogInstructions();
697 PM.add(p);
698
699 }
700
701 static RegisterStandardPasses RegisterCmpLogInstructionsPass(
702 PassManagerBuilder::EP_OptimizerLast, registerCmpLogInstructionsPass);
703
704 static RegisterStandardPasses RegisterCmpLogInstructionsPass0(
705 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogInstructionsPass);
706
707 #if LLVM_VERSION_MAJOR >= 11
708 static RegisterStandardPasses RegisterCmpLogInstructionsPassLTO(
709 PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
710 registerCmpLogInstructionsPass);
711 #endif
712 #endif
713
714