1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21
22 #include "AArch64TargetMachine.h"
23 #include "llvm/CodeGen/GlobalISel/Combiner.h"
24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
26 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
27 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
28 #include "llvm/CodeGen/GlobalISel/Utils.h"
29 #include "llvm/CodeGen/MachineDominators.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineRegisterInfo.h"
32 #include "llvm/CodeGen/TargetOpcodes.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
34 #include "llvm/Support/Debug.h"
35
36 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
37
38 using namespace llvm;
39
40 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
41 /// Rewrite for pairwise fadd pattern
42 /// (s32 (g_extract_vector_elt
43 /// (g_fadd (vXs32 Other)
44 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
45 /// ->
46 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
47 /// (g_extract_vector_elt (vXs32 Other) 1))
matchExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<unsigned,LLT,Register> & MatchInfo)48 bool matchExtractVecEltPairwiseAdd(
49 MachineInstr &MI, MachineRegisterInfo &MRI,
50 std::tuple<unsigned, LLT, Register> &MatchInfo) {
51 Register Src1 = MI.getOperand(1).getReg();
52 Register Src2 = MI.getOperand(2).getReg();
53 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
54
55 auto Cst = getConstantVRegValWithLookThrough(Src2, MRI);
56 if (!Cst || Cst->Value != 0)
57 return false;
58 // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
59
60 // Now check for an fadd operation. TODO: expand this for integer add?
61 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
62 if (!FAddMI)
63 return false;
64
65 // If we add support for integer add, must restrict these types to just s64.
66 unsigned DstSize = DstTy.getSizeInBits();
67 if (DstSize != 16 && DstSize != 32 && DstSize != 64)
68 return false;
69
70 Register Src1Op1 = FAddMI->getOperand(1).getReg();
71 Register Src1Op2 = FAddMI->getOperand(2).getReg();
72 MachineInstr *Shuffle =
73 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
74 MachineInstr *Other = MRI.getVRegDef(Src1Op1);
75 if (!Shuffle) {
76 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
77 Other = MRI.getVRegDef(Src1Op2);
78 }
79
80 // We're looking for a shuffle that moves the second element to index 0.
81 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
82 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
83 std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
84 std::get<1>(MatchInfo) = DstTy;
85 std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
86 return true;
87 }
88 return false;
89 }
90
applyExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::tuple<unsigned,LLT,Register> & MatchInfo)91 bool applyExtractVecEltPairwiseAdd(
92 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
93 std::tuple<unsigned, LLT, Register> &MatchInfo) {
94 unsigned Opc = std::get<0>(MatchInfo);
95 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
96 // We want to generate two extracts of elements 0 and 1, and add them.
97 LLT Ty = std::get<1>(MatchInfo);
98 Register Src = std::get<2>(MatchInfo);
99 LLT s64 = LLT::scalar(64);
100 B.setInstrAndDebugLoc(MI);
101 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
102 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
103 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
104 MI.eraseFromParent();
105 return true;
106 }
107
isSignExtended(Register R,MachineRegisterInfo & MRI)108 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
109 // TODO: check if extended build vector as well.
110 unsigned Opc = MRI.getVRegDef(R)->getOpcode();
111 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
112 }
113
isZeroExtended(Register R,MachineRegisterInfo & MRI)114 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
115 // TODO: check if extended build vector as well.
116 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
117 }
118
matchAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)119 bool matchAArch64MulConstCombine(
120 MachineInstr &MI, MachineRegisterInfo &MRI,
121 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
122 assert(MI.getOpcode() == TargetOpcode::G_MUL);
123 Register LHS = MI.getOperand(1).getReg();
124 Register RHS = MI.getOperand(2).getReg();
125 Register Dst = MI.getOperand(0).getReg();
126 const LLT Ty = MRI.getType(LHS);
127
128 // The below optimizations require a constant RHS.
129 auto Const = getConstantVRegValWithLookThrough(RHS, MRI);
130 if (!Const)
131 return false;
132
133 const APInt &ConstValue = APInt(Ty.getSizeInBits(), Const->Value, true);
134 // The following code is ported from AArch64ISelLowering.
135 // Multiplication of a power of two plus/minus one can be done more
136 // cheaply as as shift+add/sub. For now, this is true unilaterally. If
137 // future CPUs have a cheaper MADD instruction, this may need to be
138 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
139 // 64-bit is 5 cycles, so this is always a win.
140 // More aggressively, some multiplications N0 * C can be lowered to
141 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
142 // e.g. 6=3*2=(2+1)*2.
143 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
144 // which equals to (1+2)*16-(1+2).
145 // TrailingZeroes is used to test if the mul can be lowered to
146 // shift+add+shift.
147 unsigned TrailingZeroes = ConstValue.countTrailingZeros();
148 if (TrailingZeroes) {
149 // Conservatively do not lower to shift+add+shift if the mul might be
150 // folded into smul or umul.
151 if (MRI.hasOneNonDBGUse(LHS) &&
152 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
153 return false;
154 // Conservatively do not lower to shift+add+shift if the mul might be
155 // folded into madd or msub.
156 if (MRI.hasOneNonDBGUse(Dst)) {
157 MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
158 if (UseMI.getOpcode() == TargetOpcode::G_ADD ||
159 UseMI.getOpcode() == TargetOpcode::G_SUB)
160 return false;
161 }
162 }
163 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
164 // and shift+add+shift.
165 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
166
167 unsigned ShiftAmt, AddSubOpc;
168 // Is the shifted value the LHS operand of the add/sub?
169 bool ShiftValUseIsLHS = true;
170 // Do we need to negate the result?
171 bool NegateResult = false;
172
173 if (ConstValue.isNonNegative()) {
174 // (mul x, 2^N + 1) => (add (shl x, N), x)
175 // (mul x, 2^N - 1) => (sub (shl x, N), x)
176 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
177 APInt SCVMinus1 = ShiftedConstValue - 1;
178 APInt CVPlus1 = ConstValue + 1;
179 if (SCVMinus1.isPowerOf2()) {
180 ShiftAmt = SCVMinus1.logBase2();
181 AddSubOpc = TargetOpcode::G_ADD;
182 } else if (CVPlus1.isPowerOf2()) {
183 ShiftAmt = CVPlus1.logBase2();
184 AddSubOpc = TargetOpcode::G_SUB;
185 } else
186 return false;
187 } else {
188 // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
189 // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
190 APInt CVNegPlus1 = -ConstValue + 1;
191 APInt CVNegMinus1 = -ConstValue - 1;
192 if (CVNegPlus1.isPowerOf2()) {
193 ShiftAmt = CVNegPlus1.logBase2();
194 AddSubOpc = TargetOpcode::G_SUB;
195 ShiftValUseIsLHS = false;
196 } else if (CVNegMinus1.isPowerOf2()) {
197 ShiftAmt = CVNegMinus1.logBase2();
198 AddSubOpc = TargetOpcode::G_ADD;
199 NegateResult = true;
200 } else
201 return false;
202 }
203
204 if (NegateResult && TrailingZeroes)
205 return false;
206
207 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
208 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
209 auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
210
211 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
212 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
213 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
214 assert(!(NegateResult && TrailingZeroes) &&
215 "NegateResult and TrailingZeroes cannot both be true for now.");
216 // Negate the result.
217 if (NegateResult) {
218 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
219 return;
220 }
221 // Shift the result.
222 if (TrailingZeroes) {
223 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
224 return;
225 }
226 B.buildCopy(DstReg, Res.getReg(0));
227 };
228 return true;
229 }
230
applyAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)231 bool applyAArch64MulConstCombine(
232 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
233 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
234 B.setInstrAndDebugLoc(MI);
235 ApplyFn(B, MI.getOperand(0).getReg());
236 MI.eraseFromParent();
237 return true;
238 }
239
240 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
241 #include "AArch64GenPostLegalizeGICombiner.inc"
242 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
243
244 namespace {
245 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
246 #include "AArch64GenPostLegalizeGICombiner.inc"
247 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
248
249 class AArch64PostLegalizerCombinerInfo : public CombinerInfo {
250 GISelKnownBits *KB;
251 MachineDominatorTree *MDT;
252
253 public:
254 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
255
AArch64PostLegalizerCombinerInfo(bool EnableOpt,bool OptSize,bool MinSize,GISelKnownBits * KB,MachineDominatorTree * MDT)256 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
257 GISelKnownBits *KB,
258 MachineDominatorTree *MDT)
259 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
260 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize),
261 KB(KB), MDT(MDT) {
262 if (!GeneratedRuleCfg.parseCommandLineOption())
263 report_fatal_error("Invalid rule identifier");
264 }
265
266 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
267 MachineIRBuilder &B) const override;
268 };
269
combine(GISelChangeObserver & Observer,MachineInstr & MI,MachineIRBuilder & B) const270 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
271 MachineInstr &MI,
272 MachineIRBuilder &B) const {
273 const auto *LI =
274 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo();
275 CombinerHelper Helper(Observer, B, KB, MDT, LI);
276 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg);
277 return Generated.tryCombineAll(Observer, MI, B, Helper);
278 }
279
280 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
281 #include "AArch64GenPostLegalizeGICombiner.inc"
282 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
283
284 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
285 public:
286 static char ID;
287
288 AArch64PostLegalizerCombiner(bool IsOptNone = false);
289
getPassName() const290 StringRef getPassName() const override {
291 return "AArch64PostLegalizerCombiner";
292 }
293
294 bool runOnMachineFunction(MachineFunction &MF) override;
295 void getAnalysisUsage(AnalysisUsage &AU) const override;
296
297 private:
298 bool IsOptNone;
299 };
300 } // end anonymous namespace
301
getAnalysisUsage(AnalysisUsage & AU) const302 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
303 AU.addRequired<TargetPassConfig>();
304 AU.setPreservesCFG();
305 getSelectionDAGFallbackAnalysisUsage(AU);
306 AU.addRequired<GISelKnownBitsAnalysis>();
307 AU.addPreserved<GISelKnownBitsAnalysis>();
308 if (!IsOptNone) {
309 AU.addRequired<MachineDominatorTree>();
310 AU.addPreserved<MachineDominatorTree>();
311 }
312 MachineFunctionPass::getAnalysisUsage(AU);
313 }
314
AArch64PostLegalizerCombiner(bool IsOptNone)315 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
316 : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
317 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
318 }
319
runOnMachineFunction(MachineFunction & MF)320 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
321 if (MF.getProperties().hasProperty(
322 MachineFunctionProperties::Property::FailedISel))
323 return false;
324 assert(MF.getProperties().hasProperty(
325 MachineFunctionProperties::Property::Legalized) &&
326 "Expected a legalized function?");
327 auto *TPC = &getAnalysis<TargetPassConfig>();
328 const Function &F = MF.getFunction();
329 bool EnableOpt =
330 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
331 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
332 MachineDominatorTree *MDT =
333 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
334 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
335 F.hasMinSize(), KB, MDT);
336 Combiner C(PCInfo, TPC);
337 return C.combineMachineInstrs(MF, /*CSEInfo*/ nullptr);
338 }
339
340 char AArch64PostLegalizerCombiner::ID = 0;
341 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
342 "Combine AArch64 MachineInstrs after legalization", false,
343 false)
344 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
345 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
346 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
347 "Combine AArch64 MachineInstrs after legalization", false,
348 false)
349
350 namespace llvm {
createAArch64PostLegalizerCombiner(bool IsOptNone)351 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
352 return new AArch64PostLegalizerCombiner(IsOptNone);
353 }
354 } // end namespace llvm
355