1 /* 2 american fuzzy lop++ - LLVM CmpLog instrumentation 3 -------------------------------------------------- 4 5 Written by Andrea Fioraldi <andreafioraldi@gmail.com> 6 7 Copyright 2015, 2016 Google Inc. All rights reserved. 8 Copyright 2019-2022 AFLplusplus Project. All rights reserved. 9 10 Licensed under the Apache License, Version 2.0 (the "License"); 11 you may not use this file except in compliance with the License. 12 You may obtain a copy of the License at: 13 14 https://www.apache.org/licenses/LICENSE-2.0 15 16 */ 17 18 #include <stdio.h> 19 #include <stdlib.h> 20 #include <unistd.h> 21 22 #include <list> 23 #include <string> 24 #include <fstream> 25 #include <sys/time.h> 26 #include "llvm/Config/llvm-config.h" 27 28 #include "llvm/ADT/Statistic.h" 29 #include "llvm/IR/IRBuilder.h" 30 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 31 #include "llvm/Passes/PassPlugin.h" 32 #include "llvm/Passes/PassBuilder.h" 33 #include "llvm/IR/PassManager.h" 34 #else 35 #include "llvm/IR/LegacyPassManager.h" 36 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 37 #endif 38 #include "llvm/IR/Module.h" 39 #include "llvm/Support/Debug.h" 40 #include "llvm/Support/raw_ostream.h" 41 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 42 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 43 #include "llvm/Pass.h" 44 #include "llvm/Analysis/ValueTracking.h" 45 46 #include "llvm/IR/IRBuilder.h" 47 #if LLVM_VERSION_MAJOR >= 4 || \ 48 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 49 #include "llvm/IR/Verifier.h" 50 #include "llvm/IR/DebugInfo.h" 51 #else 52 #include "llvm/Analysis/Verifier.h" 53 #include "llvm/DebugInfo.h" 54 #define nullptr 0 55 #endif 56 57 #include <set> 58 #include "afl-llvm-common.h" 59 60 using namespace llvm; 61 62 namespace { 63 64 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 65 class CmpLogRoutines : public PassInfoMixin<CmpLogRoutines> { 66 67 public: CmpLogRoutines()68 CmpLogRoutines() { 69 70 #else 71 class CmpLogRoutines : public ModulePass { 72 73 public: 74 static char ID; 75 CmpLogRoutines() : ModulePass(ID) { 76 77 #endif 78 79 initInstrumentList(); 80 81 } 82 83 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 84 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 85 #else 86 bool runOnModule(Module &M) override; 87 88 #if LLVM_VERSION_MAJOR >= 4 89 StringRef getPassName() const override { 90 91 #else 92 const char *getPassName() const override { 93 94 #endif 95 return "cmplog routines"; 96 97 } 98 99 #endif 100 101 private: 102 bool hookRtns(Module &M); 103 104 }; 105 106 } // namespace 107 108 #if LLVM_MAJOR >= 11 109 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 110 llvmGetPassPluginInfo() { 111 112 return {LLVM_PLUGIN_API_VERSION, "cmplogroutines", "v0.1", 113 /* lambda to insert our pass into the pass pipeline. */ 114 [](PassBuilder &PB) { 115 116 #if LLVM_VERSION_MAJOR <= 13 117 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 118 #endif 119 PB.registerOptimizerLastEPCallback( 120 [](ModulePassManager &MPM, OptimizationLevel OL) { 121 122 MPM.addPass(CmpLogRoutines()); 123 124 }); 125 126 }}; 127 128 } 129 130 #else 131 char CmpLogRoutines::ID = 0; 132 #endif 133 134 bool CmpLogRoutines::hookRtns(Module &M) { 135 136 std::vector<CallInst *> calls, llvmStdStd, llvmStdC, gccStdStd, gccStdC, 137 Memcmp, Strcmp, Strncmp; 138 LLVMContext &C = M.getContext(); 139 140 Type *VoidTy = Type::getVoidTy(C); 141 // PointerType *VoidPtrTy = PointerType::get(VoidTy, 0); 142 IntegerType *Int8Ty = IntegerType::getInt8Ty(C); 143 IntegerType *Int64Ty = IntegerType::getInt64Ty(C); 144 PointerType *i8PtrTy = PointerType::get(Int8Ty, 0); 145 146 #if LLVM_VERSION_MAJOR >= 9 147 FunctionCallee 148 #else 149 Constant * 150 #endif 151 c = M.getOrInsertFunction("__cmplog_rtn_hook", VoidTy, i8PtrTy, i8PtrTy 152 #if LLVM_VERSION_MAJOR < 5 153 , 154 NULL 155 #endif 156 ); 157 #if LLVM_VERSION_MAJOR >= 9 158 FunctionCallee cmplogHookFn = c; 159 #else 160 Function *cmplogHookFn = cast<Function>(c); 161 #endif 162 163 #if LLVM_VERSION_MAJOR >= 9 164 FunctionCallee 165 #else 166 Constant * 167 #endif 168 c1 = M.getOrInsertFunction("__cmplog_rtn_llvm_stdstring_stdstring", 169 VoidTy, i8PtrTy, i8PtrTy 170 #if LLVM_VERSION_MAJOR < 5 171 , 172 NULL 173 #endif 174 ); 175 #if LLVM_VERSION_MAJOR >= 9 176 FunctionCallee cmplogLlvmStdStd = c1; 177 #else 178 Function *cmplogLlvmStdStd = cast<Function>(c1); 179 #endif 180 181 #if LLVM_VERSION_MAJOR >= 9 182 FunctionCallee 183 #else 184 Constant * 185 #endif 186 c2 = M.getOrInsertFunction("__cmplog_rtn_llvm_stdstring_cstring", VoidTy, 187 i8PtrTy, i8PtrTy 188 #if LLVM_VERSION_MAJOR < 5 189 , 190 NULL 191 #endif 192 ); 193 #if LLVM_VERSION_MAJOR >= 9 194 FunctionCallee cmplogLlvmStdC = c2; 195 #else 196 Function *cmplogLlvmStdC = cast<Function>(c2); 197 #endif 198 199 #if LLVM_VERSION_MAJOR >= 9 200 FunctionCallee 201 #else 202 Constant * 203 #endif 204 c3 = M.getOrInsertFunction("__cmplog_rtn_gcc_stdstring_stdstring", VoidTy, 205 i8PtrTy, i8PtrTy 206 #if LLVM_VERSION_MAJOR < 5 207 , 208 NULL 209 #endif 210 ); 211 #if LLVM_VERSION_MAJOR >= 9 212 FunctionCallee cmplogGccStdStd = c3; 213 #else 214 Function *cmplogGccStdStd = cast<Function>(c3); 215 #endif 216 217 #if LLVM_VERSION_MAJOR >= 9 218 FunctionCallee 219 #else 220 Constant * 221 #endif 222 c4 = M.getOrInsertFunction("__cmplog_rtn_gcc_stdstring_cstring", VoidTy, 223 i8PtrTy, i8PtrTy 224 #if LLVM_VERSION_MAJOR < 5 225 , 226 NULL 227 #endif 228 ); 229 #if LLVM_VERSION_MAJOR >= 9 230 FunctionCallee cmplogGccStdC = c4; 231 #else 232 Function *cmplogGccStdC = cast<Function>(c4); 233 #endif 234 235 #if LLVM_VERSION_MAJOR >= 9 236 FunctionCallee 237 #else 238 Constant * 239 #endif 240 c5 = M.getOrInsertFunction("__cmplog_rtn_hook_n", VoidTy, i8PtrTy, 241 i8PtrTy, Int64Ty 242 #if LLVM_VERSION_MAJOR < 5 243 , 244 NULL 245 #endif 246 ); 247 #if LLVM_VERSION_MAJOR >= 9 248 FunctionCallee cmplogHookFnN = c5; 249 #else 250 Function *cmplogHookFnN = cast<Function>(c5); 251 #endif 252 253 #if LLVM_VERSION_MAJOR >= 9 254 FunctionCallee 255 #else 256 Constant * 257 #endif 258 c6 = M.getOrInsertFunction("__cmplog_rtn_hook_strn", VoidTy, i8PtrTy, 259 i8PtrTy, Int64Ty 260 #if LLVM_VERSION_MAJOR < 5 261 , 262 NULL 263 #endif 264 ); 265 #if LLVM_VERSION_MAJOR >= 9 266 FunctionCallee cmplogHookFnStrN = c6; 267 #else 268 Function *cmplogHookFnStrN = cast<Function>(c6); 269 #endif 270 271 #if LLVM_VERSION_MAJOR >= 9 272 FunctionCallee 273 #else 274 Constant * 275 #endif 276 c7 = M.getOrInsertFunction("__cmplog_rtn_hook_str", VoidTy, i8PtrTy, 277 i8PtrTy 278 #if LLVM_VERSION_MAJOR < 5 279 , 280 NULL 281 #endif 282 ); 283 #if LLVM_VERSION_MAJOR >= 9 284 FunctionCallee cmplogHookFnStr = c7; 285 #else 286 Function *cmplogHookFnStr = cast<Function>(c7); 287 #endif 288 289 GlobalVariable *AFLCmplogPtr = M.getNamedGlobal("__afl_cmp_map"); 290 291 if (!AFLCmplogPtr) { 292 293 AFLCmplogPtr = new GlobalVariable(M, PointerType::get(Int8Ty, 0), false, 294 GlobalValue::ExternalWeakLinkage, 0, 295 "__afl_cmp_map"); 296 297 } 298 299 Constant *Null = Constant::getNullValue(PointerType::get(Int8Ty, 0)); 300 301 /* iterate over all functions, bbs and instruction and add suitable calls */ 302 for (auto &F : M) { 303 304 if (!isInInstrumentList(&F, MNAME)) continue; 305 306 for (auto &BB : F) { 307 308 for (auto &IN : BB) { 309 310 CallInst *callInst = nullptr; 311 312 if ((callInst = dyn_cast<CallInst>(&IN))) { 313 314 Function *Callee = callInst->getCalledFunction(); 315 if (!Callee) continue; 316 if (callInst->getCallingConv() != llvm::CallingConv::C) continue; 317 318 FunctionType *FT = Callee->getFunctionType(); 319 std::string FuncName = Callee->getName().str(); 320 321 bool isPtrRtn = FT->getNumParams() >= 2 && 322 !FT->getReturnType()->isVoidTy() && 323 FT->getParamType(0) == FT->getParamType(1) && 324 FT->getParamType(0)->isPointerTy(); 325 326 bool isPtrRtnN = FT->getNumParams() >= 3 && 327 !FT->getReturnType()->isVoidTy() && 328 FT->getParamType(0) == FT->getParamType(1) && 329 FT->getParamType(0)->isPointerTy() && 330 FT->getParamType(2)->isIntegerTy(); 331 if (isPtrRtnN) { 332 333 auto intTyOp = 334 dyn_cast<IntegerType>(callInst->getArgOperand(2)->getType()); 335 if (intTyOp) { 336 337 if (intTyOp->getBitWidth() != 32 && 338 intTyOp->getBitWidth() != 64) { 339 340 isPtrRtnN = false; 341 342 } 343 344 } 345 346 } 347 348 bool isMemcmp = 349 (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") || 350 !FuncName.compare("CRYPTO_memcmp") || 351 !FuncName.compare("OPENSSL_memcmp") || 352 !FuncName.compare("memcmp_const_time") || 353 !FuncName.compare("memcmpct")); 354 isMemcmp &= FT->getNumParams() == 3 && 355 FT->getReturnType()->isIntegerTy(32) && 356 FT->getParamType(0)->isPointerTy() && 357 FT->getParamType(1)->isPointerTy() && 358 FT->getParamType(2)->isIntegerTy(); 359 360 bool isStrcmp = 361 (!FuncName.compare("strcmp") || !FuncName.compare("xmlStrcmp") || 362 !FuncName.compare("xmlStrEqual") || 363 !FuncName.compare("g_strcmp0") || 364 !FuncName.compare("curl_strequal") || 365 !FuncName.compare("strcsequal") || 366 !FuncName.compare("strcasecmp") || 367 !FuncName.compare("stricmp") || 368 !FuncName.compare("ap_cstr_casecmp") || 369 !FuncName.compare("OPENSSL_strcasecmp") || 370 !FuncName.compare("xmlStrcasecmp") || 371 !FuncName.compare("g_strcasecmp") || 372 !FuncName.compare("g_ascii_strcasecmp") || 373 !FuncName.compare("Curl_strcasecompare") || 374 !FuncName.compare("Curl_safe_strcasecompare") || 375 !FuncName.compare("cmsstrcasecmp") || 376 !FuncName.compare("strstr") || 377 !FuncName.compare("g_strstr_len") || 378 !FuncName.compare("ap_strcasestr") || 379 !FuncName.compare("xmlStrstr") || 380 !FuncName.compare("xmlStrcasestr") || 381 !FuncName.compare("g_str_has_prefix") || 382 !FuncName.compare("g_str_has_suffix")); 383 isStrcmp &= 384 FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && 385 FT->getParamType(0) == FT->getParamType(1) && 386 FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); 387 388 bool isStrncmp = (!FuncName.compare("strncmp") || 389 !FuncName.compare("xmlStrncmp") || 390 !FuncName.compare("curl_strnequal") || 391 !FuncName.compare("strncasecmp") || 392 !FuncName.compare("strnicmp") || 393 !FuncName.compare("ap_cstr_casecmpn") || 394 !FuncName.compare("OPENSSL_strncasecmp") || 395 !FuncName.compare("xmlStrncasecmp") || 396 !FuncName.compare("g_ascii_strncasecmp") || 397 !FuncName.compare("Curl_strncasecompare") || 398 !FuncName.compare("g_strncasecmp")); 399 isStrncmp &= FT->getNumParams() == 3 && 400 FT->getReturnType()->isIntegerTy(32) && 401 FT->getParamType(0) == FT->getParamType(1) && 402 FT->getParamType(0) == 403 IntegerType::getInt8PtrTy(M.getContext()) && 404 FT->getParamType(2)->isIntegerTy(); 405 406 bool isGccStdStringStdString = 407 Callee->getName().find("__is_charIT_EE7__value") != 408 std::string::npos && 409 Callee->getName().find( 410 "St7__cxx1112basic_stringIS2_St11char_traits") != 411 std::string::npos && 412 FT->getNumParams() >= 2 && 413 FT->getParamType(0) == FT->getParamType(1) && 414 FT->getParamType(0)->isPointerTy(); 415 416 bool isGccStdStringCString = 417 Callee->getName().find( 418 "St7__cxx1112basic_stringIcSt11char_" 419 "traitsIcESaIcEE7compareEPK") != std::string::npos && 420 FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() && 421 FT->getParamType(1)->isPointerTy(); 422 423 bool isLlvmStdStringStdString = 424 Callee->getName().find("_ZNSt3__1eqI") != std::string::npos && 425 Callee->getName().find("_12basic_stringI") != std::string::npos && 426 Callee->getName().find("_11char_traits") != std::string::npos && 427 FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() && 428 FT->getParamType(1)->isPointerTy(); 429 430 bool isLlvmStdStringCString = 431 Callee->getName().find("_ZNSt3__1eqI") != std::string::npos && 432 Callee->getName().find("_12basic_stringI") != std::string::npos && 433 FT->getNumParams() >= 2 && FT->getParamType(0)->isPointerTy() && 434 FT->getParamType(1)->isPointerTy(); 435 436 /* 437 { 438 439 fprintf(stderr, "F:%s C:%s argc:%u\n", 440 F.getName().str().c_str(), 441 Callee->getName().str().c_str(), FT->getNumParams()); 442 fprintf(stderr, "ptr0:%u ptr1:%u ptr2:%u\n", 443 FT->getParamType(0)->isPointerTy(), 444 FT->getParamType(1)->isPointerTy(), 445 FT->getNumParams() > 2 ? 446 FT->getParamType(2)->isPointerTy() : 22 ); 447 448 } 449 450 */ 451 452 if (isGccStdStringCString || isGccStdStringStdString || 453 isLlvmStdStringStdString || isLlvmStdStringCString || isMemcmp || 454 isStrcmp || isStrncmp) { 455 456 isPtrRtnN = isPtrRtn = false; 457 458 } 459 460 if (isPtrRtnN) { isPtrRtn = false; } 461 462 if (isPtrRtn) { calls.push_back(callInst); } 463 if (isMemcmp || isPtrRtnN) { Memcmp.push_back(callInst); } 464 if (isStrcmp) { Strcmp.push_back(callInst); } 465 if (isStrncmp) { Strncmp.push_back(callInst); } 466 if (isGccStdStringStdString) { gccStdStd.push_back(callInst); } 467 if (isGccStdStringCString) { gccStdC.push_back(callInst); } 468 if (isLlvmStdStringStdString) { llvmStdStd.push_back(callInst); } 469 if (isLlvmStdStringCString) { llvmStdC.push_back(callInst); } 470 471 } 472 473 } 474 475 } 476 477 } 478 479 if (!calls.size() && !gccStdStd.size() && !gccStdC.size() && 480 !llvmStdStd.size() && !llvmStdC.size() && !Memcmp.size() && 481 Strcmp.size() && Strncmp.size()) 482 return false; 483 484 /* 485 if (!be_quiet) 486 errs() << "Hooking " << calls.size() 487 << " calls with pointers as arguments\n"; 488 */ 489 490 for (auto &callInst : calls) { 491 492 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 493 494 IRBuilder<> IRB2(callInst->getParent()); 495 IRB2.SetInsertPoint(callInst); 496 497 LoadInst *CmpPtr = IRB2.CreateLoad( 498 #if LLVM_VERSION_MAJOR >= 14 499 PointerType::get(Int8Ty, 0), 500 #endif 501 AFLCmplogPtr); 502 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 503 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 504 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 505 506 IRBuilder<> IRB(ThenTerm); 507 508 std::vector<Value *> args; 509 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 510 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 511 args.push_back(v1Pcasted); 512 args.push_back(v2Pcasted); 513 514 IRB.CreateCall(cmplogHookFn, args); 515 516 // errs() << callInst->getCalledFunction()->getName() << "\n"; 517 518 } 519 520 for (auto &callInst : Memcmp) { 521 522 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1), 523 *v3P = callInst->getArgOperand(2); 524 525 IRBuilder<> IRB2(callInst->getParent()); 526 IRB2.SetInsertPoint(callInst); 527 528 LoadInst *CmpPtr = IRB2.CreateLoad( 529 #if LLVM_VERSION_MAJOR >= 14 530 PointerType::get(Int8Ty, 0), 531 #endif 532 AFLCmplogPtr); 533 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 534 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 535 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 536 537 IRBuilder<> IRB(ThenTerm); 538 539 std::vector<Value *> args; 540 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 541 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 542 Value * v3Pbitcast = IRB.CreateBitCast( 543 v3P, IntegerType::get(C, v3P->getType()->getPrimitiveSizeInBits())); 544 Value *v3Pcasted = 545 IRB.CreateIntCast(v3Pbitcast, IntegerType::get(C, 64), false); 546 args.push_back(v1Pcasted); 547 args.push_back(v2Pcasted); 548 args.push_back(v3Pcasted); 549 550 IRB.CreateCall(cmplogHookFnN, args); 551 552 // errs() << callInst->getCalledFunction()->getName() << "\n"; 553 554 } 555 556 for (auto &callInst : Strcmp) { 557 558 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 559 560 IRBuilder<> IRB2(callInst->getParent()); 561 IRB2.SetInsertPoint(callInst); 562 563 LoadInst *CmpPtr = IRB2.CreateLoad( 564 #if LLVM_VERSION_MAJOR >= 14 565 PointerType::get(Int8Ty, 0), 566 #endif 567 AFLCmplogPtr); 568 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 569 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 570 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 571 572 IRBuilder<> IRB(ThenTerm); 573 574 std::vector<Value *> args; 575 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 576 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 577 args.push_back(v1Pcasted); 578 args.push_back(v2Pcasted); 579 580 IRB.CreateCall(cmplogHookFnStr, args); 581 582 // errs() << callInst->getCalledFunction()->getName() << "\n"; 583 584 } 585 586 for (auto &callInst : Strncmp) { 587 588 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1), 589 *v3P = callInst->getArgOperand(2); 590 591 IRBuilder<> IRB2(callInst->getParent()); 592 IRB2.SetInsertPoint(callInst); 593 594 LoadInst *CmpPtr = IRB2.CreateLoad( 595 #if LLVM_VERSION_MAJOR >= 14 596 PointerType::get(Int8Ty, 0), 597 #endif 598 AFLCmplogPtr); 599 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 600 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 601 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 602 603 IRBuilder<> IRB(ThenTerm); 604 605 std::vector<Value *> args; 606 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 607 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 608 Value * v3Pbitcast = IRB.CreateBitCast( 609 v3P, IntegerType::get(C, v3P->getType()->getPrimitiveSizeInBits())); 610 Value *v3Pcasted = 611 IRB.CreateIntCast(v3Pbitcast, IntegerType::get(C, 64), false); 612 args.push_back(v1Pcasted); 613 args.push_back(v2Pcasted); 614 args.push_back(v3Pcasted); 615 616 IRB.CreateCall(cmplogHookFnStrN, args); 617 618 // errs() << callInst->getCalledFunction()->getName() << "\n"; 619 620 } 621 622 for (auto &callInst : gccStdStd) { 623 624 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 625 626 IRBuilder<> IRB2(callInst->getParent()); 627 IRB2.SetInsertPoint(callInst); 628 629 LoadInst *CmpPtr = IRB2.CreateLoad( 630 #if LLVM_VERSION_MAJOR >= 14 631 PointerType::get(Int8Ty, 0), 632 #endif 633 AFLCmplogPtr); 634 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 635 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 636 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 637 638 IRBuilder<> IRB(ThenTerm); 639 640 std::vector<Value *> args; 641 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 642 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 643 args.push_back(v1Pcasted); 644 args.push_back(v2Pcasted); 645 646 IRB.CreateCall(cmplogGccStdStd, args); 647 648 // errs() << callInst->getCalledFunction()->getName() << "\n"; 649 650 } 651 652 for (auto &callInst : gccStdC) { 653 654 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 655 656 IRBuilder<> IRB2(callInst->getParent()); 657 IRB2.SetInsertPoint(callInst); 658 659 LoadInst *CmpPtr = IRB2.CreateLoad( 660 #if LLVM_VERSION_MAJOR >= 14 661 PointerType::get(Int8Ty, 0), 662 #endif 663 AFLCmplogPtr); 664 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 665 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 666 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 667 668 IRBuilder<> IRB(ThenTerm); 669 670 std::vector<Value *> args; 671 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 672 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 673 args.push_back(v1Pcasted); 674 args.push_back(v2Pcasted); 675 676 IRB.CreateCall(cmplogGccStdC, args); 677 678 // errs() << callInst->getCalledFunction()->getName() << "\n"; 679 680 } 681 682 for (auto &callInst : llvmStdStd) { 683 684 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 685 686 IRBuilder<> IRB2(callInst->getParent()); 687 IRB2.SetInsertPoint(callInst); 688 689 LoadInst *CmpPtr = IRB2.CreateLoad( 690 #if LLVM_VERSION_MAJOR >= 14 691 PointerType::get(Int8Ty, 0), 692 #endif 693 AFLCmplogPtr); 694 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 695 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 696 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 697 698 IRBuilder<> IRB(ThenTerm); 699 700 std::vector<Value *> args; 701 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 702 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 703 args.push_back(v1Pcasted); 704 args.push_back(v2Pcasted); 705 706 IRB.CreateCall(cmplogLlvmStdStd, args); 707 708 // errs() << callInst->getCalledFunction()->getName() << "\n"; 709 710 } 711 712 for (auto &callInst : llvmStdC) { 713 714 Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1); 715 716 IRBuilder<> IRB2(callInst->getParent()); 717 IRB2.SetInsertPoint(callInst); 718 719 LoadInst *CmpPtr = IRB2.CreateLoad( 720 #if LLVM_VERSION_MAJOR >= 14 721 PointerType::get(Int8Ty, 0), 722 #endif 723 AFLCmplogPtr); 724 CmpPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); 725 auto is_not_null = IRB2.CreateICmpNE(CmpPtr, Null); 726 auto ThenTerm = SplitBlockAndInsertIfThen(is_not_null, callInst, false); 727 728 IRBuilder<> IRB(ThenTerm); 729 730 std::vector<Value *> args; 731 Value * v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy); 732 Value * v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy); 733 args.push_back(v1Pcasted); 734 args.push_back(v2Pcasted); 735 736 IRB.CreateCall(cmplogLlvmStdC, args); 737 738 // errs() << callInst->getCalledFunction()->getName() << "\n"; 739 740 } 741 742 return true; 743 744 } 745 746 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 747 PreservedAnalyses CmpLogRoutines::run(Module &M, ModuleAnalysisManager &MAM) { 748 749 #else 750 bool CmpLogRoutines::runOnModule(Module &M) { 751 752 #endif 753 754 if (getenv("AFL_QUIET") == NULL) 755 printf("Running cmplog-routines-pass by andreafioraldi@gmail.com\n"); 756 else 757 be_quiet = 1; 758 hookRtns(M); 759 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 760 auto PA = PreservedAnalyses::all(); 761 #endif 762 verifyModule(M); 763 764 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 765 return PA; 766 #else 767 return true; 768 #endif 769 770 } 771 772 #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 773 static void registerCmpLogRoutinesPass(const PassManagerBuilder &, 774 legacy::PassManagerBase &PM) { 775 776 auto p = new CmpLogRoutines(); 777 PM.add(p); 778 779 } 780 781 static RegisterStandardPasses RegisterCmpLogRoutinesPass( 782 PassManagerBuilder::EP_OptimizerLast, registerCmpLogRoutinesPass); 783 784 static RegisterStandardPasses RegisterCmpLogRoutinesPass0( 785 PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogRoutinesPass); 786 787 #if LLVM_VERSION_MAJOR >= 11 788 static RegisterStandardPasses RegisterCmpLogRoutinesPassLTO( 789 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 790 registerCmpLogRoutinesPass); 791 #endif 792 #endif 793 794