1 //===-- AArch64CondBrTuning.cpp --- Conditional branch tuning for AArch64 -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file contains a pass that transforms CBZ/CBNZ/TBZ/TBNZ instructions
10 /// into a conditional branch (B.cond), when the NZCV flags can be set for
11 /// "free". This is preferred on targets that have more flexibility when
12 /// scheduling B.cond instructions as compared to CBZ/CBNZ/TBZ/TBNZ (assuming
13 /// all other variables are equal). This can also reduce register pressure.
14 ///
15 /// A few examples:
16 ///
17 /// 1) add w8, w0, w1 -> cmn w0, w1 ; CMN is an alias of ADDS.
18 /// cbz w8, .LBB_2 -> b.eq .LBB0_2
19 ///
20 /// 2) add w8, w0, w1 -> adds w8, w0, w1 ; w8 has multiple uses.
21 /// cbz w8, .LBB1_2 -> b.eq .LBB1_2
22 ///
23 /// 3) sub w8, w0, w1 -> subs w8, w0, w1 ; w8 has multiple uses.
24 /// tbz w8, #31, .LBB6_2 -> b.pl .LBB6_2
25 ///
26 //===----------------------------------------------------------------------===//
27
28 #include "AArch64.h"
29 #include "AArch64Subtarget.h"
30 #include "llvm/CodeGen/MachineFunction.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineInstrBuilder.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/Passes.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/CodeGen/TargetSubtargetInfo.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
40
41 using namespace llvm;
42
43 #define DEBUG_TYPE "aarch64-cond-br-tuning"
44 #define AARCH64_CONDBR_TUNING_NAME "AArch64 Conditional Branch Tuning"
45
46 namespace {
47 class AArch64CondBrTuning : public MachineFunctionPass {
48 const AArch64InstrInfo *TII;
49 const TargetRegisterInfo *TRI;
50
51 MachineRegisterInfo *MRI;
52
53 public:
54 static char ID;
AArch64CondBrTuning()55 AArch64CondBrTuning() : MachineFunctionPass(ID) {
56 initializeAArch64CondBrTuningPass(*PassRegistry::getPassRegistry());
57 }
58 void getAnalysisUsage(AnalysisUsage &AU) const override;
59 bool runOnMachineFunction(MachineFunction &MF) override;
getPassName() const60 StringRef getPassName() const override { return AARCH64_CONDBR_TUNING_NAME; }
61
62 private:
63 MachineInstr *getOperandDef(const MachineOperand &MO);
64 MachineInstr *convertToFlagSetting(MachineInstr &MI, bool IsFlagSetting);
65 MachineInstr *convertToCondBr(MachineInstr &MI);
66 bool tryToTuneBranch(MachineInstr &MI, MachineInstr &DefMI);
67 };
68 } // end anonymous namespace
69
70 char AArch64CondBrTuning::ID = 0;
71
72 INITIALIZE_PASS(AArch64CondBrTuning, "aarch64-cond-br-tuning",
73 AARCH64_CONDBR_TUNING_NAME, false, false)
74
getAnalysisUsage(AnalysisUsage & AU) const75 void AArch64CondBrTuning::getAnalysisUsage(AnalysisUsage &AU) const {
76 AU.setPreservesCFG();
77 MachineFunctionPass::getAnalysisUsage(AU);
78 }
79
getOperandDef(const MachineOperand & MO)80 MachineInstr *AArch64CondBrTuning::getOperandDef(const MachineOperand &MO) {
81 if (!Register::isVirtualRegister(MO.getReg()))
82 return nullptr;
83 return MRI->getUniqueVRegDef(MO.getReg());
84 }
85
convertToFlagSetting(MachineInstr & MI,bool IsFlagSetting)86 MachineInstr *AArch64CondBrTuning::convertToFlagSetting(MachineInstr &MI,
87 bool IsFlagSetting) {
88 // If this is already the flag setting version of the instruction (e.g., SUBS)
89 // just make sure the implicit-def of NZCV isn't marked dead.
90 if (IsFlagSetting) {
91 for (unsigned I = MI.getNumExplicitOperands(), E = MI.getNumOperands();
92 I != E; ++I) {
93 MachineOperand &MO = MI.getOperand(I);
94 if (MO.isReg() && MO.isDead() && MO.getReg() == AArch64::NZCV)
95 MO.setIsDead(false);
96 }
97 return &MI;
98 }
99 bool Is64Bit;
100 unsigned NewOpc = TII->convertToFlagSettingOpc(MI.getOpcode(), Is64Bit);
101 Register NewDestReg = MI.getOperand(0).getReg();
102 if (MRI->hasOneNonDBGUse(MI.getOperand(0).getReg()))
103 NewDestReg = Is64Bit ? AArch64::XZR : AArch64::WZR;
104
105 MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
106 TII->get(NewOpc), NewDestReg);
107 for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I)
108 MIB.add(MI.getOperand(I));
109
110 return MIB;
111 }
112
convertToCondBr(MachineInstr & MI)113 MachineInstr *AArch64CondBrTuning::convertToCondBr(MachineInstr &MI) {
114 AArch64CC::CondCode CC;
115 MachineBasicBlock *TargetMBB = TII->getBranchDestBlock(MI);
116 switch (MI.getOpcode()) {
117 default:
118 llvm_unreachable("Unexpected opcode!");
119
120 case AArch64::CBZW:
121 case AArch64::CBZX:
122 CC = AArch64CC::EQ;
123 break;
124 case AArch64::CBNZW:
125 case AArch64::CBNZX:
126 CC = AArch64CC::NE;
127 break;
128 case AArch64::TBZW:
129 case AArch64::TBZX:
130 CC = AArch64CC::PL;
131 break;
132 case AArch64::TBNZW:
133 case AArch64::TBNZX:
134 CC = AArch64CC::MI;
135 break;
136 }
137 return BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::Bcc))
138 .addImm(CC)
139 .addMBB(TargetMBB);
140 }
141
tryToTuneBranch(MachineInstr & MI,MachineInstr & DefMI)142 bool AArch64CondBrTuning::tryToTuneBranch(MachineInstr &MI,
143 MachineInstr &DefMI) {
144 // We don't want NZCV bits live across blocks.
145 if (MI.getParent() != DefMI.getParent())
146 return false;
147
148 bool IsFlagSetting = true;
149 unsigned MIOpc = MI.getOpcode();
150 MachineInstr *NewCmp = nullptr, *NewBr = nullptr;
151 switch (DefMI.getOpcode()) {
152 default:
153 return false;
154 case AArch64::ADDWri:
155 case AArch64::ADDWrr:
156 case AArch64::ADDWrs:
157 case AArch64::ADDWrx:
158 case AArch64::ANDWri:
159 case AArch64::ANDWrr:
160 case AArch64::ANDWrs:
161 case AArch64::BICWrr:
162 case AArch64::BICWrs:
163 case AArch64::SUBWri:
164 case AArch64::SUBWrr:
165 case AArch64::SUBWrs:
166 case AArch64::SUBWrx:
167 IsFlagSetting = false;
168 LLVM_FALLTHROUGH;
169 case AArch64::ADDSWri:
170 case AArch64::ADDSWrr:
171 case AArch64::ADDSWrs:
172 case AArch64::ADDSWrx:
173 case AArch64::ANDSWri:
174 case AArch64::ANDSWrr:
175 case AArch64::ANDSWrs:
176 case AArch64::BICSWrr:
177 case AArch64::BICSWrs:
178 case AArch64::SUBSWri:
179 case AArch64::SUBSWrr:
180 case AArch64::SUBSWrs:
181 case AArch64::SUBSWrx:
182 switch (MIOpc) {
183 default:
184 llvm_unreachable("Unexpected opcode!");
185
186 case AArch64::CBZW:
187 case AArch64::CBNZW:
188 case AArch64::TBZW:
189 case AArch64::TBNZW:
190 // Check to see if the TBZ/TBNZ is checking the sign bit.
191 if ((MIOpc == AArch64::TBZW || MIOpc == AArch64::TBNZW) &&
192 MI.getOperand(1).getImm() != 31)
193 return false;
194
195 // There must not be any instruction between DefMI and MI that clobbers or
196 // reads NZCV.
197 MachineBasicBlock::iterator I(DefMI), E(MI);
198 for (I = std::next(I); I != E; ++I) {
199 if (I->modifiesRegister(AArch64::NZCV, TRI) ||
200 I->readsRegister(AArch64::NZCV, TRI))
201 return false;
202 }
203 LLVM_DEBUG(dbgs() << " Replacing instructions:\n ");
204 LLVM_DEBUG(DefMI.print(dbgs()));
205 LLVM_DEBUG(dbgs() << " ");
206 LLVM_DEBUG(MI.print(dbgs()));
207
208 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting);
209 NewBr = convertToCondBr(MI);
210 break;
211 }
212 break;
213
214 case AArch64::ADDXri:
215 case AArch64::ADDXrr:
216 case AArch64::ADDXrs:
217 case AArch64::ADDXrx:
218 case AArch64::ANDXri:
219 case AArch64::ANDXrr:
220 case AArch64::ANDXrs:
221 case AArch64::BICXrr:
222 case AArch64::BICXrs:
223 case AArch64::SUBXri:
224 case AArch64::SUBXrr:
225 case AArch64::SUBXrs:
226 case AArch64::SUBXrx:
227 IsFlagSetting = false;
228 LLVM_FALLTHROUGH;
229 case AArch64::ADDSXri:
230 case AArch64::ADDSXrr:
231 case AArch64::ADDSXrs:
232 case AArch64::ADDSXrx:
233 case AArch64::ANDSXri:
234 case AArch64::ANDSXrr:
235 case AArch64::ANDSXrs:
236 case AArch64::BICSXrr:
237 case AArch64::BICSXrs:
238 case AArch64::SUBSXri:
239 case AArch64::SUBSXrr:
240 case AArch64::SUBSXrs:
241 case AArch64::SUBSXrx:
242 switch (MIOpc) {
243 default:
244 llvm_unreachable("Unexpected opcode!");
245
246 case AArch64::CBZX:
247 case AArch64::CBNZX:
248 case AArch64::TBZX:
249 case AArch64::TBNZX: {
250 // Check to see if the TBZ/TBNZ is checking the sign bit.
251 if ((MIOpc == AArch64::TBZX || MIOpc == AArch64::TBNZX) &&
252 MI.getOperand(1).getImm() != 63)
253 return false;
254 // There must not be any instruction between DefMI and MI that clobbers or
255 // reads NZCV.
256 MachineBasicBlock::iterator I(DefMI), E(MI);
257 for (I = std::next(I); I != E; ++I) {
258 if (I->modifiesRegister(AArch64::NZCV, TRI) ||
259 I->readsRegister(AArch64::NZCV, TRI))
260 return false;
261 }
262 LLVM_DEBUG(dbgs() << " Replacing instructions:\n ");
263 LLVM_DEBUG(DefMI.print(dbgs()));
264 LLVM_DEBUG(dbgs() << " ");
265 LLVM_DEBUG(MI.print(dbgs()));
266
267 NewCmp = convertToFlagSetting(DefMI, IsFlagSetting);
268 NewBr = convertToCondBr(MI);
269 break;
270 }
271 }
272 break;
273 }
274 (void)NewCmp; (void)NewBr;
275 assert(NewCmp && NewBr && "Expected new instructions.");
276
277 LLVM_DEBUG(dbgs() << " with instruction:\n ");
278 LLVM_DEBUG(NewCmp->print(dbgs()));
279 LLVM_DEBUG(dbgs() << " ");
280 LLVM_DEBUG(NewBr->print(dbgs()));
281
282 // If this was a flag setting version of the instruction, we use the original
283 // instruction by just clearing the dead marked on the implicit-def of NCZV.
284 // Therefore, we should not erase this instruction.
285 if (!IsFlagSetting)
286 DefMI.eraseFromParent();
287 MI.eraseFromParent();
288 return true;
289 }
290
runOnMachineFunction(MachineFunction & MF)291 bool AArch64CondBrTuning::runOnMachineFunction(MachineFunction &MF) {
292 if (skipFunction(MF.getFunction()))
293 return false;
294
295 LLVM_DEBUG(
296 dbgs() << "********** AArch64 Conditional Branch Tuning **********\n"
297 << "********** Function: " << MF.getName() << '\n');
298
299 TII = static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
300 TRI = MF.getSubtarget().getRegisterInfo();
301 MRI = &MF.getRegInfo();
302
303 bool Changed = false;
304 for (MachineBasicBlock &MBB : MF) {
305 bool LocalChange = false;
306 for (MachineBasicBlock::iterator I = MBB.getFirstTerminator(),
307 E = MBB.end();
308 I != E; ++I) {
309 MachineInstr &MI = *I;
310 switch (MI.getOpcode()) {
311 default:
312 break;
313 case AArch64::CBZW:
314 case AArch64::CBZX:
315 case AArch64::CBNZW:
316 case AArch64::CBNZX:
317 case AArch64::TBZW:
318 case AArch64::TBZX:
319 case AArch64::TBNZW:
320 case AArch64::TBNZX:
321 MachineInstr *DefMI = getOperandDef(MI.getOperand(0));
322 LocalChange = (DefMI && tryToTuneBranch(MI, *DefMI));
323 break;
324 }
325 // If the optimization was successful, we can't optimize any other
326 // branches because doing so would clobber the NZCV flags.
327 if (LocalChange) {
328 Changed = true;
329 break;
330 }
331 }
332 }
333 return Changed;
334 }
335
createAArch64CondBrTuning()336 FunctionPass *llvm::createAArch64CondBrTuning() {
337 return new AArch64CondBrTuning();
338 }
339