1 //===- AggressiveInstCombine.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the aggressive expression pattern combiner classes.
10 // Currently, it handles expression patterns for:
11 // * Truncate instruction
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16 #include "AggressiveInstCombineInternal.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AliasAnalysis.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/BasicAliasAnalysis.h"
21 #include "llvm/Analysis/GlobalsModRef.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/Analysis/TargetTransformInfo.h"
24 #include "llvm/Analysis/ValueTracking.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/Dominators.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/PatternMatch.h"
30 #include "llvm/Transforms/Utils/BuildLibCalls.h"
31 #include "llvm/Transforms/Utils/Local.h"
32
33 using namespace llvm;
34 using namespace PatternMatch;
35
36 #define DEBUG_TYPE "aggressive-instcombine"
37
38 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
39 STATISTIC(NumGuardedRotates,
40 "Number of guarded rotates transformed into funnel shifts");
41 STATISTIC(NumGuardedFunnelShifts,
42 "Number of guarded funnel shifts transformed into funnel shifts");
43 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
44
45 static cl::opt<unsigned> MaxInstrsToScan(
46 "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden,
47 cl::desc("Max number of instructions to scan for aggressive instcombine."));
48
49 /// Match a pattern for a bitwise funnel/rotate operation that partially guards
50 /// against undefined behavior by branching around the funnel-shift/rotation
51 /// when the shift amount is 0.
foldGuardedFunnelShift(Instruction & I,const DominatorTree & DT)52 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
53 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
54 return false;
55
56 // As with the one-use checks below, this is not strictly necessary, but we
57 // are being cautious to avoid potential perf regressions on targets that
58 // do not actually have a funnel/rotate instruction (where the funnel shift
59 // would be expanded back into math/shift/logic ops).
60 if (!isPowerOf2_32(I.getType()->getScalarSizeInBits()))
61 return false;
62
63 // Match V to funnel shift left/right and capture the source operands and
64 // shift amount.
65 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
66 Value *&ShAmt) {
67 Value *SubAmt;
68 unsigned Width = V->getType()->getScalarSizeInBits();
69
70 // fshl(ShVal0, ShVal1, ShAmt)
71 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
72 if (match(V, m_OneUse(m_c_Or(
73 m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
74 m_LShr(m_Value(ShVal1),
75 m_Sub(m_SpecificInt(Width), m_Value(SubAmt))))))) {
76 if (ShAmt == SubAmt) // TODO: Use m_Specific
77 return Intrinsic::fshl;
78 }
79
80 // fshr(ShVal0, ShVal1, ShAmt)
81 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
82 if (match(V,
83 m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
84 m_Value(SubAmt))),
85 m_LShr(m_Value(ShVal1), m_Value(ShAmt)))))) {
86 if (ShAmt == SubAmt) // TODO: Use m_Specific
87 return Intrinsic::fshr;
88 }
89
90 return Intrinsic::not_intrinsic;
91 };
92
93 // One phi operand must be a funnel/rotate operation, and the other phi
94 // operand must be the source value of that funnel/rotate operation:
95 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
96 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
97 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
98 PHINode &Phi = cast<PHINode>(I);
99 unsigned FunnelOp = 0, GuardOp = 1;
100 Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1);
101 Value *ShVal0, *ShVal1, *ShAmt;
102 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
103 if (IID == Intrinsic::not_intrinsic ||
104 (IID == Intrinsic::fshl && ShVal0 != P1) ||
105 (IID == Intrinsic::fshr && ShVal1 != P1)) {
106 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
107 if (IID == Intrinsic::not_intrinsic ||
108 (IID == Intrinsic::fshl && ShVal0 != P0) ||
109 (IID == Intrinsic::fshr && ShVal1 != P0))
110 return false;
111 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
112 "Pattern must match funnel shift left or right");
113 std::swap(FunnelOp, GuardOp);
114 }
115
116 // The incoming block with our source operand must be the "guard" block.
117 // That must contain a cmp+branch to avoid the funnel/rotate when the shift
118 // amount is equal to 0. The other incoming block is the block with the
119 // funnel/rotate.
120 BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp);
121 BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp);
122 Instruction *TermI = GuardBB->getTerminator();
123
124 // Ensure that the shift values dominate each block.
125 if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
126 return false;
127
128 ICmpInst::Predicate Pred;
129 BasicBlock *PhiBB = Phi.getParent();
130 if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()),
131 m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
132 return false;
133
134 if (Pred != CmpInst::ICMP_EQ)
135 return false;
136
137 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
138
139 if (ShVal0 == ShVal1)
140 ++NumGuardedRotates;
141 else
142 ++NumGuardedFunnelShifts;
143
144 // If this is not a rotate then the select was blocking poison from the
145 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
146 bool IsFshl = IID == Intrinsic::fshl;
147 if (ShVal0 != ShVal1) {
148 if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1))
149 ShVal1 = Builder.CreateFreeze(ShVal1);
150 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0))
151 ShVal0 = Builder.CreateFreeze(ShVal0);
152 }
153
154 // We matched a variation of this IR pattern:
155 // GuardBB:
156 // %cmp = icmp eq i32 %ShAmt, 0
157 // br i1 %cmp, label %PhiBB, label %FunnelBB
158 // FunnelBB:
159 // %sub = sub i32 32, %ShAmt
160 // %shr = lshr i32 %ShVal1, %sub
161 // %shl = shl i32 %ShVal0, %ShAmt
162 // %fsh = or i32 %shr, %shl
163 // br label %PhiBB
164 // PhiBB:
165 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
166 // -->
167 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
168 Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType());
169 Phi.replaceAllUsesWith(Builder.CreateCall(F, {ShVal0, ShVal1, ShAmt}));
170 return true;
171 }
172
173 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
174 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
175 /// of 'and' ops, then we also need to capture the fact that we saw an
176 /// "and X, 1", so that's an extra return value for that case.
177 struct MaskOps {
178 Value *Root = nullptr;
179 APInt Mask;
180 bool MatchAndChain;
181 bool FoundAnd1 = false;
182
MaskOpsMaskOps183 MaskOps(unsigned BitWidth, bool MatchAnds)
184 : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {}
185 };
186
187 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
188 /// chain of 'and' or 'or' instructions looking for shift ops of a common source
189 /// value. Examples:
190 /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
191 /// returns { X, 0x129 }
192 /// and (and (X >> 1), 1), (X >> 4)
193 /// returns { X, 0x12 }
matchAndOrChain(Value * V,MaskOps & MOps)194 static bool matchAndOrChain(Value *V, MaskOps &MOps) {
195 Value *Op0, *Op1;
196 if (MOps.MatchAndChain) {
197 // Recurse through a chain of 'and' operands. This requires an extra check
198 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
199 // in the chain to know that all of the high bits are cleared.
200 if (match(V, m_And(m_Value(Op0), m_One()))) {
201 MOps.FoundAnd1 = true;
202 return matchAndOrChain(Op0, MOps);
203 }
204 if (match(V, m_And(m_Value(Op0), m_Value(Op1))))
205 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
206 } else {
207 // Recurse through a chain of 'or' operands.
208 if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
209 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
210 }
211
212 // We need a shift-right or a bare value representing a compare of bit 0 of
213 // the original source operand.
214 Value *Candidate;
215 const APInt *BitIndex = nullptr;
216 if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex))))
217 Candidate = V;
218
219 // Initialize result source operand.
220 if (!MOps.Root)
221 MOps.Root = Candidate;
222
223 // The shift constant is out-of-range? This code hasn't been simplified.
224 if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth()))
225 return false;
226
227 // Fill in the mask bit derived from the shift constant.
228 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
229 return MOps.Root == Candidate;
230 }
231
232 /// Match patterns that correspond to "any-bits-set" and "all-bits-set".
233 /// These will include a chain of 'or' or 'and'-shifted bits from a
234 /// common source value:
235 /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0
236 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
237 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
238 /// that differ only with a final 'not' of the result. We expect that final
239 /// 'not' to be folded with the compare that we create here (invert predicate).
foldAnyOrAllBitsSet(Instruction & I)240 static bool foldAnyOrAllBitsSet(Instruction &I) {
241 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
242 // final "and X, 1" instruction must be the final op in the sequence.
243 bool MatchAllBitsSet;
244 if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
245 MatchAllBitsSet = true;
246 else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
247 MatchAllBitsSet = false;
248 else
249 return false;
250
251 MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
252 if (MatchAllBitsSet) {
253 if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1)
254 return false;
255 } else {
256 if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps))
257 return false;
258 }
259
260 // The pattern was found. Create a masked compare that replaces all of the
261 // shift and logic ops.
262 IRBuilder<> Builder(&I);
263 Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask);
264 Value *And = Builder.CreateAnd(MOps.Root, Mask);
265 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask)
266 : Builder.CreateIsNotNull(And);
267 Value *Zext = Builder.CreateZExt(Cmp, I.getType());
268 I.replaceAllUsesWith(Zext);
269 ++NumAnyOrAllBitsSet;
270 return true;
271 }
272
273 // Try to recognize below function as popcount intrinsic.
274 // This is the "best" algorithm from
275 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
276 // Also used in TargetLowering::expandCTPOP().
277 //
278 // int popcount(unsigned int i) {
279 // i = i - ((i >> 1) & 0x55555555);
280 // i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
281 // i = ((i + (i >> 4)) & 0x0F0F0F0F);
282 // return (i * 0x01010101) >> 24;
283 // }
tryToRecognizePopCount(Instruction & I)284 static bool tryToRecognizePopCount(Instruction &I) {
285 if (I.getOpcode() != Instruction::LShr)
286 return false;
287
288 Type *Ty = I.getType();
289 if (!Ty->isIntOrIntVectorTy())
290 return false;
291
292 unsigned Len = Ty->getScalarSizeInBits();
293 // FIXME: fix Len == 8 and other irregular type lengths.
294 if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
295 return false;
296
297 APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
298 APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
299 APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
300 APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
301 APInt MaskShift = APInt(Len, Len - 8);
302
303 Value *Op0 = I.getOperand(0);
304 Value *Op1 = I.getOperand(1);
305 Value *MulOp0;
306 // Matching "(i * 0x01010101...) >> 24".
307 if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
308 match(Op1, m_SpecificInt(MaskShift))) {
309 Value *ShiftOp0;
310 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
311 if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
312 m_Deferred(ShiftOp0)),
313 m_SpecificInt(Mask0F)))) {
314 Value *AndOp0;
315 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
316 if (match(ShiftOp0,
317 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
318 m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
319 m_SpecificInt(Mask33))))) {
320 Value *Root, *SubOp1;
321 // Matching "i - ((i >> 1) & 0x55555555...)".
322 if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
323 match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
324 m_SpecificInt(Mask55)))) {
325 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
326 IRBuilder<> Builder(&I);
327 Function *Func = Intrinsic::getDeclaration(
328 I.getModule(), Intrinsic::ctpop, I.getType());
329 I.replaceAllUsesWith(Builder.CreateCall(Func, {Root}));
330 ++NumPopCountRecognized;
331 return true;
332 }
333 }
334 }
335 }
336
337 return false;
338 }
339
340 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
341 /// C2 saturate the value of the fp conversion. The transform is not reversable
342 /// as the fptosi.sat is more defined than the input - all values produce a
343 /// valid value for the fptosi.sat, where as some produce poison for original
344 /// that were out of range of the integer conversion. The reversed pattern may
345 /// use fmax and fmin instead. As we cannot directly reverse the transform, and
346 /// it is not always profitable, we make it conditional on the cost being
347 /// reported as lower by TTI.
tryToFPToSat(Instruction & I,TargetTransformInfo & TTI)348 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
349 // Look for min(max(fptosi, converting to fptosi_sat.
350 Value *In;
351 const APInt *MinC, *MaxC;
352 if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))),
353 m_APInt(MinC))),
354 m_APInt(MaxC))) &&
355 !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))),
356 m_APInt(MaxC))),
357 m_APInt(MinC))))
358 return false;
359
360 // Check that the constants clamp a saturate.
361 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
362 return false;
363
364 Type *IntTy = I.getType();
365 Type *FpTy = In->getType();
366 Type *SatTy =
367 IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1);
368 if (auto *VecTy = dyn_cast<VectorType>(IntTy))
369 SatTy = VectorType::get(SatTy, VecTy->getElementCount());
370
371 // Get the cost of the intrinsic, and check that against the cost of
372 // fptosi+smin+smax
373 InstructionCost SatCost = TTI.getIntrinsicInstrCost(
374 IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
375 TTI::TCK_RecipThroughput);
376 SatCost += TTI.getCastInstrCost(Instruction::SExt, SatTy, IntTy,
377 TTI::CastContextHint::None,
378 TTI::TCK_RecipThroughput);
379
380 InstructionCost MinMaxCost = TTI.getCastInstrCost(
381 Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None,
382 TTI::TCK_RecipThroughput);
383 MinMaxCost += TTI.getIntrinsicInstrCost(
384 IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
385 TTI::TCK_RecipThroughput);
386 MinMaxCost += TTI.getIntrinsicInstrCost(
387 IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
388 TTI::TCK_RecipThroughput);
389
390 if (SatCost >= MinMaxCost)
391 return false;
392
393 IRBuilder<> Builder(&I);
394 Function *Fn = Intrinsic::getDeclaration(I.getModule(), Intrinsic::fptosi_sat,
395 {SatTy, FpTy});
396 Value *Sat = Builder.CreateCall(Fn, In);
397 I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy));
398 return true;
399 }
400
401 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
402 /// pessimistic codegen that has to account for setting errno and can enable
403 /// vectorization.
404 static bool
foldSqrt(Instruction & I,TargetTransformInfo & TTI,TargetLibraryInfo & TLI)405 foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) {
406 // Match a call to sqrt mathlib function.
407 auto *Call = dyn_cast<CallInst>(&I);
408 if (!Call)
409 return false;
410
411 Module *M = Call->getModule();
412 LibFunc Func;
413 if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func))
414 return false;
415
416 if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl)
417 return false;
418
419 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
420 // (because NNAN or the operand arg must not be less than -0.0) and (2) we
421 // would not end up lowering to a libcall anyway (which could change the value
422 // of errno), then:
423 // (1) errno won't be set.
424 // (2) it is safe to convert this to an intrinsic call.
425 Type *Ty = Call->getType();
426 Value *Arg = Call->getArgOperand(0);
427 if (TTI.haveFastSqrt(Ty) &&
428 (Call->hasNoNaNs() || CannotBeOrderedLessThanZero(Arg, &TLI))) {
429 IRBuilder<> Builder(&I);
430 IRBuilderBase::FastMathFlagGuard Guard(Builder);
431 Builder.setFastMathFlags(Call->getFastMathFlags());
432
433 Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
434 Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt");
435 I.replaceAllUsesWith(NewSqrt);
436
437 // Explicitly erase the old call because a call with side effects is not
438 // trivially dead.
439 I.eraseFromParent();
440 return true;
441 }
442
443 return false;
444 }
445
446 // Check if this array of constants represents a cttz table.
447 // Iterate over the elements from \p Table by trying to find/match all
448 // the numbers from 0 to \p InputBits that should represent cttz results.
isCTTZTable(const ConstantDataArray & Table,uint64_t Mul,uint64_t Shift,uint64_t InputBits)449 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
450 uint64_t Shift, uint64_t InputBits) {
451 unsigned Length = Table.getNumElements();
452 if (Length < InputBits || Length > InputBits * 2)
453 return false;
454
455 APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
456 unsigned Matched = 0;
457
458 for (unsigned i = 0; i < Length; i++) {
459 uint64_t Element = Table.getElementAsInteger(i);
460 if (Element >= InputBits)
461 continue;
462
463 // Check if \p Element matches a concrete answer. It could fail for some
464 // elements that are never accessed, so we keep iterating over each element
465 // from the table. The number of matched elements should be equal to the
466 // number of potential right answers which is \p InputBits actually.
467 if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
468 Matched++;
469 }
470
471 return Matched == InputBits;
472 }
473
474 // Try to recognize table-based ctz implementation.
475 // E.g., an example in C (for more cases please see the llvm/tests):
476 // int f(unsigned x) {
477 // static const char table[32] =
478 // {0, 1, 28, 2, 29, 14, 24, 3, 30,
479 // 22, 20, 15, 25, 17, 4, 8, 31, 27,
480 // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
481 // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
482 // }
483 // this can be lowered to `cttz` instruction.
484 // There is also a special case when the element is 0.
485 //
486 // Here are some examples or LLVM IR for a 64-bit target:
487 //
488 // CASE 1:
489 // %sub = sub i32 0, %x
490 // %and = and i32 %sub, %x
491 // %mul = mul i32 %and, 125613361
492 // %shr = lshr i32 %mul, 27
493 // %idxprom = zext i32 %shr to i64
494 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
495 // i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
496 //
497 // CASE 2:
498 // %sub = sub i32 0, %x
499 // %and = and i32 %sub, %x
500 // %mul = mul i32 %and, 72416175
501 // %shr = lshr i32 %mul, 26
502 // %idxprom = zext i32 %shr to i64
503 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64
504 // 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
505 //
506 // CASE 3:
507 // %sub = sub i32 0, %x
508 // %and = and i32 %sub, %x
509 // %mul = mul i32 %and, 81224991
510 // %shr = lshr i32 %mul, 27
511 // %idxprom = zext i32 %shr to i64
512 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64
513 // 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
514 //
515 // CASE 4:
516 // %sub = sub i64 0, %x
517 // %and = and i64 %sub, %x
518 // %mul = mul i64 %and, 283881067100198605
519 // %shr = lshr i64 %mul, 58
520 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64
521 // %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
522 //
523 // All this can be lowered to @llvm.cttz.i32/64 intrinsic.
tryToRecognizeTableBasedCttz(Instruction & I)524 static bool tryToRecognizeTableBasedCttz(Instruction &I) {
525 LoadInst *LI = dyn_cast<LoadInst>(&I);
526 if (!LI)
527 return false;
528
529 Type *AccessType = LI->getType();
530 if (!AccessType->isIntegerTy())
531 return false;
532
533 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
534 if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2)
535 return false;
536
537 if (!GEP->getSourceElementType()->isArrayTy())
538 return false;
539
540 uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
541 if (ArraySize != 32 && ArraySize != 64)
542 return false;
543
544 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
545 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
546 return false;
547
548 ConstantDataArray *ConstData =
549 dyn_cast<ConstantDataArray>(GVTable->getInitializer());
550 if (!ConstData)
551 return false;
552
553 if (!match(GEP->idx_begin()->get(), m_ZeroInt()))
554 return false;
555
556 Value *Idx2 = std::next(GEP->idx_begin())->get();
557 Value *X1;
558 uint64_t MulConst, ShiftConst;
559 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
560 // probably fail for other (e.g. 32-bit) targets.
561 if (!match(Idx2, m_ZExtOrSelf(
562 m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)),
563 m_ConstantInt(MulConst)),
564 m_ConstantInt(ShiftConst)))))
565 return false;
566
567 unsigned InputBits = X1->getType()->getScalarSizeInBits();
568 if (InputBits != 32 && InputBits != 64)
569 return false;
570
571 // Shift should extract top 5..7 bits.
572 if (InputBits - Log2_32(InputBits) != ShiftConst &&
573 InputBits - Log2_32(InputBits) - 1 != ShiftConst)
574 return false;
575
576 if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
577 return false;
578
579 auto ZeroTableElem = ConstData->getElementAsInteger(0);
580 bool DefinedForZero = ZeroTableElem == InputBits;
581
582 IRBuilder<> B(LI);
583 ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
584 Type *XType = X1->getType();
585 auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst});
586 Value *ZExtOrTrunc = nullptr;
587
588 if (DefinedForZero) {
589 ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType);
590 } else {
591 // If the value in elem 0 isn't the same as InputBits, we still want to
592 // produce the value from the table.
593 auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
594 auto Select =
595 B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
596
597 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
598 // it should be handled as: `cttz(x) & (typeSize - 1)`.
599
600 ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType);
601 }
602
603 LI->replaceAllUsesWith(ZExtOrTrunc);
604
605 return true;
606 }
607
608 /// This is used by foldLoadsRecursive() to capture a Root Load node which is
609 /// of type or(load, load) and recursively build the wide load. Also capture the
610 /// shift amount, zero extend type and loadSize.
611 struct LoadOps {
612 LoadInst *Root = nullptr;
613 LoadInst *RootInsert = nullptr;
614 bool FoundRoot = false;
615 uint64_t LoadSize = 0;
616 Value *Shift = nullptr;
617 Type *ZextType;
618 AAMDNodes AATags;
619 };
620
621 // Identify and Merge consecutive loads recursively which is of the form
622 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
623 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
foldLoadsRecursive(Value * V,LoadOps & LOps,const DataLayout & DL,AliasAnalysis & AA)624 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
625 AliasAnalysis &AA) {
626 Value *ShAmt2 = nullptr;
627 Value *X;
628 Instruction *L1, *L2;
629
630 // Go to the last node with loads.
631 if (match(V, m_OneUse(m_c_Or(
632 m_Value(X),
633 m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
634 m_Value(ShAmt2)))))) ||
635 match(V, m_OneUse(m_Or(m_Value(X),
636 m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
637 if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
638 // Avoid Partial chain merge.
639 return false;
640 } else
641 return false;
642
643 // Check if the pattern has loads
644 LoadInst *LI1 = LOps.Root;
645 Value *ShAmt1 = LOps.Shift;
646 if (LOps.FoundRoot == false &&
647 (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
648 match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
649 m_Value(ShAmt1)))))) {
650 LI1 = dyn_cast<LoadInst>(L1);
651 }
652 LoadInst *LI2 = dyn_cast<LoadInst>(L2);
653
654 // Check if loads are same, atomic, volatile and having same address space.
655 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
656 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
657 return false;
658
659 // Check if Loads come from same BB.
660 if (LI1->getParent() != LI2->getParent())
661 return false;
662
663 // Find the data layout
664 bool IsBigEndian = DL.isBigEndian();
665
666 // Check if loads are consecutive and same size.
667 Value *Load1Ptr = LI1->getPointerOperand();
668 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
669 Load1Ptr =
670 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1,
671 /* AllowNonInbounds */ true);
672
673 Value *Load2Ptr = LI2->getPointerOperand();
674 APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0);
675 Load2Ptr =
676 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2,
677 /* AllowNonInbounds */ true);
678
679 // Verify if both loads have same base pointers and load sizes are same.
680 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
681 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
682 if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2)
683 return false;
684
685 // Support Loadsizes greater or equal to 8bits and only power of 2.
686 if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1))
687 return false;
688
689 // Alias Analysis to check for stores b/w the loads.
690 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
691 MemoryLocation Loc;
692 if (!Start->comesBefore(End)) {
693 std::swap(Start, End);
694 Loc = MemoryLocation::get(End);
695 if (LOps.FoundRoot)
696 Loc = Loc.getWithNewSize(LOps.LoadSize);
697 } else
698 Loc = MemoryLocation::get(End);
699 unsigned NumScanned = 0;
700 for (Instruction &Inst :
701 make_range(Start->getIterator(), End->getIterator())) {
702 if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
703 return false;
704 if (++NumScanned > MaxInstrsToScan)
705 return false;
706 }
707
708 // Make sure Load with lower Offset is at LI1
709 bool Reverse = false;
710 if (Offset2.slt(Offset1)) {
711 std::swap(LI1, LI2);
712 std::swap(ShAmt1, ShAmt2);
713 std::swap(Offset1, Offset2);
714 std::swap(Load1Ptr, Load2Ptr);
715 std::swap(LoadSize1, LoadSize2);
716 Reverse = true;
717 }
718
719 // Big endian swap the shifts
720 if (IsBigEndian)
721 std::swap(ShAmt1, ShAmt2);
722
723 // Find Shifts values.
724 const APInt *Temp;
725 uint64_t Shift1 = 0, Shift2 = 0;
726 if (ShAmt1 && match(ShAmt1, m_APInt(Temp)))
727 Shift1 = Temp->getZExtValue();
728 if (ShAmt2 && match(ShAmt2, m_APInt(Temp)))
729 Shift2 = Temp->getZExtValue();
730
731 // First load is always LI1. This is where we put the new load.
732 // Use the merged load size available from LI1 for forward loads.
733 if (LOps.FoundRoot) {
734 if (!Reverse)
735 LoadSize1 = LOps.LoadSize;
736 else
737 LoadSize2 = LOps.LoadSize;
738 }
739
740 // Verify if shift amount and load index aligns and verifies that loads
741 // are consecutive.
742 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
743 uint64_t PrevSize =
744 DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
745 if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
746 return false;
747
748 // Update LOps
749 AAMDNodes AATags1 = LOps.AATags;
750 AAMDNodes AATags2 = LI2->getAAMetadata();
751 if (LOps.FoundRoot == false) {
752 LOps.FoundRoot = true;
753 AATags1 = LI1->getAAMetadata();
754 }
755 LOps.LoadSize = LoadSize1 + LoadSize2;
756 LOps.RootInsert = Start;
757
758 // Concatenate the AATags of the Merged Loads.
759 LOps.AATags = AATags1.concat(AATags2);
760
761 LOps.Root = LI1;
762 LOps.Shift = ShAmt1;
763 LOps.ZextType = X->getType();
764 return true;
765 }
766
767 // For a given BB instruction, evaluate all loads in the chain that form a
768 // pattern which suggests that the loads can be combined. The one and only use
769 // of the loads is to form a wider load.
foldConsecutiveLoads(Instruction & I,const DataLayout & DL,TargetTransformInfo & TTI,AliasAnalysis & AA)770 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
771 TargetTransformInfo &TTI, AliasAnalysis &AA) {
772 // Only consider load chains of scalar values.
773 if (isa<VectorType>(I.getType()))
774 return false;
775
776 LoadOps LOps;
777 if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot)
778 return false;
779
780 IRBuilder<> Builder(&I);
781 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
782
783 IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize);
784 // TTI based checks if we want to proceed with wider load
785 bool Allowed = TTI.isTypeLegal(WiderType);
786 if (!Allowed)
787 return false;
788
789 unsigned AS = LI1->getPointerAddressSpace();
790 unsigned Fast = 0;
791 Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize,
792 AS, LI1->getAlign(), &Fast);
793 if (!Allowed || !Fast)
794 return false;
795
796 // Make sure the Load pointer of type GEP/non-GEP is above insert point
797 Instruction *Inst = dyn_cast<Instruction>(LI1->getPointerOperand());
798 if (Inst && Inst->getParent() == LI1->getParent() &&
799 !Inst->comesBefore(LOps.RootInsert))
800 Inst->moveBefore(LOps.RootInsert);
801
802 // New load can be generated
803 Value *Load1Ptr = LI1->getPointerOperand();
804 Builder.SetInsertPoint(LOps.RootInsert);
805 Value *NewPtr = Builder.CreateBitCast(Load1Ptr, WiderType->getPointerTo(AS));
806 NewLoad = Builder.CreateAlignedLoad(WiderType, NewPtr, LI1->getAlign(),
807 LI1->isVolatile(), "");
808 NewLoad->takeName(LI1);
809 // Set the New Load AATags Metadata.
810 if (LOps.AATags)
811 NewLoad->setAAMetadata(LOps.AATags);
812
813 Value *NewOp = NewLoad;
814 // Check if zero extend needed.
815 if (LOps.ZextType)
816 NewOp = Builder.CreateZExt(NewOp, LOps.ZextType);
817
818 // Check if shift needed. We need to shift with the amount of load1
819 // shift if not zero.
820 if (LOps.Shift)
821 NewOp = Builder.CreateShl(NewOp, LOps.Shift);
822 I.replaceAllUsesWith(NewOp);
823
824 return true;
825 }
826
827 /// This is the entry point for folds that could be implemented in regular
828 /// InstCombine, but they are separated because they are not expected to
829 /// occur frequently and/or have more than a constant-length pattern match.
foldUnusualPatterns(Function & F,DominatorTree & DT,TargetTransformInfo & TTI,TargetLibraryInfo & TLI,AliasAnalysis & AA)830 static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
831 TargetTransformInfo &TTI,
832 TargetLibraryInfo &TLI, AliasAnalysis &AA) {
833 bool MadeChange = false;
834 for (BasicBlock &BB : F) {
835 // Ignore unreachable basic blocks.
836 if (!DT.isReachableFromEntry(&BB))
837 continue;
838
839 const DataLayout &DL = F.getParent()->getDataLayout();
840
841 // Walk the block backwards for efficiency. We're matching a chain of
842 // use->defs, so we're more likely to succeed by starting from the bottom.
843 // Also, we want to avoid matching partial patterns.
844 // TODO: It would be more efficient if we removed dead instructions
845 // iteratively in this loop rather than waiting until the end.
846 for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) {
847 MadeChange |= foldAnyOrAllBitsSet(I);
848 MadeChange |= foldGuardedFunnelShift(I, DT);
849 MadeChange |= tryToRecognizePopCount(I);
850 MadeChange |= tryToFPToSat(I, TTI);
851 MadeChange |= tryToRecognizeTableBasedCttz(I);
852 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA);
853 // NOTE: This function introduces erasing of the instruction `I`, so it
854 // needs to be called at the end of this sequence, otherwise we may make
855 // bugs.
856 MadeChange |= foldSqrt(I, TTI, TLI);
857 }
858 }
859
860 // We're done with transforms, so remove dead instructions.
861 if (MadeChange)
862 for (BasicBlock &BB : F)
863 SimplifyInstructionsInBlock(&BB);
864
865 return MadeChange;
866 }
867
868 /// This is the entry point for all transforms. Pass manager differences are
869 /// handled in the callers of this function.
runImpl(Function & F,AssumptionCache & AC,TargetTransformInfo & TTI,TargetLibraryInfo & TLI,DominatorTree & DT,AliasAnalysis & AA)870 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
871 TargetLibraryInfo &TLI, DominatorTree &DT,
872 AliasAnalysis &AA) {
873 bool MadeChange = false;
874 const DataLayout &DL = F.getParent()->getDataLayout();
875 TruncInstCombine TIC(AC, TLI, DL, DT);
876 MadeChange |= TIC.run(F);
877 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA);
878 return MadeChange;
879 }
880
run(Function & F,FunctionAnalysisManager & AM)881 PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
882 FunctionAnalysisManager &AM) {
883 auto &AC = AM.getResult<AssumptionAnalysis>(F);
884 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
885 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
886 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
887 auto &AA = AM.getResult<AAManager>(F);
888 if (!runImpl(F, AC, TTI, TLI, DT, AA)) {
889 // No changes, all analyses are preserved.
890 return PreservedAnalyses::all();
891 }
892 // Mark all the analyses that instcombine updates as preserved.
893 PreservedAnalyses PA;
894 PA.preserveSet<CFGAnalyses>();
895 return PA;
896 }
897