• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 laf-intel
3  * extended for floating point by Heiko Eißfeldt
4  * adapted to new pass manager by Heiko Eißfeldt
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     https://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include <stdio.h>
20 #include <stdlib.h>
21 #include <unistd.h>
22 
23 #include <list>
24 #include <string>
25 #include <fstream>
26 #include <sys/time.h>
27 
28 #include "llvm/Config/llvm-config.h"
29 
30 #include "llvm/Pass.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #if LLVM_MAJOR >= 11
34   #include "llvm/Passes/PassPlugin.h"
35   #include "llvm/Passes/PassBuilder.h"
36   #include "llvm/IR/PassManager.h"
37 #else
38   #include "llvm/IR/LegacyPassManager.h"
39   #include "llvm/Transforms/IPO/PassManagerBuilder.h"
40 #endif
41 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
42 #include "llvm/IR/Module.h"
43 #if LLVM_VERSION_MAJOR >= 14                /* how about stable interfaces? */
44   #include "llvm/Passes/OptimizationLevel.h"
45 #endif
46 
47 #include "llvm/IR/IRBuilder.h"
48 #if LLVM_VERSION_MAJOR >= 4 || \
49     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
50   #include "llvm/IR/Verifier.h"
51   #include "llvm/IR/DebugInfo.h"
52 #else
53   #include "llvm/Analysis/Verifier.h"
54   #include "llvm/DebugInfo.h"
55   #define nullptr 0
56 #endif
57 
58 using namespace llvm;
59 #include "afl-llvm-common.h"
60 
61 // uncomment this toggle function verification at each step. horribly slow, but
62 // helps to pinpoint a potential problem in the splitting code.
63 //#define VERIFY_TOO_MUCH 1
64 
65 namespace {
66 
67 #if LLVM_MAJOR >= 11
68 class SplitComparesTransform : public PassInfoMixin<SplitComparesTransform> {
69 
70  public:
71   //  static char ID;
SplitComparesTransform()72   SplitComparesTransform() : enableFPSplit(0) {
73 
74 #else
75 class SplitComparesTransform : public ModulePass {
76 
77  public:
78   static char ID;
79   SplitComparesTransform() : ModulePass(ID), enableFPSplit(0) {
80 
81 #endif
82 
83     initInstrumentList();
84 
85   }
86 
87 #if LLVM_MAJOR >= 11
88   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
89 #else
90   bool runOnModule(Module &M) override;
91 #endif
92 
93  private:
94   int enableFPSplit;
95 
96   unsigned target_bitwidth = 8;
97 
98   size_t count = 0;
99 
100   size_t splitFPCompares(Module &M);
101   bool   simplifyFPCompares(Module &M);
102   size_t nextPowerOfTwo(size_t in);
103 
104   using CmpWorklist = SmallVector<CmpInst *, 8>;
105 
106   /// simplify the comparison and then split the comparison until the
107   /// target_bitwidth is reached.
108   bool simplifyAndSplit(CmpInst *I, Module &M);
109   /// simplify a non-strict comparison (e.g., less than or equals)
110   bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M,
111                                CmpWorklist &worklist);
112   /// simplify a signed comparison (signed less or greater than)
113   bool simplifySignedCompare(CmpInst *IcmpInst, Module &M,
114                              CmpWorklist &worklist);
115   /// splits an icmp into nested icmps recursivly until target_bitwidth is
116   /// reached
117   bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist);
118 
119   /// print an error to llvm's errs stream, but only if not ordered to be quiet
120   void reportError(const StringRef msg, Instruction *I, Module &M) {
121 
122     if (!be_quiet) {
123 
124       errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n";
125       if (debug) {
126 
127         if (I) {
128 
129           errs() << "Instruction = " << *I << "\n";
130           if (auto BB = I->getParent()) {
131 
132             if (auto F = BB->getParent()) {
133 
134               if (F->hasName()) {
135 
136                 errs() << "|-> in function " << F->getName() << " ";
137 
138               }
139 
140             }
141 
142           }
143 
144         }
145 
146         auto n = M.getName();
147         if (n.size() > 0) { errs() << "in module " << n << "\n"; }
148 
149       }
150 
151     }
152 
153   }
154 
155   bool isSupportedBitWidth(unsigned bitw) {
156 
157     // IDK whether the icmp code works on other bitwidths. I guess not? So we
158     // try to avoid dealing with other weird icmp's that llvm might use (looking
159     // at you `icmp i0`).
160     switch (bitw) {
161 
162       case 8:
163       case 16:
164       case 32:
165       case 64:
166       case 128:
167       case 256:
168         return true;
169       default:
170         return false;
171 
172     }
173 
174   }
175 
176 };
177 
178 }  // namespace
179 
180 #if LLVM_MAJOR >= 11
181 extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo()182 llvmGetPassPluginInfo() {
183 
184   return {LLVM_PLUGIN_API_VERSION, "splitcompares", "v0.1",
185           /* lambda to insert our pass into the pass pipeline. */
186           [](PassBuilder &PB) {
187 
188   #if 1
189     #if LLVM_VERSION_MAJOR <= 13
190             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
191     #endif
192             PB.registerOptimizerLastEPCallback(
193                 [](ModulePassManager &MPM, OptimizationLevel OL) {
194 
195                   MPM.addPass(SplitComparesTransform());
196 
197                 });
198 
199   /* TODO LTO registration */
200   #else
201             using PipelineElement = typename PassBuilder::PipelineElement;
202             PB.registerPipelineParsingCallback([](StringRef          Name,
203                                                   ModulePassManager &MPM,
204                                                   ArrayRef<PipelineElement>) {
205 
206               if (Name == "splitcompares") {
207 
208                 MPM.addPass(SplitComparesTransform());
209                 return true;
210 
211               } else {
212 
213                 return false;
214 
215               }
216 
217             });
218 
219   #endif
220 
221           }};
222 
223 }
224 
225 #else
226 char SplitComparesTransform::ID = 0;
227 #endif
228 
229 /// This function splits FCMP instructions with xGE or xLE predicates into two
230 /// FCMP instructions with predicate xGT or xLT and EQ
simplifyFPCompares(Module & M)231 bool SplitComparesTransform::simplifyFPCompares(Module &M) {
232 
233   LLVMContext &              C = M.getContext();
234   std::vector<Instruction *> fcomps;
235   IntegerType *              Int1Ty = IntegerType::getInt1Ty(C);
236 
237   /* iterate over all functions, bbs and instruction and add
238    * all integer comparisons with >= and <= predicates to the icomps vector */
239   for (auto &F : M) {
240 
241     if (!isInInstrumentList(&F, MNAME)) continue;
242 
243     for (auto &BB : F) {
244 
245       for (auto &IN : BB) {
246 
247         CmpInst *selectcmpInst = nullptr;
248 
249         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
250 
251           if (enableFPSplit &&
252               (selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
253                selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
254                selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
255                selectcmpInst->getPredicate() == CmpInst::FCMP_ULE)) {
256 
257             auto op0 = selectcmpInst->getOperand(0);
258             auto op1 = selectcmpInst->getOperand(1);
259 
260             Type *TyOp0 = op0->getType();
261             Type *TyOp1 = op1->getType();
262 
263             /* this is probably not needed but we do it anyway */
264             if (TyOp0 != TyOp1) { continue; }
265 
266             if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
267 
268             fcomps.push_back(selectcmpInst);
269 
270           }
271 
272         }
273 
274       }
275 
276     }
277 
278   }
279 
280   if (!fcomps.size()) { return false; }
281 
282   /* transform for floating point */
283   for (auto &FcmpInst : fcomps) {
284 
285     BasicBlock *bb = FcmpInst->getParent();
286 
287     auto op0 = FcmpInst->getOperand(0);
288     auto op1 = FcmpInst->getOperand(1);
289 
290     /* find out what the new predicate is going to be */
291     auto cmp_inst = dyn_cast<CmpInst>(FcmpInst);
292     if (!cmp_inst) { continue; }
293     auto               pred = cmp_inst->getPredicate();
294     CmpInst::Predicate new_pred;
295 
296     switch (pred) {
297 
298       case CmpInst::FCMP_UGE:
299         new_pred = CmpInst::FCMP_UGT;
300         break;
301       case CmpInst::FCMP_OGE:
302         new_pred = CmpInst::FCMP_OGT;
303         break;
304       case CmpInst::FCMP_ULE:
305         new_pred = CmpInst::FCMP_ULT;
306         break;
307       case CmpInst::FCMP_OLE:
308         new_pred = CmpInst::FCMP_OLT;
309         break;
310       default:  // keep the compiler happy
311         continue;
312 
313     }
314 
315     /* split before the fcmp instruction */
316     BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
317 
318     /* the old bb now contains a unconditional jump to the new one (end_bb)
319      * we need to delete it later */
320 
321     /* create the FCMP instruction with new_pred and add it to the old basic
322      * block bb it is now at the position where the old FcmpInst was */
323     Instruction *fcmp_np;
324     fcmp_np = CmpInst::Create(Instruction::FCmp, new_pred, op0, op1);
325     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
326                              fcmp_np);
327 
328     /* create a new basic block which holds the new EQ fcmp */
329     Instruction *fcmp_eq;
330     /* insert middle_bb before end_bb */
331     BasicBlock *middle_bb =
332         BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
333     fcmp_eq = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, op0, op1);
334     middle_bb->getInstList().push_back(fcmp_eq);
335     /* add an unconditional branch to the end of middle_bb with destination
336      * end_bb */
337     BranchInst::Create(end_bb, middle_bb);
338 
339     /* replace the uncond branch with a conditional one, which depends on the
340      * new_pred fcmp. True goes to end, false to the middle (injected) bb */
341     auto term = bb->getTerminator();
342     BranchInst::Create(end_bb, middle_bb, fcmp_np, bb);
343     term->eraseFromParent();
344 
345     /* replace the old FcmpInst (which is the first inst in end_bb) with a PHI
346      * inst to wire up the loose ends */
347     PHINode *PN = PHINode::Create(Int1Ty, 2, "");
348     /* the first result depends on the outcome of fcmp_eq */
349     PN->addIncoming(fcmp_eq, middle_bb);
350     /* if the source was the original bb we know that the fcmp_np yielded true
351      * hence we can hardcode this value */
352     PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
353     /* replace the old FcmpInst with our new and shiny PHI inst */
354     BasicBlock::iterator ii(FcmpInst);
355     ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
356 
357   }
358 
359   return true;
360 
361 }
362 
363 /// This function splits ICMP instructions with xGE or xLE predicates into two
364 /// ICMP instructions with predicate xGT or xLT and EQ
simplifyOrEqualsCompare(CmpInst * IcmpInst,Module & M,CmpWorklist & worklist)365 bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst *    IcmpInst,
366                                                      Module &     M,
367                                                      CmpWorklist &worklist) {
368 
369   LLVMContext &C = M.getContext();
370   IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
371 
372   /* find out what the new predicate is going to be */
373   auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
374   if (!cmp_inst) { return false; }
375 
376   BasicBlock *bb = IcmpInst->getParent();
377 
378   auto op0 = IcmpInst->getOperand(0);
379   auto op1 = IcmpInst->getOperand(1);
380 
381   CmpInst::Predicate pred = cmp_inst->getPredicate();
382   CmpInst::Predicate new_pred;
383 
384   switch (pred) {
385 
386     case CmpInst::ICMP_UGE:
387       new_pred = CmpInst::ICMP_UGT;
388       break;
389     case CmpInst::ICMP_SGE:
390       new_pred = CmpInst::ICMP_SGT;
391       break;
392     case CmpInst::ICMP_ULE:
393       new_pred = CmpInst::ICMP_ULT;
394       break;
395     case CmpInst::ICMP_SLE:
396       new_pred = CmpInst::ICMP_SLT;
397       break;
398     default:  // keep the compiler happy
399       return false;
400 
401   }
402 
403   /* split before the icmp instruction */
404   BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
405 
406   /* the old bb now contains a unconditional jump to the new one (end_bb)
407    * we need to delete it later */
408 
409   /* create the ICMP instruction with new_pred and add it to the old basic
410    * block bb it is now at the position where the old IcmpInst was */
411   CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
412   bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np);
413 
414   /* create a new basic block which holds the new EQ icmp */
415   CmpInst *icmp_eq;
416   /* insert middle_bb before end_bb */
417   BasicBlock *middle_bb =
418       BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
419   icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
420   middle_bb->getInstList().push_back(icmp_eq);
421   /* add an unconditional branch to the end of middle_bb with destination
422    * end_bb */
423   BranchInst::Create(end_bb, middle_bb);
424 
425   /* replace the uncond branch with a conditional one, which depends on the
426    * new_pred icmp. True goes to end, false to the middle (injected) bb */
427   auto term = bb->getTerminator();
428   BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
429   term->eraseFromParent();
430 
431   /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
432    * inst to wire up the loose ends */
433   PHINode *PN = PHINode::Create(Int1Ty, 2, "");
434   /* the first result depends on the outcome of icmp_eq */
435   PN->addIncoming(icmp_eq, middle_bb);
436   /* if the source was the original bb we know that the icmp_np yielded true
437    * hence we can hardcode this value */
438   PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
439   /* replace the old IcmpInst with our new and shiny PHI inst */
440   BasicBlock::iterator ii(IcmpInst);
441   ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
442 
443   worklist.push_back(icmp_np);
444   worklist.push_back(icmp_eq);
445 
446   return true;
447 
448 }
449 
450 /// Simplify a signed comparison operator by splitting it into a unsigned and
451 /// bit comparison. add all resulting comparisons to
452 /// the worklist passed as a reference.
simplifySignedCompare(CmpInst * IcmpInst,Module & M,CmpWorklist & worklist)453 bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M,
454                                                    CmpWorklist &worklist) {
455 
456   LLVMContext &C = M.getContext();
457   IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
458 
459   BasicBlock *bb = IcmpInst->getParent();
460 
461   auto op0 = IcmpInst->getOperand(0);
462   auto op1 = IcmpInst->getOperand(1);
463 
464   IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
465   if (!intTyOp0) { return false; }
466   unsigned     bitw = intTyOp0->getBitWidth();
467   IntegerType *IntType = IntegerType::get(C, bitw);
468 
469   /* get the new predicate */
470   auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
471   if (!cmp_inst) { return false; }
472   auto               pred = cmp_inst->getPredicate();
473   CmpInst::Predicate new_pred;
474 
475   if (pred == CmpInst::ICMP_SGT) {
476 
477     new_pred = CmpInst::ICMP_UGT;
478 
479   } else {
480 
481     new_pred = CmpInst::ICMP_ULT;
482 
483   }
484 
485   BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
486 
487   /* create a 1 bit compare for the sign bit. to do this shift and trunc
488    * the original operands so only the first bit remains.*/
489   Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
490 
491   IRBuilder<> IRB(bb->getTerminator());
492   s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1));
493   t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty);
494   s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1));
495   t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty);
496   /* compare of the sign bits */
497   icmp_sign_bit = IRB.CreateICmp(CmpInst::ICMP_EQ, t_op0, t_op1);
498 
499   /* create a new basic block which is executed if the signedness bit is
500    * different */
501   CmpInst *   icmp_inv_sig_cmp;
502   BasicBlock *sign_bb =
503       BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
504   if (pred == CmpInst::ICMP_SGT) {
505 
506     /* if we check for > and the op0 positive and op1 negative then the final
507      * result is true. if op0 negative and op1 pos, the cmp must result
508      * in false
509      */
510     icmp_inv_sig_cmp =
511         CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
512 
513   } else {
514 
515     /* just the inverse of the above statement */
516     icmp_inv_sig_cmp =
517         CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
518 
519   }
520 
521   sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
522   BranchInst::Create(end_bb, sign_bb);
523 
524   /* create a new bb which is executed if signedness is equal */
525   CmpInst *   icmp_usign_cmp;
526   BasicBlock *middle_bb =
527       BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
528   /* we can do a normal unsigned compare now */
529   icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
530 
531   middle_bb->getInstList().push_back(icmp_usign_cmp);
532   BranchInst::Create(end_bb, middle_bb);
533 
534   auto term = bb->getTerminator();
535   /* if the sign is eq do a normal unsigned cmp, else we have to check the
536    * signedness bit */
537   BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
538   term->eraseFromParent();
539 
540   PHINode *PN = PHINode::Create(Int1Ty, 2, "");
541 
542   PN->addIncoming(icmp_usign_cmp, middle_bb);
543   PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
544 
545   BasicBlock::iterator ii(IcmpInst);
546   ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
547 
548   // save for later
549   worklist.push_back(icmp_usign_cmp);
550 
551   // signed comparisons are not supported by the splitting code, so we must not
552   // add it to the worklist.
553   // worklist.push_back(icmp_inv_sig_cmp);
554 
555   return true;
556 
557 }
558 
splitCompare(CmpInst * cmp_inst,Module & M,CmpWorklist & worklist)559 bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M,
560                                           CmpWorklist &worklist) {
561 
562   auto pred = cmp_inst->getPredicate();
563   switch (pred) {
564 
565     case CmpInst::ICMP_EQ:
566     case CmpInst::ICMP_NE:
567     case CmpInst::ICMP_UGT:
568     case CmpInst::ICMP_ULT:
569       break;
570     default:
571       // unsupported predicate!
572       return false;
573 
574   }
575 
576   auto op0 = cmp_inst->getOperand(0);
577   auto op1 = cmp_inst->getOperand(1);
578 
579   // get bitwidth by checking the bitwidth of the first operator
580   IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
581   if (!intTyOp0) {
582 
583     // not an integer type
584     return false;
585 
586   }
587 
588   unsigned bitw = intTyOp0->getBitWidth();
589   if (bitw == target_bitwidth) {
590 
591     // already the target bitwidth so we have to do nothing here.
592     return true;
593 
594   }
595 
596   LLVMContext &C = M.getContext();
597   IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
598   BasicBlock * bb = cmp_inst->getParent();
599   IntegerType *OldIntType = IntegerType::get(C, bitw);
600   IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
601   BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst));
602   CmpInst *    icmp_high, *icmp_low;
603 
604   /* create the comparison of the top halves of the original operands */
605   Value *s_op0, *op0_high, *s_op1, *op1_high;
606 
607   IRBuilder<> IRB(bb->getTerminator());
608 
609   s_op0 = IRB.CreateBinOp(Instruction::LShr, op0,
610                           ConstantInt::get(OldIntType, bitw / 2));
611   op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType);
612 
613   s_op1 = IRB.CreateBinOp(Instruction::LShr, op1,
614                           ConstantInt::get(OldIntType, bitw / 2));
615   op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType);
616   icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high));
617 
618   PHINode *PN = nullptr;
619 
620   /* now we have to destinguish between == != and > < */
621   switch (pred) {
622 
623     case CmpInst::ICMP_EQ:
624     case CmpInst::ICMP_NE: {
625 
626       /* transformation for == and != icmps */
627 
628       /* create a compare for the lower half of the original operands */
629       BasicBlock *cmp_low_bb =
630           BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
631 
632       Value *     op0_low, *op1_low;
633       IRBuilder<> Builder(cmp_low_bb);
634 
635       op0_low = Builder.CreateTrunc(op0, NewIntType);
636       op1_low = Builder.CreateTrunc(op1, NewIntType);
637       icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low));
638 
639       BranchInst::Create(end_bb, cmp_low_bb);
640 
641       /* dependent on the cmp of the high parts go to the end or go on with
642        * the comparison */
643       auto term = bb->getTerminator();
644 
645       if (pred == CmpInst::ICMP_EQ) {
646 
647         BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
648 
649       } else {
650 
651         // CmpInst::ICMP_NE
652         BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
653 
654       }
655 
656       term->eraseFromParent();
657 
658       /* create the PHI and connect the edges accordingly */
659       PN = PHINode::Create(Int1Ty, 2, "");
660       PN->addIncoming(icmp_low, cmp_low_bb);
661       Value *val = nullptr;
662       if (pred == CmpInst::ICMP_EQ) {
663 
664         val = ConstantInt::get(Int1Ty, 0);
665 
666       } else {
667 
668         /* CmpInst::ICMP_NE */
669         val = ConstantInt::get(Int1Ty, 1);
670 
671       }
672 
673       PN->addIncoming(val, icmp_high->getParent());
674       break;
675 
676     }
677 
678     case CmpInst::ICMP_UGT:
679     case CmpInst::ICMP_ULT: {
680 
681       /* transformations for < and > */
682 
683       /* create a basic block which checks for the inverse predicate.
684        * if this is true we can go to the end if not we have to go to the
685        * bb which checks the lower half of the operands */
686       Instruction *op0_low, *op1_low;
687       CmpInst *    icmp_inv_cmp = nullptr;
688       BasicBlock * inv_cmp_bb =
689           BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
690       if (pred == CmpInst::ICMP_UGT) {
691 
692         icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
693                                        op0_high, op1_high);
694 
695       } else {
696 
697         icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
698                                        op0_high, op1_high);
699 
700       }
701 
702       inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
703       worklist.push_back(icmp_inv_cmp);
704 
705       auto term = bb->getTerminator();
706       term->eraseFromParent();
707       BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
708 
709       /* create a bb which handles the cmp of the lower halves */
710       BasicBlock *cmp_low_bb =
711           BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
712       op0_low = new TruncInst(op0, NewIntType);
713       cmp_low_bb->getInstList().push_back(op0_low);
714       op1_low = new TruncInst(op1, NewIntType);
715       cmp_low_bb->getInstList().push_back(op1_low);
716 
717       icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
718       cmp_low_bb->getInstList().push_back(icmp_low);
719       BranchInst::Create(end_bb, cmp_low_bb);
720 
721       BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
722 
723       PN = PHINode::Create(Int1Ty, 3);
724       PN->addIncoming(icmp_low, cmp_low_bb);
725       PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
726       PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
727       break;
728 
729     }
730 
731     default:
732       return false;
733 
734   }
735 
736   BasicBlock::iterator ii(cmp_inst);
737   ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN);
738 
739   // We split the comparison into low and high. If this isn't our target
740   // bitwidth we recursively split the low and high parts again until we have
741   // target bitwidth.
742   if ((bitw / 2) > target_bitwidth) {
743 
744     worklist.push_back(icmp_high);
745     worklist.push_back(icmp_low);
746 
747   }
748 
749   return true;
750 
751 }
752 
simplifyAndSplit(CmpInst * I,Module & M)753 bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) {
754 
755   CmpWorklist worklist;
756 
757   auto op0 = I->getOperand(0);
758   auto op1 = I->getOperand(1);
759   if (!op0 || !op1) { return false; }
760   auto op0Ty = dyn_cast<IntegerType>(op0->getType());
761   if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; }
762 
763   unsigned bitw = op0Ty->getBitWidth();
764 
765 #ifdef VERIFY_TOO_MUCH
766   auto F = I->getParent()->getParent();
767 #endif
768 
769   // we run the comparison simplification on all compares regardless of their
770   // bitwidth.
771   if (I->getPredicate() == CmpInst::ICMP_UGE ||
772       I->getPredicate() == CmpInst::ICMP_SGE ||
773       I->getPredicate() == CmpInst::ICMP_ULE ||
774       I->getPredicate() == CmpInst::ICMP_SLE) {
775 
776     if (!simplifyOrEqualsCompare(I, M, worklist)) {
777 
778       reportError(
779           "Failed to simplify inequality or equals comparison "
780           "(UGE,SGE,ULE,SLE)",
781           I, M);
782 
783     }
784 
785   } else if (I->getPredicate() == CmpInst::ICMP_SGT ||
786 
787              I->getPredicate() == CmpInst::ICMP_SLT) {
788 
789     if (!simplifySignedCompare(I, M, worklist)) {
790 
791       reportError("Failed to simplify signed comparison (SGT,SLT)", I, M);
792 
793     }
794 
795   }
796 
797 #ifdef VERIFY_TOO_MUCH
798   if (verifyFunction(*F, &errs())) {
799 
800     reportError("simpliyfing compare lead to broken function", nullptr, M);
801 
802   }
803 
804 #endif
805 
806   // the simplification methods replace the original CmpInst and push the
807   // resulting new CmpInst into the worklist. If the worklist is empty then
808   // we only have to split the original CmpInst.
809   if (worklist.size() == 0) { worklist.push_back(I); }
810 
811   while (!worklist.empty()) {
812 
813     CmpInst *cmp = worklist.pop_back_val();
814     // we split the simplified compares into comparisons with smaller bitwidths
815     // if they are larger than our target_bitwidth.
816     if (bitw > target_bitwidth) {
817 
818       if (!splitCompare(cmp, M, worklist)) {
819 
820         reportError("Failed to split comparison", cmp, M);
821 
822       }
823 
824 #ifdef VERIFY_TOO_MUCH
825       if (verifyFunction(*F, &errs())) {
826 
827         reportError("splitting compare lead to broken function", nullptr, M);
828 
829       }
830 
831 #endif
832 
833     }
834 
835   }
836 
837   count++;
838   return true;
839 
840 }
841 
nextPowerOfTwo(size_t in)842 size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
843 
844   --in;
845   in |= in >> 1;
846   in |= in >> 2;
847   in |= in >> 4;
848   //  in |= in >> 8;
849   //  in |= in >> 16;
850   return in + 1;
851 
852 }
853 
854 /* splits fcmps into two nested fcmps with sign compare and the rest */
splitFPCompares(Module & M)855 size_t SplitComparesTransform::splitFPCompares(Module &M) {
856 
857   size_t count = 0;
858 
859   LLVMContext &C = M.getContext();
860 
861 #if LLVM_VERSION_MAJOR >= 4 || \
862     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
863   const DataLayout &dl = M.getDataLayout();
864 
865   /* define unions with floating point and (sign, exponent, mantissa)  triples
866    */
867   if (dl.isLittleEndian()) {
868 
869   } else if (dl.isBigEndian()) {
870 
871   } else {
872 
873     return count;
874 
875   }
876 
877 #endif
878 
879   std::vector<CmpInst *> fcomps;
880 
881   /* get all EQ, NE, GT, and LT fcmps. if the other two
882    * functions were executed only these four predicates should exist */
883   for (auto &F : M) {
884 
885     if (!isInInstrumentList(&F, MNAME)) continue;
886 
887     for (auto &BB : F) {
888 
889       for (auto &IN : BB) {
890 
891         CmpInst *selectcmpInst = nullptr;
892 
893         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
894 
895           if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
896               selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ ||
897               selectcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
898               selectcmpInst->getPredicate() == CmpInst::FCMP_UNE ||
899               selectcmpInst->getPredicate() == CmpInst::FCMP_UGT ||
900               selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
901               selectcmpInst->getPredicate() == CmpInst::FCMP_ULT ||
902               selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) {
903 
904             auto op0 = selectcmpInst->getOperand(0);
905             auto op1 = selectcmpInst->getOperand(1);
906 
907             Type *TyOp0 = op0->getType();
908             Type *TyOp1 = op1->getType();
909 
910             if (TyOp0 != TyOp1) { continue; }
911 
912             if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
913 
914             fcomps.push_back(selectcmpInst);
915 
916           }
917 
918         }
919 
920       }
921 
922     }
923 
924   }
925 
926   if (!fcomps.size()) { return count; }
927 
928   IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
929 
930   for (auto &FcmpInst : fcomps) {
931 
932     BasicBlock *bb = FcmpInst->getParent();
933 
934     auto op0 = FcmpInst->getOperand(0);
935     auto op1 = FcmpInst->getOperand(1);
936 
937     unsigned op_size;
938     op_size = op0->getType()->getPrimitiveSizeInBits();
939 
940     if (op_size != op1->getType()->getPrimitiveSizeInBits()) { continue; }
941 
942     const unsigned int sizeInBits = op0->getType()->getPrimitiveSizeInBits();
943 
944     // BUG FIXME TODO: u64 does not work for > 64 bit ... e.g. 80 and 128 bit
945     if (sizeInBits > 64) { continue; }
946 
947     IntegerType *      intType = IntegerType::get(C, op_size);
948     const unsigned int precision = sizeInBits == 32    ? 24
949                                    : sizeInBits == 64  ? 53
950                                    : sizeInBits == 128 ? 113
951                                    : sizeInBits == 16  ? 11
952                                    : sizeInBits == 80  ? 65
953                                                        : sizeInBits - 8;
954 
955     const unsigned           shiftR_exponent = precision - 1;
956     const unsigned long long mask_fraction =
957         (1ULL << (shiftR_exponent - 1)) | ((1ULL << (shiftR_exponent - 1)) - 1);
958     const unsigned long long mask_exponent =
959         (1ULL << (sizeInBits - precision)) - 1;
960 
961     // round up sizes to the next power of two
962     // this should help with integer compare splitting
963     size_t exTySizeBytes = ((sizeInBits - precision + 7) >> 3);
964     size_t frTySizeBytes = ((precision - 1ULL + 7) >> 3);
965 
966     IntegerType *IntExponentTy =
967         IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3);
968     IntegerType *IntFractionTy =
969         IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3);
970 
971     //    errs() << "Fractions: IntFractionTy size " <<
972     //     IntFractionTy->getPrimitiveSizeInBits() << ", op_size " << op_size <<
973     //     ", mask " << mask_fraction <<
974     //     ", precision " << precision << "\n";
975 
976     BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
977 
978     /* create the integers from floats directly */
979     Instruction *bpre_op0, *bpre_op1;
980     bpre_op0 = CastInst::Create(Instruction::BitCast, op0,
981                                 IntegerType::get(C, op_size));
982     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
983                              bpre_op0);
984 
985     bpre_op1 = CastInst::Create(Instruction::BitCast, op1,
986                                 IntegerType::get(C, op_size));
987     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
988                              bpre_op1);
989 
990     /* Check if any operand is NaN.
991      * If so, all comparisons except unequal (which yields true) yield false */
992 
993     /* build mask for NaN */
994     const unsigned long long NaN_lowend = mask_exponent << precision;
995     //    errs() << "Fractions: IntFractionTy size " <<
996     //     IntFractionTy->getPrimitiveSizeInBits() << ", op_size " << op_size <<
997     //     ", mask_fraction 0x";
998     //    errs().write_hex(mask_fraction);
999     //    errs() << ", precision " << precision <<
1000     //     ", NaN_lowend 0x";
1001     //    errs().write_hex(NaN_lowend); errs() << "\n";
1002 
1003     /* Check op0 for NaN */
1004     /* Shift left 1 Bit, ignore sign bit */
1005     Instruction *nan_op0, *nan_op1;
1006     nan_op0 = BinaryOperator::Create(Instruction::Shl, bpre_op0,
1007                                      ConstantInt::get(bpre_op0->getType(), 1));
1008     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1009                              nan_op0);
1010 
1011     /* compare to NaN interval */
1012     Instruction *is_op0_nan =
1013         CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, nan_op0,
1014                         ConstantInt::get(intType, NaN_lowend));
1015     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1016                              is_op0_nan);
1017 
1018     /* Check op1 for NaN */
1019     /* Shift right 1 Bit, ignore sign bit */
1020     nan_op1 = BinaryOperator::Create(Instruction::Shl, bpre_op1,
1021                                      ConstantInt::get(bpre_op1->getType(), 1));
1022     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1023                              nan_op1);
1024 
1025     /* compare to NaN interval */
1026     Instruction *is_op1_nan =
1027         CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, nan_op1,
1028                         ConstantInt::get(intType, NaN_lowend));
1029     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
1030                              is_op1_nan);
1031 
1032     /* combine checks */
1033     Instruction *is_nan =
1034         BinaryOperator::Create(Instruction::Or, is_op0_nan, is_op1_nan);
1035     bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), is_nan);
1036 
1037     /* the result of the comparison, when at least one op is NaN
1038        is true only for the "NOT EQUAL" predicates. */
1039     bool NaNcmp_result = FcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
1040                          FcmpInst->getPredicate() == CmpInst::FCMP_UNE;
1041 
1042     BasicBlock *nonan_bb =
1043         BasicBlock::Create(C, "noNaN", end_bb->getParent(), end_bb);
1044 
1045     BranchInst::Create(end_bb, nonan_bb);
1046 
1047     auto term = bb->getTerminator();
1048     /* if no operand is NaN goto nonan_bb else to handleNaN_bb */
1049     BranchInst::Create(end_bb, nonan_bb, is_nan, bb);
1050     term->eraseFromParent();
1051 
1052     /*** now working in nonan_bb ***/
1053 
1054     /* Treat -0.0 as equal to +0.0, that is for -0.0 make it +0.0 */
1055     Instruction *            b_op0, *b_op1;
1056     Instruction *            isMzero_op0, *isMzero_op1;
1057     const unsigned long long MinusZero = 1UL << (sizeInBits - 1U);
1058     const unsigned long long PlusZero = 0;
1059 
1060     isMzero_op0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, bpre_op0,
1061                                   ConstantInt::get(intType, MinusZero));
1062     nonan_bb->getInstList().insert(
1063         BasicBlock::iterator(nonan_bb->getTerminator()), isMzero_op0);
1064 
1065     isMzero_op1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, bpre_op1,
1066                                   ConstantInt::get(intType, MinusZero));
1067     nonan_bb->getInstList().insert(
1068         BasicBlock::iterator(nonan_bb->getTerminator()), isMzero_op1);
1069 
1070     b_op0 = SelectInst::Create(isMzero_op0, ConstantInt::get(intType, PlusZero),
1071                                bpre_op0);
1072     nonan_bb->getInstList().insert(
1073         BasicBlock::iterator(nonan_bb->getTerminator()), b_op0);
1074 
1075     b_op1 = SelectInst::Create(isMzero_op1, ConstantInt::get(intType, PlusZero),
1076                                bpre_op1);
1077     nonan_bb->getInstList().insert(
1078         BasicBlock::iterator(nonan_bb->getTerminator()), b_op1);
1079 
1080     /* isolate signs of value of floating point type */
1081 
1082     /* create a 1 bit compare for the sign bit. to do this shift and trunc
1083      * the original operands so only the first bit remains.*/
1084     Instruction *s_s0, *t_s0, *s_s1, *t_s1, *icmp_sign_bit;
1085 
1086     s_s0 =
1087         BinaryOperator::Create(Instruction::LShr, b_op0,
1088                                ConstantInt::get(b_op0->getType(), op_size - 1));
1089     nonan_bb->getInstList().insert(
1090         BasicBlock::iterator(nonan_bb->getTerminator()), s_s0);
1091     t_s0 = new TruncInst(s_s0, Int1Ty);
1092     nonan_bb->getInstList().insert(
1093         BasicBlock::iterator(nonan_bb->getTerminator()), t_s0);
1094 
1095     s_s1 =
1096         BinaryOperator::Create(Instruction::LShr, b_op1,
1097                                ConstantInt::get(b_op1->getType(), op_size - 1));
1098     nonan_bb->getInstList().insert(
1099         BasicBlock::iterator(nonan_bb->getTerminator()), s_s1);
1100     t_s1 = new TruncInst(s_s1, Int1Ty);
1101     nonan_bb->getInstList().insert(
1102         BasicBlock::iterator(nonan_bb->getTerminator()), t_s1);
1103 
1104     /* compare of the sign bits */
1105     icmp_sign_bit =
1106         CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1);
1107     nonan_bb->getInstList().insert(
1108         BasicBlock::iterator(nonan_bb->getTerminator()), icmp_sign_bit);
1109 
1110     /* create a new basic block which is executed if the signedness bits are
1111      * equal */
1112     BasicBlock *signequal_bb =
1113         BasicBlock::Create(C, "signequal", end_bb->getParent(), end_bb);
1114 
1115     BranchInst::Create(end_bb, signequal_bb);
1116 
1117     /* create a new bb which is executed if exponents are satisfying the compare
1118      */
1119     BasicBlock *middle_bb =
1120         BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
1121 
1122     BranchInst::Create(end_bb, middle_bb);
1123 
1124     term = nonan_bb->getTerminator();
1125     /* if the signs are different goto end_bb else to signequal_bb */
1126     BranchInst::Create(signequal_bb, end_bb, icmp_sign_bit, nonan_bb);
1127     term->eraseFromParent();
1128 
1129     /* insert code for equal signs */
1130 
1131     /* isolate the exponents */
1132     Instruction *s_e0, *m_e0, *t_e0, *s_e1, *m_e1, *t_e1;
1133 
1134     s_e0 = BinaryOperator::Create(
1135         Instruction::LShr, b_op0,
1136         ConstantInt::get(b_op0->getType(), shiftR_exponent));
1137     s_e1 = BinaryOperator::Create(
1138         Instruction::LShr, b_op1,
1139         ConstantInt::get(b_op1->getType(), shiftR_exponent));
1140     signequal_bb->getInstList().insert(
1141         BasicBlock::iterator(signequal_bb->getTerminator()), s_e0);
1142     signequal_bb->getInstList().insert(
1143         BasicBlock::iterator(signequal_bb->getTerminator()), s_e1);
1144 
1145     t_e0 = new TruncInst(s_e0, IntExponentTy);
1146     t_e1 = new TruncInst(s_e1, IntExponentTy);
1147     signequal_bb->getInstList().insert(
1148         BasicBlock::iterator(signequal_bb->getTerminator()), t_e0);
1149     signequal_bb->getInstList().insert(
1150         BasicBlock::iterator(signequal_bb->getTerminator()), t_e1);
1151 
1152     if (sizeInBits - precision < exTySizeBytes * 8) {
1153 
1154       m_e0 = BinaryOperator::Create(
1155           Instruction::And, t_e0,
1156           ConstantInt::get(t_e0->getType(), mask_exponent));
1157       m_e1 = BinaryOperator::Create(
1158           Instruction::And, t_e1,
1159           ConstantInt::get(t_e1->getType(), mask_exponent));
1160       signequal_bb->getInstList().insert(
1161           BasicBlock::iterator(signequal_bb->getTerminator()), m_e0);
1162       signequal_bb->getInstList().insert(
1163           BasicBlock::iterator(signequal_bb->getTerminator()), m_e1);
1164 
1165     } else {
1166 
1167       m_e0 = t_e0;
1168       m_e1 = t_e1;
1169 
1170     }
1171 
1172     /* compare the exponents of the operands */
1173     Instruction *icmp_exponents_equal;
1174     Instruction *icmp_exponent_result;
1175     BasicBlock * signequal2_bb = signequal_bb;
1176     switch (FcmpInst->getPredicate()) {
1177 
1178       case CmpInst::FCMP_UEQ:
1179       case CmpInst::FCMP_OEQ:
1180         icmp_exponent_result =
1181             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
1182         break;
1183       case CmpInst::FCMP_ONE:
1184       case CmpInst::FCMP_UNE:
1185         icmp_exponent_result =
1186             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, m_e0, m_e1);
1187         break;
1188       /* compare the exponents of the operands (signs are equal)
1189        * if exponents are equal -> proceed to mantissa comparison
1190        * else get result depending on sign
1191        */
1192       case CmpInst::FCMP_OGT:
1193       case CmpInst::FCMP_UGT:
1194         Instruction *icmp_exponent;
1195         icmp_exponents_equal =
1196             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
1197         signequal_bb->getInstList().insert(
1198             BasicBlock::iterator(signequal_bb->getTerminator()),
1199             icmp_exponents_equal);
1200 
1201         // shortcut for unequal exponents
1202         signequal2_bb = signequal_bb->splitBasicBlock(
1203             BasicBlock::iterator(signequal_bb->getTerminator()));
1204 
1205         /* if the exponents are equal goto middle_bb else to signequal2_bb */
1206         term = signequal_bb->getTerminator();
1207         BranchInst::Create(middle_bb, signequal2_bb, icmp_exponents_equal,
1208                            signequal_bb);
1209         term->eraseFromParent();
1210 
1211         icmp_exponent =
1212             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, m_e0, m_e1);
1213         signequal2_bb->getInstList().insert(
1214             BasicBlock::iterator(signequal2_bb->getTerminator()),
1215             icmp_exponent);
1216         icmp_exponent_result =
1217             BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
1218         break;
1219       case CmpInst::FCMP_OLT:
1220       case CmpInst::FCMP_ULT:
1221         icmp_exponents_equal =
1222             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
1223         signequal_bb->getInstList().insert(
1224             BasicBlock::iterator(signequal_bb->getTerminator()),
1225             icmp_exponents_equal);
1226 
1227         // shortcut for unequal exponents
1228         signequal2_bb = signequal_bb->splitBasicBlock(
1229             BasicBlock::iterator(signequal_bb->getTerminator()));
1230 
1231         /* if the exponents are equal goto middle_bb else to signequal2_bb */
1232         term = signequal_bb->getTerminator();
1233         BranchInst::Create(middle_bb, signequal2_bb, icmp_exponents_equal,
1234                            signequal_bb);
1235         term->eraseFromParent();
1236 
1237         icmp_exponent =
1238             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, m_e0, m_e1);
1239         signequal2_bb->getInstList().insert(
1240             BasicBlock::iterator(signequal2_bb->getTerminator()),
1241             icmp_exponent);
1242         icmp_exponent_result =
1243             BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
1244         break;
1245       default:
1246         continue;
1247 
1248     }
1249 
1250     signequal2_bb->getInstList().insert(
1251         BasicBlock::iterator(signequal2_bb->getTerminator()),
1252         icmp_exponent_result);
1253 
1254     {
1255 
1256       term = signequal2_bb->getTerminator();
1257 
1258       switch (FcmpInst->getPredicate()) {
1259 
1260         case CmpInst::FCMP_UEQ:
1261         case CmpInst::FCMP_OEQ:
1262           /* if the exponents are satifying the compare do a fraction cmp in
1263            * middle_bb */
1264           BranchInst::Create(middle_bb, end_bb, icmp_exponent_result,
1265                              signequal2_bb);
1266           break;
1267         case CmpInst::FCMP_ONE:
1268         case CmpInst::FCMP_UNE:
1269           /* if the exponents are satifying the compare do a fraction cmp in
1270            * middle_bb */
1271           BranchInst::Create(end_bb, middle_bb, icmp_exponent_result,
1272                              signequal2_bb);
1273           break;
1274         case CmpInst::FCMP_OGT:
1275         case CmpInst::FCMP_UGT:
1276         case CmpInst::FCMP_OLT:
1277         case CmpInst::FCMP_ULT:
1278           BranchInst::Create(end_bb, signequal2_bb);
1279           break;
1280         default:
1281           continue;
1282 
1283       }
1284 
1285       term->eraseFromParent();
1286 
1287     }
1288 
1289     /* isolate the mantissa aka fraction */
1290     Instruction *t_f0, *t_f1;
1291     bool         needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op_size;
1292 
1293     if (precision - 1 < frTySizeBytes * 8) {
1294 
1295       Instruction *m_f0, *m_f1;
1296       m_f0 = BinaryOperator::Create(
1297           Instruction::And, b_op0,
1298           ConstantInt::get(b_op0->getType(), mask_fraction));
1299       m_f1 = BinaryOperator::Create(
1300           Instruction::And, b_op1,
1301           ConstantInt::get(b_op1->getType(), mask_fraction));
1302       middle_bb->getInstList().insert(
1303           BasicBlock::iterator(middle_bb->getTerminator()), m_f0);
1304       middle_bb->getInstList().insert(
1305           BasicBlock::iterator(middle_bb->getTerminator()), m_f1);
1306 
1307       if (needTrunc) {
1308 
1309         t_f0 = new TruncInst(m_f0, IntFractionTy);
1310         t_f1 = new TruncInst(m_f1, IntFractionTy);
1311         middle_bb->getInstList().insert(
1312             BasicBlock::iterator(middle_bb->getTerminator()), t_f0);
1313         middle_bb->getInstList().insert(
1314             BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
1315 
1316       } else {
1317 
1318         t_f0 = m_f0;
1319         t_f1 = m_f1;
1320 
1321       }
1322 
1323     } else {
1324 
1325       if (needTrunc) {
1326 
1327         t_f0 = new TruncInst(b_op0, IntFractionTy);
1328         t_f1 = new TruncInst(b_op1, IntFractionTy);
1329         middle_bb->getInstList().insert(
1330             BasicBlock::iterator(middle_bb->getTerminator()), t_f0);
1331         middle_bb->getInstList().insert(
1332             BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
1333 
1334       } else {
1335 
1336         t_f0 = b_op0;
1337         t_f1 = b_op1;
1338 
1339       }
1340 
1341     }
1342 
1343     /* compare the fractions of the operands */
1344     Instruction *icmp_fraction_result;
1345     BasicBlock * middle2_bb = middle_bb;
1346     PHINode *    PN2 = nullptr;
1347     switch (FcmpInst->getPredicate()) {
1348 
1349       case CmpInst::FCMP_UEQ:
1350       case CmpInst::FCMP_OEQ:
1351         icmp_fraction_result =
1352             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_f0, t_f1);
1353         middle2_bb->getInstList().insert(
1354             BasicBlock::iterator(middle2_bb->getTerminator()),
1355             icmp_fraction_result);
1356 
1357         break;
1358       case CmpInst::FCMP_UNE:
1359       case CmpInst::FCMP_ONE:
1360         icmp_fraction_result =
1361             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_f0, t_f1);
1362         middle2_bb->getInstList().insert(
1363             BasicBlock::iterator(middle2_bb->getTerminator()),
1364             icmp_fraction_result);
1365 
1366         break;
1367       case CmpInst::FCMP_OGT:
1368       case CmpInst::FCMP_UGT:
1369       case CmpInst::FCMP_OLT:
1370       case CmpInst::FCMP_ULT: {
1371 
1372         Instruction *icmp_fraction_result2;
1373 
1374         middle2_bb = middle_bb->splitBasicBlock(
1375             BasicBlock::iterator(middle_bb->getTerminator()));
1376 
1377         BasicBlock *negative_bb = BasicBlock::Create(
1378             C, "negative_value", middle2_bb->getParent(), middle2_bb);
1379         BasicBlock *positive_bb = BasicBlock::Create(
1380             C, "positive_value", negative_bb->getParent(), negative_bb);
1381 
1382         if (FcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
1383             FcmpInst->getPredicate() == CmpInst::FCMP_UGT) {
1384 
1385           negative_bb->getInstList().push_back(
1386               icmp_fraction_result = CmpInst::Create(
1387                   Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
1388           positive_bb->getInstList().push_back(
1389               icmp_fraction_result2 = CmpInst::Create(
1390                   Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
1391 
1392         } else {
1393 
1394           negative_bb->getInstList().push_back(
1395               icmp_fraction_result = CmpInst::Create(
1396                   Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
1397           positive_bb->getInstList().push_back(
1398               icmp_fraction_result2 = CmpInst::Create(
1399                   Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
1400 
1401         }
1402 
1403         BranchInst::Create(middle2_bb, negative_bb);
1404         BranchInst::Create(middle2_bb, positive_bb);
1405 
1406         term = middle_bb->getTerminator();
1407         BranchInst::Create(negative_bb, positive_bb, t_s0, middle_bb);
1408         term->eraseFromParent();
1409 
1410         PN2 = PHINode::Create(Int1Ty, 2, "");
1411         PN2->addIncoming(icmp_fraction_result, negative_bb);
1412         PN2->addIncoming(icmp_fraction_result2, positive_bb);
1413         middle2_bb->getInstList().insert(
1414             BasicBlock::iterator(middle2_bb->getTerminator()), PN2);
1415 
1416       } break;
1417 
1418       default:
1419         continue;
1420 
1421     }
1422 
1423     PHINode *PN = PHINode::Create(Int1Ty, 4, "");
1424 
1425     switch (FcmpInst->getPredicate()) {
1426 
1427       case CmpInst::FCMP_UEQ:
1428       case CmpInst::FCMP_OEQ:
1429         /* unequal signs cannot be equal values */
1430         /* goto false branch */
1431         PN->addIncoming(ConstantInt::get(Int1Ty, 0), nonan_bb);
1432         /* unequal exponents cannot be equal values, too */
1433         PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb);
1434         /* fractions comparison */
1435         PN->addIncoming(icmp_fraction_result, middle2_bb);
1436         /* NaNs */
1437         PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1438         break;
1439       case CmpInst::FCMP_ONE:
1440       case CmpInst::FCMP_UNE:
1441         /* unequal signs are unequal values */
1442         /* goto true branch */
1443         PN->addIncoming(ConstantInt::get(Int1Ty, 1), nonan_bb);
1444         /* unequal exponents are unequal values, too */
1445         PN->addIncoming(icmp_exponent_result, signequal_bb);
1446         /* fractions comparison */
1447         PN->addIncoming(icmp_fraction_result, middle2_bb);
1448         /* NaNs */
1449         PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1450         break;
1451       case CmpInst::FCMP_OGT:
1452       case CmpInst::FCMP_UGT:
1453         /* if op1 is negative goto true branch,
1454            else go on comparing */
1455         PN->addIncoming(t_s1, nonan_bb);
1456         PN->addIncoming(icmp_exponent_result, signequal2_bb);
1457         PN->addIncoming(PN2, middle2_bb);
1458         /* NaNs */
1459         PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1460         break;
1461       case CmpInst::FCMP_OLT:
1462       case CmpInst::FCMP_ULT:
1463         /* if op0 is negative goto true branch,
1464            else go on comparing */
1465         PN->addIncoming(t_s0, nonan_bb);
1466         PN->addIncoming(icmp_exponent_result, signequal2_bb);
1467         PN->addIncoming(PN2, middle2_bb);
1468         /* NaNs */
1469         PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb);
1470         break;
1471       default:
1472         continue;
1473 
1474     }
1475 
1476     BasicBlock::iterator ii(FcmpInst);
1477     ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
1478     ++count;
1479 
1480   }
1481 
1482   return count;
1483 
1484 }
1485 
1486 #if LLVM_MAJOR >= 11
run(Module & M,ModuleAnalysisManager & MAM)1487 PreservedAnalyses SplitComparesTransform::run(Module &               M,
1488                                               ModuleAnalysisManager &MAM) {
1489 
1490 #else
1491 bool SplitComparesTransform::runOnModule(Module &M) {
1492 
1493 #endif
1494 
1495   char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
1496   if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
1497   if (bitw_env) { target_bitwidth = atoi(bitw_env); }
1498 
1499   enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
1500 
1501   if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
1502       getenv("AFL_DEBUG") != NULL) {
1503 
1504     errs() << "Split-compare-newpass by laf.intel@gmail.com, extended by "
1505               "heiko@hexco.de (splitting icmp to "
1506            << target_bitwidth << " bit)\n";
1507 
1508     if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; }
1509 
1510   } else {
1511 
1512     be_quiet = 1;
1513 
1514   }
1515 
1516 #if LLVM_MAJOR >= 11
1517   auto PA = PreservedAnalyses::all();
1518 #endif
1519 
1520   if (enableFPSplit) {
1521 
1522     simplifyFPCompares(M);
1523     count = splitFPCompares(M);
1524 
1525     if (!be_quiet && !debug) {
1526 
1527       errs() << "Split-floatingpoint-compare-pass: " << count
1528              << " FP comparisons splitted\n";
1529 
1530     }
1531 
1532   }
1533 
1534   std::vector<CmpInst *> worklist;
1535   /* iterate over all functions, bbs and instruction search for all integer
1536    * compare instructions. Save them into the worklist for later. */
1537   for (auto &F : M) {
1538 
1539     if (!isInInstrumentList(&F, MNAME)) continue;
1540 
1541     for (auto &BB : F) {
1542 
1543       for (auto &IN : BB) {
1544 
1545         if (auto CI = dyn_cast<CmpInst>(&IN)) {
1546 
1547           auto op0 = CI->getOperand(0);
1548           auto op1 = CI->getOperand(1);
1549           if (!op0 || !op1) {
1550 
1551 #if LLVM_MAJOR >= 11
1552             return PA;
1553 #else
1554             return false;
1555 #endif
1556 
1557           }
1558 
1559           auto iTy1 = dyn_cast<IntegerType>(op0->getType());
1560           if (iTy1 && isa<IntegerType>(op1->getType())) {
1561 
1562             unsigned bitw = iTy1->getBitWidth();
1563             if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
1564 
1565           }
1566 
1567         }
1568 
1569       }
1570 
1571     }
1572 
1573   }
1574 
1575   // now that we have a list of all integer comparisons we can start replacing
1576   // them with the splitted alternatives.
1577   for (auto CI : worklist) {
1578 
1579     simplifyAndSplit(CI, M);
1580 
1581   }
1582 
1583   bool brokenDebug = false;
1584   if (verifyModule(M, &errs()
1585 #if LLVM_VERSION_MAJOR >= 4 || \
1586     (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9)
1587                           ,
1588                    &brokenDebug  // 9th May 2016
1589 #endif
1590                    )) {
1591 
1592     reportError(
1593         "Module Verifier failed! Consider reporting a bug with the AFL++ "
1594         "project.",
1595         nullptr, M);
1596 
1597   }
1598 
1599   if (brokenDebug) {
1600 
1601     reportError("Module Verifier reported broken Debug Infos - Stripping!",
1602                 nullptr, M);
1603     StripDebugInfo(M);
1604 
1605   }
1606 
1607   if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
1608       getenv("AFL_DEBUG") != NULL) {
1609 
1610     errs() << count << " comparisons found\n";
1611 
1612   }
1613 
1614 #if LLVM_MAJOR >= 11
1615   /*  if (modified) {
1616 
1617       PA.abandon<XX_Manager>();
1618 
1619     }*/
1620 
1621   return PA;
1622 #else
1623   return true;
1624 #endif
1625 
1626 }
1627 
1628 #if LLVM_MAJOR < 11                                 /* use old pass manager */
1629 
1630 static void registerSplitComparesPass(const PassManagerBuilder &,
1631                                       legacy::PassManagerBase &PM) {
1632 
1633   PM.add(new SplitComparesTransform());
1634 
1635 }
1636 
1637 static RegisterStandardPasses RegisterSplitComparesPass(
1638     PassManagerBuilder::EP_OptimizerLast, registerSplitComparesPass);
1639 
1640 static RegisterStandardPasses RegisterSplitComparesTransPass0(
1641     PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitComparesPass);
1642 
1643   #if LLVM_VERSION_MAJOR >= 11
1644 static RegisterStandardPasses RegisterSplitComparesTransPassLTO(
1645     PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
1646     registerSplitComparesPass);
1647   #endif
1648 
1649 static RegisterPass<SplitComparesTransform> X("splitcompares",
1650                                               "AFL++ split compares",
1651                                               true /* Only looks at CFG */,
1652                                               true /* Analysis Pass */);
1653 #endif
1654 
1655