1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/CodeGen/TargetSubtargetInfo.h"
22 #include "llvm/InitializePasses.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Constant.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/InstrTypes.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/IntrinsicsARM.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/IR/Value.h"
38 #include "llvm/Pass.h"
39 #include "llvm/Support/Casting.h"
40 #include <algorithm>
41 #include <cassert>
42
43 using namespace llvm;
44
45 #define DEBUG_TYPE "mve-gather-scatter-lowering"
46
47 cl::opt<bool> EnableMaskedGatherScatters(
48 "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
49 cl::desc("Enable the generation of masked gathers and scatters"));
50
51 namespace {
52
53 class MVEGatherScatterLowering : public FunctionPass {
54 public:
55 static char ID; // Pass identification, replacement for typeid
56
MVEGatherScatterLowering()57 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
58 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
59 }
60
61 bool runOnFunction(Function &F) override;
62
getPassName() const63 StringRef getPassName() const override {
64 return "MVE gather/scatter lowering";
65 }
66
getAnalysisUsage(AnalysisUsage & AU) const67 void getAnalysisUsage(AnalysisUsage &AU) const override {
68 AU.setPreservesCFG();
69 AU.addRequired<TargetPassConfig>();
70 FunctionPass::getAnalysisUsage(AU);
71 }
72
73 private:
74 // Check this is a valid gather with correct alignment
75 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
76 unsigned Alignment);
77 // Check whether Ptr is hidden behind a bitcast and look through it
78 void lookThroughBitcast(Value *&Ptr);
79 // Check for a getelementptr and deduce base and offsets from it, on success
80 // returning the base directly and the offsets indirectly using the Offsets
81 // argument
82 Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder);
83
84 bool lowerGather(IntrinsicInst *I);
85 // Create a gather from a base + vector of offsets
86 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
87 IRBuilder<> Builder);
88 // Create a gather from a vector of pointers
89 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
90 IRBuilder<> Builder);
91 };
92
93 } // end anonymous namespace
94
95 char MVEGatherScatterLowering::ID = 0;
96
97 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
98 "MVE gather/scattering lowering pass", false, false)
99
createMVEGatherScatterLoweringPass()100 Pass *llvm::createMVEGatherScatterLoweringPass() {
101 return new MVEGatherScatterLowering();
102 }
103
isLegalTypeAndAlignment(unsigned NumElements,unsigned ElemSize,unsigned Alignment)104 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
105 unsigned ElemSize,
106 unsigned Alignment) {
107 // Do only allow non-extending gathers for now
108 if (((NumElements == 4 && ElemSize == 32) ||
109 (NumElements == 8 && ElemSize == 16) ||
110 (NumElements == 16 && ElemSize == 8)) &&
111 ElemSize / 8 <= Alignment)
112 return true;
113 LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid "
114 << "alignment or vector type \n");
115 return false;
116 }
117
checkGEP(Value * & Offsets,Type * Ty,Value * Ptr,IRBuilder<> Builder)118 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr,
119 IRBuilder<> Builder) {
120 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
121 if (!GEP) {
122 LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n");
123 return nullptr;
124 }
125 LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading"
126 << " from base + vector of offsets\n");
127 Value *GEPPtr = GEP->getPointerOperand();
128 if (GEPPtr->getType()->isVectorTy()) {
129 LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers"
130 << " hidden behind a getelementptr currently not"
131 << " supported. Expanding.\n");
132 return nullptr;
133 }
134 if (GEP->getNumOperands() != 2) {
135 LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many"
136 << " operands. Expanding.\n");
137 return nullptr;
138 }
139 Offsets = GEP->getOperand(1);
140 // SExt offsets inside masked gathers are not permitted by the architecture;
141 // we therefore can't fold them
142 if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
143 Offsets = ZextOffs->getOperand(0);
144 Type *OffsType = VectorType::getInteger(cast<VectorType>(Ty));
145 // If the offset we found does not have the type the intrinsic expects,
146 // i.e., the same type as the gather itself, we need to convert it (only i
147 // types) or fall back to expanding the gather
148 if (OffsType != Offsets->getType()) {
149 if (OffsType->getScalarSizeInBits() >
150 Offsets->getType()->getScalarSizeInBits()) {
151 LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n");
152 Offsets = Builder.CreateZExt(Offsets, OffsType, "");
153 } else {
154 LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't"
155 << " create masked gather\n");
156 return nullptr;
157 }
158 }
159 // If none of the checks failed, return the gep's base pointer
160 return GEPPtr;
161 }
162
lookThroughBitcast(Value * & Ptr)163 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
164 // Look through bitcast instruction if #elements is the same
165 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
166 Type *BCTy = BitCast->getType();
167 Type *BCSrcTy = BitCast->getOperand(0)->getType();
168 if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) {
169 LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n");
170 Ptr = BitCast->getOperand(0);
171 }
172 }
173 }
174
lowerGather(IntrinsicInst * I)175 bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
176 using namespace PatternMatch;
177 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
178
179 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
180 // Attempt to turn the masked gather in I into a MVE intrinsic
181 // Potentially optimising the addressing modes as we do so.
182 Type *Ty = I->getType();
183 Value *Ptr = I->getArgOperand(0);
184 unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue();
185 Value *Mask = I->getArgOperand(2);
186 Value *PassThru = I->getArgOperand(3);
187
188 if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(),
189 Ty->getScalarSizeInBits(), Alignment))
190 return false;
191 lookThroughBitcast(Ptr);
192 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
193
194 IRBuilder<> Builder(I->getContext());
195 Builder.SetInsertPoint(I);
196 Builder.SetCurrentDebugLocation(I->getDebugLoc());
197 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder);
198 if (!Load)
199 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
200 if (!Load)
201 return false;
202
203 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
204 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
205 << "creating select\n");
206 Load = Builder.CreateSelect(Mask, Load, PassThru);
207 }
208
209 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
210 I->replaceAllUsesWith(Load);
211 I->eraseFromParent();
212 return true;
213 }
214
tryCreateMaskedGatherBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> Builder)215 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
216 IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
217 using namespace PatternMatch;
218 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
219 Type *Ty = I->getType();
220 if (Ty->getVectorNumElements() != 4)
221 // Can't build an intrinsic for this
222 return nullptr;
223 Value *Mask = I->getArgOperand(2);
224 if (match(Mask, m_One()))
225 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
226 {Ty, Ptr->getType()},
227 {Ptr, Builder.getInt32(0)});
228 else
229 return Builder.CreateIntrinsic(
230 Intrinsic::arm_mve_vldr_gather_base_predicated,
231 {Ty, Ptr->getType(), Mask->getType()},
232 {Ptr, Builder.getInt32(0), Mask});
233 }
234
tryCreateMaskedGatherOffset(IntrinsicInst * I,Value * Ptr,IRBuilder<> Builder)235 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
236 IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
237 using namespace PatternMatch;
238 Type *Ty = I->getType();
239 Value *Offsets;
240 Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder);
241 if (!BasePtr)
242 return nullptr;
243
244 unsigned Scale;
245 int GEPElemSize =
246 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits();
247 int ResultElemSize = Ty->getScalarSizeInBits();
248 // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a
249 // 8bit, 16bit or 32bit load scaled by 1
250 if (GEPElemSize == 32 && ResultElemSize == 32) {
251 Scale = 2;
252 } else if (GEPElemSize == 16 && ResultElemSize == 16) {
253 Scale = 1;
254 } else if (GEPElemSize == 8) {
255 Scale = 0;
256 } else {
257 LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't"
258 << " create masked gather\n");
259 return nullptr;
260 }
261
262 Value *Mask = I->getArgOperand(2);
263 if (!match(Mask, m_One()))
264 return Builder.CreateIntrinsic(
265 Intrinsic::arm_mve_vldr_gather_offset_predicated,
266 {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()},
267 {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
268 Builder.getInt32(Scale), Builder.getInt32(1), Mask});
269 else
270 return Builder.CreateIntrinsic(
271 Intrinsic::arm_mve_vldr_gather_offset,
272 {Ty, BasePtr->getType(), Offsets->getType()},
273 {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
274 Builder.getInt32(Scale), Builder.getInt32(1)});
275 }
276
runOnFunction(Function & F)277 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
278 if (!EnableMaskedGatherScatters)
279 return false;
280 auto &TPC = getAnalysis<TargetPassConfig>();
281 auto &TM = TPC.getTM<TargetMachine>();
282 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
283 if (!ST->hasMVEIntegerOps())
284 return false;
285 SmallVector<IntrinsicInst *, 4> Gathers;
286 for (BasicBlock &BB : F) {
287 for (Instruction &I : BB) {
288 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
289 if (II && II->getIntrinsicID() == Intrinsic::masked_gather)
290 Gathers.push_back(II);
291 }
292 }
293
294 if (Gathers.empty())
295 return false;
296
297 for (IntrinsicInst *I : Gathers)
298 lowerGather(I);
299
300 return true;
301 }
302