1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/IRReader/IRReader.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21
22 using namespace llvm;
23
24 namespace {
getBlockByName(Function * F,StringRef name)25 BasicBlock *getBlockByName(Function *F, StringRef name) {
26 for (auto &BB : *F)
27 if (BB.getName() == name)
28 return &BB;
29 return nullptr;
30 }
31
TEST(CodeExtractor,ExitStub)32 TEST(CodeExtractor, ExitStub) {
33 LLVMContext Ctx;
34 SMDiagnostic Err;
35 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
36 define i32 @foo(i32 %x, i32 %y, i32 %z) {
37 header:
38 %0 = icmp ugt i32 %x, %y
39 br i1 %0, label %body1, label %body2
40
41 body1:
42 %1 = add i32 %z, 2
43 br label %notExtracted
44
45 body2:
46 %2 = mul i32 %z, 7
47 br label %notExtracted
48
49 notExtracted:
50 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
51 %4 = add i32 %3, %x
52 ret i32 %4
53 }
54 )invalid",
55 Err, Ctx));
56
57 Function *Func = M->getFunction("foo");
58 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
59 getBlockByName(Func, "body1"),
60 getBlockByName(Func, "body2") };
61
62 CodeExtractor CE(Candidates);
63 EXPECT_TRUE(CE.isEligible());
64
65 CodeExtractorAnalysisCache CEAC(*Func);
66 Function *Outlined = CE.extractCodeRegion(CEAC);
67 EXPECT_TRUE(Outlined);
68 BasicBlock *Exit = getBlockByName(Func, "notExtracted");
69 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
70 // Ensure that PHI in exit block has only one incoming value (from code
71 // replacer block).
72 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
73 // Ensure that there is a PHI in outlined function with 2 incoming values.
74 EXPECT_TRUE(ExitSplit &&
75 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
76 EXPECT_FALSE(verifyFunction(*Outlined));
77 EXPECT_FALSE(verifyFunction(*Func));
78 }
79
TEST(CodeExtractor,ExitPHIOnePredFromRegion)80 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
81 LLVMContext Ctx;
82 SMDiagnostic Err;
83 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
84 define i32 @foo() {
85 header:
86 br i1 undef, label %extracted1, label %pred
87
88 pred:
89 br i1 undef, label %exit1, label %exit2
90
91 extracted1:
92 br i1 undef, label %extracted2, label %exit1
93
94 extracted2:
95 br label %exit2
96
97 exit1:
98 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
99 ret i32 %0
100
101 exit2:
102 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
103 ret i32 %1
104 }
105 )invalid", Err, Ctx));
106
107 Function *Func = M->getFunction("foo");
108 SmallVector<BasicBlock *, 2> ExtractedBlocks{
109 getBlockByName(Func, "extracted1"),
110 getBlockByName(Func, "extracted2")
111 };
112
113 CodeExtractor CE(ExtractedBlocks);
114 EXPECT_TRUE(CE.isEligible());
115
116 CodeExtractorAnalysisCache CEAC(*Func);
117 Function *Outlined = CE.extractCodeRegion(CEAC);
118 EXPECT_TRUE(Outlined);
119 BasicBlock *Exit1 = getBlockByName(Func, "exit1");
120 BasicBlock *Exit2 = getBlockByName(Func, "exit2");
121 // Ensure that PHIs in exits are not splitted (since that they have only one
122 // incoming value from extracted region).
123 EXPECT_TRUE(Exit1 &&
124 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
125 EXPECT_TRUE(Exit2 &&
126 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
127 EXPECT_FALSE(verifyFunction(*Outlined));
128 EXPECT_FALSE(verifyFunction(*Func));
129 }
130
TEST(CodeExtractor,StoreOutputInvokeResultAfterEHPad)131 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
132 LLVMContext Ctx;
133 SMDiagnostic Err;
134 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
135 declare i8 @hoge()
136
137 define i32 @foo() personality i8* null {
138 entry:
139 %call = invoke i8 @hoge()
140 to label %invoke.cont unwind label %lpad
141
142 invoke.cont: ; preds = %entry
143 unreachable
144
145 lpad: ; preds = %entry
146 %0 = landingpad { i8*, i32 }
147 catch i8* null
148 br i1 undef, label %catch, label %finally.catchall
149
150 catch: ; preds = %lpad
151 %call2 = invoke i8 @hoge()
152 to label %invoke.cont2 unwind label %lpad2
153
154 invoke.cont2: ; preds = %catch
155 %call3 = invoke i8 @hoge()
156 to label %invoke.cont3 unwind label %lpad2
157
158 invoke.cont3: ; preds = %invoke.cont2
159 unreachable
160
161 lpad2: ; preds = %invoke.cont2, %catch
162 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
163 %1 = landingpad { i8*, i32 }
164 catch i8* null
165 br label %finally.catchall
166
167 finally.catchall: ; preds = %lpad33, %lpad
168 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
169 unreachable
170 }
171 )invalid", Err, Ctx));
172
173 if (!M) {
174 Err.print("unit", errs());
175 exit(1);
176 }
177
178 Function *Func = M->getFunction("foo");
179 EXPECT_FALSE(verifyFunction(*Func, &errs()));
180
181 SmallVector<BasicBlock *, 2> ExtractedBlocks{
182 getBlockByName(Func, "catch"),
183 getBlockByName(Func, "invoke.cont2"),
184 getBlockByName(Func, "invoke.cont3"),
185 getBlockByName(Func, "lpad2")
186 };
187
188 CodeExtractor CE(ExtractedBlocks);
189 EXPECT_TRUE(CE.isEligible());
190
191 CodeExtractorAnalysisCache CEAC(*Func);
192 Function *Outlined = CE.extractCodeRegion(CEAC);
193 EXPECT_TRUE(Outlined);
194 EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
195 EXPECT_FALSE(verifyFunction(*Func, &errs()));
196 }
197
TEST(CodeExtractor,StoreOutputInvokeResultInExitStub)198 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
199 LLVMContext Ctx;
200 SMDiagnostic Err;
201 std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
202 declare i32 @bar()
203
204 define i32 @foo() personality i8* null {
205 entry:
206 %0 = invoke i32 @bar() to label %exit unwind label %lpad
207
208 exit:
209 ret i32 %0
210
211 lpad:
212 %1 = landingpad { i8*, i32 }
213 cleanup
214 resume { i8*, i32 } %1
215 }
216 )invalid",
217 Err, Ctx));
218
219 Function *Func = M->getFunction("foo");
220 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
221 getBlockByName(Func, "lpad") };
222
223 CodeExtractor CE(Blocks);
224 EXPECT_TRUE(CE.isEligible());
225
226 CodeExtractorAnalysisCache CEAC(*Func);
227 Function *Outlined = CE.extractCodeRegion(CEAC);
228 EXPECT_TRUE(Outlined);
229 EXPECT_FALSE(verifyFunction(*Outlined));
230 EXPECT_FALSE(verifyFunction(*Func));
231 }
232
TEST(CodeExtractor,ExtractAndInvalidateAssumptionCache)233 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
234 LLVMContext Ctx;
235 SMDiagnostic Err;
236 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
237 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
238 target triple = "aarch64"
239
240 %b = type { i64 }
241 declare void @g(i8*)
242
243 declare void @llvm.assume(i1) #0
244
245 define void @test() {
246 entry:
247 br label %label
248
249 label:
250 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
251 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
252 %2 = load i64, i64* %1, align 8
253 %3 = icmp ugt i64 %2, 1
254 br i1 %3, label %if.then, label %if.else
255
256 if.then:
257 unreachable
258
259 if.else:
260 call void @g(i8* undef)
261 store i64 undef, i64* null, align 536870912
262 %4 = icmp eq i64 %2, 0
263 call void @llvm.assume(i1 %4)
264 unreachable
265 }
266
267 attributes #0 = { nounwind willreturn }
268 )ir",
269 Err, Ctx));
270
271 assert(M && "Could not parse module?");
272 Function *Func = M->getFunction("test");
273 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
274 AssumptionCache AC(*Func);
275 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
276 EXPECT_TRUE(CE.isEligible());
277
278 CodeExtractorAnalysisCache CEAC(*Func);
279 Function *Outlined = CE.extractCodeRegion(CEAC);
280 EXPECT_TRUE(Outlined);
281 EXPECT_FALSE(verifyFunction(*Outlined));
282 EXPECT_FALSE(verifyFunction(*Func));
283 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
284 }
285
TEST(CodeExtractor,RemoveBitcastUsesFromOuterLifetimeMarkers)286 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
287 LLVMContext Ctx;
288 SMDiagnostic Err;
289 std::unique_ptr<Module> M(parseAssemblyString(R"ir(
290 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
291 target triple = "x86_64-unknown-linux-gnu"
292
293 declare void @use(i32*)
294 declare void @llvm.lifetime.start.p0i8(i64, i8*)
295 declare void @llvm.lifetime.end.p0i8(i64, i8*)
296
297 define void @foo() {
298 entry:
299 %0 = alloca i32
300 br label %extract
301
302 extract:
303 %1 = bitcast i32* %0 to i8*
304 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
305 call void @use(i32* %0)
306 br label %exit
307
308 exit:
309 call void @use(i32* %0)
310 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
311 ret void
312 }
313 )ir",
314 Err, Ctx));
315
316 Function *Func = M->getFunction("foo");
317 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
318
319 CodeExtractor CE(Blocks);
320 EXPECT_TRUE(CE.isEligible());
321
322 CodeExtractorAnalysisCache CEAC(*Func);
323 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
324 BasicBlock *CommonExit = nullptr;
325 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
326 CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
327 EXPECT_EQ(Outputs.size(), 0U);
328
329 Function *Outlined = CE.extractCodeRegion(CEAC);
330 EXPECT_TRUE(Outlined);
331 EXPECT_FALSE(verifyFunction(*Outlined));
332 EXPECT_FALSE(verifyFunction(*Func));
333 }
334 } // end anonymous namespace
335