1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <algorithm>
28 #include <cassert>
29
30 using namespace llvm;
31
32 #define DEBUG_TYPE "machine-scheduler"
33
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
35 LLVM_DUMP_METHOD
printLivesAt(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)36 void llvm::printLivesAt(SlotIndex SI,
37 const LiveIntervals &LIS,
38 const MachineRegisterInfo &MRI) {
39 dbgs() << "Live regs at " << SI << ": "
40 << *LIS.getInstructionFromIndex(SI);
41 unsigned Num = 0;
42 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
43 const unsigned Reg = Register::index2VirtReg(I);
44 if (!LIS.hasInterval(Reg))
45 continue;
46 const auto &LI = LIS.getInterval(Reg);
47 if (LI.hasSubRanges()) {
48 bool firstTime = true;
49 for (const auto &S : LI.subranges()) {
50 if (!S.liveAt(SI)) continue;
51 if (firstTime) {
52 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo())
53 << '\n';
54 firstTime = false;
55 }
56 dbgs() << " " << S << '\n';
57 ++Num;
58 }
59 } else if (LI.liveAt(SI)) {
60 dbgs() << " " << LI << '\n';
61 ++Num;
62 }
63 }
64 if (!Num) dbgs() << " <none>\n";
65 }
66 #endif
67
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
69 const GCNRPTracker::LiveRegSet &S2) {
70 if (S1.size() != S2.size())
71 return false;
72
73 for (const auto &P : S1) {
74 auto I = S2.find(P.first);
75 if (I == S2.end() || I->second != P.second)
76 return false;
77 }
78 return true;
79 }
80
81
82 ///////////////////////////////////////////////////////////////////////////////
83 // GCNRegPressure
84
getRegKind(unsigned Reg,const MachineRegisterInfo & MRI)85 unsigned GCNRegPressure::getRegKind(unsigned Reg,
86 const MachineRegisterInfo &MRI) {
87 assert(Register::isVirtualRegister(Reg));
88 const auto RC = MRI.getRegClass(Reg);
89 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90 return STI->isSGPRClass(RC) ?
91 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92 STI->hasAGPRs(RC) ?
93 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
94 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
95 }
96
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)97 void GCNRegPressure::inc(unsigned Reg,
98 LaneBitmask PrevMask,
99 LaneBitmask NewMask,
100 const MachineRegisterInfo &MRI) {
101 if (NewMask == PrevMask)
102 return;
103
104 int Sign = 1;
105 if (NewMask < PrevMask) {
106 std::swap(NewMask, PrevMask);
107 Sign = -1;
108 }
109 #ifndef NDEBUG
110 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
111 #endif
112 switch (auto Kind = getRegKind(Reg, MRI)) {
113 case SGPR32:
114 case VGPR32:
115 case AGPR32:
116 assert(PrevMask.none() && NewMask == MaxMask);
117 Value[Kind] += Sign;
118 break;
119
120 case SGPR_TUPLE:
121 case VGPR_TUPLE:
122 case AGPR_TUPLE:
123 assert(NewMask < MaxMask || NewMask == MaxMask);
124 assert(PrevMask < NewMask);
125
126 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
127 Sign * (~PrevMask & NewMask).getNumLanes();
128
129 if (PrevMask.none()) {
130 assert(NewMask.any());
131 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
132 }
133 break;
134
135 default: llvm_unreachable("Unknown register kind");
136 }
137 }
138
less(const GCNSubtarget & ST,const GCNRegPressure & O,unsigned MaxOccupancy) const139 bool GCNRegPressure::less(const GCNSubtarget &ST,
140 const GCNRegPressure& O,
141 unsigned MaxOccupancy) const {
142 const auto SGPROcc = std::min(MaxOccupancy,
143 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
144 const auto VGPROcc = std::min(MaxOccupancy,
145 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
146 const auto OtherSGPROcc = std::min(MaxOccupancy,
147 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
148 const auto OtherVGPROcc = std::min(MaxOccupancy,
149 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
150
151 const auto Occ = std::min(SGPROcc, VGPROcc);
152 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
153 if (Occ != OtherOcc)
154 return Occ > OtherOcc;
155
156 bool SGPRImportant = SGPROcc < VGPROcc;
157 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
158
159 // if both pressures disagree on what is more important compare vgprs
160 if (SGPRImportant != OtherSGPRImportant) {
161 SGPRImportant = false;
162 }
163
164 // compare large regs pressure
165 bool SGPRFirst = SGPRImportant;
166 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
167 if (SGPRFirst) {
168 auto SW = getSGPRTuplesWeight();
169 auto OtherSW = O.getSGPRTuplesWeight();
170 if (SW != OtherSW)
171 return SW < OtherSW;
172 } else {
173 auto VW = getVGPRTuplesWeight();
174 auto OtherVW = O.getVGPRTuplesWeight();
175 if (VW != OtherVW)
176 return VW < OtherVW;
177 }
178 }
179 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
180 (getVGPRNum() < O.getVGPRNum());
181 }
182
183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
184 LLVM_DUMP_METHOD
print(raw_ostream & OS,const GCNSubtarget * ST) const185 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
186 OS << "VGPRs: " << Value[VGPR32] << ' ';
187 OS << "AGPRs: " << Value[AGPR32];
188 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
189 OS << ", SGPRs: " << getSGPRNum();
190 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
191 OS << ", LVGPR WT: " << getVGPRTuplesWeight()
192 << ", LSGPR WT: " << getSGPRTuplesWeight();
193 if (ST) OS << " -> Occ: " << getOccupancy(*ST);
194 OS << '\n';
195 }
196 #endif
197
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)198 static LaneBitmask getDefRegMask(const MachineOperand &MO,
199 const MachineRegisterInfo &MRI) {
200 assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
201
202 // We don't rely on read-undef flag because in case of tentative schedule
203 // tracking it isn't set correctly yet. This works correctly however since
204 // use mask has been tracked before using LIS.
205 return MO.getSubReg() == 0 ?
206 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
207 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
208 }
209
getUsedRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI,const LiveIntervals & LIS)210 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
211 const MachineRegisterInfo &MRI,
212 const LiveIntervals &LIS) {
213 assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
214
215 if (auto SubReg = MO.getSubReg())
216 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
217
218 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
219 if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
220 return MaxMask;
221
222 // For a tentative schedule LIS isn't updated yet but livemask should remain
223 // the same on any schedule. Subreg defs can be reordered but they all must
224 // dominate uses anyway.
225 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
226 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
227 }
228
229 static SmallVector<RegisterMaskPair, 8>
collectVirtualRegUses(const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)230 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
231 const MachineRegisterInfo &MRI) {
232 SmallVector<RegisterMaskPair, 8> Res;
233 for (const auto &MO : MI.operands()) {
234 if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
235 continue;
236 if (!MO.isUse() || !MO.readsReg())
237 continue;
238
239 auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
240
241 auto Reg = MO.getReg();
242 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
243 return RM.RegUnit == Reg;
244 });
245 if (I != Res.end())
246 I->LaneMask |= UsedMask;
247 else
248 Res.push_back(RegisterMaskPair(Reg, UsedMask));
249 }
250 return Res;
251 }
252
253 ///////////////////////////////////////////////////////////////////////////////
254 // GCNRPTracker
255
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)256 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
257 SlotIndex SI,
258 const LiveIntervals &LIS,
259 const MachineRegisterInfo &MRI) {
260 LaneBitmask LiveMask;
261 const auto &LI = LIS.getInterval(Reg);
262 if (LI.hasSubRanges()) {
263 for (const auto &S : LI.subranges())
264 if (S.liveAt(SI)) {
265 LiveMask |= S.LaneMask;
266 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
267 LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
268 }
269 } else if (LI.liveAt(SI)) {
270 LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
271 }
272 return LiveMask;
273 }
274
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)275 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
276 const LiveIntervals &LIS,
277 const MachineRegisterInfo &MRI) {
278 GCNRPTracker::LiveRegSet LiveRegs;
279 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
280 auto Reg = Register::index2VirtReg(I);
281 if (!LIS.hasInterval(Reg))
282 continue;
283 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
284 if (LiveMask.any())
285 LiveRegs[Reg] = LiveMask;
286 }
287 return LiveRegs;
288 }
289
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)290 void GCNRPTracker::reset(const MachineInstr &MI,
291 const LiveRegSet *LiveRegsCopy,
292 bool After) {
293 const MachineFunction &MF = *MI.getMF();
294 MRI = &MF.getRegInfo();
295 if (LiveRegsCopy) {
296 if (&LiveRegs != LiveRegsCopy)
297 LiveRegs = *LiveRegsCopy;
298 } else {
299 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
300 : getLiveRegsBefore(MI, LIS);
301 }
302
303 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
304 }
305
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)306 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
307 const LiveRegSet *LiveRegsCopy) {
308 GCNRPTracker::reset(MI, LiveRegsCopy, true);
309 }
310
recede(const MachineInstr & MI)311 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
312 assert(MRI && "call reset first");
313
314 LastTrackedMI = &MI;
315
316 if (MI.isDebugInstr())
317 return;
318
319 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
320
321 // calc pressure at the MI (defs + uses)
322 auto AtMIPressure = CurPressure;
323 for (const auto &U : RegUses) {
324 auto LiveMask = LiveRegs[U.RegUnit];
325 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
326 }
327 // update max pressure
328 MaxPressure = max(AtMIPressure, MaxPressure);
329
330 for (const auto &MO : MI.defs()) {
331 if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()) || MO.isDead())
332 continue;
333
334 auto Reg = MO.getReg();
335 auto I = LiveRegs.find(Reg);
336 if (I == LiveRegs.end())
337 continue;
338 auto &LiveMask = I->second;
339 auto PrevMask = LiveMask;
340 LiveMask &= ~getDefRegMask(MO, *MRI);
341 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
342 if (LiveMask.none())
343 LiveRegs.erase(I);
344 }
345 for (const auto &U : RegUses) {
346 auto &LiveMask = LiveRegs[U.RegUnit];
347 auto PrevMask = LiveMask;
348 LiveMask |= U.LaneMask;
349 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
350 }
351 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
352 }
353
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)354 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
355 const LiveRegSet *LiveRegsCopy) {
356 MRI = &MI.getParent()->getParent()->getRegInfo();
357 LastTrackedMI = nullptr;
358 MBBEnd = MI.getParent()->end();
359 NextMI = &MI;
360 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
361 if (NextMI == MBBEnd)
362 return false;
363 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
364 return true;
365 }
366
advanceBeforeNext()367 bool GCNDownwardRPTracker::advanceBeforeNext() {
368 assert(MRI && "call reset first");
369
370 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
371 if (NextMI == MBBEnd)
372 return false;
373
374 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
375 assert(SI.isValid());
376
377 // Remove dead registers or mask bits.
378 for (auto &It : LiveRegs) {
379 const LiveInterval &LI = LIS.getInterval(It.first);
380 if (LI.hasSubRanges()) {
381 for (const auto &S : LI.subranges()) {
382 if (!S.liveAt(SI)) {
383 auto PrevMask = It.second;
384 It.second &= ~S.LaneMask;
385 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
386 }
387 }
388 } else if (!LI.liveAt(SI)) {
389 auto PrevMask = It.second;
390 It.second = LaneBitmask::getNone();
391 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
392 }
393 if (It.second.none())
394 LiveRegs.erase(It.first);
395 }
396
397 MaxPressure = max(MaxPressure, CurPressure);
398
399 return true;
400 }
401
advanceToNext()402 void GCNDownwardRPTracker::advanceToNext() {
403 LastTrackedMI = &*NextMI++;
404
405 // Add new registers or mask bits.
406 for (const auto &MO : LastTrackedMI->defs()) {
407 if (!MO.isReg())
408 continue;
409 Register Reg = MO.getReg();
410 if (!Register::isVirtualRegister(Reg))
411 continue;
412 auto &LiveMask = LiveRegs[Reg];
413 auto PrevMask = LiveMask;
414 LiveMask |= getDefRegMask(MO, *MRI);
415 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
416 }
417
418 MaxPressure = max(MaxPressure, CurPressure);
419 }
420
advance()421 bool GCNDownwardRPTracker::advance() {
422 // If we have just called reset live set is actual.
423 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
424 return false;
425 advanceToNext();
426 return true;
427 }
428
advance(MachineBasicBlock::const_iterator End)429 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
430 while (NextMI != End)
431 if (!advance()) return false;
432 return true;
433 }
434
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)435 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
436 MachineBasicBlock::const_iterator End,
437 const LiveRegSet *LiveRegsCopy) {
438 reset(*Begin, LiveRegsCopy);
439 return advance(End);
440 }
441
442 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
443 LLVM_DUMP_METHOD
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI)444 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
445 const GCNRPTracker::LiveRegSet &TrackedLR,
446 const TargetRegisterInfo *TRI) {
447 for (auto const &P : TrackedLR) {
448 auto I = LISLR.find(P.first);
449 if (I == LISLR.end()) {
450 dbgs() << " " << printReg(P.first, TRI)
451 << ":L" << PrintLaneMask(P.second)
452 << " isn't found in LIS reported set\n";
453 }
454 else if (I->second != P.second) {
455 dbgs() << " " << printReg(P.first, TRI)
456 << " masks doesn't match: LIS reported "
457 << PrintLaneMask(I->second)
458 << ", tracked "
459 << PrintLaneMask(P.second)
460 << '\n';
461 }
462 }
463 for (auto const &P : LISLR) {
464 auto I = TrackedLR.find(P.first);
465 if (I == TrackedLR.end()) {
466 dbgs() << " " << printReg(P.first, TRI)
467 << ":L" << PrintLaneMask(P.second)
468 << " isn't found in tracked set\n";
469 }
470 }
471 }
472
isValid() const473 bool GCNUpwardRPTracker::isValid() const {
474 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
475 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
476 const auto &TrackedLR = LiveRegs;
477
478 if (!isEqual(LISLR, TrackedLR)) {
479 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
480 " LIS reported livesets mismatch:\n";
481 printLivesAt(SI, LIS, *MRI);
482 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
483 return false;
484 }
485
486 auto LISPressure = getRegPressure(*MRI, LISLR);
487 if (LISPressure != CurPressure) {
488 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
489 CurPressure.print(dbgs());
490 dbgs() << "LIS rpt: ";
491 LISPressure.print(dbgs());
492 return false;
493 }
494 return true;
495 }
496
printLiveRegs(raw_ostream & OS,const LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)497 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
498 const MachineRegisterInfo &MRI) {
499 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
500 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
501 unsigned Reg = Register::index2VirtReg(I);
502 auto It = LiveRegs.find(Reg);
503 if (It != LiveRegs.end() && It->second.any())
504 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
505 << PrintLaneMask(It->second);
506 }
507 OS << '\n';
508 }
509 #endif
510