• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines a utility class to perform loop versioning.  The versioned
11 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
12 // emits checks to prove this.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Utils/LoopVersioning.h"
17 #include "llvm/Analysis/LoopAccessAnalysis.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/ScalarEvolutionExpander.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/MDBuilder.h"
22 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
23 #include "llvm/Transforms/Utils/Cloning.h"
24 
25 using namespace llvm;
26 
27 static cl::opt<bool>
28     AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
29                     cl::Hidden,
30                     cl::desc("Add no-alias annotation for instructions that "
31                              "are disambiguated by memchecks"));
32 
LoopVersioning(const LoopAccessInfo & LAI,Loop * L,LoopInfo * LI,DominatorTree * DT,ScalarEvolution * SE,bool UseLAIChecks)33 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
34                                DominatorTree *DT, ScalarEvolution *SE,
35                                bool UseLAIChecks)
36     : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
37       SE(SE) {
38   assert(L->getExitBlock() && "No single exit block");
39   assert(L->getLoopPreheader() && "No preheader");
40   if (UseLAIChecks) {
41     setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
42     setSCEVChecks(LAI.getPSE().getUnionPredicate());
43   }
44 }
45 
setAliasChecks(SmallVector<RuntimePointerChecking::PointerCheck,4> Checks)46 void LoopVersioning::setAliasChecks(
47     SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
48   AliasChecks = std::move(Checks);
49 }
50 
setSCEVChecks(SCEVUnionPredicate Check)51 void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
52   Preds = std::move(Check);
53 }
54 
versionLoop(const SmallVectorImpl<Instruction * > & DefsUsedOutside)55 void LoopVersioning::versionLoop(
56     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
57   Instruction *FirstCheckInst;
58   Instruction *MemRuntimeCheck;
59   Value *SCEVRuntimeCheck;
60   Value *RuntimeCheck = nullptr;
61 
62   // Add the memcheck in the original preheader (this is empty initially).
63   BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
64   std::tie(FirstCheckInst, MemRuntimeCheck) =
65       LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
66 
67   const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate();
68   SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
69                    "scev.check");
70   SCEVRuntimeCheck =
71       Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
72   auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
73 
74   // Discard the SCEV runtime check if it is always true.
75   if (CI && CI->isZero())
76     SCEVRuntimeCheck = nullptr;
77 
78   if (MemRuntimeCheck && SCEVRuntimeCheck) {
79     RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
80                                           SCEVRuntimeCheck, "lver.safe");
81     if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
82       I->insertBefore(RuntimeCheckBB->getTerminator());
83   } else
84     RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
85 
86   assert(RuntimeCheck && "called even though we don't need "
87                          "any runtime checks");
88 
89   // Rename the block to make the IR more readable.
90   RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
91                           ".lver.check");
92 
93   // Create empty preheader for the loop (and after cloning for the
94   // non-versioned loop).
95   BasicBlock *PH =
96       SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
97   PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
98 
99   // Clone the loop including the preheader.
100   //
101   // FIXME: This does not currently preserve SimplifyLoop because the exit
102   // block is a join between the two loops.
103   SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
104   NonVersionedLoop =
105       cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
106                              ".lver.orig", LI, DT, NonVersionedLoopBlocks);
107   remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
108 
109   // Insert the conditional branch based on the result of the memchecks.
110   Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
111   BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
112                      VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
113   OrigTerm->eraseFromParent();
114 
115   // The loops merge in the original exit block.  This is now dominated by the
116   // memchecking block.
117   DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
118 
119   // Adds the necessary PHI nodes for the versioned loops based on the
120   // loop-defined values used outside of the loop.
121   addPHINodes(DefsUsedOutside);
122 }
123 
addPHINodes(const SmallVectorImpl<Instruction * > & DefsUsedOutside)124 void LoopVersioning::addPHINodes(
125     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
126   BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
127   assert(PHIBlock && "No single successor to loop exit block");
128   PHINode *PN;
129 
130   // First add a single-operand PHI for each DefsUsedOutside if one does not
131   // exists yet.
132   for (auto *Inst : DefsUsedOutside) {
133     // See if we have a single-operand PHI with the value defined by the
134     // original loop.
135     for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
136       if (PN->getIncomingValue(0) == Inst)
137         break;
138     }
139     // If not create it.
140     if (!PN) {
141       PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
142                            &PHIBlock->front());
143       for (auto *User : Inst->users())
144         if (!VersionedLoop->contains(cast<Instruction>(User)->getParent()))
145           User->replaceUsesOfWith(Inst, PN);
146       PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
147     }
148   }
149 
150   // Then for each PHI add the operand for the edge from the cloned loop.
151   for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
152     assert(PN->getNumOperands() == 1 &&
153            "Exit block should only have on predecessor");
154 
155     // If the definition was cloned used that otherwise use the same value.
156     Value *ClonedValue = PN->getIncomingValue(0);
157     auto Mapped = VMap.find(ClonedValue);
158     if (Mapped != VMap.end())
159       ClonedValue = Mapped->second;
160 
161     PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
162   }
163 }
164 
prepareNoAliasMetadata()165 void LoopVersioning::prepareNoAliasMetadata() {
166   // We need to turn the no-alias relation between pointer checking groups into
167   // no-aliasing annotations between instructions.
168   //
169   // We accomplish this by mapping each pointer checking group (a set of
170   // pointers memchecked together) to an alias scope and then also mapping each
171   // group to the list of scopes it can't alias.
172 
173   const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
174   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
175 
176   // First allocate an aliasing scope for each pointer checking group.
177   //
178   // While traversing through the checking groups in the loop, also create a
179   // reverse map from pointers to the pointer checking group they were assigned
180   // to.
181   MDBuilder MDB(Context);
182   MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
183 
184   for (const auto &Group : RtPtrChecking->CheckingGroups) {
185     GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
186 
187     for (unsigned PtrIdx : Group.Members)
188       PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
189   }
190 
191   // Go through the checks and for each pointer group, collect the scopes for
192   // each non-aliasing pointer group.
193   DenseMap<const RuntimePointerChecking::CheckingPtrGroup *,
194            SmallVector<Metadata *, 4>>
195       GroupToNonAliasingScopes;
196 
197   for (const auto &Check : AliasChecks)
198     GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
199 
200   // Finally, transform the above to actually map to scope list which is what
201   // the metadata uses.
202 
203   for (auto Pair : GroupToNonAliasingScopes)
204     GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
205 }
206 
annotateLoopWithNoAlias()207 void LoopVersioning::annotateLoopWithNoAlias() {
208   if (!AnnotateNoAlias)
209     return;
210 
211   // First prepare the maps.
212   prepareNoAliasMetadata();
213 
214   // Add the scope and no-alias metadata to the instructions.
215   for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
216     annotateInstWithNoAlias(I);
217   }
218 }
219 
annotateInstWithNoAlias(Instruction * VersionedInst,const Instruction * OrigInst)220 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
221                                              const Instruction *OrigInst) {
222   if (!AnnotateNoAlias)
223     return;
224 
225   LLVMContext &Context = VersionedLoop->getHeader()->getContext();
226   const Value *Ptr = isa<LoadInst>(OrigInst)
227                          ? cast<LoadInst>(OrigInst)->getPointerOperand()
228                          : cast<StoreInst>(OrigInst)->getPointerOperand();
229 
230   // Find the group for the pointer and then add the scope metadata.
231   auto Group = PtrToGroup.find(Ptr);
232   if (Group != PtrToGroup.end()) {
233     VersionedInst->setMetadata(
234         LLVMContext::MD_alias_scope,
235         MDNode::concatenate(
236             VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
237             MDNode::get(Context, GroupToScope[Group->second])));
238 
239     // Add the no-alias metadata.
240     auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
241     if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
242       VersionedInst->setMetadata(
243           LLVMContext::MD_noalias,
244           MDNode::concatenate(
245               VersionedInst->getMetadata(LLVMContext::MD_noalias),
246               NonAliasingScopeList->second));
247   }
248 }
249 
250 namespace {
251 /// \brief Also expose this is a pass.  Currently this is only used for
252 /// unit-testing.  It adds all memchecks necessary to remove all may-aliasing
253 /// array accesses from the loop.
254 class LoopVersioningPass : public FunctionPass {
255 public:
LoopVersioningPass()256   LoopVersioningPass() : FunctionPass(ID) {
257     initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry());
258   }
259 
runOnFunction(Function & F)260   bool runOnFunction(Function &F) override {
261     auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
262     auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
263     auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
264     auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
265 
266     // Build up a worklist of inner-loops to version. This is necessary as the
267     // act of versioning a loop creates new loops and can invalidate iterators
268     // across the loops.
269     SmallVector<Loop *, 8> Worklist;
270 
271     for (Loop *TopLevelLoop : *LI)
272       for (Loop *L : depth_first(TopLevelLoop))
273         // We only handle inner-most loops.
274         if (L->empty())
275           Worklist.push_back(L);
276 
277     // Now walk the identified inner loops.
278     bool Changed = false;
279     for (Loop *L : Worklist) {
280       const LoopAccessInfo &LAI = LAA->getInfo(L);
281       if (LAI.getNumRuntimePointerChecks() ||
282           !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) {
283         LoopVersioning LVer(LAI, L, LI, DT, SE);
284         LVer.versionLoop();
285         LVer.annotateLoopWithNoAlias();
286         Changed = true;
287       }
288     }
289 
290     return Changed;
291   }
292 
getAnalysisUsage(AnalysisUsage & AU) const293   void getAnalysisUsage(AnalysisUsage &AU) const override {
294     AU.addRequired<LoopInfoWrapperPass>();
295     AU.addPreserved<LoopInfoWrapperPass>();
296     AU.addRequired<LoopAccessLegacyAnalysis>();
297     AU.addRequired<DominatorTreeWrapperPass>();
298     AU.addPreserved<DominatorTreeWrapperPass>();
299     AU.addRequired<ScalarEvolutionWrapperPass>();
300   }
301 
302   static char ID;
303 };
304 }
305 
306 #define LVER_OPTION "loop-versioning"
307 #define DEBUG_TYPE LVER_OPTION
308 
309 char LoopVersioningPass::ID;
310 static const char LVer_name[] = "Loop Versioning";
311 
312 INITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
313 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
314 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
315 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
316 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
317 INITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
318 
319 namespace llvm {
createLoopVersioningPass()320 FunctionPass *createLoopVersioningPass() {
321   return new LoopVersioningPass();
322 }
323 }
324