1 //===- InstCombineMulDivRem.cpp -------------------------------------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the visit functions for mul, fmul, sdiv, udiv, fdiv,
11 // srem, urem, frem.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "InstCombineInternal.h"
16 #include "llvm/Analysis/InstructionSimplify.h"
17 #include "llvm/IR/IntrinsicInst.h"
18 #include "llvm/IR/PatternMatch.h"
19 using namespace llvm;
20 using namespace PatternMatch;
21
22 #define DEBUG_TYPE "instcombine"
23
24
25 /// The specific integer value is used in a context where it is known to be
26 /// non-zero. If this allows us to simplify the computation, do so and return
27 /// the new operand, otherwise return null.
simplifyValueKnownNonZero(Value * V,InstCombiner & IC,Instruction & CxtI)28 static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
29 Instruction &CxtI) {
30 // If V has multiple uses, then we would have to do more analysis to determine
31 // if this is safe. For example, the use could be in dynamically unreached
32 // code.
33 if (!V->hasOneUse()) return nullptr;
34
35 bool MadeChange = false;
36
37 // ((1 << A) >>u B) --> (1 << (A-B))
38 // Because V cannot be zero, we know that B is less than A.
39 Value *A = nullptr, *B = nullptr, *One = nullptr;
40 if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(One), m_Value(A))), m_Value(B))) &&
41 match(One, m_One())) {
42 A = IC.Builder->CreateSub(A, B);
43 return IC.Builder->CreateShl(One, A);
44 }
45
46 // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it
47 // inexact. Similarly for <<.
48 BinaryOperator *I = dyn_cast<BinaryOperator>(V);
49 if (I && I->isLogicalShift() &&
50 isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0,
51 IC.getAssumptionCache(), &CxtI,
52 IC.getDominatorTree())) {
53 // We know that this is an exact/nuw shift and that the input is a
54 // non-zero context as well.
55 if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) {
56 I->setOperand(0, V2);
57 MadeChange = true;
58 }
59
60 if (I->getOpcode() == Instruction::LShr && !I->isExact()) {
61 I->setIsExact();
62 MadeChange = true;
63 }
64
65 if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) {
66 I->setHasNoUnsignedWrap();
67 MadeChange = true;
68 }
69 }
70
71 // TODO: Lots more we could do here:
72 // If V is a phi node, we can call this on each of its operands.
73 // "select cond, X, 0" can simplify to "X".
74
75 return MadeChange ? V : nullptr;
76 }
77
78
79 /// True if the multiply can not be expressed in an int this size.
MultiplyOverflows(const APInt & C1,const APInt & C2,APInt & Product,bool IsSigned)80 static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product,
81 bool IsSigned) {
82 bool Overflow;
83 if (IsSigned)
84 Product = C1.smul_ov(C2, Overflow);
85 else
86 Product = C1.umul_ov(C2, Overflow);
87
88 return Overflow;
89 }
90
91 /// \brief True if C2 is a multiple of C1. Quotient contains C2/C1.
IsMultiple(const APInt & C1,const APInt & C2,APInt & Quotient,bool IsSigned)92 static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient,
93 bool IsSigned) {
94 assert(C1.getBitWidth() == C2.getBitWidth() &&
95 "Inconsistent width of constants!");
96
97 // Bail if we will divide by zero.
98 if (C2.isMinValue())
99 return false;
100
101 // Bail if we would divide INT_MIN by -1.
102 if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue())
103 return false;
104
105 APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned);
106 if (IsSigned)
107 APInt::sdivrem(C1, C2, Quotient, Remainder);
108 else
109 APInt::udivrem(C1, C2, Quotient, Remainder);
110
111 return Remainder.isMinValue();
112 }
113
114 /// \brief A helper routine of InstCombiner::visitMul().
115 ///
116 /// If C is a vector of known powers of 2, then this function returns
117 /// a new vector obtained from C replacing each element with its logBase2.
118 /// Return a null pointer otherwise.
getLogBase2Vector(ConstantDataVector * CV)119 static Constant *getLogBase2Vector(ConstantDataVector *CV) {
120 const APInt *IVal;
121 SmallVector<Constant *, 4> Elts;
122
123 for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) {
124 Constant *Elt = CV->getElementAsConstant(I);
125 if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2())
126 return nullptr;
127 Elts.push_back(ConstantInt::get(Elt->getType(), IVal->logBase2()));
128 }
129
130 return ConstantVector::get(Elts);
131 }
132
133 /// \brief Return true if we can prove that:
134 /// (mul LHS, RHS) === (mul nsw LHS, RHS)
WillNotOverflowSignedMul(Value * LHS,Value * RHS,Instruction & CxtI)135 bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS,
136 Instruction &CxtI) {
137 // Multiplying n * m significant bits yields a result of n + m significant
138 // bits. If the total number of significant bits does not exceed the
139 // result bit width (minus 1), there is no overflow.
140 // This means if we have enough leading sign bits in the operands
141 // we can guarantee that the result does not overflow.
142 // Ref: "Hacker's Delight" by Henry Warren
143 unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
144
145 // Note that underestimating the number of sign bits gives a more
146 // conservative answer.
147 unsigned SignBits =
148 ComputeNumSignBits(LHS, 0, &CxtI) + ComputeNumSignBits(RHS, 0, &CxtI);
149
150 // First handle the easy case: if we have enough sign bits there's
151 // definitely no overflow.
152 if (SignBits > BitWidth + 1)
153 return true;
154
155 // There are two ambiguous cases where there can be no overflow:
156 // SignBits == BitWidth + 1 and
157 // SignBits == BitWidth
158 // The second case is difficult to check, therefore we only handle the
159 // first case.
160 if (SignBits == BitWidth + 1) {
161 // It overflows only when both arguments are negative and the true
162 // product is exactly the minimum negative number.
163 // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
164 // For simplicity we just check if at least one side is not negative.
165 bool LHSNonNegative, LHSNegative;
166 bool RHSNonNegative, RHSNegative;
167 ComputeSignBit(LHS, LHSNonNegative, LHSNegative, /*Depth=*/0, &CxtI);
168 ComputeSignBit(RHS, RHSNonNegative, RHSNegative, /*Depth=*/0, &CxtI);
169 if (LHSNonNegative || RHSNonNegative)
170 return true;
171 }
172 return false;
173 }
174
visitMul(BinaryOperator & I)175 Instruction *InstCombiner::visitMul(BinaryOperator &I) {
176 bool Changed = SimplifyAssociativeOrCommutative(I);
177 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
178
179 if (Value *V = SimplifyVectorOp(I))
180 return replaceInstUsesWith(I, V);
181
182 if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC))
183 return replaceInstUsesWith(I, V);
184
185 if (Value *V = SimplifyUsingDistributiveLaws(I))
186 return replaceInstUsesWith(I, V);
187
188 // X * -1 == 0 - X
189 if (match(Op1, m_AllOnes())) {
190 BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName());
191 if (I.hasNoSignedWrap())
192 BO->setHasNoSignedWrap();
193 return BO;
194 }
195
196 // Also allow combining multiply instructions on vectors.
197 {
198 Value *NewOp;
199 Constant *C1, *C2;
200 const APInt *IVal;
201 if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)),
202 m_Constant(C1))) &&
203 match(C1, m_APInt(IVal))) {
204 // ((X << C2)*C1) == (X * (C1 << C2))
205 Constant *Shl = ConstantExpr::getShl(C1, C2);
206 BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0));
207 BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl);
208 if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap())
209 BO->setHasNoUnsignedWrap();
210 if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() &&
211 Shl->isNotMinSignedValue())
212 BO->setHasNoSignedWrap();
213 return BO;
214 }
215
216 if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) {
217 Constant *NewCst = nullptr;
218 if (match(C1, m_APInt(IVal)) && IVal->isPowerOf2())
219 // Replace X*(2^C) with X << C, where C is either a scalar or a splat.
220 NewCst = ConstantInt::get(NewOp->getType(), IVal->logBase2());
221 else if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(C1))
222 // Replace X*(2^C) with X << C, where C is a vector of known
223 // constant powers of 2.
224 NewCst = getLogBase2Vector(CV);
225
226 if (NewCst) {
227 unsigned Width = NewCst->getType()->getPrimitiveSizeInBits();
228 BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst);
229
230 if (I.hasNoUnsignedWrap())
231 Shl->setHasNoUnsignedWrap();
232 if (I.hasNoSignedWrap()) {
233 uint64_t V;
234 if (match(NewCst, m_ConstantInt(V)) && V != Width - 1)
235 Shl->setHasNoSignedWrap();
236 }
237
238 return Shl;
239 }
240 }
241 }
242
243 if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
244 // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n
245 // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n
246 // The "* (2**n)" thus becomes a potential shifting opportunity.
247 {
248 const APInt & Val = CI->getValue();
249 const APInt &PosVal = Val.abs();
250 if (Val.isNegative() && PosVal.isPowerOf2()) {
251 Value *X = nullptr, *Y = nullptr;
252 if (Op0->hasOneUse()) {
253 ConstantInt *C1;
254 Value *Sub = nullptr;
255 if (match(Op0, m_Sub(m_Value(Y), m_Value(X))))
256 Sub = Builder->CreateSub(X, Y, "suba");
257 else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1))))
258 Sub = Builder->CreateSub(Builder->CreateNeg(C1), Y, "subc");
259 if (Sub)
260 return
261 BinaryOperator::CreateMul(Sub,
262 ConstantInt::get(Y->getType(), PosVal));
263 }
264 }
265 }
266 }
267
268 // Simplify mul instructions with a constant RHS.
269 if (isa<Constant>(Op1)) {
270 // Try to fold constant mul into select arguments.
271 if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
272 if (Instruction *R = FoldOpIntoSelect(I, SI))
273 return R;
274
275 if (isa<PHINode>(Op0))
276 if (Instruction *NV = FoldOpIntoPhi(I))
277 return NV;
278
279 // Canonicalize (X+C1)*CI -> X*CI+C1*CI.
280 {
281 Value *X;
282 Constant *C1;
283 if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) {
284 Value *Mul = Builder->CreateMul(C1, Op1);
285 // Only go forward with the transform if C1*CI simplifies to a tidier
286 // constant.
287 if (!match(Mul, m_Mul(m_Value(), m_Value())))
288 return BinaryOperator::CreateAdd(Builder->CreateMul(X, Op1), Mul);
289 }
290 }
291 }
292
293 if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y
294 if (Value *Op1v = dyn_castNegVal(Op1)) {
295 BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v);
296 if (I.hasNoSignedWrap() &&
297 match(Op0, m_NSWSub(m_Value(), m_Value())) &&
298 match(Op1, m_NSWSub(m_Value(), m_Value())))
299 BO->setHasNoSignedWrap();
300 return BO;
301 }
302 }
303
304 // (X / Y) * Y = X - (X % Y)
305 // (X / Y) * -Y = (X % Y) - X
306 {
307 Value *Op1C = Op1;
308 BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0);
309 if (!BO ||
310 (BO->getOpcode() != Instruction::UDiv &&
311 BO->getOpcode() != Instruction::SDiv)) {
312 Op1C = Op0;
313 BO = dyn_cast<BinaryOperator>(Op1);
314 }
315 Value *Neg = dyn_castNegVal(Op1C);
316 if (BO && BO->hasOneUse() &&
317 (BO->getOperand(1) == Op1C || BO->getOperand(1) == Neg) &&
318 (BO->getOpcode() == Instruction::UDiv ||
319 BO->getOpcode() == Instruction::SDiv)) {
320 Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1);
321
322 // If the division is exact, X % Y is zero, so we end up with X or -X.
323 if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO))
324 if (SDiv->isExact()) {
325 if (Op1BO == Op1C)
326 return replaceInstUsesWith(I, Op0BO);
327 return BinaryOperator::CreateNeg(Op0BO);
328 }
329
330 Value *Rem;
331 if (BO->getOpcode() == Instruction::UDiv)
332 Rem = Builder->CreateURem(Op0BO, Op1BO);
333 else
334 Rem = Builder->CreateSRem(Op0BO, Op1BO);
335 Rem->takeName(BO);
336
337 if (Op1BO == Op1C)
338 return BinaryOperator::CreateSub(Op0BO, Rem);
339 return BinaryOperator::CreateSub(Rem, Op0BO);
340 }
341 }
342
343 /// i1 mul -> i1 and.
344 if (I.getType()->getScalarType()->isIntegerTy(1))
345 return BinaryOperator::CreateAnd(Op0, Op1);
346
347 // X*(1 << Y) --> X << Y
348 // (1 << Y)*X --> X << Y
349 {
350 Value *Y;
351 BinaryOperator *BO = nullptr;
352 bool ShlNSW = false;
353 if (match(Op0, m_Shl(m_One(), m_Value(Y)))) {
354 BO = BinaryOperator::CreateShl(Op1, Y);
355 ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap();
356 } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) {
357 BO = BinaryOperator::CreateShl(Op0, Y);
358 ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap();
359 }
360 if (BO) {
361 if (I.hasNoUnsignedWrap())
362 BO->setHasNoUnsignedWrap();
363 if (I.hasNoSignedWrap() && ShlNSW)
364 BO->setHasNoSignedWrap();
365 return BO;
366 }
367 }
368
369 // If one of the operands of the multiply is a cast from a boolean value, then
370 // we know the bool is either zero or one, so this is a 'masking' multiply.
371 // X * Y (where Y is 0 or 1) -> X & (0-Y)
372 if (!I.getType()->isVectorTy()) {
373 // -2 is "-1 << 1" so it is all bits set except the low one.
374 APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true);
375
376 Value *BoolCast = nullptr, *OtherOp = nullptr;
377 if (MaskedValueIsZero(Op0, Negative2, 0, &I)) {
378 BoolCast = Op0;
379 OtherOp = Op1;
380 } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) {
381 BoolCast = Op1;
382 OtherOp = Op0;
383 }
384
385 if (BoolCast) {
386 Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()),
387 BoolCast);
388 return BinaryOperator::CreateAnd(V, OtherOp);
389 }
390 }
391
392 if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, I)) {
393 Changed = true;
394 I.setHasNoSignedWrap(true);
395 }
396
397 if (!I.hasNoUnsignedWrap() &&
398 computeOverflowForUnsignedMul(Op0, Op1, &I) ==
399 OverflowResult::NeverOverflows) {
400 Changed = true;
401 I.setHasNoUnsignedWrap(true);
402 }
403
404 return Changed ? &I : nullptr;
405 }
406
407 /// Detect pattern log2(Y * 0.5) with corresponding fast math flags.
detectLog2OfHalf(Value * & Op,Value * & Y,IntrinsicInst * & Log2)408 static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) {
409 if (!Op->hasOneUse())
410 return;
411
412 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op);
413 if (!II)
414 return;
415 if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra())
416 return;
417 Log2 = II;
418
419 Value *OpLog2Of = II->getArgOperand(0);
420 if (!OpLog2Of->hasOneUse())
421 return;
422
423 Instruction *I = dyn_cast<Instruction>(OpLog2Of);
424 if (!I)
425 return;
426 if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra())
427 return;
428
429 if (match(I->getOperand(0), m_SpecificFP(0.5)))
430 Y = I->getOperand(1);
431 else if (match(I->getOperand(1), m_SpecificFP(0.5)))
432 Y = I->getOperand(0);
433 }
434
isFiniteNonZeroFp(Constant * C)435 static bool isFiniteNonZeroFp(Constant *C) {
436 if (C->getType()->isVectorTy()) {
437 for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E;
438 ++I) {
439 ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(C->getAggregateElement(I));
440 if (!CFP || !CFP->getValueAPF().isFiniteNonZero())
441 return false;
442 }
443 return true;
444 }
445
446 return isa<ConstantFP>(C) &&
447 cast<ConstantFP>(C)->getValueAPF().isFiniteNonZero();
448 }
449
isNormalFp(Constant * C)450 static bool isNormalFp(Constant *C) {
451 if (C->getType()->isVectorTy()) {
452 for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E;
453 ++I) {
454 ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(C->getAggregateElement(I));
455 if (!CFP || !CFP->getValueAPF().isNormal())
456 return false;
457 }
458 return true;
459 }
460
461 return isa<ConstantFP>(C) && cast<ConstantFP>(C)->getValueAPF().isNormal();
462 }
463
464 /// Helper function of InstCombiner::visitFMul(BinaryOperator(). It returns
465 /// true iff the given value is FMul or FDiv with one and only one operand
466 /// being a normal constant (i.e. not Zero/NaN/Infinity).
isFMulOrFDivWithConstant(Value * V)467 static bool isFMulOrFDivWithConstant(Value *V) {
468 Instruction *I = dyn_cast<Instruction>(V);
469 if (!I || (I->getOpcode() != Instruction::FMul &&
470 I->getOpcode() != Instruction::FDiv))
471 return false;
472
473 Constant *C0 = dyn_cast<Constant>(I->getOperand(0));
474 Constant *C1 = dyn_cast<Constant>(I->getOperand(1));
475
476 if (C0 && C1)
477 return false;
478
479 return (C0 && isFiniteNonZeroFp(C0)) || (C1 && isFiniteNonZeroFp(C1));
480 }
481
482 /// foldFMulConst() is a helper routine of InstCombiner::visitFMul().
483 /// The input \p FMulOrDiv is a FMul/FDiv with one and only one operand
484 /// being a constant (i.e. isFMulOrFDivWithConstant(FMulOrDiv) == true).
485 /// This function is to simplify "FMulOrDiv * C" and returns the
486 /// resulting expression. Note that this function could return NULL in
487 /// case the constants cannot be folded into a normal floating-point.
488 ///
foldFMulConst(Instruction * FMulOrDiv,Constant * C,Instruction * InsertBefore)489 Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C,
490 Instruction *InsertBefore) {
491 assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid");
492
493 Value *Opnd0 = FMulOrDiv->getOperand(0);
494 Value *Opnd1 = FMulOrDiv->getOperand(1);
495
496 Constant *C0 = dyn_cast<Constant>(Opnd0);
497 Constant *C1 = dyn_cast<Constant>(Opnd1);
498
499 BinaryOperator *R = nullptr;
500
501 // (X * C0) * C => X * (C0*C)
502 if (FMulOrDiv->getOpcode() == Instruction::FMul) {
503 Constant *F = ConstantExpr::getFMul(C1 ? C1 : C0, C);
504 if (isNormalFp(F))
505 R = BinaryOperator::CreateFMul(C1 ? Opnd0 : Opnd1, F);
506 } else {
507 if (C0) {
508 // (C0 / X) * C => (C0 * C) / X
509 if (FMulOrDiv->hasOneUse()) {
510 // It would otherwise introduce another div.
511 Constant *F = ConstantExpr::getFMul(C0, C);
512 if (isNormalFp(F))
513 R = BinaryOperator::CreateFDiv(F, Opnd1);
514 }
515 } else {
516 // (X / C1) * C => X * (C/C1) if C/C1 is not a denormal
517 Constant *F = ConstantExpr::getFDiv(C, C1);
518 if (isNormalFp(F)) {
519 R = BinaryOperator::CreateFMul(Opnd0, F);
520 } else {
521 // (X / C1) * C => X / (C1/C)
522 Constant *F = ConstantExpr::getFDiv(C1, C);
523 if (isNormalFp(F))
524 R = BinaryOperator::CreateFDiv(Opnd0, F);
525 }
526 }
527 }
528
529 if (R) {
530 R->setHasUnsafeAlgebra(true);
531 InsertNewInstWith(R, *InsertBefore);
532 }
533
534 return R;
535 }
536
visitFMul(BinaryOperator & I)537 Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
538 bool Changed = SimplifyAssociativeOrCommutative(I);
539 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
540
541 if (Value *V = SimplifyVectorOp(I))
542 return replaceInstUsesWith(I, V);
543
544 if (isa<Constant>(Op0))
545 std::swap(Op0, Op1);
546
547 if (Value *V =
548 SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC))
549 return replaceInstUsesWith(I, V);
550
551 bool AllowReassociate = I.hasUnsafeAlgebra();
552
553 // Simplify mul instructions with a constant RHS.
554 if (isa<Constant>(Op1)) {
555 // Try to fold constant mul into select arguments.
556 if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
557 if (Instruction *R = FoldOpIntoSelect(I, SI))
558 return R;
559
560 if (isa<PHINode>(Op0))
561 if (Instruction *NV = FoldOpIntoPhi(I))
562 return NV;
563
564 // (fmul X, -1.0) --> (fsub -0.0, X)
565 if (match(Op1, m_SpecificFP(-1.0))) {
566 Constant *NegZero = ConstantFP::getNegativeZero(Op1->getType());
567 Instruction *RI = BinaryOperator::CreateFSub(NegZero, Op0);
568 RI->copyFastMathFlags(&I);
569 return RI;
570 }
571
572 Constant *C = cast<Constant>(Op1);
573 if (AllowReassociate && isFiniteNonZeroFp(C)) {
574 // Let MDC denote an expression in one of these forms:
575 // X * C, C/X, X/C, where C is a constant.
576 //
577 // Try to simplify "MDC * Constant"
578 if (isFMulOrFDivWithConstant(Op0))
579 if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I))
580 return replaceInstUsesWith(I, V);
581
582 // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C)
583 Instruction *FAddSub = dyn_cast<Instruction>(Op0);
584 if (FAddSub &&
585 (FAddSub->getOpcode() == Instruction::FAdd ||
586 FAddSub->getOpcode() == Instruction::FSub)) {
587 Value *Opnd0 = FAddSub->getOperand(0);
588 Value *Opnd1 = FAddSub->getOperand(1);
589 Constant *C0 = dyn_cast<Constant>(Opnd0);
590 Constant *C1 = dyn_cast<Constant>(Opnd1);
591 bool Swap = false;
592 if (C0) {
593 std::swap(C0, C1);
594 std::swap(Opnd0, Opnd1);
595 Swap = true;
596 }
597
598 if (C1 && isFiniteNonZeroFp(C1) && isFMulOrFDivWithConstant(Opnd0)) {
599 Value *M1 = ConstantExpr::getFMul(C1, C);
600 Value *M0 = isNormalFp(cast<Constant>(M1)) ?
601 foldFMulConst(cast<Instruction>(Opnd0), C, &I) :
602 nullptr;
603 if (M0 && M1) {
604 if (Swap && FAddSub->getOpcode() == Instruction::FSub)
605 std::swap(M0, M1);
606
607 Instruction *RI = (FAddSub->getOpcode() == Instruction::FAdd)
608 ? BinaryOperator::CreateFAdd(M0, M1)
609 : BinaryOperator::CreateFSub(M0, M1);
610 RI->copyFastMathFlags(&I);
611 return RI;
612 }
613 }
614 }
615 }
616 }
617
618 if (Op0 == Op1) {
619 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
620 // sqrt(X) * sqrt(X) -> X
621 if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt)
622 return replaceInstUsesWith(I, II->getOperand(0));
623
624 // fabs(X) * fabs(X) -> X * X
625 if (II->getIntrinsicID() == Intrinsic::fabs) {
626 Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0),
627 II->getOperand(0),
628 I.getName());
629 FMulVal->copyFastMathFlags(&I);
630 return FMulVal;
631 }
632 }
633 }
634
635 // Under unsafe algebra do:
636 // X * log2(0.5*Y) = X*log2(Y) - X
637 if (AllowReassociate) {
638 Value *OpX = nullptr;
639 Value *OpY = nullptr;
640 IntrinsicInst *Log2;
641 detectLog2OfHalf(Op0, OpY, Log2);
642 if (OpY) {
643 OpX = Op1;
644 } else {
645 detectLog2OfHalf(Op1, OpY, Log2);
646 if (OpY) {
647 OpX = Op0;
648 }
649 }
650 // if pattern detected emit alternate sequence
651 if (OpX && OpY) {
652 BuilderTy::FastMathFlagGuard Guard(*Builder);
653 Builder->setFastMathFlags(Log2->getFastMathFlags());
654 Log2->setArgOperand(0, OpY);
655 Value *FMulVal = Builder->CreateFMul(OpX, Log2);
656 Value *FSub = Builder->CreateFSub(FMulVal, OpX);
657 FSub->takeName(&I);
658 return replaceInstUsesWith(I, FSub);
659 }
660 }
661
662 // Handle symmetric situation in a 2-iteration loop
663 Value *Opnd0 = Op0;
664 Value *Opnd1 = Op1;
665 for (int i = 0; i < 2; i++) {
666 bool IgnoreZeroSign = I.hasNoSignedZeros();
667 if (BinaryOperator::isFNeg(Opnd0, IgnoreZeroSign)) {
668 BuilderTy::FastMathFlagGuard Guard(*Builder);
669 Builder->setFastMathFlags(I.getFastMathFlags());
670
671 Value *N0 = dyn_castFNegVal(Opnd0, IgnoreZeroSign);
672 Value *N1 = dyn_castFNegVal(Opnd1, IgnoreZeroSign);
673
674 // -X * -Y => X*Y
675 if (N1) {
676 Value *FMul = Builder->CreateFMul(N0, N1);
677 FMul->takeName(&I);
678 return replaceInstUsesWith(I, FMul);
679 }
680
681 if (Opnd0->hasOneUse()) {
682 // -X * Y => -(X*Y) (Promote negation as high as possible)
683 Value *T = Builder->CreateFMul(N0, Opnd1);
684 Value *Neg = Builder->CreateFNeg(T);
685 Neg->takeName(&I);
686 return replaceInstUsesWith(I, Neg);
687 }
688 }
689
690 // (X*Y) * X => (X*X) * Y where Y != X
691 // The purpose is two-fold:
692 // 1) to form a power expression (of X).
693 // 2) potentially shorten the critical path: After transformation, the
694 // latency of the instruction Y is amortized by the expression of X*X,
695 // and therefore Y is in a "less critical" position compared to what it
696 // was before the transformation.
697 //
698 if (AllowReassociate) {
699 Value *Opnd0_0, *Opnd0_1;
700 if (Opnd0->hasOneUse() &&
701 match(Opnd0, m_FMul(m_Value(Opnd0_0), m_Value(Opnd0_1)))) {
702 Value *Y = nullptr;
703 if (Opnd0_0 == Opnd1 && Opnd0_1 != Opnd1)
704 Y = Opnd0_1;
705 else if (Opnd0_1 == Opnd1 && Opnd0_0 != Opnd1)
706 Y = Opnd0_0;
707
708 if (Y) {
709 BuilderTy::FastMathFlagGuard Guard(*Builder);
710 Builder->setFastMathFlags(I.getFastMathFlags());
711 Value *T = Builder->CreateFMul(Opnd1, Opnd1);
712
713 Value *R = Builder->CreateFMul(T, Y);
714 R->takeName(&I);
715 return replaceInstUsesWith(I, R);
716 }
717 }
718 }
719
720 if (!isa<Constant>(Op1))
721 std::swap(Opnd0, Opnd1);
722 else
723 break;
724 }
725
726 return Changed ? &I : nullptr;
727 }
728
729 /// Try to fold a divide or remainder of a select instruction.
SimplifyDivRemOfSelect(BinaryOperator & I)730 bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) {
731 SelectInst *SI = cast<SelectInst>(I.getOperand(1));
732
733 // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y
734 int NonNullOperand = -1;
735 if (Constant *ST = dyn_cast<Constant>(SI->getOperand(1)))
736 if (ST->isNullValue())
737 NonNullOperand = 2;
738 // div/rem X, (Cond ? Y : 0) -> div/rem X, Y
739 if (Constant *ST = dyn_cast<Constant>(SI->getOperand(2)))
740 if (ST->isNullValue())
741 NonNullOperand = 1;
742
743 if (NonNullOperand == -1)
744 return false;
745
746 Value *SelectCond = SI->getOperand(0);
747
748 // Change the div/rem to use 'Y' instead of the select.
749 I.setOperand(1, SI->getOperand(NonNullOperand));
750
751 // Okay, we know we replace the operand of the div/rem with 'Y' with no
752 // problem. However, the select, or the condition of the select may have
753 // multiple uses. Based on our knowledge that the operand must be non-zero,
754 // propagate the known value for the select into other uses of it, and
755 // propagate a known value of the condition into its other users.
756
757 // If the select and condition only have a single use, don't bother with this,
758 // early exit.
759 if (SI->use_empty() && SelectCond->hasOneUse())
760 return true;
761
762 // Scan the current block backward, looking for other uses of SI.
763 BasicBlock::iterator BBI = I.getIterator(), BBFront = I.getParent()->begin();
764
765 while (BBI != BBFront) {
766 --BBI;
767 // If we found a call to a function, we can't assume it will return, so
768 // information from below it cannot be propagated above it.
769 if (isa<CallInst>(BBI) && !isa<IntrinsicInst>(BBI))
770 break;
771
772 // Replace uses of the select or its condition with the known values.
773 for (Instruction::op_iterator I = BBI->op_begin(), E = BBI->op_end();
774 I != E; ++I) {
775 if (*I == SI) {
776 *I = SI->getOperand(NonNullOperand);
777 Worklist.Add(&*BBI);
778 } else if (*I == SelectCond) {
779 *I = Builder->getInt1(NonNullOperand == 1);
780 Worklist.Add(&*BBI);
781 }
782 }
783
784 // If we past the instruction, quit looking for it.
785 if (&*BBI == SI)
786 SI = nullptr;
787 if (&*BBI == SelectCond)
788 SelectCond = nullptr;
789
790 // If we ran out of things to eliminate, break out of the loop.
791 if (!SelectCond && !SI)
792 break;
793
794 }
795 return true;
796 }
797
798
799 /// This function implements the transforms common to both integer division
800 /// instructions (udiv and sdiv). It is called by the visitors to those integer
801 /// division instructions.
802 /// @brief Common integer divide transforms
commonIDivTransforms(BinaryOperator & I)803 Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
804 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
805
806 // The RHS is known non-zero.
807 if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) {
808 I.setOperand(1, V);
809 return &I;
810 }
811
812 // Handle cases involving: [su]div X, (select Cond, Y, Z)
813 // This does not apply for fdiv.
814 if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
815 return &I;
816
817 if (Instruction *LHS = dyn_cast<Instruction>(Op0)) {
818 const APInt *C2;
819 if (match(Op1, m_APInt(C2))) {
820 Value *X;
821 const APInt *C1;
822 bool IsSigned = I.getOpcode() == Instruction::SDiv;
823
824 // (X / C1) / C2 -> X / (C1*C2)
825 if ((IsSigned && match(LHS, m_SDiv(m_Value(X), m_APInt(C1)))) ||
826 (!IsSigned && match(LHS, m_UDiv(m_Value(X), m_APInt(C1))))) {
827 APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
828 if (!MultiplyOverflows(*C1, *C2, Product, IsSigned))
829 return BinaryOperator::Create(I.getOpcode(), X,
830 ConstantInt::get(I.getType(), Product));
831 }
832
833 if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) ||
834 (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) {
835 APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
836
837 // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1.
838 if (IsMultiple(*C2, *C1, Quotient, IsSigned)) {
839 BinaryOperator *BO = BinaryOperator::Create(
840 I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient));
841 BO->setIsExact(I.isExact());
842 return BO;
843 }
844
845 // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2.
846 if (IsMultiple(*C1, *C2, Quotient, IsSigned)) {
847 BinaryOperator *BO = BinaryOperator::Create(
848 Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient));
849 BO->setHasNoUnsignedWrap(
850 !IsSigned &&
851 cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap());
852 BO->setHasNoSignedWrap(
853 cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap());
854 return BO;
855 }
856 }
857
858 if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1))) &&
859 *C1 != C1->getBitWidth() - 1) ||
860 (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) {
861 APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned);
862 APInt C1Shifted = APInt::getOneBitSet(
863 C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue()));
864
865 // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1.
866 if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) {
867 BinaryOperator *BO = BinaryOperator::Create(
868 I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient));
869 BO->setIsExact(I.isExact());
870 return BO;
871 }
872
873 // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2.
874 if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) {
875 BinaryOperator *BO = BinaryOperator::Create(
876 Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient));
877 BO->setHasNoUnsignedWrap(
878 !IsSigned &&
879 cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap());
880 BO->setHasNoSignedWrap(
881 cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap());
882 return BO;
883 }
884 }
885
886 if (*C2 != 0) { // avoid X udiv 0
887 if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
888 if (Instruction *R = FoldOpIntoSelect(I, SI))
889 return R;
890 if (isa<PHINode>(Op0))
891 if (Instruction *NV = FoldOpIntoPhi(I))
892 return NV;
893 }
894 }
895 }
896
897 if (ConstantInt *One = dyn_cast<ConstantInt>(Op0)) {
898 if (One->isOne() && !I.getType()->isIntegerTy(1)) {
899 bool isSigned = I.getOpcode() == Instruction::SDiv;
900 if (isSigned) {
901 // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the
902 // result is one, if Op1 is -1 then the result is minus one, otherwise
903 // it's zero.
904 Value *Inc = Builder->CreateAdd(Op1, One);
905 Value *Cmp = Builder->CreateICmpULT(
906 Inc, ConstantInt::get(I.getType(), 3));
907 return SelectInst::Create(Cmp, Op1, ConstantInt::get(I.getType(), 0));
908 } else {
909 // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the
910 // result is one, otherwise it's zero.
911 return new ZExtInst(Builder->CreateICmpEQ(Op1, One), I.getType());
912 }
913 }
914 }
915
916 // See if we can fold away this div instruction.
917 if (SimplifyDemandedInstructionBits(I))
918 return &I;
919
920 // (X - (X rem Y)) / Y -> X / Y; usually originates as ((X / Y) * Y) / Y
921 Value *X = nullptr, *Z = nullptr;
922 if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) { // (X - Z) / Y; Y = Op1
923 bool isSigned = I.getOpcode() == Instruction::SDiv;
924 if ((isSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) ||
925 (!isSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1)))))
926 return BinaryOperator::Create(I.getOpcode(), X, Op1);
927 }
928
929 return nullptr;
930 }
931
932 /// dyn_castZExtVal - Checks if V is a zext or constant that can
933 /// be truncated to Ty without losing bits.
dyn_castZExtVal(Value * V,Type * Ty)934 static Value *dyn_castZExtVal(Value *V, Type *Ty) {
935 if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) {
936 if (Z->getSrcTy() == Ty)
937 return Z->getOperand(0);
938 } else if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
939 if (C->getValue().getActiveBits() <= cast<IntegerType>(Ty)->getBitWidth())
940 return ConstantExpr::getTrunc(C, Ty);
941 }
942 return nullptr;
943 }
944
945 namespace {
946 const unsigned MaxDepth = 6;
947 typedef Instruction *(*FoldUDivOperandCb)(Value *Op0, Value *Op1,
948 const BinaryOperator &I,
949 InstCombiner &IC);
950
951 /// \brief Used to maintain state for visitUDivOperand().
952 struct UDivFoldAction {
953 FoldUDivOperandCb FoldAction; ///< Informs visitUDiv() how to fold this
954 ///< operand. This can be zero if this action
955 ///< joins two actions together.
956
957 Value *OperandToFold; ///< Which operand to fold.
958 union {
959 Instruction *FoldResult; ///< The instruction returned when FoldAction is
960 ///< invoked.
961
962 size_t SelectLHSIdx; ///< Stores the LHS action index if this action
963 ///< joins two actions together.
964 };
965
UDivFoldAction__anon736161690111::UDivFoldAction966 UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand)
967 : FoldAction(FA), OperandToFold(InputOperand), FoldResult(nullptr) {}
UDivFoldAction__anon736161690111::UDivFoldAction968 UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand, size_t SLHS)
969 : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {}
970 };
971 }
972
973 // X udiv 2^C -> X >> C
foldUDivPow2Cst(Value * Op0,Value * Op1,const BinaryOperator & I,InstCombiner & IC)974 static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1,
975 const BinaryOperator &I, InstCombiner &IC) {
976 const APInt &C = cast<Constant>(Op1)->getUniqueInteger();
977 BinaryOperator *LShr = BinaryOperator::CreateLShr(
978 Op0, ConstantInt::get(Op0->getType(), C.logBase2()));
979 if (I.isExact())
980 LShr->setIsExact();
981 return LShr;
982 }
983
984 // X udiv C, where C >= signbit
foldUDivNegCst(Value * Op0,Value * Op1,const BinaryOperator & I,InstCombiner & IC)985 static Instruction *foldUDivNegCst(Value *Op0, Value *Op1,
986 const BinaryOperator &I, InstCombiner &IC) {
987 Value *ICI = IC.Builder->CreateICmpULT(Op0, cast<ConstantInt>(Op1));
988
989 return SelectInst::Create(ICI, Constant::getNullValue(I.getType()),
990 ConstantInt::get(I.getType(), 1));
991 }
992
993 // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2)
foldUDivShl(Value * Op0,Value * Op1,const BinaryOperator & I,InstCombiner & IC)994 static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I,
995 InstCombiner &IC) {
996 Instruction *ShiftLeft = cast<Instruction>(Op1);
997 if (isa<ZExtInst>(ShiftLeft))
998 ShiftLeft = cast<Instruction>(ShiftLeft->getOperand(0));
999
1000 const APInt &CI =
1001 cast<Constant>(ShiftLeft->getOperand(0))->getUniqueInteger();
1002 Value *N = ShiftLeft->getOperand(1);
1003 if (CI != 1)
1004 N = IC.Builder->CreateAdd(N, ConstantInt::get(N->getType(), CI.logBase2()));
1005 if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1))
1006 N = IC.Builder->CreateZExt(N, Z->getDestTy());
1007 BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N);
1008 if (I.isExact())
1009 LShr->setIsExact();
1010 return LShr;
1011 }
1012
1013 // \brief Recursively visits the possible right hand operands of a udiv
1014 // instruction, seeing through select instructions, to determine if we can
1015 // replace the udiv with something simpler. If we find that an operand is not
1016 // able to simplify the udiv, we abort the entire transformation.
visitUDivOperand(Value * Op0,Value * Op1,const BinaryOperator & I,SmallVectorImpl<UDivFoldAction> & Actions,unsigned Depth=0)1017 static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I,
1018 SmallVectorImpl<UDivFoldAction> &Actions,
1019 unsigned Depth = 0) {
1020 // Check to see if this is an unsigned division with an exact power of 2,
1021 // if so, convert to a right shift.
1022 if (match(Op1, m_Power2())) {
1023 Actions.push_back(UDivFoldAction(foldUDivPow2Cst, Op1));
1024 return Actions.size();
1025 }
1026
1027 if (ConstantInt *C = dyn_cast<ConstantInt>(Op1))
1028 // X udiv C, where C >= signbit
1029 if (C->getValue().isNegative()) {
1030 Actions.push_back(UDivFoldAction(foldUDivNegCst, C));
1031 return Actions.size();
1032 }
1033
1034 // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2)
1035 if (match(Op1, m_Shl(m_Power2(), m_Value())) ||
1036 match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) {
1037 Actions.push_back(UDivFoldAction(foldUDivShl, Op1));
1038 return Actions.size();
1039 }
1040
1041 // The remaining tests are all recursive, so bail out if we hit the limit.
1042 if (Depth++ == MaxDepth)
1043 return 0;
1044
1045 if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1046 if (size_t LHSIdx =
1047 visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth))
1048 if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) {
1049 Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1));
1050 return Actions.size();
1051 }
1052
1053 return 0;
1054 }
1055
visitUDiv(BinaryOperator & I)1056 Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
1057 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1058
1059 if (Value *V = SimplifyVectorOp(I))
1060 return replaceInstUsesWith(I, V);
1061
1062 if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC))
1063 return replaceInstUsesWith(I, V);
1064
1065 // Handle the integer div common cases
1066 if (Instruction *Common = commonIDivTransforms(I))
1067 return Common;
1068
1069 // (x lshr C1) udiv C2 --> x udiv (C2 << C1)
1070 {
1071 Value *X;
1072 const APInt *C1, *C2;
1073 if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) &&
1074 match(Op1, m_APInt(C2))) {
1075 bool Overflow;
1076 APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow);
1077 if (!Overflow) {
1078 bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value()));
1079 BinaryOperator *BO = BinaryOperator::CreateUDiv(
1080 X, ConstantInt::get(X->getType(), C2ShlC1));
1081 if (IsExact)
1082 BO->setIsExact();
1083 return BO;
1084 }
1085 }
1086 }
1087
1088 // (zext A) udiv (zext B) --> zext (A udiv B)
1089 if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0))
1090 if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy()))
1091 return new ZExtInst(
1092 Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()),
1093 I.getType());
1094
1095 // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...))))
1096 SmallVector<UDivFoldAction, 6> UDivActions;
1097 if (visitUDivOperand(Op0, Op1, I, UDivActions))
1098 for (unsigned i = 0, e = UDivActions.size(); i != e; ++i) {
1099 FoldUDivOperandCb Action = UDivActions[i].FoldAction;
1100 Value *ActionOp1 = UDivActions[i].OperandToFold;
1101 Instruction *Inst;
1102 if (Action)
1103 Inst = Action(Op0, ActionOp1, I, *this);
1104 else {
1105 // This action joins two actions together. The RHS of this action is
1106 // simply the last action we processed, we saved the LHS action index in
1107 // the joining action.
1108 size_t SelectRHSIdx = i - 1;
1109 Value *SelectRHS = UDivActions[SelectRHSIdx].FoldResult;
1110 size_t SelectLHSIdx = UDivActions[i].SelectLHSIdx;
1111 Value *SelectLHS = UDivActions[SelectLHSIdx].FoldResult;
1112 Inst = SelectInst::Create(cast<SelectInst>(ActionOp1)->getCondition(),
1113 SelectLHS, SelectRHS);
1114 }
1115
1116 // If this is the last action to process, return it to the InstCombiner.
1117 // Otherwise, we insert it before the UDiv and record it so that we may
1118 // use it as part of a joining action (i.e., a SelectInst).
1119 if (e - i != 1) {
1120 Inst->insertBefore(&I);
1121 UDivActions[i].FoldResult = Inst;
1122 } else
1123 return Inst;
1124 }
1125
1126 return nullptr;
1127 }
1128
visitSDiv(BinaryOperator & I)1129 Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
1130 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1131
1132 if (Value *V = SimplifyVectorOp(I))
1133 return replaceInstUsesWith(I, V);
1134
1135 if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC))
1136 return replaceInstUsesWith(I, V);
1137
1138 // Handle the integer div common cases
1139 if (Instruction *Common = commonIDivTransforms(I))
1140 return Common;
1141
1142 const APInt *Op1C;
1143 if (match(Op1, m_APInt(Op1C))) {
1144 // sdiv X, -1 == -X
1145 if (Op1C->isAllOnesValue())
1146 return BinaryOperator::CreateNeg(Op0);
1147
1148 // sdiv exact X, C --> ashr exact X, log2(C)
1149 if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) {
1150 Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2());
1151 return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName());
1152 }
1153
1154 // If the dividend is sign-extended and the constant divisor is small enough
1155 // to fit in the source type, shrink the division to the narrower type:
1156 // (sext X) sdiv C --> sext (X sdiv C)
1157 Value *Op0Src;
1158 if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) &&
1159 Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) {
1160
1161 // In the general case, we need to make sure that the dividend is not the
1162 // minimum signed value because dividing that by -1 is UB. But here, we
1163 // know that the -1 divisor case is already handled above.
1164
1165 Constant *NarrowDivisor =
1166 ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType());
1167 Value *NarrowOp = Builder->CreateSDiv(Op0Src, NarrowDivisor);
1168 return new SExtInst(NarrowOp, Op0->getType());
1169 }
1170 }
1171
1172 if (Constant *RHS = dyn_cast<Constant>(Op1)) {
1173 // X/INT_MIN -> X == INT_MIN
1174 if (RHS->isMinSignedValue())
1175 return new ZExtInst(Builder->CreateICmpEQ(Op0, Op1), I.getType());
1176
1177 // -X/C --> X/-C provided the negation doesn't overflow.
1178 Value *X;
1179 if (match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) {
1180 auto *BO = BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(RHS));
1181 BO->setIsExact(I.isExact());
1182 return BO;
1183 }
1184 }
1185
1186 // If the sign bits of both operands are zero (i.e. we can prove they are
1187 // unsigned inputs), turn this into a udiv.
1188 if (I.getType()->isIntegerTy()) {
1189 APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()));
1190 if (MaskedValueIsZero(Op0, Mask, 0, &I)) {
1191 if (MaskedValueIsZero(Op1, Mask, 0, &I)) {
1192 // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
1193 auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
1194 BO->setIsExact(I.isExact());
1195 return BO;
1196 }
1197
1198 if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) {
1199 // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y)
1200 // Safe because the only negative value (1 << Y) can take on is
1201 // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have
1202 // the sign bit set.
1203 auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
1204 BO->setIsExact(I.isExact());
1205 return BO;
1206 }
1207 }
1208 }
1209
1210 return nullptr;
1211 }
1212
1213 /// CvtFDivConstToReciprocal tries to convert X/C into X*1/C if C not a special
1214 /// FP value and:
1215 /// 1) 1/C is exact, or
1216 /// 2) reciprocal is allowed.
1217 /// If the conversion was successful, the simplified expression "X * 1/C" is
1218 /// returned; otherwise, NULL is returned.
1219 ///
CvtFDivConstToReciprocal(Value * Dividend,Constant * Divisor,bool AllowReciprocal)1220 static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor,
1221 bool AllowReciprocal) {
1222 if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors.
1223 return nullptr;
1224
1225 const APFloat &FpVal = cast<ConstantFP>(Divisor)->getValueAPF();
1226 APFloat Reciprocal(FpVal.getSemantics());
1227 bool Cvt = FpVal.getExactInverse(&Reciprocal);
1228
1229 if (!Cvt && AllowReciprocal && FpVal.isFiniteNonZero()) {
1230 Reciprocal = APFloat(FpVal.getSemantics(), 1.0f);
1231 (void)Reciprocal.divide(FpVal, APFloat::rmNearestTiesToEven);
1232 Cvt = !Reciprocal.isDenormal();
1233 }
1234
1235 if (!Cvt)
1236 return nullptr;
1237
1238 ConstantFP *R;
1239 R = ConstantFP::get(Dividend->getType()->getContext(), Reciprocal);
1240 return BinaryOperator::CreateFMul(Dividend, R);
1241 }
1242
visitFDiv(BinaryOperator & I)1243 Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
1244 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1245
1246 if (Value *V = SimplifyVectorOp(I))
1247 return replaceInstUsesWith(I, V);
1248
1249 if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(),
1250 DL, TLI, DT, AC))
1251 return replaceInstUsesWith(I, V);
1252
1253 if (isa<Constant>(Op0))
1254 if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1255 if (Instruction *R = FoldOpIntoSelect(I, SI))
1256 return R;
1257
1258 bool AllowReassociate = I.hasUnsafeAlgebra();
1259 bool AllowReciprocal = I.hasAllowReciprocal();
1260
1261 if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
1262 if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
1263 if (Instruction *R = FoldOpIntoSelect(I, SI))
1264 return R;
1265
1266 if (AllowReassociate) {
1267 Constant *C1 = nullptr;
1268 Constant *C2 = Op1C;
1269 Value *X;
1270 Instruction *Res = nullptr;
1271
1272 if (match(Op0, m_FMul(m_Value(X), m_Constant(C1)))) {
1273 // (X*C1)/C2 => X * (C1/C2)
1274 //
1275 Constant *C = ConstantExpr::getFDiv(C1, C2);
1276 if (isNormalFp(C))
1277 Res = BinaryOperator::CreateFMul(X, C);
1278 } else if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
1279 // (X/C1)/C2 => X /(C2*C1) [=> X * 1/(C2*C1) if reciprocal is allowed]
1280 //
1281 Constant *C = ConstantExpr::getFMul(C1, C2);
1282 if (isNormalFp(C)) {
1283 Res = CvtFDivConstToReciprocal(X, C, AllowReciprocal);
1284 if (!Res)
1285 Res = BinaryOperator::CreateFDiv(X, C);
1286 }
1287 }
1288
1289 if (Res) {
1290 Res->setFastMathFlags(I.getFastMathFlags());
1291 return Res;
1292 }
1293 }
1294
1295 // X / C => X * 1/C
1296 if (Instruction *T = CvtFDivConstToReciprocal(Op0, Op1C, AllowReciprocal)) {
1297 T->copyFastMathFlags(&I);
1298 return T;
1299 }
1300
1301 return nullptr;
1302 }
1303
1304 if (AllowReassociate && isa<Constant>(Op0)) {
1305 Constant *C1 = cast<Constant>(Op0), *C2;
1306 Constant *Fold = nullptr;
1307 Value *X;
1308 bool CreateDiv = true;
1309
1310 // C1 / (X*C2) => (C1/C2) / X
1311 if (match(Op1, m_FMul(m_Value(X), m_Constant(C2))))
1312 Fold = ConstantExpr::getFDiv(C1, C2);
1313 else if (match(Op1, m_FDiv(m_Value(X), m_Constant(C2)))) {
1314 // C1 / (X/C2) => (C1*C2) / X
1315 Fold = ConstantExpr::getFMul(C1, C2);
1316 } else if (match(Op1, m_FDiv(m_Constant(C2), m_Value(X)))) {
1317 // C1 / (C2/X) => (C1/C2) * X
1318 Fold = ConstantExpr::getFDiv(C1, C2);
1319 CreateDiv = false;
1320 }
1321
1322 if (Fold && isNormalFp(Fold)) {
1323 Instruction *R = CreateDiv ? BinaryOperator::CreateFDiv(Fold, X)
1324 : BinaryOperator::CreateFMul(X, Fold);
1325 R->setFastMathFlags(I.getFastMathFlags());
1326 return R;
1327 }
1328 return nullptr;
1329 }
1330
1331 if (AllowReassociate) {
1332 Value *X, *Y;
1333 Value *NewInst = nullptr;
1334 Instruction *SimpR = nullptr;
1335
1336 if (Op0->hasOneUse() && match(Op0, m_FDiv(m_Value(X), m_Value(Y)))) {
1337 // (X/Y) / Z => X / (Y*Z)
1338 //
1339 if (!isa<Constant>(Y) || !isa<Constant>(Op1)) {
1340 NewInst = Builder->CreateFMul(Y, Op1);
1341 if (Instruction *RI = dyn_cast<Instruction>(NewInst)) {
1342 FastMathFlags Flags = I.getFastMathFlags();
1343 Flags &= cast<Instruction>(Op0)->getFastMathFlags();
1344 RI->setFastMathFlags(Flags);
1345 }
1346 SimpR = BinaryOperator::CreateFDiv(X, NewInst);
1347 }
1348 } else if (Op1->hasOneUse() && match(Op1, m_FDiv(m_Value(X), m_Value(Y)))) {
1349 // Z / (X/Y) => Z*Y / X
1350 //
1351 if (!isa<Constant>(Y) || !isa<Constant>(Op0)) {
1352 NewInst = Builder->CreateFMul(Op0, Y);
1353 if (Instruction *RI = dyn_cast<Instruction>(NewInst)) {
1354 FastMathFlags Flags = I.getFastMathFlags();
1355 Flags &= cast<Instruction>(Op1)->getFastMathFlags();
1356 RI->setFastMathFlags(Flags);
1357 }
1358 SimpR = BinaryOperator::CreateFDiv(NewInst, X);
1359 }
1360 }
1361
1362 if (NewInst) {
1363 if (Instruction *T = dyn_cast<Instruction>(NewInst))
1364 T->setDebugLoc(I.getDebugLoc());
1365 SimpR->setFastMathFlags(I.getFastMathFlags());
1366 return SimpR;
1367 }
1368 }
1369
1370 return nullptr;
1371 }
1372
1373 /// This function implements the transforms common to both integer remainder
1374 /// instructions (urem and srem). It is called by the visitors to those integer
1375 /// remainder instructions.
1376 /// @brief Common integer remainder transforms
commonIRemTransforms(BinaryOperator & I)1377 Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
1378 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1379
1380 // The RHS is known non-zero.
1381 if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) {
1382 I.setOperand(1, V);
1383 return &I;
1384 }
1385
1386 // Handle cases involving: rem X, (select Cond, Y, Z)
1387 if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
1388 return &I;
1389
1390 if (isa<Constant>(Op1)) {
1391 if (Instruction *Op0I = dyn_cast<Instruction>(Op0)) {
1392 if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {
1393 if (Instruction *R = FoldOpIntoSelect(I, SI))
1394 return R;
1395 } else if (isa<PHINode>(Op0I)) {
1396 using namespace llvm::PatternMatch;
1397 const APInt *Op1Int;
1398 if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&
1399 (I.getOpcode() == Instruction::URem ||
1400 !Op1Int->isMinSignedValue())) {
1401 // FoldOpIntoPhi will speculate instructions to the end of the PHI's
1402 // predecessor blocks, so do this only if we know the srem or urem
1403 // will not fault.
1404 if (Instruction *NV = FoldOpIntoPhi(I))
1405 return NV;
1406 }
1407 }
1408
1409 // See if we can fold away this rem instruction.
1410 if (SimplifyDemandedInstructionBits(I))
1411 return &I;
1412 }
1413 }
1414
1415 return nullptr;
1416 }
1417
visitURem(BinaryOperator & I)1418 Instruction *InstCombiner::visitURem(BinaryOperator &I) {
1419 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1420
1421 if (Value *V = SimplifyVectorOp(I))
1422 return replaceInstUsesWith(I, V);
1423
1424 if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC))
1425 return replaceInstUsesWith(I, V);
1426
1427 if (Instruction *common = commonIRemTransforms(I))
1428 return common;
1429
1430 // (zext A) urem (zext B) --> zext (A urem B)
1431 if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0))
1432 if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy()))
1433 return new ZExtInst(Builder->CreateURem(ZOp0->getOperand(0), ZOp1),
1434 I.getType());
1435
1436 // X urem Y -> X and Y-1, where Y is a power of 2,
1437 if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) {
1438 Constant *N1 = Constant::getAllOnesValue(I.getType());
1439 Value *Add = Builder->CreateAdd(Op1, N1);
1440 return BinaryOperator::CreateAnd(Op0, Add);
1441 }
1442
1443 // 1 urem X -> zext(X != 1)
1444 if (match(Op0, m_One())) {
1445 Value *Cmp = Builder->CreateICmpNE(Op1, Op0);
1446 Value *Ext = Builder->CreateZExt(Cmp, I.getType());
1447 return replaceInstUsesWith(I, Ext);
1448 }
1449
1450 return nullptr;
1451 }
1452
visitSRem(BinaryOperator & I)1453 Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
1454 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1455
1456 if (Value *V = SimplifyVectorOp(I))
1457 return replaceInstUsesWith(I, V);
1458
1459 if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC))
1460 return replaceInstUsesWith(I, V);
1461
1462 // Handle the integer rem common cases
1463 if (Instruction *Common = commonIRemTransforms(I))
1464 return Common;
1465
1466 {
1467 const APInt *Y;
1468 // X % -Y -> X % Y
1469 if (match(Op1, m_APInt(Y)) && Y->isNegative() && !Y->isMinSignedValue()) {
1470 Worklist.AddValue(I.getOperand(1));
1471 I.setOperand(1, ConstantInt::get(I.getType(), -*Y));
1472 return &I;
1473 }
1474 }
1475
1476 // If the sign bits of both operands are zero (i.e. we can prove they are
1477 // unsigned inputs), turn this into a urem.
1478 if (I.getType()->isIntegerTy()) {
1479 APInt Mask(APInt::getSignBit(I.getType()->getPrimitiveSizeInBits()));
1480 if (MaskedValueIsZero(Op1, Mask, 0, &I) &&
1481 MaskedValueIsZero(Op0, Mask, 0, &I)) {
1482 // X srem Y -> X urem Y, iff X and Y don't have sign bit set
1483 return BinaryOperator::CreateURem(Op0, Op1, I.getName());
1484 }
1485 }
1486
1487 // If it's a constant vector, flip any negative values positive.
1488 if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
1489 Constant *C = cast<Constant>(Op1);
1490 unsigned VWidth = C->getType()->getVectorNumElements();
1491
1492 bool hasNegative = false;
1493 bool hasMissing = false;
1494 for (unsigned i = 0; i != VWidth; ++i) {
1495 Constant *Elt = C->getAggregateElement(i);
1496 if (!Elt) {
1497 hasMissing = true;
1498 break;
1499 }
1500
1501 if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elt))
1502 if (RHS->isNegative())
1503 hasNegative = true;
1504 }
1505
1506 if (hasNegative && !hasMissing) {
1507 SmallVector<Constant *, 16> Elts(VWidth);
1508 for (unsigned i = 0; i != VWidth; ++i) {
1509 Elts[i] = C->getAggregateElement(i); // Handle undef, etc.
1510 if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) {
1511 if (RHS->isNegative())
1512 Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
1513 }
1514 }
1515
1516 Constant *NewRHSV = ConstantVector::get(Elts);
1517 if (NewRHSV != C) { // Don't loop on -MININT
1518 Worklist.AddValue(I.getOperand(1));
1519 I.setOperand(1, NewRHSV);
1520 return &I;
1521 }
1522 }
1523 }
1524
1525 return nullptr;
1526 }
1527
visitFRem(BinaryOperator & I)1528 Instruction *InstCombiner::visitFRem(BinaryOperator &I) {
1529 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1530
1531 if (Value *V = SimplifyVectorOp(I))
1532 return replaceInstUsesWith(I, V);
1533
1534 if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(),
1535 DL, TLI, DT, AC))
1536 return replaceInstUsesWith(I, V);
1537
1538 // Handle cases involving: rem X, (select Cond, Y, Z)
1539 if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
1540 return &I;
1541
1542 return nullptr;
1543 }
1544