1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines a utility class to perform loop versioning. The versioned
10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and
11 // emits checks to prove this.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Utils/LoopVersioning.h"
16 #include "llvm/Analysis/LoopAccessAnalysis.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/Analysis/ScalarEvolutionExpander.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
24 #include "llvm/Transforms/Utils/Cloning.h"
25
26 using namespace llvm;
27
28 static cl::opt<bool>
29 AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true),
30 cl::Hidden,
31 cl::desc("Add no-alias annotation for instructions that "
32 "are disambiguated by memchecks"));
33
LoopVersioning(const LoopAccessInfo & LAI,Loop * L,LoopInfo * LI,DominatorTree * DT,ScalarEvolution * SE,bool UseLAIChecks)34 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
35 DominatorTree *DT, ScalarEvolution *SE,
36 bool UseLAIChecks)
37 : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
38 SE(SE) {
39 assert(L->getExitBlock() && "No single exit block");
40 assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form");
41 if (UseLAIChecks) {
42 setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
43 setSCEVChecks(LAI.getPSE().getUnionPredicate());
44 }
45 }
46
setAliasChecks(SmallVector<RuntimePointerChecking::PointerCheck,4> Checks)47 void LoopVersioning::setAliasChecks(
48 SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
49 AliasChecks = std::move(Checks);
50 }
51
setSCEVChecks(SCEVUnionPredicate Check)52 void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
53 Preds = std::move(Check);
54 }
55
versionLoop(const SmallVectorImpl<Instruction * > & DefsUsedOutside)56 void LoopVersioning::versionLoop(
57 const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
58 Instruction *FirstCheckInst;
59 Instruction *MemRuntimeCheck;
60 Value *SCEVRuntimeCheck;
61 Value *RuntimeCheck = nullptr;
62
63 // Add the memcheck in the original preheader (this is empty initially).
64 BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
65 std::tie(FirstCheckInst, MemRuntimeCheck) =
66 LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
67
68 const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate();
69 SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
70 "scev.check");
71 SCEVRuntimeCheck =
72 Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
73 auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck);
74
75 // Discard the SCEV runtime check if it is always true.
76 if (CI && CI->isZero())
77 SCEVRuntimeCheck = nullptr;
78
79 if (MemRuntimeCheck && SCEVRuntimeCheck) {
80 RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
81 SCEVRuntimeCheck, "lver.safe");
82 if (auto *I = dyn_cast<Instruction>(RuntimeCheck))
83 I->insertBefore(RuntimeCheckBB->getTerminator());
84 } else
85 RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
86
87 assert(RuntimeCheck && "called even though we don't need "
88 "any runtime checks");
89
90 // Rename the block to make the IR more readable.
91 RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
92 ".lver.check");
93
94 // Create empty preheader for the loop (and after cloning for the
95 // non-versioned loop).
96 BasicBlock *PH =
97 SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI,
98 nullptr, VersionedLoop->getHeader()->getName() + ".ph");
99
100 // Clone the loop including the preheader.
101 //
102 // FIXME: This does not currently preserve SimplifyLoop because the exit
103 // block is a join between the two loops.
104 SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks;
105 NonVersionedLoop =
106 cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
107 ".lver.orig", LI, DT, NonVersionedLoopBlocks);
108 remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
109
110 // Insert the conditional branch based on the result of the memchecks.
111 Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
112 BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
113 VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
114 OrigTerm->eraseFromParent();
115
116 // The loops merge in the original exit block. This is now dominated by the
117 // memchecking block.
118 DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
119
120 // Adds the necessary PHI nodes for the versioned loops based on the
121 // loop-defined values used outside of the loop.
122 addPHINodes(DefsUsedOutside);
123 }
124
addPHINodes(const SmallVectorImpl<Instruction * > & DefsUsedOutside)125 void LoopVersioning::addPHINodes(
126 const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
127 BasicBlock *PHIBlock = VersionedLoop->getExitBlock();
128 assert(PHIBlock && "No single successor to loop exit block");
129 PHINode *PN;
130
131 // First add a single-operand PHI for each DefsUsedOutside if one does not
132 // exists yet.
133 for (auto *Inst : DefsUsedOutside) {
134 // See if we have a single-operand PHI with the value defined by the
135 // original loop.
136 for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
137 if (PN->getIncomingValue(0) == Inst)
138 break;
139 }
140 // If not create it.
141 if (!PN) {
142 PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver",
143 &PHIBlock->front());
144 SmallVector<User*, 8> UsersToUpdate;
145 for (User *U : Inst->users())
146 if (!VersionedLoop->contains(cast<Instruction>(U)->getParent()))
147 UsersToUpdate.push_back(U);
148 for (User *U : UsersToUpdate)
149 U->replaceUsesOfWith(Inst, PN);
150 PN->addIncoming(Inst, VersionedLoop->getExitingBlock());
151 }
152 }
153
154 // Then for each PHI add the operand for the edge from the cloned loop.
155 for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) {
156 assert(PN->getNumOperands() == 1 &&
157 "Exit block should only have on predecessor");
158
159 // If the definition was cloned used that otherwise use the same value.
160 Value *ClonedValue = PN->getIncomingValue(0);
161 auto Mapped = VMap.find(ClonedValue);
162 if (Mapped != VMap.end())
163 ClonedValue = Mapped->second;
164
165 PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock());
166 }
167 }
168
prepareNoAliasMetadata()169 void LoopVersioning::prepareNoAliasMetadata() {
170 // We need to turn the no-alias relation between pointer checking groups into
171 // no-aliasing annotations between instructions.
172 //
173 // We accomplish this by mapping each pointer checking group (a set of
174 // pointers memchecked together) to an alias scope and then also mapping each
175 // group to the list of scopes it can't alias.
176
177 const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking();
178 LLVMContext &Context = VersionedLoop->getHeader()->getContext();
179
180 // First allocate an aliasing scope for each pointer checking group.
181 //
182 // While traversing through the checking groups in the loop, also create a
183 // reverse map from pointers to the pointer checking group they were assigned
184 // to.
185 MDBuilder MDB(Context);
186 MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain");
187
188 for (const auto &Group : RtPtrChecking->CheckingGroups) {
189 GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain);
190
191 for (unsigned PtrIdx : Group.Members)
192 PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group;
193 }
194
195 // Go through the checks and for each pointer group, collect the scopes for
196 // each non-aliasing pointer group.
197 DenseMap<const RuntimePointerChecking::CheckingPtrGroup *,
198 SmallVector<Metadata *, 4>>
199 GroupToNonAliasingScopes;
200
201 for (const auto &Check : AliasChecks)
202 GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]);
203
204 // Finally, transform the above to actually map to scope list which is what
205 // the metadata uses.
206
207 for (auto Pair : GroupToNonAliasingScopes)
208 GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second);
209 }
210
annotateLoopWithNoAlias()211 void LoopVersioning::annotateLoopWithNoAlias() {
212 if (!AnnotateNoAlias)
213 return;
214
215 // First prepare the maps.
216 prepareNoAliasMetadata();
217
218 // Add the scope and no-alias metadata to the instructions.
219 for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) {
220 annotateInstWithNoAlias(I);
221 }
222 }
223
annotateInstWithNoAlias(Instruction * VersionedInst,const Instruction * OrigInst)224 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst,
225 const Instruction *OrigInst) {
226 if (!AnnotateNoAlias)
227 return;
228
229 LLVMContext &Context = VersionedLoop->getHeader()->getContext();
230 const Value *Ptr = isa<LoadInst>(OrigInst)
231 ? cast<LoadInst>(OrigInst)->getPointerOperand()
232 : cast<StoreInst>(OrigInst)->getPointerOperand();
233
234 // Find the group for the pointer and then add the scope metadata.
235 auto Group = PtrToGroup.find(Ptr);
236 if (Group != PtrToGroup.end()) {
237 VersionedInst->setMetadata(
238 LLVMContext::MD_alias_scope,
239 MDNode::concatenate(
240 VersionedInst->getMetadata(LLVMContext::MD_alias_scope),
241 MDNode::get(Context, GroupToScope[Group->second])));
242
243 // Add the no-alias metadata.
244 auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second);
245 if (NonAliasingScopeList != GroupToNonAliasingScopeList.end())
246 VersionedInst->setMetadata(
247 LLVMContext::MD_noalias,
248 MDNode::concatenate(
249 VersionedInst->getMetadata(LLVMContext::MD_noalias),
250 NonAliasingScopeList->second));
251 }
252 }
253
254 namespace {
255 /// Also expose this is a pass. Currently this is only used for
256 /// unit-testing. It adds all memchecks necessary to remove all may-aliasing
257 /// array accesses from the loop.
258 class LoopVersioningPass : public FunctionPass {
259 public:
LoopVersioningPass()260 LoopVersioningPass() : FunctionPass(ID) {
261 initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry());
262 }
263
runOnFunction(Function & F)264 bool runOnFunction(Function &F) override {
265 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
266 auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>();
267 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
268 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
269
270 // Build up a worklist of inner-loops to version. This is necessary as the
271 // act of versioning a loop creates new loops and can invalidate iterators
272 // across the loops.
273 SmallVector<Loop *, 8> Worklist;
274
275 for (Loop *TopLevelLoop : *LI)
276 for (Loop *L : depth_first(TopLevelLoop))
277 // We only handle inner-most loops.
278 if (L->empty())
279 Worklist.push_back(L);
280
281 // Now walk the identified inner loops.
282 bool Changed = false;
283 for (Loop *L : Worklist) {
284 const LoopAccessInfo &LAI = LAA->getInfo(L);
285 if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() &&
286 (LAI.getNumRuntimePointerChecks() ||
287 !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) {
288 LoopVersioning LVer(LAI, L, LI, DT, SE);
289 LVer.versionLoop();
290 LVer.annotateLoopWithNoAlias();
291 Changed = true;
292 }
293 }
294
295 return Changed;
296 }
297
getAnalysisUsage(AnalysisUsage & AU) const298 void getAnalysisUsage(AnalysisUsage &AU) const override {
299 AU.addRequired<LoopInfoWrapperPass>();
300 AU.addPreserved<LoopInfoWrapperPass>();
301 AU.addRequired<LoopAccessLegacyAnalysis>();
302 AU.addRequired<DominatorTreeWrapperPass>();
303 AU.addPreserved<DominatorTreeWrapperPass>();
304 AU.addRequired<ScalarEvolutionWrapperPass>();
305 }
306
307 static char ID;
308 };
309 }
310
311 #define LVER_OPTION "loop-versioning"
312 #define DEBUG_TYPE LVER_OPTION
313
314 char LoopVersioningPass::ID;
315 static const char LVer_name[] = "Loop Versioning";
316
317 INITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
318 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
319 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis)
320 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
321 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
322 INITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false)
323
324 namespace llvm {
createLoopVersioningPass()325 FunctionPass *createLoopVersioningPass() {
326 return new LoopVersioningPass();
327 }
328 }
329