• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- SampleContextTracker.cpp - Context-sensitive Profile Tracker -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SampleContextTracker used by CSSPGO.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/IPO/SampleContextTracker.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/IR/DebugInfoMetadata.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/ProfileData/SampleProf.h"
19 #include <map>
20 #include <queue>
21 #include <vector>
22 
23 using namespace llvm;
24 using namespace sampleprof;
25 
26 #define DEBUG_TYPE "sample-context-tracker"
27 
28 namespace llvm {
29 
getChildContext(const LineLocation & CallSite,StringRef CalleeName)30 ContextTrieNode *ContextTrieNode::getChildContext(const LineLocation &CallSite,
31                                                   StringRef CalleeName) {
32   if (CalleeName.empty())
33     return getChildContext(CallSite);
34 
35   uint32_t Hash = nodeHash(CalleeName, CallSite);
36   auto It = AllChildContext.find(Hash);
37   if (It != AllChildContext.end())
38     return &It->second;
39   return nullptr;
40 }
41 
42 ContextTrieNode *
getChildContext(const LineLocation & CallSite)43 ContextTrieNode::getChildContext(const LineLocation &CallSite) {
44   // CSFDO-TODO: This could be slow, change AllChildContext so we can
45   // do point look up for child node by call site alone.
46   // CSFDO-TODO: Return the child with max count for indirect call
47   ContextTrieNode *ChildNodeRet = nullptr;
48   for (auto &It : AllChildContext) {
49     ContextTrieNode &ChildNode = It.second;
50     if (ChildNode.CallSiteLoc == CallSite) {
51       if (ChildNodeRet)
52         return nullptr;
53       else
54         ChildNodeRet = &ChildNode;
55     }
56   }
57 
58   return ChildNodeRet;
59 }
60 
moveToChildContext(const LineLocation & CallSite,ContextTrieNode && NodeToMove,StringRef ContextStrToRemove,bool DeleteNode)61 ContextTrieNode &ContextTrieNode::moveToChildContext(
62     const LineLocation &CallSite, ContextTrieNode &&NodeToMove,
63     StringRef ContextStrToRemove, bool DeleteNode) {
64   uint32_t Hash = nodeHash(NodeToMove.getFuncName(), CallSite);
65   assert(!AllChildContext.count(Hash) && "Node to remove must exist");
66   LineLocation OldCallSite = NodeToMove.CallSiteLoc;
67   ContextTrieNode &OldParentContext = *NodeToMove.getParentContext();
68   AllChildContext[Hash] = NodeToMove;
69   ContextTrieNode &NewNode = AllChildContext[Hash];
70   NewNode.CallSiteLoc = CallSite;
71 
72   // Walk through nodes in the moved the subtree, and update
73   // FunctionSamples' context as for the context promotion.
74   // We also need to set new parant link for all children.
75   std::queue<ContextTrieNode *> NodeToUpdate;
76   NewNode.setParentContext(this);
77   NodeToUpdate.push(&NewNode);
78 
79   while (!NodeToUpdate.empty()) {
80     ContextTrieNode *Node = NodeToUpdate.front();
81     NodeToUpdate.pop();
82     FunctionSamples *FSamples = Node->getFunctionSamples();
83 
84     if (FSamples) {
85       FSamples->getContext().promoteOnPath(ContextStrToRemove);
86       FSamples->getContext().setState(SyntheticContext);
87       LLVM_DEBUG(dbgs() << "  Context promoted to: " << FSamples->getContext()
88                         << "\n");
89     }
90 
91     for (auto &It : Node->getAllChildContext()) {
92       ContextTrieNode *ChildNode = &It.second;
93       ChildNode->setParentContext(Node);
94       NodeToUpdate.push(ChildNode);
95     }
96   }
97 
98   // Original context no longer needed, destroy if requested.
99   if (DeleteNode)
100     OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName());
101 
102   return NewNode;
103 }
104 
removeChildContext(const LineLocation & CallSite,StringRef CalleeName)105 void ContextTrieNode::removeChildContext(const LineLocation &CallSite,
106                                          StringRef CalleeName) {
107   uint32_t Hash = nodeHash(CalleeName, CallSite);
108   // Note this essentially calls dtor and destroys that child context
109   AllChildContext.erase(Hash);
110 }
111 
getAllChildContext()112 std::map<uint32_t, ContextTrieNode> &ContextTrieNode::getAllChildContext() {
113   return AllChildContext;
114 }
115 
getFuncName() const116 const StringRef ContextTrieNode::getFuncName() const { return FuncName; }
117 
getFunctionSamples() const118 FunctionSamples *ContextTrieNode::getFunctionSamples() const {
119   return FuncSamples;
120 }
121 
setFunctionSamples(FunctionSamples * FSamples)122 void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) {
123   FuncSamples = FSamples;
124 }
125 
getCallSiteLoc() const126 LineLocation ContextTrieNode::getCallSiteLoc() const { return CallSiteLoc; }
127 
getParentContext() const128 ContextTrieNode *ContextTrieNode::getParentContext() const {
129   return ParentContext;
130 }
131 
setParentContext(ContextTrieNode * Parent)132 void ContextTrieNode::setParentContext(ContextTrieNode *Parent) {
133   ParentContext = Parent;
134 }
135 
dump()136 void ContextTrieNode::dump() {
137   dbgs() << "Node: " << FuncName << "\n"
138          << "  Callsite: " << CallSiteLoc << "\n"
139          << "  Children:\n";
140 
141   for (auto &It : AllChildContext) {
142     dbgs() << "    Node: " << It.second.getFuncName() << "\n";
143   }
144 }
145 
nodeHash(StringRef ChildName,const LineLocation & Callsite)146 uint32_t ContextTrieNode::nodeHash(StringRef ChildName,
147                                    const LineLocation &Callsite) {
148   // We still use child's name for child hash, this is
149   // because for children of root node, we don't have
150   // different line/discriminator, and we'll rely on name
151   // to differentiate children.
152   uint32_t NameHash = std::hash<std::string>{}(ChildName.str());
153   uint32_t LocId = (Callsite.LineOffset << 16) | Callsite.Discriminator;
154   return NameHash + (LocId << 5) + LocId;
155 }
156 
getOrCreateChildContext(const LineLocation & CallSite,StringRef CalleeName,bool AllowCreate)157 ContextTrieNode *ContextTrieNode::getOrCreateChildContext(
158     const LineLocation &CallSite, StringRef CalleeName, bool AllowCreate) {
159   uint32_t Hash = nodeHash(CalleeName, CallSite);
160   auto It = AllChildContext.find(Hash);
161   if (It != AllChildContext.end()) {
162     assert(It->second.getFuncName() == CalleeName &&
163            "Hash collision for child context node");
164     return &It->second;
165   }
166 
167   if (!AllowCreate)
168     return nullptr;
169 
170   AllChildContext[Hash] = ContextTrieNode(this, CalleeName, nullptr, CallSite);
171   return &AllChildContext[Hash];
172 }
173 
174 // Profiler tracker than manages profiles and its associated context
SampleContextTracker(StringMap<FunctionSamples> & Profiles)175 SampleContextTracker::SampleContextTracker(
176     StringMap<FunctionSamples> &Profiles) {
177   for (auto &FuncSample : Profiles) {
178     FunctionSamples *FSamples = &FuncSample.second;
179     SampleContext Context(FuncSample.first(), RawContext);
180     LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context << "\n");
181     if (!Context.isBaseContext())
182       FuncToCtxtProfileSet[Context.getName()].insert(FSamples);
183     ContextTrieNode *NewNode = getOrCreateContextPath(Context, true);
184     assert(!NewNode->getFunctionSamples() &&
185            "New node can't have sample profile");
186     NewNode->setFunctionSamples(FSamples);
187   }
188 }
189 
190 FunctionSamples *
getCalleeContextSamplesFor(const CallBase & Inst,StringRef CalleeName)191 SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst,
192                                                  StringRef CalleeName) {
193   LLVM_DEBUG(dbgs() << "Getting callee context for instr: " << Inst << "\n");
194   // CSFDO-TODO: We use CalleeName to differentiate indirect call
195   // We need to get sample for indirect callee too.
196   DILocation *DIL = Inst.getDebugLoc();
197   if (!DIL)
198     return nullptr;
199 
200   ContextTrieNode *CalleeContext = getCalleeContextFor(DIL, CalleeName);
201   if (CalleeContext) {
202     FunctionSamples *FSamples = CalleeContext->getFunctionSamples();
203     LLVM_DEBUG(if (FSamples) {
204       dbgs() << "  Callee context found: " << FSamples->getContext() << "\n";
205     });
206     return FSamples;
207   }
208 
209   return nullptr;
210 }
211 
212 FunctionSamples *
getContextSamplesFor(const DILocation * DIL)213 SampleContextTracker::getContextSamplesFor(const DILocation *DIL) {
214   assert(DIL && "Expect non-null location");
215 
216   ContextTrieNode *ContextNode = getContextFor(DIL);
217   if (!ContextNode)
218     return nullptr;
219 
220   // We may have inlined callees during pre-LTO compilation, in which case
221   // we need to rely on the inline stack from !dbg to mark context profile
222   // as inlined, instead of `MarkContextSamplesInlined` during inlining.
223   // Sample profile loader walks through all instructions to get profile,
224   // which calls this function. So once that is done, all previously inlined
225   // context profile should be marked properly.
226   FunctionSamples *Samples = ContextNode->getFunctionSamples();
227   if (Samples && ContextNode->getParentContext() != &RootContext)
228     Samples->getContext().setState(InlinedContext);
229 
230   return Samples;
231 }
232 
233 FunctionSamples *
getContextSamplesFor(const SampleContext & Context)234 SampleContextTracker::getContextSamplesFor(const SampleContext &Context) {
235   ContextTrieNode *Node = getContextFor(Context);
236   if (!Node)
237     return nullptr;
238 
239   return Node->getFunctionSamples();
240 }
241 
getBaseSamplesFor(const Function & Func,bool MergeContext)242 FunctionSamples *SampleContextTracker::getBaseSamplesFor(const Function &Func,
243                                                          bool MergeContext) {
244   StringRef CanonName = FunctionSamples::getCanonicalFnName(Func);
245   return getBaseSamplesFor(CanonName, MergeContext);
246 }
247 
getBaseSamplesFor(StringRef Name,bool MergeContext)248 FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name,
249                                                          bool MergeContext) {
250   LLVM_DEBUG(dbgs() << "Getting base profile for function: " << Name << "\n");
251   // Base profile is top-level node (child of root node), so try to retrieve
252   // existing top-level node for given function first. If it exists, it could be
253   // that we've merged base profile before, or there's actually context-less
254   // profile from the input (e.g. due to unreliable stack walking).
255   ContextTrieNode *Node = getTopLevelContextNode(Name);
256   if (MergeContext) {
257     LLVM_DEBUG(dbgs() << "  Merging context profile into base profile: " << Name
258                       << "\n");
259 
260     // We have profile for function under different contexts,
261     // create synthetic base profile and merge context profiles
262     // into base profile.
263     for (auto *CSamples : FuncToCtxtProfileSet[Name]) {
264       SampleContext &Context = CSamples->getContext();
265       ContextTrieNode *FromNode = getContextFor(Context);
266       if (FromNode == Node)
267         continue;
268 
269       // Skip inlined context profile and also don't re-merge any context
270       if (Context.hasState(InlinedContext) || Context.hasState(MergedContext))
271         continue;
272 
273       ContextTrieNode &ToNode = promoteMergeContextSamplesTree(*FromNode);
274       assert((!Node || Node == &ToNode) && "Expect only one base profile");
275       Node = &ToNode;
276     }
277   }
278 
279   // Still no profile even after merge/promotion (if allowed)
280   if (!Node)
281     return nullptr;
282 
283   return Node->getFunctionSamples();
284 }
285 
markContextSamplesInlined(const FunctionSamples * InlinedSamples)286 void SampleContextTracker::markContextSamplesInlined(
287     const FunctionSamples *InlinedSamples) {
288   assert(InlinedSamples && "Expect non-null inlined samples");
289   LLVM_DEBUG(dbgs() << "Marking context profile as inlined: "
290                     << InlinedSamples->getContext() << "\n");
291   InlinedSamples->getContext().setState(InlinedContext);
292 }
293 
promoteMergeContextSamplesTree(const Instruction & Inst,StringRef CalleeName)294 void SampleContextTracker::promoteMergeContextSamplesTree(
295     const Instruction &Inst, StringRef CalleeName) {
296   LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
297                     << Inst << "\n");
298   // CSFDO-TODO: We also need to promote context profile from indirect
299   // calls. We won't have callee names from those from call instr.
300   if (CalleeName.empty())
301     return;
302 
303   // Get the caller context for the call instruction, we don't use callee
304   // name from call because there can be context from indirect calls too.
305   DILocation *DIL = Inst.getDebugLoc();
306   ContextTrieNode *CallerNode = getContextFor(DIL);
307   if (!CallerNode)
308     return;
309 
310   // Get the context that needs to be promoted
311   LineLocation CallSite(FunctionSamples::getOffset(DIL),
312                         DIL->getBaseDiscriminator());
313   ContextTrieNode *NodeToPromo =
314       CallerNode->getChildContext(CallSite, CalleeName);
315   if (!NodeToPromo)
316     return;
317 
318   promoteMergeContextSamplesTree(*NodeToPromo);
319 }
320 
promoteMergeContextSamplesTree(ContextTrieNode & NodeToPromo)321 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
322     ContextTrieNode &NodeToPromo) {
323   // Promote the input node to be directly under root. This can happen
324   // when we decided to not inline a function under context represented
325   // by the input node. The promote and merge is then needed to reflect
326   // the context profile in the base (context-less) profile.
327   FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples();
328   assert(FromSamples && "Shouldn't promote a context without profile");
329   LLVM_DEBUG(dbgs() << "  Found context tree root to promote: "
330                     << FromSamples->getContext() << "\n");
331 
332   StringRef ContextStrToRemove = FromSamples->getContext().getCallingContext();
333   return promoteMergeContextSamplesTree(NodeToPromo, RootContext,
334                                         ContextStrToRemove);
335 }
336 
dump()337 void SampleContextTracker::dump() {
338   dbgs() << "Context Profile Tree:\n";
339   std::queue<ContextTrieNode *> NodeQueue;
340   NodeQueue.push(&RootContext);
341 
342   while (!NodeQueue.empty()) {
343     ContextTrieNode *Node = NodeQueue.front();
344     NodeQueue.pop();
345     Node->dump();
346 
347     for (auto &It : Node->getAllChildContext()) {
348       ContextTrieNode *ChildNode = &It.second;
349       NodeQueue.push(ChildNode);
350     }
351   }
352 }
353 
354 ContextTrieNode *
getContextFor(const SampleContext & Context)355 SampleContextTracker::getContextFor(const SampleContext &Context) {
356   return getOrCreateContextPath(Context, false);
357 }
358 
359 ContextTrieNode *
getCalleeContextFor(const DILocation * DIL,StringRef CalleeName)360 SampleContextTracker::getCalleeContextFor(const DILocation *DIL,
361                                           StringRef CalleeName) {
362   assert(DIL && "Expect non-null location");
363 
364   // CSSPGO-TODO: need to support indirect callee
365   if (CalleeName.empty())
366     return nullptr;
367 
368   ContextTrieNode *CallContext = getContextFor(DIL);
369   if (!CallContext)
370     return nullptr;
371 
372   return CallContext->getChildContext(
373       LineLocation(FunctionSamples::getOffset(DIL),
374                    DIL->getBaseDiscriminator()),
375       CalleeName);
376 }
377 
getContextFor(const DILocation * DIL)378 ContextTrieNode *SampleContextTracker::getContextFor(const DILocation *DIL) {
379   assert(DIL && "Expect non-null location");
380   SmallVector<std::pair<LineLocation, StringRef>, 10> S;
381 
382   // Use C++ linkage name if possible.
383   const DILocation *PrevDIL = DIL;
384   for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) {
385     StringRef Name = PrevDIL->getScope()->getSubprogram()->getLinkageName();
386     if (Name.empty())
387       Name = PrevDIL->getScope()->getSubprogram()->getName();
388     S.push_back(
389         std::make_pair(LineLocation(FunctionSamples::getOffset(DIL),
390                                     DIL->getBaseDiscriminator()), Name));
391     PrevDIL = DIL;
392   }
393 
394   // Push root node, note that root node like main may only
395   // a name, but not linkage name.
396   StringRef RootName = PrevDIL->getScope()->getSubprogram()->getLinkageName();
397   if (RootName.empty())
398     RootName = PrevDIL->getScope()->getSubprogram()->getName();
399   S.push_back(std::make_pair(LineLocation(0, 0), RootName));
400 
401   ContextTrieNode *ContextNode = &RootContext;
402   int I = S.size();
403   while (--I >= 0 && ContextNode) {
404     LineLocation &CallSite = S[I].first;
405     StringRef &CalleeName = S[I].second;
406     ContextNode = ContextNode->getChildContext(CallSite, CalleeName);
407   }
408 
409   if (I < 0)
410     return ContextNode;
411 
412   return nullptr;
413 }
414 
415 ContextTrieNode *
getOrCreateContextPath(const SampleContext & Context,bool AllowCreate)416 SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
417                                              bool AllowCreate) {
418   ContextTrieNode *ContextNode = &RootContext;
419   StringRef ContextRemain = Context;
420   StringRef ChildContext;
421   StringRef CalleeName;
422   LineLocation CallSiteLoc(0, 0);
423 
424   while (ContextNode && !ContextRemain.empty()) {
425     auto ContextSplit = SampleContext::splitContextString(ContextRemain);
426     ChildContext = ContextSplit.first;
427     ContextRemain = ContextSplit.second;
428     LineLocation NextCallSiteLoc(0, 0);
429     SampleContext::decodeContextString(ChildContext, CalleeName,
430                                        NextCallSiteLoc);
431 
432     // Create child node at parent line/disc location
433     if (AllowCreate) {
434       ContextNode =
435           ContextNode->getOrCreateChildContext(CallSiteLoc, CalleeName);
436     } else {
437       ContextNode = ContextNode->getChildContext(CallSiteLoc, CalleeName);
438     }
439     CallSiteLoc = NextCallSiteLoc;
440   }
441 
442   assert((!AllowCreate || ContextNode) &&
443          "Node must exist if creation is allowed");
444   return ContextNode;
445 }
446 
getTopLevelContextNode(StringRef FName)447 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) {
448   return RootContext.getChildContext(LineLocation(0, 0), FName);
449 }
450 
addTopLevelContextNode(StringRef FName)451 ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) {
452   assert(!getTopLevelContextNode(FName) && "Node to add must not exist");
453   return *RootContext.getOrCreateChildContext(LineLocation(0, 0), FName);
454 }
455 
mergeContextNode(ContextTrieNode & FromNode,ContextTrieNode & ToNode,StringRef ContextStrToRemove)456 void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode,
457                                             ContextTrieNode &ToNode,
458                                             StringRef ContextStrToRemove) {
459   FunctionSamples *FromSamples = FromNode.getFunctionSamples();
460   FunctionSamples *ToSamples = ToNode.getFunctionSamples();
461   if (FromSamples && ToSamples) {
462     // Merge/duplicate FromSamples into ToSamples
463     ToSamples->merge(*FromSamples);
464     ToSamples->getContext().setState(SyntheticContext);
465     FromSamples->getContext().setState(MergedContext);
466   } else if (FromSamples) {
467     // Transfer FromSamples from FromNode to ToNode
468     ToNode.setFunctionSamples(FromSamples);
469     FromSamples->getContext().setState(SyntheticContext);
470     FromSamples->getContext().promoteOnPath(ContextStrToRemove);
471     FromNode.setFunctionSamples(nullptr);
472   }
473 }
474 
promoteMergeContextSamplesTree(ContextTrieNode & FromNode,ContextTrieNode & ToNodeParent,StringRef ContextStrToRemove)475 ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
476     ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent,
477     StringRef ContextStrToRemove) {
478   assert(!ContextStrToRemove.empty() && "Context to remove can't be empty");
479 
480   // Ignore call site location if destination is top level under root
481   LineLocation NewCallSiteLoc = LineLocation(0, 0);
482   LineLocation OldCallSiteLoc = FromNode.getCallSiteLoc();
483   ContextTrieNode &FromNodeParent = *FromNode.getParentContext();
484   ContextTrieNode *ToNode = nullptr;
485   bool MoveToRoot = (&ToNodeParent == &RootContext);
486   if (!MoveToRoot) {
487     NewCallSiteLoc = OldCallSiteLoc;
488   }
489 
490   // Locate destination node, create/move if not existing
491   ToNode = ToNodeParent.getChildContext(NewCallSiteLoc, FromNode.getFuncName());
492   if (!ToNode) {
493     // Do not delete node to move from its parent here because
494     // caller is iterating over children of that parent node.
495     ToNode = &ToNodeParent.moveToChildContext(
496         NewCallSiteLoc, std::move(FromNode), ContextStrToRemove, false);
497   } else {
498     // Destination node exists, merge samples for the context tree
499     mergeContextNode(FromNode, *ToNode, ContextStrToRemove);
500     LLVM_DEBUG(dbgs() << "  Context promoted and merged to: "
501                       << ToNode->getFunctionSamples()->getContext() << "\n");
502 
503     // Recursively promote and merge children
504     for (auto &It : FromNode.getAllChildContext()) {
505       ContextTrieNode &FromChildNode = It.second;
506       promoteMergeContextSamplesTree(FromChildNode, *ToNode,
507                                      ContextStrToRemove);
508     }
509 
510     // Remove children once they're all merged
511     FromNode.getAllChildContext().clear();
512   }
513 
514   // For root of subtree, remove itself from old parent too
515   if (MoveToRoot)
516     FromNodeParent.removeChildContext(OldCallSiteLoc, ToNode->getFuncName());
517 
518   return *ToNode;
519 }
520 
521 } // namespace llvm
522