• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- AArch64A57FPLoadBalancing.cpp - Balance FP ops statically on A57---===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // For best-case performance on Cortex-A57, we should try to use a balanced
10 // mix of odd and even D-registers when performing a critical sequence of
11 // independent, non-quadword FP/ASIMD floating-point multiply or
12 // multiply-accumulate operations.
13 //
14 // This pass attempts to detect situations where the register allocation may
15 // adversely affect this load balancing and to change the registers used so as
16 // to better utilize the CPU.
17 //
18 // Ideally we'd just take each multiply or multiply-accumulate in turn and
19 // allocate it alternating even or odd registers. However, multiply-accumulates
20 // are most efficiently performed in the same functional unit as their
21 // accumulation operand. Therefore this pass tries to find maximal sequences
22 // ("Chains") of multiply-accumulates linked via their accumulation operand,
23 // and assign them all the same "color" (oddness/evenness).
24 //
25 // This optimization affects S-register and D-register floating point
26 // multiplies and FMADD/FMAs, as well as vector (floating point only) muls and
27 // FMADD/FMA. Q register instructions (and 128-bit vector instructions) are
28 // not affected.
29 //===----------------------------------------------------------------------===//
30 
31 #include "AArch64.h"
32 #include "AArch64InstrInfo.h"
33 #include "AArch64Subtarget.h"
34 #include "llvm/ADT/BitVector.h"
35 #include "llvm/ADT/EquivalenceClasses.h"
36 #include "llvm/CodeGen/MachineFunction.h"
37 #include "llvm/CodeGen/MachineFunctionPass.h"
38 #include "llvm/CodeGen/MachineInstr.h"
39 #include "llvm/CodeGen/MachineInstrBuilder.h"
40 #include "llvm/CodeGen/MachineRegisterInfo.h"
41 #include "llvm/CodeGen/RegisterClassInfo.h"
42 #include "llvm/CodeGen/RegisterScavenging.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/raw_ostream.h"
46 using namespace llvm;
47 
48 #define DEBUG_TYPE "aarch64-a57-fp-load-balancing"
49 
50 // Enforce the algorithm to use the scavenged register even when the original
51 // destination register is the correct color. Used for testing.
52 static cl::opt<bool>
53 TransformAll("aarch64-a57-fp-load-balancing-force-all",
54              cl::desc("Always modify dest registers regardless of color"),
55              cl::init(false), cl::Hidden);
56 
57 // Never use the balance information obtained from chains - return a specific
58 // color always. Used for testing.
59 static cl::opt<unsigned>
60 OverrideBalance("aarch64-a57-fp-load-balancing-override",
61               cl::desc("Ignore balance information, always return "
62                        "(1: Even, 2: Odd)."),
63               cl::init(0), cl::Hidden);
64 
65 //===----------------------------------------------------------------------===//
66 // Helper functions
67 
68 // Is the instruction a type of multiply on 64-bit (or 32-bit) FPRs?
isMul(MachineInstr * MI)69 static bool isMul(MachineInstr *MI) {
70   switch (MI->getOpcode()) {
71   case AArch64::FMULSrr:
72   case AArch64::FNMULSrr:
73   case AArch64::FMULDrr:
74   case AArch64::FNMULDrr:
75     return true;
76   default:
77     return false;
78   }
79 }
80 
81 // Is the instruction a type of FP multiply-accumulate on 64-bit (or 32-bit) FPRs?
isMla(MachineInstr * MI)82 static bool isMla(MachineInstr *MI) {
83   switch (MI->getOpcode()) {
84   case AArch64::FMSUBSrrr:
85   case AArch64::FMADDSrrr:
86   case AArch64::FNMSUBSrrr:
87   case AArch64::FNMADDSrrr:
88   case AArch64::FMSUBDrrr:
89   case AArch64::FMADDDrrr:
90   case AArch64::FNMSUBDrrr:
91   case AArch64::FNMADDDrrr:
92     return true;
93   default:
94     return false;
95   }
96 }
97 
98 namespace llvm {
99 static void initializeAArch64A57FPLoadBalancingPass(PassRegistry &);
100 }
101 
102 //===----------------------------------------------------------------------===//
103 
104 namespace {
105 /// A "color", which is either even or odd. Yes, these aren't really colors
106 /// but the algorithm is conceptually doing two-color graph coloring.
107 enum class Color { Even, Odd };
108 #ifndef NDEBUG
109 static const char *ColorNames[2] = { "Even", "Odd" };
110 #endif
111 
112 class Chain;
113 
114 class AArch64A57FPLoadBalancing : public MachineFunctionPass {
115   MachineRegisterInfo *MRI;
116   const TargetRegisterInfo *TRI;
117   RegisterClassInfo RCI;
118 
119 public:
120   static char ID;
AArch64A57FPLoadBalancing()121   explicit AArch64A57FPLoadBalancing() : MachineFunctionPass(ID) {
122     initializeAArch64A57FPLoadBalancingPass(*PassRegistry::getPassRegistry());
123   }
124 
125   bool runOnMachineFunction(MachineFunction &F) override;
126 
getRequiredProperties() const127   MachineFunctionProperties getRequiredProperties() const override {
128     return MachineFunctionProperties().set(
129         MachineFunctionProperties::Property::AllVRegsAllocated);
130   }
131 
getPassName() const132   const char *getPassName() const override {
133     return "A57 FP Anti-dependency breaker";
134   }
135 
getAnalysisUsage(AnalysisUsage & AU) const136   void getAnalysisUsage(AnalysisUsage &AU) const override {
137     AU.setPreservesCFG();
138     MachineFunctionPass::getAnalysisUsage(AU);
139   }
140 
141 private:
142   bool runOnBasicBlock(MachineBasicBlock &MBB);
143   bool colorChainSet(std::vector<Chain*> GV, MachineBasicBlock &MBB,
144                      int &Balance);
145   bool colorChain(Chain *G, Color C, MachineBasicBlock &MBB);
146   int scavengeRegister(Chain *G, Color C, MachineBasicBlock &MBB);
147   void scanInstruction(MachineInstr *MI, unsigned Idx,
148                        std::map<unsigned, Chain*> &Active,
149                        std::vector<std::unique_ptr<Chain>> &AllChains);
150   void maybeKillChain(MachineOperand &MO, unsigned Idx,
151                       std::map<unsigned, Chain*> &RegChains);
152   Color getColor(unsigned Register);
153   Chain *getAndEraseNext(Color PreferredColor, std::vector<Chain*> &L);
154 };
155 }
156 
157 char AArch64A57FPLoadBalancing::ID = 0;
158 
159 INITIALIZE_PASS_BEGIN(AArch64A57FPLoadBalancing, DEBUG_TYPE,
160                       "AArch64 A57 FP Load-Balancing", false, false)
161 INITIALIZE_PASS_END(AArch64A57FPLoadBalancing, DEBUG_TYPE,
162                     "AArch64 A57 FP Load-Balancing", false, false)
163 
164 namespace {
165 /// A Chain is a sequence of instructions that are linked together by
166 /// an accumulation operand. For example:
167 ///
168 ///   fmul d0<def>, ?
169 ///   fmla d1<def>, ?, ?, d0<kill>
170 ///   fmla d2<def>, ?, ?, d1<kill>
171 ///
172 /// There may be other instructions interleaved in the sequence that
173 /// do not belong to the chain. These other instructions must not use
174 /// the "chain" register at any point.
175 ///
176 /// We currently only support chains where the "chain" operand is killed
177 /// at each link in the chain for simplicity.
178 /// A chain has three important instructions - Start, Last and Kill.
179 ///   * The start instruction is the first instruction in the chain.
180 ///   * Last is the final instruction in the chain.
181 ///   * Kill may or may not be defined. If defined, Kill is the instruction
182 ///     where the outgoing value of the Last instruction is killed.
183 ///     This information is important as if we know the outgoing value is
184 ///     killed with no intervening uses, we can safely change its register.
185 ///
186 /// Without a kill instruction, we must assume the outgoing value escapes
187 /// beyond our model and either must not change its register or must
188 /// create a fixup FMOV to keep the old register value consistent.
189 ///
190 class Chain {
191 public:
192   /// The important (marker) instructions.
193   MachineInstr *StartInst, *LastInst, *KillInst;
194   /// The index, from the start of the basic block, that each marker
195   /// appears. These are stored so we can do quick interval tests.
196   unsigned StartInstIdx, LastInstIdx, KillInstIdx;
197   /// All instructions in the chain.
198   std::set<MachineInstr*> Insts;
199   /// True if KillInst cannot be modified. If this is true,
200   /// we cannot change LastInst's outgoing register.
201   /// This will be true for tied values and regmasks.
202   bool KillIsImmutable;
203   /// The "color" of LastInst. This will be the preferred chain color,
204   /// as changing intermediate nodes is easy but changing the last
205   /// instruction can be more tricky.
206   Color LastColor;
207 
Chain(MachineInstr * MI,unsigned Idx,Color C)208   Chain(MachineInstr *MI, unsigned Idx, Color C)
209       : StartInst(MI), LastInst(MI), KillInst(nullptr),
210         StartInstIdx(Idx), LastInstIdx(Idx), KillInstIdx(0),
211         LastColor(C) {
212     Insts.insert(MI);
213   }
214 
215   /// Add a new instruction into the chain. The instruction's dest operand
216   /// has the given color.
add(MachineInstr * MI,unsigned Idx,Color C)217   void add(MachineInstr *MI, unsigned Idx, Color C) {
218     LastInst = MI;
219     LastInstIdx = Idx;
220     LastColor = C;
221     assert((KillInstIdx == 0 || LastInstIdx < KillInstIdx) &&
222            "Chain: broken invariant. A Chain can only be killed after its last "
223            "def");
224 
225     Insts.insert(MI);
226   }
227 
228   /// Return true if MI is a member of the chain.
contains(MachineInstr & MI)229   bool contains(MachineInstr &MI) { return Insts.count(&MI) > 0; }
230 
231   /// Return the number of instructions in the chain.
size() const232   unsigned size() const {
233     return Insts.size();
234   }
235 
236   /// Inform the chain that its last active register (the dest register of
237   /// LastInst) is killed by MI with no intervening uses or defs.
setKill(MachineInstr * MI,unsigned Idx,bool Immutable)238   void setKill(MachineInstr *MI, unsigned Idx, bool Immutable) {
239     KillInst = MI;
240     KillInstIdx = Idx;
241     KillIsImmutable = Immutable;
242     assert((KillInstIdx == 0 || LastInstIdx < KillInstIdx) &&
243            "Chain: broken invariant. A Chain can only be killed after its last "
244            "def");
245   }
246 
247   /// Return the first instruction in the chain.
getStart() const248   MachineInstr *getStart() const { return StartInst; }
249   /// Return the last instruction in the chain.
getLast() const250   MachineInstr *getLast() const { return LastInst; }
251   /// Return the "kill" instruction (as set with setKill()) or NULL.
getKill() const252   MachineInstr *getKill() const { return KillInst; }
253   /// Return an instruction that can be used as an iterator for the end
254   /// of the chain. This is the maximum of KillInst (if set) and LastInst.
end() const255   MachineBasicBlock::iterator end() const {
256     return ++MachineBasicBlock::iterator(KillInst ? KillInst : LastInst);
257   }
begin() const258   MachineBasicBlock::iterator begin() const { return getStart(); }
259 
260   /// Can the Kill instruction (assuming one exists) be modified?
isKillImmutable() const261   bool isKillImmutable() const { return KillIsImmutable; }
262 
263   /// Return the preferred color of this chain.
getPreferredColor()264   Color getPreferredColor() {
265     if (OverrideBalance != 0)
266       return OverrideBalance == 1 ? Color::Even : Color::Odd;
267     return LastColor;
268   }
269 
270   /// Return true if this chain (StartInst..KillInst) overlaps with Other.
rangeOverlapsWith(const Chain & Other) const271   bool rangeOverlapsWith(const Chain &Other) const {
272     unsigned End = KillInst ? KillInstIdx : LastInstIdx;
273     unsigned OtherEnd = Other.KillInst ?
274       Other.KillInstIdx : Other.LastInstIdx;
275 
276     return StartInstIdx <= OtherEnd && Other.StartInstIdx <= End;
277   }
278 
279   /// Return true if this chain starts before Other.
startsBefore(const Chain * Other) const280   bool startsBefore(const Chain *Other) const {
281     return StartInstIdx < Other->StartInstIdx;
282   }
283 
284   /// Return true if the group will require a fixup MOV at the end.
requiresFixup() const285   bool requiresFixup() const {
286     return (getKill() && isKillImmutable()) || !getKill();
287   }
288 
289   /// Return a simple string representation of the chain.
str() const290   std::string str() const {
291     std::string S;
292     raw_string_ostream OS(S);
293 
294     OS << "{";
295     StartInst->print(OS, /* SkipOpers= */true);
296     OS << " -> ";
297     LastInst->print(OS, /* SkipOpers= */true);
298     if (KillInst) {
299       OS << " (kill @ ";
300       KillInst->print(OS, /* SkipOpers= */true);
301       OS << ")";
302     }
303     OS << "}";
304 
305     return OS.str();
306   }
307 
308 };
309 
310 } // end anonymous namespace
311 
312 //===----------------------------------------------------------------------===//
313 
runOnMachineFunction(MachineFunction & F)314 bool AArch64A57FPLoadBalancing::runOnMachineFunction(MachineFunction &F) {
315   if (skipFunction(*F.getFunction()))
316     return false;
317 
318   if (!F.getSubtarget<AArch64Subtarget>().balanceFPOps())
319     return false;
320 
321   bool Changed = false;
322   DEBUG(dbgs() << "***** AArch64A57FPLoadBalancing *****\n");
323 
324   MRI = &F.getRegInfo();
325   TRI = F.getRegInfo().getTargetRegisterInfo();
326   RCI.runOnMachineFunction(F);
327 
328   for (auto &MBB : F) {
329     Changed |= runOnBasicBlock(MBB);
330   }
331 
332   return Changed;
333 }
334 
runOnBasicBlock(MachineBasicBlock & MBB)335 bool AArch64A57FPLoadBalancing::runOnBasicBlock(MachineBasicBlock &MBB) {
336   bool Changed = false;
337   DEBUG(dbgs() << "Running on MBB: " << MBB << " - scanning instructions...\n");
338 
339   // First, scan the basic block producing a set of chains.
340 
341   // The currently "active" chains - chains that can be added to and haven't
342   // been killed yet. This is keyed by register - all chains can only have one
343   // "link" register between each inst in the chain.
344   std::map<unsigned, Chain*> ActiveChains;
345   std::vector<std::unique_ptr<Chain>> AllChains;
346   unsigned Idx = 0;
347   for (auto &MI : MBB)
348     scanInstruction(&MI, Idx++, ActiveChains, AllChains);
349 
350   DEBUG(dbgs() << "Scan complete, "<< AllChains.size() << " chains created.\n");
351 
352   // Group the chains into disjoint sets based on their liveness range. This is
353   // a poor-man's version of graph coloring. Ideally we'd create an interference
354   // graph and perform full-on graph coloring on that, but;
355   //   (a) That's rather heavyweight for only two colors.
356   //   (b) We expect multiple disjoint interference regions - in practice the live
357   //       range of chains is quite small and they are clustered between loads
358   //       and stores.
359   EquivalenceClasses<Chain*> EC;
360   for (auto &I : AllChains)
361     EC.insert(I.get());
362 
363   for (auto &I : AllChains)
364     for (auto &J : AllChains)
365       if (I != J && I->rangeOverlapsWith(*J))
366         EC.unionSets(I.get(), J.get());
367   DEBUG(dbgs() << "Created " << EC.getNumClasses() << " disjoint sets.\n");
368 
369   // Now we assume that every member of an equivalence class interferes
370   // with every other member of that class, and with no members of other classes.
371 
372   // Convert the EquivalenceClasses to a simpler set of sets.
373   std::vector<std::vector<Chain*> > V;
374   for (auto I = EC.begin(), E = EC.end(); I != E; ++I) {
375     std::vector<Chain*> Cs(EC.member_begin(I), EC.member_end());
376     if (Cs.empty()) continue;
377     V.push_back(std::move(Cs));
378   }
379 
380   // Now we have a set of sets, order them by start address so
381   // we can iterate over them sequentially.
382   std::sort(V.begin(), V.end(),
383             [](const std::vector<Chain*> &A,
384                const std::vector<Chain*> &B) {
385       return A.front()->startsBefore(B.front());
386     });
387 
388   // As we only have two colors, we can track the global (BB-level) balance of
389   // odds versus evens. We aim to keep this near zero to keep both execution
390   // units fed.
391   // Positive means we're even-heavy, negative we're odd-heavy.
392   //
393   // FIXME: If chains have interdependencies, for example:
394   //   mul r0, r1, r2
395   //   mul r3, r0, r1
396   // We do not model this and may color each one differently, assuming we'll
397   // get ILP when we obviously can't. This hasn't been seen to be a problem
398   // in practice so far, so we simplify the algorithm by ignoring it.
399   int Parity = 0;
400 
401   for (auto &I : V)
402     Changed |= colorChainSet(std::move(I), MBB, Parity);
403 
404   return Changed;
405 }
406 
getAndEraseNext(Color PreferredColor,std::vector<Chain * > & L)407 Chain *AArch64A57FPLoadBalancing::getAndEraseNext(Color PreferredColor,
408                                                   std::vector<Chain*> &L) {
409   if (L.empty())
410     return nullptr;
411 
412   // We try and get the best candidate from L to color next, given that our
413   // preferred color is "PreferredColor". L is ordered from larger to smaller
414   // chains. It is beneficial to color the large chains before the small chains,
415   // but if we can't find a chain of the maximum length with the preferred color,
416   // we fuzz the size and look for slightly smaller chains before giving up and
417   // returning a chain that must be recolored.
418 
419   // FIXME: Does this need to be configurable?
420   const unsigned SizeFuzz = 1;
421   unsigned MinSize = L.front()->size() - SizeFuzz;
422   for (auto I = L.begin(), E = L.end(); I != E; ++I) {
423     if ((*I)->size() <= MinSize) {
424       // We've gone past the size limit. Return the previous item.
425       Chain *Ch = *--I;
426       L.erase(I);
427       return Ch;
428     }
429 
430     if ((*I)->getPreferredColor() == PreferredColor) {
431       Chain *Ch = *I;
432       L.erase(I);
433       return Ch;
434     }
435   }
436 
437   // Bailout case - just return the first item.
438   Chain *Ch = L.front();
439   L.erase(L.begin());
440   return Ch;
441 }
442 
colorChainSet(std::vector<Chain * > GV,MachineBasicBlock & MBB,int & Parity)443 bool AArch64A57FPLoadBalancing::colorChainSet(std::vector<Chain*> GV,
444                                               MachineBasicBlock &MBB,
445                                               int &Parity) {
446   bool Changed = false;
447   DEBUG(dbgs() << "colorChainSet(): #sets=" << GV.size() << "\n");
448 
449   // Sort by descending size order so that we allocate the most important
450   // sets first.
451   // Tie-break equivalent sizes by sorting chains requiring fixups before
452   // those without fixups. The logic here is that we should look at the
453   // chains that we cannot change before we look at those we can,
454   // so the parity counter is updated and we know what color we should
455   // change them to!
456   // Final tie-break with instruction order so pass output is stable (i.e. not
457   // dependent on malloc'd pointer values).
458   std::sort(GV.begin(), GV.end(), [](const Chain *G1, const Chain *G2) {
459       if (G1->size() != G2->size())
460         return G1->size() > G2->size();
461       if (G1->requiresFixup() != G2->requiresFixup())
462         return G1->requiresFixup() > G2->requiresFixup();
463       // Make sure startsBefore() produces a stable final order.
464       assert((G1 == G2 || (G1->startsBefore(G2) ^ G2->startsBefore(G1))) &&
465              "Starts before not total order!");
466       return G1->startsBefore(G2);
467     });
468 
469   Color PreferredColor = Parity < 0 ? Color::Even : Color::Odd;
470   while (Chain *G = getAndEraseNext(PreferredColor, GV)) {
471     // Start off by assuming we'll color to our own preferred color.
472     Color C = PreferredColor;
473     if (Parity == 0)
474       // But if we really don't care, use the chain's preferred color.
475       C = G->getPreferredColor();
476 
477     DEBUG(dbgs() << " - Parity=" << Parity << ", Color="
478           << ColorNames[(int)C] << "\n");
479 
480     // If we'll need a fixup FMOV, don't bother. Testing has shown that this
481     // happens infrequently and when it does it has at least a 50% chance of
482     // slowing code down instead of speeding it up.
483     if (G->requiresFixup() && C != G->getPreferredColor()) {
484       C = G->getPreferredColor();
485       DEBUG(dbgs() << " - " << G->str() << " - not worthwhile changing; "
486             "color remains " << ColorNames[(int)C] << "\n");
487     }
488 
489     Changed |= colorChain(G, C, MBB);
490 
491     Parity += (C == Color::Even) ? G->size() : -G->size();
492     PreferredColor = Parity < 0 ? Color::Even : Color::Odd;
493   }
494 
495   return Changed;
496 }
497 
scavengeRegister(Chain * G,Color C,MachineBasicBlock & MBB)498 int AArch64A57FPLoadBalancing::scavengeRegister(Chain *G, Color C,
499                                                 MachineBasicBlock &MBB) {
500   RegScavenger RS;
501   RS.enterBasicBlock(MBB);
502   RS.forward(MachineBasicBlock::iterator(G->getStart()));
503 
504   // Can we find an appropriate register that is available throughout the life
505   // of the chain?
506   unsigned RegClassID = G->getStart()->getDesc().OpInfo[0].RegClass;
507   BitVector AvailableRegs = RS.getRegsAvailable(TRI->getRegClass(RegClassID));
508   for (MachineBasicBlock::iterator I = G->begin(), E = G->end(); I != E; ++I) {
509     RS.forward(I);
510     AvailableRegs &= RS.getRegsAvailable(TRI->getRegClass(RegClassID));
511 
512     // Remove any registers clobbered by a regmask or any def register that is
513     // immediately dead.
514     for (auto J : I->operands()) {
515       if (J.isRegMask())
516         AvailableRegs.clearBitsNotInMask(J.getRegMask());
517 
518       if (J.isReg() && J.isDef()) {
519         MCRegAliasIterator AI(J.getReg(), TRI, /*IncludeSelf=*/true);
520         if (J.isDead())
521           for (; AI.isValid(); ++AI)
522             AvailableRegs.reset(*AI);
523 #ifndef NDEBUG
524         else
525           for (; AI.isValid(); ++AI)
526             assert(!AvailableRegs[*AI] &&
527                    "Non-dead def should have been removed by now!");
528 #endif
529       }
530     }
531   }
532 
533   // Make sure we allocate in-order, to get the cheapest registers first.
534   auto Ord = RCI.getOrder(TRI->getRegClass(RegClassID));
535   for (auto Reg : Ord) {
536     if (!AvailableRegs[Reg])
537       continue;
538     if (C == getColor(Reg))
539       return Reg;
540   }
541 
542   return -1;
543 }
544 
colorChain(Chain * G,Color C,MachineBasicBlock & MBB)545 bool AArch64A57FPLoadBalancing::colorChain(Chain *G, Color C,
546                                            MachineBasicBlock &MBB) {
547   bool Changed = false;
548   DEBUG(dbgs() << " - colorChain(" << G->str() << ", "
549         << ColorNames[(int)C] << ")\n");
550 
551   // Try and obtain a free register of the right class. Without a register
552   // to play with we cannot continue.
553   int Reg = scavengeRegister(G, C, MBB);
554   if (Reg == -1) {
555     DEBUG(dbgs() << "Scavenging (thus coloring) failed!\n");
556     return false;
557   }
558   DEBUG(dbgs() << " - Scavenged register: " << TRI->getName(Reg) << "\n");
559 
560   std::map<unsigned, unsigned> Substs;
561   for (MachineInstr &I : *G) {
562     if (!G->contains(I) && (&I != G->getKill() || G->isKillImmutable()))
563       continue;
564 
565     // I is a member of G, or I is a mutable instruction that kills G.
566 
567     std::vector<unsigned> ToErase;
568     for (auto &U : I.operands()) {
569       if (U.isReg() && U.isUse() && Substs.find(U.getReg()) != Substs.end()) {
570         unsigned OrigReg = U.getReg();
571         U.setReg(Substs[OrigReg]);
572         if (U.isKill())
573           // Don't erase straight away, because there may be other operands
574           // that also reference this substitution!
575           ToErase.push_back(OrigReg);
576       } else if (U.isRegMask()) {
577         for (auto J : Substs) {
578           if (U.clobbersPhysReg(J.first))
579             ToErase.push_back(J.first);
580         }
581       }
582     }
583     // Now it's safe to remove the substs identified earlier.
584     for (auto J : ToErase)
585       Substs.erase(J);
586 
587     // Only change the def if this isn't the last instruction.
588     if (&I != G->getKill()) {
589       MachineOperand &MO = I.getOperand(0);
590 
591       bool Change = TransformAll || getColor(MO.getReg()) != C;
592       if (G->requiresFixup() && &I == G->getLast())
593         Change = false;
594 
595       if (Change) {
596         Substs[MO.getReg()] = Reg;
597         MO.setReg(Reg);
598 
599         Changed = true;
600       }
601     }
602   }
603   assert(Substs.size() == 0 && "No substitutions should be left active!");
604 
605   if (G->getKill()) {
606     DEBUG(dbgs() << " - Kill instruction seen.\n");
607   } else {
608     // We didn't have a kill instruction, but we didn't seem to need to change
609     // the destination register anyway.
610     DEBUG(dbgs() << " - Destination register not changed.\n");
611   }
612   return Changed;
613 }
614 
scanInstruction(MachineInstr * MI,unsigned Idx,std::map<unsigned,Chain * > & ActiveChains,std::vector<std::unique_ptr<Chain>> & AllChains)615 void AArch64A57FPLoadBalancing::scanInstruction(
616     MachineInstr *MI, unsigned Idx, std::map<unsigned, Chain *> &ActiveChains,
617     std::vector<std::unique_ptr<Chain>> &AllChains) {
618   // Inspect "MI", updating ActiveChains and AllChains.
619 
620   if (isMul(MI)) {
621 
622     for (auto &I : MI->uses())
623       maybeKillChain(I, Idx, ActiveChains);
624     for (auto &I : MI->defs())
625       maybeKillChain(I, Idx, ActiveChains);
626 
627     // Create a new chain. Multiplies don't require forwarding so can go on any
628     // unit.
629     unsigned DestReg = MI->getOperand(0).getReg();
630 
631     DEBUG(dbgs() << "New chain started for register "
632           << TRI->getName(DestReg) << " at " << *MI);
633 
634     auto G = llvm::make_unique<Chain>(MI, Idx, getColor(DestReg));
635     ActiveChains[DestReg] = G.get();
636     AllChains.push_back(std::move(G));
637 
638   } else if (isMla(MI)) {
639 
640     // It is beneficial to keep MLAs on the same functional unit as their
641     // accumulator operand.
642     unsigned DestReg  = MI->getOperand(0).getReg();
643     unsigned AccumReg = MI->getOperand(3).getReg();
644 
645     maybeKillChain(MI->getOperand(1), Idx, ActiveChains);
646     maybeKillChain(MI->getOperand(2), Idx, ActiveChains);
647     if (DestReg != AccumReg)
648       maybeKillChain(MI->getOperand(0), Idx, ActiveChains);
649 
650     if (ActiveChains.find(AccumReg) != ActiveChains.end()) {
651       DEBUG(dbgs() << "Chain found for accumulator register "
652             << TRI->getName(AccumReg) << " in MI " << *MI);
653 
654       // For simplicity we only chain together sequences of MULs/MLAs where the
655       // accumulator register is killed on each instruction. This means we don't
656       // need to track other uses of the registers we want to rewrite.
657       //
658       // FIXME: We could extend to handle the non-kill cases for more coverage.
659       if (MI->getOperand(3).isKill()) {
660         // Add to chain.
661         DEBUG(dbgs() << "Instruction was successfully added to chain.\n");
662         ActiveChains[AccumReg]->add(MI, Idx, getColor(DestReg));
663         // Handle cases where the destination is not the same as the accumulator.
664         if (DestReg != AccumReg) {
665           ActiveChains[DestReg] = ActiveChains[AccumReg];
666           ActiveChains.erase(AccumReg);
667         }
668         return;
669       }
670 
671       DEBUG(dbgs() << "Cannot add to chain because accumulator operand wasn't "
672             << "marked <kill>!\n");
673       maybeKillChain(MI->getOperand(3), Idx, ActiveChains);
674     }
675 
676     DEBUG(dbgs() << "Creating new chain for dest register "
677           << TRI->getName(DestReg) << "\n");
678     auto G = llvm::make_unique<Chain>(MI, Idx, getColor(DestReg));
679     ActiveChains[DestReg] = G.get();
680     AllChains.push_back(std::move(G));
681 
682   } else {
683 
684     // Non-MUL or MLA instruction. Invalidate any chain in the uses or defs
685     // lists.
686     for (auto &I : MI->uses())
687       maybeKillChain(I, Idx, ActiveChains);
688     for (auto &I : MI->defs())
689       maybeKillChain(I, Idx, ActiveChains);
690 
691   }
692 }
693 
694 void AArch64A57FPLoadBalancing::
maybeKillChain(MachineOperand & MO,unsigned Idx,std::map<unsigned,Chain * > & ActiveChains)695 maybeKillChain(MachineOperand &MO, unsigned Idx,
696                std::map<unsigned, Chain*> &ActiveChains) {
697   // Given an operand and the set of active chains (keyed by register),
698   // determine if a chain should be ended and remove from ActiveChains.
699   MachineInstr *MI = MO.getParent();
700 
701   if (MO.isReg()) {
702 
703     // If this is a KILL of a current chain, record it.
704     if (MO.isKill() && ActiveChains.find(MO.getReg()) != ActiveChains.end()) {
705       DEBUG(dbgs() << "Kill seen for chain " << TRI->getName(MO.getReg())
706             << "\n");
707       ActiveChains[MO.getReg()]->setKill(MI, Idx, /*Immutable=*/MO.isTied());
708     }
709     ActiveChains.erase(MO.getReg());
710 
711   } else if (MO.isRegMask()) {
712 
713     for (auto I = ActiveChains.begin(), E = ActiveChains.end();
714          I != E;) {
715       if (MO.clobbersPhysReg(I->first)) {
716         DEBUG(dbgs() << "Kill (regmask) seen for chain "
717               << TRI->getName(I->first) << "\n");
718         I->second->setKill(MI, Idx, /*Immutable=*/true);
719         ActiveChains.erase(I++);
720       } else
721         ++I;
722     }
723 
724   }
725 }
726 
getColor(unsigned Reg)727 Color AArch64A57FPLoadBalancing::getColor(unsigned Reg) {
728   if ((TRI->getEncodingValue(Reg) % 2) == 0)
729     return Color::Even;
730   else
731     return Color::Odd;
732 }
733 
734 // Factory function used by AArch64TargetMachine to add the pass to the passmanager.
createAArch64A57FPLoadBalancing()735 FunctionPass *llvm::createAArch64A57FPLoadBalancing() {
736   return new AArch64A57FPLoadBalancing();
737 }
738