• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25 
26 using namespace clang;
27 using namespace CodeGen;
28 
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)29 void CodeGenPGO::setFuncName(StringRef Name,
30                              llvm::GlobalValue::LinkageTypes Linkage) {
31   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
32   FuncName = llvm::getPGOFuncName(
33       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
34       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
35 
36   // If we're generating a profile, create a variable for the name.
37   if (CGM.getCodeGenOpts().ProfileInstrGenerate)
38     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
39 }
40 
setFuncName(llvm::Function * Fn)41 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
42   setFuncName(Fn->getName(), Fn->getLinkage());
43 }
44 
45 namespace {
46 /// \brief Stable hasher for PGO region counters.
47 ///
48 /// PGOHash produces a stable hash of a given function's control flow.
49 ///
50 /// Changing the output of this hash will invalidate all previously generated
51 /// profiles -- i.e., don't do it.
52 ///
53 /// \note  When this hash does eventually change (years?), we still need to
54 /// support old hashes.  We'll need to pull in the version number from the
55 /// profile data format and use the matching hash function.
56 class PGOHash {
57   uint64_t Working;
58   unsigned Count;
59   llvm::MD5 MD5;
60 
61   static const int NumBitsPerType = 6;
62   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
63   static const unsigned TooBig = 1u << NumBitsPerType;
64 
65 public:
66   /// \brief Hash values for AST nodes.
67   ///
68   /// Distinct values for AST nodes that have region counters attached.
69   ///
70   /// These values must be stable.  All new members must be added at the end,
71   /// and no members should be removed.  Changing the enumeration value for an
72   /// AST node will affect the hash of every function that contains that node.
73   enum HashType : unsigned char {
74     None = 0,
75     LabelStmt = 1,
76     WhileStmt,
77     DoStmt,
78     ForStmt,
79     CXXForRangeStmt,
80     ObjCForCollectionStmt,
81     SwitchStmt,
82     CaseStmt,
83     DefaultStmt,
84     IfStmt,
85     CXXTryStmt,
86     CXXCatchStmt,
87     ConditionalOperator,
88     BinaryOperatorLAnd,
89     BinaryOperatorLOr,
90     BinaryConditionalOperator,
91 
92     // Keep this last.  It's for the static assert that follows.
93     LastHashType
94   };
95   static_assert(LastHashType <= TooBig, "Too many types in HashType");
96 
97   // TODO: When this format changes, take in a version number here, and use the
98   // old hash calculation for file formats that used the old hash.
PGOHash()99   PGOHash() : Working(0), Count(0) {}
100   void combine(HashType Type);
101   uint64_t finalize();
102 };
103 const int PGOHash::NumBitsPerType;
104 const unsigned PGOHash::NumTypesPerWord;
105 const unsigned PGOHash::TooBig;
106 
107 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
108 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
109   /// The next counter value to assign.
110   unsigned NextCounter;
111   /// The function hash.
112   PGOHash Hash;
113   /// The map of statements to counters.
114   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
115 
MapRegionCounters__anon65c1757c0111::MapRegionCounters116   MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
117       : NextCounter(0), CounterMap(CounterMap) {}
118 
119   // Blocks and lambdas are handled as separate functions, so we need not
120   // traverse them in the parent context.
TraverseBlockExpr__anon65c1757c0111::MapRegionCounters121   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaBody__anon65c1757c0111::MapRegionCounters122   bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
TraverseCapturedStmt__anon65c1757c0111::MapRegionCounters123   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
124 
VisitDecl__anon65c1757c0111::MapRegionCounters125   bool VisitDecl(const Decl *D) {
126     switch (D->getKind()) {
127     default:
128       break;
129     case Decl::Function:
130     case Decl::CXXMethod:
131     case Decl::CXXConstructor:
132     case Decl::CXXDestructor:
133     case Decl::CXXConversion:
134     case Decl::ObjCMethod:
135     case Decl::Block:
136     case Decl::Captured:
137       CounterMap[D->getBody()] = NextCounter++;
138       break;
139     }
140     return true;
141   }
142 
VisitStmt__anon65c1757c0111::MapRegionCounters143   bool VisitStmt(const Stmt *S) {
144     auto Type = getHashType(S);
145     if (Type == PGOHash::None)
146       return true;
147 
148     CounterMap[S] = NextCounter++;
149     Hash.combine(Type);
150     return true;
151   }
getHashType__anon65c1757c0111::MapRegionCounters152   PGOHash::HashType getHashType(const Stmt *S) {
153     switch (S->getStmtClass()) {
154     default:
155       break;
156     case Stmt::LabelStmtClass:
157       return PGOHash::LabelStmt;
158     case Stmt::WhileStmtClass:
159       return PGOHash::WhileStmt;
160     case Stmt::DoStmtClass:
161       return PGOHash::DoStmt;
162     case Stmt::ForStmtClass:
163       return PGOHash::ForStmt;
164     case Stmt::CXXForRangeStmtClass:
165       return PGOHash::CXXForRangeStmt;
166     case Stmt::ObjCForCollectionStmtClass:
167       return PGOHash::ObjCForCollectionStmt;
168     case Stmt::SwitchStmtClass:
169       return PGOHash::SwitchStmt;
170     case Stmt::CaseStmtClass:
171       return PGOHash::CaseStmt;
172     case Stmt::DefaultStmtClass:
173       return PGOHash::DefaultStmt;
174     case Stmt::IfStmtClass:
175       return PGOHash::IfStmt;
176     case Stmt::CXXTryStmtClass:
177       return PGOHash::CXXTryStmt;
178     case Stmt::CXXCatchStmtClass:
179       return PGOHash::CXXCatchStmt;
180     case Stmt::ConditionalOperatorClass:
181       return PGOHash::ConditionalOperator;
182     case Stmt::BinaryConditionalOperatorClass:
183       return PGOHash::BinaryConditionalOperator;
184     case Stmt::BinaryOperatorClass: {
185       const BinaryOperator *BO = cast<BinaryOperator>(S);
186       if (BO->getOpcode() == BO_LAnd)
187         return PGOHash::BinaryOperatorLAnd;
188       if (BO->getOpcode() == BO_LOr)
189         return PGOHash::BinaryOperatorLOr;
190       break;
191     }
192     }
193     return PGOHash::None;
194   }
195 };
196 
197 /// A StmtVisitor that propagates the raw counts through the AST and
198 /// records the count at statements where the value may change.
199 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
200   /// PGO state.
201   CodeGenPGO &PGO;
202 
203   /// A flag that is set when the current count should be recorded on the
204   /// next statement, such as at the exit of a loop.
205   bool RecordNextStmtCount;
206 
207   /// The count at the current location in the traversal.
208   uint64_t CurrentCount;
209 
210   /// The map of statements to count values.
211   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
212 
213   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
214   struct BreakContinue {
215     uint64_t BreakCount;
216     uint64_t ContinueCount;
BreakContinue__anon65c1757c0111::ComputeRegionCounts::BreakContinue217     BreakContinue() : BreakCount(0), ContinueCount(0) {}
218   };
219   SmallVector<BreakContinue, 8> BreakContinueStack;
220 
ComputeRegionCounts__anon65c1757c0111::ComputeRegionCounts221   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
222                       CodeGenPGO &PGO)
223       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
224 
RecordStmtCount__anon65c1757c0111::ComputeRegionCounts225   void RecordStmtCount(const Stmt *S) {
226     if (RecordNextStmtCount) {
227       CountMap[S] = CurrentCount;
228       RecordNextStmtCount = false;
229     }
230   }
231 
232   /// Set and return the current count.
setCount__anon65c1757c0111::ComputeRegionCounts233   uint64_t setCount(uint64_t Count) {
234     CurrentCount = Count;
235     return Count;
236   }
237 
VisitStmt__anon65c1757c0111::ComputeRegionCounts238   void VisitStmt(const Stmt *S) {
239     RecordStmtCount(S);
240     for (const Stmt *Child : S->children())
241       if (Child)
242         this->Visit(Child);
243   }
244 
VisitFunctionDecl__anon65c1757c0111::ComputeRegionCounts245   void VisitFunctionDecl(const FunctionDecl *D) {
246     // Counter tracks entry to the function body.
247     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
248     CountMap[D->getBody()] = BodyCount;
249     Visit(D->getBody());
250   }
251 
252   // Skip lambda expressions. We visit these as FunctionDecls when we're
253   // generating them and aren't interested in the body when generating a
254   // parent context.
VisitLambdaExpr__anon65c1757c0111::ComputeRegionCounts255   void VisitLambdaExpr(const LambdaExpr *LE) {}
256 
VisitCapturedDecl__anon65c1757c0111::ComputeRegionCounts257   void VisitCapturedDecl(const CapturedDecl *D) {
258     // Counter tracks entry to the capture body.
259     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
260     CountMap[D->getBody()] = BodyCount;
261     Visit(D->getBody());
262   }
263 
VisitObjCMethodDecl__anon65c1757c0111::ComputeRegionCounts264   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
265     // Counter tracks entry to the method body.
266     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
267     CountMap[D->getBody()] = BodyCount;
268     Visit(D->getBody());
269   }
270 
VisitBlockDecl__anon65c1757c0111::ComputeRegionCounts271   void VisitBlockDecl(const BlockDecl *D) {
272     // Counter tracks entry to the block body.
273     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
274     CountMap[D->getBody()] = BodyCount;
275     Visit(D->getBody());
276   }
277 
VisitReturnStmt__anon65c1757c0111::ComputeRegionCounts278   void VisitReturnStmt(const ReturnStmt *S) {
279     RecordStmtCount(S);
280     if (S->getRetValue())
281       Visit(S->getRetValue());
282     CurrentCount = 0;
283     RecordNextStmtCount = true;
284   }
285 
VisitCXXThrowExpr__anon65c1757c0111::ComputeRegionCounts286   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
287     RecordStmtCount(E);
288     if (E->getSubExpr())
289       Visit(E->getSubExpr());
290     CurrentCount = 0;
291     RecordNextStmtCount = true;
292   }
293 
VisitGotoStmt__anon65c1757c0111::ComputeRegionCounts294   void VisitGotoStmt(const GotoStmt *S) {
295     RecordStmtCount(S);
296     CurrentCount = 0;
297     RecordNextStmtCount = true;
298   }
299 
VisitLabelStmt__anon65c1757c0111::ComputeRegionCounts300   void VisitLabelStmt(const LabelStmt *S) {
301     RecordNextStmtCount = false;
302     // Counter tracks the block following the label.
303     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
304     CountMap[S] = BlockCount;
305     Visit(S->getSubStmt());
306   }
307 
VisitBreakStmt__anon65c1757c0111::ComputeRegionCounts308   void VisitBreakStmt(const BreakStmt *S) {
309     RecordStmtCount(S);
310     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
311     BreakContinueStack.back().BreakCount += CurrentCount;
312     CurrentCount = 0;
313     RecordNextStmtCount = true;
314   }
315 
VisitContinueStmt__anon65c1757c0111::ComputeRegionCounts316   void VisitContinueStmt(const ContinueStmt *S) {
317     RecordStmtCount(S);
318     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
319     BreakContinueStack.back().ContinueCount += CurrentCount;
320     CurrentCount = 0;
321     RecordNextStmtCount = true;
322   }
323 
VisitWhileStmt__anon65c1757c0111::ComputeRegionCounts324   void VisitWhileStmt(const WhileStmt *S) {
325     RecordStmtCount(S);
326     uint64_t ParentCount = CurrentCount;
327 
328     BreakContinueStack.push_back(BreakContinue());
329     // Visit the body region first so the break/continue adjustments can be
330     // included when visiting the condition.
331     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
332     CountMap[S->getBody()] = CurrentCount;
333     Visit(S->getBody());
334     uint64_t BackedgeCount = CurrentCount;
335 
336     // ...then go back and propagate counts through the condition. The count
337     // at the start of the condition is the sum of the incoming edges,
338     // the backedge from the end of the loop body, and the edges from
339     // continue statements.
340     BreakContinue BC = BreakContinueStack.pop_back_val();
341     uint64_t CondCount =
342         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
343     CountMap[S->getCond()] = CondCount;
344     Visit(S->getCond());
345     setCount(BC.BreakCount + CondCount - BodyCount);
346     RecordNextStmtCount = true;
347   }
348 
VisitDoStmt__anon65c1757c0111::ComputeRegionCounts349   void VisitDoStmt(const DoStmt *S) {
350     RecordStmtCount(S);
351     uint64_t LoopCount = PGO.getRegionCount(S);
352 
353     BreakContinueStack.push_back(BreakContinue());
354     // The count doesn't include the fallthrough from the parent scope. Add it.
355     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
356     CountMap[S->getBody()] = BodyCount;
357     Visit(S->getBody());
358     uint64_t BackedgeCount = CurrentCount;
359 
360     BreakContinue BC = BreakContinueStack.pop_back_val();
361     // The count at the start of the condition is equal to the count at the
362     // end of the body, plus any continues.
363     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
364     CountMap[S->getCond()] = CondCount;
365     Visit(S->getCond());
366     setCount(BC.BreakCount + CondCount - LoopCount);
367     RecordNextStmtCount = true;
368   }
369 
VisitForStmt__anon65c1757c0111::ComputeRegionCounts370   void VisitForStmt(const ForStmt *S) {
371     RecordStmtCount(S);
372     if (S->getInit())
373       Visit(S->getInit());
374 
375     uint64_t ParentCount = CurrentCount;
376 
377     BreakContinueStack.push_back(BreakContinue());
378     // Visit the body region first. (This is basically the same as a while
379     // loop; see further comments in VisitWhileStmt.)
380     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
381     CountMap[S->getBody()] = BodyCount;
382     Visit(S->getBody());
383     uint64_t BackedgeCount = CurrentCount;
384     BreakContinue BC = BreakContinueStack.pop_back_val();
385 
386     // The increment is essentially part of the body but it needs to include
387     // the count for all the continue statements.
388     if (S->getInc()) {
389       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
390       CountMap[S->getInc()] = IncCount;
391       Visit(S->getInc());
392     }
393 
394     // ...then go back and propagate counts through the condition.
395     uint64_t CondCount =
396         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
397     if (S->getCond()) {
398       CountMap[S->getCond()] = CondCount;
399       Visit(S->getCond());
400     }
401     setCount(BC.BreakCount + CondCount - BodyCount);
402     RecordNextStmtCount = true;
403   }
404 
VisitCXXForRangeStmt__anon65c1757c0111::ComputeRegionCounts405   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
406     RecordStmtCount(S);
407     Visit(S->getLoopVarStmt());
408     Visit(S->getRangeStmt());
409     Visit(S->getBeginEndStmt());
410 
411     uint64_t ParentCount = CurrentCount;
412     BreakContinueStack.push_back(BreakContinue());
413     // Visit the body region first. (This is basically the same as a while
414     // loop; see further comments in VisitWhileStmt.)
415     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
416     CountMap[S->getBody()] = BodyCount;
417     Visit(S->getBody());
418     uint64_t BackedgeCount = CurrentCount;
419     BreakContinue BC = BreakContinueStack.pop_back_val();
420 
421     // The increment is essentially part of the body but it needs to include
422     // the count for all the continue statements.
423     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
424     CountMap[S->getInc()] = IncCount;
425     Visit(S->getInc());
426 
427     // ...then go back and propagate counts through the condition.
428     uint64_t CondCount =
429         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
430     CountMap[S->getCond()] = CondCount;
431     Visit(S->getCond());
432     setCount(BC.BreakCount + CondCount - BodyCount);
433     RecordNextStmtCount = true;
434   }
435 
VisitObjCForCollectionStmt__anon65c1757c0111::ComputeRegionCounts436   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
437     RecordStmtCount(S);
438     Visit(S->getElement());
439     uint64_t ParentCount = CurrentCount;
440     BreakContinueStack.push_back(BreakContinue());
441     // Counter tracks the body of the loop.
442     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
443     CountMap[S->getBody()] = BodyCount;
444     Visit(S->getBody());
445     uint64_t BackedgeCount = CurrentCount;
446     BreakContinue BC = BreakContinueStack.pop_back_val();
447 
448     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
449              BodyCount);
450     RecordNextStmtCount = true;
451   }
452 
VisitSwitchStmt__anon65c1757c0111::ComputeRegionCounts453   void VisitSwitchStmt(const SwitchStmt *S) {
454     RecordStmtCount(S);
455     Visit(S->getCond());
456     CurrentCount = 0;
457     BreakContinueStack.push_back(BreakContinue());
458     Visit(S->getBody());
459     // If the switch is inside a loop, add the continue counts.
460     BreakContinue BC = BreakContinueStack.pop_back_val();
461     if (!BreakContinueStack.empty())
462       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
463     // Counter tracks the exit block of the switch.
464     setCount(PGO.getRegionCount(S));
465     RecordNextStmtCount = true;
466   }
467 
VisitSwitchCase__anon65c1757c0111::ComputeRegionCounts468   void VisitSwitchCase(const SwitchCase *S) {
469     RecordNextStmtCount = false;
470     // Counter for this particular case. This counts only jumps from the
471     // switch header and does not include fallthrough from the case before
472     // this one.
473     uint64_t CaseCount = PGO.getRegionCount(S);
474     setCount(CurrentCount + CaseCount);
475     // We need the count without fallthrough in the mapping, so it's more useful
476     // for branch probabilities.
477     CountMap[S] = CaseCount;
478     RecordNextStmtCount = true;
479     Visit(S->getSubStmt());
480   }
481 
VisitIfStmt__anon65c1757c0111::ComputeRegionCounts482   void VisitIfStmt(const IfStmt *S) {
483     RecordStmtCount(S);
484     uint64_t ParentCount = CurrentCount;
485     Visit(S->getCond());
486 
487     // Counter tracks the "then" part of an if statement. The count for
488     // the "else" part, if it exists, will be calculated from this counter.
489     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
490     CountMap[S->getThen()] = ThenCount;
491     Visit(S->getThen());
492     uint64_t OutCount = CurrentCount;
493 
494     uint64_t ElseCount = ParentCount - ThenCount;
495     if (S->getElse()) {
496       setCount(ElseCount);
497       CountMap[S->getElse()] = ElseCount;
498       Visit(S->getElse());
499       OutCount += CurrentCount;
500     } else
501       OutCount += ElseCount;
502     setCount(OutCount);
503     RecordNextStmtCount = true;
504   }
505 
VisitCXXTryStmt__anon65c1757c0111::ComputeRegionCounts506   void VisitCXXTryStmt(const CXXTryStmt *S) {
507     RecordStmtCount(S);
508     Visit(S->getTryBlock());
509     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
510       Visit(S->getHandler(I));
511     // Counter tracks the continuation block of the try statement.
512     setCount(PGO.getRegionCount(S));
513     RecordNextStmtCount = true;
514   }
515 
VisitCXXCatchStmt__anon65c1757c0111::ComputeRegionCounts516   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
517     RecordNextStmtCount = false;
518     // Counter tracks the catch statement's handler block.
519     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
520     CountMap[S] = CatchCount;
521     Visit(S->getHandlerBlock());
522   }
523 
VisitAbstractConditionalOperator__anon65c1757c0111::ComputeRegionCounts524   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
525     RecordStmtCount(E);
526     uint64_t ParentCount = CurrentCount;
527     Visit(E->getCond());
528 
529     // Counter tracks the "true" part of a conditional operator. The
530     // count in the "false" part will be calculated from this counter.
531     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
532     CountMap[E->getTrueExpr()] = TrueCount;
533     Visit(E->getTrueExpr());
534     uint64_t OutCount = CurrentCount;
535 
536     uint64_t FalseCount = setCount(ParentCount - TrueCount);
537     CountMap[E->getFalseExpr()] = FalseCount;
538     Visit(E->getFalseExpr());
539     OutCount += CurrentCount;
540 
541     setCount(OutCount);
542     RecordNextStmtCount = true;
543   }
544 
VisitBinLAnd__anon65c1757c0111::ComputeRegionCounts545   void VisitBinLAnd(const BinaryOperator *E) {
546     RecordStmtCount(E);
547     uint64_t ParentCount = CurrentCount;
548     Visit(E->getLHS());
549     // Counter tracks the right hand side of a logical and operator.
550     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
551     CountMap[E->getRHS()] = RHSCount;
552     Visit(E->getRHS());
553     setCount(ParentCount + RHSCount - CurrentCount);
554     RecordNextStmtCount = true;
555   }
556 
VisitBinLOr__anon65c1757c0111::ComputeRegionCounts557   void VisitBinLOr(const BinaryOperator *E) {
558     RecordStmtCount(E);
559     uint64_t ParentCount = CurrentCount;
560     Visit(E->getLHS());
561     // Counter tracks the right hand side of a logical or operator.
562     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
563     CountMap[E->getRHS()] = RHSCount;
564     Visit(E->getRHS());
565     setCount(ParentCount + RHSCount - CurrentCount);
566     RecordNextStmtCount = true;
567   }
568 };
569 } // end anonymous namespace
570 
combine(HashType Type)571 void PGOHash::combine(HashType Type) {
572   // Check that we never combine 0 and only have six bits.
573   assert(Type && "Hash is invalid: unexpected type 0");
574   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
575 
576   // Pass through MD5 if enough work has built up.
577   if (Count && Count % NumTypesPerWord == 0) {
578     using namespace llvm::support;
579     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
580     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
581     Working = 0;
582   }
583 
584   // Accumulate the current type.
585   ++Count;
586   Working = Working << NumBitsPerType | Type;
587 }
588 
finalize()589 uint64_t PGOHash::finalize() {
590   // Use Working as the hash directly if we never used MD5.
591   if (Count <= NumTypesPerWord)
592     // No need to byte swap here, since none of the math was endian-dependent.
593     // This number will be byte-swapped as required on endianness transitions,
594     // so we will see the same value on the other side.
595     return Working;
596 
597   // Check for remaining work in Working.
598   if (Working)
599     MD5.update(Working);
600 
601   // Finalize the MD5 and return the hash.
602   llvm::MD5::MD5Result Result;
603   MD5.final(Result);
604   using namespace llvm::support;
605   return endian::read<uint64_t, little, unaligned>(Result);
606 }
607 
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)608 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
609   const Decl *D = GD.getDecl();
610   bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
611   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
612   if (!InstrumentRegions && !PGOReader)
613     return;
614   if (D->isImplicit())
615     return;
616   // Constructors and destructors may be represented by several functions in IR.
617   // If so, instrument only base variant, others are implemented by delegation
618   // to the base one, it would be counted twice otherwise.
619   if (CGM.getTarget().getCXXABI().hasConstructorVariants() &&
620       ((isa<CXXConstructorDecl>(GD.getDecl()) &&
621         GD.getCtorType() != Ctor_Base) ||
622        (isa<CXXDestructorDecl>(GD.getDecl()) &&
623         GD.getDtorType() != Dtor_Base))) {
624       return;
625   }
626   CGM.ClearUnusedCoverageMapping(D);
627   setFuncName(Fn);
628 
629   mapRegionCounters(D);
630   if (CGM.getCodeGenOpts().CoverageMapping)
631     emitCounterRegionMapping(D);
632   if (PGOReader) {
633     SourceManager &SM = CGM.getContext().getSourceManager();
634     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
635     computeRegionCounts(D);
636     applyFunctionAttributes(PGOReader, Fn);
637   }
638 }
639 
mapRegionCounters(const Decl * D)640 void CodeGenPGO::mapRegionCounters(const Decl *D) {
641   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
642   MapRegionCounters Walker(*RegionCounterMap);
643   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
644     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
645   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
646     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
647   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
648     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
649   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
650     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
651   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
652   NumRegionCounters = Walker.NextCounter;
653   FunctionHash = Walker.Hash.finalize();
654 }
655 
emitCounterRegionMapping(const Decl * D)656 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
657   if (SkipCoverageMapping)
658     return;
659   // Don't map the functions inside the system headers
660   auto Loc = D->getBody()->getLocStart();
661   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
662     return;
663 
664   std::string CoverageMapping;
665   llvm::raw_string_ostream OS(CoverageMapping);
666   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
667                                 CGM.getContext().getSourceManager(),
668                                 CGM.getLangOpts(), RegionCounterMap.get());
669   MappingGen.emitCounterMapping(D, OS);
670   OS.flush();
671 
672   if (CoverageMapping.empty())
673     return;
674 
675   CGM.getCoverageMapping()->addFunctionMappingRecord(
676       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
677 }
678 
679 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)680 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
681                                     llvm::GlobalValue::LinkageTypes Linkage) {
682   if (SkipCoverageMapping)
683     return;
684   // Don't map the functions inside the system headers
685   auto Loc = D->getBody()->getLocStart();
686   if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
687     return;
688 
689   std::string CoverageMapping;
690   llvm::raw_string_ostream OS(CoverageMapping);
691   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
692                                 CGM.getContext().getSourceManager(),
693                                 CGM.getLangOpts());
694   MappingGen.emitEmptyMapping(D, OS);
695   OS.flush();
696 
697   if (CoverageMapping.empty())
698     return;
699 
700   setFuncName(Name, Linkage);
701   CGM.getCoverageMapping()->addFunctionMappingRecord(
702       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
703 }
704 
computeRegionCounts(const Decl * D)705 void CodeGenPGO::computeRegionCounts(const Decl *D) {
706   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
707   ComputeRegionCounts Walker(*StmtCountMap, *this);
708   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
709     Walker.VisitFunctionDecl(FD);
710   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
711     Walker.VisitObjCMethodDecl(MD);
712   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
713     Walker.VisitBlockDecl(BD);
714   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
715     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
716 }
717 
718 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)719 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
720                                     llvm::Function *Fn) {
721   if (!haveRegionCounts())
722     return;
723 
724   uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
725   uint64_t FunctionCount = getRegionCount(nullptr);
726   if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
727     // Turn on InlineHint attribute for hot functions.
728     // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
729     Fn->addFnAttr(llvm::Attribute::InlineHint);
730   else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
731     // Turn on Cold attribute for cold functions.
732     // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
733     Fn->addFnAttr(llvm::Attribute::Cold);
734 
735   Fn->setEntryCount(FunctionCount);
736 }
737 
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S)738 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
739   if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
740     return;
741   if (!Builder.GetInsertBlock())
742     return;
743 
744   unsigned Counter = (*RegionCounterMap)[S];
745   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
746   Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
747                      {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
748                       Builder.getInt64(FunctionHash),
749                       Builder.getInt32(NumRegionCounters),
750                       Builder.getInt32(Counter)});
751 }
752 
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)753 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
754                                   bool IsInMainFile) {
755   CGM.getPGOStats().addVisited(IsInMainFile);
756   RegionCounts.clear();
757   if (std::error_code EC =
758           PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
759     if (EC == llvm::instrprof_error::unknown_function)
760       CGM.getPGOStats().addMissing(IsInMainFile);
761     else if (EC == llvm::instrprof_error::hash_mismatch)
762       CGM.getPGOStats().addMismatched(IsInMainFile);
763     else if (EC == llvm::instrprof_error::malformed)
764       // TODO: Consider a more specific warning for this case.
765       CGM.getPGOStats().addMismatched(IsInMainFile);
766     RegionCounts.clear();
767   }
768 }
769 
770 /// \brief Calculate what to divide by to scale weights.
771 ///
772 /// Given the maximum weight, calculate a divisor that will scale all the
773 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)774 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
775   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
776 }
777 
778 /// \brief Scale an individual branch weight (and add 1).
779 ///
780 /// Scale a 64-bit weight down to 32-bits using \c Scale.
781 ///
782 /// According to Laplace's Rule of Succession, it is better to compute the
783 /// weight based on the count plus 1, so universally add 1 to the value.
784 ///
785 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
786 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)787 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
788   assert(Scale && "scale by 0?");
789   uint64_t Scaled = Weight / Scale + 1;
790   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
791   return Scaled;
792 }
793 
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount)794 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
795                                                     uint64_t FalseCount) {
796   // Check for empty weights.
797   if (!TrueCount && !FalseCount)
798     return nullptr;
799 
800   // Calculate how to scale down to 32-bits.
801   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
802 
803   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
804   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
805                                       scaleBranchWeight(FalseCount, Scale));
806 }
807 
808 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights)809 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
810   // We need at least two elements to create meaningful weights.
811   if (Weights.size() < 2)
812     return nullptr;
813 
814   // Check for empty weights.
815   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
816   if (MaxWeight == 0)
817     return nullptr;
818 
819   // Calculate how to scale down to 32-bits.
820   uint64_t Scale = calculateWeightScale(MaxWeight);
821 
822   SmallVector<uint32_t, 16> ScaledWeights;
823   ScaledWeights.reserve(Weights.size());
824   for (uint64_t W : Weights)
825     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
826 
827   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
828   return MDHelper.createBranchWeights(ScaledWeights);
829 }
830 
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount)831 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
832                                                            uint64_t LoopCount) {
833   if (!PGO.haveRegionCounts())
834     return nullptr;
835   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
836   assert(CondCount.hasValue() && "missing expected loop condition count");
837   if (*CondCount == 0)
838     return nullptr;
839   return createProfileWeights(LoopCount,
840                               std::max(*CondCount, LoopCount) - LoopCount);
841 }
842