1 /* 2 * Copyright 2016 laf-intel 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * https://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <stdio.h> 18 #include <stdlib.h> 19 #include <unistd.h> 20 21 #include <list> 22 #include <string> 23 #include <fstream> 24 #include <sys/time.h> 25 26 #include "llvm/Config/llvm-config.h" 27 28 #include "llvm/ADT/Statistic.h" 29 #include "llvm/IR/IRBuilder.h" 30 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 31 #include "llvm/Passes/PassPlugin.h" 32 #include "llvm/Passes/PassBuilder.h" 33 #include "llvm/IR/PassManager.h" 34 #else 35 #include "llvm/IR/LegacyPassManager.h" 36 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 37 #endif 38 #include "llvm/IR/Module.h" 39 #include "llvm/Support/Debug.h" 40 #include "llvm/Support/raw_ostream.h" 41 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 42 #include "llvm/Pass.h" 43 #include "llvm/Analysis/ValueTracking.h" 44 #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */ 45 #include "llvm/Passes/OptimizationLevel.h" 46 #endif 47 48 #include "llvm/IR/IRBuilder.h" 49 #if LLVM_VERSION_MAJOR >= 4 || \ 50 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4) 51 #include "llvm/IR/Verifier.h" 52 #include "llvm/IR/DebugInfo.h" 53 #else 54 #include "llvm/Analysis/Verifier.h" 55 #include "llvm/DebugInfo.h" 56 #define nullptr 0 57 #endif 58 59 #include <set> 60 #include "afl-llvm-common.h" 61 62 using namespace llvm; 63 64 namespace { 65 66 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 67 class SplitSwitchesTransform : public PassInfoMixin<SplitSwitchesTransform> { 68 69 public: SplitSwitchesTransform()70 SplitSwitchesTransform() { 71 72 #else 73 class SplitSwitchesTransform : public ModulePass { 74 75 public: 76 static char ID; 77 SplitSwitchesTransform() : ModulePass(ID) { 78 79 #endif 80 initInstrumentList(); 81 82 } 83 84 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 85 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); 86 #else 87 bool runOnModule(Module &M) override; 88 89 #if LLVM_VERSION_MAJOR >= 4 90 StringRef getPassName() const override { 91 92 #else 93 const char *getPassName() const override { 94 95 #endif 96 return "splits switch constructs"; 97 98 } 99 100 #endif 101 102 struct CaseExpr { 103 104 ConstantInt *Val; 105 BasicBlock * BB; 106 107 CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) 108 : Val(val), BB(bb) { 109 110 } 111 112 }; 113 114 using CaseVector = std::vector<CaseExpr>; 115 116 private: 117 bool splitSwitches(Module &M); 118 bool transformCmps(Module &M, const bool processStrcmp, 119 const bool processMemcmp); 120 BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, 121 BasicBlock *OrigBlock, BasicBlock *NewDefault, 122 Value *Val, unsigned level); 123 124 }; 125 126 } // namespace 127 128 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 129 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK 130 llvmGetPassPluginInfo() { 131 132 return {LLVM_PLUGIN_API_VERSION, "splitswitches", "v0.1", 133 /* lambda to insert our pass into the pass pipeline. */ 134 [](PassBuilder &PB) { 135 136 #if 1 137 #if LLVM_VERSION_MAJOR <= 13 138 using OptimizationLevel = typename PassBuilder::OptimizationLevel; 139 #endif 140 PB.registerOptimizerLastEPCallback( 141 [](ModulePassManager &MPM, OptimizationLevel OL) { 142 143 MPM.addPass(SplitSwitchesTransform()); 144 145 }); 146 147 /* TODO LTO registration */ 148 #else 149 using PipelineElement = typename PassBuilder::PipelineElement; 150 PB.registerPipelineParsingCallback([](StringRef Name, 151 ModulePassManager &MPM, 152 ArrayRef<PipelineElement>) { 153 154 if (Name == "splitswitches") { 155 156 MPM.addPass(SplitSwitchesTransform()); 157 return true; 158 159 } else { 160 161 return false; 162 163 } 164 165 }); 166 167 #endif 168 169 }}; 170 171 } 172 173 #else 174 char SplitSwitchesTransform::ID = 0; 175 #endif 176 177 /* switchConvert - Transform simple list of Cases into list of CaseRange's */ 178 BasicBlock *SplitSwitchesTransform::switchConvert( 179 CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock, 180 BasicBlock *NewDefault, Value *Val, unsigned level) { 181 182 unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); 183 IntegerType *ValType = 184 IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); 185 IntegerType * ByteType = IntegerType::get(OrigBlock->getContext(), 8); 186 unsigned BytesInValue = bytesChecked.size(); 187 std::vector<uint8_t> setSizes; 188 std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>()); 189 190 /* for each of the possible cases we iterate over all bytes of the values 191 * build a set of possible values at each byte position in byteSets */ 192 for (CaseExpr &Case : Cases) { 193 194 for (unsigned i = 0; i < BytesInValue; i++) { 195 196 uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF; 197 byteSets[i].insert(byte); 198 199 } 200 201 } 202 203 /* find the index of the first byte position that was not yet checked. then 204 * save the number of possible values at that byte position */ 205 unsigned smallestIndex = 0; 206 unsigned smallestSize = 257; 207 for (unsigned i = 0; i < byteSets.size(); i++) { 208 209 if (bytesChecked[i]) continue; 210 if (byteSets[i].size() < smallestSize) { 211 212 smallestIndex = i; 213 smallestSize = byteSets[i].size(); 214 215 } 216 217 } 218 219 assert(bytesChecked[smallestIndex] == false); 220 221 /* there are only smallestSize different bytes at index smallestIndex */ 222 223 Instruction *Shift, *Trunc; 224 Function * F = OrigBlock->getParent(); 225 BasicBlock * NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); 226 Shift = BinaryOperator::Create(Instruction::LShr, Val, 227 ConstantInt::get(ValType, smallestIndex * 8)); 228 NewNode->getInstList().push_back(Shift); 229 230 if (ValTypeBitWidth > 8) { 231 232 Trunc = new TruncInst(Shift, ByteType); 233 NewNode->getInstList().push_back(Trunc); 234 235 } else { 236 237 /* not necessary to trunc */ 238 Trunc = Shift; 239 240 } 241 242 /* this is a trivial case, we can directly check for the byte, 243 * if the byte is not found go to default. if the byte was found 244 * mark the byte as checked. if this was the last byte to check 245 * we can finally execute the block belonging to this case */ 246 247 if (smallestSize == 1) { 248 249 uint8_t byte = *(byteSets[smallestIndex].begin()); 250 251 /* insert instructions to check whether the value we are switching on is 252 * equal to byte */ 253 ICmpInst *Comp = 254 new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), 255 "byteMatch"); 256 NewNode->getInstList().push_back(Comp); 257 258 bytesChecked[smallestIndex] = true; 259 bool allBytesAreChecked = true; 260 261 for (std::vector<bool>::iterator BCI = bytesChecked.begin(), 262 E = bytesChecked.end(); 263 BCI != E; ++BCI) { 264 265 if (!*BCI) { 266 267 allBytesAreChecked = false; 268 break; 269 270 } 271 272 } 273 274 // if (std::all_of(bytesChecked.begin(), bytesChecked.end(), 275 // [](bool b) { return b; })) { 276 277 if (allBytesAreChecked) { 278 279 assert(Cases.size() == 1); 280 BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode); 281 282 /* we have to update the phi nodes! */ 283 for (BasicBlock::iterator I = Cases[0].BB->begin(); 284 I != Cases[0].BB->end(); ++I) { 285 286 if (!isa<PHINode>(&*I)) { continue; } 287 PHINode *PN = cast<PHINode>(I); 288 289 /* Only update the first occurrence. */ 290 unsigned Idx = 0, E = PN->getNumIncomingValues(); 291 for (; Idx != E; ++Idx) { 292 293 if (PN->getIncomingBlock(Idx) == OrigBlock) { 294 295 PN->setIncomingBlock(Idx, NewNode); 296 break; 297 298 } 299 300 } 301 302 } 303 304 } else { 305 306 BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, 307 Val, level + 1); 308 BranchInst::Create(BB, NewDefault, Comp, NewNode); 309 310 } 311 312 } 313 314 /* there is no byte which we can directly check on, split the tree */ 315 else { 316 317 std::vector<uint8_t> byteVector; 318 std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), 319 std::back_inserter(byteVector)); 320 std::sort(byteVector.begin(), byteVector.end()); 321 uint8_t pivot = byteVector[byteVector.size() / 2]; 322 323 /* we already chose to divide the cases based on the value of byte at index 324 * smallestIndex the pivot value determines the threshold for the decicion; 325 * if a case value 326 * is smaller at this byte index move it to the LHS vector, otherwise to the 327 * RHS vector */ 328 329 CaseVector LHSCases, RHSCases; 330 331 for (CaseExpr &Case : Cases) { 332 333 uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF; 334 335 if (byte < pivot) { 336 337 LHSCases.push_back(Case); 338 339 } else { 340 341 RHSCases.push_back(Case); 342 343 } 344 345 } 346 347 BasicBlock *LBB, *RBB; 348 LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, 349 level + 1); 350 RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, 351 level + 1); 352 353 /* insert instructions to check whether the value we are switching on is 354 * equal to byte */ 355 ICmpInst *Comp = 356 new ICmpInst(ICmpInst::ICMP_ULT, Trunc, 357 ConstantInt::get(ByteType, pivot), "byteMatch"); 358 NewNode->getInstList().push_back(Comp); 359 BranchInst::Create(LBB, RBB, Comp, NewNode); 360 361 } 362 363 return NewNode; 364 365 } 366 367 bool SplitSwitchesTransform::splitSwitches(Module &M) { 368 369 #if (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR < 7) 370 LLVMContext &C = M.getContext(); 371 #endif 372 373 std::vector<SwitchInst *> switches; 374 375 /* iterate over all functions, bbs and instruction and add 376 * all switches to switches vector for later processing */ 377 for (auto &F : M) { 378 379 if (!isInInstrumentList(&F, MNAME)) continue; 380 381 for (auto &BB : F) { 382 383 SwitchInst *switchInst = nullptr; 384 385 if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { 386 387 if (switchInst->getNumCases() < 1) continue; 388 switches.push_back(switchInst); 389 390 } 391 392 } 393 394 } 395 396 if (!switches.size()) return false; 397 /* 398 if (!be_quiet) 399 errs() << "Rewriting " << switches.size() << " switch statements " 400 << "\n"; 401 */ 402 for (auto &SI : switches) { 403 404 BasicBlock *CurBlock = SI->getParent(); 405 BasicBlock *OrigBlock = CurBlock; 406 Function * F = CurBlock->getParent(); 407 /* this is the value we are switching on */ 408 Value * Val = SI->getCondition(); 409 BasicBlock *Default = SI->getDefaultDest(); 410 unsigned bitw = Val->getType()->getIntegerBitWidth(); 411 412 /* 413 if (!be_quiet) 414 errs() << "switch: " << SI->getNumCases() << " cases " << bitw 415 << " bit\n"; 416 */ 417 418 /* If there is only the default destination or the condition checks 8 bit or 419 * less, don't bother with the code below. */ 420 if (SI->getNumCases() < 2 || bitw % 8 || bitw > 64) { 421 422 // if (!be_quiet) errs() << "skip switch..\n"; 423 continue; 424 425 } 426 427 /* Create a new, empty default block so that the new hierarchy of 428 * if-then statements go to this and the PHI nodes are happy. 429 * if the default block is set as an unreachable we avoid creating one 430 * because will never be a valid target.*/ 431 BasicBlock *NewDefault = nullptr; 432 NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault", F, Default); 433 BranchInst::Create(Default, NewDefault); 434 435 /* Prepare cases vector. */ 436 CaseVector Cases; 437 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; 438 ++i) 439 #if LLVM_VERSION_MAJOR >= 5 440 Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor())); 441 #else 442 Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor())); 443 #endif 444 /* bugfix thanks to pbst 445 * round up bytesChecked (in case getBitWidth() % 8 != 0) */ 446 std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, 447 false); 448 BasicBlock * SwitchBlock = 449 switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); 450 451 /* Branch to our shiny new if-then stuff... */ 452 BranchInst::Create(SwitchBlock, OrigBlock); 453 454 /* We are now done with the switch instruction, delete it. */ 455 CurBlock->getInstList().erase(SI); 456 457 /* we have to update the phi nodes! */ 458 for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { 459 460 if (!isa<PHINode>(&*I)) { continue; } 461 PHINode *PN = cast<PHINode>(I); 462 463 /* Only update the first occurrence. */ 464 unsigned Idx = 0, E = PN->getNumIncomingValues(); 465 for (; Idx != E; ++Idx) { 466 467 if (PN->getIncomingBlock(Idx) == OrigBlock) { 468 469 PN->setIncomingBlock(Idx, NewDefault); 470 break; 471 472 } 473 474 } 475 476 } 477 478 } 479 480 verifyModule(M); 481 return true; 482 483 } 484 485 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 486 PreservedAnalyses SplitSwitchesTransform::run(Module & M, 487 ModuleAnalysisManager &MAM) { 488 489 #else 490 bool SplitSwitchesTransform::runOnModule(Module &M) { 491 492 #endif 493 494 if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) 495 printf("Running split-switches-pass by laf.intel@gmail.com\n"); 496 else 497 be_quiet = 1; 498 499 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 500 auto PA = PreservedAnalyses::all(); 501 #endif 502 503 splitSwitches(M); 504 verifyModule(M); 505 506 #if LLVM_VERSION_MAJOR >= 11 /* use new pass manager */ 507 /* if (modified) { 508 509 PA.abandon<XX_Manager>(); 510 511 }*/ 512 513 return PA; 514 #else 515 return true; 516 #endif 517 518 } 519 520 #if LLVM_VERSION_MAJOR < 11 /* use old pass manager */ 521 static void registerSplitSwitchesTransPass(const PassManagerBuilder &, 522 legacy::PassManagerBase &PM) { 523 524 auto p = new SplitSwitchesTransform(); 525 PM.add(p); 526 527 } 528 529 static RegisterStandardPasses RegisterSplitSwitchesTransPass( 530 PassManagerBuilder::EP_OptimizerLast, registerSplitSwitchesTransPass); 531 532 static RegisterStandardPasses RegisterSplitSwitchesTransPass0( 533 PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass); 534 535 #if LLVM_VERSION_MAJOR >= 11 536 static RegisterStandardPasses RegisterSplitSwitchesTransPassLTO( 537 PassManagerBuilder::EP_FullLinkTimeOptimizationLast, 538 registerSplitSwitchesTransPass); 539 #endif 540 #endif 541 542