1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/CodeGen/TargetSubtargetInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37
38 using namespace llvm;
39
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42 namespace {
43
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45 const TargetTransformInfo *TTI = nullptr;
46
47 public:
48 static char ID; // Pass identification, replacement for typeid
49
ScalarizeMaskedMemIntrin()50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 }
53
54 bool runOnFunction(Function &F) override;
55
getPassName() const56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
getAnalysisUsage(AnalysisUsage & AU) const60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63
64 private:
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67 };
68
69 } // end anonymous namespace
70
71 char ScalarizeMaskedMemIntrin::ID = 0;
72
73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 "Scalarize unsupported masked memory intrinsics", false, false)
75
createScalarizeMaskedMemIntrinPass()76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 return new ScalarizeMaskedMemIntrin();
78 }
79
isConstantIntVector(Value * Mask)80 static bool isConstantIntVector(Value *Mask) {
81 Constant *C = dyn_cast<Constant>(Mask);
82 if (!C)
83 return false;
84
85 unsigned NumElts = Mask->getType()->getVectorNumElements();
86 for (unsigned i = 0; i != NumElts; ++i) {
87 Constant *CElt = C->getAggregateElement(i);
88 if (!CElt || !isa<ConstantInt>(CElt))
89 return false;
90 }
91
92 return true;
93 }
94
95 // Translate a masked load intrinsic like
96 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
97 // <16 x i1> %mask, <16 x i32> %passthru)
98 // to a chain of basic blocks, with loading element one-by-one if
99 // the appropriate mask bit is set
100 //
101 // %1 = bitcast i8* %addr to i32*
102 // %2 = extractelement <16 x i1> %mask, i32 0
103 // br i1 %2, label %cond.load, label %else
104 //
105 // cond.load: ; preds = %0
106 // %3 = getelementptr i32* %1, i32 0
107 // %4 = load i32* %3
108 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
109 // br label %else
110 //
111 // else: ; preds = %0, %cond.load
112 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113 // %6 = extractelement <16 x i1> %mask, i32 1
114 // br i1 %6, label %cond.load1, label %else2
115 //
116 // cond.load1: ; preds = %else
117 // %7 = getelementptr i32* %1, i32 1
118 // %8 = load i32* %7
119 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
120 // br label %else2
121 //
122 // else2: ; preds = %else, %cond.load1
123 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124 // %10 = extractelement <16 x i1> %mask, i32 2
125 // br i1 %10, label %cond.load4, label %else5
126 //
scalarizeMaskedLoad(CallInst * CI,bool & ModifiedDT)127 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
128 Value *Ptr = CI->getArgOperand(0);
129 Value *Alignment = CI->getArgOperand(1);
130 Value *Mask = CI->getArgOperand(2);
131 Value *Src0 = CI->getArgOperand(3);
132
133 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
134 VectorType *VecType = cast<VectorType>(CI->getType());
135
136 Type *EltTy = VecType->getElementType();
137
138 IRBuilder<> Builder(CI->getContext());
139 Instruction *InsertPt = CI;
140 BasicBlock *IfBlock = CI->getParent();
141
142 Builder.SetInsertPoint(InsertPt);
143 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
144
145 // Short-cut if the mask is all-true.
146 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
147 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
148 CI->replaceAllUsesWith(NewI);
149 CI->eraseFromParent();
150 return;
151 }
152
153 // Adjust alignment for the scalar instruction.
154 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
155 // Bitcast %addr from i8* to EltTy*
156 Type *NewPtrType =
157 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
158 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
159 unsigned VectorWidth = VecType->getNumElements();
160
161 // The result vector
162 Value *VResult = Src0;
163
164 if (isConstantIntVector(Mask)) {
165 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
166 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
167 continue;
168 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
169 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
170 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
171 }
172 CI->replaceAllUsesWith(VResult);
173 CI->eraseFromParent();
174 return;
175 }
176
177 // If the mask is not v1i1, use scalar bit test operations. This generates
178 // better results on X86 at least.
179 Value *SclrMask;
180 if (VectorWidth != 1) {
181 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
182 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
183 }
184
185 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
186 // Fill the "else" block, created in the previous iteration
187 //
188 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
189 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
190 // %cond = icmp ne i16 %mask_1, 0
191 // br i1 %mask_1, label %cond.load, label %else
192 //
193 Value *Predicate;
194 if (VectorWidth != 1) {
195 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
196 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
197 Builder.getIntN(VectorWidth, 0));
198 } else {
199 Predicate = Builder.CreateExtractElement(Mask, Idx);
200 }
201
202 // Create "cond" block
203 //
204 // %EltAddr = getelementptr i32* %1, i32 0
205 // %Elt = load i32* %EltAddr
206 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
207 //
208 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
209 "cond.load");
210 Builder.SetInsertPoint(InsertPt);
211
212 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
213 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
214 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
215
216 // Create "else" block, fill it in the next iteration
217 BasicBlock *NewIfBlock =
218 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
219 Builder.SetInsertPoint(InsertPt);
220 Instruction *OldBr = IfBlock->getTerminator();
221 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
222 OldBr->eraseFromParent();
223 BasicBlock *PrevIfBlock = IfBlock;
224 IfBlock = NewIfBlock;
225
226 // Create the phi to join the new and previous value.
227 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
228 Phi->addIncoming(NewVResult, CondBlock);
229 Phi->addIncoming(VResult, PrevIfBlock);
230 VResult = Phi;
231 }
232
233 CI->replaceAllUsesWith(VResult);
234 CI->eraseFromParent();
235
236 ModifiedDT = true;
237 }
238
239 // Translate a masked store intrinsic, like
240 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
241 // <16 x i1> %mask)
242 // to a chain of basic blocks, that stores element one-by-one if
243 // the appropriate mask bit is set
244 //
245 // %1 = bitcast i8* %addr to i32*
246 // %2 = extractelement <16 x i1> %mask, i32 0
247 // br i1 %2, label %cond.store, label %else
248 //
249 // cond.store: ; preds = %0
250 // %3 = extractelement <16 x i32> %val, i32 0
251 // %4 = getelementptr i32* %1, i32 0
252 // store i32 %3, i32* %4
253 // br label %else
254 //
255 // else: ; preds = %0, %cond.store
256 // %5 = extractelement <16 x i1> %mask, i32 1
257 // br i1 %5, label %cond.store1, label %else2
258 //
259 // cond.store1: ; preds = %else
260 // %6 = extractelement <16 x i32> %val, i32 1
261 // %7 = getelementptr i32* %1, i32 1
262 // store i32 %6, i32* %7
263 // br label %else2
264 // . . .
scalarizeMaskedStore(CallInst * CI,bool & ModifiedDT)265 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
266 Value *Src = CI->getArgOperand(0);
267 Value *Ptr = CI->getArgOperand(1);
268 Value *Alignment = CI->getArgOperand(2);
269 Value *Mask = CI->getArgOperand(3);
270
271 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
272 VectorType *VecType = cast<VectorType>(Src->getType());
273
274 Type *EltTy = VecType->getElementType();
275
276 IRBuilder<> Builder(CI->getContext());
277 Instruction *InsertPt = CI;
278 BasicBlock *IfBlock = CI->getParent();
279 Builder.SetInsertPoint(InsertPt);
280 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
281
282 // Short-cut if the mask is all-true.
283 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
284 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285 CI->eraseFromParent();
286 return;
287 }
288
289 // Adjust alignment for the scalar instruction.
290 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
291 // Bitcast %addr from i8* to EltTy*
292 Type *NewPtrType =
293 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
294 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295 unsigned VectorWidth = VecType->getNumElements();
296
297 if (isConstantIntVector(Mask)) {
298 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
300 continue;
301 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
302 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
303 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
304 }
305 CI->eraseFromParent();
306 return;
307 }
308
309 // If the mask is not v1i1, use scalar bit test operations. This generates
310 // better results on X86 at least.
311 Value *SclrMask;
312 if (VectorWidth != 1) {
313 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
314 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
315 }
316
317 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
318 // Fill the "else" block, created in the previous iteration
319 //
320 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
321 // %cond = icmp ne i16 %mask_1, 0
322 // br i1 %mask_1, label %cond.store, label %else
323 //
324 Value *Predicate;
325 if (VectorWidth != 1) {
326 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
327 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
328 Builder.getIntN(VectorWidth, 0));
329 } else {
330 Predicate = Builder.CreateExtractElement(Mask, Idx);
331 }
332
333 // Create "cond" block
334 //
335 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
336 // %EltAddr = getelementptr i32* %1, i32 0
337 // %store i32 %OneElt, i32* %EltAddr
338 //
339 BasicBlock *CondBlock =
340 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
341 Builder.SetInsertPoint(InsertPt);
342
343 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
344 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
345 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
346
347 // Create "else" block, fill it in the next iteration
348 BasicBlock *NewIfBlock =
349 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
350 Builder.SetInsertPoint(InsertPt);
351 Instruction *OldBr = IfBlock->getTerminator();
352 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
353 OldBr->eraseFromParent();
354 IfBlock = NewIfBlock;
355 }
356 CI->eraseFromParent();
357
358 ModifiedDT = true;
359 }
360
361 // Translate a masked gather intrinsic like
362 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
363 // <16 x i1> %Mask, <16 x i32> %Src)
364 // to a chain of basic blocks, with loading element one-by-one if
365 // the appropriate mask bit is set
366 //
367 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
368 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
369 // br i1 %Mask0, label %cond.load, label %else
370 //
371 // cond.load:
372 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
373 // %Load0 = load i32, i32* %Ptr0, align 4
374 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
375 // br label %else
376 //
377 // else:
378 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
379 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
380 // br i1 %Mask1, label %cond.load1, label %else2
381 //
382 // cond.load1:
383 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
384 // %Load1 = load i32, i32* %Ptr1, align 4
385 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
386 // br label %else2
387 // . . .
388 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
389 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI,bool & ModifiedDT)390 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
391 Value *Ptrs = CI->getArgOperand(0);
392 Value *Alignment = CI->getArgOperand(1);
393 Value *Mask = CI->getArgOperand(2);
394 Value *Src0 = CI->getArgOperand(3);
395
396 VectorType *VecType = cast<VectorType>(CI->getType());
397 Type *EltTy = VecType->getElementType();
398
399 IRBuilder<> Builder(CI->getContext());
400 Instruction *InsertPt = CI;
401 BasicBlock *IfBlock = CI->getParent();
402 Builder.SetInsertPoint(InsertPt);
403 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
404
405 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
406
407 // The result vector
408 Value *VResult = Src0;
409 unsigned VectorWidth = VecType->getNumElements();
410
411 // Shorten the way if the mask is a vector of constants.
412 if (isConstantIntVector(Mask)) {
413 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
414 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
415 continue;
416 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
417 LoadInst *Load =
418 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
419 VResult =
420 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
421 }
422 CI->replaceAllUsesWith(VResult);
423 CI->eraseFromParent();
424 return;
425 }
426
427 // If the mask is not v1i1, use scalar bit test operations. This generates
428 // better results on X86 at least.
429 Value *SclrMask;
430 if (VectorWidth != 1) {
431 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
432 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
433 }
434
435 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
436 // Fill the "else" block, created in the previous iteration
437 //
438 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
439 // %cond = icmp ne i16 %mask_1, 0
440 // br i1 %Mask1, label %cond.load, label %else
441 //
442
443 Value *Predicate;
444 if (VectorWidth != 1) {
445 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
446 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
447 Builder.getIntN(VectorWidth, 0));
448 } else {
449 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
450 }
451
452 // Create "cond" block
453 //
454 // %EltAddr = getelementptr i32* %1, i32 0
455 // %Elt = load i32* %EltAddr
456 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
457 //
458 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
459 Builder.SetInsertPoint(InsertPt);
460
461 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
462 LoadInst *Load =
463 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
464 Value *NewVResult =
465 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
466
467 // Create "else" block, fill it in the next iteration
468 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
469 Builder.SetInsertPoint(InsertPt);
470 Instruction *OldBr = IfBlock->getTerminator();
471 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
472 OldBr->eraseFromParent();
473 BasicBlock *PrevIfBlock = IfBlock;
474 IfBlock = NewIfBlock;
475
476 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
477 Phi->addIncoming(NewVResult, CondBlock);
478 Phi->addIncoming(VResult, PrevIfBlock);
479 VResult = Phi;
480 }
481
482 CI->replaceAllUsesWith(VResult);
483 CI->eraseFromParent();
484
485 ModifiedDT = true;
486 }
487
488 // Translate a masked scatter intrinsic, like
489 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
490 // <16 x i1> %Mask)
491 // to a chain of basic blocks, that stores element one-by-one if
492 // the appropriate mask bit is set.
493 //
494 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
495 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
496 // br i1 %Mask0, label %cond.store, label %else
497 //
498 // cond.store:
499 // %Elt0 = extractelement <16 x i32> %Src, i32 0
500 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
501 // store i32 %Elt0, i32* %Ptr0, align 4
502 // br label %else
503 //
504 // else:
505 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
506 // br i1 %Mask1, label %cond.store1, label %else2
507 //
508 // cond.store1:
509 // %Elt1 = extractelement <16 x i32> %Src, i32 1
510 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
511 // store i32 %Elt1, i32* %Ptr1, align 4
512 // br label %else2
513 // . . .
scalarizeMaskedScatter(CallInst * CI,bool & ModifiedDT)514 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
515 Value *Src = CI->getArgOperand(0);
516 Value *Ptrs = CI->getArgOperand(1);
517 Value *Alignment = CI->getArgOperand(2);
518 Value *Mask = CI->getArgOperand(3);
519
520 assert(isa<VectorType>(Src->getType()) &&
521 "Unexpected data type in masked scatter intrinsic");
522 assert(isa<VectorType>(Ptrs->getType()) &&
523 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
524 "Vector of pointers is expected in masked scatter intrinsic");
525
526 IRBuilder<> Builder(CI->getContext());
527 Instruction *InsertPt = CI;
528 BasicBlock *IfBlock = CI->getParent();
529 Builder.SetInsertPoint(InsertPt);
530 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
531
532 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
533 unsigned VectorWidth = Src->getType()->getVectorNumElements();
534
535 // Shorten the way if the mask is a vector of constants.
536 if (isConstantIntVector(Mask)) {
537 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
538 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
539 continue;
540 Value *OneElt =
541 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
542 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
543 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
544 }
545 CI->eraseFromParent();
546 return;
547 }
548
549 // If the mask is not v1i1, use scalar bit test operations. This generates
550 // better results on X86 at least.
551 Value *SclrMask;
552 if (VectorWidth != 1) {
553 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
554 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
555 }
556
557 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
558 // Fill the "else" block, created in the previous iteration
559 //
560 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx
561 // %cond = icmp ne i16 %mask_1, 0
562 // br i1 %Mask1, label %cond.store, label %else
563 //
564 Value *Predicate;
565 if (VectorWidth != 1) {
566 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
567 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
568 Builder.getIntN(VectorWidth, 0));
569 } else {
570 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
571 }
572
573 // Create "cond" block
574 //
575 // %Elt1 = extractelement <16 x i32> %Src, i32 1
576 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
577 // %store i32 %Elt1, i32* %Ptr1
578 //
579 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
580 Builder.SetInsertPoint(InsertPt);
581
582 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
583 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
584 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
585
586 // Create "else" block, fill it in the next iteration
587 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
588 Builder.SetInsertPoint(InsertPt);
589 Instruction *OldBr = IfBlock->getTerminator();
590 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
591 OldBr->eraseFromParent();
592 IfBlock = NewIfBlock;
593 }
594 CI->eraseFromParent();
595
596 ModifiedDT = true;
597 }
598
scalarizeMaskedExpandLoad(CallInst * CI,bool & ModifiedDT)599 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
600 Value *Ptr = CI->getArgOperand(0);
601 Value *Mask = CI->getArgOperand(1);
602 Value *PassThru = CI->getArgOperand(2);
603
604 VectorType *VecType = cast<VectorType>(CI->getType());
605
606 Type *EltTy = VecType->getElementType();
607
608 IRBuilder<> Builder(CI->getContext());
609 Instruction *InsertPt = CI;
610 BasicBlock *IfBlock = CI->getParent();
611
612 Builder.SetInsertPoint(InsertPt);
613 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
614
615 unsigned VectorWidth = VecType->getNumElements();
616
617 // The result vector
618 Value *VResult = PassThru;
619
620 // Shorten the way if the mask is a vector of constants.
621 if (isConstantIntVector(Mask)) {
622 unsigned MemIndex = 0;
623 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
624 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
625 continue;
626 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
627 LoadInst *Load =
628 Builder.CreateAlignedLoad(EltTy, NewPtr, 1, "Load" + Twine(Idx));
629 VResult =
630 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
631 ++MemIndex;
632 }
633 CI->replaceAllUsesWith(VResult);
634 CI->eraseFromParent();
635 return;
636 }
637
638 // If the mask is not v1i1, use scalar bit test operations. This generates
639 // better results on X86 at least.
640 Value *SclrMask;
641 if (VectorWidth != 1) {
642 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
643 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
644 }
645
646 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
647 // Fill the "else" block, created in the previous iteration
648 //
649 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
650 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
651 // br i1 %mask_1, label %cond.load, label %else
652 //
653
654 Value *Predicate;
655 if (VectorWidth != 1) {
656 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
657 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
658 Builder.getIntN(VectorWidth, 0));
659 } else {
660 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
661 }
662
663 // Create "cond" block
664 //
665 // %EltAddr = getelementptr i32* %1, i32 0
666 // %Elt = load i32* %EltAddr
667 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
668 //
669 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
670 "cond.load");
671 Builder.SetInsertPoint(InsertPt);
672
673 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
674 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
675
676 // Move the pointer if there are more blocks to come.
677 Value *NewPtr;
678 if ((Idx + 1) != VectorWidth)
679 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
680
681 // Create "else" block, fill it in the next iteration
682 BasicBlock *NewIfBlock =
683 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
684 Builder.SetInsertPoint(InsertPt);
685 Instruction *OldBr = IfBlock->getTerminator();
686 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
687 OldBr->eraseFromParent();
688 BasicBlock *PrevIfBlock = IfBlock;
689 IfBlock = NewIfBlock;
690
691 // Create the phi to join the new and previous value.
692 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
693 ResultPhi->addIncoming(NewVResult, CondBlock);
694 ResultPhi->addIncoming(VResult, PrevIfBlock);
695 VResult = ResultPhi;
696
697 // Add a PHI for the pointer if this isn't the last iteration.
698 if ((Idx + 1) != VectorWidth) {
699 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
700 PtrPhi->addIncoming(NewPtr, CondBlock);
701 PtrPhi->addIncoming(Ptr, PrevIfBlock);
702 Ptr = PtrPhi;
703 }
704 }
705
706 CI->replaceAllUsesWith(VResult);
707 CI->eraseFromParent();
708
709 ModifiedDT = true;
710 }
711
scalarizeMaskedCompressStore(CallInst * CI,bool & ModifiedDT)712 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
713 Value *Src = CI->getArgOperand(0);
714 Value *Ptr = CI->getArgOperand(1);
715 Value *Mask = CI->getArgOperand(2);
716
717 VectorType *VecType = cast<VectorType>(Src->getType());
718
719 IRBuilder<> Builder(CI->getContext());
720 Instruction *InsertPt = CI;
721 BasicBlock *IfBlock = CI->getParent();
722
723 Builder.SetInsertPoint(InsertPt);
724 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
725
726 Type *EltTy = VecType->getVectorElementType();
727
728 unsigned VectorWidth = VecType->getNumElements();
729
730 // Shorten the way if the mask is a vector of constants.
731 if (isConstantIntVector(Mask)) {
732 unsigned MemIndex = 0;
733 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
734 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
735 continue;
736 Value *OneElt =
737 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
738 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
739 Builder.CreateAlignedStore(OneElt, NewPtr, 1);
740 ++MemIndex;
741 }
742 CI->eraseFromParent();
743 return;
744 }
745
746 // If the mask is not v1i1, use scalar bit test operations. This generates
747 // better results on X86 at least.
748 Value *SclrMask;
749 if (VectorWidth != 1) {
750 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
751 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
752 }
753
754 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
755 // Fill the "else" block, created in the previous iteration
756 //
757 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
758 // br i1 %mask_1, label %cond.store, label %else
759 //
760 Value *Predicate;
761 if (VectorWidth != 1) {
762 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
763 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
764 Builder.getIntN(VectorWidth, 0));
765 } else {
766 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
767 }
768
769 // Create "cond" block
770 //
771 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
772 // %EltAddr = getelementptr i32* %1, i32 0
773 // %store i32 %OneElt, i32* %EltAddr
774 //
775 BasicBlock *CondBlock =
776 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
777 Builder.SetInsertPoint(InsertPt);
778
779 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
780 Builder.CreateAlignedStore(OneElt, Ptr, 1);
781
782 // Move the pointer if there are more blocks to come.
783 Value *NewPtr;
784 if ((Idx + 1) != VectorWidth)
785 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
786
787 // Create "else" block, fill it in the next iteration
788 BasicBlock *NewIfBlock =
789 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
790 Builder.SetInsertPoint(InsertPt);
791 Instruction *OldBr = IfBlock->getTerminator();
792 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
793 OldBr->eraseFromParent();
794 BasicBlock *PrevIfBlock = IfBlock;
795 IfBlock = NewIfBlock;
796
797 // Add a PHI for the pointer if this isn't the last iteration.
798 if ((Idx + 1) != VectorWidth) {
799 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
800 PtrPhi->addIncoming(NewPtr, CondBlock);
801 PtrPhi->addIncoming(Ptr, PrevIfBlock);
802 Ptr = PtrPhi;
803 }
804 }
805 CI->eraseFromParent();
806
807 ModifiedDT = true;
808 }
809
runOnFunction(Function & F)810 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
811 bool EverMadeChange = false;
812
813 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
814
815 bool MadeChange = true;
816 while (MadeChange) {
817 MadeChange = false;
818 for (Function::iterator I = F.begin(); I != F.end();) {
819 BasicBlock *BB = &*I++;
820 bool ModifiedDTOnIteration = false;
821 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
822
823 // Restart BB iteration if the dominator tree of the Function was changed
824 if (ModifiedDTOnIteration)
825 break;
826 }
827
828 EverMadeChange |= MadeChange;
829 }
830
831 return EverMadeChange;
832 }
833
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)834 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
835 bool MadeChange = false;
836
837 BasicBlock::iterator CurInstIterator = BB.begin();
838 while (CurInstIterator != BB.end()) {
839 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
840 MadeChange |= optimizeCallInst(CI, ModifiedDT);
841 if (ModifiedDT)
842 return true;
843 }
844
845 return MadeChange;
846 }
847
optimizeCallInst(CallInst * CI,bool & ModifiedDT)848 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
849 bool &ModifiedDT) {
850 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
851 if (II) {
852 unsigned Alignment;
853 switch (II->getIntrinsicID()) {
854 default:
855 break;
856 case Intrinsic::masked_load: {
857 // Scalarize unsupported vector masked load
858 Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
859 if (TTI->isLegalMaskedLoad(CI->getType(), MaybeAlign(Alignment)))
860 return false;
861 scalarizeMaskedLoad(CI, ModifiedDT);
862 return true;
863 }
864 case Intrinsic::masked_store: {
865 Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
866 if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType(),
867 MaybeAlign(Alignment)))
868 return false;
869 scalarizeMaskedStore(CI, ModifiedDT);
870 return true;
871 }
872 case Intrinsic::masked_gather:
873 Alignment = cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
874 if (TTI->isLegalMaskedGather(CI->getType(), MaybeAlign(Alignment)))
875 return false;
876 scalarizeMaskedGather(CI, ModifiedDT);
877 return true;
878 case Intrinsic::masked_scatter:
879 Alignment = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
880 if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType(),
881 MaybeAlign(Alignment)))
882 return false;
883 scalarizeMaskedScatter(CI, ModifiedDT);
884 return true;
885 case Intrinsic::masked_expandload:
886 if (TTI->isLegalMaskedExpandLoad(CI->getType()))
887 return false;
888 scalarizeMaskedExpandLoad(CI, ModifiedDT);
889 return true;
890 case Intrinsic::masked_compressstore:
891 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
892 return false;
893 scalarizeMaskedCompressStore(CI, ModifiedDT);
894 return true;
895 }
896 }
897
898 return false;
899 }
900