• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H
11 
12 #include "llvm/Analysis/CallGraph.h"
13 #include "llvm/Analysis/InlineAdvisor.h"
14 #include "llvm/Analysis/MLModelRunner.h"
15 #include "llvm/IR/PassManager.h"
16 
17 #include <memory>
18 #include <unordered_map>
19 
20 namespace llvm {
21 class Module;
22 class MLInlineAdvice;
23 
24 class MLInlineAdvisor : public InlineAdvisor {
25 public:
26   MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
27                   std::unique_ptr<MLModelRunner> ModelRunner);
28 
callGraph()29   CallGraph *callGraph() const { return CG.get(); }
30   virtual ~MLInlineAdvisor() = default;
31 
32   void onPassEntry() override;
33 
34   std::unique_ptr<InlineAdvice> getAdvice(CallBase &CB) override;
35 
getIRSize(const Function & F)36   int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
37   void onSuccessfulInlining(const MLInlineAdvice &Advice,
38                             bool CalleeWasDeleted);
39 
isForcedToStop()40   bool isForcedToStop() const { return ForceStop; }
41   int64_t getLocalCalls(Function &F);
getModelRunner()42   const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
43 
44 protected:
45   virtual std::unique_ptr<MLInlineAdvice>
46   getMandatoryAdvice(CallBase &CB, OptimizationRemarkEmitter &ORE);
47 
48   virtual std::unique_ptr<MLInlineAdvice>
49   getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
50 
51   Module &M;
52   std::unique_ptr<MLModelRunner> ModelRunner;
53 
54 private:
55   int64_t getModuleIRSize() const;
56 
57   std::unique_ptr<CallGraph> CG;
58 
59   int64_t NodeCount = 0;
60   int64_t EdgeCount = 0;
61   std::map<const Function *, unsigned> FunctionLevels;
62   const int32_t InitialIRSize = 0;
63   int32_t CurrentIRSize = 0;
64 
65   bool ForceStop = false;
66 };
67 
68 /// InlineAdvice that tracks changes post inlining. For that reason, it only
69 /// overrides the "successful inlining" extension points.
70 class MLInlineAdvice : public InlineAdvice {
71 public:
MLInlineAdvice(MLInlineAdvisor * Advisor,CallBase & CB,OptimizationRemarkEmitter & ORE,bool Recommendation)72   MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
73                  OptimizationRemarkEmitter &ORE, bool Recommendation)
74       : InlineAdvice(Advisor, CB, ORE, Recommendation),
75         CallerIRSize(Advisor->isForcedToStop() ? 0
76                                                : Advisor->getIRSize(*Caller)),
77         CalleeIRSize(Advisor->isForcedToStop() ? 0
78                                                : Advisor->getIRSize(*Callee)),
79         CallerAndCalleeEdges(Advisor->isForcedToStop()
80                                  ? 0
81                                  : (Advisor->getLocalCalls(*Caller) +
82                                     Advisor->getLocalCalls(*Callee))) {}
83   virtual ~MLInlineAdvice() = default;
84 
85   void recordInliningImpl() override;
86   void recordInliningWithCalleeDeletedImpl() override;
87   void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
88   void recordUnattemptedInliningImpl() override;
89 
getCaller()90   Function *getCaller() const { return Caller; }
getCallee()91   Function *getCallee() const { return Callee; }
92 
93   const int64_t CallerIRSize;
94   const int64_t CalleeIRSize;
95   const int64_t CallerAndCalleeEdges;
96 
97 private:
98   void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
99 
getAdvisor()100   MLInlineAdvisor *getAdvisor() const {
101     return static_cast<MLInlineAdvisor *>(Advisor);
102   };
103 };
104 
105 } // namespace llvm
106 
107 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H