1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/MDBuilder.h"
19 #include "llvm/ProfileData/InstrProfReader.h"
20 #include "llvm/Support/Endian.h"
21 #include "llvm/Support/FileSystem.h"
22 #include "llvm/Support/MD5.h"
23
24 using namespace clang;
25 using namespace CodeGen;
26
setFuncName(llvm::Function * Fn)27 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
28 RawFuncName = Fn->getName();
29
30 // Function names may be prefixed with a binary '1' to indicate
31 // that the backend should not modify the symbols due to any platform
32 // naming convention. Do not include that '1' in the PGO profile name.
33 if (RawFuncName[0] == '\1')
34 RawFuncName = RawFuncName.substr(1);
35
36 if (!Fn->hasLocalLinkage()) {
37 PrefixedFuncName.reset(new std::string(RawFuncName));
38 return;
39 }
40
41 // For local symbols, prepend the main file name to distinguish them.
42 // Do not include the full path in the file name since there's no guarantee
43 // that it will stay the same, e.g., if the files are checked out from
44 // version control in different locations.
45 PrefixedFuncName.reset(new std::string(CGM.getCodeGenOpts().MainFileName));
46 if (PrefixedFuncName->empty())
47 PrefixedFuncName->assign("<unknown>");
48 PrefixedFuncName->append(":");
49 PrefixedFuncName->append(RawFuncName);
50 }
51
getRegisterFunc(CodeGenModule & CGM)52 static llvm::Function *getRegisterFunc(CodeGenModule &CGM) {
53 return CGM.getModule().getFunction("__llvm_profile_register_functions");
54 }
55
getOrInsertRegisterBB(CodeGenModule & CGM)56 static llvm::BasicBlock *getOrInsertRegisterBB(CodeGenModule &CGM) {
57 // Don't do this for Darwin. compiler-rt uses linker magic.
58 if (CGM.getTarget().getTriple().isOSDarwin())
59 return nullptr;
60
61 // Only need to insert this once per module.
62 if (llvm::Function *RegisterF = getRegisterFunc(CGM))
63 return &RegisterF->getEntryBlock();
64
65 // Construct the function.
66 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
67 auto *RegisterFTy = llvm::FunctionType::get(VoidTy, false);
68 auto *RegisterF = llvm::Function::Create(RegisterFTy,
69 llvm::GlobalValue::InternalLinkage,
70 "__llvm_profile_register_functions",
71 &CGM.getModule());
72 RegisterF->setUnnamedAddr(true);
73 if (CGM.getCodeGenOpts().DisableRedZone)
74 RegisterF->addFnAttr(llvm::Attribute::NoRedZone);
75
76 // Construct and return the entry block.
77 auto *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", RegisterF);
78 CGBuilderTy Builder(BB);
79 Builder.CreateRetVoid();
80 return BB;
81 }
82
getOrInsertRuntimeRegister(CodeGenModule & CGM)83 static llvm::Constant *getOrInsertRuntimeRegister(CodeGenModule &CGM) {
84 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
85 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
86 auto *RuntimeRegisterTy = llvm::FunctionType::get(VoidTy, VoidPtrTy, false);
87 return CGM.getModule().getOrInsertFunction("__llvm_profile_register_function",
88 RuntimeRegisterTy);
89 }
90
isMachO(const CodeGenModule & CGM)91 static bool isMachO(const CodeGenModule &CGM) {
92 return CGM.getTarget().getTriple().isOSBinFormatMachO();
93 }
94
getCountersSection(const CodeGenModule & CGM)95 static StringRef getCountersSection(const CodeGenModule &CGM) {
96 return isMachO(CGM) ? "__DATA,__llvm_prf_cnts" : "__llvm_prf_cnts";
97 }
98
getNameSection(const CodeGenModule & CGM)99 static StringRef getNameSection(const CodeGenModule &CGM) {
100 return isMachO(CGM) ? "__DATA,__llvm_prf_names" : "__llvm_prf_names";
101 }
102
getDataSection(const CodeGenModule & CGM)103 static StringRef getDataSection(const CodeGenModule &CGM) {
104 return isMachO(CGM) ? "__DATA,__llvm_prf_data" : "__llvm_prf_data";
105 }
106
buildDataVar()107 llvm::GlobalVariable *CodeGenPGO::buildDataVar() {
108 // Create name variable.
109 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
110 auto *VarName = llvm::ConstantDataArray::getString(Ctx, getFuncName(),
111 false);
112 auto *Name = new llvm::GlobalVariable(CGM.getModule(), VarName->getType(),
113 true, VarLinkage, VarName,
114 getFuncVarName("name"));
115 Name->setSection(getNameSection(CGM));
116 Name->setAlignment(1);
117
118 // Create data variable.
119 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
120 auto *Int64Ty = llvm::Type::getInt64Ty(Ctx);
121 auto *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
122 auto *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
123 llvm::Type *DataTypes[] = {
124 Int32Ty, Int32Ty, Int64Ty, Int8PtrTy, Int64PtrTy
125 };
126 auto *DataTy = llvm::StructType::get(Ctx, makeArrayRef(DataTypes));
127 llvm::Constant *DataVals[] = {
128 llvm::ConstantInt::get(Int32Ty, getFuncName().size()),
129 llvm::ConstantInt::get(Int32Ty, NumRegionCounters),
130 llvm::ConstantInt::get(Int64Ty, FunctionHash),
131 llvm::ConstantExpr::getBitCast(Name, Int8PtrTy),
132 llvm::ConstantExpr::getBitCast(RegionCounters, Int64PtrTy)
133 };
134 auto *Data =
135 new llvm::GlobalVariable(CGM.getModule(), DataTy, true, VarLinkage,
136 llvm::ConstantStruct::get(DataTy, DataVals),
137 getFuncVarName("data"));
138
139 // All the data should be packed into an array in its own section.
140 Data->setSection(getDataSection(CGM));
141 Data->setAlignment(8);
142
143 // Hide all these symbols so that we correctly get a copy for each
144 // executable. The profile format expects names and counters to be
145 // contiguous, so references into shared objects would be invalid.
146 if (!llvm::GlobalValue::isLocalLinkage(VarLinkage)) {
147 Name->setVisibility(llvm::GlobalValue::HiddenVisibility);
148 Data->setVisibility(llvm::GlobalValue::HiddenVisibility);
149 RegionCounters->setVisibility(llvm::GlobalValue::HiddenVisibility);
150 }
151
152 // Make sure the data doesn't get deleted.
153 CGM.addUsedGlobal(Data);
154 return Data;
155 }
156
emitInstrumentationData()157 void CodeGenPGO::emitInstrumentationData() {
158 if (!RegionCounters)
159 return;
160
161 // Build the data.
162 auto *Data = buildDataVar();
163
164 // Register the data.
165 auto *RegisterBB = getOrInsertRegisterBB(CGM);
166 if (!RegisterBB)
167 return;
168 CGBuilderTy Builder(RegisterBB->getTerminator());
169 auto *VoidPtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
170 Builder.CreateCall(getOrInsertRuntimeRegister(CGM),
171 Builder.CreateBitCast(Data, VoidPtrTy));
172 }
173
emitInitialization(CodeGenModule & CGM)174 llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
175 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
176 return nullptr;
177
178 assert(CGM.getModule().getFunction("__llvm_profile_init") == nullptr &&
179 "profile initialization already emitted");
180
181 // Get the function to call at initialization.
182 llvm::Constant *RegisterF = getRegisterFunc(CGM);
183 if (!RegisterF)
184 return nullptr;
185
186 // Create the initialization function.
187 auto *VoidTy = llvm::Type::getVoidTy(CGM.getLLVMContext());
188 auto *F = llvm::Function::Create(llvm::FunctionType::get(VoidTy, false),
189 llvm::GlobalValue::InternalLinkage,
190 "__llvm_profile_init", &CGM.getModule());
191 F->setUnnamedAddr(true);
192 F->addFnAttr(llvm::Attribute::NoInline);
193 if (CGM.getCodeGenOpts().DisableRedZone)
194 F->addFnAttr(llvm::Attribute::NoRedZone);
195
196 // Add the basic block and the necessary calls.
197 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F));
198 Builder.CreateCall(RegisterF);
199 Builder.CreateRetVoid();
200
201 return F;
202 }
203
204 namespace {
205 /// \brief Stable hasher for PGO region counters.
206 ///
207 /// PGOHash produces a stable hash of a given function's control flow.
208 ///
209 /// Changing the output of this hash will invalidate all previously generated
210 /// profiles -- i.e., don't do it.
211 ///
212 /// \note When this hash does eventually change (years?), we still need to
213 /// support old hashes. We'll need to pull in the version number from the
214 /// profile data format and use the matching hash function.
215 class PGOHash {
216 uint64_t Working;
217 unsigned Count;
218 llvm::MD5 MD5;
219
220 static const int NumBitsPerType = 6;
221 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
222 static const unsigned TooBig = 1u << NumBitsPerType;
223
224 public:
225 /// \brief Hash values for AST nodes.
226 ///
227 /// Distinct values for AST nodes that have region counters attached.
228 ///
229 /// These values must be stable. All new members must be added at the end,
230 /// and no members should be removed. Changing the enumeration value for an
231 /// AST node will affect the hash of every function that contains that node.
232 enum HashType : unsigned char {
233 None = 0,
234 LabelStmt = 1,
235 WhileStmt,
236 DoStmt,
237 ForStmt,
238 CXXForRangeStmt,
239 ObjCForCollectionStmt,
240 SwitchStmt,
241 CaseStmt,
242 DefaultStmt,
243 IfStmt,
244 CXXTryStmt,
245 CXXCatchStmt,
246 ConditionalOperator,
247 BinaryOperatorLAnd,
248 BinaryOperatorLOr,
249 BinaryConditionalOperator,
250
251 // Keep this last. It's for the static assert that follows.
252 LastHashType
253 };
254 static_assert(LastHashType <= TooBig, "Too many types in HashType");
255
256 // TODO: When this format changes, take in a version number here, and use the
257 // old hash calculation for file formats that used the old hash.
PGOHash()258 PGOHash() : Working(0), Count(0) {}
259 void combine(HashType Type);
260 uint64_t finalize();
261 };
262 const int PGOHash::NumBitsPerType;
263 const unsigned PGOHash::NumTypesPerWord;
264 const unsigned PGOHash::TooBig;
265
266 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
267 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
268 /// The next counter value to assign.
269 unsigned NextCounter;
270 /// The function hash.
271 PGOHash Hash;
272 /// The map of statements to counters.
273 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
274
MapRegionCounters__anon4141c6920111::MapRegionCounters275 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
276 : NextCounter(0), CounterMap(CounterMap) {}
277
278 // Blocks and lambdas are handled as separate functions, so we need not
279 // traverse them in the parent context.
TraverseBlockExpr__anon4141c6920111::MapRegionCounters280 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaBody__anon4141c6920111::MapRegionCounters281 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
TraverseCapturedStmt__anon4141c6920111::MapRegionCounters282 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
283
VisitDecl__anon4141c6920111::MapRegionCounters284 bool VisitDecl(const Decl *D) {
285 switch (D->getKind()) {
286 default:
287 break;
288 case Decl::Function:
289 case Decl::CXXMethod:
290 case Decl::CXXConstructor:
291 case Decl::CXXDestructor:
292 case Decl::CXXConversion:
293 case Decl::ObjCMethod:
294 case Decl::Block:
295 case Decl::Captured:
296 CounterMap[D->getBody()] = NextCounter++;
297 break;
298 }
299 return true;
300 }
301
VisitStmt__anon4141c6920111::MapRegionCounters302 bool VisitStmt(const Stmt *S) {
303 auto Type = getHashType(S);
304 if (Type == PGOHash::None)
305 return true;
306
307 CounterMap[S] = NextCounter++;
308 Hash.combine(Type);
309 return true;
310 }
getHashType__anon4141c6920111::MapRegionCounters311 PGOHash::HashType getHashType(const Stmt *S) {
312 switch (S->getStmtClass()) {
313 default:
314 break;
315 case Stmt::LabelStmtClass:
316 return PGOHash::LabelStmt;
317 case Stmt::WhileStmtClass:
318 return PGOHash::WhileStmt;
319 case Stmt::DoStmtClass:
320 return PGOHash::DoStmt;
321 case Stmt::ForStmtClass:
322 return PGOHash::ForStmt;
323 case Stmt::CXXForRangeStmtClass:
324 return PGOHash::CXXForRangeStmt;
325 case Stmt::ObjCForCollectionStmtClass:
326 return PGOHash::ObjCForCollectionStmt;
327 case Stmt::SwitchStmtClass:
328 return PGOHash::SwitchStmt;
329 case Stmt::CaseStmtClass:
330 return PGOHash::CaseStmt;
331 case Stmt::DefaultStmtClass:
332 return PGOHash::DefaultStmt;
333 case Stmt::IfStmtClass:
334 return PGOHash::IfStmt;
335 case Stmt::CXXTryStmtClass:
336 return PGOHash::CXXTryStmt;
337 case Stmt::CXXCatchStmtClass:
338 return PGOHash::CXXCatchStmt;
339 case Stmt::ConditionalOperatorClass:
340 return PGOHash::ConditionalOperator;
341 case Stmt::BinaryConditionalOperatorClass:
342 return PGOHash::BinaryConditionalOperator;
343 case Stmt::BinaryOperatorClass: {
344 const BinaryOperator *BO = cast<BinaryOperator>(S);
345 if (BO->getOpcode() == BO_LAnd)
346 return PGOHash::BinaryOperatorLAnd;
347 if (BO->getOpcode() == BO_LOr)
348 return PGOHash::BinaryOperatorLOr;
349 break;
350 }
351 }
352 return PGOHash::None;
353 }
354 };
355
356 /// A StmtVisitor that propagates the raw counts through the AST and
357 /// records the count at statements where the value may change.
358 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
359 /// PGO state.
360 CodeGenPGO &PGO;
361
362 /// A flag that is set when the current count should be recorded on the
363 /// next statement, such as at the exit of a loop.
364 bool RecordNextStmtCount;
365
366 /// The map of statements to count values.
367 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
368
369 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
370 struct BreakContinue {
371 uint64_t BreakCount;
372 uint64_t ContinueCount;
BreakContinue__anon4141c6920111::ComputeRegionCounts::BreakContinue373 BreakContinue() : BreakCount(0), ContinueCount(0) {}
374 };
375 SmallVector<BreakContinue, 8> BreakContinueStack;
376
ComputeRegionCounts__anon4141c6920111::ComputeRegionCounts377 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
378 CodeGenPGO &PGO)
379 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
380
RecordStmtCount__anon4141c6920111::ComputeRegionCounts381 void RecordStmtCount(const Stmt *S) {
382 if (RecordNextStmtCount) {
383 CountMap[S] = PGO.getCurrentRegionCount();
384 RecordNextStmtCount = false;
385 }
386 }
387
VisitStmt__anon4141c6920111::ComputeRegionCounts388 void VisitStmt(const Stmt *S) {
389 RecordStmtCount(S);
390 for (Stmt::const_child_range I = S->children(); I; ++I) {
391 if (*I)
392 this->Visit(*I);
393 }
394 }
395
VisitFunctionDecl__anon4141c6920111::ComputeRegionCounts396 void VisitFunctionDecl(const FunctionDecl *D) {
397 // Counter tracks entry to the function body.
398 RegionCounter Cnt(PGO, D->getBody());
399 Cnt.beginRegion();
400 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
401 Visit(D->getBody());
402 }
403
404 // Skip lambda expressions. We visit these as FunctionDecls when we're
405 // generating them and aren't interested in the body when generating a
406 // parent context.
VisitLambdaExpr__anon4141c6920111::ComputeRegionCounts407 void VisitLambdaExpr(const LambdaExpr *LE) {}
408
VisitCapturedDecl__anon4141c6920111::ComputeRegionCounts409 void VisitCapturedDecl(const CapturedDecl *D) {
410 // Counter tracks entry to the capture body.
411 RegionCounter Cnt(PGO, D->getBody());
412 Cnt.beginRegion();
413 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
414 Visit(D->getBody());
415 }
416
VisitObjCMethodDecl__anon4141c6920111::ComputeRegionCounts417 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
418 // Counter tracks entry to the method body.
419 RegionCounter Cnt(PGO, D->getBody());
420 Cnt.beginRegion();
421 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
422 Visit(D->getBody());
423 }
424
VisitBlockDecl__anon4141c6920111::ComputeRegionCounts425 void VisitBlockDecl(const BlockDecl *D) {
426 // Counter tracks entry to the block body.
427 RegionCounter Cnt(PGO, D->getBody());
428 Cnt.beginRegion();
429 CountMap[D->getBody()] = PGO.getCurrentRegionCount();
430 Visit(D->getBody());
431 }
432
VisitReturnStmt__anon4141c6920111::ComputeRegionCounts433 void VisitReturnStmt(const ReturnStmt *S) {
434 RecordStmtCount(S);
435 if (S->getRetValue())
436 Visit(S->getRetValue());
437 PGO.setCurrentRegionUnreachable();
438 RecordNextStmtCount = true;
439 }
440
VisitGotoStmt__anon4141c6920111::ComputeRegionCounts441 void VisitGotoStmt(const GotoStmt *S) {
442 RecordStmtCount(S);
443 PGO.setCurrentRegionUnreachable();
444 RecordNextStmtCount = true;
445 }
446
VisitLabelStmt__anon4141c6920111::ComputeRegionCounts447 void VisitLabelStmt(const LabelStmt *S) {
448 RecordNextStmtCount = false;
449 // Counter tracks the block following the label.
450 RegionCounter Cnt(PGO, S);
451 Cnt.beginRegion();
452 CountMap[S] = PGO.getCurrentRegionCount();
453 Visit(S->getSubStmt());
454 }
455
VisitBreakStmt__anon4141c6920111::ComputeRegionCounts456 void VisitBreakStmt(const BreakStmt *S) {
457 RecordStmtCount(S);
458 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
460 PGO.setCurrentRegionUnreachable();
461 RecordNextStmtCount = true;
462 }
463
VisitContinueStmt__anon4141c6920111::ComputeRegionCounts464 void VisitContinueStmt(const ContinueStmt *S) {
465 RecordStmtCount(S);
466 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
468 PGO.setCurrentRegionUnreachable();
469 RecordNextStmtCount = true;
470 }
471
VisitWhileStmt__anon4141c6920111::ComputeRegionCounts472 void VisitWhileStmt(const WhileStmt *S) {
473 RecordStmtCount(S);
474 // Counter tracks the body of the loop.
475 RegionCounter Cnt(PGO, S);
476 BreakContinueStack.push_back(BreakContinue());
477 // Visit the body region first so the break/continue adjustments can be
478 // included when visiting the condition.
479 Cnt.beginRegion();
480 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
481 Visit(S->getBody());
482 Cnt.adjustForControlFlow();
483
484 // ...then go back and propagate counts through the condition. The count
485 // at the start of the condition is the sum of the incoming edges,
486 // the backedge from the end of the loop body, and the edges from
487 // continue statements.
488 BreakContinue BC = BreakContinueStack.pop_back_val();
489 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
490 Cnt.getAdjustedCount() + BC.ContinueCount);
491 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
492 Visit(S->getCond());
493 Cnt.adjustForControlFlow();
494 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
495 RecordNextStmtCount = true;
496 }
497
VisitDoStmt__anon4141c6920111::ComputeRegionCounts498 void VisitDoStmt(const DoStmt *S) {
499 RecordStmtCount(S);
500 // Counter tracks the body of the loop.
501 RegionCounter Cnt(PGO, S);
502 BreakContinueStack.push_back(BreakContinue());
503 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
504 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
505 Visit(S->getBody());
506 Cnt.adjustForControlFlow();
507
508 BreakContinue BC = BreakContinueStack.pop_back_val();
509 // The count at the start of the condition is equal to the count at the
510 // end of the body. The adjusted count does not include either the
511 // fall-through count coming into the loop or the continue count, so add
512 // both of those separately. This is coincidentally the same equation as
513 // with while loops but for different reasons.
514 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
515 Cnt.getAdjustedCount() + BC.ContinueCount);
516 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
517 Visit(S->getCond());
518 Cnt.adjustForControlFlow();
519 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
520 RecordNextStmtCount = true;
521 }
522
VisitForStmt__anon4141c6920111::ComputeRegionCounts523 void VisitForStmt(const ForStmt *S) {
524 RecordStmtCount(S);
525 if (S->getInit())
526 Visit(S->getInit());
527 // Counter tracks the body of the loop.
528 RegionCounter Cnt(PGO, S);
529 BreakContinueStack.push_back(BreakContinue());
530 // Visit the body region first. (This is basically the same as a while
531 // loop; see further comments in VisitWhileStmt.)
532 Cnt.beginRegion();
533 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
534 Visit(S->getBody());
535 Cnt.adjustForControlFlow();
536
537 // The increment is essentially part of the body but it needs to include
538 // the count for all the continue statements.
539 if (S->getInc()) {
540 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
541 BreakContinueStack.back().ContinueCount);
542 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
543 Visit(S->getInc());
544 Cnt.adjustForControlFlow();
545 }
546
547 BreakContinue BC = BreakContinueStack.pop_back_val();
548
549 // ...then go back and propagate counts through the condition.
550 if (S->getCond()) {
551 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
552 Cnt.getAdjustedCount() +
553 BC.ContinueCount);
554 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
555 Visit(S->getCond());
556 Cnt.adjustForControlFlow();
557 }
558 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
559 RecordNextStmtCount = true;
560 }
561
VisitCXXForRangeStmt__anon4141c6920111::ComputeRegionCounts562 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
563 RecordStmtCount(S);
564 Visit(S->getRangeStmt());
565 Visit(S->getBeginEndStmt());
566 // Counter tracks the body of the loop.
567 RegionCounter Cnt(PGO, S);
568 BreakContinueStack.push_back(BreakContinue());
569 // Visit the body region first. (This is basically the same as a while
570 // loop; see further comments in VisitWhileStmt.)
571 Cnt.beginRegion();
572 CountMap[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
573 Visit(S->getLoopVarStmt());
574 Visit(S->getBody());
575 Cnt.adjustForControlFlow();
576
577 // The increment is essentially part of the body but it needs to include
578 // the count for all the continue statements.
579 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
580 BreakContinueStack.back().ContinueCount);
581 CountMap[S->getInc()] = PGO.getCurrentRegionCount();
582 Visit(S->getInc());
583 Cnt.adjustForControlFlow();
584
585 BreakContinue BC = BreakContinueStack.pop_back_val();
586
587 // ...then go back and propagate counts through the condition.
588 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
589 Cnt.getAdjustedCount() +
590 BC.ContinueCount);
591 CountMap[S->getCond()] = PGO.getCurrentRegionCount();
592 Visit(S->getCond());
593 Cnt.adjustForControlFlow();
594 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
595 RecordNextStmtCount = true;
596 }
597
VisitObjCForCollectionStmt__anon4141c6920111::ComputeRegionCounts598 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
599 RecordStmtCount(S);
600 Visit(S->getElement());
601 // Counter tracks the body of the loop.
602 RegionCounter Cnt(PGO, S);
603 BreakContinueStack.push_back(BreakContinue());
604 Cnt.beginRegion();
605 CountMap[S->getBody()] = PGO.getCurrentRegionCount();
606 Visit(S->getBody());
607 BreakContinue BC = BreakContinueStack.pop_back_val();
608 Cnt.adjustForControlFlow();
609 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
610 RecordNextStmtCount = true;
611 }
612
VisitSwitchStmt__anon4141c6920111::ComputeRegionCounts613 void VisitSwitchStmt(const SwitchStmt *S) {
614 RecordStmtCount(S);
615 Visit(S->getCond());
616 PGO.setCurrentRegionUnreachable();
617 BreakContinueStack.push_back(BreakContinue());
618 Visit(S->getBody());
619 // If the switch is inside a loop, add the continue counts.
620 BreakContinue BC = BreakContinueStack.pop_back_val();
621 if (!BreakContinueStack.empty())
622 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
623 // Counter tracks the exit block of the switch.
624 RegionCounter ExitCnt(PGO, S);
625 ExitCnt.beginRegion();
626 RecordNextStmtCount = true;
627 }
628
VisitCaseStmt__anon4141c6920111::ComputeRegionCounts629 void VisitCaseStmt(const CaseStmt *S) {
630 RecordNextStmtCount = false;
631 // Counter for this particular case. This counts only jumps from the
632 // switch header and does not include fallthrough from the case before
633 // this one.
634 RegionCounter Cnt(PGO, S);
635 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
636 CountMap[S] = Cnt.getCount();
637 RecordNextStmtCount = true;
638 Visit(S->getSubStmt());
639 }
640
VisitDefaultStmt__anon4141c6920111::ComputeRegionCounts641 void VisitDefaultStmt(const DefaultStmt *S) {
642 RecordNextStmtCount = false;
643 // Counter for this default case. This does not include fallthrough from
644 // the previous case.
645 RegionCounter Cnt(PGO, S);
646 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
647 CountMap[S] = Cnt.getCount();
648 RecordNextStmtCount = true;
649 Visit(S->getSubStmt());
650 }
651
VisitIfStmt__anon4141c6920111::ComputeRegionCounts652 void VisitIfStmt(const IfStmt *S) {
653 RecordStmtCount(S);
654 // Counter tracks the "then" part of an if statement. The count for
655 // the "else" part, if it exists, will be calculated from this counter.
656 RegionCounter Cnt(PGO, S);
657 Visit(S->getCond());
658
659 Cnt.beginRegion();
660 CountMap[S->getThen()] = PGO.getCurrentRegionCount();
661 Visit(S->getThen());
662 Cnt.adjustForControlFlow();
663
664 if (S->getElse()) {
665 Cnt.beginElseRegion();
666 CountMap[S->getElse()] = PGO.getCurrentRegionCount();
667 Visit(S->getElse());
668 Cnt.adjustForControlFlow();
669 }
670 Cnt.applyAdjustmentsToRegion(0);
671 RecordNextStmtCount = true;
672 }
673
VisitCXXTryStmt__anon4141c6920111::ComputeRegionCounts674 void VisitCXXTryStmt(const CXXTryStmt *S) {
675 RecordStmtCount(S);
676 Visit(S->getTryBlock());
677 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
678 Visit(S->getHandler(I));
679 // Counter tracks the continuation block of the try statement.
680 RegionCounter Cnt(PGO, S);
681 Cnt.beginRegion();
682 RecordNextStmtCount = true;
683 }
684
VisitCXXCatchStmt__anon4141c6920111::ComputeRegionCounts685 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
686 RecordNextStmtCount = false;
687 // Counter tracks the catch statement's handler block.
688 RegionCounter Cnt(PGO, S);
689 Cnt.beginRegion();
690 CountMap[S] = PGO.getCurrentRegionCount();
691 Visit(S->getHandlerBlock());
692 }
693
VisitAbstractConditionalOperator__anon4141c6920111::ComputeRegionCounts694 void VisitAbstractConditionalOperator(
695 const AbstractConditionalOperator *E) {
696 RecordStmtCount(E);
697 // Counter tracks the "true" part of a conditional operator. The
698 // count in the "false" part will be calculated from this counter.
699 RegionCounter Cnt(PGO, E);
700 Visit(E->getCond());
701
702 Cnt.beginRegion();
703 CountMap[E->getTrueExpr()] = PGO.getCurrentRegionCount();
704 Visit(E->getTrueExpr());
705 Cnt.adjustForControlFlow();
706
707 Cnt.beginElseRegion();
708 CountMap[E->getFalseExpr()] = PGO.getCurrentRegionCount();
709 Visit(E->getFalseExpr());
710 Cnt.adjustForControlFlow();
711
712 Cnt.applyAdjustmentsToRegion(0);
713 RecordNextStmtCount = true;
714 }
715
VisitBinLAnd__anon4141c6920111::ComputeRegionCounts716 void VisitBinLAnd(const BinaryOperator *E) {
717 RecordStmtCount(E);
718 // Counter tracks the right hand side of a logical and operator.
719 RegionCounter Cnt(PGO, E);
720 Visit(E->getLHS());
721 Cnt.beginRegion();
722 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
723 Visit(E->getRHS());
724 Cnt.adjustForControlFlow();
725 Cnt.applyAdjustmentsToRegion(0);
726 RecordNextStmtCount = true;
727 }
728
VisitBinLOr__anon4141c6920111::ComputeRegionCounts729 void VisitBinLOr(const BinaryOperator *E) {
730 RecordStmtCount(E);
731 // Counter tracks the right hand side of a logical or operator.
732 RegionCounter Cnt(PGO, E);
733 Visit(E->getLHS());
734 Cnt.beginRegion();
735 CountMap[E->getRHS()] = PGO.getCurrentRegionCount();
736 Visit(E->getRHS());
737 Cnt.adjustForControlFlow();
738 Cnt.applyAdjustmentsToRegion(0);
739 RecordNextStmtCount = true;
740 }
741 };
742 }
743
combine(HashType Type)744 void PGOHash::combine(HashType Type) {
745 // Check that we never combine 0 and only have six bits.
746 assert(Type && "Hash is invalid: unexpected type 0");
747 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
748
749 // Pass through MD5 if enough work has built up.
750 if (Count && Count % NumTypesPerWord == 0) {
751 using namespace llvm::support;
752 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
753 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
754 Working = 0;
755 }
756
757 // Accumulate the current type.
758 ++Count;
759 Working = Working << NumBitsPerType | Type;
760 }
761
finalize()762 uint64_t PGOHash::finalize() {
763 // Use Working as the hash directly if we never used MD5.
764 if (Count <= NumTypesPerWord)
765 // No need to byte swap here, since none of the math was endian-dependent.
766 // This number will be byte-swapped as required on endianness transitions,
767 // so we will see the same value on the other side.
768 return Working;
769
770 // Check for remaining work in Working.
771 if (Working)
772 MD5.update(Working);
773
774 // Finalize the MD5 and return the hash.
775 llvm::MD5::MD5Result Result;
776 MD5.final(Result);
777 using namespace llvm::support;
778 return endian::read<uint64_t, little, unaligned>(Result);
779 }
780
emitRuntimeHook(CodeGenModule & CGM)781 static void emitRuntimeHook(CodeGenModule &CGM) {
782 const char *const RuntimeVarName = "__llvm_profile_runtime";
783 const char *const RuntimeUserName = "__llvm_profile_runtime_user";
784 if (CGM.getModule().getGlobalVariable(RuntimeVarName))
785 return;
786
787 // Declare the runtime hook.
788 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
789 auto *Int32Ty = llvm::Type::getInt32Ty(Ctx);
790 auto *Var = new llvm::GlobalVariable(CGM.getModule(), Int32Ty, false,
791 llvm::GlobalValue::ExternalLinkage,
792 nullptr, RuntimeVarName);
793
794 // Make a function that uses it.
795 auto *User = llvm::Function::Create(llvm::FunctionType::get(Int32Ty, false),
796 llvm::GlobalValue::LinkOnceODRLinkage,
797 RuntimeUserName, &CGM.getModule());
798 User->addFnAttr(llvm::Attribute::NoInline);
799 if (CGM.getCodeGenOpts().DisableRedZone)
800 User->addFnAttr(llvm::Attribute::NoRedZone);
801 CGBuilderTy Builder(llvm::BasicBlock::Create(CGM.getLLVMContext(), "", User));
802 auto *Load = Builder.CreateLoad(Var);
803 Builder.CreateRet(Load);
804
805 // Create a use of the function. Now the definition of the runtime variable
806 // should get pulled in, along with any static initializears.
807 CGM.addUsedGlobal(User);
808 }
809
assignRegionCounters(const Decl * D,llvm::Function * Fn)810 void CodeGenPGO::assignRegionCounters(const Decl *D, llvm::Function *Fn) {
811 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
812 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
813 if (!InstrumentRegions && !PGOReader)
814 return;
815 if (D->isImplicit())
816 return;
817 setFuncName(Fn);
818
819 // Set the linkage for variables based on the function linkage. Usually, we
820 // want to match it, but available_externally and extern_weak both have the
821 // wrong semantics.
822 VarLinkage = Fn->getLinkage();
823 switch (VarLinkage) {
824 case llvm::GlobalValue::ExternalWeakLinkage:
825 VarLinkage = llvm::GlobalValue::LinkOnceAnyLinkage;
826 break;
827 case llvm::GlobalValue::AvailableExternallyLinkage:
828 VarLinkage = llvm::GlobalValue::LinkOnceODRLinkage;
829 break;
830 default:
831 break;
832 }
833
834 mapRegionCounters(D);
835 if (InstrumentRegions) {
836 emitRuntimeHook(CGM);
837 emitCounterVariables();
838 }
839 if (PGOReader) {
840 SourceManager &SM = CGM.getContext().getSourceManager();
841 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
842 computeRegionCounts(D);
843 applyFunctionAttributes(PGOReader, Fn);
844 }
845 }
846
mapRegionCounters(const Decl * D)847 void CodeGenPGO::mapRegionCounters(const Decl *D) {
848 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
849 MapRegionCounters Walker(*RegionCounterMap);
850 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
851 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
852 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
853 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
854 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
855 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
856 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
857 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
858 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
859 NumRegionCounters = Walker.NextCounter;
860 FunctionHash = Walker.Hash.finalize();
861 }
862
computeRegionCounts(const Decl * D)863 void CodeGenPGO::computeRegionCounts(const Decl *D) {
864 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
865 ComputeRegionCounts Walker(*StmtCountMap, *this);
866 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
867 Walker.VisitFunctionDecl(FD);
868 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
869 Walker.VisitObjCMethodDecl(MD);
870 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
871 Walker.VisitBlockDecl(BD);
872 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
873 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
874 }
875
876 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)877 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
878 llvm::Function *Fn) {
879 if (!haveRegionCounts())
880 return;
881
882 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
883 uint64_t FunctionCount = getRegionCount(0);
884 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
885 // Turn on InlineHint attribute for hot functions.
886 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
887 Fn->addFnAttr(llvm::Attribute::InlineHint);
888 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
889 // Turn on Cold attribute for cold functions.
890 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
891 Fn->addFnAttr(llvm::Attribute::Cold);
892 }
893
emitCounterVariables()894 void CodeGenPGO::emitCounterVariables() {
895 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
896 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
897 NumRegionCounters);
898 RegionCounters =
899 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false, VarLinkage,
900 llvm::Constant::getNullValue(CounterTy),
901 getFuncVarName("counters"));
902 RegionCounters->setAlignment(8);
903 RegionCounters->setSection(getCountersSection(CGM));
904 }
905
emitCounterIncrement(CGBuilderTy & Builder,unsigned Counter)906 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
907 if (!RegionCounters)
908 return;
909 llvm::Value *Addr =
910 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
911 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
912 Count = Builder.CreateAdd(Count, Builder.getInt64(1));
913 Builder.CreateStore(Count, Addr);
914 }
915
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)916 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
917 bool IsInMainFile) {
918 CGM.getPGOStats().addVisited(IsInMainFile);
919 RegionCounts.reset(new std::vector<uint64_t>);
920 uint64_t Hash;
921 if (PGOReader->getFunctionCounts(getFuncName(), Hash, *RegionCounts)) {
922 CGM.getPGOStats().addMissing(IsInMainFile);
923 RegionCounts.reset();
924 } else if (Hash != FunctionHash ||
925 RegionCounts->size() != NumRegionCounters) {
926 CGM.getPGOStats().addMismatched(IsInMainFile);
927 RegionCounts.reset();
928 }
929 }
930
destroyRegionCounters()931 void CodeGenPGO::destroyRegionCounters() {
932 RegionCounterMap.reset();
933 StmtCountMap.reset();
934 RegionCounts.reset();
935 RegionCounters = nullptr;
936 }
937
938 /// \brief Calculate what to divide by to scale weights.
939 ///
940 /// Given the maximum weight, calculate a divisor that will scale all the
941 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)942 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
943 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
944 }
945
946 /// \brief Scale an individual branch weight (and add 1).
947 ///
948 /// Scale a 64-bit weight down to 32-bits using \c Scale.
949 ///
950 /// According to Laplace's Rule of Succession, it is better to compute the
951 /// weight based on the count plus 1, so universally add 1 to the value.
952 ///
953 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
954 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)955 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
956 assert(Scale && "scale by 0?");
957 uint64_t Scaled = Weight / Scale + 1;
958 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
959 return Scaled;
960 }
961
createBranchWeights(uint64_t TrueCount,uint64_t FalseCount)962 llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
963 uint64_t FalseCount) {
964 // Check for empty weights.
965 if (!TrueCount && !FalseCount)
966 return nullptr;
967
968 // Calculate how to scale down to 32-bits.
969 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
970
971 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
972 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
973 scaleBranchWeight(FalseCount, Scale));
974 }
975
createBranchWeights(ArrayRef<uint64_t> Weights)976 llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
977 // We need at least two elements to create meaningful weights.
978 if (Weights.size() < 2)
979 return nullptr;
980
981 // Check for empty weights.
982 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
983 if (MaxWeight == 0)
984 return nullptr;
985
986 // Calculate how to scale down to 32-bits.
987 uint64_t Scale = calculateWeightScale(MaxWeight);
988
989 SmallVector<uint32_t, 16> ScaledWeights;
990 ScaledWeights.reserve(Weights.size());
991 for (uint64_t W : Weights)
992 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
993
994 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
995 return MDHelper.createBranchWeights(ScaledWeights);
996 }
997
createLoopWeights(const Stmt * Cond,RegionCounter & Cnt)998 llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
999 RegionCounter &Cnt) {
1000 if (!haveRegionCounts())
1001 return nullptr;
1002 uint64_t LoopCount = Cnt.getCount();
1003 uint64_t CondCount = 0;
1004 bool Found = getStmtCount(Cond, CondCount);
1005 assert(Found && "missing expected loop condition count");
1006 (void)Found;
1007 if (CondCount == 0)
1008 return nullptr;
1009 return createBranchWeights(LoopCount,
1010 std::max(CondCount, LoopCount) - LoopCount);
1011 }
1012