1 /* 2 american fuzzy lop++ - LLVM CmpLog instrumentation 3 -------------------------------------------------- 4 5 Written by Andrea Fioraldi <andreafioraldi@gmail.com> 6 7 Copyright 2015, 2016 Google Inc. All rights reserved. 8 Copyright 2019-2022 AFLplusplus Project. All rights reserved. 9 10 Licensed under the Apache License, Version 2.0 (the "License"); 11 you may not use this file except in compliance with the License. 12 You may obtain a copy of the License at: 13 14 https://www.apache.org/licenses/LICENSE-2.0 15 16 */ 17 18 #include <stdio.h> 19 #include <stdlib.h> 20 #include <unistd.h> 21 22 #include <iostream> 23 #include <list> 24 #include <string> 25 #include <fstream> 26 #include <sys/time.h> 27 28 #include "llvm/Config/llvm-config.h" 29 #include "llvm/ADT/Statistic.h" 30 #include "llvm/IR/IRBuilder.h" 31 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 32 #include "llvm/Passes/PassPlugin.h" 33 #include "llvm/Passes/PassBuilder.h" 34 #include "llvm/IR/PassManager.h" 35 #else 36 #include "llvm/IR/LegacyPassManager.h" 37 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 38 #endif 39 #include "llvm/IR/Module.h" 40 #include "llvm/Support/Debug.h" 41 #include "llvm/Support/raw_ostream.h" 42 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 43 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 44 #include "llvm/Pass.h" 45 #include "llvm/Analysis/ValueTracking.h" 46 47 #include "llvm/IR/IRBuilder.h" 48 #if LLVM_VERSION_MAJOR >= 4 || \ 49 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 50 #include "llvm/IR/Verifier.h" 51 #include "llvm/IR/DebugInfo.h" 52 #else 53 #include "llvm/Analysis/Verifier.h" 54 #include "llvm/DebugInfo.h" 55 #define nullptr 0 56 #endif 57 58 #include <set> 59 #include "afl-llvm-common.h" 60 61 using namespace llvm; 62 63 namespace { 64 65 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 66 class CmplogSwitches : public PassInfoMixin<CmplogSwitches> { 67 68 public: CmplogSwitches()69 CmplogSwitches() { 70 71 #else 72 class CmplogSwitches : public ModulePass { 73 74 public: 75 static char ID; 76 CmplogSwitches() : ModulePass(ID) { 77 78 #endif 79 initInstrumentList(); 80 81 } 82 83 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 84 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 85 #else 86 bool runOnModule(Module &M) override; 87 88 #if LLVM_VERSION_MAJOR < 4 89 const char *getPassName() const override { 90 91 #else 92 StringRef getPassName() const override { 93 94 #endif 95 return "cmplog switch split"; 96 97 } 98 99 #endif 100 101 private: 102 bool hookInstrs(Module &M); 103 104 }; 105 106 } // namespace 107 108 #if LLVM_MAJOR >= 11 109 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 110 llvmGetPassPluginInfo() { 111 112 return {LLVM_PLUGIN_API_VERSION, "cmplogswitches", "v0.1", 113 /* lambda to insert our pass into the pass pipeline. */ 114 [](PassBuilder &PB) { 115 116 #if LLVM_VERSION_MAJOR <= 13 117 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 118 #endif 119 PB.registerOptimizerLastEPCallback( 120 [](ModulePassManager &MPM, OptimizationLevel OL) { 121 122 MPM.addPass(CmplogSwitches()); 123 124 }); 125 126 }}; 127 128 } 129 130 #else 131 char CmplogSwitches::ID = 0; 132 #endif 133 134 template <class Iterator> 135 Iterator Unique(Iterator first, Iterator last) { 136 137 while (first != last) { 138 139 Iterator next(first); 140 last = std::remove(++next, last, *first); 141 first = next; 142 143 } 144 145 return last; 146 147 } 148 149 bool CmplogSwitches::hookInstrs(Module &M) { 150 151 std::vector<SwitchInst *> switches; 152 LLVMContext & C = M.getContext(); 153 154 Type * VoidTy = Type::getVoidTy(C); 155 IntegerType *Int8Ty = IntegerType::getInt8Ty(C); 156 IntegerType *Int16Ty = IntegerType::getInt16Ty(C); 157 IntegerType *Int32Ty = IntegerType::getInt32Ty(C); 158 IntegerType *Int64Ty = IntegerType::getInt64Ty(C); 159 160 #if LLVM_VERSION_MAJOR >= 9 161 FunctionCallee 162 #else 163 Constant * 164 #endif 165 c1 = M.getOrInsertFunction("__cmplog_ins_hook1", VoidTy, Int8Ty, Int8Ty, 166 Int8Ty 167 #if LLVM_VERSION_MAJOR < 5 168 , 169 NULL 170 #endif 171 ); 172 #if LLVM_VERSION_MAJOR >= 9 173 FunctionCallee cmplogHookIns1 = c1; 174 #else 175 Function *cmplogHookIns1 = cast<Function>(c1); 176 #endif 177 178 #if LLVM_VERSION_MAJOR >= 9 179 FunctionCallee 180 #else 181 Constant * 182 #endif 183 c2 = M.getOrInsertFunction("__cmplog_ins_hook2", VoidTy, Int16Ty, Int16Ty, 184 Int8Ty 185 #if LLVM_VERSION_MAJOR < 5 186 , 187 NULL 188 #endif 189 ); 190 #if LLVM_VERSION_MAJOR >= 9 191 FunctionCallee cmplogHookIns2 = c2; 192 #else 193 Function *cmplogHookIns2 = cast<Function>(c2); 194 #endif 195 196 #if LLVM_VERSION_MAJOR >= 9 197 FunctionCallee 198 #else 199 Constant * 200 #endif 201 c4 = M.getOrInsertFunction("__cmplog_ins_hook4", VoidTy, Int32Ty, Int32Ty, 202 Int8Ty 203 #if LLVM_VERSION_MAJOR < 5 204 , 205 NULL 206 #endif 207 ); 208 #if LLVM_VERSION_MAJOR >= 9 209 FunctionCallee cmplogHookIns4 = c4; 210 #else 211 Function *cmplogHookIns4 = cast<Function>(c4); 212 #endif 213 214 #if LLVM_VERSION_MAJOR >= 9 215 FunctionCallee 216 #else 217 Constant * 218 #endif 219 c8 = M.getOrInsertFunction("__cmplog_ins_hook8", VoidTy, Int64Ty, Int64Ty, 220 Int8Ty 221 #if LLVM_VERSION_MAJOR < 5 222 , 223 NULL 224 #endif 225 ); 226 #if LLVM_VERSION_MAJOR >= 9 227 FunctionCallee cmplogHookIns8 = c8; 228 #else 229 Function *cmplogHookIns8 = cast<Function>(c8); 230 #endif 231 232 GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); 233 234 if (!AFLCmplogPtr) { 235 236 AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, 237 GlobalValue::ExternalWeakLinkage, 0, 238 "__afl_cmp_map"); 239 240 } 241 242 Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); 243 244 /* iterate over all functions, bbs and instruction and add suitable calls */ 245 for (auto &F : M) { 246 247 if (!isInInstrumentList(&F, MNAME)) continue; 248 249 for (auto &BB : F) { 250 251 SwitchInst *switchInst = nullptr; 252 if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { 253 254 if (switchInst->getNumCases() > 1) { switches.push_back(switchInst); } 255 256 } 257 258 } 259 260 } 261 262 // unique the collected switches 263 switches.erase(Unique(switches.begin(), switches.end()), switches.end()); 264 265 // Instrument switch values for cmplog 266 if (switches.size()) { 267 268 if (!be_quiet) 269 errs() << "Hooking " << switches.size() << " switch instructions\n"; 270 271 for (auto &SI : switches) { 272 273 Value * Val = SI->getCondition(); 274 unsigned int max_size = Val->getType()->getIntegerBitWidth(), cast_size; 275 unsigned char do_cast = 0; 276 277 if (!SI->getNumCases() || max_size < 16) { 278 279 // if (!be_quiet) errs() << "skip trivial switch..\n"; 280 continue; 281 282 } 283 284 if (max_size % 8) { 285 286 max_size = (((max_size / 8) + 1) * 8); 287 do_cast = 1; 288 289 } 290 291 IRBuilder<> IRB2(SI->getParent()); 292 IRB2.SetInsertPoint(SI); 293 294 LoadInst *CmpPtr = IRB2.CreateLoad( 295 #if LLVM_VERSION_MAJOR >= 14 296 PointerType::get(Int8Ty, 0), 297 #endif 298 AFLCmplogPtr); 299 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 300 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 301 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, SI, false); 302 303 IRBuilder<> IRB(ThenTerm); 304 305 if (max_size > 128) { 306 307 if (!be_quiet) { 308 309 fprintf(stderr, 310 "Cannot handle this switch bit size: %u (truncating)\n", 311 max_size); 312 313 } 314 315 max_size = 128; 316 do_cast = 1; 317 318 } 319 320 // do we need to cast? 321 switch (max_size) { 322 323 case 8: 324 case 16: 325 case 32: 326 case 64: 327 case 128: 328 cast_size = max_size; 329 break; 330 default: 331 cast_size = 128; 332 do_cast = 1; 333 334 } 335 336 Value *CompareTo = Val; 337 338 if (do_cast) { 339 340 CompareTo = 341 IRB.CreateIntCast(CompareTo, IntegerType::get(C, cast_size), false); 342 343 } 344 345 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; 346 ++i) { 347 348 #if LLVM_VERSION_MAJOR < 5 349 ConstantInt *cint = i.getCaseValue(); 350 #else 351 ConstantInt *cint = i->getCaseValue(); 352 #endif 353 354 if (cint) { 355 356 std::vector<Value *> args; 357 args.push_back(CompareTo); 358 359 Value *new_param = cint; 360 361 if (do_cast) { 362 363 new_param = 364 IRB.CreateIntCast(cint, IntegerType::get(C, cast_size), false); 365 366 } 367 368 if (new_param) { 369 370 args.push_back(new_param); 371 ConstantInt *attribute = ConstantInt::get(Int8Ty, 1); 372 args.push_back(attribute); 373 if (cast_size != max_size) { 374 375 ConstantInt *bitsize = 376 ConstantInt::get(Int8Ty, (max_size / 8) - 1); 377 args.push_back(bitsize); 378 379 } 380 381 switch (cast_size) { 382 383 case 8: 384 IRB.CreateCall(cmplogHookIns1, args); 385 break; 386 case 16: 387 IRB.CreateCall(cmplogHookIns2, args); 388 break; 389 case 32: 390 IRB.CreateCall(cmplogHookIns4, args); 391 break; 392 case 64: 393 IRB.CreateCall(cmplogHookIns8, args); 394 break; 395 case 128: 396 #ifdef WORD_SIZE_64 397 if (max_size == 128) { 398 399 IRB.CreateCall(cmplogHookIns16, args); 400 401 } else { 402 403 IRB.CreateCall(cmplogHookInsN, args); 404 405 } 406 407 #endif 408 break; 409 default: 410 break; 411 412 } 413 414 } 415 416 } 417 418 } 419 420 } 421 422 } 423 424 if (switches.size()) 425 return true; 426 else 427 return false; 428 429 } 430 431 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 432 PreservedAnalyses CmplogSwitches::run(Module &M, ModuleAnalysisManager &MAM) { 433 434 #else 435 bool CmplogSwitches::runOnModule(Module &M) { 436 437 #endif 438 439 if (getenv("AFL_QUIET") == NULL) 440 printf("Running cmplog-switches-pass by andreafioraldi@gmail.com\n"); 441 else 442 be_quiet = 1; 443 hookInstrs(M); 444 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 445 auto PA = PreservedAnalyses::all(); 446 #endif 447 verifyModule(M); 448 449 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 450 return PA; 451 #else 452 return true; 453 #endif 454 455 } 456 457 #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 458 static void registerCmplogSwitchesPass(const PassManagerBuilder &, 459 legacy::PassManagerBase &PM) { 460 461 auto p = new CmplogSwitches(); 462 PM.add(p); 463 464 } 465 466 static RegisterStandardPasses RegisterCmplogSwitchesPass( 467 PassManagerBuilder::EP_OptimizerLast, registerCmplogSwitchesPass); 468 469 static RegisterStandardPasses RegisterCmplogSwitchesPass0( 470 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmplogSwitchesPass); 471 472 #if LLVM_VERSION_MAJOR >= 11 473 static RegisterStandardPasses RegisterCmplogSwitchesPassLTO( 474 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 475 registerCmplogSwitchesPass); 476 #endif 477 #endif 478 479