• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* SanitizeCoverage.cpp ported to afl++ LTO :-) */
2 
3 #define AFL_LLVM_PASS
4 
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <unistd.h>
8 #include <string.h>
9 #include <sys/time.h>
10 
11 #include <list>
12 #include <string>
13 #include <fstream>
14 #include <set>
15 #include <iostream>
16 
17 #include "llvm/Transforms/Instrumentation/SanitizerCoverage.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Triple.h"
21 #include "llvm/Analysis/EHPersonalities.h"
22 #include "llvm/Analysis/PostDominators.h"
23 #include "llvm/Analysis/ValueTracking.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/CFG.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DebugInfo.h"
29 #include "llvm/IR/Dominators.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/GlobalVariable.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/InlineAsm.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/Intrinsics.h"
37 #include "llvm/IR/LLVMContext.h"
38 #include "llvm/IR/MDBuilder.h"
39 #include "llvm/IR/Mangler.h"
40 #include "llvm/IR/Module.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/InitializePasses.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/CommandLine.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/SpecialCaseList.h"
47 #include "llvm/Support/VirtualFileSystem.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include "llvm/Transforms/Instrumentation.h"
50 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
51 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/ModuleUtils.h"
54 #include "llvm/Passes/PassPlugin.h"
55 #include "llvm/Passes/PassBuilder.h"
56 #include "llvm/IR/PassManager.h"
57 
58 #include "config.h"
59 #include "debug.h"
60 #include "afl-llvm-common.h"
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "sancov"
65 
66 const char SanCovTracePCIndirName[] = "__sanitizer_cov_trace_pc_indir";
67 const char SanCovTracePCName[] = "__sanitizer_cov_trace_pc";
68 // const char SanCovTracePCGuardName =
69 //    "__sanitizer_cov_trace_pc_guard";
70 const char SanCovGuardsSectionName[] = "sancov_guards";
71 const char SanCovCountersSectionName[] = "sancov_cntrs";
72 const char SanCovBoolFlagSectionName[] = "sancov_bools";
73 const char SanCovPCsSectionName[] = "sancov_pcs";
74 
75 static cl::opt<int> ClCoverageLevel(
76     "lto-coverage-level",
77     cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, "
78              "3: all blocks and critical edges"),
79     cl::Hidden, cl::init(3));
80 
81 static cl::opt<bool> ClTracePC("lto-coverage-trace-pc",
82                                cl::desc("Experimental pc tracing"), cl::Hidden,
83                                cl::init(false));
84 
85 static cl::opt<bool> ClTracePCGuard("lto-coverage-trace-pc-guard",
86                                     cl::desc("pc tracing with a guard"),
87                                     cl::Hidden, cl::init(false));
88 
89 // If true, we create a global variable that contains PCs of all instrumented
90 // BBs, put this global into a named section, and pass this section's bounds
91 // to __sanitizer_cov_pcs_init.
92 // This way the coverage instrumentation does not need to acquire the PCs
93 // at run-time. Works with trace-pc-guard, inline-8bit-counters, and
94 // inline-bool-flag.
95 static cl::opt<bool> ClCreatePCTable("lto-coverage-pc-table",
96                                      cl::desc("create a static PC table"),
97                                      cl::Hidden, cl::init(false));
98 
99 static cl::opt<bool> ClInline8bitCounters(
100     "lto-coverage-inline-8bit-counters",
101     cl::desc("increments 8-bit counter for every edge"), cl::Hidden,
102     cl::init(false));
103 
104 static cl::opt<bool> ClInlineBoolFlag(
105     "lto-coverage-inline-bool-flag",
106     cl::desc("sets a boolean flag for every edge"), cl::Hidden,
107     cl::init(false));
108 
109 static cl::opt<bool> ClPruneBlocks(
110     "lto-coverage-prune-blocks",
111     cl::desc("Reduce the number of instrumented blocks"), cl::Hidden,
112     cl::init(true));
113 
114 namespace {
115 
getOptions(int LegacyCoverageLevel)116 SanitizerCoverageOptions getOptions(int LegacyCoverageLevel) {
117 
118   SanitizerCoverageOptions Res;
119   switch (LegacyCoverageLevel) {
120 
121     case 0:
122       Res.CoverageType = SanitizerCoverageOptions::SCK_None;
123       break;
124     case 1:
125       Res.CoverageType = SanitizerCoverageOptions::SCK_Function;
126       break;
127     case 2:
128       Res.CoverageType = SanitizerCoverageOptions::SCK_BB;
129       break;
130     case 3:
131       Res.CoverageType = SanitizerCoverageOptions::SCK_Edge;
132       break;
133     case 4:
134       Res.CoverageType = SanitizerCoverageOptions::SCK_Edge;
135       Res.IndirectCalls = true;
136       break;
137 
138   }
139 
140   return Res;
141 
142 }
143 
OverrideFromCL(SanitizerCoverageOptions Options)144 SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) {
145 
146   // Sets CoverageType and IndirectCalls.
147   SanitizerCoverageOptions CLOpts = getOptions(ClCoverageLevel);
148   Options.CoverageType = std::max(Options.CoverageType, CLOpts.CoverageType);
149   Options.IndirectCalls |= CLOpts.IndirectCalls;
150   Options.TracePC |= ClTracePC;
151   Options.TracePCGuard |= ClTracePCGuard;
152   Options.Inline8bitCounters |= ClInline8bitCounters;
153   Options.InlineBoolFlag |= ClInlineBoolFlag;
154   Options.PCTable |= ClCreatePCTable;
155   Options.NoPrune |= !ClPruneBlocks;
156   if (!Options.TracePCGuard && !Options.TracePC &&
157       !Options.Inline8bitCounters && !Options.InlineBoolFlag)
158     Options.TracePCGuard = true;  // TracePCGuard is default.
159   return Options;
160 
161 }
162 
163 using DomTreeCallback = function_ref<const DominatorTree *(Function &F)>;
164 using PostDomTreeCallback =
165     function_ref<const PostDominatorTree *(Function &F)>;
166 
167 class ModuleSanitizerCoverageLTO
168     : public PassInfoMixin<ModuleSanitizerCoverageLTO> {
169 
170  public:
ModuleSanitizerCoverageLTO(const SanitizerCoverageOptions & Options=SanitizerCoverageOptions ())171   ModuleSanitizerCoverageLTO(
172       const SanitizerCoverageOptions &Options = SanitizerCoverageOptions())
173       : Options(OverrideFromCL(Options)) {
174 
175   }
176 
177   bool instrumentModule(Module &M, DomTreeCallback DTCallback,
178                         PostDomTreeCallback PDTCallback);
179 
180   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
181 
182  private:
183   void            instrumentFunction(Function &F, DomTreeCallback DTCallback,
184                                      PostDomTreeCallback PDTCallback);
185   void            InjectCoverageForIndirectCalls(Function &              F,
186                                                  ArrayRef<Instruction *> IndirCalls);
187   bool            InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks,
188                                  bool IsLeafFunc = true);
189   GlobalVariable *CreateFunctionLocalArrayInSection(size_t    NumElements,
190                                                     Function &F, Type *Ty,
191                                                     const char *Section);
192   GlobalVariable *CreatePCArray(Function &F, ArrayRef<BasicBlock *> AllBlocks);
193   void CreateFunctionLocalArrays(Function &F, ArrayRef<BasicBlock *> AllBlocks);
194   void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx,
195                              bool IsLeafFunc = true);
196   //  std::pair<Value *, Value *> CreateSecStartEnd(Module &M, const char
197   //  *Section,
198   //                                                Type *Ty);
199 
SetNoSanitizeMetadata(Instruction * I)200   void SetNoSanitizeMetadata(Instruction *I) {
201 
202     I->setMetadata(I->getModule()->getMDKindID("nosanitize"),
203                    MDNode::get(*C, None));
204 
205   }
206 
207   std::string getSectionName(const std::string &Section) const;
208   //  std::string    getSectionStart(const std::string &Section) const;
209   //  std::string    getSectionEnd(const std::string &Section) const;
210   FunctionCallee SanCovTracePCIndir;
211   FunctionCallee SanCovTracePC /*, SanCovTracePCGuard*/;
212   Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy,
213       *Int16Ty, *Int8Ty, *Int8PtrTy, *Int1Ty, *Int1PtrTy;
214   Module *          CurModule;
215   std::string       CurModuleUniqueId;
216   Triple            TargetTriple;
217   LLVMContext *     C;
218   const DataLayout *DL;
219 
220   GlobalVariable *FunctionGuardArray;        // for trace-pc-guard.
221   GlobalVariable *Function8bitCounterArray;  // for inline-8bit-counters.
222   GlobalVariable *FunctionBoolArray;         // for inline-bool-flag.
223   GlobalVariable *FunctionPCsArray;          // for pc-table.
224   SmallVector<GlobalValue *, 20> GlobalsToAppendToUsed;
225   SmallVector<GlobalValue *, 20> GlobalsToAppendToCompilerUsed;
226 
227   SanitizerCoverageOptions Options;
228 
229   // afl++ START
230   // const SpecialCaseList *          Allowlist;
231   // const SpecialCaseList *          Blocklist;
232   uint32_t                         autodictionary = 1;
233   uint32_t                         inst = 0;
234   uint32_t                         afl_global_id = 0;
235   uint32_t                         unhandled = 0;
236   uint32_t                         select_cnt = 0;
237   uint64_t                         map_addr = 0;
238   const char *                     skip_nozero = NULL;
239   const char *                     use_threadsafe_counters = nullptr;
240   std::vector<BasicBlock *>        BlockList;
241   DenseMap<Value *, std::string *> valueMap;
242   std::vector<std::string>         dictionary;
243   IntegerType *                    Int8Tyi = NULL;
244   IntegerType *                    Int32Tyi = NULL;
245   IntegerType *                    Int64Tyi = NULL;
246   ConstantInt *                    Zero = NULL;
247   ConstantInt *                    One = NULL;
248   LLVMContext *                    Ct = NULL;
249   Module *                         Mo = NULL;
250   GlobalVariable *                 AFLMapPtr = NULL;
251   Value *                          MapPtrFixed = NULL;
252   std::ofstream                    dFile;
253   size_t                           found = 0;
254   // afl++ END
255 
256 };
257 
258 class ModuleSanitizerCoverageLegacyPass : public ModulePass {
259 
260  public:
261   static char ID;
getPassName() const262   StringRef   getPassName() const override {
263 
264     return "sancov";
265 
266   }
267 
getAnalysisUsage(AnalysisUsage & AU) const268   void getAnalysisUsage(AnalysisUsage &AU) const override {
269 
270     AU.addRequired<DominatorTreeWrapperPass>();
271     AU.addRequired<PostDominatorTreeWrapperPass>();
272 
273   }
274 
ModuleSanitizerCoverageLegacyPass(const SanitizerCoverageOptions & Options=SanitizerCoverageOptions ())275   ModuleSanitizerCoverageLegacyPass(
276       const SanitizerCoverageOptions &Options = SanitizerCoverageOptions())
277       : ModulePass(ID), Options(Options) {
278 
279     initializeModuleSanitizerCoverageLegacyPassPass(
280         *PassRegistry::getPassRegistry());
281 
282   }
283 
runOnModule(Module & M)284   bool runOnModule(Module &M) override {
285 
286     ModuleSanitizerCoverageLTO ModuleSancov(Options);
287     auto DTCallback = [this](Function &F) -> const DominatorTree * {
288 
289       return &this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
290 
291     };
292 
293     auto PDTCallback = [this](Function &F) -> const PostDominatorTree * {
294 
295       return &this->getAnalysis<PostDominatorTreeWrapperPass>(F)
296                   .getPostDomTree();
297 
298     };
299 
300     return ModuleSancov.instrumentModule(M, DTCallback, PDTCallback);
301 
302   }
303 
304  private:
305   SanitizerCoverageOptions Options;
306 
307 };
308 
309 }  // namespace
310 
311 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo()312 llvmGetPassPluginInfo() {
313 
314   return {LLVM_PLUGIN_API_VERSION, "SanitizerCoverageLTO", "v0.1",
315           /* lambda to insert our pass into the pass pipeline. */
316           [](PassBuilder &PB) {
317 
318 #if LLVM_VERSION_MAJOR <= 13
319             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
320 #endif
321             //            PB.registerFullLinkTimeOptimizationLastEPCallback(
322             PB.registerOptimizerLastEPCallback(
323                 [](ModulePassManager &MPM, OptimizationLevel OL) {
324 
325                   MPM.addPass(ModuleSanitizerCoverageLTO());
326 
327                 });
328 
329           }};
330 
331 }
332 
run(Module & M,ModuleAnalysisManager & MAM)333 PreservedAnalyses ModuleSanitizerCoverageLTO::run(Module &               M,
334                                                   ModuleAnalysisManager &MAM) {
335 
336   ModuleSanitizerCoverageLTO ModuleSancov(Options);
337   auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
338   auto  DTCallback = [&FAM](Function &F) -> const DominatorTree * {
339 
340     return &FAM.getResult<DominatorTreeAnalysis>(F);
341 
342   };
343 
344   auto PDTCallback = [&FAM](Function &F) -> const PostDominatorTree * {
345 
346     return &FAM.getResult<PostDominatorTreeAnalysis>(F);
347 
348   };
349 
350   if (ModuleSancov.instrumentModule(M, DTCallback, PDTCallback))
351     return PreservedAnalyses::none();
352 
353   return PreservedAnalyses::all();
354 
355 }
356 
instrumentModule(Module & M,DomTreeCallback DTCallback,PostDomTreeCallback PDTCallback)357 bool ModuleSanitizerCoverageLTO::instrumentModule(
358     Module &M, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) {
359 
360   if (Options.CoverageType == SanitizerCoverageOptions::SCK_None) return false;
361   /*
362     if (Allowlist &&
363         !Allowlist->inSection("coverage", "src", MNAME))
364       return false;
365     if (Blocklist &&
366         Blocklist->inSection("coverage", "src", MNAME))
367       return false;
368   */
369   BlockList.clear();
370   valueMap.clear();
371   dictionary.clear();
372   C = &(M.getContext());
373   DL = &M.getDataLayout();
374   CurModule = &M;
375   CurModuleUniqueId = getUniqueModuleId(CurModule);
376   TargetTriple = Triple(M.getTargetTriple());
377   FunctionGuardArray = nullptr;
378   Function8bitCounterArray = nullptr;
379   FunctionBoolArray = nullptr;
380   FunctionPCsArray = nullptr;
381   IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits());
382   IntptrPtrTy = PointerType::getUnqual(IntptrTy);
383   Type *      VoidTy = Type::getVoidTy(*C);
384   IRBuilder<> IRB(*C);
385   Int64PtrTy = PointerType::getUnqual(IRB.getInt64Ty());
386   Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty());
387   Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty());
388   Int1PtrTy = PointerType::getUnqual(IRB.getInt1Ty());
389   Int64Ty = IRB.getInt64Ty();
390   Int32Ty = IRB.getInt32Ty();
391   Int16Ty = IRB.getInt16Ty();
392   Int8Ty = IRB.getInt8Ty();
393   Int1Ty = IRB.getInt1Ty();
394 
395   /* afl++ START */
396   char *       ptr;
397   LLVMContext &Ctx = M.getContext();
398   Ct = &Ctx;
399   Int8Tyi = IntegerType::getInt8Ty(Ctx);
400   Int32Tyi = IntegerType::getInt32Ty(Ctx);
401   Int64Tyi = IntegerType::getInt64Ty(Ctx);
402 
403   /* Show a banner */
404   setvbuf(stdout, NULL, _IONBF, 0);
405   if (getenv("AFL_DEBUG")) debug = 1;
406 
407   if ((isatty(2) && !getenv("AFL_QUIET")) || debug) {
408 
409     SAYF(cCYA "afl-llvm-lto" VERSION cRST
410               " by Marc \"vanHauser\" Heuse <mh@mh-sec.de>\n");
411 
412   } else
413 
414     be_quiet = 1;
415 
416   skip_nozero = getenv("AFL_LLVM_SKIP_NEVERZERO");
417   use_threadsafe_counters = getenv("AFL_LLVM_THREADSAFE_INST");
418 
419   if ((ptr = getenv("AFL_LLVM_LTO_STARTID")) != NULL)
420     if ((afl_global_id = atoi(ptr)) < 0)
421       FATAL("AFL_LLVM_LTO_STARTID value of \"%s\" is negative\n", ptr);
422 
423   if ((ptr = getenv("AFL_LLVM_DOCUMENT_IDS")) != NULL) {
424 
425     dFile.open(ptr, std::ofstream::out | std::ofstream::app);
426     if (dFile.is_open()) WARNF("Cannot access document file %s", ptr);
427 
428   }
429 
430   // we make this the default as the fixed map has problems with
431   // defered forkserver, early constructors, ifuncs and maybe more
432   /*if (getenv("AFL_LLVM_MAP_DYNAMIC"))*/
433   map_addr = 0;
434 
435   if ((ptr = getenv("AFL_LLVM_MAP_ADDR"))) {
436 
437     uint64_t val;
438     if (!*ptr || !strcmp(ptr, "0") || !strcmp(ptr, "0x0")) {
439 
440       map_addr = 0;
441 
442     } else if (getenv("AFL_LLVM_MAP_DYNAMIC")) {
443 
444       FATAL(
445           "AFL_LLVM_MAP_ADDR and AFL_LLVM_MAP_DYNAMIC cannot be used together");
446 
447     } else if (strncmp(ptr, "0x", 2) != 0) {
448 
449       map_addr = 0x10000;  // the default
450 
451     } else {
452 
453       val = strtoull(ptr, NULL, 16);
454       if (val < 0x100 || val > 0xffffffff00000000) {
455 
456         FATAL(
457             "AFL_LLVM_MAP_ADDR must be a value between 0x100 and "
458             "0xffffffff00000000");
459 
460       }
461 
462       map_addr = val;
463 
464     }
465 
466   }
467 
468   /* Get/set the globals for the SHM region. */
469 
470   if (!map_addr) {
471 
472     AFLMapPtr =
473         new GlobalVariable(M, PointerType::get(Int8Tyi, 0), false,
474                            GlobalValue::ExternalLinkage, 0, "__afl_area_ptr");
475 
476   } else {
477 
478     ConstantInt *MapAddr = ConstantInt::get(Int64Tyi, map_addr);
479     MapPtrFixed =
480         ConstantExpr::getIntToPtr(MapAddr, PointerType::getUnqual(Int8Tyi));
481 
482   }
483 
484   Zero = ConstantInt::get(Int8Tyi, 0);
485   One = ConstantInt::get(Int8Tyi, 1);
486 
487   initInstrumentList();
488   scanForDangerousFunctions(&M);
489   Mo = &M;
490 
491   if (autodictionary) {
492 
493     for (auto &F : M) {
494 
495       if (!isInInstrumentList(&F, MNAME) || !F.size()) { continue; }
496 
497       for (auto &BB : F) {
498 
499         for (auto &IN : BB) {
500 
501           CallInst *callInst = nullptr;
502           CmpInst * cmpInst = nullptr;
503 
504           if ((cmpInst = dyn_cast<CmpInst>(&IN))) {
505 
506             Value *      op = cmpInst->getOperand(1);
507             ConstantInt *ilen = dyn_cast<ConstantInt>(op);
508 
509             if (ilen && ilen->uge(0xffffffffffffffff) == false) {
510 
511               u64 val2 = 0, val = ilen->getZExtValue();
512               u32 len = 0;
513               if (val > 0x10000 && val < 0xffffffff) len = 4;
514               if (val > 0x100000001 && val < 0xffffffffffffffff) len = 8;
515 
516               if (len) {
517 
518                 auto c = cmpInst->getPredicate();
519 
520                 switch (c) {
521 
522                   case CmpInst::FCMP_OGT:  // fall through
523                   case CmpInst::FCMP_OLE:  // fall through
524                   case CmpInst::ICMP_SLE:  // fall through
525                   case CmpInst::ICMP_SGT:
526 
527                     // signed comparison and it is a negative constant
528                     if ((len == 4 && (val & 80000000)) ||
529                         (len == 8 && (val & 8000000000000000))) {
530 
531                       if ((val & 0xffff) != 1) val2 = val - 1;
532                       break;
533 
534                     }
535 
536                     // fall through
537 
538                   case CmpInst::FCMP_UGT:  // fall through
539                   case CmpInst::FCMP_ULE:  // fall through
540                   case CmpInst::ICMP_UGT:  // fall through
541                   case CmpInst::ICMP_ULE:
542                     if ((val & 0xffff) != 0xfffe) val2 = val + 1;
543                     break;
544 
545                   case CmpInst::FCMP_OLT:  // fall through
546                   case CmpInst::FCMP_OGE:  // fall through
547                   case CmpInst::ICMP_SLT:  // fall through
548                   case CmpInst::ICMP_SGE:
549 
550                     // signed comparison and it is a negative constant
551                     if ((len == 4 && (val & 80000000)) ||
552                         (len == 8 && (val & 8000000000000000))) {
553 
554                       if ((val & 0xffff) != 1) val2 = val - 1;
555                       break;
556 
557                     }
558 
559                     // fall through
560 
561                   case CmpInst::FCMP_ULT:  // fall through
562                   case CmpInst::FCMP_UGE:  // fall through
563                   case CmpInst::ICMP_ULT:  // fall through
564                   case CmpInst::ICMP_UGE:
565                     if ((val & 0xffff) != 1) val2 = val - 1;
566                     break;
567 
568                   default:
569                     val2 = 0;
570 
571                 }
572 
573                 dictionary.push_back(std::string((char *)&val, len));
574                 found++;
575 
576                 if (val2) {
577 
578                   dictionary.push_back(std::string((char *)&val2, len));
579                   found++;
580 
581                 }
582 
583               }
584 
585             }
586 
587           }
588 
589           if ((callInst = dyn_cast<CallInst>(&IN))) {
590 
591             bool   isStrcmp = true;
592             bool   isMemcmp = true;
593             bool   isStrncmp = true;
594             bool   isStrcasecmp = true;
595             bool   isStrncasecmp = true;
596             bool   isIntMemcpy = true;
597             bool   isStdString = true;
598             size_t optLen = 0;
599 
600             Function *Callee = callInst->getCalledFunction();
601             if (!Callee) continue;
602             if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
603             std::string FuncName = Callee->getName().str();
604 
605             isStrcmp &= (!FuncName.compare("strcmp") ||
606                          !FuncName.compare("xmlStrcmp") ||
607                          !FuncName.compare("xmlStrEqual") ||
608                          !FuncName.compare("g_strcmp0") ||
609                          !FuncName.compare("curl_strequal") ||
610                          !FuncName.compare("strcsequal"));
611             isMemcmp &=
612                 (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") ||
613                  !FuncName.compare("CRYPTO_memcmp") ||
614                  !FuncName.compare("OPENSSL_memcmp") ||
615                  !FuncName.compare("memcmp_const_time") ||
616                  !FuncName.compare("memcmpct"));
617             isStrncmp &= (!FuncName.compare("strncmp") ||
618                           !FuncName.compare("xmlStrncmp") ||
619                           !FuncName.compare("curl_strnequal"));
620             isStrcasecmp &= (!FuncName.compare("strcasecmp") ||
621                              !FuncName.compare("stricmp") ||
622                              !FuncName.compare("ap_cstr_casecmp") ||
623                              !FuncName.compare("OPENSSL_strcasecmp") ||
624                              !FuncName.compare("xmlStrcasecmp") ||
625                              !FuncName.compare("g_strcasecmp") ||
626                              !FuncName.compare("g_ascii_strcasecmp") ||
627                              !FuncName.compare("Curl_strcasecompare") ||
628                              !FuncName.compare("Curl_safe_strcasecompare") ||
629                              !FuncName.compare("cmsstrcasecmp"));
630             isStrncasecmp &= (!FuncName.compare("strncasecmp") ||
631                               !FuncName.compare("strnicmp") ||
632                               !FuncName.compare("ap_cstr_casecmpn") ||
633                               !FuncName.compare("OPENSSL_strncasecmp") ||
634                               !FuncName.compare("xmlStrncasecmp") ||
635                               !FuncName.compare("g_ascii_strncasecmp") ||
636                               !FuncName.compare("Curl_strncasecompare") ||
637                               !FuncName.compare("g_strncasecmp"));
638 
639             isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64");
640             isStdString &=
641                 ((FuncName.find("basic_string") != std::string::npos &&
642                   FuncName.find("compare") != std::string::npos) ||
643                  (FuncName.find("basic_string") != std::string::npos &&
644                   FuncName.find("find") != std::string::npos));
645 
646             /* we do something different here, putting this BB and the
647                successors in a block map */
648             if (!FuncName.compare("__afl_persistent_loop")) {
649 
650               BlockList.push_back(&BB);
651               for (succ_iterator SI = succ_begin(&BB), SE = succ_end(&BB);
652                    SI != SE; ++SI) {
653 
654                 BasicBlock *succ = *SI;
655                 BlockList.push_back(succ);
656 
657               }
658 
659             }
660 
661             if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
662                 !isStrncasecmp && !isIntMemcpy && !isStdString)
663               continue;
664 
665             /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
666              * prototype */
667             FunctionType *FT = Callee->getFunctionType();
668 
669             isStrcmp &= FT->getNumParams() == 2 &&
670                         FT->getReturnType()->isIntegerTy(32) &&
671                         FT->getParamType(0) == FT->getParamType(1) &&
672                         FT->getParamType(0) ==
673                             IntegerType::getInt8PtrTy(M.getContext());
674             isStrcasecmp &= FT->getNumParams() == 2 &&
675                             FT->getReturnType()->isIntegerTy(32) &&
676                             FT->getParamType(0) == FT->getParamType(1) &&
677                             FT->getParamType(0) ==
678                                 IntegerType::getInt8PtrTy(M.getContext());
679             isMemcmp &= FT->getNumParams() == 3 &&
680                         FT->getReturnType()->isIntegerTy(32) &&
681                         FT->getParamType(0)->isPointerTy() &&
682                         FT->getParamType(1)->isPointerTy() &&
683                         FT->getParamType(2)->isIntegerTy();
684             isStrncmp &= FT->getNumParams() == 3 &&
685                          FT->getReturnType()->isIntegerTy(32) &&
686                          FT->getParamType(0) == FT->getParamType(1) &&
687                          FT->getParamType(0) ==
688                              IntegerType::getInt8PtrTy(M.getContext()) &&
689                          FT->getParamType(2)->isIntegerTy();
690             isStrncasecmp &= FT->getNumParams() == 3 &&
691                              FT->getReturnType()->isIntegerTy(32) &&
692                              FT->getParamType(0) == FT->getParamType(1) &&
693                              FT->getParamType(0) ==
694                                  IntegerType::getInt8PtrTy(M.getContext()) &&
695                              FT->getParamType(2)->isIntegerTy();
696             isStdString &= FT->getNumParams() >= 2 &&
697                            FT->getParamType(0)->isPointerTy() &&
698                            FT->getParamType(1)->isPointerTy();
699 
700             if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
701                 !isStrncasecmp && !isIntMemcpy && !isStdString)
702               continue;
703 
704             /* is a str{n,}{case,}cmp/memcmp, check if we have
705              * str{case,}cmp(x, "const") or str{case,}cmp("const", x)
706              * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
707              * memcmp(x, "const", ..) or memcmp("const", x, ..) */
708             Value *Str1P = callInst->getArgOperand(0),
709                   *Str2P = callInst->getArgOperand(1);
710             std::string Str1, Str2;
711             StringRef   TmpStr;
712             bool        HasStr1 = getConstantStringInfo(Str1P, TmpStr);
713             if (TmpStr.empty())
714               HasStr1 = false;
715             else
716               Str1 = TmpStr.str();
717             bool HasStr2 = getConstantStringInfo(Str2P, TmpStr);
718             if (TmpStr.empty())
719               HasStr2 = false;
720             else
721               Str2 = TmpStr.str();
722 
723             if (debug)
724               fprintf(stderr, "F:%s %p(%s)->\"%s\"(%s) %p(%s)->\"%s\"(%s)\n",
725                       FuncName.c_str(), Str1P, Str1P->getName().str().c_str(),
726                       Str1.c_str(), HasStr1 == true ? "true" : "false", Str2P,
727                       Str2P->getName().str().c_str(), Str2.c_str(),
728                       HasStr2 == true ? "true" : "false");
729 
730             // we handle the 2nd parameter first because of llvm memcpy
731             if (!HasStr2) {
732 
733               auto *Ptr = dyn_cast<ConstantExpr>(Str2P);
734               if (Ptr && Ptr->getOpcode() == Instruction::GetElementPtr) {
735 
736                 if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) {
737 
738                   if (Var->hasInitializer()) {
739 
740                     if (auto *Array = dyn_cast<ConstantDataArray>(
741                             Var->getInitializer())) {
742 
743                       HasStr2 = true;
744                       Str2 = Array->getRawDataValues().str();
745 
746                     }
747 
748                   }
749 
750                 }
751 
752               }
753 
754             }
755 
756             // for the internal memcpy routine we only care for the second
757             // parameter and are not reporting anything.
758             if (isIntMemcpy == true) {
759 
760               if (HasStr2 == true) {
761 
762                 Value *      op2 = callInst->getArgOperand(2);
763                 ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
764                 if (ilen) {
765 
766                   uint64_t literalLength = Str2.size();
767                   uint64_t optLength = ilen->getZExtValue();
768                   if (optLength > literalLength + 1) {
769 
770                     optLength = Str2.length() + 1;
771 
772                   }
773 
774                   if (literalLength + 1 == optLength) {
775 
776                     Str2.append("\0", 1);  // add null byte
777 
778                   }
779 
780                 }
781 
782                 valueMap[Str1P] = new std::string(Str2);
783 
784                 if (debug)
785                   fprintf(stderr, "Saved: %s for %p\n", Str2.c_str(), Str1P);
786                 continue;
787 
788               }
789 
790               continue;
791 
792             }
793 
794             // Neither a literal nor a global variable?
795             // maybe it is a local variable that we saved
796             if (!HasStr2) {
797 
798               std::string *strng = valueMap[Str2P];
799               if (strng && !strng->empty()) {
800 
801                 Str2 = *strng;
802                 HasStr2 = true;
803                 if (debug)
804                   fprintf(stderr, "Filled2: %s for %p\n", strng->c_str(),
805                           Str2P);
806 
807               }
808 
809             }
810 
811             if (!HasStr1) {
812 
813               auto Ptr = dyn_cast<ConstantExpr>(Str1P);
814 
815               if (Ptr && Ptr->getOpcode() == Instruction::GetElementPtr) {
816 
817                 if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) {
818 
819                   if (Var->hasInitializer()) {
820 
821                     if (auto *Array = dyn_cast<ConstantDataArray>(
822                             Var->getInitializer())) {
823 
824                       HasStr1 = true;
825                       Str1 = Array->getRawDataValues().str();
826 
827                     }
828 
829                   }
830 
831                 }
832 
833               }
834 
835             }
836 
837             // Neither a literal nor a global variable?
838             // maybe it is a local variable that we saved
839             if (!HasStr1) {
840 
841               std::string *strng = valueMap[Str1P];
842               if (strng && !strng->empty()) {
843 
844                 Str1 = *strng;
845                 HasStr1 = true;
846                 if (debug)
847                   fprintf(stderr, "Filled1: %s for %p\n", strng->c_str(),
848                           Str1P);
849 
850               }
851 
852             }
853 
854             /* handle cases of one string is const, one string is variable */
855             if (!(HasStr1 ^ HasStr2)) continue;
856 
857             std::string thestring;
858 
859             if (HasStr1)
860               thestring = Str1;
861             else
862               thestring = Str2;
863 
864             optLen = thestring.length();
865             if (optLen < 2 || (optLen == 2 && !thestring[1])) { continue; }
866 
867             if (isMemcmp || isStrncmp || isStrncasecmp) {
868 
869               Value *      op2 = callInst->getArgOperand(2);
870               ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
871 
872               if (ilen) {
873 
874                 uint64_t literalLength = optLen;
875                 optLen = ilen->getZExtValue();
876                 if (optLen > thestring.length() + 1) {
877 
878                   optLen = thestring.length() + 1;
879 
880                 }
881 
882                 if (optLen < 2) { continue; }
883                 if (literalLength + 1 == optLen) {  // add null byte
884 
885                   thestring.append("\0", 1);
886 
887                 }
888 
889               }
890 
891             }
892 
893             // add null byte if this is a string compare function and a null
894             // was not already added
895             if (!isMemcmp) {
896 
897               /*
898                             if (addedNull == false && thestring[optLen - 1] !=
899                  '\0') {
900 
901                               thestring.append("\0", 1);  // add null byte
902                               optLen++;
903 
904                             }
905 
906               */
907               if (!isStdString &&
908                   thestring.find('\0', 0) != std::string::npos) {
909 
910                 // ensure we do not have garbage
911                 size_t offset = thestring.find('\0', 0);
912                 if (offset + 1 < optLen) optLen = offset + 1;
913                 thestring = thestring.substr(0, optLen);
914 
915               }
916 
917             }
918 
919             if (!be_quiet) {
920 
921               std::string outstring;
922               fprintf(stderr, "%s: length %zu/%zu \"", FuncName.c_str(), optLen,
923                       thestring.length());
924               for (uint8_t i = 0; i < thestring.length(); i++) {
925 
926                 uint8_t c = thestring[i];
927                 if (c <= 32 || c >= 127)
928                   fprintf(stderr, "\\x%02x", c);
929                 else
930                   fprintf(stderr, "%c", c);
931 
932               }
933 
934               fprintf(stderr, "\"\n");
935 
936             }
937 
938             // we take the longer string, even if the compare was to a
939             // shorter part. Note that depending on the optimizer of the
940             // compiler this can be wrong, but it is more likely that this
941             // is helping the fuzzer
942             if (optLen != thestring.length()) optLen = thestring.length();
943             if (optLen > MAX_AUTO_EXTRA) optLen = MAX_AUTO_EXTRA;
944             if (optLen < MIN_AUTO_EXTRA)  // too short? skip
945               continue;
946 
947             dictionary.push_back(thestring.substr(0, optLen));
948 
949           }
950 
951         }
952 
953       }
954 
955     }
956 
957   }
958 
959   // afl++ END
960 
961   SanCovTracePCIndir =
962       M.getOrInsertFunction(SanCovTracePCIndirName, VoidTy, IntptrTy);
963   // Make sure smaller parameters are zero-extended to i64 as required by the
964   // x86_64 ABI.
965   AttributeList SanCovTraceCmpZeroExtAL;
966   if (TargetTriple.getArch() == Triple::x86_64) {
967 
968     SanCovTraceCmpZeroExtAL =
969         SanCovTraceCmpZeroExtAL.addParamAttribute(*C, 0, Attribute::ZExt);
970     SanCovTraceCmpZeroExtAL =
971         SanCovTraceCmpZeroExtAL.addParamAttribute(*C, 1, Attribute::ZExt);
972 
973   }
974 
975   SanCovTracePC = M.getOrInsertFunction(SanCovTracePCName, VoidTy);
976 
977   // SanCovTracePCGuard =
978   //    M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, Int32PtrTy);
979 
980   for (auto &F : M)
981     instrumentFunction(F, DTCallback, PDTCallback);
982 
983   // afl++ START
984   if (dFile.is_open()) dFile.close();
985 
986   if (!getenv("AFL_LLVM_LTO_DONTWRITEID") || dictionary.size() || map_addr) {
987 
988     // yes we could create our own function, insert it into ctors ...
989     // but this would be a pain in the butt ... so we use afl-llvm-rt-lto.o
990 
991     Function *f = M.getFunction("__afl_auto_init_globals");
992 
993     if (!f) {
994 
995       fprintf(stderr,
996               "Error: init function could not be found (this should not "
997               "happen)\n");
998       exit(-1);
999 
1000     }
1001 
1002     BasicBlock *bb = &f->getEntryBlock();
1003     if (!bb) {
1004 
1005       fprintf(stderr,
1006               "Error: init function does not have an EntryBlock (this should "
1007               "not happen)\n");
1008       exit(-1);
1009 
1010     }
1011 
1012     BasicBlock::iterator IP = bb->getFirstInsertionPt();
1013     IRBuilder<>          IRB(&(*IP));
1014 
1015     if (map_addr) {
1016 
1017       GlobalVariable *AFLMapAddrFixed = new GlobalVariable(
1018           M, Int64Tyi, true, GlobalValue::ExternalLinkage, 0, "__afl_map_addr");
1019       ConstantInt *MapAddr = ConstantInt::get(Int64Tyi, map_addr);
1020       StoreInst *  StoreMapAddr = IRB.CreateStore(MapAddr, AFLMapAddrFixed);
1021       ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(StoreMapAddr);
1022 
1023     }
1024 
1025     if (getenv("AFL_LLVM_LTO_DONTWRITEID") == NULL) {
1026 
1027       uint32_t write_loc = afl_global_id;
1028 
1029       write_loc = (((afl_global_id + 8) >> 3) << 3);
1030 
1031       GlobalVariable *AFLFinalLoc =
1032           new GlobalVariable(M, Int32Tyi, true, GlobalValue::ExternalLinkage, 0,
1033                              "__afl_final_loc");
1034       ConstantInt *const_loc = ConstantInt::get(Int32Tyi, write_loc);
1035       StoreInst *  StoreFinalLoc = IRB.CreateStore(const_loc, AFLFinalLoc);
1036       ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(StoreFinalLoc);
1037 
1038     }
1039 
1040     if (dictionary.size()) {
1041 
1042       size_t memlen = 0, count = 0, offset = 0;
1043 
1044       // sort and unique the dictionary
1045       std::sort(dictionary.begin(), dictionary.end());
1046       auto last = std::unique(dictionary.begin(), dictionary.end());
1047       dictionary.erase(last, dictionary.end());
1048 
1049       for (auto token : dictionary) {
1050 
1051         memlen += token.length();
1052         count++;
1053 
1054       }
1055 
1056       if (!be_quiet)
1057         printf("AUTODICTIONARY: %lu string%s found\n", count,
1058                count == 1 ? "" : "s");
1059 
1060       if (count) {
1061 
1062         auto ptrhld = std::unique_ptr<char[]>(new char[memlen + count]);
1063 
1064         count = 0;
1065 
1066         for (auto token : dictionary) {
1067 
1068           if (offset + token.length() < 0xfffff0 && count < MAX_AUTO_EXTRAS) {
1069 
1070             ptrhld.get()[offset++] = (uint8_t)token.length();
1071             memcpy(ptrhld.get() + offset, token.c_str(), token.length());
1072             offset += token.length();
1073             count++;
1074 
1075           }
1076 
1077         }
1078 
1079         GlobalVariable *AFLDictionaryLen =
1080             new GlobalVariable(M, Int32Tyi, false, GlobalValue::ExternalLinkage,
1081                                0, "__afl_dictionary_len");
1082         ConstantInt *const_len = ConstantInt::get(Int32Tyi, offset);
1083         StoreInst *StoreDictLen = IRB.CreateStore(const_len, AFLDictionaryLen);
1084         ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(StoreDictLen);
1085 
1086         ArrayType *ArrayTy = ArrayType::get(IntegerType::get(Ctx, 8), offset);
1087         GlobalVariable *AFLInternalDictionary = new GlobalVariable(
1088             M, ArrayTy, true, GlobalValue::ExternalLinkage,
1089             ConstantDataArray::get(Ctx,
1090                                    *(new ArrayRef<char>(ptrhld.get(), offset))),
1091             "__afl_internal_dictionary");
1092         AFLInternalDictionary->setInitializer(ConstantDataArray::get(
1093             Ctx, *(new ArrayRef<char>(ptrhld.get(), offset))));
1094         AFLInternalDictionary->setConstant(true);
1095 
1096         GlobalVariable *AFLDictionary = new GlobalVariable(
1097             M, PointerType::get(Int8Tyi, 0), false,
1098             GlobalValue::ExternalLinkage, 0, "__afl_dictionary");
1099 
1100         Value *AFLDictOff = IRB.CreateGEP(Int8Ty, AFLInternalDictionary, Zero);
1101         Value *AFLDictPtr =
1102             IRB.CreatePointerCast(AFLDictOff, PointerType::get(Int8Tyi, 0));
1103         StoreInst *StoreDict = IRB.CreateStore(AFLDictPtr, AFLDictionary);
1104         ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(StoreDict);
1105 
1106       }
1107 
1108     }
1109 
1110   }
1111 
1112   /* Say something nice. */
1113 
1114   if (!be_quiet) {
1115 
1116     if (!inst)
1117       WARNF("No instrumentation targets found.");
1118     else {
1119 
1120       char modeline[100];
1121       snprintf(modeline, sizeof(modeline), "%s%s%s%s%s%s",
1122                getenv("AFL_HARDEN") ? "hardened" : "non-hardened",
1123                getenv("AFL_USE_ASAN") ? ", ASAN" : "",
1124                getenv("AFL_USE_MSAN") ? ", MSAN" : "",
1125                getenv("AFL_USE_TSAN") ? ", TSAN" : "",
1126                getenv("AFL_USE_CFISAN") ? ", CFISAN" : "",
1127                getenv("AFL_USE_UBSAN") ? ", UBSAN" : "");
1128       OKF("Instrumented %u locations (%u selects) without collisions (%llu "
1129           "collisions have been avoided) (%s mode).",
1130           inst, select_cnt, calculateCollisions(inst), modeline);
1131 
1132     }
1133 
1134   }
1135 
1136   // afl++ END
1137 
1138   // We don't reference these arrays directly in any of our runtime functions,
1139   // so we need to prevent them from being dead stripped.
1140   if (TargetTriple.isOSBinFormatMachO()) appendToUsed(M, GlobalsToAppendToUsed);
1141   appendToCompilerUsed(M, GlobalsToAppendToCompilerUsed);
1142   return true;
1143 
1144 }
1145 
1146 // True if block has successors and it dominates all of them.
isFullDominator(const BasicBlock * BB,const DominatorTree * DT)1147 static bool isFullDominator(const BasicBlock *BB, const DominatorTree *DT) {
1148 
1149   if (succ_begin(BB) == succ_end(BB)) return false;
1150 
1151   for (const BasicBlock *SUCC : make_range(succ_begin(BB), succ_end(BB))) {
1152 
1153     if (!DT->dominates(BB, SUCC)) return false;
1154 
1155   }
1156 
1157   return true;
1158 
1159 }
1160 
1161 // True if block has predecessors and it postdominates all of them.
isFullPostDominator(const BasicBlock * BB,const PostDominatorTree * PDT)1162 static bool isFullPostDominator(const BasicBlock *       BB,
1163                                 const PostDominatorTree *PDT) {
1164 
1165   if (pred_begin(BB) == pred_end(BB)) return false;
1166 
1167   for (const BasicBlock *PRED : make_range(pred_begin(BB), pred_end(BB))) {
1168 
1169     if (!PDT->dominates(BB, PRED)) return false;
1170 
1171   }
1172 
1173   return true;
1174 
1175 }
1176 
shouldInstrumentBlock(const Function & F,const BasicBlock * BB,const DominatorTree * DT,const PostDominatorTree * PDT,const SanitizerCoverageOptions & Options)1177 static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB,
1178                                   const DominatorTree *           DT,
1179                                   const PostDominatorTree *       PDT,
1180                                   const SanitizerCoverageOptions &Options) {
1181 
1182   // Don't insert coverage for blocks containing nothing but unreachable: we
1183   // will never call __sanitizer_cov() for them, so counting them in
1184   // NumberOfInstrumentedBlocks() might complicate calculation of code coverage
1185   // percentage. Also, unreachable instructions frequently have no debug
1186   // locations.
1187   if (isa<UnreachableInst>(BB->getFirstNonPHIOrDbgOrLifetime())) return false;
1188 
1189   // Don't insert coverage into blocks without a valid insertion point
1190   // (catchswitch blocks).
1191   if (BB->getFirstInsertionPt() == BB->end()) return false;
1192 
1193   // afl++ START
1194   if (!Options.NoPrune && &F.getEntryBlock() == BB && F.size() > 1)
1195     return false;
1196   // afl++ END
1197 
1198   if (Options.NoPrune || &F.getEntryBlock() == BB) return true;
1199 
1200   if (Options.CoverageType == SanitizerCoverageOptions::SCK_Function &&
1201       &F.getEntryBlock() != BB)
1202     return false;
1203 
1204   // Do not instrument full dominators, or full post-dominators with multiple
1205   // predecessors.
1206   return !isFullDominator(BB, DT) &&
1207          !(isFullPostDominator(BB, PDT) && !BB->getSinglePredecessor());
1208 
1209 }
1210 
instrumentFunction(Function & F,DomTreeCallback DTCallback,PostDomTreeCallback PDTCallback)1211 void ModuleSanitizerCoverageLTO::instrumentFunction(
1212     Function &F, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) {
1213 
1214   if (F.empty()) return;
1215   if (F.getName().find(".module_ctor") != std::string::npos)
1216     return;  // Should not instrument sanitizer init functions.
1217   if (F.getName().startswith("__sanitizer_"))
1218     return;  // Don't instrument __sanitizer_* callbacks.
1219   // Don't touch available_externally functions, their actual body is elsewhere.
1220   if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return;
1221   // Don't instrument MSVC CRT configuration helpers. They may run before normal
1222   // initialization.
1223   if (F.getName() == "__local_stdio_printf_options" ||
1224       F.getName() == "__local_stdio_scanf_options")
1225     return;
1226   if (isa<UnreachableInst>(F.getEntryBlock().getTerminator())) return;
1227   // Don't instrument functions using SEH for now. Splitting basic blocks like
1228   // we do for coverage breaks WinEHPrepare.
1229   // FIXME: Remove this when SEH no longer uses landingpad pattern matching.
1230   if (F.hasPersonalityFn() &&
1231       isAsynchronousEHPersonality(classifyEHPersonality(F.getPersonalityFn())))
1232     return;
1233   // if (Allowlist && !Allowlist->inSection("coverage", "fun", F.getName()))
1234   //  return;
1235   // if (Blocklist && Blocklist->inSection("coverage", "fun", F.getName()))
1236   // return;
1237 
1238   // afl++ START
1239   if (!F.size()) return;
1240   if (!isInInstrumentList(&F, FMNAME)) return;
1241   // afl++ END
1242 
1243   if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge)
1244     SplitAllCriticalEdges(
1245         F, CriticalEdgeSplittingOptions().setIgnoreUnreachableDests());
1246   SmallVector<Instruction *, 8> IndirCalls;
1247   SmallVector<BasicBlock *, 16> BlocksToInstrument;
1248 
1249   const DominatorTree *    DT = DTCallback(F);
1250   const PostDominatorTree *PDT = PDTCallback(F);
1251   bool                     IsLeafFunc = true;
1252   uint32_t                 skip_next = 0;
1253 
1254   for (auto &BB : F) {
1255 
1256     for (auto &IN : BB) {
1257 
1258       CallInst *callInst = nullptr;
1259 
1260       if ((callInst = dyn_cast<CallInst>(&IN))) {
1261 
1262         Function *Callee = callInst->getCalledFunction();
1263         if (!Callee) continue;
1264         if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
1265         StringRef FuncName = Callee->getName();
1266         if (!FuncName.compare(StringRef("dlopen")) ||
1267             !FuncName.compare(StringRef("_dlopen"))) {
1268 
1269           fprintf(stderr,
1270                   "WARNING: dlopen() detected. To have coverage for a library "
1271                   "that your target dlopen()'s this must either happen before "
1272                   "__AFL_INIT() or you must use AFL_PRELOAD to preload all "
1273                   "dlopen()'ed libraries!\n");
1274           continue;
1275 
1276         }
1277 
1278         if (FuncName.compare(StringRef("__afl_coverage_interesting"))) continue;
1279 
1280         Value *val = ConstantInt::get(Int32Ty, ++afl_global_id);
1281         callInst->setOperand(1, val);
1282         ++inst;
1283 
1284       }
1285 
1286       SelectInst *selectInst = nullptr;
1287 
1288       /*
1289             std::string errMsg;
1290             raw_string_ostream os(errMsg);
1291             IN.print(os);
1292             fprintf(stderr, "X(%u): %s\n", skip_next, os.str().c_str());
1293       */
1294       if (!skip_next && (selectInst = dyn_cast<SelectInst>(&IN))) {
1295 
1296         uint32_t    vector_cnt = 0;
1297         Value *     condition = selectInst->getCondition();
1298         Value *     result;
1299         auto        t = condition->getType();
1300         IRBuilder<> IRB(selectInst->getNextNode());
1301 
1302         ++select_cnt;
1303 
1304         if (t->getTypeID() == llvm::Type::IntegerTyID) {
1305 
1306           Value *val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
1307           Value *val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
1308           result = IRB.CreateSelect(condition, val1, val2);
1309           skip_next = 1;
1310           inst += 2;
1311 
1312         } else
1313 
1314 #if LLVM_VERSION_MAJOR >= 14
1315             if (t->getTypeID() == llvm::Type::FixedVectorTyID) {
1316 
1317           FixedVectorType *tt = dyn_cast<FixedVectorType>(t);
1318           if (tt) {
1319 
1320             uint32_t elements = tt->getElementCount().getFixedValue();
1321             vector_cnt = elements;
1322             inst += vector_cnt * 2;
1323             if (elements) {
1324 
1325               FixedVectorType *GuardPtr1 =
1326                   FixedVectorType::get(Int32Ty, elements);
1327               FixedVectorType *GuardPtr2 =
1328                   FixedVectorType::get(Int32Ty, elements);
1329               Value *x, *y;
1330 
1331               Value *val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
1332               Value *val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
1333               x = IRB.CreateInsertElement(GuardPtr1, val1, (uint64_t)0);
1334               y = IRB.CreateInsertElement(GuardPtr2, val2, (uint64_t)0);
1335 
1336               for (uint64_t i = 1; i < elements; i++) {
1337 
1338                 val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
1339                 val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
1340                 x = IRB.CreateInsertElement(GuardPtr1, val1, i);
1341                 y = IRB.CreateInsertElement(GuardPtr2, val2, i);
1342 
1343               }
1344 
1345               result = IRB.CreateSelect(condition, x, y);
1346               skip_next = 1;
1347 
1348             }
1349 
1350           }
1351 
1352         } else
1353 
1354 #endif
1355         {
1356 
1357           unhandled++;
1358           continue;
1359 
1360         }
1361 
1362         uint32_t vector_cur = 0;
1363         /* Load SHM pointer */
1364         LoadInst *MapPtr =
1365             IRB.CreateLoad(PointerType::get(Int8Ty, 0), AFLMapPtr);
1366         ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(MapPtr);
1367 
1368         while (1) {
1369 
1370           /* Get CurLoc */
1371           Value *MapPtrIdx = nullptr;
1372 
1373           /* Load counter for CurLoc */
1374           if (!vector_cnt) {
1375 
1376             MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, result);
1377 
1378           } else {
1379 
1380             auto element = IRB.CreateExtractElement(result, vector_cur++);
1381             MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, element);
1382 
1383           }
1384 
1385           if (use_threadsafe_counters) {
1386 
1387             IRB.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, MapPtrIdx, One,
1388 #if LLVM_VERSION_MAJOR >= 13
1389                                 llvm::MaybeAlign(1),
1390 #endif
1391                                 llvm::AtomicOrdering::Monotonic);
1392 
1393           } else {
1394 
1395             LoadInst *Counter = IRB.CreateLoad(IRB.getInt8Ty(), MapPtrIdx);
1396             ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(Counter);
1397 
1398             /* Update bitmap */
1399 
1400             Value *Incr = IRB.CreateAdd(Counter, One);
1401 
1402             if (skip_nozero == NULL) {
1403 
1404               auto cf = IRB.CreateICmpEQ(Incr, Zero);
1405               auto carry = IRB.CreateZExt(cf, Int8Ty);
1406               Incr = IRB.CreateAdd(Incr, carry);
1407 
1408             }
1409 
1410             auto nosan = IRB.CreateStore(Incr, MapPtrIdx);
1411             ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(nosan);
1412 
1413           }
1414 
1415           if (!vector_cnt || vector_cnt == vector_cur) { break; }
1416 
1417         }
1418 
1419         skip_next = 1;
1420 
1421       } else {
1422 
1423         skip_next = 0;
1424 
1425       }
1426 
1427     }
1428 
1429     if (shouldInstrumentBlock(F, &BB, DT, PDT, Options))
1430       BlocksToInstrument.push_back(&BB);
1431     for (auto &Inst : BB) {
1432 
1433       if (Options.IndirectCalls) {
1434 
1435         CallBase *CB = dyn_cast<CallBase>(&Inst);
1436         if (CB && !CB->getCalledFunction()) IndirCalls.push_back(&Inst);
1437 
1438       }
1439 
1440     }
1441 
1442   }
1443 
1444   InjectCoverage(F, BlocksToInstrument, IsLeafFunc);
1445   InjectCoverageForIndirectCalls(F, IndirCalls);
1446 
1447 }
1448 
CreateFunctionLocalArrayInSection(size_t NumElements,Function & F,Type * Ty,const char * Section)1449 GlobalVariable *ModuleSanitizerCoverageLTO::CreateFunctionLocalArrayInSection(
1450     size_t NumElements, Function &F, Type *Ty, const char *Section) {
1451 
1452   ArrayType *ArrayTy = ArrayType::get(Ty, NumElements);
1453   auto       Array = new GlobalVariable(
1454       *CurModule, ArrayTy, false, GlobalVariable::PrivateLinkage,
1455       Constant::getNullValue(ArrayTy), "__sancov_gen_");
1456 
1457 #if LLVM_VERSION_MAJOR >= 13
1458   if (TargetTriple.supportsCOMDAT() &&
1459       (TargetTriple.isOSBinFormatELF() || !F.isInterposable()))
1460     if (auto Comdat = getOrCreateFunctionComdat(F, TargetTriple))
1461       Array->setComdat(Comdat);
1462 #else
1463   if (TargetTriple.supportsCOMDAT() && !F.isInterposable())
1464     if (auto Comdat =
1465             GetOrCreateFunctionComdat(F, TargetTriple, CurModuleUniqueId))
1466       Array->setComdat(Comdat);
1467 #endif
1468   Array->setSection(getSectionName(Section));
1469   Array->setAlignment(Align(DL->getTypeStoreSize(Ty).getFixedSize()));
1470   GlobalsToAppendToUsed.push_back(Array);
1471   GlobalsToAppendToCompilerUsed.push_back(Array);
1472   MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F));
1473   Array->addMetadata(LLVMContext::MD_associated, *MD);
1474 
1475   return Array;
1476 
1477 }
1478 
CreatePCArray(Function & F,ArrayRef<BasicBlock * > AllBlocks)1479 GlobalVariable *ModuleSanitizerCoverageLTO::CreatePCArray(
1480     Function &F, ArrayRef<BasicBlock *> AllBlocks) {
1481 
1482   size_t N = AllBlocks.size();
1483   assert(N);
1484   SmallVector<Constant *, 32> PCs;
1485   IRBuilder<>                 IRB(&*F.getEntryBlock().getFirstInsertionPt());
1486   for (size_t i = 0; i < N; i++) {
1487 
1488     if (&F.getEntryBlock() == AllBlocks[i]) {
1489 
1490       PCs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy));
1491       PCs.push_back((Constant *)IRB.CreateIntToPtr(
1492           ConstantInt::get(IntptrTy, 1), IntptrPtrTy));
1493 
1494     } else {
1495 
1496       PCs.push_back((Constant *)IRB.CreatePointerCast(
1497           BlockAddress::get(AllBlocks[i]), IntptrPtrTy));
1498       PCs.push_back((Constant *)IRB.CreateIntToPtr(
1499           ConstantInt::get(IntptrTy, 0), IntptrPtrTy));
1500 
1501     }
1502 
1503   }
1504 
1505   auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, IntptrPtrTy,
1506                                                     SanCovPCsSectionName);
1507   PCArray->setInitializer(
1508       ConstantArray::get(ArrayType::get(IntptrPtrTy, N * 2), PCs));
1509   PCArray->setConstant(true);
1510 
1511   return PCArray;
1512 
1513 }
1514 
CreateFunctionLocalArrays(Function & F,ArrayRef<BasicBlock * > AllBlocks)1515 void ModuleSanitizerCoverageLTO::CreateFunctionLocalArrays(
1516     Function &F, ArrayRef<BasicBlock *> AllBlocks) {
1517 
1518   if (Options.TracePCGuard)
1519     FunctionGuardArray = CreateFunctionLocalArrayInSection(
1520         AllBlocks.size(), F, Int32Ty, SanCovGuardsSectionName);
1521   if (Options.Inline8bitCounters)
1522     Function8bitCounterArray = CreateFunctionLocalArrayInSection(
1523         AllBlocks.size(), F, Int8Ty, SanCovCountersSectionName);
1524   if (Options.InlineBoolFlag)
1525     FunctionBoolArray = CreateFunctionLocalArrayInSection(
1526         AllBlocks.size(), F, Int1Ty, SanCovBoolFlagSectionName);
1527   if (Options.PCTable) FunctionPCsArray = CreatePCArray(F, AllBlocks);
1528 
1529 }
1530 
InjectCoverage(Function & F,ArrayRef<BasicBlock * > AllBlocks,bool IsLeafFunc)1531 bool ModuleSanitizerCoverageLTO::InjectCoverage(
1532     Function &F, ArrayRef<BasicBlock *> AllBlocks, bool IsLeafFunc) {
1533 
1534   if (AllBlocks.empty()) return false;
1535   CreateFunctionLocalArrays(F, AllBlocks);
1536 
1537   for (size_t i = 0, N = AllBlocks.size(); i < N; i++) {
1538 
1539     // afl++ START
1540     if (BlockList.size()) {
1541 
1542       int skip = 0;
1543       for (uint32_t k = 0; k < BlockList.size(); k++) {
1544 
1545         if (AllBlocks[i] == BlockList[k]) {
1546 
1547           if (debug)
1548             fprintf(stderr,
1549                     "DEBUG: Function %s skipping BB with/after __afl_loop\n",
1550                     F.getName().str().c_str());
1551           skip = 1;
1552 
1553         }
1554 
1555       }
1556 
1557       if (skip) continue;
1558 
1559     }
1560 
1561     // afl++ END
1562 
1563     InjectCoverageAtBlock(F, *AllBlocks[i], i, IsLeafFunc);
1564 
1565   }
1566 
1567   return true;
1568 
1569 }
1570 
1571 // On every indirect call we call a run-time function
1572 // __sanitizer_cov_indir_call* with two parameters:
1573 //   - callee address,
1574 //   - global cache array that contains CacheSize pointers (zero-initialized).
1575 //     The cache is used to speed up recording the caller-callee pairs.
1576 // The address of the caller is passed implicitly via caller PC.
1577 // CacheSize is encoded in the name of the run-time function.
InjectCoverageForIndirectCalls(Function & F,ArrayRef<Instruction * > IndirCalls)1578 void ModuleSanitizerCoverageLTO::InjectCoverageForIndirectCalls(
1579     Function &F, ArrayRef<Instruction *> IndirCalls) {
1580 
1581   if (IndirCalls.empty()) return;
1582   assert(Options.TracePC || Options.TracePCGuard ||
1583          Options.Inline8bitCounters || Options.InlineBoolFlag);
1584   for (auto I : IndirCalls) {
1585 
1586     IRBuilder<> IRB(I);
1587     CallBase &  CB = cast<CallBase>(*I);
1588     Value *     Callee = CB.getCalledOperand();
1589     if (isa<InlineAsm>(Callee)) continue;
1590     IRB.CreateCall(SanCovTracePCIndir, IRB.CreatePointerCast(Callee, IntptrTy));
1591 
1592   }
1593 
1594 }
1595 
InjectCoverageAtBlock(Function & F,BasicBlock & BB,size_t Idx,bool IsLeafFunc)1596 void ModuleSanitizerCoverageLTO::InjectCoverageAtBlock(Function &  F,
1597                                                        BasicBlock &BB,
1598                                                        size_t      Idx,
1599                                                        bool        IsLeafFunc) {
1600 
1601   BasicBlock::iterator IP = BB.getFirstInsertionPt();
1602   bool                 IsEntryBB = &BB == &F.getEntryBlock();
1603 
1604   if (IsEntryBB) {
1605 
1606     // Keep static allocas and llvm.localescape calls in the entry block.  Even
1607     // if we aren't splitting the block, it's nice for allocas to be before
1608     // calls.
1609     IP = PrepareToSplitEntryBlock(BB, IP);
1610 
1611   }
1612 
1613   IRBuilder<> IRB(&*IP);
1614   if (Options.TracePC) {
1615 
1616     IRB.CreateCall(SanCovTracePC)
1617 #if LLVM_VERSION_MAJOR >= 12
1618         ->setCannotMerge();  // gets the PC using GET_CALLER_PC.
1619 #else
1620         ->cannotMerge();  // gets the PC using GET_CALLER_PC.
1621 #endif
1622 
1623   }
1624 
1625   if (Options.TracePCGuard) {
1626 
1627     // afl++ START
1628     ++afl_global_id;
1629 
1630     if (dFile.is_open()) {
1631 
1632       unsigned long long int moduleID =
1633           (((unsigned long long int)(rand() & 0xffffffff)) << 32) | getpid();
1634       dFile << "ModuleID=" << moduleID << " Function=" << F.getName().str()
1635             << " edgeID=" << afl_global_id << "\n";
1636 
1637     }
1638 
1639     /* Set the ID of the inserted basic block */
1640 
1641     ConstantInt *CurLoc = ConstantInt::get(Int32Tyi, afl_global_id);
1642 
1643     /* Load SHM pointer */
1644 
1645     Value *MapPtrIdx;
1646 
1647     if (map_addr) {
1648 
1649       MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtrFixed, CurLoc);
1650 
1651     } else {
1652 
1653       LoadInst *MapPtr = IRB.CreateLoad(PointerType::get(Int8Ty, 0), AFLMapPtr);
1654       ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(MapPtr);
1655       MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, CurLoc);
1656 
1657     }
1658 
1659     /* Update bitmap */
1660     if (use_threadsafe_counters) {                                /* Atomic */
1661 
1662       IRB.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, MapPtrIdx, One,
1663 #if LLVM_VERSION_MAJOR >= 13
1664                           llvm::MaybeAlign(1),
1665 #endif
1666                           llvm::AtomicOrdering::Monotonic);
1667 
1668     } else {
1669 
1670       LoadInst *Counter = IRB.CreateLoad(IRB.getInt8Ty(), MapPtrIdx);
1671       ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(Counter);
1672 
1673       Value *Incr = IRB.CreateAdd(Counter, One);
1674 
1675       if (skip_nozero == NULL) {
1676 
1677         auto cf = IRB.CreateICmpEQ(Incr, Zero);
1678         auto carry = IRB.CreateZExt(cf, Int8Tyi);
1679         Incr = IRB.CreateAdd(Incr, carry);
1680 
1681       }
1682 
1683       auto nosan = IRB.CreateStore(Incr, MapPtrIdx);
1684       ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(nosan);
1685 
1686     }
1687 
1688     // done :)
1689 
1690     inst++;
1691     // afl++ END
1692 
1693     /*
1694     XXXXXXXXXXXXXXXXXXX
1695 
1696         auto GuardPtr = IRB.CreateIntToPtr(
1697             IRB.CreateAdd(IRB.CreatePointerCast(FunctionGuardArray, IntptrTy),
1698                           ConstantInt::get(IntptrTy, Idx * 4)),
1699             Int32PtrTy);
1700 
1701         IRB.CreateCall(SanCovTracePCGuard, GuardPtr)->setCannotMerge();
1702     */
1703 
1704   }
1705 
1706   if (Options.Inline8bitCounters) {
1707 
1708     auto CounterPtr = IRB.CreateGEP(
1709         Function8bitCounterArray->getValueType(), Function8bitCounterArray,
1710         {ConstantInt::get(IntptrTy, 0), ConstantInt::get(IntptrTy, Idx)});
1711     auto Load = IRB.CreateLoad(Int8Ty, CounterPtr);
1712     auto Inc = IRB.CreateAdd(Load, ConstantInt::get(Int8Ty, 1));
1713     auto Store = IRB.CreateStore(Inc, CounterPtr);
1714     SetNoSanitizeMetadata(Load);
1715     SetNoSanitizeMetadata(Store);
1716 
1717   }
1718 
1719   if (Options.InlineBoolFlag) {
1720 
1721     auto FlagPtr = IRB.CreateGEP(
1722         FunctionBoolArray->getValueType(), FunctionBoolArray,
1723         {ConstantInt::get(IntptrTy, 0), ConstantInt::get(IntptrTy, Idx)});
1724     auto Load = IRB.CreateLoad(Int1Ty, FlagPtr);
1725     auto ThenTerm =
1726         SplitBlockAndInsertIfThen(IRB.CreateIsNull(Load), &*IP, false);
1727     IRBuilder<> ThenIRB(ThenTerm);
1728     auto Store = ThenIRB.CreateStore(ConstantInt::getTrue(Int1Ty), FlagPtr);
1729     SetNoSanitizeMetadata(Load);
1730     SetNoSanitizeMetadata(Store);
1731 
1732   }
1733 
1734 }
1735 
getSectionName(const std::string & Section) const1736 std::string ModuleSanitizerCoverageLTO::getSectionName(
1737     const std::string &Section) const {
1738 
1739   if (TargetTriple.isOSBinFormatCOFF()) {
1740 
1741     if (Section == SanCovCountersSectionName) return ".SCOV$CM";
1742     if (Section == SanCovBoolFlagSectionName) return ".SCOV$BM";
1743     if (Section == SanCovPCsSectionName) return ".SCOVP$M";
1744     return ".SCOV$GM";  // For SanCovGuardsSectionName.
1745 
1746   }
1747 
1748   if (TargetTriple.isOSBinFormatMachO()) return "__DATA,__" + Section;
1749   return "__" + Section;
1750 
1751 }
1752 
1753 char ModuleSanitizerCoverageLegacyPass::ID = 0;
1754 
1755 INITIALIZE_PASS_BEGIN(ModuleSanitizerCoverageLegacyPass, "sancov",
1756                       "Pass for instrumenting coverage on functions", false,
1757                       false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)1758 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
1759 INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)
1760 INITIALIZE_PASS_END(ModuleSanitizerCoverageLegacyPass, "sancov",
1761                     "Pass for instrumenting coverage on functions", false,
1762                     false)
1763 
1764 ModulePass *llvm::createModuleSanitizerCoverageLegacyPassPass(
1765     const SanitizerCoverageOptions &Options,
1766     const std::vector<std::string> &AllowlistFiles,
1767     const std::vector<std::string> &BlocklistFiles) {
1768 
1769   return new ModuleSanitizerCoverageLegacyPass(Options);
1770 
1771 }
1772 
registerLTOPass(const PassManagerBuilder &,legacy::PassManagerBase & PM)1773 static void registerLTOPass(const PassManagerBuilder &,
1774                             legacy::PassManagerBase &PM) {
1775 
1776   auto p = new ModuleSanitizerCoverageLegacyPass();
1777   PM.add(p);
1778 
1779 }
1780 
1781 static RegisterStandardPasses RegisterCompTransPass(
1782     PassManagerBuilder::EP_OptimizerLast, registerLTOPass);
1783 
1784 static RegisterStandardPasses RegisterCompTransPass0(
1785     PassManagerBuilder::EP_EnabledOnOptLevel0, registerLTOPass);
1786 
1787 #if LLVM_VERSION_MAJOR >= 11
1788 static RegisterStandardPasses RegisterCompTransPassLTO(
1789     PassManagerBuilder::EP_FullLinkTimeOptimizationLast, registerLTOPass);
1790 #endif
1791 
1792