1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to
10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11 /// produce a better final result as we go.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "ARM.h"
16 #include "ARMBaseInstrInfo.h"
17 #include "ARMSubtarget.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/Constant.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/InstrTypes.h"
30 #include "llvm/IR/Instruction.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/IntrinsicsARM.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/PatternMatch.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #include <algorithm>
43 #include <cassert>
44
45 using namespace llvm;
46
47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering"
48
49 cl::opt<bool> EnableMaskedGatherScatters(
50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true),
51 cl::desc("Enable the generation of masked gathers and scatters"));
52
53 namespace {
54
55 class MVEGatherScatterLowering : public FunctionPass {
56 public:
57 static char ID; // Pass identification, replacement for typeid
58
MVEGatherScatterLowering()59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61 }
62
63 bool runOnFunction(Function &F) override;
64
getPassName() const65 StringRef getPassName() const override {
66 return "MVE gather/scatter lowering";
67 }
68
getAnalysisUsage(AnalysisUsage & AU) const69 void getAnalysisUsage(AnalysisUsage &AU) const override {
70 AU.setPreservesCFG();
71 AU.addRequired<TargetPassConfig>();
72 AU.addRequired<LoopInfoWrapperPass>();
73 FunctionPass::getAnalysisUsage(AU);
74 }
75
76 private:
77 LoopInfo *LI = nullptr;
78
79 // Check this is a valid gather with correct alignment
80 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
81 Align Alignment);
82 // Check whether Ptr is hidden behind a bitcast and look through it
83 void lookThroughBitcast(Value *&Ptr);
84 // Check for a getelementptr and deduce base and offsets from it, on success
85 // returning the base directly and the offsets indirectly using the Offsets
86 // argument
87 Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP,
88 IRBuilder<> &Builder);
89 // Compute the scale of this gather/scatter instruction
90 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
91 // If the value is a constant, or derived from constants via additions
92 // and multilications, return its numeric value
93 Optional<int64_t> getIfConst(const Value *V);
94 // If Inst is an add instruction, check whether one summand is a
95 // constant. If so, scale this constant and return it together with
96 // the other summand.
97 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale);
98
99 Value *lowerGather(IntrinsicInst *I);
100 // Create a gather from a base + vector of offsets
101 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
102 Instruction *&Root, IRBuilder<> &Builder);
103 // Create a gather from a vector of pointers
104 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
105 IRBuilder<> &Builder, int64_t Increment = 0);
106 // Create an incrementing gather from a vector of pointers
107 Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr,
108 IRBuilder<> &Builder,
109 int64_t Increment = 0);
110
111 Value *lowerScatter(IntrinsicInst *I);
112 // Create a scatter to a base + vector of offsets
113 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
114 IRBuilder<> &Builder);
115 // Create a scatter to a vector of pointers
116 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
117 IRBuilder<> &Builder,
118 int64_t Increment = 0);
119 // Create an incrementing scatter from a vector of pointers
120 Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr,
121 IRBuilder<> &Builder,
122 int64_t Increment = 0);
123
124 // QI gathers and scatters can increment their offsets on their own if
125 // the increment is a constant value (digit)
126 Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr,
127 Value *Ptr, GetElementPtrInst *GEP,
128 IRBuilder<> &Builder);
129 // QI gathers/scatters can increment their offsets on their own if the
130 // increment is a constant value (digit) - this creates a writeback QI
131 // gather/scatter
132 Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr,
133 Value *Ptr, unsigned TypeScale,
134 IRBuilder<> &Builder);
135
136 // Optimise the base and offsets of the given address
137 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI);
138 // Try to fold consecutive geps together into one
139 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder);
140 // Check whether these offsets could be moved out of the loop they're in
141 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
142 // Pushes the given add out of the loop
143 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
144 // Pushes the given mul out of the loop
145 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
146 Value *OffsSecondOperand, unsigned LoopIncrement,
147 IRBuilder<> &Builder);
148 };
149
150 } // end anonymous namespace
151
152 char MVEGatherScatterLowering::ID = 0;
153
154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
155 "MVE gather/scattering lowering pass", false, false)
156
createMVEGatherScatterLoweringPass()157 Pass *llvm::createMVEGatherScatterLoweringPass() {
158 return new MVEGatherScatterLowering();
159 }
160
isLegalTypeAndAlignment(unsigned NumElements,unsigned ElemSize,Align Alignment)161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
162 unsigned ElemSize,
163 Align Alignment) {
164 if (((NumElements == 4 &&
165 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
166 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
167 (NumElements == 16 && ElemSize == 8)) &&
168 Alignment >= ElemSize / 8)
169 return true;
170 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
171 << "valid alignment or vector type \n");
172 return false;
173 }
174
checkOffsetSize(Value * Offsets,unsigned TargetElemCount)175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) {
176 // Offsets that are not of type <N x i32> are sign extended by the
177 // getelementptr instruction, and MVE gathers/scatters treat the offset as
178 // unsigned. Thus, if the element size is smaller than 32, we can only allow
179 // positive offsets - i.e., the offsets are not allowed to be variables we
180 // can't look into.
181 // Additionally, <N x i32> offsets have to either originate from a zext of a
182 // vector with element types smaller or equal the type of the gather we're
183 // looking at, or consist of constants that we can check are small enough
184 // to fit into the gather type.
185 // Thus we check that 0 < value < 2^TargetElemSize.
186 unsigned TargetElemSize = 128 / TargetElemCount;
187 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType())
188 ->getElementType()
189 ->getScalarSizeInBits();
190 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) {
191 Constant *ConstOff = dyn_cast<Constant>(Offsets);
192 if (!ConstOff)
193 return false;
194 int64_t TargetElemMaxSize = (1ULL << TargetElemSize);
195 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) {
196 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem);
197 if (!OConst)
198 return false;
199 int SExtValue = OConst->getSExtValue();
200 if (SExtValue >= TargetElemMaxSize || SExtValue < 0)
201 return false;
202 return true;
203 };
204 if (isa<FixedVectorType>(ConstOff->getType())) {
205 for (unsigned i = 0; i < TargetElemCount; i++) {
206 if (!CheckValueSize(ConstOff->getAggregateElement(i)))
207 return false;
208 }
209 } else {
210 if (!CheckValueSize(ConstOff))
211 return false;
212 }
213 }
214 return true;
215 }
216
checkGEP(Value * & Offsets,FixedVectorType * Ty,GetElementPtrInst * GEP,IRBuilder<> & Builder)217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty,
218 GetElementPtrInst *GEP,
219 IRBuilder<> &Builder) {
220 if (!GEP) {
221 LLVM_DEBUG(
222 dbgs() << "masked gathers/scatters: no getelementpointer found\n");
223 return nullptr;
224 }
225 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
226 << " Looking at intrinsic for base + vector of offsets\n");
227 Value *GEPPtr = GEP->getPointerOperand();
228 Offsets = GEP->getOperand(1);
229 if (GEPPtr->getType()->isVectorTy() ||
230 !isa<FixedVectorType>(Offsets->getType()))
231 return nullptr;
232
233 if (GEP->getNumOperands() != 2) {
234 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
235 << " operands. Expanding.\n");
236 return nullptr;
237 }
238 Offsets = GEP->getOperand(1);
239 unsigned OffsetsElemCount =
240 cast<FixedVectorType>(Offsets->getType())->getNumElements();
241 // Paranoid check whether the number of parallel lanes is the same
242 assert(Ty->getNumElements() == OffsetsElemCount);
243
244 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets);
245 if (ZextOffs)
246 Offsets = ZextOffs->getOperand(0);
247 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType());
248
249 // If the offsets are already being zext-ed to <N x i32>, that relieves us of
250 // having to make sure that they won't overflow.
251 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy())
252 ->getElementType()
253 ->getScalarSizeInBits() != 32)
254 if (!checkOffsetSize(Offsets, OffsetsElemCount))
255 return nullptr;
256
257 // The offset sizes have been checked; if any truncating or zext-ing is
258 // required to fix them, do that now
259 if (Ty != Offsets->getType()) {
260 if ((Ty->getElementType()->getScalarSizeInBits() <
261 OffsetType->getElementType()->getScalarSizeInBits())) {
262 Offsets = Builder.CreateTrunc(Offsets, Ty);
263 } else {
264 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty));
265 }
266 }
267 // If none of the checks failed, return the gep's base pointer
268 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
269 return GEPPtr;
270 }
271
lookThroughBitcast(Value * & Ptr)272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
273 // Look through bitcast instruction if #elements is the same
274 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
275 auto *BCTy = cast<FixedVectorType>(BitCast->getType());
276 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType());
277 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
278 LLVM_DEBUG(
279 dbgs() << "masked gathers/scatters: looking through bitcast\n");
280 Ptr = BitCast->getOperand(0);
281 }
282 }
283 }
284
computeScale(unsigned GEPElemSize,unsigned MemoryElemSize)285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
286 unsigned MemoryElemSize) {
287 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
288 // or a 8bit, 16bit or 32bit load/store scaled by 1
289 if (GEPElemSize == 32 && MemoryElemSize == 32)
290 return 2;
291 else if (GEPElemSize == 16 && MemoryElemSize == 16)
292 return 1;
293 else if (GEPElemSize == 8)
294 return 0;
295 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
296 << "create intrinsic\n");
297 return -1;
298 }
299
getIfConst(const Value * V)300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) {
301 const Constant *C = dyn_cast<Constant>(V);
302 if (C != nullptr)
303 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()};
304 if (!isa<Instruction>(V))
305 return Optional<int64_t>{};
306
307 const Instruction *I = cast<Instruction>(V);
308 if (I->getOpcode() == Instruction::Add ||
309 I->getOpcode() == Instruction::Mul) {
310 Optional<int64_t> Op0 = getIfConst(I->getOperand(0));
311 Optional<int64_t> Op1 = getIfConst(I->getOperand(1));
312 if (!Op0 || !Op1)
313 return Optional<int64_t>{};
314 if (I->getOpcode() == Instruction::Add)
315 return Optional<int64_t>{Op0.getValue() + Op1.getValue()};
316 if (I->getOpcode() == Instruction::Mul)
317 return Optional<int64_t>{Op0.getValue() * Op1.getValue()};
318 }
319 return Optional<int64_t>{};
320 }
321
322 std::pair<Value *, int64_t>
getVarAndConst(Value * Inst,int TypeScale)323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) {
324 std::pair<Value *, int64_t> ReturnFalse =
325 std::pair<Value *, int64_t>(nullptr, 0);
326 // At this point, the instruction we're looking at must be an add or we
327 // bail out
328 Instruction *Add = dyn_cast<Instruction>(Inst);
329 if (Add == nullptr || Add->getOpcode() != Instruction::Add)
330 return ReturnFalse;
331
332 Value *Summand;
333 Optional<int64_t> Const;
334 // Find out which operand the value that is increased is
335 if ((Const = getIfConst(Add->getOperand(0))))
336 Summand = Add->getOperand(1);
337 else if ((Const = getIfConst(Add->getOperand(1))))
338 Summand = Add->getOperand(0);
339 else
340 return ReturnFalse;
341
342 // Check that the constant is small enough for an incrementing gather
343 int64_t Immediate = Const.getValue() << TypeScale;
344 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0)
345 return ReturnFalse;
346
347 return std::pair<Value *, int64_t>(Summand, Immediate);
348 }
349
lowerGather(IntrinsicInst * I)350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
351 using namespace PatternMatch;
352 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
353
354 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
355 // Attempt to turn the masked gather in I into a MVE intrinsic
356 // Potentially optimising the addressing modes as we do so.
357 auto *Ty = cast<FixedVectorType>(I->getType());
358 Value *Ptr = I->getArgOperand(0);
359 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue();
360 Value *Mask = I->getArgOperand(2);
361 Value *PassThru = I->getArgOperand(3);
362
363 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
364 Alignment))
365 return nullptr;
366 lookThroughBitcast(Ptr);
367 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
368
369 IRBuilder<> Builder(I->getContext());
370 Builder.SetInsertPoint(I);
371 Builder.SetCurrentDebugLocation(I->getDebugLoc());
372
373 Instruction *Root = I;
374 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
375 if (!Load)
376 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
377 if (!Load)
378 return nullptr;
379
380 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
381 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
382 << "creating select\n");
383 Load = Builder.CreateSelect(Mask, Load, PassThru);
384 }
385
386 Root->replaceAllUsesWith(Load);
387 Root->eraseFromParent();
388 if (Root != I)
389 // If this was an extending gather, we need to get rid of the sext/zext
390 // sext/zext as well as of the gather itself
391 I->eraseFromParent();
392
393 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
394 return Load;
395 }
396
tryCreateMaskedGatherBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)397 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
398 Value *Ptr,
399 IRBuilder<> &Builder,
400 int64_t Increment) {
401 using namespace PatternMatch;
402 auto *Ty = cast<FixedVectorType>(I->getType());
403 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
404 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
405 // Can't build an intrinsic for this
406 return nullptr;
407 Value *Mask = I->getArgOperand(2);
408 if (match(Mask, m_One()))
409 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
410 {Ty, Ptr->getType()},
411 {Ptr, Builder.getInt32(Increment)});
412 else
413 return Builder.CreateIntrinsic(
414 Intrinsic::arm_mve_vldr_gather_base_predicated,
415 {Ty, Ptr->getType(), Mask->getType()},
416 {Ptr, Builder.getInt32(Increment), Mask});
417 }
418
tryCreateMaskedGatherBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)419 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB(
420 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
421 using namespace PatternMatch;
422 auto *Ty = cast<FixedVectorType>(I->getType());
423 LLVM_DEBUG(
424 dbgs()
425 << "masked gathers: loading from vector of pointers with writeback\n");
426 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
427 // Can't build an intrinsic for this
428 return nullptr;
429 Value *Mask = I->getArgOperand(2);
430 if (match(Mask, m_One()))
431 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb,
432 {Ty, Ptr->getType()},
433 {Ptr, Builder.getInt32(Increment)});
434 else
435 return Builder.CreateIntrinsic(
436 Intrinsic::arm_mve_vldr_gather_base_wb_predicated,
437 {Ty, Ptr->getType(), Mask->getType()},
438 {Ptr, Builder.getInt32(Increment), Mask});
439 }
440
tryCreateMaskedGatherOffset(IntrinsicInst * I,Value * Ptr,Instruction * & Root,IRBuilder<> & Builder)441 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
442 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
443 using namespace PatternMatch;
444
445 Type *OriginalTy = I->getType();
446 Type *ResultTy = OriginalTy;
447
448 unsigned Unsigned = 1;
449 // The size of the gather was already checked in isLegalTypeAndAlignment;
450 // if it was not a full vector width an appropriate extend should follow.
451 auto *Extend = Root;
452 if (OriginalTy->getPrimitiveSizeInBits() < 128) {
453 // Only transform gathers with exactly one use
454 if (!I->hasOneUse())
455 return nullptr;
456
457 // The correct root to replace is not the CallInst itself, but the
458 // instruction which extends it
459 Extend = cast<Instruction>(*I->users().begin());
460 if (isa<SExtInst>(Extend)) {
461 Unsigned = 0;
462 } else if (!isa<ZExtInst>(Extend)) {
463 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
464 << "Expanding\n");
465 return nullptr;
466 }
467 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
468 ResultTy = Extend->getType();
469 // The final size of the gather must be a full vector width
470 if (ResultTy->getPrimitiveSizeInBits() != 128) {
471 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
472 << "Expanding\n");
473 return nullptr;
474 }
475 }
476
477 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
478 Value *Offsets;
479 Value *BasePtr =
480 checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder);
481 if (!BasePtr)
482 return nullptr;
483 // Check whether the offset is a constant increment that could be merged into
484 // a QI gather
485 Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
486 if (Load)
487 return Load;
488
489 int Scale = computeScale(
490 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
491 OriginalTy->getScalarSizeInBits());
492 if (Scale == -1)
493 return nullptr;
494 Root = Extend;
495
496 Value *Mask = I->getArgOperand(2);
497 if (!match(Mask, m_One()))
498 return Builder.CreateIntrinsic(
499 Intrinsic::arm_mve_vldr_gather_offset_predicated,
500 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
501 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
502 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
503 else
504 return Builder.CreateIntrinsic(
505 Intrinsic::arm_mve_vldr_gather_offset,
506 {ResultTy, BasePtr->getType(), Offsets->getType()},
507 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
508 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
509 }
510
lowerScatter(IntrinsicInst * I)511 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
512 using namespace PatternMatch;
513 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
514
515 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
516 // Attempt to turn the masked scatter in I into a MVE intrinsic
517 // Potentially optimising the addressing modes as we do so.
518 Value *Input = I->getArgOperand(0);
519 Value *Ptr = I->getArgOperand(1);
520 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue();
521 auto *Ty = cast<FixedVectorType>(Input->getType());
522
523 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
524 Alignment))
525 return nullptr;
526
527 lookThroughBitcast(Ptr);
528 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
529
530 IRBuilder<> Builder(I->getContext());
531 Builder.SetInsertPoint(I);
532 Builder.SetCurrentDebugLocation(I->getDebugLoc());
533
534 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
535 if (!Store)
536 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
537 if (!Store)
538 return nullptr;
539
540 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
541 I->eraseFromParent();
542 return Store;
543 }
544
tryCreateMaskedScatterBase(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)545 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
546 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
547 using namespace PatternMatch;
548 Value *Input = I->getArgOperand(0);
549 auto *Ty = cast<FixedVectorType>(Input->getType());
550 // Only QR variants allow truncating
551 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
552 // Can't build an intrinsic for this
553 return nullptr;
554 }
555 Value *Mask = I->getArgOperand(3);
556 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
557 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
558 if (match(Mask, m_One()))
559 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
560 {Ptr->getType(), Input->getType()},
561 {Ptr, Builder.getInt32(Increment), Input});
562 else
563 return Builder.CreateIntrinsic(
564 Intrinsic::arm_mve_vstr_scatter_base_predicated,
565 {Ptr->getType(), Input->getType(), Mask->getType()},
566 {Ptr, Builder.getInt32(Increment), Input, Mask});
567 }
568
tryCreateMaskedScatterBaseWB(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder,int64_t Increment)569 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB(
570 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) {
571 using namespace PatternMatch;
572 Value *Input = I->getArgOperand(0);
573 auto *Ty = cast<FixedVectorType>(Input->getType());
574 LLVM_DEBUG(
575 dbgs()
576 << "masked scatters: storing to a vector of pointers with writeback\n");
577 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
578 // Can't build an intrinsic for this
579 return nullptr;
580 Value *Mask = I->getArgOperand(3);
581 if (match(Mask, m_One()))
582 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb,
583 {Ptr->getType(), Input->getType()},
584 {Ptr, Builder.getInt32(Increment), Input});
585 else
586 return Builder.CreateIntrinsic(
587 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated,
588 {Ptr->getType(), Input->getType(), Mask->getType()},
589 {Ptr, Builder.getInt32(Increment), Input, Mask});
590 }
591
tryCreateMaskedScatterOffset(IntrinsicInst * I,Value * Ptr,IRBuilder<> & Builder)592 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
593 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
594 using namespace PatternMatch;
595 Value *Input = I->getArgOperand(0);
596 Value *Mask = I->getArgOperand(3);
597 Type *InputTy = Input->getType();
598 Type *MemoryTy = InputTy;
599 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
600 << " to base + vector of offsets\n");
601 // If the input has been truncated, try to integrate that trunc into the
602 // scatter instruction (we don't care about alignment here)
603 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
604 Value *PreTrunc = Trunc->getOperand(0);
605 Type *PreTruncTy = PreTrunc->getType();
606 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
607 Input = PreTrunc;
608 InputTy = PreTruncTy;
609 }
610 }
611 if (InputTy->getPrimitiveSizeInBits() != 128) {
612 LLVM_DEBUG(
613 dbgs() << "masked scatters: cannot create scatters for non-standard"
614 << " input types. Expanding.\n");
615 return nullptr;
616 }
617
618 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
619 Value *Offsets;
620 Value *BasePtr =
621 checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder);
622 if (!BasePtr)
623 return nullptr;
624 // Check whether the offset is a constant increment that could be merged into
625 // a QI gather
626 Value *Store =
627 tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder);
628 if (Store)
629 return Store;
630 int Scale = computeScale(
631 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
632 MemoryTy->getScalarSizeInBits());
633 if (Scale == -1)
634 return nullptr;
635
636 if (!match(Mask, m_One()))
637 return Builder.CreateIntrinsic(
638 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
639 {BasePtr->getType(), Offsets->getType(), Input->getType(),
640 Mask->getType()},
641 {BasePtr, Offsets, Input,
642 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
643 Builder.getInt32(Scale), Mask});
644 else
645 return Builder.CreateIntrinsic(
646 Intrinsic::arm_mve_vstr_scatter_offset,
647 {BasePtr->getType(), Offsets->getType(), Input->getType()},
648 {BasePtr, Offsets, Input,
649 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
650 Builder.getInt32(Scale)});
651 }
652
tryCreateIncrementingGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,GetElementPtrInst * GEP,IRBuilder<> & Builder)653 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat(
654 IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP,
655 IRBuilder<> &Builder) {
656 FixedVectorType *Ty;
657 if (I->getIntrinsicID() == Intrinsic::masked_gather)
658 Ty = cast<FixedVectorType>(I->getType());
659 else
660 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType());
661 // Incrementing gathers only exist for v4i32
662 if (Ty->getNumElements() != 4 ||
663 Ty->getScalarSizeInBits() != 32)
664 return nullptr;
665 Loop *L = LI->getLoopFor(I->getParent());
666 if (L == nullptr)
667 // Incrementing gathers are not beneficial outside of a loop
668 return nullptr;
669 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
670 "wb gather/scatter\n");
671
672 // The gep was in charge of making sure the offsets are scaled correctly
673 // - calculate that factor so it can be applied by hand
674 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout();
675 int TypeScale =
676 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()),
677 DT.getTypeSizeInBits(GEP->getType()) /
678 cast<FixedVectorType>(GEP->getType())->getNumElements());
679 if (TypeScale == -1)
680 return nullptr;
681
682 if (GEP->hasOneUse()) {
683 // Only in this case do we want to build a wb gather, because the wb will
684 // change the phi which does affect other users of the gep (which will still
685 // be using the phi in the old way)
686 Value *Load =
687 tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder);
688 if (Load != nullptr)
689 return Load;
690 }
691 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing "
692 "non-wb gather/scatter\n");
693
694 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
695 if (Add.first == nullptr)
696 return nullptr;
697 Value *OffsetsIncoming = Add.first;
698 int64_t Immediate = Add.second;
699
700 // Make sure the offsets are scaled correctly
701 Instruction *ScaledOffsets = BinaryOperator::Create(
702 Instruction::Shl, OffsetsIncoming,
703 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)),
704 "ScaledIndex", I);
705 // Add the base to the offsets
706 OffsetsIncoming = BinaryOperator::Create(
707 Instruction::Add, ScaledOffsets,
708 Builder.CreateVectorSplat(
709 Ty->getNumElements(),
710 Builder.CreatePtrToInt(
711 BasePtr,
712 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
713 "StartIndex", I);
714
715 if (I->getIntrinsicID() == Intrinsic::masked_gather)
716 return cast<IntrinsicInst>(
717 tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate));
718 else
719 return cast<IntrinsicInst>(
720 tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate));
721 }
722
tryCreateIncrementingWBGatScat(IntrinsicInst * I,Value * BasePtr,Value * Offsets,unsigned TypeScale,IRBuilder<> & Builder)723 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat(
724 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale,
725 IRBuilder<> &Builder) {
726 // Check whether this gather's offset is incremented by a constant - if so,
727 // and the load is of the right type, we can merge this into a QI gather
728 Loop *L = LI->getLoopFor(I->getParent());
729 // Offsets that are worth merging into this instruction will be incremented
730 // by a constant, thus we're looking for an add of a phi and a constant
731 PHINode *Phi = dyn_cast<PHINode>(Offsets);
732 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 ||
733 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2)
734 // No phi means no IV to write back to; if there is a phi, we expect it
735 // to have exactly two incoming values; the only phis we are interested in
736 // will be loop IV's and have exactly two uses, one in their increment and
737 // one in the gather's gep
738 return nullptr;
739
740 unsigned IncrementIndex =
741 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1;
742 // Look through the phi to the phi increment
743 Offsets = Phi->getIncomingValue(IncrementIndex);
744
745 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale);
746 if (Add.first == nullptr)
747 return nullptr;
748 Value *OffsetsIncoming = Add.first;
749 int64_t Immediate = Add.second;
750 if (OffsetsIncoming != Phi)
751 // Then the increment we are looking at is not an increment of the
752 // induction variable, and we don't want to do a writeback
753 return nullptr;
754
755 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back());
756 unsigned NumElems =
757 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements();
758
759 // Make sure the offsets are scaled correctly
760 Instruction *ScaledOffsets = BinaryOperator::Create(
761 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex),
762 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)),
763 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
764 // Add the base to the offsets
765 OffsetsIncoming = BinaryOperator::Create(
766 Instruction::Add, ScaledOffsets,
767 Builder.CreateVectorSplat(
768 NumElems,
769 Builder.CreatePtrToInt(
770 BasePtr,
771 cast<VectorType>(ScaledOffsets->getType())->getElementType())),
772 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back());
773 // The gather is pre-incrementing
774 OffsetsIncoming = BinaryOperator::Create(
775 Instruction::Sub, OffsetsIncoming,
776 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)),
777 "PreIncrementStartIndex",
778 &Phi->getIncomingBlock(1 - IncrementIndex)->back());
779 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming);
780
781 Builder.SetInsertPoint(I);
782
783 Value *EndResult;
784 Value *NewInduction;
785 if (I->getIntrinsicID() == Intrinsic::masked_gather) {
786 // Build the incrementing gather
787 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate);
788 // One value to be handed to whoever uses the gather, one is the loop
789 // increment
790 EndResult = Builder.CreateExtractValue(Load, 0, "Gather");
791 NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement");
792 } else {
793 // Build the incrementing scatter
794 NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate);
795 EndResult = NewInduction;
796 }
797 Instruction *AddInst = cast<Instruction>(Offsets);
798 AddInst->replaceAllUsesWith(NewInduction);
799 AddInst->eraseFromParent();
800 Phi->setIncomingValue(IncrementIndex, NewInduction);
801
802 return EndResult;
803 }
804
pushOutAdd(PHINode * & Phi,Value * OffsSecondOperand,unsigned StartIndex)805 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
806 Value *OffsSecondOperand,
807 unsigned StartIndex) {
808 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
809 Instruction *InsertionPoint =
810 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
811 // Initialize the phi with a vector that contains a sum of the constants
812 Instruction *NewIndex = BinaryOperator::Create(
813 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
814 "PushedOutAdd", InsertionPoint);
815 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
816
817 // Order such that start index comes first (this reduces mov's)
818 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
819 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
820 Phi->getIncomingBlock(IncrementIndex));
821 Phi->removeIncomingValue(IncrementIndex);
822 Phi->removeIncomingValue(StartIndex);
823 }
824
pushOutMul(PHINode * & Phi,Value * IncrementPerRound,Value * OffsSecondOperand,unsigned LoopIncrement,IRBuilder<> & Builder)825 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
826 Value *IncrementPerRound,
827 Value *OffsSecondOperand,
828 unsigned LoopIncrement,
829 IRBuilder<> &Builder) {
830 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
831
832 // Create a new scalar add outside of the loop and transform it to a splat
833 // by which loop variable can be incremented
834 Instruction *InsertionPoint = &cast<Instruction>(
835 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
836
837 // Create a new index
838 Value *StartIndex = BinaryOperator::Create(
839 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
840 OffsSecondOperand, "PushedOutMul", InsertionPoint);
841
842 Instruction *Product =
843 BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
844 OffsSecondOperand, "Product", InsertionPoint);
845 // Increment NewIndex by Product instead of the multiplication
846 Instruction *NewIncrement = BinaryOperator::Create(
847 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
848 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
849 .getPrevNode());
850
851 Phi->addIncoming(StartIndex,
852 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
853 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
854 Phi->removeIncomingValue((unsigned)0);
855 Phi->removeIncomingValue((unsigned)0);
856 return;
857 }
858
859 // Check whether all usages of this instruction are as offsets of
860 // gathers/scatters or simple arithmetics only used by gathers/scatters
hasAllGatScatUsers(Instruction * I)861 static bool hasAllGatScatUsers(Instruction *I) {
862 if (I->hasNUses(0)) {
863 return false;
864 }
865 bool Gatscat = true;
866 for (User *U : I->users()) {
867 if (!isa<Instruction>(U))
868 return false;
869 if (isa<GetElementPtrInst>(U) ||
870 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
871 return Gatscat;
872 } else {
873 unsigned OpCode = cast<Instruction>(U)->getOpcode();
874 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
875 hasAllGatScatUsers(cast<Instruction>(U))) {
876 continue;
877 }
878 return false;
879 }
880 }
881 return Gatscat;
882 }
883
optimiseOffsets(Value * Offsets,BasicBlock * BB,LoopInfo * LI)884 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
885 LoopInfo *LI) {
886 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
887 // Optimise the addresses of gathers/scatters by moving invariant
888 // calculations out of the loop
889 if (!isa<Instruction>(Offsets))
890 return false;
891 Instruction *Offs = cast<Instruction>(Offsets);
892 if (Offs->getOpcode() != Instruction::Add &&
893 Offs->getOpcode() != Instruction::Mul)
894 return false;
895 Loop *L = LI->getLoopFor(BB);
896 if (L == nullptr)
897 return false;
898 if (!Offs->hasOneUse()) {
899 if (!hasAllGatScatUsers(Offs))
900 return false;
901 }
902
903 // Find out which, if any, operand of the instruction
904 // is a phi node
905 PHINode *Phi;
906 int OffsSecondOp;
907 if (isa<PHINode>(Offs->getOperand(0))) {
908 Phi = cast<PHINode>(Offs->getOperand(0));
909 OffsSecondOp = 1;
910 } else if (isa<PHINode>(Offs->getOperand(1))) {
911 Phi = cast<PHINode>(Offs->getOperand(1));
912 OffsSecondOp = 0;
913 } else {
914 bool Changed = true;
915 if (isa<Instruction>(Offs->getOperand(0)) &&
916 L->contains(cast<Instruction>(Offs->getOperand(0))))
917 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
918 if (isa<Instruction>(Offs->getOperand(1)) &&
919 L->contains(cast<Instruction>(Offs->getOperand(1))))
920 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
921 if (!Changed) {
922 return false;
923 } else {
924 if (isa<PHINode>(Offs->getOperand(0))) {
925 Phi = cast<PHINode>(Offs->getOperand(0));
926 OffsSecondOp = 1;
927 } else if (isa<PHINode>(Offs->getOperand(1))) {
928 Phi = cast<PHINode>(Offs->getOperand(1));
929 OffsSecondOp = 0;
930 } else {
931 return false;
932 }
933 }
934 }
935 // A phi node we want to perform this function on should be from the
936 // loop header, and shouldn't have more than 2 incoming values
937 if (Phi->getParent() != L->getHeader() ||
938 Phi->getNumIncomingValues() != 2)
939 return false;
940
941 // The phi must be an induction variable
942 Instruction *Op;
943 int IncrementingBlock = -1;
944
945 for (int i = 0; i < 2; i++)
946 if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
947 if (Op->getOpcode() == Instruction::Add &&
948 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
949 IncrementingBlock = i;
950 if (IncrementingBlock == -1)
951 return false;
952
953 Instruction *IncInstruction =
954 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
955
956 // If the phi is not used by anything else, we can just adapt it when
957 // replacing the instruction; if it is, we'll have to duplicate it
958 PHINode *NewPhi;
959 Value *IncrementPerRound = IncInstruction->getOperand(
960 (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
961
962 // Get the value that is added to/multiplied with the phi
963 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
964
965 if (IncrementPerRound->getType() != OffsSecondOperand->getType())
966 // Something has gone wrong, abort
967 return false;
968
969 // Only proceed if the increment per round is a constant or an instruction
970 // which does not originate from within the loop
971 if (!isa<Constant>(IncrementPerRound) &&
972 !(isa<Instruction>(IncrementPerRound) &&
973 !L->contains(cast<Instruction>(IncrementPerRound))))
974 return false;
975
976 if (Phi->getNumUses() == 2) {
977 // No other users -> reuse existing phi (One user is the instruction
978 // we're looking at, the other is the phi increment)
979 if (IncInstruction->getNumUses() != 1) {
980 // If the incrementing instruction does have more users than
981 // our phi, we need to copy it
982 IncInstruction = BinaryOperator::Create(
983 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
984 IncrementPerRound, "LoopIncrement", IncInstruction);
985 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
986 }
987 NewPhi = Phi;
988 } else {
989 // There are other users -> create a new phi
990 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
991 std::vector<Value *> Increases;
992 // Copy the incoming values of the old phi
993 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
994 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
995 IncInstruction = BinaryOperator::Create(
996 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
997 IncrementPerRound, "LoopIncrement", IncInstruction);
998 NewPhi->addIncoming(IncInstruction,
999 Phi->getIncomingBlock(IncrementingBlock));
1000 IncrementingBlock = 1;
1001 }
1002
1003 IRBuilder<> Builder(BB->getContext());
1004 Builder.SetInsertPoint(Phi);
1005 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
1006
1007 switch (Offs->getOpcode()) {
1008 case Instruction::Add:
1009 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
1010 break;
1011 case Instruction::Mul:
1012 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
1013 Builder);
1014 break;
1015 default:
1016 return false;
1017 }
1018 LLVM_DEBUG(
1019 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
1020
1021 // The instruction has now been "absorbed" into the phi value
1022 Offs->replaceAllUsesWith(NewPhi);
1023 if (Offs->hasNUses(0))
1024 Offs->eraseFromParent();
1025 // Clean up the old increment in case it's unused because we built a new
1026 // one
1027 if (IncInstruction->hasNUses(0))
1028 IncInstruction->eraseFromParent();
1029
1030 return true;
1031 }
1032
CheckAndCreateOffsetAdd(Value * X,Value * Y,Value * GEP,IRBuilder<> & Builder)1033 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP,
1034 IRBuilder<> &Builder) {
1035 // Splat the non-vector value to a vector of the given type - if the value is
1036 // a constant (and its value isn't too big), we can even use this opportunity
1037 // to scale it to the size of the vector elements
1038 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) {
1039 ConstantInt *Const;
1040 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) &&
1041 VT->getElementType() != NonVectorVal->getType()) {
1042 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits();
1043 uint64_t N = Const->getZExtValue();
1044 if (N < (unsigned)(1 << (TargetElemSize - 1))) {
1045 NonVectorVal = Builder.CreateVectorSplat(
1046 VT->getNumElements(), Builder.getIntN(TargetElemSize, N));
1047 return;
1048 }
1049 }
1050 NonVectorVal =
1051 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal);
1052 };
1053
1054 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType());
1055 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType());
1056 // If one of X, Y is not a vector, we have to splat it in order
1057 // to add the two of them.
1058 if (XElType && !YElType) {
1059 FixSummands(XElType, Y);
1060 YElType = cast<FixedVectorType>(Y->getType());
1061 } else if (YElType && !XElType) {
1062 FixSummands(YElType, X);
1063 XElType = cast<FixedVectorType>(X->getType());
1064 }
1065 assert(XElType && YElType && "Unknown vector types");
1066 // Check that the summands are of compatible types
1067 if (XElType != YElType) {
1068 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n");
1069 return nullptr;
1070 }
1071
1072 if (XElType->getElementType()->getScalarSizeInBits() != 32) {
1073 // Check that by adding the vectors we do not accidentally
1074 // create an overflow
1075 Constant *ConstX = dyn_cast<Constant>(X);
1076 Constant *ConstY = dyn_cast<Constant>(Y);
1077 if (!ConstX || !ConstY)
1078 return nullptr;
1079 unsigned TargetElemSize = 128 / XElType->getNumElements();
1080 for (unsigned i = 0; i < XElType->getNumElements(); i++) {
1081 ConstantInt *ConstXEl =
1082 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i));
1083 ConstantInt *ConstYEl =
1084 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i));
1085 if (!ConstXEl || !ConstYEl ||
1086 ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >=
1087 (unsigned)(1 << (TargetElemSize - 1)))
1088 return nullptr;
1089 }
1090 }
1091
1092 Value *Add = Builder.CreateAdd(X, Y);
1093
1094 FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType());
1095 if (checkOffsetSize(Add, GEPType->getNumElements()))
1096 return Add;
1097 else
1098 return nullptr;
1099 }
1100
foldGEP(GetElementPtrInst * GEP,Value * & Offsets,IRBuilder<> & Builder)1101 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP,
1102 Value *&Offsets,
1103 IRBuilder<> &Builder) {
1104 Value *GEPPtr = GEP->getPointerOperand();
1105 Offsets = GEP->getOperand(1);
1106 // We only merge geps with constant offsets, because only for those
1107 // we can make sure that we do not cause an overflow
1108 if (!isa<Constant>(Offsets))
1109 return nullptr;
1110 GetElementPtrInst *BaseGEP;
1111 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) {
1112 // Merge the two geps into one
1113 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder);
1114 if (!BaseBasePtr)
1115 return nullptr;
1116 Offsets =
1117 CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder);
1118 if (Offsets == nullptr)
1119 return nullptr;
1120 return BaseBasePtr;
1121 }
1122 return GEPPtr;
1123 }
1124
optimiseAddress(Value * Address,BasicBlock * BB,LoopInfo * LI)1125 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB,
1126 LoopInfo *LI) {
1127 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address);
1128 if (!GEP)
1129 return false;
1130 bool Changed = false;
1131 if (GEP->hasOneUse() &&
1132 dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) {
1133 IRBuilder<> Builder(GEP->getContext());
1134 Builder.SetInsertPoint(GEP);
1135 Builder.SetCurrentDebugLocation(GEP->getDebugLoc());
1136 Value *Offsets;
1137 Value *Base = foldGEP(GEP, Offsets, Builder);
1138 // We only want to merge the geps if there is a real chance that they can be
1139 // used by an MVE gather; thus the offset has to have the correct size
1140 // (always i32 if it is not of vector type) and the base has to be a
1141 // pointer.
1142 if (Offsets && Base && Base != GEP) {
1143 PointerType *BaseType = cast<PointerType>(Base->getType());
1144 GetElementPtrInst *NewAddress = GetElementPtrInst::Create(
1145 BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP);
1146 GEP->replaceAllUsesWith(NewAddress);
1147 GEP = NewAddress;
1148 Changed = true;
1149 }
1150 }
1151 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI);
1152 return Changed;
1153 }
1154
runOnFunction(Function & F)1155 bool MVEGatherScatterLowering::runOnFunction(Function &F) {
1156 if (!EnableMaskedGatherScatters)
1157 return false;
1158 auto &TPC = getAnalysis<TargetPassConfig>();
1159 auto &TM = TPC.getTM<TargetMachine>();
1160 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
1161 if (!ST->hasMVEIntegerOps())
1162 return false;
1163 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1164 SmallVector<IntrinsicInst *, 4> Gathers;
1165 SmallVector<IntrinsicInst *, 4> Scatters;
1166
1167 bool Changed = false;
1168
1169 for (BasicBlock &BB : F) {
1170 for (Instruction &I : BB) {
1171 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
1172 if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
1173 isa<FixedVectorType>(II->getType())) {
1174 Gathers.push_back(II);
1175 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI);
1176 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
1177 isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
1178 Scatters.push_back(II);
1179 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI);
1180 }
1181 }
1182 }
1183 for (unsigned i = 0; i < Gathers.size(); i++) {
1184 IntrinsicInst *I = Gathers[i];
1185 Value *L = lowerGather(I);
1186 if (L == nullptr)
1187 continue;
1188
1189 // Get rid of any now dead instructions
1190 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
1191 Changed = true;
1192 }
1193
1194 for (unsigned i = 0; i < Scatters.size(); i++) {
1195 IntrinsicInst *I = Scatters[i];
1196 Value *S = lowerScatter(I);
1197 if (S == nullptr)
1198 continue;
1199
1200 // Get rid of any now dead instructions
1201 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
1202 Changed = true;
1203 }
1204 return Changed;
1205 }
1206