• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // In this example we will use an IR transform to optimize a module as it
10 // passes through LLJIT's IRTransformLayer.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
15 #include "llvm/IR/LegacyPassManager.h"
16 #include "llvm/Support/InitLLVM.h"
17 #include "llvm/Support/TargetSelect.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "llvm/Transforms/IPO.h"
20 #include "llvm/Transforms/Scalar.h"
21 
22 #include "../ExampleModules.h"
23 
24 using namespace llvm;
25 using namespace llvm::orc;
26 
27 ExitOnError ExitOnErr;
28 
29 // Example IR module.
30 //
31 // This IR contains a recursive definition of the factorial function:
32 //
33 // fac(n) | n == 0    = 1
34 //        | otherwise = n * fac(n - 1)
35 //
36 // It also contains an entry function which calls the factorial function with
37 // an input value of 5.
38 //
39 // We expect the IR optimization transform that we build below to transform
40 // this into a non-recursive factorial function and an entry function that
41 // returns a constant value of 5!, or 120.
42 
43 const llvm::StringRef MainMod =
44     R"(
45 
46   define i32 @fac(i32 %n) {
47   entry:
48     %tobool = icmp eq i32 %n, 0
49     br i1 %tobool, label %return, label %if.then
50 
51   if.then:                                          ; preds = %entry
52     %arg = add nsw i32 %n, -1
53     %call_result = call i32 @fac(i32 %arg)
54     %result = mul nsw i32 %n, %call_result
55     br label %return
56 
57   return:                                           ; preds = %entry, %if.then
58     %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
59     ret i32 %final_result
60   }
61 
62   define i32 @entry() {
63   entry:
64     %result = call i32 @fac(i32 5)
65     ret i32 %result
66   }
67 
68 )";
69 
70 // A function object that creates a simple pass pipeline to apply to each
71 // module as it passes through the IRTransformLayer.
72 class MyOptimizationTransform {
73 public:
MyOptimizationTransform()74   MyOptimizationTransform() : PM(std::make_unique<legacy::PassManager>()) {
75     PM->add(createTailCallEliminationPass());
76     PM->add(createFunctionInliningPass());
77     PM->add(createIndVarSimplifyPass());
78     PM->add(createCFGSimplificationPass());
79   }
80 
operator ()(ThreadSafeModule TSM,MaterializationResponsibility & R)81   Expected<ThreadSafeModule> operator()(ThreadSafeModule TSM,
82                                         MaterializationResponsibility &R) {
83     TSM.withModuleDo([this](Module &M) {
84       dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M << "\n";
85       PM->run(M);
86       dbgs() << "--- AFTER OPTIMIZATION ---\n" << M << "\n";
87     });
88     return std::move(TSM);
89   }
90 
91 private:
92   std::unique_ptr<legacy::PassManager> PM;
93 };
94 
main(int argc,char * argv[])95 int main(int argc, char *argv[]) {
96   // Initialize LLVM.
97   InitLLVM X(argc, argv);
98 
99   InitializeNativeTarget();
100   InitializeNativeTargetAsmPrinter();
101 
102   ExitOnErr.setBanner(std::string(argv[0]) + ": ");
103 
104   // (1) Create LLJIT instance.
105   auto J = ExitOnErr(LLJITBuilder().create());
106 
107   // (2) Install transform to optimize modules when they're materialized.
108   J->getIRTransformLayer().setTransform(MyOptimizationTransform());
109 
110   // (3) Add modules.
111   ExitOnErr(J->addIRModule(ExitOnErr(parseExampleModule(MainMod, "MainMod"))));
112 
113   // (4) Look up the JIT'd function and call it.
114   auto EntrySym = ExitOnErr(J->lookup("entry"));
115   auto *Entry = (int (*)())EntrySym.getAddress();
116 
117   int Result = Entry();
118   outs() << "--- Result ---\n"
119          << "entry() = " << Result << "\n";
120 
121   return 0;
122 }
123