1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H 10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H 11 12 #include "llvm/Analysis/CallGraph.h" 13 #include "llvm/Analysis/InlineAdvisor.h" 14 #include "llvm/Analysis/MLModelRunner.h" 15 #include "llvm/IR/PassManager.h" 16 17 #include <memory> 18 #include <unordered_map> 19 20 namespace llvm { 21 class Module; 22 class MLInlineAdvice; 23 24 class MLInlineAdvisor : public InlineAdvisor { 25 public: 26 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, 27 std::unique_ptr<MLModelRunner> ModelRunner); 28 callGraph()29 CallGraph *callGraph() const { return CG.get(); } 30 virtual ~MLInlineAdvisor() = default; 31 32 void onPassEntry() override; 33 34 std::unique_ptr<InlineAdvice> getAdvice(CallBase &CB) override; 35 getIRSize(const Function & F)36 int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); } 37 void onSuccessfulInlining(const MLInlineAdvice &Advice, 38 bool CalleeWasDeleted); 39 isForcedToStop()40 bool isForcedToStop() const { return ForceStop; } 41 int64_t getLocalCalls(Function &F); getModelRunner()42 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } 43 44 protected: 45 virtual std::unique_ptr<MLInlineAdvice> 46 getMandatoryAdvice(CallBase &CB, OptimizationRemarkEmitter &ORE); 47 48 virtual std::unique_ptr<MLInlineAdvice> 49 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE); 50 51 Module &M; 52 std::unique_ptr<MLModelRunner> ModelRunner; 53 54 private: 55 int64_t getModuleIRSize() const; 56 57 std::unique_ptr<CallGraph> CG; 58 59 int64_t NodeCount = 0; 60 int64_t EdgeCount = 0; 61 std::map<const Function *, unsigned> FunctionLevels; 62 const int32_t InitialIRSize = 0; 63 int32_t CurrentIRSize = 0; 64 65 bool ForceStop = false; 66 }; 67 68 /// InlineAdvice that tracks changes post inlining. For that reason, it only 69 /// overrides the "successful inlining" extension points. 70 class MLInlineAdvice : public InlineAdvice { 71 public: MLInlineAdvice(MLInlineAdvisor * Advisor,CallBase & CB,OptimizationRemarkEmitter & ORE,bool Recommendation)72 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 73 OptimizationRemarkEmitter &ORE, bool Recommendation) 74 : InlineAdvice(Advisor, CB, ORE, Recommendation), 75 CallerIRSize(Advisor->isForcedToStop() ? 0 76 : Advisor->getIRSize(*Caller)), 77 CalleeIRSize(Advisor->isForcedToStop() ? 0 78 : Advisor->getIRSize(*Callee)), 79 CallerAndCalleeEdges(Advisor->isForcedToStop() 80 ? 0 81 : (Advisor->getLocalCalls(*Caller) + 82 Advisor->getLocalCalls(*Callee))) {} 83 virtual ~MLInlineAdvice() = default; 84 85 void recordInliningImpl() override; 86 void recordInliningWithCalleeDeletedImpl() override; 87 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; 88 void recordUnattemptedInliningImpl() override; 89 getCaller()90 Function *getCaller() const { return Caller; } getCallee()91 Function *getCallee() const { return Callee; } 92 93 const int64_t CallerIRSize; 94 const int64_t CalleeIRSize; 95 const int64_t CallerAndCalleeEdges; 96 97 private: 98 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR); 99 getAdvisor()100 MLInlineAdvisor *getAdvisor() const { 101 return static_cast<MLInlineAdvisor *>(Advisor); 102 }; 103 }; 104 105 } // namespace llvm 106 107 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H