• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "transforms/passes/gep_propagation.h"
17 
18 #include "transforms/gc_utils.h"
19 #include "transforms/transform_utils.h"
20 
21 #include <llvm/ADT/DenseMap.h>
22 #include <llvm/ADT/PostOrderIterator.h>
23 #include <llvm/IR/IRBuilder.h>
24 #include <llvm/Pass.h>
25 #include <llvm/Support/CommandLine.h>
26 
27 #define DEBUG_TYPE "gep-propagation"
28 
29 // Basic classes
30 using llvm::BasicBlock;
31 using llvm::DenseMap;
32 using llvm::Function;
33 using llvm::SmallVector;
34 using llvm::Type;
35 using llvm::Use;
36 using llvm::Value;
37 // Instructions
38 using llvm::CastInst;
39 using llvm::Constant;
40 using llvm::GetElementPtrInst;
41 using llvm::Instruction;
42 using llvm::PHINode;
43 using llvm::SelectInst;
44 // Gc utils
45 using ark::llvmbackend::gc_utils::IsGcRefType;
46 
47 /// Optimize no-op PHINodes and Selects in place.
48 // NOLINTNEXTLINE(fuchsia-statically-constructed-objects)
49 static llvm::cl::opt<bool> g_optimizeNoop("gprop-optimize", llvm::cl::Hidden, llvm::cl::init(true));
50 
51 namespace ark::llvmbackend::passes {
52 
run(llvm::Function & function,llvm::FunctionAnalysisManager &)53 llvm::PreservedAnalyses GepPropagation::run(llvm::Function &function, llvm::FunctionAnalysisManager & /*AM*/)
54 {
55     LLVM_DEBUG(llvm::dbgs() << "Function: " << function.getName() << "\n");
56     if (!gc_utils::IsGcFunction(function) || gc_utils::IsFunctionSupplemental(function)) {
57         return llvm::PreservedAnalyses::all();
58     }
59 
60     Propagate(&function);
61     return llvm::PreservedAnalyses::none();
62 }
63 
AddToVector(Instruction * inst,SmallVector<Instruction * > * toExpand,SmallVector<Instruction * > * selectors)64 void GepPropagation::AddToVector(Instruction *inst, SmallVector<Instruction *> *toExpand,
65                                  SmallVector<Instruction *> *selectors)
66 {
67     switch (inst->getOpcode()) {
68         // Derived references
69         case Instruction::GetElementPtr:
70             if (IsGcRefType(inst->getType())) {
71                 toExpand->push_back(inst);
72             }
73             break;
74         // Escaping managed scope
75         case Instruction::PtrToInt:
76         case Instruction::AddrSpaceCast:
77             if (IsGcRefType(inst->getOperand(0)->getType())) {
78                 toExpand->push_back(inst);
79             }
80             break;
81         case Instruction::Select:
82         case Instruction::PHI:
83             if ((IsGcRefType(inst->getType()) && gc_utils::IsDerived(inst)) || gc_utils::HasBeenGcRef(inst, false)) {
84                 selectors->push_back(inst);
85                 toExpand->push_back(inst);
86             }
87             break;
88         default:
89             break;
90     }
91 }
92 
Propagate(Function * function)93 void GepPropagation::Propagate(Function *function)
94 {
95     SmallVector<Instruction *> toExpand;
96     SmallVector<Instruction *> selectors;
97     llvm::ReversePostOrderTraversal<Function *> rpo(function);
98     for (auto block : rpo) {
99         for (auto &inst : *block) {
100             AddToVector(&inst, &toExpand, &selectors);
101         }
102     }
103 
104     DenseMap<Instruction *, Instruction *> sgeps;
105     SplitGepSelectors(function, &selectors, &sgeps);
106 
107     while (!toExpand.empty()) {
108         auto gep = toExpand.pop_back_val();
109         if (sgeps.find(gep) != sgeps.end()) {
110             gep = sgeps[gep];
111         }
112         SmallVector<Instruction *, 1> seq;
113         ReplaceRecursively(gep, &seq);
114     }
115 }
116 
FindSplitGep(Function * function,Instruction * inst,Instruction ** ipoint)117 static std::pair<Instruction *, Instruction *> FindSplitGep(Function *function, Instruction *inst, Instruction **ipoint)
118 {
119     auto &ctx = function->getContext();
120     auto bptrTy = llvm::PointerType::get(ctx, ark::llvmbackend::LLVMArkInterface::GC_ADDR_SPACE);
121     auto undefBase = llvm::UndefValue::get(bptrTy);
122     auto undefOffset = llvm::UndefValue::get(Type::getInt32Ty(ctx));
123     auto phi = llvm::dyn_cast<PHINode>(inst);
124     Instruction *mbase;
125     Instruction *moff;
126     if (phi != nullptr) {
127         mbase = PHINode::Create(bptrTy, phi->getNumIncomingValues(), "base", phi);
128         moff = PHINode::Create(Type::getInt32Ty(ctx), phi->getNumIncomingValues(), "gepoff", phi);
129         for (size_t i = 0; i < phi->getNumIncomingValues(); ++i) {
130             auto bb = phi->getIncomingBlock(i);
131             llvm::cast<PHINode>(mbase)->addIncoming(undefBase, bb);
132             llvm::cast<PHINode>(moff)->addIncoming(undefOffset, bb);
133         }
134         *ipoint = phi->getParent()->getFirstNonPHI();
135     } else {
136         auto condition = llvm::cast<SelectInst>(inst)->getCondition();
137         mbase = SelectInst::Create(condition, undefBase, undefBase, "base", inst);
138         moff = SelectInst::Create(condition, undefOffset, undefOffset, "gepoff", inst);
139         *ipoint = inst;
140     }
141     return {mbase, moff};
142 }
143 
SplitGepSelectors(Function * function,SmallVector<Instruction * > * selectors,DenseMap<Instruction *,Instruction * > * sgeps)144 void GepPropagation::SplitGepSelectors(Function *function, SmallVector<Instruction *> *selectors,
145                                        DenseMap<Instruction *, Instruction *> *sgeps)
146 {
147     auto &ctx = function->getContext();
148     SelectorSplitMap mapping;
149     for (auto inst : *selectors) {
150         Instruction *ipoint = nullptr;
151         auto splitGep = FindSplitGep(function, inst, &ipoint);
152         mapping[inst] = splitGep;
153         auto gep = GetElementPtrInst::CreateInBounds(Type::getInt8Ty(ctx), splitGep.first, splitGep.second, "", ipoint);
154         gep->takeName(inst);
155         sgeps->insert({inst, gep});
156     }
157 
158     for (auto inst : *selectors) {
159         ASSERT((llvm::isa<PHINode, SelectInst>(inst)));
160         GenerateSelectorInputs(inst, mapping);
161         ReplaceWithSplitGep(inst, sgeps->lookup(inst));
162     }
163 
164     // Erase GEP selectors as we replaced them with their split versions.
165     while (!selectors->empty()) {
166         auto val = selectors->pop_back_val();
167         [[maybe_unused]] auto safeToDelete = [selectors](auto user) {
168             auto inst = llvm::cast<Instruction>(user);
169             if (!llvm::isa<PHINode, SelectInst>(inst)) {
170                 return false;
171             }
172             return std::find(selectors->begin(), selectors->end(), inst) != selectors->end();
173         };
174         ASSERT(std::find_if_not(val->user_begin(), val->user_end(), safeToDelete) == val->user_end());
175         val->replaceAllUsesWith(llvm::UndefValue::get(val->getType()));
176         val->eraseFromParent();
177     }
178 
179     if (g_optimizeNoop) {
180         OptimizeSelectors(mapping);
181     }
182 }
183 
GenerateInput(Value * input,Instruction * inst,Instruction * inPoint,const SelectorSplitMap & mapping)184 std::pair<Value *, Value *> GepPropagation::GenerateInput(Value *input, Instruction *inst, Instruction *inPoint,
185                                                           const SelectorSplitMap &mapping)
186 {
187     auto [mbase, moff] = mapping.lookup(inst);
188     auto offTy = moff->getType();
189     auto instInput = llvm::dyn_cast<Instruction>(input);
190     if (instInput != nullptr) {
191         auto mapit = mapping.find(instInput);
192         if (mapit != mapping.end()) {
193             return mapit->second;
194         }
195     }
196     if (auto nulloffset = llvm::dyn_cast<Constant>(input)) {
197         return {Constant::getNullValue(mbase->getType()), GetConstantOffset(nulloffset, offTy)};
198     }
199 
200     auto [base, derived] = GetBasePointer(input);
201     Value *offset = nullptr;
202     auto instBase = llvm::dyn_cast<Instruction>(base);
203     if (instBase != nullptr) {
204         auto mapit = mapping.find(instBase);
205         if (mapit != mapping.end()) {
206             base = mapit->second.first;
207             offset = mapit->second.second;
208         }
209     }
210 
211     if (derived) {
212         // Calculate offset
213         auto baseRaw = CastInst::Create(Instruction::PtrToInt, base, offTy, "", inPoint);
214         Value *gepRaw = nullptr;
215 
216         ASSERT(input->getType()->isIntOrPtrTy() && "Unexpected type of gep selector");
217         if (input->getType()->isPointerTy()) {
218             gepRaw = CastInst::Create(Instruction::PtrToInt, input, offTy, "", inPoint);
219         } else if (input->getType()->getScalarSizeInBits() < offTy->getScalarSizeInBits()) {
220             gepRaw = CastInst::Create(Instruction::ZExt, input, offTy, "", inPoint);
221         } else if (input->getType()->getScalarSizeInBits() > offTy->getScalarSizeInBits()) {
222             gepRaw = CastInst::Create(Instruction::Trunc, input, offTy, "", inPoint);
223         } else {
224             gepRaw = input;
225         }
226         offset = llvm::BinaryOperator::Create(Instruction::Sub, gepRaw, baseRaw, "", inPoint);
227     } else if (offset == nullptr) {
228         offset = llvm::ConstantInt::getSigned(offTy, 0);
229     }
230     return {base, offset};
231 }
232 
GenerateSelectorInputs(Instruction * inst,const SelectorSplitMap & mapping)233 void GepPropagation::GenerateSelectorInputs(Instruction *inst, const SelectorSplitMap &mapping)
234 {
235     ASSERT((llvm::isa<SelectInst, PHINode>(inst)));
236     auto [mbase, moff] = mapping.lookup(inst);
237 
238     auto setInputs = [mbase = mbase, moff = moff](auto idx, auto ibase, auto ioff) {
239         if (auto phi = llvm::dyn_cast<PHINode>(mbase)) {
240             auto bb = phi->getIncomingBlock(idx);
241             phi->setIncomingValueForBlock(bb, ibase);
242             llvm::cast<PHINode>(moff)->setIncomingValueForBlock(bb, ioff);
243         } else {
244             mbase->setOperand(idx, ibase);
245             moff->setOperand(idx, ioff);
246         }
247     };
248     // Generate inputs
249     bool select = llvm::isa<SelectInst>(inst);
250     for (size_t i = select ? 1 : 0; i < inst->getNumOperands(); ++i) {
251         auto input = inst->getOperand(i);
252         if (!llvm::isa<llvm::UndefValue>(mbase->getOperand(i))) {
253             continue;
254         }
255         auto inPoint = select ? mbase : llvm::cast<PHINode>(inst)->getIncomingBlock(i)->getTerminator();
256         auto [base, offset] = GenerateInput(input, inst, inPoint, mapping);
257         setInputs(i, base, offset);
258     }
259 }
260 
261 /// Generate a ConstaintInt of type TYPE that represents an OFFSET.
GetConstantOffset(Constant * offset,Type * type)262 Value *GepPropagation::GetConstantOffset(Constant *offset, Type *type)
263 {
264     offset = offset->stripPointerCasts();
265     if (llvm::isa<llvm::ConstantExpr>(offset)) {
266         ASSERT(offset->getNumOperands() == 1);
267         auto offsetRaw = offset->getOperand(0);
268         return llvm::ConstantInt::getSigned(type, llvm::cast<llvm::ConstantInt>(offsetRaw)->getSExtValue());
269     }
270     if (offset->isNullValue()) {
271         return llvm::ConstantInt::getSigned(type, 0);
272     }
273     if (llvm::isa<llvm::PoisonValue>(offset)) {
274         return llvm::PoisonValue::get(type);
275     }
276 
277     return llvm::ConstantInt::getSigned(type, llvm::cast<llvm::ConstantInt>(offset)->getSExtValue());
278 }
279 
GetBasePointer(Value * value)280 std::pair<Value *, bool> GepPropagation::GetBasePointer(Value *value)
281 {
282     bool derived = false;
283     auto base = value;
284     // This loop is needed to get to addrspace(271) through addrspace (0).
285     // It is required if the value is an escaped reference.
286     while (!IsGcRefType(base->getType()) && llvm::isa<GetElementPtrInst, CastInst>(base)) {
287         derived |= llvm::isa<GetElementPtrInst>(base);
288         base = llvm::cast<Instruction>(base)->getOperand(0);
289     }
290     // Here we try to find a base pointer inside addrspace (271).
291     while (llvm::isa<GetElementPtrInst, CastInst>(base)) {
292         derived |= llvm::isa<GetElementPtrInst>(base);
293         auto next = llvm::cast<Instruction>(base)->getOperand(0);
294         // Go until instruction introducing a GC reference.
295         if (!IsGcRefType(next->getType())) {
296             break;
297         }
298         base = next;
299     }
300     ASSERT(IsGcRefType(base->getType()));
301     return {base, derived};
302 }
303 
304 /// Returns a constant input of a selector.
GetConstantInput(Instruction * inst)305 Value *GepPropagation::GetConstantInput(Instruction *inst)
306 {
307     llvm::DenseSet<Value *> visited;
308     llvm::SmallVector<Instruction *> queueLocal;
309     queueLocal.push_back(inst);
310     Value *single = nullptr;
311     while (!queueLocal.empty()) {
312         auto elt = queueLocal.pop_back_val();
313         if (!visited.insert(elt).second) {
314             continue;
315         }
316         bool select = llvm::isa<SelectInst>(elt);
317         for (size_t i = select ? 1 : 0; i < elt->getNumOperands(); ++i) {
318             auto input = elt->getOperand(i);
319             if (llvm::isa<SelectInst, PHINode>(input)) {
320                 queueLocal.push_back(llvm::cast<Instruction>(input));
321                 continue;
322             }
323             if (single == nullptr) {
324                 single = input;
325             }
326             if (single != input) {
327                 return nullptr;
328             }
329         }
330     }
331     return single;
332 }
333 
OptimizeGepoffs(SelectorSplitMap & mapping)334 void GepPropagation::OptimizeGepoffs(SelectorSplitMap &mapping)
335 {
336     llvm::SmallVector<Instruction *> queue;
337     for (auto entry : mapping) {
338         queue.push_back(entry.second.second);
339     }
340     while (!queue.empty()) {
341         auto off = queue.pop_back_val();
342         auto coff = GetConstantInput(off);
343         if (coff == nullptr) {
344             continue;
345         }
346         for (auto uiter = off->use_begin(), uend = off->use_end(); uiter != uend;) {
347             Use &use = *uiter++;
348             auto user = use.getUser();
349             use.set(coff);
350 
351             if (llvm::isa<PHINode, SelectInst>(user)) {
352                 queue.push_back(llvm::cast<Instruction>(user));
353                 continue;
354             }
355             auto gep = llvm::dyn_cast<GetElementPtrInst>(user);
356             if (gep == nullptr) {
357                 continue;
358             }
359             if (gep->hasAllZeroIndices()) {
360                 auto pointer = gep->getPointerOperand();
361                 pointer->takeName(gep);
362                 gep->replaceAllUsesWith(pointer);
363             }
364         }
365     }
366 }
367 
368 /// Basic optimization of generated PHI and Select instructions that have identical inputs.
OptimizeSelectors(SelectorSplitMap & mapping)369 void GepPropagation::OptimizeSelectors(SelectorSplitMap &mapping)
370 {
371     llvm::SmallVector<Instruction *> queue;
372     // Optimize bases
373     for (auto entry : mapping) {
374         queue.push_back(entry.second.first);
375     }
376     while (!queue.empty()) {
377         auto base = queue.pop_back_val();
378         auto cbase = GetConstantInput(base);
379         if (cbase == nullptr) {
380             continue;
381         }
382         for (auto uiter = base->use_begin(), uend = base->use_end(); uiter != uend;) {
383             Use &use = *uiter++;
384             auto user = use.getUser();
385             use.set(cbase);
386 
387             if (llvm::isa<PHINode, SelectInst>(user)) {
388                 queue.push_back(llvm::cast<Instruction>(user));
389             }
390         }
391     }
392     // Optimize gepoffs
393     OptimizeGepoffs(mapping);
394     // Remove replaced instructions
395     for (auto entry : mapping) {
396         auto [base, off] = entry.second;
397         if (base->user_empty()) {
398             base->eraseFromParent();
399         }
400         if (off->user_empty()) {
401             off->eraseFromParent();
402         }
403     }
404 }
405 
406 /// Update users of GEP selector using GEP obtained after selector splitting.
ReplaceWithSplitGep(Instruction * inst,Instruction * splitGep)407 void GepPropagation::ReplaceWithSplitGep(Instruction *inst, Instruction *splitGep)
408 {
409     bool needCast = splitGep->getType() != inst->getType();
410     auto generateCast = [splitGep, type = inst->getType()]() {
411         auto cast = llvm::Instruction::CastOps::CastOpsEnd;
412         if (type->isIntegerTy()) {
413             cast = Instruction::PtrToInt;
414         } else if (type->isPointerTy()) {
415             cast = IsGcRefType(type) ? Instruction::BitCast : Instruction::AddrSpaceCast;
416         }
417         ASSERT(cast != llvm::Instruction::CastOps::CastOpsEnd && "Unsupported selector type");
418         return CastInst::Create(cast, splitGep, type);
419     };
420     for (auto uiter = inst->use_begin(), uend = inst->use_end(); uiter != uend;) {
421         Use &use = *uiter++;
422         auto *uinst = llvm::cast<Instruction>(use.getUser());
423         if (llvm::isa<PHINode, SelectInst>(uinst)) {
424             continue;
425         }
426         if (needCast) {
427             auto cast = generateCast();
428             cast->insertBefore(uinst);
429             use.set(cast);
430         } else {
431             use.set(splitGep);
432         }
433     }
434 }
435 
436 /// Recursively collect a sequence of instructions to clone them for each use of this sequence.
ReplaceRecursively(Instruction * inst,SmallVector<Instruction *,1> * seq)437 void GepPropagation::ReplaceRecursively(Instruction *inst, SmallVector<Instruction *, 1> *seq)
438 {
439     llvm::IRBuilder<> builder(inst);
440     // Mapping of (phi, incoming basic block) -> expanded value
441     DenseMap<std::pair<Instruction *, BasicBlock *>, Instruction *> phiCache;
442     seq->push_back(inst);
443     for (auto iter = inst->use_begin(), end = inst->use_end(); iter != end;) {
444         Use &use = *iter++;
445         auto *uinst = llvm::cast<Instruction>(use.getUser());
446 
447         if (uinst->isCast()) {
448             // This assert guaranties that we can set zero operands during
449             // cloning the sequence of casts
450             ASSERT(uinst->getNumOperands() == 1);
451             ReplaceRecursively(uinst, seq);
452             continue;
453         }
454         // Reached the end of sequence
455         Instruction *clone = nullptr;
456         // If use a phi, insert it to the block where inst comes from.
457         if (llvm::isa<PHINode>(uinst)) {
458             auto incoming = llvm::cast<PHINode>(uinst)->getIncomingBlock(use);
459             auto visited = phiCache.find({uinst, incoming});
460             if (visited != phiCache.end()) {
461                 use.set(visited->second);
462                 continue;
463             }
464             builder.SetInsertPoint(incoming->getTerminator());
465             clone = CloneSequence(&builder, seq);
466             phiCache.insert({{uinst, incoming}, clone});
467         } else {
468             builder.SetInsertPoint(uinst);
469             clone = CloneSequence(&builder, seq);
470         }
471         use.set(clone);
472     }
473     seq->pop_back();
474     // Since we have replaced all uses the original can be erased
475     ASSERT(inst->user_empty());
476     inst->eraseFromParent();
477 }
478 
CloneSequence(llvm::IRBuilder<> * builder,SmallVector<Instruction *,1> * seq)479 Instruction *GepPropagation::CloneSequence(llvm::IRBuilder<> *builder, SmallVector<Instruction *, 1> *seq)
480 {
481     ASSERT(!seq->empty());
482     auto prev = (*seq->begin())->clone();
483     builder->Insert(prev);
484     for (auto siter = seq->begin() + 1; siter != seq->end(); siter++) {
485         auto clone = (*siter)->clone();
486         clone->setOperand(0, prev);
487         builder->Insert(clone);
488         prev = clone;
489     }
490     return prev;
491 }
492 
493 }  // namespace ark::llvmbackend::passes
494