1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 // instrinsics
3 //
4 // The LLVM Compiler Infrastructure
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This pass replaces masked memory intrinsics - when unsupported by the target
12 // - with a chain of basic blocks, that deal with the elements one-by-one if the
13 // appropriate mask bit is set.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/CodeGen/TargetSubtargetInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37
38 using namespace llvm;
39
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42 namespace {
43
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45 const TargetTransformInfo *TTI = nullptr;
46
47 public:
48 static char ID; // Pass identification, replacement for typeid
49
ScalarizeMaskedMemIntrin()50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 }
53
54 bool runOnFunction(Function &F) override;
55
getPassName() const56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
getAnalysisUsage(AnalysisUsage & AU) const60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63
64 private:
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67 };
68
69 } // end anonymous namespace
70
71 char ScalarizeMaskedMemIntrin::ID = 0;
72
73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 "Scalarize unsupported masked memory intrinsics", false, false)
75
createScalarizeMaskedMemIntrinPass()76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 return new ScalarizeMaskedMemIntrin();
78 }
79
80 // Translate a masked load intrinsic like
81 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
82 // <16 x i1> %mask, <16 x i32> %passthru)
83 // to a chain of basic blocks, with loading element one-by-one if
84 // the appropriate mask bit is set
85 //
86 // %1 = bitcast i8* %addr to i32*
87 // %2 = extractelement <16 x i1> %mask, i32 0
88 // %3 = icmp eq i1 %2, true
89 // br i1 %3, label %cond.load, label %else
90 //
91 // cond.load: ; preds = %0
92 // %4 = getelementptr i32* %1, i32 0
93 // %5 = load i32* %4
94 // %6 = insertelement <16 x i32> undef, i32 %5, i32 0
95 // br label %else
96 //
97 // else: ; preds = %0, %cond.load
98 // %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
99 // %7 = extractelement <16 x i1> %mask, i32 1
100 // %8 = icmp eq i1 %7, true
101 // br i1 %8, label %cond.load1, label %else2
102 //
103 // cond.load1: ; preds = %else
104 // %9 = getelementptr i32* %1, i32 1
105 // %10 = load i32* %9
106 // %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
107 // br label %else2
108 //
109 // else2: ; preds = %else, %cond.load1
110 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
111 // %12 = extractelement <16 x i1> %mask, i32 2
112 // %13 = icmp eq i1 %12, true
113 // br i1 %13, label %cond.load4, label %else5
114 //
scalarizeMaskedLoad(CallInst * CI)115 static void scalarizeMaskedLoad(CallInst *CI) {
116 Value *Ptr = CI->getArgOperand(0);
117 Value *Alignment = CI->getArgOperand(1);
118 Value *Mask = CI->getArgOperand(2);
119 Value *Src0 = CI->getArgOperand(3);
120
121 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
122 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
123 assert(VecType && "Unexpected return type of masked load intrinsic");
124
125 Type *EltTy = CI->getType()->getVectorElementType();
126
127 IRBuilder<> Builder(CI->getContext());
128 Instruction *InsertPt = CI;
129 BasicBlock *IfBlock = CI->getParent();
130 BasicBlock *CondBlock = nullptr;
131 BasicBlock *PrevIfBlock = CI->getParent();
132
133 Builder.SetInsertPoint(InsertPt);
134 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
135
136 // Short-cut if the mask is all-true.
137 bool IsAllOnesMask =
138 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
139
140 if (IsAllOnesMask) {
141 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
142 CI->replaceAllUsesWith(NewI);
143 CI->eraseFromParent();
144 return;
145 }
146
147 // Adjust alignment for the scalar instruction.
148 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
149 // Bitcast %addr fron i8* to EltTy*
150 Type *NewPtrType =
151 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
152 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
153 unsigned VectorWidth = VecType->getNumElements();
154
155 Value *UndefVal = UndefValue::get(VecType);
156
157 // The result vector
158 Value *VResult = UndefVal;
159
160 if (isa<ConstantVector>(Mask)) {
161 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
162 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
163 continue;
164 Value *Gep =
165 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
166 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
167 VResult =
168 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
169 }
170 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
171 CI->replaceAllUsesWith(NewI);
172 CI->eraseFromParent();
173 return;
174 }
175
176 PHINode *Phi = nullptr;
177 Value *PrevPhi = UndefVal;
178
179 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
180 // Fill the "else" block, created in the previous iteration
181 //
182 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
183 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
184 // %to_load = icmp eq i1 %mask_1, true
185 // br i1 %to_load, label %cond.load, label %else
186 //
187 if (Idx > 0) {
188 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
189 Phi->addIncoming(VResult, CondBlock);
190 Phi->addIncoming(PrevPhi, PrevIfBlock);
191 PrevPhi = Phi;
192 VResult = Phi;
193 }
194
195 Value *Predicate =
196 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
197 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
198 ConstantInt::get(Predicate->getType(), 1));
199
200 // Create "cond" block
201 //
202 // %EltAddr = getelementptr i32* %1, i32 0
203 // %Elt = load i32* %EltAddr
204 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
205 //
206 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
207 Builder.SetInsertPoint(InsertPt);
208
209 Value *Gep =
210 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
211 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
212 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
213
214 // Create "else" block, fill it in the next iteration
215 BasicBlock *NewIfBlock =
216 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
217 Builder.SetInsertPoint(InsertPt);
218 Instruction *OldBr = IfBlock->getTerminator();
219 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
220 OldBr->eraseFromParent();
221 PrevIfBlock = IfBlock;
222 IfBlock = NewIfBlock;
223 }
224
225 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
226 Phi->addIncoming(VResult, CondBlock);
227 Phi->addIncoming(PrevPhi, PrevIfBlock);
228 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
229 CI->replaceAllUsesWith(NewI);
230 CI->eraseFromParent();
231 }
232
233 // Translate a masked store intrinsic, like
234 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
235 // <16 x i1> %mask)
236 // to a chain of basic blocks, that stores element one-by-one if
237 // the appropriate mask bit is set
238 //
239 // %1 = bitcast i8* %addr to i32*
240 // %2 = extractelement <16 x i1> %mask, i32 0
241 // %3 = icmp eq i1 %2, true
242 // br i1 %3, label %cond.store, label %else
243 //
244 // cond.store: ; preds = %0
245 // %4 = extractelement <16 x i32> %val, i32 0
246 // %5 = getelementptr i32* %1, i32 0
247 // store i32 %4, i32* %5
248 // br label %else
249 //
250 // else: ; preds = %0, %cond.store
251 // %6 = extractelement <16 x i1> %mask, i32 1
252 // %7 = icmp eq i1 %6, true
253 // br i1 %7, label %cond.store1, label %else2
254 //
255 // cond.store1: ; preds = %else
256 // %8 = extractelement <16 x i32> %val, i32 1
257 // %9 = getelementptr i32* %1, i32 1
258 // store i32 %8, i32* %9
259 // br label %else2
260 // . . .
scalarizeMaskedStore(CallInst * CI)261 static void scalarizeMaskedStore(CallInst *CI) {
262 Value *Src = CI->getArgOperand(0);
263 Value *Ptr = CI->getArgOperand(1);
264 Value *Alignment = CI->getArgOperand(2);
265 Value *Mask = CI->getArgOperand(3);
266
267 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
268 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
269 assert(VecType && "Unexpected data type in masked store intrinsic");
270
271 Type *EltTy = VecType->getElementType();
272
273 IRBuilder<> Builder(CI->getContext());
274 Instruction *InsertPt = CI;
275 BasicBlock *IfBlock = CI->getParent();
276 Builder.SetInsertPoint(InsertPt);
277 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
278
279 // Short-cut if the mask is all-true.
280 bool IsAllOnesMask =
281 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
282
283 if (IsAllOnesMask) {
284 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285 CI->eraseFromParent();
286 return;
287 }
288
289 // Adjust alignment for the scalar instruction.
290 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
291 // Bitcast %addr fron i8* to EltTy*
292 Type *NewPtrType =
293 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
294 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295 unsigned VectorWidth = VecType->getNumElements();
296
297 if (isa<ConstantVector>(Mask)) {
298 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
300 continue;
301 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
302 Value *Gep =
303 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
304 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
305 }
306 CI->eraseFromParent();
307 return;
308 }
309
310 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
311 // Fill the "else" block, created in the previous iteration
312 //
313 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
314 // %to_store = icmp eq i1 %mask_1, true
315 // br i1 %to_store, label %cond.store, label %else
316 //
317 Value *Predicate =
318 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
319 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
320 ConstantInt::get(Predicate->getType(), 1));
321
322 // Create "cond" block
323 //
324 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
325 // %EltAddr = getelementptr i32* %1, i32 0
326 // %store i32 %OneElt, i32* %EltAddr
327 //
328 BasicBlock *CondBlock =
329 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
330 Builder.SetInsertPoint(InsertPt);
331
332 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
333 Value *Gep =
334 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
335 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
336
337 // Create "else" block, fill it in the next iteration
338 BasicBlock *NewIfBlock =
339 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
340 Builder.SetInsertPoint(InsertPt);
341 Instruction *OldBr = IfBlock->getTerminator();
342 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
343 OldBr->eraseFromParent();
344 IfBlock = NewIfBlock;
345 }
346 CI->eraseFromParent();
347 }
348
349 // Translate a masked gather intrinsic like
350 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
351 // <16 x i1> %Mask, <16 x i32> %Src)
352 // to a chain of basic blocks, with loading element one-by-one if
353 // the appropriate mask bit is set
354 //
355 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
356 // % Mask0 = extractelement <16 x i1> %Mask, i32 0
357 // % ToLoad0 = icmp eq i1 % Mask0, true
358 // br i1 % ToLoad0, label %cond.load, label %else
359 //
360 // cond.load:
361 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
362 // % Load0 = load i32, i32* % Ptr0, align 4
363 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
364 // br label %else
365 //
366 // else:
367 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
368 // % Mask1 = extractelement <16 x i1> %Mask, i32 1
369 // % ToLoad1 = icmp eq i1 % Mask1, true
370 // br i1 % ToLoad1, label %cond.load1, label %else2
371 //
372 // cond.load1:
373 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
374 // % Load1 = load i32, i32* % Ptr1, align 4
375 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
376 // br label %else2
377 // . . .
378 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
379 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI)380 static void scalarizeMaskedGather(CallInst *CI) {
381 Value *Ptrs = CI->getArgOperand(0);
382 Value *Alignment = CI->getArgOperand(1);
383 Value *Mask = CI->getArgOperand(2);
384 Value *Src0 = CI->getArgOperand(3);
385
386 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
387
388 assert(VecType && "Unexpected return type of masked load intrinsic");
389
390 IRBuilder<> Builder(CI->getContext());
391 Instruction *InsertPt = CI;
392 BasicBlock *IfBlock = CI->getParent();
393 BasicBlock *CondBlock = nullptr;
394 BasicBlock *PrevIfBlock = CI->getParent();
395 Builder.SetInsertPoint(InsertPt);
396 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
397
398 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
399
400 Value *UndefVal = UndefValue::get(VecType);
401
402 // The result vector
403 Value *VResult = UndefVal;
404 unsigned VectorWidth = VecType->getNumElements();
405
406 // Shorten the way if the mask is a vector of constants.
407 bool IsConstMask = isa<ConstantVector>(Mask);
408
409 if (IsConstMask) {
410 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
411 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
412 continue;
413 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
414 "Ptr" + Twine(Idx));
415 LoadInst *Load =
416 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
417 VResult = Builder.CreateInsertElement(
418 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
419 }
420 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
421 CI->replaceAllUsesWith(NewI);
422 CI->eraseFromParent();
423 return;
424 }
425
426 PHINode *Phi = nullptr;
427 Value *PrevPhi = UndefVal;
428
429 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
430 // Fill the "else" block, created in the previous iteration
431 //
432 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
433 // %ToLoad1 = icmp eq i1 %Mask1, true
434 // br i1 %ToLoad1, label %cond.load, label %else
435 //
436 if (Idx > 0) {
437 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
438 Phi->addIncoming(VResult, CondBlock);
439 Phi->addIncoming(PrevPhi, PrevIfBlock);
440 PrevPhi = Phi;
441 VResult = Phi;
442 }
443
444 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
445 "Mask" + Twine(Idx));
446 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
447 ConstantInt::get(Predicate->getType(), 1),
448 "ToLoad" + Twine(Idx));
449
450 // Create "cond" block
451 //
452 // %EltAddr = getelementptr i32* %1, i32 0
453 // %Elt = load i32* %EltAddr
454 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
455 //
456 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
457 Builder.SetInsertPoint(InsertPt);
458
459 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
460 "Ptr" + Twine(Idx));
461 LoadInst *Load =
462 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
463 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
464 "Res" + Twine(Idx));
465
466 // Create "else" block, fill it in the next iteration
467 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
468 Builder.SetInsertPoint(InsertPt);
469 Instruction *OldBr = IfBlock->getTerminator();
470 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
471 OldBr->eraseFromParent();
472 PrevIfBlock = IfBlock;
473 IfBlock = NewIfBlock;
474 }
475
476 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
477 Phi->addIncoming(VResult, CondBlock);
478 Phi->addIncoming(PrevPhi, PrevIfBlock);
479 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
480 CI->replaceAllUsesWith(NewI);
481 CI->eraseFromParent();
482 }
483
484 // Translate a masked scatter intrinsic, like
485 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
486 // <16 x i1> %Mask)
487 // to a chain of basic blocks, that stores element one-by-one if
488 // the appropriate mask bit is set.
489 //
490 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
491 // % Mask0 = extractelement <16 x i1> % Mask, i32 0
492 // % ToStore0 = icmp eq i1 % Mask0, true
493 // br i1 %ToStore0, label %cond.store, label %else
494 //
495 // cond.store:
496 // % Elt0 = extractelement <16 x i32> %Src, i32 0
497 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
498 // store i32 %Elt0, i32* % Ptr0, align 4
499 // br label %else
500 //
501 // else:
502 // % Mask1 = extractelement <16 x i1> % Mask, i32 1
503 // % ToStore1 = icmp eq i1 % Mask1, true
504 // br i1 % ToStore1, label %cond.store1, label %else2
505 //
506 // cond.store1:
507 // % Elt1 = extractelement <16 x i32> %Src, i32 1
508 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
509 // store i32 % Elt1, i32* % Ptr1, align 4
510 // br label %else2
511 // . . .
scalarizeMaskedScatter(CallInst * CI)512 static void scalarizeMaskedScatter(CallInst *CI) {
513 Value *Src = CI->getArgOperand(0);
514 Value *Ptrs = CI->getArgOperand(1);
515 Value *Alignment = CI->getArgOperand(2);
516 Value *Mask = CI->getArgOperand(3);
517
518 assert(isa<VectorType>(Src->getType()) &&
519 "Unexpected data type in masked scatter intrinsic");
520 assert(isa<VectorType>(Ptrs->getType()) &&
521 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
522 "Vector of pointers is expected in masked scatter intrinsic");
523
524 IRBuilder<> Builder(CI->getContext());
525 Instruction *InsertPt = CI;
526 BasicBlock *IfBlock = CI->getParent();
527 Builder.SetInsertPoint(InsertPt);
528 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
529
530 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
531 unsigned VectorWidth = Src->getType()->getVectorNumElements();
532
533 // Shorten the way if the mask is a vector of constants.
534 bool IsConstMask = isa<ConstantVector>(Mask);
535
536 if (IsConstMask) {
537 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
538 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
539 continue;
540 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
541 "Elt" + Twine(Idx));
542 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
543 "Ptr" + Twine(Idx));
544 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
545 }
546 CI->eraseFromParent();
547 return;
548 }
549 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
550 // Fill the "else" block, created in the previous iteration
551 //
552 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
553 // % ToStore = icmp eq i1 % Mask1, true
554 // br i1 % ToStore, label %cond.store, label %else
555 //
556 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
557 "Mask" + Twine(Idx));
558 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
559 ConstantInt::get(Predicate->getType(), 1),
560 "ToStore" + Twine(Idx));
561
562 // Create "cond" block
563 //
564 // % Elt1 = extractelement <16 x i32> %Src, i32 1
565 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
566 // %store i32 % Elt1, i32* % Ptr1
567 //
568 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
569 Builder.SetInsertPoint(InsertPt);
570
571 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
572 "Elt" + Twine(Idx));
573 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
574 "Ptr" + Twine(Idx));
575 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
576
577 // Create "else" block, fill it in the next iteration
578 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
579 Builder.SetInsertPoint(InsertPt);
580 Instruction *OldBr = IfBlock->getTerminator();
581 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
582 OldBr->eraseFromParent();
583 IfBlock = NewIfBlock;
584 }
585 CI->eraseFromParent();
586 }
587
runOnFunction(Function & F)588 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
589 bool EverMadeChange = false;
590
591 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
592
593 bool MadeChange = true;
594 while (MadeChange) {
595 MadeChange = false;
596 for (Function::iterator I = F.begin(); I != F.end();) {
597 BasicBlock *BB = &*I++;
598 bool ModifiedDTOnIteration = false;
599 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
600
601 // Restart BB iteration if the dominator tree of the Function was changed
602 if (ModifiedDTOnIteration)
603 break;
604 }
605
606 EverMadeChange |= MadeChange;
607 }
608
609 return EverMadeChange;
610 }
611
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)612 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
613 bool MadeChange = false;
614
615 BasicBlock::iterator CurInstIterator = BB.begin();
616 while (CurInstIterator != BB.end()) {
617 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
618 MadeChange |= optimizeCallInst(CI, ModifiedDT);
619 if (ModifiedDT)
620 return true;
621 }
622
623 return MadeChange;
624 }
625
optimizeCallInst(CallInst * CI,bool & ModifiedDT)626 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
627 bool &ModifiedDT) {
628 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
629 if (II) {
630 switch (II->getIntrinsicID()) {
631 default:
632 break;
633 case Intrinsic::masked_load:
634 // Scalarize unsupported vector masked load
635 if (!TTI->isLegalMaskedLoad(CI->getType())) {
636 scalarizeMaskedLoad(CI);
637 ModifiedDT = true;
638 return true;
639 }
640 return false;
641 case Intrinsic::masked_store:
642 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
643 scalarizeMaskedStore(CI);
644 ModifiedDT = true;
645 return true;
646 }
647 return false;
648 case Intrinsic::masked_gather:
649 if (!TTI->isLegalMaskedGather(CI->getType())) {
650 scalarizeMaskedGather(CI);
651 ModifiedDT = true;
652 return true;
653 }
654 return false;
655 case Intrinsic::masked_scatter:
656 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
657 scalarizeMaskedScatter(CI);
658 ModifiedDT = true;
659 return true;
660 }
661 return false;
662 }
663 }
664
665 return false;
666 }
667