1 /*
2 * Copyright 2016 laf-intel
3 * extended for floating point by Heiko Eißfeldt
4 * adapted to new pass manager by Heiko Eißfeldt
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * https://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <unistd.h>
22
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27
28 #include "llvm/Config/llvm-config.h"
29
30 #include "llvm/Pass.h"
31 #include "llvm/Support/raw_ostream.h"
32
33 #if LLVM_MAJOR >= 11
34 #include "llvm/Passes/PassPlugin.h"
35 #include "llvm/Passes/PassBuilder.h"
36 #include "llvm/IR/PassManager.h"
37 #else
38 #include "llvm/IR/LegacyPassManager.h"
39 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
40 #endif
41 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
42 #include "llvm/IR/Module.h"
43 #if LLVM_VERSION_MAJOR >= 14 /* how about stable interfaces? */
44 #include "llvm/Passes/OptimizationLevel.h"
45 #endif
46
47 #include "llvm/IR/IRBuilder.h"
48 #if LLVM_VERSION_MAJOR >= 4 || \
49 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
50 #include "llvm/IR/Verifier.h"
51 #include "llvm/IR/DebugInfo.h"
52 #else
53 #include "llvm/Analysis/Verifier.h"
54 #include "llvm/DebugInfo.h"
55 #define nullptr 0
56 #endif
57
58 using namespace llvm;
59 #include "afl-llvm-common.h"
60
61 // uncomment this toggle function verification at each step. horribly slow, but
62 // helps to pinpoint a potential problem in the splitting code.
63 //#define VERIFY_TOO_MUCH 1
64
65 namespace {
66
67 #if LLVM_MAJOR >= 11
68 class SplitComparesTransform : public PassInfoMixin<SplitComparesTransform> {
69
70 public:
71 // static char ID;
SplitComparesTransform()72 SplitComparesTransform() : enableFPSplit(0) {
73
74 #else
75 class SplitComparesTransform : public ModulePass {
76
77 public:
78 static char ID;
79 SplitComparesTransform() : ModulePass(ID), enableFPSplit(0) {
80
81 #endif
82
83 initInstrumentList();
84
85 }
86
87 #if LLVM_MAJOR >= 11
88 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
89 #else
90 bool runOnModule(Module &M) override;
91 #endif
92
93 private:
94 int enableFPSplit;
95
96 unsigned target_bitwidth = 8;
97
98 size_t count = 0;
99
100 size_t splitFPCompares(Module &M);
101 bool simplifyFPCompares(Module &M);
102 size_t nextPowerOfTwo(size_t in);
103
104 using CmpWorklist = SmallVector<CmpInst *, 8>;
105
106 /// simplify the comparison and then split the comparison until the
107 /// target_bitwidth is reached.
108 bool simplifyAndSplit(CmpInst *I, Module &M);
109 /// simplify a non-strict comparison (e.g., less than or equals)
110 bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M,
111 CmpWorklist &worklist);
112 /// simplify a signed comparison (signed less or greater than)
113 bool simplifySignedCompare(CmpInst *IcmpInst, Module &M,
114 CmpWorklist &worklist);
115 /// splits an icmp into nested icmps recursivly until target_bitwidth is
116 /// reached
117 bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist);
118
119 /// print an error to llvm's errs stream, but only if not ordered to be quiet
120 void reportError(const StringRef msg, Instruction *I, Module &M) {
121
122 if (!be_quiet) {
123
124 errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n";
125 if (debug) {
126
127 if (I) {
128
129 errs() << "Instruction = " << *I << "\n";
130 if (auto BB = I->getParent()) {
131
132 if (auto F = BB->getParent()) {
133
134 if (F->hasName()) {
135
136 errs() << "|-> in function " << F->getName() << " ";
137
138 }
139
140 }
141
142 }
143
144 }
145
146 auto n = M.getName();
147 if (n.size() > 0) { errs() << "in module " << n << "\n"; }
148
149 }
150
151 }
152
153 }
154
155 bool isSupportedBitWidth(unsigned bitw) {
156
157 // IDK whether the icmp code works on other bitwidths. I guess not? So we
158 // try to avoid dealing with other weird icmp's that llvm might use (looking
159 // at you `icmp i0`).
160 switch (bitw) {
161
162 case 8:
163 case 16:
164 case 32:
165 case 64:
166 case 128:
167 case 256:
168 return true;
169 default:
170 return false;
171
172 }
173
174 }
175
176 };
177
178 } // namespace
179
180 #if LLVM_MAJOR >= 11
181 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo()182 llvmGetPassPluginInfo() {
183
184 return {LLVM_PLUGIN_API_VERSION, "splitcompares", "v0.1",
185 /* lambda to insert our pass into the pass pipeline. */
186 [](PassBuilder &PB) {
187
188 #if 1
189 #if LLVM_VERSION_MAJOR <= 13
190 using OptimizationLevel = typename PassBuilder::OptimizationLevel;
191 #endif
192 PB.registerOptimizerLastEPCallback(
193 [](ModulePassManager &MPM, OptimizationLevel OL) {
194
195 MPM.addPass(SplitComparesTransform());
196
197 });
198
199 /* TODO LTO registration */
200 #else
201 using PipelineElement = typename PassBuilder::PipelineElement;
202 PB.registerPipelineParsingCallback([](StringRef Name,
203 ModulePassManager &MPM,
204 ArrayRef<PipelineElement>) {
205
206 if (Name == "splitcompares") {
207
208 MPM.addPass(SplitComparesTransform());
209 return true;
210
211 } else {
212
213 return false;
214
215 }
216
217 });
218
219 #endif
220
221 }};
222
223 }
224
225 #else
226 char SplitComparesTransform::ID = 0;
227 #endif
228
229 /// This function splits FCMP instructions with xGE or xLE predicates into two
230 /// FCMP instructions with predicate xGT or xLT and EQ
simplifyFPCompares(Module & M)231 bool SplitComparesTransform::simplifyFPCompares(Module &M) {
232
233 LLVMContext & C = M.getContext();
234 std::vector<Instruction *> fcomps;
235 IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
236
237 /* iterate over all functions, bbs and instruction and add
238 * all integer comparisons with >= and <= predicates to the icomps vector */
239 for (auto &F : M) {
240
241 if (!isInInstrumentList(&F, MNAME)) continue;
242
243 for (auto &BB : F) {
244
245 for (auto &IN : BB) {
246
247 CmpInst *selectcmpInst = nullptr;
248
249 if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
250
251 if (enableFPSplit &&
252 (selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
253 selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
254 selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
255 selectcmpInst->getPredicate() == CmpInst::FCMP_ULE)) {
256
257 auto op0 = selectcmpInst->getOperand(0);
258 auto op1 = selectcmpInst->getOperand(1);
259
260 Type *TyOp0 = op0->getType();
261 Type *TyOp1 = op1->getType();
262
263 /* this is probably not needed but we do it anyway */
264 if (TyOp0 != TyOp1) { continue; }
265
266 if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
267
268 fcomps.push_back(selectcmpInst);
269
270 }
271
272 }
273
274 }
275
276 }
277
278 }
279
280 if (!fcomps.size()) { return false; }
281
282 /* transform for floating point */
283 for (auto &FcmpInst : fcomps) {
284
285 BasicBlock *bb = FcmpInst->getParent();
286
287 auto op0 = FcmpInst->getOperand(0);
288 auto op1 = FcmpInst->getOperand(1);
289
290 /* find out what the new predicate is going to be */
291 auto cmp_inst = dyn_cast<CmpInst>(FcmpInst);
292 if (!cmp_inst) { continue; }
293 auto pred = cmp_inst->getPredicate();
294 CmpInst::Predicate new_pred;
295
296 switch (pred) {
297
298 case CmpInst::FCMP_UGE:
299 new_pred = CmpInst::FCMP_UGT;
300 break;
301 case CmpInst::FCMP_OGE:
302 new_pred = CmpInst::FCMP_OGT;
303 break;
304 case CmpInst::FCMP_ULE:
305 new_pred = CmpInst::FCMP_ULT;
306 break;
307 case CmpInst::FCMP_OLE:
308 new_pred = CmpInst::FCMP_OLT;
309 break;
310 default: // keep the compiler happy
311 continue;
312
313 }
314
315 /* split before the fcmp instruction */
316 BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
317
318 /* the old bb now contains a unconditional jump to the new one (end_bb)
319 * we need to delete it later */
320
321 /* create the FCMP instruction with new_pred and add it to the old basic
322 * block bb it is now at the position where the old FcmpInst was */
323 Instruction *fcmp_np;
324 fcmp_np = CmpInst::Create(Instruction::FCmp, new_pred, op0, op1);
325 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
326 fcmp_np);
327
328 /* create a new basic block which holds the new EQ fcmp */
329 Instruction *fcmp_eq;
330 /* insert middle_bb before end_bb */
331 BasicBlock *middle_bb =
332 BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
333 fcmp_eq = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, op0, op1);
334 middle_bb->getInstList().push_back(fcmp_eq);
335 /* add an unconditional branch to the end of middle_bb with destination
336 * end_bb */
337 BranchInst::Create(end_bb, middle_bb);
338
339 /* replace the uncond branch with a conditional one, which depends on the
340 * new_pred fcmp. True goes to end, false to the middle (injected) bb */
341 auto term = bb->getTerminator();
342 BranchInst::Create(end_bb, middle_bb, fcmp_np, bb);
343 term->eraseFromParent();
344
345 /* replace the old FcmpInst (which is the first inst in end_bb) with a PHI
346 * inst to wire up the loose ends */
347 PHINode *PN = PHINode::Create(Int1Ty, 2, "");
348 /* the first result depends on the outcome of fcmp_eq */
349 PN->addIncoming(fcmp_eq, middle_bb);
350 /* if the source was the original bb we know that the fcmp_np yielded true
351 * hence we can hardcode this value */
352 PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
353 /* replace the old FcmpInst with our new and shiny PHI inst */
354 BasicBlock::iterator ii(FcmpInst);
355 ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
356
357 }
358
359 return true;
360
361 }
362
363 /// This function splits ICMP instructions with xGE or xLE predicates into two
364 /// ICMP instructions with predicate xGT or xLT and EQ
simplifyOrEqualsCompare(CmpInst * IcmpInst,Module & M,CmpWorklist & worklist)365 bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst * IcmpInst,
366 Module & M,
367 CmpWorklist &worklist) {
368
369 LLVMContext &C = M.getContext();
370 IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
371
372 /* find out what the new predicate is going to be */
373 auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
374 if (!cmp_inst) { return false; }
375
376 BasicBlock *bb = IcmpInst->getParent();
377
378 auto op0 = IcmpInst->getOperand(0);
379 auto op1 = IcmpInst->getOperand(1);
380
381 CmpInst::Predicate pred = cmp_inst->getPredicate();
382 CmpInst::Predicate new_pred;
383
384 switch (pred) {
385
386 case CmpInst::ICMP_UGE:
387 new_pred = CmpInst::ICMP_UGT;
388 break;
389 case CmpInst::ICMP_SGE:
390 new_pred = CmpInst::ICMP_SGT;
391 break;
392 case CmpInst::ICMP_ULE:
393 new_pred = CmpInst::ICMP_ULT;
394 break;
395 case CmpInst::ICMP_SLE:
396 new_pred = CmpInst::ICMP_SLT;
397 break;
398 default: // keep the compiler happy
399 return false;
400
401 }
402
403 /* split before the icmp instruction */
404 BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
405
406 /* the old bb now contains a unconditional jump to the new one (end_bb)
407 * we need to delete it later */
408
409 /* create the ICMP instruction with new_pred and add it to the old basic
410 * block bb it is now at the position where the old IcmpInst was */
411 CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
412 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np);
413
414 /* create a new basic block which holds the new EQ icmp */
415 CmpInst *icmp_eq;
416 /* insert middle_bb before end_bb */
417 BasicBlock *middle_bb =
418 BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
419 icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
420 middle_bb->getInstList().push_back(icmp_eq);
421 /* add an unconditional branch to the end of middle_bb with destination
422 * end_bb */
423 BranchInst::Create(end_bb, middle_bb);
424
425 /* replace the uncond branch with a conditional one, which depends on the
426 * new_pred icmp. True goes to end, false to the middle (injected) bb */
427 auto term = bb->getTerminator();
428 BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
429 term->eraseFromParent();
430
431 /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
432 * inst to wire up the loose ends */
433 PHINode *PN = PHINode::Create(Int1Ty, 2, "");
434 /* the first result depends on the outcome of icmp_eq */
435 PN->addIncoming(icmp_eq, middle_bb);
436 /* if the source was the original bb we know that the icmp_np yielded true
437 * hence we can hardcode this value */
438 PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
439 /* replace the old IcmpInst with our new and shiny PHI inst */
440 BasicBlock::iterator ii(IcmpInst);
441 ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
442
443 worklist.push_back(icmp_np);
444 worklist.push_back(icmp_eq);
445
446 return true;
447
448 }
449
450 /// Simplify a signed comparison operator by splitting it into a unsigned and
451 /// bit comparison. add all resulting comparisons to
452 /// the worklist passed as a reference.
simplifySignedCompare(CmpInst * IcmpInst,Module & M,CmpWorklist & worklist)453 bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M,
454 CmpWorklist &worklist) {
455
456 LLVMContext &C = M.getContext();
457 IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
458
459 BasicBlock *bb = IcmpInst->getParent();
460
461 auto op0 = IcmpInst->getOperand(0);
462 auto op1 = IcmpInst->getOperand(1);
463
464 IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
465 if (!intTyOp0) { return false; }
466 unsigned bitw = intTyOp0->getBitWidth();
467 IntegerType *IntType = IntegerType::get(C, bitw);
468
469 /* get the new predicate */
470 auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
471 if (!cmp_inst) { return false; }
472 auto pred = cmp_inst->getPredicate();
473 CmpInst::Predicate new_pred;
474
475 if (pred == CmpInst::ICMP_SGT) {
476
477 new_pred = CmpInst::ICMP_UGT;
478
479 } else {
480
481 new_pred = CmpInst::ICMP_ULT;
482
483 }
484
485 BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
486
487 /* create a 1 bit compare for the sign bit. to do this shift and trunc
488 * the original operands so only the first bit remains.*/
489 Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
490
491 IRBuilder<> IRB(bb->getTerminator());
492 s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1));
493 t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty);
494 s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1));
495 t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty);
496 /* compare of the sign bits */
497 icmp_sign_bit = IRB.CreateICmp(CmpInst::ICMP_EQ, t_op0, t_op1);
498
499 /* create a new basic block which is executed if the signedness bit is
500 * different */
501 CmpInst * icmp_inv_sig_cmp;
502 BasicBlock *sign_bb =
503 BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
504 if (pred == CmpInst::ICMP_SGT) {
505
506 /* if we check for > and the op0 positive and op1 negative then the final
507 * result is true. if op0 negative and op1 pos, the cmp must result
508 * in false
509 */
510 icmp_inv_sig_cmp =
511 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
512
513 } else {
514
515 /* just the inverse of the above statement */
516 icmp_inv_sig_cmp =
517 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
518
519 }
520
521 sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
522 BranchInst::Create(end_bb, sign_bb);
523
524 /* create a new bb which is executed if signedness is equal */
525 CmpInst * icmp_usign_cmp;
526 BasicBlock *middle_bb =
527 BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
528 /* we can do a normal unsigned compare now */
529 icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
530
531 middle_bb->getInstList().push_back(icmp_usign_cmp);
532 BranchInst::Create(end_bb, middle_bb);
533
534 auto term = bb->getTerminator();
535 /* if the sign is eq do a normal unsigned cmp, else we have to check the
536 * signedness bit */
537 BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
538 term->eraseFromParent();
539
540 PHINode *PN = PHINode::Create(Int1Ty, 2, "");
541
542 PN->addIncoming(icmp_usign_cmp, middle_bb);
543 PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
544
545 BasicBlock::iterator ii(IcmpInst);
546 ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
547
548 // save for later
549 worklist.push_back(icmp_usign_cmp);
550
551 // signed comparisons are not supported by the splitting code, so we must not
552 // add it to the worklist.
553 // worklist.push_back(icmp_inv_sig_cmp);
554
555 return true;
556
557 }
558
splitCompare(CmpInst * cmp_inst,Module & M,CmpWorklist & worklist)559 bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M,
560 CmpWorklist &worklist) {
561
562 auto pred = cmp_inst->getPredicate();
563 switch (pred) {
564
565 case CmpInst::ICMP_EQ:
566 case CmpInst::ICMP_NE:
567 case CmpInst::ICMP_UGT:
568 case CmpInst::ICMP_ULT:
569 break;
570 default:
571 // unsupported predicate!
572 return false;
573
574 }
575
576 auto op0 = cmp_inst->getOperand(0);
577 auto op1 = cmp_inst->getOperand(1);
578
579 // get bitwidth by checking the bitwidth of the first operator
580 IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
581 if (!intTyOp0) {
582
583 // not an integer type
584 return false;
585
586 }
587
588 unsigned bitw = intTyOp0->getBitWidth();
589 if (bitw == target_bitwidth) {
590
591 // already the target bitwidth so we have to do nothing here.
592 return true;
593
594 }
595
596 LLVMContext &C = M.getContext();
597 IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
598 BasicBlock * bb = cmp_inst->getParent();
599 IntegerType *OldIntType = IntegerType::get(C, bitw);
600 IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
601 BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst));
602 CmpInst * icmp_high, *icmp_low;
603
604 /* create the comparison of the top halves of the original operands */
605 Value *s_op0, *op0_high, *s_op1, *op1_high;
606
607 IRBuilder<> IRB(bb->getTerminator());
608
609 s_op0 = IRB.CreateBinOp(Instruction::LShr, op0,
610 ConstantInt::get(OldIntType, bitw / 2));
611 op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType);
612
613 s_op1 = IRB.CreateBinOp(Instruction::LShr, op1,
614 ConstantInt::get(OldIntType, bitw / 2));
615 op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType);
616 icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high));
617
618 PHINode *PN = nullptr;
619
620 /* now we have to destinguish between == != and > < */
621 switch (pred) {
622
623 case CmpInst::ICMP_EQ:
624 case CmpInst::ICMP_NE: {
625
626 /* transformation for == and != icmps */
627
628 /* create a compare for the lower half of the original operands */
629 BasicBlock *cmp_low_bb =
630 BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
631
632 Value * op0_low, *op1_low;
633 IRBuilder<> Builder(cmp_low_bb);
634
635 op0_low = Builder.CreateTrunc(op0, NewIntType);
636 op1_low = Builder.CreateTrunc(op1, NewIntType);
637 icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low));
638
639 BranchInst::Create(end_bb, cmp_low_bb);
640
641 /* dependent on the cmp of the high parts go to the end or go on with
642 * the comparison */
643 auto term = bb->getTerminator();
644
645 if (pred == CmpInst::ICMP_EQ) {
646
647 BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
648
649 } else {
650
651 // CmpInst::ICMP_NE
652 BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
653
654 }
655
656 term->eraseFromParent();
657
658 /* create the PHI and connect the edges accordingly */
659 PN = PHINode::Create(Int1Ty, 2, "");
660 PN->addIncoming(icmp_low, cmp_low_bb);
661 Value *val = nullptr;
662 if (pred == CmpInst::ICMP_EQ) {
663
664 val = ConstantInt::get(Int1Ty, 0);
665
666 } else {
667
668 /* CmpInst::ICMP_NE */
669 val = ConstantInt::get(Int1Ty, 1);
670
671 }
672
673 PN->addIncoming(val, icmp_high->getParent());
674 break;
675
676 }
677
678 case CmpInst::ICMP_UGT:
679 case CmpInst::ICMP_ULT: {
680
681 /* transformations for < and > */
682
683 /* create a basic block which checks for the inverse predicate.
684 * if this is true we can go to the end if not we have to go to the
685 * bb which checks the lower half of the operands */
686 Instruction *op0_low, *op1_low;
687 CmpInst * icmp_inv_cmp = nullptr;
688 BasicBlock * inv_cmp_bb =
689 BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
690 if (pred == CmpInst::ICMP_UGT) {
691
692 icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
693 op0_high, op1_high);
694
695 } else {
696
697 icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
698 op0_high, op1_high);
699
700 }
701
702 inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
703 worklist.push_back(icmp_inv_cmp);
704
705 auto term = bb->getTerminator();
706 term->eraseFromParent();
707 BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
708
709 /* create a bb which handles the cmp of the lower halves */
710 BasicBlock *cmp_low_bb =
711 BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
712 op0_low = new TruncInst(op0, NewIntType);
713 cmp_low_bb->getInstList().push_back(op0_low);
714 op1_low = new TruncInst(op1, NewIntType);
715 cmp_low_bb->getInstList().push_back(op1_low);
716
717 icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
718 cmp_low_bb->getInstList().push_back(icmp_low);
719 BranchInst::Create(end_bb, cmp_low_bb);
720
721 BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
722
723 PN = PHINode::Create(Int1Ty, 3);
724 PN->addIncoming(icmp_low, cmp_low_bb);
725 PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
726 PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
727 break;
728
729 }
730
731 default:
732 return false;
733
734 }
735
736 BasicBlock::iterator ii(cmp_inst);
737 ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN);
738
739 // We split the comparison into low and high. If this isn't our target
740 // bitwidth we recursively split the low and high parts again until we have
741 // target bitwidth.
742 if ((bitw / 2) > target_bitwidth) {
743
744 worklist.push_back(icmp_high);
745 worklist.push_back(icmp_low);
746
747 }
748
749 return true;
750
751 }
752
simplifyAndSplit(CmpInst * I,Module & M)753 bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) {
754
755 CmpWorklist worklist;
756
757 auto op0 = I->getOperand(0);
758 auto op1 = I->getOperand(1);
759 if (!op0 || !op1) { return false; }
760 auto op0Ty = dyn_cast<IntegerType>(op0->getType());
761 if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; }
762
763 unsigned bitw = op0Ty->getBitWidth();
764
765 #ifdef VERIFY_TOO_MUCH
766 auto F = I->getParent()->getParent();
767 #endif
768
769 // we run the comparison simplification on all compares regardless of their
770 // bitwidth.
771 if (I->getPredicate() == CmpInst::ICMP_UGE ||
772 I->getPredicate() == CmpInst::ICMP_SGE ||
773 I->getPredicate() == CmpInst::ICMP_ULE ||
774 I->getPredicate() == CmpInst::ICMP_SLE) {
775
776 if (!simplifyOrEqualsCompare(I, M, worklist)) {
777
778 reportError(
779 "Failed to simplify inequality or equals comparison "
780 "(UGE,SGE,ULE,SLE)",
781 I, M);
782
783 }
784
785 } else if (I->getPredicate() == CmpInst::ICMP_SGT ||
786
787 I->getPredicate() == CmpInst::ICMP_SLT) {
788
789 if (!simplifySignedCompare(I, M, worklist)) {
790
791 reportError("Failed to simplify signed comparison (SGT,SLT)", I, M);
792
793 }
794
795 }
796
797 #ifdef VERIFY_TOO_MUCH
798 if (verifyFunction(*F, &errs())) {
799
800 reportError("simpliyfing compare lead to broken function", nullptr, M);
801
802 }
803
804 #endif
805
806 // the simplification methods replace the original CmpInst and push the
807 // resulting new CmpInst into the worklist. If the worklist is empty then
808 // we only have to split the original CmpInst.
809 if (worklist.size() == 0) { worklist.push_back(I); }
810
811 while (!worklist.empty()) {
812
813 CmpInst *cmp = worklist.pop_back_val();
814 // we split the simplified compares into comparisons with smaller bitwidths
815 // if they are larger than our target_bitwidth.
816 if (bitw > target_bitwidth) {
817
818 if (!splitCompare(cmp, M, worklist)) {
819
820 reportError("Failed to split comparison", cmp, M);
821
822 }
823
824 #ifdef VERIFY_TOO_MUCH
825 if (verifyFunction(*F, &errs())) {
826
827 reportError("splitting compare lead to broken function", nullptr, M);
828
829 }
830
831 #endif
832
833 }
834
835 }
836
837 count++;
838 return true;
839
840 }
841
nextPowerOfTwo(size_t in)842 size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
843
844 --in;
845 in |= in >> 1;
846 in |= in >> 2;
847 in |= in >> 4;
848 // in |= in >> 8;
849 // in |= in >> 16;
850 return in + 1;
851
852 }
853
854 /* splits fcmps into two nested fcmps with sign compare and the rest */
splitFPCompares(Module & M)855 size_t SplitComparesTransform::splitFPCompares(Module &M) {
856
857 size_t count = 0;
858
859 LLVMContext &C = M.getContext();
860
861 #if LLVM_VERSION_MAJOR >= 4 || \
862 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
863 const DataLayout &dl = M.getDataLayout();
864
865 /* define unions with floating point and (sign, exponent, mantissa) triples
866 */
867 if (dl.isLittleEndian()) {
868
869 } else if (dl.isBigEndian()) {
870
871 } else {
872
873 return count;
874
875 }
876
877 #endif
878
879 std::vector<CmpInst *> fcomps;
880
881 /* get all EQ, NE, GT, and LT fcmps. if the other two
882 * functions were executed only these four predicates should exist */
883 for (auto &F : M) {
884
885 if (!isInInstrumentList(&F, MNAME)) continue;
886
887 for (auto &BB : F) {
888
889 for (auto &IN : BB) {
890
891 CmpInst *selectcmpInst = nullptr;
892
893 if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
894
895 if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
896 selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ ||
897 selectcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
898 selectcmpInst->getPredicate() == CmpInst::FCMP_UNE ||
899 selectcmpInst->getPredicate() == CmpInst::FCMP_UGT ||
900 selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
901 selectcmpInst->getPredicate() == CmpInst::FCMP_ULT ||
902 selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) {
903
904 auto op0 = selectcmpInst->getOperand(0);
905 auto op1 = selectcmpInst->getOperand(1);
906
907 Type *TyOp0 = op0->getType();
908 Type *TyOp1 = op1->getType();
909
910 if (TyOp0 != TyOp1) { continue; }
911
912 if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
913
914 fcomps.push_back(selectcmpInst);
915
916 }
917
918 }
919
920 }
921
922 }
923
924 }
925
926 if (!fcomps.size()) { return count; }
927
928 IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
929
930 for (auto &FcmpInst : fcomps) {
931
932 BasicBlock *bb = FcmpInst->getParent();
933
934 auto op0 = FcmpInst->getOperand(0);
935 auto op1 = FcmpInst->getOperand(1);
936
937 unsigned op_size;
938 op_size = op0->getType()->getPrimitiveSizeInBits();
939
940 if (op_size != op1->getType()->getPrimitiveSizeInBits()) { continue; }
941
942 const unsigned int sizeInBits = op0->getType()->getPrimitiveSizeInBits();
943
944 // BUG FIXME TODO: u64 does not work for > 64 bit ... e.g. 80 and 128 bit
945 if (sizeInBits > 64) { continue; }
946
947 IntegerType * intType = IntegerType::get(C, op_size);
948 const unsigned int precision = sizeInBits == 32 ? 24
949 : sizeInBits == 64 ? 53
950 : sizeInBits == 128 ? 113
951 : sizeInBits == 16 ? 11
952 : sizeInBits == 80 ? 65
953 : sizeInBits - 8;
954
955 const unsigned shiftR_exponent = precision - 1;
956 const unsigned long long mask_fraction =
957 (1ULL << (shiftR_exponent - 1)) | ((1ULL << (shiftR_exponent - 1)) - 1);
958 const unsigned long long mask_exponent =
959 (1ULL << (sizeInBits - precision)) - 1;
960
961 // round up sizes to the next power of two
962 // this should help with integer compare splitting
963 size_t exTySizeBytes = ((sizeInBits - precision + 7) >> 3);
964 size_t frTySizeBytes = ((precision - 1ULL + 7) >> 3);
965
966 IntegerType *IntExponentTy =
967 IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3);
968 IntegerType *IntFractionTy =
969 IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3);
970
971 // errs() << "Fractions: IntFractionTy size " <<
972 // IntFractionTy->getPrimitiveSizeInBits() << ", op_size " << op_size <<
973 // ", mask " << mask_fraction <<
974 // ", precision " << precision << "\n";
975
976 BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
977
978 /* create the integers from floats directly */
979 Instruction *bpre_op0, *bpre_op1;
980 bpre_op0 = CastInst::Create(Instruction::BitCast, op0,
981 IntegerType::get(C, op_size));
982 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
983 bpre_op0);
984
985 bpre_op1 = CastInst::Create(Instruction::BitCast, op1,
986 IntegerType::get(C, op_size));
987 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
988 bpre_op1);
989
990 /* Check if any operand is NaN.
991 * If so, all comparisons except unequal (which yields true) yield false */
992
993 /* build mask for NaN */
994 const unsigned long long NaN_lowend = mask_exponent << precision;
995 // errs() << "Fractions: IntFractionTy size " <<
996 // IntFractionTy->getPrimitiveSizeInBits() << ", op_size " << op_size <<
997 // ", mask_fraction 0x";
998 // errs().write_hex(mask_fraction);
999 // errs() << ", precision " << precision <<
1000 // ", NaN_lowend 0x";
1001 // errs().write_hex(NaN_lowend); errs() << "\n";
1002
1003 /* Check op0 for NaN */
1004 /* Shift left 1 Bit, ignore sign bit */
1005 Instruction *nan_op0, *nan_op1;
1006 nan_op0 = BinaryOperator::Create(Instruction::Shl, bpre_op0,
1007 ConstantInt::get(bpre_op0->getType(), 1));
1008 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1009 nan_op0);
1010
1011 /* compare to NaN interval */
1012 Instruction *is_op0_nan =
1013 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, nan_op0,
1014 ConstantInt::get(intType, NaN_lowend));
1015 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1016 is_op0_nan);
1017
1018 /* Check op1 for NaN */
1019 /* Shift right 1 Bit, ignore sign bit */
1020 nan_op1 = BinaryOperator::Create(Instruction::Shl, bpre_op1,
1021 ConstantInt::get(bpre_op1->getType(), 1));
1022 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1023 nan_op1);
1024
1025 /* compare to NaN interval */
1026 Instruction *is_op1_nan =
1027 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, nan_op1,
1028 ConstantInt::get(intType, NaN_lowend));
1029 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1030 is_op1_nan);
1031
1032 /* combine checks */
1033 Instruction *is_nan =
1034 BinaryOperator::Create(Instruction::Or, is_op0_nan, is_op1_nan);
1035 bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), is_nan);
1036
1037 /* the result of the comparison, when at least one op is NaN
1038 is true only for the "NOT EQUAL" predicates. */
1039 bool NaNcmp_result = FcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
1040 FcmpInst->getPredicate() == CmpInst::FCMP_UNE;
1041
1042 BasicBlock *nonan_bb =
1043 BasicBlock::Create(C, "noNaN", end_bb->getParent(), end_bb);
1044
1045 BranchInst::Create(end_bb, nonan_bb);
1046
1047 auto term = bb->getTerminator();
1048 /* if no operand is NaN goto nonan_bb else to handleNaN_bb */
1049 BranchInst::Create(end_bb, nonan_bb, is_nan, bb);
1050 term->eraseFromParent();
1051
1052 /*** now working in nonan_bb ***/
1053
1054 /* Treat -0.0 as equal to +0.0, that is for -0.0 make it +0.0 */
1055 Instruction * b_op0, *b_op1;
1056 Instruction * isMzero_op0, *isMzero_op1;
1057 const unsigned long long MinusZero = 1UL << (sizeInBits - 1U);
1058 const unsigned long long PlusZero = 0;
1059
1060 isMzero_op0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, bpre_op0,
1061 ConstantInt::get(intType, MinusZero));
1062 nonan_bb->getInstList().insert(
1063 BasicBlock::iterator(nonan_bb->getTerminator()), isMzero_op0);
1064
1065 isMzero_op1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, bpre_op1,
1066 ConstantInt::get(intType, MinusZero));
1067 nonan_bb->getInstList().insert(
1068 BasicBlock::iterator(nonan_bb->getTerminator()), isMzero_op1);
1069
1070 b_op0 = SelectInst::Create(isMzero_op0, ConstantInt::get(intType, PlusZero),
1071 bpre_op0);
1072 nonan_bb->getInstList().insert(
1073 BasicBlock::iterator(nonan_bb->getTerminator()), b_op0);
1074
1075 b_op1 = SelectInst::Create(isMzero_op1, ConstantInt::get(intType, PlusZero),
1076 bpre_op1);
1077 nonan_bb->getInstList().insert(
1078 BasicBlock::iterator(nonan_bb->getTerminator()), b_op1);
1079
1080 /* isolate signs of value of floating point type */
1081
1082 /* create a 1 bit compare for the sign bit. to do this shift and trunc
1083 * the original operands so only the first bit remains.*/
1084 Instruction *s_s0, *t_s0, *s_s1, *t_s1, *icmp_sign_bit;
1085
1086 s_s0 =
1087 BinaryOperator::Create(Instruction::LShr, b_op0,
1088 ConstantInt::get(b_op0->getType(), op_size - 1));
1089 nonan_bb->getInstList().insert(
1090 BasicBlock::iterator(nonan_bb->getTerminator()), s_s0);
1091 t_s0 = new TruncInst(s_s0, Int1Ty);
1092 nonan_bb->getInstList().insert(
1093 BasicBlock::iterator(nonan_bb->getTerminator()), t_s0);
1094
1095 s_s1 =
1096 BinaryOperator::Create(Instruction::LShr, b_op1,
1097 ConstantInt::get(b_op1->getType(), op_size - 1));
1098 nonan_bb->getInstList().insert(
1099 BasicBlock::iterator(nonan_bb->getTerminator()), s_s1);
1100 t_s1 = new TruncInst(s_s1, Int1Ty);
1101 nonan_bb->getInstList().insert(
1102 BasicBlock::iterator(nonan_bb->getTerminator()), t_s1);
1103
1104 /* compare of the sign bits */
1105 icmp_sign_bit =
1106 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1);
1107 nonan_bb->getInstList().insert(
1108 BasicBlock::iterator(nonan_bb->getTerminator()), icmp_sign_bit);
1109
1110 /* create a new basic block which is executed if the signedness bits are
1111 * equal */
1112 BasicBlock *signequal_bb =
1113 BasicBlock::Create(C, "signequal", end_bb->getParent(), end_bb);
1114
1115 BranchInst::Create(end_bb, signequal_bb);
1116
1117 /* create a new bb which is executed if exponents are satisfying the compare
1118 */
1119 BasicBlock *middle_bb =
1120 BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
1121
1122 BranchInst::Create(end_bb, middle_bb);
1123
1124 term = nonan_bb->getTerminator();
1125 /* if the signs are different goto end_bb else to signequal_bb */
1126 BranchInst::Create(signequal_bb, end_bb, icmp_sign_bit, nonan_bb);
1127 term->eraseFromParent();
1128
1129 /* insert code for equal signs */
1130
1131 /* isolate the exponents */
1132 Instruction *s_e0, *m_e0, *t_e0, *s_e1, *m_e1, *t_e1;
1133
1134 s_e0 = BinaryOperator::Create(
1135 Instruction::LShr, b_op0,
1136 ConstantInt::get(b_op0->getType(), shiftR_exponent));
1137 s_e1 = BinaryOperator::Create(
1138 Instruction::LShr, b_op1,
1139 ConstantInt::get(b_op1->getType(), shiftR_exponent));
1140 signequal_bb->getInstList().insert(
1141 BasicBlock::iterator(signequal_bb->getTerminator()), s_e0);
1142 signequal_bb->getInstList().insert(
1143 BasicBlock::iterator(signequal_bb->getTerminator()), s_e1);
1144
1145 t_e0 = new TruncInst(s_e0, IntExponentTy);
1146 t_e1 = new TruncInst(s_e1, IntExponentTy);
1147 signequal_bb->getInstList().insert(
1148 BasicBlock::iterator(signequal_bb->getTerminator()), t_e0);
1149 signequal_bb->getInstList().insert(
1150 BasicBlock::iterator(signequal_bb->getTerminator()), t_e1);
1151
1152 if (sizeInBits - precision < exTySizeBytes * 8) {
1153
1154 m_e0 = BinaryOperator::Create(
1155 Instruction::And, t_e0,
1156 ConstantInt::get(t_e0->getType(), mask_exponent));
1157 m_e1 = BinaryOperator::Create(
1158 Instruction::And, t_e1,
1159 ConstantInt::get(t_e1->getType(), mask_exponent));
1160 signequal_bb->getInstList().insert(
1161 BasicBlock::iterator(signequal_bb->getTerminator()), m_e0);
1162 signequal_bb->getInstList().insert(
1163 BasicBlock::iterator(signequal_bb->getTerminator()), m_e1);
1164
1165 } else {
1166
1167 m_e0 = t_e0;
1168 m_e1 = t_e1;
1169
1170 }
1171
1172 /* compare the exponents of the operands */
1173 Instruction *icmp_exponents_equal;
1174 Instruction *icmp_exponent_result;
1175 BasicBlock * signequal2_bb = signequal_bb;
1176 switch (FcmpInst->getPredicate()) {
1177
1178 case CmpInst::FCMP_UEQ:
1179 case CmpInst::FCMP_OEQ:
1180 icmp_exponent_result =
1181 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
1182 break;
1183 case CmpInst::FCMP_ONE:
1184 case CmpInst::FCMP_UNE:
1185 icmp_exponent_result =
1186 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, m_e0, m_e1);
1187 break;
1188 /* compare the exponents of the operands (signs are equal)
1189 * if exponents are equal -> proceed to mantissa comparison
1190 * else get result depending on sign
1191 */
1192 case CmpInst::FCMP_OGT:
1193 case CmpInst::FCMP_UGT:
1194 Instruction *icmp_exponent;
1195 icmp_exponents_equal =
1196 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
1197 signequal_bb->getInstList().insert(
1198 BasicBlock::iterator(signequal_bb->getTerminator()),
1199 icmp_exponents_equal);
1200
1201 // shortcut for unequal exponents
1202 signequal2_bb = signequal_bb->splitBasicBlock(
1203 BasicBlock::iterator(signequal_bb->getTerminator()));
1204
1205 /* if the exponents are equal goto middle_bb else to signequal2_bb */
1206 term = signequal_bb->getTerminator();
1207 BranchInst::Create(middle_bb, signequal2_bb, icmp_exponents_equal,
1208 signequal_bb);
1209 term->eraseFromParent();
1210
1211 icmp_exponent =
1212 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, m_e0, m_e1);
1213 signequal2_bb->getInstList().insert(
1214 BasicBlock::iterator(signequal2_bb->getTerminator()),
1215 icmp_exponent);
1216 icmp_exponent_result =
1217 BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
1218 break;
1219 case CmpInst::FCMP_OLT:
1220 case CmpInst::FCMP_ULT:
1221 icmp_exponents_equal =
1222 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
1223 signequal_bb->getInstList().insert(
1224 BasicBlock::iterator(signequal_bb->getTerminator()),
1225 icmp_exponents_equal);
1226
1227 // shortcut for unequal exponents
1228 signequal2_bb = signequal_bb->splitBasicBlock(
1229 BasicBlock::iterator(signequal_bb->getTerminator()));
1230
1231 /* if the exponents are equal goto middle_bb else to signequal2_bb */
1232 term = signequal_bb->getTerminator();
1233 BranchInst::Create(middle_bb, signequal2_bb, icmp_exponents_equal,
1234 signequal_bb);
1235 term->eraseFromParent();
1236
1237 icmp_exponent =
1238 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, m_e0, m_e1);
1239 signequal2_bb->getInstList().insert(
1240 BasicBlock::iterator(signequal2_bb->getTerminator()),
1241 icmp_exponent);
1242 icmp_exponent_result =
1243 BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
1244 break;
1245 default:
1246 continue;
1247
1248 }
1249
1250 signequal2_bb->getInstList().insert(
1251 BasicBlock::iterator(signequal2_bb->getTerminator()),
1252 icmp_exponent_result);
1253
1254 {
1255
1256 term = signequal2_bb->getTerminator();
1257
1258 switch (FcmpInst->getPredicate()) {
1259
1260 case CmpInst::FCMP_UEQ:
1261 case CmpInst::FCMP_OEQ:
1262 /* if the exponents are satifying the compare do a fraction cmp in
1263 * middle_bb */
1264 BranchInst::Create(middle_bb, end_bb, icmp_exponent_result,
1265 signequal2_bb);
1266 break;
1267 case CmpInst::FCMP_ONE:
1268 case CmpInst::FCMP_UNE:
1269 /* if the exponents are satifying the compare do a fraction cmp in
1270 * middle_bb */
1271 BranchInst::Create(end_bb, middle_bb, icmp_exponent_result,
1272 signequal2_bb);
1273 break;
1274 case CmpInst::FCMP_OGT:
1275 case CmpInst::FCMP_UGT:
1276 case CmpInst::FCMP_OLT:
1277 case CmpInst::FCMP_ULT:
1278 BranchInst::Create(end_bb, signequal2_bb);
1279 break;
1280 default:
1281 continue;
1282
1283 }
1284
1285 term->eraseFromParent();
1286
1287 }
1288
1289 /* isolate the mantissa aka fraction */
1290 Instruction *t_f0, *t_f1;
1291 bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op_size;
1292
1293 if (precision - 1 < frTySizeBytes * 8) {
1294
1295 Instruction *m_f0, *m_f1;
1296 m_f0 = BinaryOperator::Create(
1297 Instruction::And, b_op0,
1298 ConstantInt::get(b_op0->getType(), mask_fraction));
1299 m_f1 = BinaryOperator::Create(
1300 Instruction::And, b_op1,
1301 ConstantInt::get(b_op1->getType(), mask_fraction));
1302 middle_bb->getInstList().insert(
1303 BasicBlock::iterator(middle_bb->getTerminator()), m_f0);
1304 middle_bb->getInstList().insert(
1305 BasicBlock::iterator(middle_bb->getTerminator()), m_f1);
1306
1307 if (needTrunc) {
1308
1309 t_f0 = new TruncInst(m_f0, IntFractionTy);
1310 t_f1 = new TruncInst(m_f1, IntFractionTy);
1311 middle_bb->getInstList().insert(
1312 BasicBlock::iterator(middle_bb->getTerminator()), t_f0);
1313 middle_bb->getInstList().insert(
1314 BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
1315
1316 } else {
1317
1318 t_f0 = m_f0;
1319 t_f1 = m_f1;
1320
1321 }
1322
1323 } else {
1324
1325 if (needTrunc) {
1326
1327 t_f0 = new TruncInst(b_op0, IntFractionTy);
1328 t_f1 = new TruncInst(b_op1, IntFractionTy);
1329 middle_bb->getInstList().insert(
1330 BasicBlock::iterator(middle_bb->getTerminator()), t_f0);
1331 middle_bb->getInstList().insert(
1332 BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
1333
1334 } else {
1335
1336 t_f0 = b_op0;
1337 t_f1 = b_op1;
1338
1339 }
1340
1341 }
1342
1343 /* compare the fractions of the operands */
1344 Instruction *icmp_fraction_result;
1345 BasicBlock * middle2_bb = middle_bb;
1346 PHINode * PN2 = nullptr;
1347 switch (FcmpInst->getPredicate()) {
1348
1349 case CmpInst::FCMP_UEQ:
1350 case CmpInst::FCMP_OEQ:
1351 icmp_fraction_result =
1352 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_f0, t_f1);
1353 middle2_bb->getInstList().insert(
1354 BasicBlock::iterator(middle2_bb->getTerminator()),
1355 icmp_fraction_result);
1356
1357 break;
1358 case CmpInst::FCMP_UNE:
1359 case CmpInst::FCMP_ONE:
1360 icmp_fraction_result =
1361 CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_f0, t_f1);
1362 middle2_bb->getInstList().insert(
1363 BasicBlock::iterator(middle2_bb->getTerminator()),
1364 icmp_fraction_result);
1365
1366 break;
1367 case CmpInst::FCMP_OGT:
1368 case CmpInst::FCMP_UGT:
1369 case CmpInst::FCMP_OLT:
1370 case CmpInst::FCMP_ULT: {
1371
1372 Instruction *icmp_fraction_result2;
1373
1374 middle2_bb = middle_bb->splitBasicBlock(
1375 BasicBlock::iterator(middle_bb->getTerminator()));
1376
1377 BasicBlock *negative_bb = BasicBlock::Create(
1378 C, "negative_value", middle2_bb->getParent(), middle2_bb);
1379 BasicBlock *positive_bb = BasicBlock::Create(
1380 C, "positive_value", negative_bb->getParent(), negative_bb);
1381
1382 if (FcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
1383 FcmpInst->getPredicate() == CmpInst::FCMP_UGT) {
1384
1385 negative_bb->getInstList().push_back(
1386 icmp_fraction_result = CmpInst::Create(
1387 Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
1388 positive_bb->getInstList().push_back(
1389 icmp_fraction_result2 = CmpInst::Create(
1390 Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
1391
1392 } else {
1393
1394 negative_bb->getInstList().push_back(
1395 icmp_fraction_result = CmpInst::Create(
1396 Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
1397 positive_bb->getInstList().push_back(
1398 icmp_fraction_result2 = CmpInst::Create(
1399 Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
1400
1401 }
1402
1403 BranchInst::Create(middle2_bb, negative_bb);
1404 BranchInst::Create(middle2_bb, positive_bb);
1405
1406 term = middle_bb->getTerminator();
1407 BranchInst::Create(negative_bb, positive_bb, t_s0, middle_bb);
1408 term->eraseFromParent();
1409
1410 PN2 = PHINode::Create(Int1Ty, 2, "");
1411 PN2->addIncoming(icmp_fraction_result, negative_bb);
1412 PN2->addIncoming(icmp_fraction_result2, positive_bb);
1413 middle2_bb->getInstList().insert(
1414 BasicBlock::iterator(middle2_bb->getTerminator()), PN2);
1415
1416 } break;
1417
1418 default:
1419 continue;
1420
1421 }
1422
1423 PHINode *PN = PHINode::Create(Int1Ty, 4, "");
1424
1425 switch (FcmpInst->getPredicate()) {
1426
1427 case CmpInst::FCMP_UEQ:
1428 case CmpInst::FCMP_OEQ:
1429 /* unequal signs cannot be equal values */
1430 /* goto false branch */
1431 PN->addIncoming(ConstantInt::get(Int1Ty, 0), nonan_bb);
1432 /* unequal exponents cannot be equal values, too */
1433 PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb);
1434 /* fractions comparison */
1435 PN->addIncoming(icmp_fraction_result, middle2_bb);
1436 /* NaNs */
1437 PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1438 break;
1439 case CmpInst::FCMP_ONE:
1440 case CmpInst::FCMP_UNE:
1441 /* unequal signs are unequal values */
1442 /* goto true branch */
1443 PN->addIncoming(ConstantInt::get(Int1Ty, 1), nonan_bb);
1444 /* unequal exponents are unequal values, too */
1445 PN->addIncoming(icmp_exponent_result, signequal_bb);
1446 /* fractions comparison */
1447 PN->addIncoming(icmp_fraction_result, middle2_bb);
1448 /* NaNs */
1449 PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1450 break;
1451 case CmpInst::FCMP_OGT:
1452 case CmpInst::FCMP_UGT:
1453 /* if op1 is negative goto true branch,
1454 else go on comparing */
1455 PN->addIncoming(t_s1, nonan_bb);
1456 PN->addIncoming(icmp_exponent_result, signequal2_bb);
1457 PN->addIncoming(PN2, middle2_bb);
1458 /* NaNs */
1459 PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1460 break;
1461 case CmpInst::FCMP_OLT:
1462 case CmpInst::FCMP_ULT:
1463 /* if op0 is negative goto true branch,
1464 else go on comparing */
1465 PN->addIncoming(t_s0, nonan_bb);
1466 PN->addIncoming(icmp_exponent_result, signequal2_bb);
1467 PN->addIncoming(PN2, middle2_bb);
1468 /* NaNs */
1469 PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1470 break;
1471 default:
1472 continue;
1473
1474 }
1475
1476 BasicBlock::iterator ii(FcmpInst);
1477 ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
1478 ++count;
1479
1480 }
1481
1482 return count;
1483
1484 }
1485
1486 #if LLVM_MAJOR >= 11
run(Module & M,ModuleAnalysisManager & MAM)1487 PreservedAnalyses SplitComparesTransform::run(Module & M,
1488 ModuleAnalysisManager &MAM) {
1489
1490 #else
1491 bool SplitComparesTransform::runOnModule(Module &M) {
1492
1493 #endif
1494
1495 char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
1496 if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
1497 if (bitw_env) { target_bitwidth = atoi(bitw_env); }
1498
1499 enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
1500
1501 if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
1502 getenv("AFL_DEBUG") != NULL) {
1503
1504 errs() << "Split-compare-newpass by laf.intel@gmail.com, extended by "
1505 "heiko@hexco.de (splitting icmp to "
1506 << target_bitwidth << " bit)\n";
1507
1508 if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; }
1509
1510 } else {
1511
1512 be_quiet = 1;
1513
1514 }
1515
1516 #if LLVM_MAJOR >= 11
1517 auto PA = PreservedAnalyses::all();
1518 #endif
1519
1520 if (enableFPSplit) {
1521
1522 simplifyFPCompares(M);
1523 count = splitFPCompares(M);
1524
1525 if (!be_quiet && !debug) {
1526
1527 errs() << "Split-floatingpoint-compare-pass: " << count
1528 << " FP comparisons splitted\n";
1529
1530 }
1531
1532 }
1533
1534 std::vector<CmpInst *> worklist;
1535 /* iterate over all functions, bbs and instruction search for all integer
1536 * compare instructions. Save them into the worklist for later. */
1537 for (auto &F : M) {
1538
1539 if (!isInInstrumentList(&F, MNAME)) continue;
1540
1541 for (auto &BB : F) {
1542
1543 for (auto &IN : BB) {
1544
1545 if (auto CI = dyn_cast<CmpInst>(&IN)) {
1546
1547 auto op0 = CI->getOperand(0);
1548 auto op1 = CI->getOperand(1);
1549 if (!op0 || !op1) {
1550
1551 #if LLVM_MAJOR >= 11
1552 return PA;
1553 #else
1554 return false;
1555 #endif
1556
1557 }
1558
1559 auto iTy1 = dyn_cast<IntegerType>(op0->getType());
1560 if (iTy1 && isa<IntegerType>(op1->getType())) {
1561
1562 unsigned bitw = iTy1->getBitWidth();
1563 if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
1564
1565 }
1566
1567 }
1568
1569 }
1570
1571 }
1572
1573 }
1574
1575 // now that we have a list of all integer comparisons we can start replacing
1576 // them with the splitted alternatives.
1577 for (auto CI : worklist) {
1578
1579 simplifyAndSplit(CI, M);
1580
1581 }
1582
1583 bool brokenDebug = false;
1584 if (verifyModule(M, &errs()
1585 #if LLVM_VERSION_MAJOR >= 4 || \
1586 (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9)
1587 ,
1588 &brokenDebug // 9th May 2016
1589 #endif
1590 )) {
1591
1592 reportError(
1593 "Module Verifier failed! Consider reporting a bug with the AFL++ "
1594 "project.",
1595 nullptr, M);
1596
1597 }
1598
1599 if (brokenDebug) {
1600
1601 reportError("Module Verifier reported broken Debug Infos - Stripping!",
1602 nullptr, M);
1603 StripDebugInfo(M);
1604
1605 }
1606
1607 if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
1608 getenv("AFL_DEBUG") != NULL) {
1609
1610 errs() << count << " comparisons found\n";
1611
1612 }
1613
1614 #if LLVM_MAJOR >= 11
1615 /* if (modified) {
1616
1617 PA.abandon<XX_Manager>();
1618
1619 }*/
1620
1621 return PA;
1622 #else
1623 return true;
1624 #endif
1625
1626 }
1627
1628 #if LLVM_MAJOR < 11 /* use old pass manager */
1629
1630 static void registerSplitComparesPass(const PassManagerBuilder &,
1631 legacy::PassManagerBase &PM) {
1632
1633 PM.add(new SplitComparesTransform());
1634
1635 }
1636
1637 static RegisterStandardPasses RegisterSplitComparesPass(
1638 PassManagerBuilder::EP_OptimizerLast, registerSplitComparesPass);
1639
1640 static RegisterStandardPasses RegisterSplitComparesTransPass0(
1641 PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitComparesPass);
1642
1643 #if LLVM_VERSION_MAJOR >= 11
1644 static RegisterStandardPasses RegisterSplitComparesTransPassLTO(
1645 PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
1646 registerSplitComparesPass);
1647 #endif
1648
1649 static RegisterPass<SplitComparesTransform> X("splitcompares",
1650 "AFL++ split compares",
1651 true /* Only looks at CFG */,
1652 true /* Analysis Pass */);
1653 #endif
1654
1655