• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "transforms/passes/gep_propagation.h"
17 
18 #include "transforms/gc_utils.h"
19 #include "transforms/transform_utils.h"
20 
21 #include <llvm/ADT/DenseMap.h>
22 #include <llvm/ADT/PostOrderIterator.h>
23 #include <llvm/IR/IRBuilder.h>
24 #include <llvm/Pass.h>
25 #include <llvm/Support/CommandLine.h>
26 
27 #define DEBUG_TYPE "gep-propagation"
28 
29 // Basic classes
30 using llvm::BasicBlock;
31 using llvm::DenseMap;
32 using llvm::Function;
33 using llvm::SmallVector;
34 using llvm::Type;
35 using llvm::Use;
36 using llvm::Value;
37 // Instructions
38 using llvm::CastInst;
39 using llvm::Constant;
40 using llvm::GetElementPtrInst;
41 using llvm::Instruction;
42 using llvm::PHINode;
43 using llvm::SelectInst;
44 // Gc utils
45 using ark::llvmbackend::gc_utils::IsGcRefType;
46 
47 /// Optimize no-op PHINodes and Selects in place.
48 // NOLINTNEXTLINE(fuchsia-statically-constructed-objects)
49 static llvm::cl::opt<bool> g_optimizeNoop("gprop-optimize", llvm::cl::Hidden, llvm::cl::init(true));
50 
51 namespace ark::llvmbackend::passes {
52 
run(llvm::Function & function,llvm::FunctionAnalysisManager &)53 llvm::PreservedAnalyses GepPropagation::run(llvm::Function &function, llvm::FunctionAnalysisManager & /*AM*/)
54 {
55     LLVM_DEBUG(llvm::dbgs() << "Function: " << function.getName() << "\n");
56     if (!gc_utils::IsGcFunction(function) || gc_utils::IsFunctionSupplemental(function)) {
57         return llvm::PreservedAnalyses::all();
58     }
59 
60     Propagate(&function);
61     return llvm::PreservedAnalyses::none();
62 }
63 
AddToVector(Instruction * inst,SmallVector<Instruction * > * toExpand,SmallVector<Instruction * > * selectors)64 void GepPropagation::AddToVector(Instruction *inst, SmallVector<Instruction *> *toExpand,
65                                  SmallVector<Instruction *> *selectors)
66 {
67     switch (inst->getOpcode()) {
68         // Derived references
69         case Instruction::GetElementPtr:
70             if (IsGcRefType(inst->getType())) {
71                 toExpand->push_back(inst);
72             }
73             break;
74         // Escaping managed scope
75         case Instruction::PtrToInt:
76         case Instruction::AddrSpaceCast:
77             if (IsGcRefType(inst->getOperand(0)->getType())) {
78                 toExpand->push_back(inst);
79             }
80             break;
81         case Instruction::Select:
82         case Instruction::PHI:
83             if ((IsGcRefType(inst->getType()) && gc_utils::IsDerived(inst)) || gc_utils::HasBeenGcRef(inst, false)) {
84                 selectors->push_back(inst);
85                 toExpand->push_back(inst);
86             }
87             break;
88         default:
89             break;
90     }
91 }
92 
Propagate(Function * function)93 void GepPropagation::Propagate(Function *function)
94 {
95     SmallVector<Instruction *> toExpand;
96     SmallVector<Instruction *> selectors;
97     llvm::ReversePostOrderTraversal<Function *> rpo(function);
98     for (auto block : rpo) {
99         for (auto &inst : *block) {
100             AddToVector(&inst, &toExpand, &selectors);
101         }
102     }
103 
104     DenseMap<Instruction *, Instruction *> sgeps;
105     SplitGepSelectors(function, &selectors, &sgeps);
106 
107     while (!toExpand.empty()) {
108         auto gep = toExpand.pop_back_val();
109         if (sgeps.find(gep) != sgeps.end()) {
110             gep = sgeps[gep];
111         }
112         SmallVector<Instruction *, 1> seq;
113         ReplaceRecursively(gep, &seq);
114     }
115 }
116 
FindSplitGep(Function * function,Instruction * inst,Instruction ** ipoint)117 static std::pair<Instruction *, Instruction *> FindSplitGep(Function *function, Instruction *inst, Instruction **ipoint)
118 {
119     auto &ctx = function->getContext();
120     auto bptrTy = llvm::PointerType::get(ctx, ark::llvmbackend::LLVMArkInterface::GC_ADDR_SPACE);
121     auto undefBase = llvm::UndefValue::get(bptrTy);
122     auto undefOffset = llvm::UndefValue::get(Type::getInt32Ty(ctx));
123     auto phi = llvm::dyn_cast<PHINode>(inst);
124     Instruction *mbase;
125     Instruction *moff;
126     if (phi != nullptr) {
127         mbase = PHINode::Create(bptrTy, phi->getNumIncomingValues(), "base", phi);
128         moff = PHINode::Create(Type::getInt32Ty(ctx), phi->getNumIncomingValues(), "gepoff", phi);
129         for (size_t i = 0; i < phi->getNumIncomingValues(); ++i) {
130             auto bb = phi->getIncomingBlock(i);
131             llvm::cast<PHINode>(mbase)->addIncoming(undefBase, bb);
132             llvm::cast<PHINode>(moff)->addIncoming(undefOffset, bb);
133         }
134         *ipoint = phi->getParent()->getFirstNonPHI();
135     } else {
136         auto condition = llvm::cast<SelectInst>(inst)->getCondition();
137         mbase = SelectInst::Create(condition, undefBase, undefBase, "base", inst);
138         moff = SelectInst::Create(condition, undefOffset, undefOffset, "gepoff", inst);
139         *ipoint = inst;
140     }
141     return {mbase, moff};
142 }
143 
SplitGepSelectors(Function * function,SmallVector<Instruction * > * selectors,DenseMap<Instruction *,Instruction * > * sgeps)144 void GepPropagation::SplitGepSelectors(Function *function, SmallVector<Instruction *> *selectors,
145                                        DenseMap<Instruction *, Instruction *> *sgeps)
146 {
147     auto &ctx = function->getContext();
148     SelectorSplitMap mapping;
149     for (auto inst : *selectors) {
150         Instruction *ipoint = nullptr;
151         auto splitGep = FindSplitGep(function, inst, &ipoint);
152         mapping[inst] = splitGep;
153         auto gep = GetElementPtrInst::CreateInBounds(Type::getInt8Ty(ctx), splitGep.first, splitGep.second, "", ipoint);
154         gep->takeName(inst);
155         sgeps->insert({inst, gep});
156     }
157 
158     for (auto inst : *selectors) {
159         ASSERT((llvm::isa<PHINode, SelectInst>(inst)));
160         GenerateSelectorInputs(inst, mapping);
161         ReplaceWithSplitGep(inst, sgeps->lookup(inst));
162     }
163 
164     // Erase GEP selectors as we replaced them with their split versions.
165     while (!selectors->empty()) {
166         auto val = selectors->pop_back_val();
167         [[maybe_unused]] auto safeToDelete = [selectors](auto user) {
168             auto inst = llvm::cast<Instruction>(user);
169             if (!llvm::isa<PHINode, SelectInst>(inst)) {
170                 return false;
171             }
172             return std::find(selectors->begin(), selectors->end(), inst) != selectors->end();
173         };
174         ASSERT(std::find_if_not(val->user_begin(), val->user_end(), safeToDelete) == val->user_end());
175         val->replaceAllUsesWith(llvm::UndefValue::get(val->getType()));
176         val->eraseFromParent();
177     }
178 
179     if (g_optimizeNoop) {
180         OptimizeSelectors(mapping);
181     }
182 }
183 
GenerateInput(Value * input,Instruction * inst,Instruction * inPoint,const SelectorSplitMap & mapping)184 std::pair<Value *, Value *> GepPropagation::GenerateInput(Value *input, Instruction *inst, Instruction *inPoint,
185                                                           const SelectorSplitMap &mapping)
186 {
187     auto [mbase, moff] = mapping.lookup(inst);
188     auto offTy = moff->getType();
189     auto instInput = llvm::dyn_cast<Instruction>(input);
190     if (instInput != nullptr) {
191         auto mapit = mapping.find(instInput);
192         if (mapit != mapping.end()) {
193             return mapit->second;
194         }
195     }
196     if (auto nulloffset = llvm::dyn_cast<Constant>(input)) {
197         return {Constant::getNullValue(mbase->getType()), GetConstantOffset(nulloffset, offTy)};
198     }
199 
200     auto [base, derived] = GetBasePointer(input);
201     Value *offset = nullptr;
202     auto instBase = llvm::dyn_cast<Instruction>(base);
203     if (instBase != nullptr) {
204         auto mapit = mapping.find(instBase);
205         if (mapit != mapping.end()) {
206             base = mapit->second.first;
207             offset = mapit->second.second;
208         }
209     }
210 
211     if (derived) {
212         // Calculate offset
213         auto baseRaw = CastInst::Create(Instruction::PtrToInt, base, offTy, "", inPoint);
214         Value *gepRaw = nullptr;
215 
216         ASSERT(input->getType()->isIntOrPtrTy() && "Unexpected type of gep selector");
217         if (input->getType()->isPointerTy()) {
218             gepRaw = CastInst::Create(Instruction::PtrToInt, input, offTy, "", inPoint);
219         } else if (input->getType()->getScalarSizeInBits() < offTy->getScalarSizeInBits()) {
220             gepRaw = CastInst::Create(Instruction::ZExt, input, offTy, "", inPoint);
221         } else if (input->getType()->getScalarSizeInBits() > offTy->getScalarSizeInBits()) {
222             gepRaw = CastInst::Create(Instruction::Trunc, input, offTy, "", inPoint);
223         } else {
224             gepRaw = input;
225         }
226         offset = llvm::BinaryOperator::Create(Instruction::Sub, gepRaw, baseRaw, "", inPoint);
227     } else if (offset == nullptr) {
228         offset = llvm::ConstantInt::getSigned(offTy, 0);
229     }
230     return {base, offset};
231 }
232 
GenerateSelectorInputs(Instruction * inst,const SelectorSplitMap & mapping)233 void GepPropagation::GenerateSelectorInputs(Instruction *inst, const SelectorSplitMap &mapping)
234 {
235     ASSERT((llvm::isa<SelectInst, PHINode>(inst)));
236     auto [mbase, moff] = mapping.lookup(inst);
237 
238     auto setInputs = [mbase = mbase, moff = moff](auto idx, auto ibase, auto ioff) {
239         if (auto phi = llvm::dyn_cast<PHINode>(mbase)) {
240             auto bb = phi->getIncomingBlock(idx);
241             phi->setIncomingValueForBlock(bb, ibase);
242             llvm::cast<PHINode>(moff)->setIncomingValueForBlock(bb, ioff);
243         } else {
244             mbase->setOperand(idx, ibase);
245             moff->setOperand(idx, ioff);
246         }
247     };
248     // Generate inputs
249     bool select = llvm::isa<SelectInst>(inst);
250     for (size_t i = select ? 1 : 0; i < inst->getNumOperands(); ++i) {
251         auto input = inst->getOperand(i);
252         if (!llvm::isa<llvm::UndefValue>(mbase->getOperand(i))) {
253             continue;
254         }
255         auto inPoint = select ? mbase : llvm::cast<PHINode>(inst)->getIncomingBlock(i)->getTerminator();
256         auto [base, offset] = GenerateInput(input, inst, inPoint, mapping);
257         setInputs(i, base, offset);
258     }
259 }
260 
261 /// Generate a ConstaintInt of type TYPE that represents an OFFSET.
GetConstantOffset(Constant * offset,Type * type)262 Value *GepPropagation::GetConstantOffset(Constant *offset, Type *type)
263 {
264     offset = offset->stripPointerCasts();
265     if (llvm::isa<llvm::ConstantExpr>(offset)) {
266         ASSERT(offset->getNumOperands() == 1);
267         auto offsetRaw = offset->getOperand(0);
268         return llvm::ConstantInt::getSigned(type, llvm::cast<llvm::ConstantInt>(offsetRaw)->getSExtValue());
269     }
270     if (offset->isNullValue() || llvm::isa<llvm::PoisonValue, llvm::UndefValue>(offset)) {
271         return llvm::ConstantInt::getNullValue(type);
272     }
273 
274     return llvm::ConstantInt::getSigned(type, llvm::cast<llvm::ConstantInt>(offset)->getSExtValue());
275 }
276 
GetBasePointer(Value * value)277 std::pair<Value *, bool> GepPropagation::GetBasePointer(Value *value)
278 {
279     bool derived = false;
280     auto base = value;
281     // This loop is needed to get to addrspace(271) through addrspace (0).
282     // It is required if the value is an escaped reference.
283     while (!IsGcRefType(base->getType()) && llvm::isa<GetElementPtrInst, CastInst>(base)) {
284         derived |= llvm::isa<GetElementPtrInst>(base);
285         base = llvm::cast<Instruction>(base)->getOperand(0);
286     }
287     // Here we try to find a base pointer inside addrspace (271).
288     while (llvm::isa<GetElementPtrInst, CastInst>(base)) {
289         derived |= llvm::isa<GetElementPtrInst>(base);
290         auto next = llvm::cast<Instruction>(base)->getOperand(0);
291         // Go until instruction introducing a GC reference.
292         if (!IsGcRefType(next->getType())) {
293             break;
294         }
295         base = next;
296     }
297     ASSERT(IsGcRefType(base->getType()));
298     return {base, derived};
299 }
300 
301 /// Returns a constant input of a selector.
GetConstantInput(Instruction * inst)302 Value *GepPropagation::GetConstantInput(Instruction *inst)
303 {
304     llvm::DenseSet<Value *> visited;
305     llvm::SmallVector<Instruction *> queueLocal;
306     queueLocal.push_back(inst);
307     Value *single = nullptr;
308     while (!queueLocal.empty()) {
309         auto elt = queueLocal.pop_back_val();
310         if (!visited.insert(elt).second) {
311             continue;
312         }
313         bool select = llvm::isa<SelectInst>(elt);
314         for (size_t i = select ? 1 : 0; i < elt->getNumOperands(); ++i) {
315             auto input = elt->getOperand(i);
316             if (llvm::isa<SelectInst, PHINode>(input)) {
317                 queueLocal.push_back(llvm::cast<Instruction>(input));
318                 continue;
319             }
320             if (single == nullptr) {
321                 single = input;
322             }
323             if (single != input) {
324                 return nullptr;
325             }
326         }
327     }
328     return single;
329 }
330 
OptimizeGepoffs(SelectorSplitMap & mapping)331 void GepPropagation::OptimizeGepoffs(SelectorSplitMap &mapping)
332 {
333     llvm::SmallVector<Instruction *> queue;
334     for (auto entry : mapping) {
335         queue.push_back(entry.second.second);
336     }
337     while (!queue.empty()) {
338         auto off = queue.pop_back_val();
339         auto coff = GetConstantInput(off);
340         if (coff == nullptr) {
341             continue;
342         }
343         for (auto uiter = off->use_begin(), uend = off->use_end(); uiter != uend;) {
344             Use &use = *uiter++;
345             auto user = use.getUser();
346             use.set(coff);
347 
348             if (llvm::isa<PHINode, SelectInst>(user)) {
349                 queue.push_back(llvm::cast<Instruction>(user));
350                 continue;
351             }
352             auto gep = llvm::dyn_cast<GetElementPtrInst>(user);
353             if (gep == nullptr) {
354                 continue;
355             }
356             if (gep->hasAllZeroIndices()) {
357                 auto pointer = gep->getPointerOperand();
358                 pointer->takeName(gep);
359                 gep->replaceAllUsesWith(pointer);
360             }
361         }
362     }
363 }
364 
365 /// Basic optimization of generated PHI and Select instructions that have identical inputs.
OptimizeSelectors(SelectorSplitMap & mapping)366 void GepPropagation::OptimizeSelectors(SelectorSplitMap &mapping)
367 {
368     llvm::SmallVector<Instruction *> queue;
369     // Optimize bases
370     for (auto entry : mapping) {
371         queue.push_back(entry.second.first);
372     }
373     while (!queue.empty()) {
374         auto base = queue.pop_back_val();
375         auto cbase = GetConstantInput(base);
376         if (cbase == nullptr) {
377             continue;
378         }
379         for (auto uiter = base->use_begin(), uend = base->use_end(); uiter != uend;) {
380             Use &use = *uiter++;
381             auto user = use.getUser();
382             use.set(cbase);
383 
384             if (llvm::isa<PHINode, SelectInst>(user)) {
385                 queue.push_back(llvm::cast<Instruction>(user));
386             }
387         }
388     }
389     // Optimize gepoffs
390     OptimizeGepoffs(mapping);
391     // Remove replaced instructions
392     for (auto entry : mapping) {
393         auto [base, off] = entry.second;
394         if (base->user_empty()) {
395             base->eraseFromParent();
396         }
397         if (off->user_empty()) {
398             off->eraseFromParent();
399         }
400     }
401 }
402 
403 /// Update users of GEP selector using GEP obtained after selector splitting.
ReplaceWithSplitGep(Instruction * inst,Instruction * splitGep)404 void GepPropagation::ReplaceWithSplitGep(Instruction *inst, Instruction *splitGep)
405 {
406     bool needCast = splitGep->getType() != inst->getType();
407     auto generateCast = [splitGep, type = inst->getType()]() {
408         auto cast = llvm::Instruction::CastOps::CastOpsEnd;
409         if (type->isIntegerTy()) {
410             cast = Instruction::PtrToInt;
411         } else if (type->isPointerTy()) {
412             cast = IsGcRefType(type) ? Instruction::BitCast : Instruction::AddrSpaceCast;
413         }
414         ASSERT(cast != llvm::Instruction::CastOps::CastOpsEnd && "Unsupported selector type");
415         return CastInst::Create(cast, splitGep, type);
416     };
417     for (auto uiter = inst->use_begin(), uend = inst->use_end(); uiter != uend;) {
418         Use &use = *uiter++;
419         auto *uinst = llvm::cast<Instruction>(use.getUser());
420         if (llvm::isa<PHINode, SelectInst>(uinst)) {
421             continue;
422         }
423         if (needCast) {
424             auto cast = generateCast();
425             cast->insertBefore(uinst);
426             use.set(cast);
427         } else {
428             use.set(splitGep);
429         }
430     }
431 }
432 
433 /// Recursively collect a sequence of instructions to clone them for each use of this sequence.
ReplaceRecursively(Instruction * inst,SmallVector<Instruction *,1> * seq)434 void GepPropagation::ReplaceRecursively(Instruction *inst, SmallVector<Instruction *, 1> *seq)
435 {
436     llvm::IRBuilder<> builder(inst);
437     // Mapping of (phi, incoming basic block) -> expanded value
438     DenseMap<std::pair<Instruction *, BasicBlock *>, Instruction *> phiCache;
439     seq->push_back(inst);
440     for (auto iter = inst->use_begin(), end = inst->use_end(); iter != end;) {
441         Use &use = *iter++;
442         auto *uinst = llvm::cast<Instruction>(use.getUser());
443 
444         if (uinst->isCast()) {
445             // This assert guaranties that we can set zero operands during
446             // cloning the sequence of casts
447             ASSERT(uinst->getNumOperands() == 1);
448             ReplaceRecursively(uinst, seq);
449             continue;
450         }
451         // Reached the end of sequence
452         Instruction *clone = nullptr;
453         // If use a phi, insert it to the block where inst comes from.
454         if (llvm::isa<PHINode>(uinst)) {
455             auto incoming = llvm::cast<PHINode>(uinst)->getIncomingBlock(use);
456             auto visited = phiCache.find({uinst, incoming});
457             if (visited != phiCache.end()) {
458                 use.set(visited->second);
459                 continue;
460             }
461             builder.SetInsertPoint(incoming->getTerminator());
462             clone = CloneSequence(&builder, seq);
463             phiCache.insert({{uinst, incoming}, clone});
464         } else {
465             builder.SetInsertPoint(uinst);
466             clone = CloneSequence(&builder, seq);
467         }
468         use.set(clone);
469     }
470     seq->pop_back();
471     // Since we have replaced all uses the original can be erased
472     ASSERT(inst->user_empty());
473     inst->eraseFromParent();
474 }
475 
CloneSequence(llvm::IRBuilder<> * builder,SmallVector<Instruction *,1> * seq)476 Instruction *GepPropagation::CloneSequence(llvm::IRBuilder<> *builder, SmallVector<Instruction *, 1> *seq)
477 {
478     ASSERT(!seq->empty());
479     auto prev = (*seq->begin())->clone();
480     builder->Insert(prev);
481     for (auto siter = seq->begin() + 1; siter != seq->end(); siter++) {
482         auto clone = (*siter)->clone();
483         clone->setOperand(0, prev);
484         builder->Insert(clone);
485         prev = clone;
486     }
487     return prev;
488 }
489 
490 }  // namespace ark::llvmbackend::passes
491