• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2    american fuzzy lop++ - LLVM CmpLog instrumentation
3    --------------------------------------------------
4 
5    Written by Andrea Fioraldi <andreafioraldi@gmail.com>
6 
7    Copyright 2015, 2016 Google Inc. All rights reserved.
8    Copyright 2019-2022 AFLplusplus Project. All rights reserved.
9 
10    Licensed under the Apache License, Version 2.0 (the "License");
11    you may not use this file except in compliance with the License.
12    You may obtain a copy of the License at:
13 
14      https://www.apache.org/licenses/LICENSE-2.0
15 
16 */
17 
18 #include <stdio.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21 
22 #include <iostream>
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27 
28 #include "llvm/Config/llvm-config.h"
29 #include "llvm/ADT/Statistic.h"
30 #include "llvm/IR/IRBuilder.h"
31 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
32   #include "llvm/Passes/PassPlugin.h"
33   #include "llvm/Passes/PassBuilder.h"
34   #include "llvm/IR/PassManager.h"
35 #else
36   #include "llvm/IR/LegacyPassManager.h"
37   #include "llvm/Transforms/IPO/PassManagerBuilder.h"
38 #endif
39 #include "llvm/IR/Module.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Pass.h"
45 #include "llvm/Analysis/ValueTracking.h"
46 
47 #include "llvm/IR/IRBuilder.h"
48 #if LLVM_VERSION_MAJOR >= 4 || \
49     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
50   #include "llvm/IR/Verifier.h"
51   #include "llvm/IR/DebugInfo.h"
52 #else
53   #include "llvm/Analysis/Verifier.h"
54   #include "llvm/DebugInfo.h"
55   #define nullptr 0
56 #endif
57 
58 #include <set>
59 #include "afl-llvm-common.h"
60 
61 using namespace llvm;
62 
63 namespace {
64 
65 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
66 class CmplogSwitches : public PassInfoMixin<CmplogSwitches> {
67 
68  public:
CmplogSwitches()69   CmplogSwitches() {
70 
71 #else
72 class CmplogSwitches : public ModulePass {
73 
74  public:
75   static char ID;
76   CmplogSwitches() : ModulePass(ID) {
77 
78 #endif
79     initInstrumentList();
80 
81   }
82 
83 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
84   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
85 #else
86   bool        runOnModule(Module &M) override;
87 
88   #if LLVM_VERSION_MAJOR < 4
89   const char *getPassName() const override {
90 
91   #else
92   StringRef getPassName() const override {
93 
94   #endif
95     return "cmplog switch split";
96 
97   }
98 
99 #endif
100 
101  private:
102   bool hookInstrs(Module &M);
103 
104 };
105 
106 }  // namespace
107 
108 #if LLVM_MAJOR >= 11
109 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
110 llvmGetPassPluginInfo() {
111 
112   return {LLVM_PLUGIN_API_VERSION, "cmplogswitches", "v0.1",
113           /* lambda to insert our pass into the pass pipeline. */
114           [](PassBuilder &PB) {
115 
116   #if LLVM_VERSION_MAJOR <= 13
117             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
118   #endif
119             PB.registerOptimizerLastEPCallback(
120                 [](ModulePassManager &MPM, OptimizationLevel OL) {
121 
122                   MPM.addPass(CmplogSwitches());
123 
124                 });
125 
126           }};
127 
128 }
129 
130 #else
131 char CmplogSwitches::ID = 0;
132 #endif
133 
134 template <class Iterator>
135 Iterator Unique(Iterator first, Iterator last) {
136 
137   while (first != last) {
138 
139     Iterator next(first);
140     last = std::remove(++next, last, *first);
141     first = next;
142 
143   }
144 
145   return last;
146 
147 }
148 
149 bool CmplogSwitches::hookInstrs(Module &M) {
150 
151   std::vector<SwitchInst *> switches;
152   LLVMContext &             C = M.getContext();
153 
154   Type *       VoidTy = Type::getVoidTy(C);
155   IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
156   IntegerType *Int16Ty = IntegerType::getInt16Ty(C);
157   IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
158   IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
159 
160 #if LLVM_VERSION_MAJOR >= 9
161   FunctionCallee
162 #else
163   Constant *
164 #endif
165       c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty,
166                                  Int8Ty
167 #if LLVM_VERSION_MAJOR < 5
168                                  ,
169                                  NULL
170 #endif
171       );
172 #if LLVM_VERSION_MAJOR >= 9
173   FunctionCallee cmplogHookIns1 = c1;
174 #else
175   Function *cmplogHookIns1 = cast<Function>(c1);
176 #endif
177 
178 #if LLVM_VERSION_MAJOR >= 9
179   FunctionCallee
180 #else
181   Constant *
182 #endif
183       c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty,
184                                  Int8Ty
185 #if LLVM_VERSION_MAJOR < 5
186                                  ,
187                                  NULL
188 #endif
189       );
190 #if LLVM_VERSION_MAJOR >= 9
191   FunctionCallee cmplogHookIns2 = c2;
192 #else
193   Function *cmplogHookIns2 = cast<Function>(c2);
194 #endif
195 
196 #if LLVM_VERSION_MAJOR >= 9
197   FunctionCallee
198 #else
199   Constant *
200 #endif
201       c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty,
202                                  Int8Ty
203 #if LLVM_VERSION_MAJOR < 5
204                                  ,
205                                  NULL
206 #endif
207       );
208 #if LLVM_VERSION_MAJOR >= 9
209   FunctionCallee cmplogHookIns4 = c4;
210 #else
211   Function *cmplogHookIns4 = cast<Function>(c4);
212 #endif
213 
214 #if LLVM_VERSION_MAJOR >= 9
215   FunctionCallee
216 #else
217   Constant *
218 #endif
219       c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty,
220                                  Int8Ty
221 #if LLVM_VERSION_MAJOR < 5
222                                  ,
223                                  NULL
224 #endif
225       );
226 #if LLVM_VERSION_MAJOR >= 9
227   FunctionCallee cmplogHookIns8 = c8;
228 #else
229   Function *cmplogHookIns8 = cast<Function>(c8);
230 #endif
231 
232   GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map");
233 
234   if (!AFLCmplogPtr) {
235 
236     AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false,
237                                       GlobalValue::ExternalWeakLinkage, 0,
238                                       "__afl_cmp_map");
239 
240   }
241 
242   Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0));
243 
244   /* iterate over all functions, bbs and instruction and add suitable calls */
245   for (auto &F : M) {
246 
247     if (!isInInstrumentList(&F, MNAME)) continue;
248 
249     for (auto &BB : F) {
250 
251       SwitchInst *switchInst = nullptr;
252       if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
253 
254         if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); }
255 
256       }
257 
258     }
259 
260   }
261 
262   // unique the collected switches
263   switches.erase(Unique(switches.begin(), switches.end()), switches.end());
264 
265   // Instrument switch values for cmplog
266   if (switches.size()) {
267 
268     if (!be_quiet)
269       errs() << "Hooking " << switches.size() << " switch instructions\n";
270 
271     for (auto &SI : switches) {
272 
273       Value *       Val = SI->getCondition();
274       unsigned int  max_size = Val->getType()->getIntegerBitWidth(), cast_size;
275       unsigned char do_cast = 0;
276 
277       if (!SI->getNumCases() || max_size < 16) {
278 
279         // if (!be_quiet) errs() << "skip trivial switch..\n";
280         continue;
281 
282       }
283 
284       if (max_size % 8) {
285 
286         max_size = (((max_size / 8) + 1) * 8);
287         do_cast = 1;
288 
289       }
290 
291       IRBuilder<> IRB2(SI->getParent());
292       IRB2.SetInsertPoint(SI);
293 
294       LoadInst *CmpPtr = IRB2.CreateLoad(
295 #if LLVM_VERSION_MAJOR >= 14
296           PointerType::get(Int8Ty, 0),
297 #endif
298           AFLCmplogPtr);
299       CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
300       auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null);
301       auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, SI, false);
302 
303       IRBuilder<> IRB(ThenTerm);
304 
305       if (max_size > 128) {
306 
307         if (!be_quiet) {
308 
309           fprintf(stderr,
310                   "Cannot handle this switch bit size: %u (truncating)\n",
311                   max_size);
312 
313         }
314 
315         max_size = 128;
316         do_cast = 1;
317 
318       }
319 
320       // do we need to cast?
321       switch (max_size) {
322 
323         case 8:
324         case 16:
325         case 32:
326         case 64:
327         case 128:
328           cast_size = max_size;
329           break;
330         default:
331           cast_size = 128;
332           do_cast = 1;
333 
334       }
335 
336       Value *CompareTo = Val;
337 
338       if (do_cast) {
339 
340         CompareTo =
341             IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false);
342 
343       }
344 
345       for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
346            ++i) {
347 
348 #if LLVM_VERSION_MAJOR < 5
349         ConstantInt *cint = i.getCaseValue();
350 #else
351         ConstantInt *cint = i->getCaseValue();
352 #endif
353 
354         if (cint) {
355 
356           std::vector<Value *> args;
357           args.push_back(CompareTo);
358 
359           Value *new_param = cint;
360 
361           if (do_cast) {
362 
363             new_param =
364                 IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false);
365 
366           }
367 
368           if (new_param) {
369 
370             args.push_back(new_param);
371             ConstantInt *attribute = ConstantInt::get(Int8Ty, 1);
372             args.push_back(attribute);
373             if (cast_size != max_size) {
374 
375               ConstantInt *bitsize =
376                   ConstantInt::get(Int8Ty, (max_size / 8) - 1);
377               args.push_back(bitsize);
378 
379             }
380 
381             switch (cast_size) {
382 
383               case 8:
384                 IRB.CreateCall(cmplogHookIns1, args);
385                 break;
386               case 16:
387                 IRB.CreateCall(cmplogHookIns2, args);
388                 break;
389               case 32:
390                 IRB.CreateCall(cmplogHookIns4, args);
391                 break;
392               case 64:
393                 IRB.CreateCall(cmplogHookIns8, args);
394                 break;
395               case 128:
396 #ifdef WORD_SIZE_64
397                 if (max_size == 128) {
398 
399                   IRB.CreateCall(cmplogHookIns16, args);
400 
401                 } else {
402 
403                   IRB.CreateCall(cmplogHookInsN, args);
404 
405                 }
406 
407 #endif
408                 break;
409               default:
410                 break;
411 
412             }
413 
414           }
415 
416         }
417 
418       }
419 
420     }
421 
422   }
423 
424   if (switches.size())
425     return true;
426   else
427     return false;
428 
429 }
430 
431 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
432 PreservedAnalyses CmplogSwitches::run(Module &M, ModuleAnalysisManager &MAM) {
433 
434 #else
435 bool CmplogSwitches::runOnModule(Module &M) {
436 
437 #endif
438 
439   if (getenv("AFL_QUIET") == NULL)
440     printf("Running cmplog-switches-pass by andreafioraldi@gmail.com\n");
441   else
442     be_quiet = 1;
443   hookInstrs(M);
444 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
445   auto PA = PreservedAnalyses::all();
446 #endif
447   verifyModule(M);
448 
449 #if LLVM_VERSION_MAJOR >= 11                        /* use new pass manager */
450   return PA;
451 #else
452   return true;
453 #endif
454 
455 }
456 
457 #if LLVM_VERSION_MAJOR < 11                         /* use old pass manager */
458 static void registerCmplogSwitchesPass(const PassManagerBuilder &,
459                                        legacy::PassManagerBase &PM) {
460 
461   auto p = new CmplogSwitches();
462   PM.add(p);
463 
464 }
465 
466 static RegisterStandardPasses RegisterCmplogSwitchesPass(
467     PassManagerBuilder::EP_OptimizerLast, registerCmplogSwitchesPass);
468 
469 static RegisterStandardPasses RegisterCmplogSwitchesPass0(
470     PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmplogSwitchesPass);
471 
472   #if LLVM_VERSION_MAJOR >= 11
473 static RegisterStandardPasses RegisterCmplogSwitchesPassLTO(
474     PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
475     registerCmplogSwitchesPass);
476   #endif
477 #endif
478 
479