• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===--- Scalarizer.cpp - Scalarize vector operations ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass converts vector operations into scalar operations, in order
11 // to expose optimization opportunities on the individual scalar operations.
12 // It is mainly intended for targets that do not have vector units, but it
13 // may also be useful for revectorizing code to different vector widths.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Scalar.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/InstVisitor.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
23 
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "scalarizer"
27 
28 namespace {
29 // Used to store the scattered form of a vector.
30 typedef SmallVector<Value *, 8> ValueVector;
31 
32 // Used to map a vector Value to its scattered form.  We use std::map
33 // because we want iterators to persist across insertion and because the
34 // values are relatively large.
35 typedef std::map<Value *, ValueVector> ScatterMap;
36 
37 // Lists Instructions that have been replaced with scalar implementations,
38 // along with a pointer to their scattered forms.
39 typedef SmallVector<std::pair<Instruction *, ValueVector *>, 16> GatherList;
40 
41 // Provides a very limited vector-like interface for lazily accessing one
42 // component of a scattered vector or vector pointer.
43 class Scatterer {
44 public:
Scatterer()45   Scatterer() {}
46 
47   // Scatter V into Size components.  If new instructions are needed,
48   // insert them before BBI in BB.  If Cache is nonnull, use it to cache
49   // the results.
50   Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
51             ValueVector *cachePtr = nullptr);
52 
53   // Return component I, creating a new Value for it if necessary.
54   Value *operator[](unsigned I);
55 
56   // Return the number of components.
size() const57   unsigned size() const { return Size; }
58 
59 private:
60   BasicBlock *BB;
61   BasicBlock::iterator BBI;
62   Value *V;
63   ValueVector *CachePtr;
64   PointerType *PtrTy;
65   ValueVector Tmp;
66   unsigned Size;
67 };
68 
69 // FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
70 // called Name that compares X and Y in the same way as FCI.
71 struct FCmpSplitter {
FCmpSplitter__anon987e38e40111::FCmpSplitter72   FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
operator ()__anon987e38e40111::FCmpSplitter73   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
74                     const Twine &Name) const {
75     return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
76   }
77   FCmpInst &FCI;
78 };
79 
80 // ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
81 // called Name that compares X and Y in the same way as ICI.
82 struct ICmpSplitter {
ICmpSplitter__anon987e38e40111::ICmpSplitter83   ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
operator ()__anon987e38e40111::ICmpSplitter84   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
85                     const Twine &Name) const {
86     return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
87   }
88   ICmpInst &ICI;
89 };
90 
91 // BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create
92 // a binary operator like BO called Name with operands X and Y.
93 struct BinarySplitter {
BinarySplitter__anon987e38e40111::BinarySplitter94   BinarySplitter(BinaryOperator &bo) : BO(bo) {}
operator ()__anon987e38e40111::BinarySplitter95   Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
96                     const Twine &Name) const {
97     return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
98   }
99   BinaryOperator &BO;
100 };
101 
102 // Information about a load or store that we're scalarizing.
103 struct VectorLayout {
VectorLayout__anon987e38e40111::VectorLayout104   VectorLayout() : VecTy(nullptr), ElemTy(nullptr), VecAlign(0), ElemSize(0) {}
105 
106   // Return the alignment of element I.
getElemAlign__anon987e38e40111::VectorLayout107   uint64_t getElemAlign(unsigned I) {
108     return MinAlign(VecAlign, I * ElemSize);
109   }
110 
111   // The type of the vector.
112   VectorType *VecTy;
113 
114   // The type of each element.
115   Type *ElemTy;
116 
117   // The alignment of the vector.
118   uint64_t VecAlign;
119 
120   // The size of each element.
121   uint64_t ElemSize;
122 };
123 
124 class Scalarizer : public FunctionPass,
125                    public InstVisitor<Scalarizer, bool> {
126 public:
127   static char ID;
128 
Scalarizer()129   Scalarizer() :
130     FunctionPass(ID) {
131     initializeScalarizerPass(*PassRegistry::getPassRegistry());
132   }
133 
134   bool doInitialization(Module &M) override;
135   bool runOnFunction(Function &F) override;
136 
137   // InstVisitor methods.  They return true if the instruction was scalarized,
138   // false if nothing changed.
visitInstruction(Instruction &)139   bool visitInstruction(Instruction &) { return false; }
140   bool visitSelectInst(SelectInst &SI);
141   bool visitICmpInst(ICmpInst &);
142   bool visitFCmpInst(FCmpInst &);
143   bool visitBinaryOperator(BinaryOperator &);
144   bool visitGetElementPtrInst(GetElementPtrInst &);
145   bool visitCastInst(CastInst &);
146   bool visitBitCastInst(BitCastInst &);
147   bool visitShuffleVectorInst(ShuffleVectorInst &);
148   bool visitPHINode(PHINode &);
149   bool visitLoadInst(LoadInst &);
150   bool visitStoreInst(StoreInst &);
151 
registerOptions()152   static void registerOptions() {
153     // This is disabled by default because having separate loads and stores
154     // makes it more likely that the -combiner-alias-analysis limits will be
155     // reached.
156     OptionRegistry::registerOption<bool, Scalarizer,
157                                  &Scalarizer::ScalarizeLoadStore>(
158         "scalarize-load-store",
159         "Allow the scalarizer pass to scalarize loads and store", false);
160   }
161 
162 private:
163   Scatterer scatter(Instruction *, Value *);
164   void gather(Instruction *, const ValueVector &);
165   bool canTransferMetadata(unsigned Kind);
166   void transferMetadata(Instruction *, const ValueVector &);
167   bool getVectorLayout(Type *, unsigned, VectorLayout &, const DataLayout &);
168   bool finish();
169 
170   template<typename T> bool splitBinary(Instruction &, const T &);
171 
172   ScatterMap Scattered;
173   GatherList Gathered;
174   unsigned ParallelLoopAccessMDKind;
175   bool ScalarizeLoadStore;
176 };
177 
178 char Scalarizer::ID = 0;
179 } // end anonymous namespace
180 
181 INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer",
182                              "Scalarize vector operations", false, false)
183 
Scatterer(BasicBlock * bb,BasicBlock::iterator bbi,Value * v,ValueVector * cachePtr)184 Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
185                      ValueVector *cachePtr)
186   : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) {
187   Type *Ty = V->getType();
188   PtrTy = dyn_cast<PointerType>(Ty);
189   if (PtrTy)
190     Ty = PtrTy->getElementType();
191   Size = Ty->getVectorNumElements();
192   if (!CachePtr)
193     Tmp.resize(Size, nullptr);
194   else if (CachePtr->empty())
195     CachePtr->resize(Size, nullptr);
196   else
197     assert(Size == CachePtr->size() && "Inconsistent vector sizes");
198 }
199 
200 // Return component I, creating a new Value for it if necessary.
operator [](unsigned I)201 Value *Scatterer::operator[](unsigned I) {
202   ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
203   // Try to reuse a previous value.
204   if (CV[I])
205     return CV[I];
206   IRBuilder<> Builder(BB, BBI);
207   if (PtrTy) {
208     if (!CV[0]) {
209       Type *Ty =
210         PointerType::get(PtrTy->getElementType()->getVectorElementType(),
211                          PtrTy->getAddressSpace());
212       CV[0] = Builder.CreateBitCast(V, Ty, V->getName() + ".i0");
213     }
214     if (I != 0)
215       CV[I] = Builder.CreateConstGEP1_32(nullptr, CV[0], I,
216                                          V->getName() + ".i" + Twine(I));
217   } else {
218     // Search through a chain of InsertElementInsts looking for element I.
219     // Record other elements in the cache.  The new V is still suitable
220     // for all uncached indices.
221     for (;;) {
222       InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
223       if (!Insert)
224         break;
225       ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
226       if (!Idx)
227         break;
228       unsigned J = Idx->getZExtValue();
229       V = Insert->getOperand(0);
230       if (I == J) {
231         CV[J] = Insert->getOperand(1);
232         return CV[J];
233       } else if (!CV[J]) {
234         // Only cache the first entry we find for each index we're not actively
235         // searching for. This prevents us from going too far up the chain and
236         // caching incorrect entries.
237         CV[J] = Insert->getOperand(1);
238       }
239     }
240     CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I),
241                                          V->getName() + ".i" + Twine(I));
242   }
243   return CV[I];
244 }
245 
doInitialization(Module & M)246 bool Scalarizer::doInitialization(Module &M) {
247   ParallelLoopAccessMDKind =
248       M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
249   ScalarizeLoadStore =
250       M.getContext().getOption<bool, Scalarizer, &Scalarizer::ScalarizeLoadStore>();
251   return false;
252 }
253 
runOnFunction(Function & F)254 bool Scalarizer::runOnFunction(Function &F) {
255   if (skipFunction(F))
256     return false;
257   assert(Gathered.empty() && Scattered.empty());
258   for (BasicBlock &BB : F) {
259     for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) {
260       Instruction *I = &*II;
261       bool Done = visit(I);
262       ++II;
263       if (Done && I->getType()->isVoidTy())
264         I->eraseFromParent();
265     }
266   }
267   return finish();
268 }
269 
270 // Return a scattered form of V that can be accessed by Point.  V must be a
271 // vector or a pointer to a vector.
scatter(Instruction * Point,Value * V)272 Scatterer Scalarizer::scatter(Instruction *Point, Value *V) {
273   if (Argument *VArg = dyn_cast<Argument>(V)) {
274     // Put the scattered form of arguments in the entry block,
275     // so that it can be used everywhere.
276     Function *F = VArg->getParent();
277     BasicBlock *BB = &F->getEntryBlock();
278     return Scatterer(BB, BB->begin(), V, &Scattered[V]);
279   }
280   if (Instruction *VOp = dyn_cast<Instruction>(V)) {
281     // Put the scattered form of an instruction directly after the
282     // instruction.
283     BasicBlock *BB = VOp->getParent();
284     return Scatterer(BB, std::next(BasicBlock::iterator(VOp)),
285                      V, &Scattered[V]);
286   }
287   // In the fallback case, just put the scattered before Point and
288   // keep the result local to Point.
289   return Scatterer(Point->getParent(), Point->getIterator(), V);
290 }
291 
292 // Replace Op with the gathered form of the components in CV.  Defer the
293 // deletion of Op and creation of the gathered form to the end of the pass,
294 // so that we can avoid creating the gathered form if all uses of Op are
295 // replaced with uses of CV.
gather(Instruction * Op,const ValueVector & CV)296 void Scalarizer::gather(Instruction *Op, const ValueVector &CV) {
297   // Since we're not deleting Op yet, stub out its operands, so that it
298   // doesn't make anything live unnecessarily.
299   for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I)
300     Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType()));
301 
302   transferMetadata(Op, CV);
303 
304   // If we already have a scattered form of Op (created from ExtractElements
305   // of Op itself), replace them with the new form.
306   ValueVector &SV = Scattered[Op];
307   if (!SV.empty()) {
308     for (unsigned I = 0, E = SV.size(); I != E; ++I) {
309       Value *V = SV[I];
310       if (V == nullptr)
311         continue;
312 
313       Instruction *Old = cast<Instruction>(V);
314       CV[I]->takeName(Old);
315       Old->replaceAllUsesWith(CV[I]);
316       Old->eraseFromParent();
317     }
318   }
319   SV = CV;
320   Gathered.push_back(GatherList::value_type(Op, &SV));
321 }
322 
323 // Return true if it is safe to transfer the given metadata tag from
324 // vector to scalar instructions.
canTransferMetadata(unsigned Tag)325 bool Scalarizer::canTransferMetadata(unsigned Tag) {
326   return (Tag == LLVMContext::MD_tbaa
327           || Tag == LLVMContext::MD_fpmath
328           || Tag == LLVMContext::MD_tbaa_struct
329           || Tag == LLVMContext::MD_invariant_load
330           || Tag == LLVMContext::MD_alias_scope
331           || Tag == LLVMContext::MD_noalias
332           || Tag == ParallelLoopAccessMDKind);
333 }
334 
335 // Transfer metadata from Op to the instructions in CV if it is known
336 // to be safe to do so.
transferMetadata(Instruction * Op,const ValueVector & CV)337 void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) {
338   SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
339   Op->getAllMetadataOtherThanDebugLoc(MDs);
340   for (unsigned I = 0, E = CV.size(); I != E; ++I) {
341     if (Instruction *New = dyn_cast<Instruction>(CV[I])) {
342       for (const auto &MD : MDs)
343         if (canTransferMetadata(MD.first))
344           New->setMetadata(MD.first, MD.second);
345       if (Op->getDebugLoc() && !New->getDebugLoc())
346         New->setDebugLoc(Op->getDebugLoc());
347     }
348   }
349 }
350 
351 // Try to fill in Layout from Ty, returning true on success.  Alignment is
352 // the alignment of the vector, or 0 if the ABI default should be used.
getVectorLayout(Type * Ty,unsigned Alignment,VectorLayout & Layout,const DataLayout & DL)353 bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment,
354                                  VectorLayout &Layout, const DataLayout &DL) {
355   // Make sure we're dealing with a vector.
356   Layout.VecTy = dyn_cast<VectorType>(Ty);
357   if (!Layout.VecTy)
358     return false;
359 
360   // Check that we're dealing with full-byte elements.
361   Layout.ElemTy = Layout.VecTy->getElementType();
362   if (DL.getTypeSizeInBits(Layout.ElemTy) !=
363       DL.getTypeStoreSizeInBits(Layout.ElemTy))
364     return false;
365 
366   if (Alignment)
367     Layout.VecAlign = Alignment;
368   else
369     Layout.VecAlign = DL.getABITypeAlignment(Layout.VecTy);
370   Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy);
371   return true;
372 }
373 
374 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
375 // to create an instruction like I with operands X and Y and name Name.
376 template<typename Splitter>
splitBinary(Instruction & I,const Splitter & Split)377 bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) {
378   VectorType *VT = dyn_cast<VectorType>(I.getType());
379   if (!VT)
380     return false;
381 
382   unsigned NumElems = VT->getNumElements();
383   IRBuilder<> Builder(&I);
384   Scatterer Op0 = scatter(&I, I.getOperand(0));
385   Scatterer Op1 = scatter(&I, I.getOperand(1));
386   assert(Op0.size() == NumElems && "Mismatched binary operation");
387   assert(Op1.size() == NumElems && "Mismatched binary operation");
388   ValueVector Res;
389   Res.resize(NumElems);
390   for (unsigned Elem = 0; Elem < NumElems; ++Elem)
391     Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem],
392                       I.getName() + ".i" + Twine(Elem));
393   gather(&I, Res);
394   return true;
395 }
396 
visitSelectInst(SelectInst & SI)397 bool Scalarizer::visitSelectInst(SelectInst &SI) {
398   VectorType *VT = dyn_cast<VectorType>(SI.getType());
399   if (!VT)
400     return false;
401 
402   unsigned NumElems = VT->getNumElements();
403   IRBuilder<> Builder(&SI);
404   Scatterer Op1 = scatter(&SI, SI.getOperand(1));
405   Scatterer Op2 = scatter(&SI, SI.getOperand(2));
406   assert(Op1.size() == NumElems && "Mismatched select");
407   assert(Op2.size() == NumElems && "Mismatched select");
408   ValueVector Res;
409   Res.resize(NumElems);
410 
411   if (SI.getOperand(0)->getType()->isVectorTy()) {
412     Scatterer Op0 = scatter(&SI, SI.getOperand(0));
413     assert(Op0.size() == NumElems && "Mismatched select");
414     for (unsigned I = 0; I < NumElems; ++I)
415       Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I],
416                                     SI.getName() + ".i" + Twine(I));
417   } else {
418     Value *Op0 = SI.getOperand(0);
419     for (unsigned I = 0; I < NumElems; ++I)
420       Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I],
421                                     SI.getName() + ".i" + Twine(I));
422   }
423   gather(&SI, Res);
424   return true;
425 }
426 
visitICmpInst(ICmpInst & ICI)427 bool Scalarizer::visitICmpInst(ICmpInst &ICI) {
428   return splitBinary(ICI, ICmpSplitter(ICI));
429 }
430 
visitFCmpInst(FCmpInst & FCI)431 bool Scalarizer::visitFCmpInst(FCmpInst &FCI) {
432   return splitBinary(FCI, FCmpSplitter(FCI));
433 }
434 
visitBinaryOperator(BinaryOperator & BO)435 bool Scalarizer::visitBinaryOperator(BinaryOperator &BO) {
436   return splitBinary(BO, BinarySplitter(BO));
437 }
438 
visitGetElementPtrInst(GetElementPtrInst & GEPI)439 bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
440   VectorType *VT = dyn_cast<VectorType>(GEPI.getType());
441   if (!VT)
442     return false;
443 
444   IRBuilder<> Builder(&GEPI);
445   unsigned NumElems = VT->getNumElements();
446   unsigned NumIndices = GEPI.getNumIndices();
447 
448   Scatterer Base = scatter(&GEPI, GEPI.getOperand(0));
449 
450   SmallVector<Scatterer, 8> Ops;
451   Ops.resize(NumIndices);
452   for (unsigned I = 0; I < NumIndices; ++I)
453     Ops[I] = scatter(&GEPI, GEPI.getOperand(I + 1));
454 
455   ValueVector Res;
456   Res.resize(NumElems);
457   for (unsigned I = 0; I < NumElems; ++I) {
458     SmallVector<Value *, 8> Indices;
459     Indices.resize(NumIndices);
460     for (unsigned J = 0; J < NumIndices; ++J)
461       Indices[J] = Ops[J][I];
462     Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices,
463                                GEPI.getName() + ".i" + Twine(I));
464     if (GEPI.isInBounds())
465       if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I]))
466         NewGEPI->setIsInBounds();
467   }
468   gather(&GEPI, Res);
469   return true;
470 }
471 
visitCastInst(CastInst & CI)472 bool Scalarizer::visitCastInst(CastInst &CI) {
473   VectorType *VT = dyn_cast<VectorType>(CI.getDestTy());
474   if (!VT)
475     return false;
476 
477   unsigned NumElems = VT->getNumElements();
478   IRBuilder<> Builder(&CI);
479   Scatterer Op0 = scatter(&CI, CI.getOperand(0));
480   assert(Op0.size() == NumElems && "Mismatched cast");
481   ValueVector Res;
482   Res.resize(NumElems);
483   for (unsigned I = 0; I < NumElems; ++I)
484     Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(),
485                                 CI.getName() + ".i" + Twine(I));
486   gather(&CI, Res);
487   return true;
488 }
489 
visitBitCastInst(BitCastInst & BCI)490 bool Scalarizer::visitBitCastInst(BitCastInst &BCI) {
491   VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy());
492   VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy());
493   if (!DstVT || !SrcVT)
494     return false;
495 
496   unsigned DstNumElems = DstVT->getNumElements();
497   unsigned SrcNumElems = SrcVT->getNumElements();
498   IRBuilder<> Builder(&BCI);
499   Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
500   ValueVector Res;
501   Res.resize(DstNumElems);
502 
503   if (DstNumElems == SrcNumElems) {
504     for (unsigned I = 0; I < DstNumElems; ++I)
505       Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(),
506                                      BCI.getName() + ".i" + Twine(I));
507   } else if (DstNumElems > SrcNumElems) {
508     // <M x t1> -> <N*M x t2>.  Convert each t1 to <N x t2> and copy the
509     // individual elements to the destination.
510     unsigned FanOut = DstNumElems / SrcNumElems;
511     Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut);
512     unsigned ResI = 0;
513     for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
514       Value *V = Op0[Op0I];
515       Instruction *VI;
516       // Look through any existing bitcasts before converting to <N x t2>.
517       // In the best case, the resulting conversion might be a no-op.
518       while ((VI = dyn_cast<Instruction>(V)) &&
519              VI->getOpcode() == Instruction::BitCast)
520         V = VI->getOperand(0);
521       V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast");
522       Scatterer Mid = scatter(&BCI, V);
523       for (unsigned MidI = 0; MidI < FanOut; ++MidI)
524         Res[ResI++] = Mid[MidI];
525     }
526   } else {
527     // <N*M x t1> -> <M x t2>.  Convert each group of <N x t1> into a t2.
528     unsigned FanIn = SrcNumElems / DstNumElems;
529     Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn);
530     unsigned Op0I = 0;
531     for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
532       Value *V = UndefValue::get(MidTy);
533       for (unsigned MidI = 0; MidI < FanIn; ++MidI)
534         V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI),
535                                         BCI.getName() + ".i" + Twine(ResI)
536                                         + ".upto" + Twine(MidI));
537       Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(),
538                                         BCI.getName() + ".i" + Twine(ResI));
539     }
540   }
541   gather(&BCI, Res);
542   return true;
543 }
544 
visitShuffleVectorInst(ShuffleVectorInst & SVI)545 bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
546   VectorType *VT = dyn_cast<VectorType>(SVI.getType());
547   if (!VT)
548     return false;
549 
550   unsigned NumElems = VT->getNumElements();
551   Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
552   Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
553   ValueVector Res;
554   Res.resize(NumElems);
555 
556   for (unsigned I = 0; I < NumElems; ++I) {
557     int Selector = SVI.getMaskValue(I);
558     if (Selector < 0)
559       Res[I] = UndefValue::get(VT->getElementType());
560     else if (unsigned(Selector) < Op0.size())
561       Res[I] = Op0[Selector];
562     else
563       Res[I] = Op1[Selector - Op0.size()];
564   }
565   gather(&SVI, Res);
566   return true;
567 }
568 
visitPHINode(PHINode & PHI)569 bool Scalarizer::visitPHINode(PHINode &PHI) {
570   VectorType *VT = dyn_cast<VectorType>(PHI.getType());
571   if (!VT)
572     return false;
573 
574   unsigned NumElems = VT->getNumElements();
575   IRBuilder<> Builder(&PHI);
576   ValueVector Res;
577   Res.resize(NumElems);
578 
579   unsigned NumOps = PHI.getNumOperands();
580   for (unsigned I = 0; I < NumElems; ++I)
581     Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps,
582                                PHI.getName() + ".i" + Twine(I));
583 
584   for (unsigned I = 0; I < NumOps; ++I) {
585     Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I));
586     BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
587     for (unsigned J = 0; J < NumElems; ++J)
588       cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
589   }
590   gather(&PHI, Res);
591   return true;
592 }
593 
visitLoadInst(LoadInst & LI)594 bool Scalarizer::visitLoadInst(LoadInst &LI) {
595   if (!ScalarizeLoadStore)
596     return false;
597   if (!LI.isSimple())
598     return false;
599 
600   VectorLayout Layout;
601   if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout,
602                        LI.getModule()->getDataLayout()))
603     return false;
604 
605   unsigned NumElems = Layout.VecTy->getNumElements();
606   IRBuilder<> Builder(&LI);
607   Scatterer Ptr = scatter(&LI, LI.getPointerOperand());
608   ValueVector Res;
609   Res.resize(NumElems);
610 
611   for (unsigned I = 0; I < NumElems; ++I)
612     Res[I] = Builder.CreateAlignedLoad(Ptr[I], Layout.getElemAlign(I),
613                                        LI.getName() + ".i" + Twine(I));
614   gather(&LI, Res);
615   return true;
616 }
617 
visitStoreInst(StoreInst & SI)618 bool Scalarizer::visitStoreInst(StoreInst &SI) {
619   if (!ScalarizeLoadStore)
620     return false;
621   if (!SI.isSimple())
622     return false;
623 
624   VectorLayout Layout;
625   Value *FullValue = SI.getValueOperand();
626   if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout,
627                        SI.getModule()->getDataLayout()))
628     return false;
629 
630   unsigned NumElems = Layout.VecTy->getNumElements();
631   IRBuilder<> Builder(&SI);
632   Scatterer Ptr = scatter(&SI, SI.getPointerOperand());
633   Scatterer Val = scatter(&SI, FullValue);
634 
635   ValueVector Stores;
636   Stores.resize(NumElems);
637   for (unsigned I = 0; I < NumElems; ++I) {
638     unsigned Align = Layout.getElemAlign(I);
639     Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align);
640   }
641   transferMetadata(&SI, Stores);
642   return true;
643 }
644 
645 // Delete the instructions that we scalarized.  If a full vector result
646 // is still needed, recreate it using InsertElements.
finish()647 bool Scalarizer::finish() {
648   // The presence of data in Gathered or Scattered indicates changes
649   // made to the Function.
650   if (Gathered.empty() && Scattered.empty())
651     return false;
652   for (const auto &GMI : Gathered) {
653     Instruction *Op = GMI.first;
654     ValueVector &CV = *GMI.second;
655     if (!Op->use_empty()) {
656       // The value is still needed, so recreate it using a series of
657       // InsertElements.
658       Type *Ty = Op->getType();
659       Value *Res = UndefValue::get(Ty);
660       BasicBlock *BB = Op->getParent();
661       unsigned Count = Ty->getVectorNumElements();
662       IRBuilder<> Builder(Op);
663       if (isa<PHINode>(Op))
664         Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
665       for (unsigned I = 0; I < Count; ++I)
666         Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
667                                           Op->getName() + ".upto" + Twine(I));
668       Res->takeName(Op);
669       Op->replaceAllUsesWith(Res);
670     }
671     Op->eraseFromParent();
672   }
673   Gathered.clear();
674   Scattered.clear();
675   return true;
676 }
677 
createScalarizerPass()678 FunctionPass *llvm::createScalarizerPass() {
679   return new Scalarizer();
680 }
681