• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2    american fuzzy lop++ - LLVM CmpLog instrumentation
3    --------------------------------------------------
4 
5    Written by Andrea Fioraldi <andreafioraldi@gmail.com>
6 
7    Copyright 2015, 2016 Google Inc. All rights reserved.
8    Copyright 2019-2022 AFLplusplus Project. All rights reserved.
9 
10    Licensed under the Apache License, Version 2.0 (the "License");
11    you may not use this file except in compliance with the License.
12    You may obtain a copy of the License at:
13 
14      https://www.apache.org/licenses/LICENSE-2.0
15 
16 */
17 
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21 
22 #include <iostream>
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27 
28 #include "llvm/Config/llvm-config.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #if LLVM_MAJOR >= 11
35   #include "llvm/Passes/PassPlugin.h"
36   #include "llvm/Passes/PassBuilder.h"
37   #include "llvm/IR/PassManager.h"
38 #else
39   #include "llvm/IR/LegacyPassManager.h"
40   #include "llvm/Transforms/IPO/PassManagerBuilder.h"
41 #endif
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Analysis/ValueTracking.h"
45 #if LLVM_VERSION_MAJOR >= 14                /* how about stable interfaces? */
46   #include "llvm/Passes/OptimizationLevel.h"
47 #endif
48 
49 #if LLVM_VERSION_MAJOR >= 4 || \
50     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
51   #include "llvm/IR/Verifier.h"
52   #include "llvm/IR/DebugInfo.h"
53   #include "llvm/Support/raw_ostream.h"
54 #else
55   #include "llvm/Analysis/Verifier.h"
56   #include "llvm/DebugInfo.h"
57   #define nullptr 0
58 #endif
59 
60 #include <set>
61 #include "afl-llvm-common.h"
62 
63 using namespace llvm;
64 
65 namespace {
66 
67 #if LLVM_MAJOR >= 11                                /* use new pass manager */
68 class CmpLogInstructions : public PassInfoMixin<CmpLogInstructions> {
69 
70  public:
CmpLogInstructions()71   CmpLogInstructions() {
72 
73     initInstrumentList();
74 
75   }
76 
77 #else
78 class CmpLogInstructions : public ModulePass {
79 
80  public:
81   static char ID;
82   CmpLogInstructions() : ModulePass(ID) {
83 
84     initInstrumentList();
85 
86   }
87 
88 #endif
89 
90 #if LLVM_MAJOR >= 11                                /* use new pass manager */
91   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
92 #else
93   bool      runOnModule(Module &M) override;
94 
95   #if LLVM_VERSION_MAJOR >= 4
getPassName() const96   StringRef getPassName() const override {
97 
98   #else
99   const char *getPassName() const override {
100 
101   #endif
102     return "cmplog instructions";
103 
104   }
105 
106 #endif
107 
108  private:
109   bool hookInstrs(Module &M);
110 
111 };
112 
113 }  // namespace
114 
115 #if LLVM_MAJOR >= 11
116 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo()117 llvmGetPassPluginInfo() {
118 
119   return {LLVM_PLUGIN_API_VERSION, "cmploginstructions", "v0.1",
120           /* lambda to insert our pass into the pass pipeline. */
121           [](PassBuilder &PB) {
122 
123   #if LLVM_VERSION_MAJOR <= 13
124             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
125   #endif
126             PB.registerOptimizerLastEPCallback(
127                 [](ModulePassManager &MPM, OptimizationLevel OL) {
128 
129                   MPM.addPass(CmpLogInstructions());
130 
131                 });
132 
133           }};
134 
135 }
136 
137 #else
138 char CmpLogInstructions::ID = 0;
139 #endif
140 
141 template <class Iterator>
Unique(Iterator first,Iterator last)142 Iterator Unique(Iterator first, Iterator last) {
143 
144   while (first != last) {
145 
146     Iterator next(first);
147     last = std::remove(++next, last, *first);
148     first = next;
149 
150   }
151 
152   return last;
153 
154 }
155 
hookInstrs(Module & M)156 bool CmpLogInstructions::hookInstrs(Module &M) {
157 
158   std::vector<Instruction *> icomps;
159   LLVMContext &              C = M.getContext();
160 
161   Type *       VoidTy = Type::getVoidTy(C);
162   IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
163   IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
164   IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
165   IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
166   IntegerType *Int128Ty = IntegerType::getInt128Ty(C);
167 
168 #if LLVM_VERSION_MAJOR >= 9
169   FunctionCallee
170 #else
171   Constant *
172 #endif
173       c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
174                                  Int8Ty
175 #if LLVM_VERSION_MAJOR < 5
176                                  ,
177                                  NULL
178 #endif
179       );
180 #if LLVM_VERSION_MAJOR >= 9
181   FunctionCallee cmplogHookIns1 = c1;
182 #else
183   Function *cmplogHookIns1 = cast<Function>(c1);
184 #endif
185 
186 #if LLVM_VERSION_MAJOR >= 9
187   FunctionCallee
188 #else
189   Constant *
190 #endif
191       c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
192                                  Int8Ty
193 #if LLVM_VERSION_MAJOR < 5
194                                  ,
195                                  NULL
196 #endif
197       );
198 #if LLVM_VERSION_MAJOR >= 9
199   FunctionCallee cmplogHookIns2 = c2;
200 #else
201   Function *cmplogHookIns2 = cast<Function>(c2);
202 #endif
203 
204 #if LLVM_VERSION_MAJOR >= 9
205   FunctionCallee
206 #else
207   Constant *
208 #endif
209       c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
210                                  Int8Ty
211 #if LLVM_VERSION_MAJOR < 5
212                                  ,
213                                  NULL
214 #endif
215       );
216 #if LLVM_VERSION_MAJOR >= 9
217   FunctionCallee cmplogHookIns4 = c4;
218 #else
219   Function *cmplogHookIns4 = cast<Function>(c4);
220 #endif
221 
222 #if LLVM_VERSION_MAJOR >= 9
223   FunctionCallee
224 #else
225   Constant *
226 #endif
227       c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
228                                  Int8Ty
229 #if LLVM_VERSION_MAJOR < 5
230                                  ,
231                                  NULL
232 #endif
233       );
234 #if LLVM_VERSION_MAJOR >= 9
235   FunctionCallee cmplogHookIns8 = c8;
236 #else
237   Function *cmplogHookIns8 = cast<Function>(c8);
238 #endif
239 
240 #if LLVM_VERSION_MAJOR >= 9
241   FunctionCallee
242 #else
243   Constant *
244 #endif
245       c16 = M.getOrInsertFunction("__cmplog_ins_hook16", VoidTy, Int128Ty,
246                                   Int128Ty, Int8Ty
247 #if LLVM_VERSION_MAJOR < 5
248                                   ,
249                                   NULL
250 #endif
251       );
252 #if LLVM_VERSION_MAJOR < 9
253   Function *cmplogHookIns16 = cast<Function>(c16);
254 #else
255   FunctionCallee cmplogHookIns16 = c16;
256 #endif
257 
258 #if LLVM_VERSION_MAJOR >= 9
259   FunctionCallee
260 #else
261   Constant *
262 #endif
263       cN = M.getOrInsertFunction("__cmplog_ins_hookN", VoidTy, Int128Ty,
264                                  Int128Ty, Int8Ty, Int8Ty
265 #if LLVM_VERSION_MAJOR < 5
266                                  ,
267                                  NULL
268 #endif
269       );
270 #if LLVM_VERSION_MAJOR >= 9
271   FunctionCallee cmplogHookInsN = cN;
272 #else
273   Function *cmplogHookInsN = cast<Function>(cN);
274 #endif
275 
276   GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
277 
278   if (!AFLCmplogPtr) {
279 
280     AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
281                                       GlobalValue::ExternalWeakLinkage, 0,
282                                       "__afl_cmp_map");
283 
284   }
285 
286   Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
287 
288   /* iterate over all functions, bbs and instruction and add suitable calls */
289   for (auto &F : M) {
290 
291     if (!isInInstrumentList(&F, MNAME)) continue;
292 
293     for (auto &BB : F) {
294 
295       for (auto &IN : BB) {
296 
297         CmpInst *selectcmpInst = nullptr;
298         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
299 
300           icomps.push_back(selectcmpInst);
301 
302         }
303 
304       }
305 
306     }
307 
308   }
309 
310   if (icomps.size()) {
311 
312     // if (!be_quiet) errs() << "Hooking " << icomps.size() <<
313     //                          " cmp instructions\n";
314 
315     for (auto &selectcmpInst : icomps) {
316 
317       IRBuilder<> IRB2(selectcmpInst->getParent());
318       IRB2.SetInsertPoint(selectcmpInst);
319       LoadInst *CmpPtr = IRB2.CreateLoad(
320 #if LLVM_VERSION_MAJOR >= 14
321           PointerType::get(Int8Ty, 0),
322 #endif
323           AFLCmplogPtr);
324       CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
325       auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
326       auto ThenTerm =
327           SplitBlockAndInsertIfThen(is_not_null, selectcmpInst, false);
328 
329       IRBuilder<> IRB(ThenTerm);
330 
331       Value *op0 = selectcmpInst->getOperand(0);
332       Value *op1 = selectcmpInst->getOperand(1);
333       Value *op0_saved = op0, *op1_saved = op1;
334       auto   ty0 = op0->getType();
335       auto   ty1 = op1->getType();
336 
337       IntegerType *intTyOp0 = NULL;
338       IntegerType *intTyOp1 = NULL;
339       unsigned     max_size = 0, cast_size = 0;
340       unsigned     attr = 0, vector_cnt = 0, is_fp = 0;
341       CmpInst *    cmpInst = dyn_cast<CmpInst>(selectcmpInst);
342 
343       if (!cmpInst) { continue; }
344 
345       switch (cmpInst->getPredicate()) {
346 
347         case CmpInst::ICMP_NE:
348         case CmpInst::FCMP_UNE:
349         case CmpInst::FCMP_ONE:
350           break;
351         case CmpInst::ICMP_EQ:
352         case CmpInst::FCMP_UEQ:
353         case CmpInst::FCMP_OEQ:
354           attr += 1;
355           break;
356         case CmpInst::ICMP_UGT:
357         case CmpInst::ICMP_SGT:
358         case CmpInst::FCMP_OGT:
359         case CmpInst::FCMP_UGT:
360           attr += 2;
361           break;
362         case CmpInst::ICMP_UGE:
363         case CmpInst::ICMP_SGE:
364         case CmpInst::FCMP_OGE:
365         case CmpInst::FCMP_UGE:
366           attr += 3;
367           break;
368         case CmpInst::ICMP_ULT:
369         case CmpInst::ICMP_SLT:
370         case CmpInst::FCMP_OLT:
371         case CmpInst::FCMP_ULT:
372           attr += 4;
373           break;
374         case CmpInst::ICMP_ULE:
375         case CmpInst::ICMP_SLE:
376         case CmpInst::FCMP_OLE:
377         case CmpInst::FCMP_ULE:
378           attr += 5;
379           break;
380         default:
381           break;
382 
383       }
384 
385       if (selectcmpInst->getOpcode() == Instruction::FCmp) {
386 
387         if (ty0->isVectorTy()) {
388 
389           VectorType *tt = dyn_cast<VectorType>(ty0);
390           if (!tt) {
391 
392             fprintf(stderr, "Warning: cmplog cmp vector is not a vector!\n");
393             continue;
394 
395           }
396 
397 #if (LLVM_VERSION_MAJOR >= 12)
398           vector_cnt = tt->getElementCount().getKnownMinValue();
399           ty0 = tt->getElementType();
400 #endif
401 
402         }
403 
404         if (ty0->isHalfTy()
405 #if LLVM_VERSION_MAJOR >= 11
406             || ty0->isBFloatTy()
407 #endif
408         )
409           max_size = 16;
410         else if (ty0->isFloatTy())
411           max_size = 32;
412         else if (ty0->isDoubleTy())
413           max_size = 64;
414         else if (ty0->isX86_FP80Ty())
415           max_size = 80;
416         else if (ty0->isFP128Ty() || ty0->isPPC_FP128Ty())
417           max_size = 128;
418 #if (LLVM_VERSION_MAJOR >= 12)
419         else if (ty0->getTypeID() != llvm::Type::PointerTyID && !be_quiet)
420           fprintf(stderr, "Warning: unsupported cmp type for cmplog: %u!\n",
421                   ty0->getTypeID());
422 #endif
423 
424         attr += 8;
425         is_fp = 1;
426         // fprintf(stderr, "HAVE FP %u!\n", vector_cnt);
427 
428       } else {
429 
430         if (ty0->isVectorTy()) {
431 
432 #if (LLVM_VERSION_MAJOR >= 12)
433           VectorType *tt = dyn_cast<VectorType>(ty0);
434           if (!tt) {
435 
436             fprintf(stderr, "Warning: cmplog cmp vector is not a vector!\n");
437             continue;
438 
439           }
440 
441           vector_cnt = tt->getElementCount().getKnownMinValue();
442           ty1 = ty0 = tt->getElementType();
443 #endif
444 
445         }
446 
447         intTyOp0 = dyn_cast<IntegerType>(ty0);
448         intTyOp1 = dyn_cast<IntegerType>(ty1);
449 
450         if (intTyOp0 && intTyOp1) {
451 
452           max_size = intTyOp0->getBitWidth() > intTyOp1->getBitWidth()
453                          ? intTyOp0->getBitWidth()
454                          : intTyOp1->getBitWidth();
455 
456         } else {
457 
458 #if (LLVM_VERSION_MAJOR >= 12)
459           if (ty0->getTypeID() != llvm::Type::PointerTyID && !be_quiet) {
460 
461             fprintf(stderr, "Warning: unsupported cmp type for cmplog: %u\n",
462                     ty0->getTypeID());
463 
464           }
465 
466 #endif
467 
468         }
469 
470       }
471 
472       if (!max_size || max_size < 16) {
473 
474         // fprintf(stderr, "too small\n");
475         continue;
476 
477       }
478 
479       if (max_size % 8) { max_size = (((max_size / 8) + 1) * 8); }
480 
481       if (max_size > 128) {
482 
483         if (!be_quiet) {
484 
485           fprintf(stderr,
486                   "Cannot handle this compare bit size: %u (truncating)\n",
487                   max_size);
488 
489         }
490 
491         max_size = 128;
492 
493       }
494 
495       // do we need to cast?
496       switch (max_size) {
497 
498         case 8:
499         case 16:
500         case 32:
501         case 64:
502         case 128:
503           cast_size = max_size;
504           break;
505         default:
506           cast_size = 128;
507 
508       }
509 
510       // XXX FIXME BUG TODO
511       if (is_fp && vector_cnt) { continue; }
512 
513       uint64_t cur = 0, last_val0 = 0, last_val1 = 0, cur_val;
514 
515       while (1) {
516 
517         std::vector<Value *> args;
518         bool                 skip = false;
519 
520         if (vector_cnt) {
521 
522           op0 = IRB.CreateExtractElement(op0_saved, cur);
523           op1 = IRB.CreateExtractElement(op1_saved, cur);
524           /*
525           std::string errMsg;
526           raw_string_ostream os(errMsg);
527           op0_saved->print(os);
528           fprintf(stderr, "X: %s\n", os.str().c_str());
529           */
530           if (is_fp) {
531 
532             /*
533                         ConstantFP *i0 = dyn_cast<ConstantFP>(op0);
534                         ConstantFP *i1 = dyn_cast<ConstantFP>(op1);
535                         // BUG FIXME TODO: this is null ... but why?
536                         // fprintf(stderr, "%p %p\n", i0, i1);
537                         if (i0) {
538 
539                           cur_val = (uint64_t)i0->getValue().convertToDouble();
540                           if (last_val0 && last_val0 == cur_val) { skip = true;
541 
542                } last_val0 = cur_val;
543 
544                         }
545 
546                         if (i1) {
547 
548                           cur_val = (uint64_t)i1->getValue().convertToDouble();
549                           if (last_val1 && last_val1 == cur_val) { skip = true;
550 
551                } last_val1 = cur_val;
552 
553                         }
554 
555             */
556 
557           } else {
558 
559             ConstantInt *i0 = dyn_cast<ConstantInt>(op0);
560             ConstantInt *i1 = dyn_cast<ConstantInt>(op1);
561             if (i0 && i0->uge(0xffffffffffffffff) == false) {
562 
563               cur_val = i0->getZExtValue();
564               if (last_val0 && last_val0 == cur_val) { skip = true; }
565               last_val0 = cur_val;
566 
567             }
568 
569             if (i1 && i1->uge(0xffffffffffffffff) == false) {
570 
571               cur_val = i1->getZExtValue();
572               if (last_val1 && last_val1 == cur_val) { skip = true; }
573               last_val1 = cur_val;
574 
575             }
576 
577           }
578 
579         }
580 
581         if (!skip) {
582 
583           // errs() << "[CMPLOG] cmp  " << *cmpInst << "(in function " <<
584           // cmpInst->getFunction()->getName() << ")\n";
585 
586           // first bitcast to integer type of the same bitsize as the original
587           // type (this is a nop, if already integer)
588           Value *op0_i = IRB.CreateBitCast(
589               op0, IntegerType::get(C, ty0->getPrimitiveSizeInBits()));
590           // then create a int cast, which does zext, trunc or bitcast. In our
591           // case usually zext to the next larger supported type (this is a nop
592           // if already the right type)
593           Value *V0 =
594               IRB.CreateIntCast(op0_i, IntegerType::get(C, cast_size), false);
595           args.push_back(V0);
596           Value *op1_i = IRB.CreateBitCast(
597               op1, IntegerType::get(C, ty1->getPrimitiveSizeInBits()));
598           Value *V1 =
599               IRB.CreateIntCast(op1_i, IntegerType::get(C, cast_size), false);
600           args.push_back(V1);
601 
602           // errs() << "[CMPLOG] casted parameters:\n0: " << *V0 << "\n1: " <<
603           // *V1
604           // << "\n";
605 
606           ConstantInt *attribute = ConstantInt::get(Int8Ty, attr);
607           args.push_back(attribute);
608 
609           if (cast_size != max_size) {
610 
611             ConstantInt *bitsize = ConstantInt::get(Int8Ty, (max_size / 8) - 1);
612             args.push_back(bitsize);
613 
614           }
615 
616           // fprintf(stderr, "_ExtInt(%u) castTo %u with attr %u didcast %u\n",
617           //         max_size, cast_size, attr);
618 
619           switch (cast_size) {
620 
621             case 8:
622               IRB.CreateCall(cmplogHookIns1, args);
623               break;
624             case 16:
625               IRB.CreateCall(cmplogHookIns2, args);
626               break;
627             case 32:
628               IRB.CreateCall(cmplogHookIns4, args);
629               break;
630             case 64:
631               IRB.CreateCall(cmplogHookIns8, args);
632               break;
633             case 128:
634               if (max_size == 128) {
635 
636                 IRB.CreateCall(cmplogHookIns16, args);
637 
638               } else {
639 
640                 IRB.CreateCall(cmplogHookInsN, args);
641 
642               }
643 
644               break;
645 
646           }
647 
648         }
649 
650         /* else fprintf(stderr, "skipped\n"); */
651 
652         ++cur;
653         if (cur >= vector_cnt) { break; }
654 
655       }
656 
657     }
658 
659   }
660 
661   if (icomps.size())
662     return true;
663   else
664     return false;
665 
666 }
667 
668 #if LLVM_MAJOR >= 11                                /* use new pass manager */
run(Module & M,ModuleAnalysisManager & MAM)669 PreservedAnalyses CmpLogInstructions::run(Module &               M,
670                                           ModuleAnalysisManager &MAM) {
671 
672 #else
673 bool CmpLogInstructions::runOnModule(Module &M) {
674 
675 #endif
676 
677   if (getenv("AFL_QUIET") == NULL)
678     printf("Running cmplog-instructions-pass by andreafioraldi@gmail.com\n");
679   else
680     be_quiet = 1;
681   hookInstrs(M);
682   verifyModule(M);
683 
684 #if LLVM_MAJOR >= 11                                /* use new pass manager */
685   return PreservedAnalyses::all();
686 #else
687   return true;
688 #endif
689 
690 }
691 
692 #if LLVM_MAJOR < 11                                 /* use old pass manager */
693 static void registerCmpLogInstructionsPass(const PassManagerBuilder &,
694                                            legacy::PassManagerBase &PM) {
695 
696   auto p = new CmpLogInstructions();
697   PM.add(p);
698 
699 }
700 
701 static RegisterStandardPasses RegisterCmpLogInstructionsPass(
702     PassManagerBuilder::EP_OptimizerLast, registerCmpLogInstructionsPass);
703 
704 static RegisterStandardPasses RegisterCmpLogInstructionsPass0(
705     PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogInstructionsPass);
706 
707   #if LLVM_VERSION_MAJOR >= 11
708 static RegisterStandardPasses RegisterCmpLogInstructionsPassLTO(
709     PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
710     registerCmpLogInstructionsPass);
711   #endif
712 #endif
713 
714