• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the cost model analysis. It provides a very basic cost
11 // estimation for LLVM-IR. This analysis uses the services of the codegen
12 // to approximate the cost of any IR instruction when lowered to machine
13 // instructions. The cost results are unit-less and the cost number represents
14 // the throughput of the machine assuming that all loads hit the cache, all
15 // branches are predicted, etc. The cost numbers can be added in order to
16 // compare two or more transformation alternatives.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Analysis/Passes.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/Value.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 using namespace llvm;
32 
33 #define CM_NAME "cost-model"
34 #define DEBUG_TYPE CM_NAME
35 
36 static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false),
37                                      cl::Hidden,
38                                      cl::desc("Recognize reduction patterns."));
39 
40 namespace {
41   class CostModelAnalysis : public FunctionPass {
42 
43   public:
44     static char ID; // Class identification, replacement for typeinfo
CostModelAnalysis()45     CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) {
46       initializeCostModelAnalysisPass(
47         *PassRegistry::getPassRegistry());
48     }
49 
50     /// Returns the expected cost of the instruction.
51     /// Returns -1 if the cost is unknown.
52     /// Note, this method does not cache the cost calculation and it
53     /// can be expensive in some cases.
54     unsigned getInstructionCost(const Instruction *I) const;
55 
56   private:
57     void getAnalysisUsage(AnalysisUsage &AU) const override;
58     bool runOnFunction(Function &F) override;
59     void print(raw_ostream &OS, const Module*) const override;
60 
61     /// The function that we analyze.
62     Function *F;
63     /// Target information.
64     const TargetTransformInfo *TTI;
65   };
66 }  // End of anonymous namespace
67 
68 // Register this pass.
69 char CostModelAnalysis::ID = 0;
70 static const char cm_name[] = "Cost Model Analysis";
INITIALIZE_PASS_BEGIN(CostModelAnalysis,CM_NAME,cm_name,false,true)71 INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
72 INITIALIZE_PASS_END  (CostModelAnalysis, CM_NAME, cm_name, false, true)
73 
74 FunctionPass *llvm::createCostModelAnalysisPass() {
75   return new CostModelAnalysis();
76 }
77 
78 void
getAnalysisUsage(AnalysisUsage & AU) const79 CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
80   AU.setPreservesAll();
81 }
82 
83 bool
runOnFunction(Function & F)84 CostModelAnalysis::runOnFunction(Function &F) {
85  this->F = &F;
86  TTI = getAnalysisIfAvailable<TargetTransformInfo>();
87 
88  return false;
89 }
90 
isReverseVectorMask(SmallVectorImpl<int> & Mask)91 static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) {
92   for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
93     if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
94       return false;
95   return true;
96 }
97 
isAlternateVectorMask(SmallVectorImpl<int> & Mask)98 static bool isAlternateVectorMask(SmallVectorImpl<int> &Mask) {
99   bool isAlternate = true;
100   unsigned MaskSize = Mask.size();
101 
102   // Example: shufflevector A, B, <0,5,2,7>
103   for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
104     if (Mask[i] < 0)
105       continue;
106     isAlternate = Mask[i] == (int)((i & 1) ? MaskSize + i : i);
107   }
108 
109   if (isAlternate)
110     return true;
111 
112   isAlternate = true;
113   // Example: shufflevector A, B, <4,1,6,3>
114   for (unsigned i = 0; i < MaskSize && isAlternate; ++i) {
115     if (Mask[i] < 0)
116       continue;
117     isAlternate = Mask[i] == (int)((i & 1) ? i : MaskSize + i);
118   }
119 
120   return isAlternate;
121 }
122 
getOperandInfo(Value * V)123 static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) {
124   TargetTransformInfo::OperandValueKind OpInfo =
125     TargetTransformInfo::OK_AnyValue;
126 
127   // Check for a splat of a constant or for a non uniform vector of constants.
128   if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
129     OpInfo = TargetTransformInfo::OK_NonUniformConstantValue;
130     if (cast<Constant>(V)->getSplatValue() != nullptr)
131       OpInfo = TargetTransformInfo::OK_UniformConstantValue;
132   }
133 
134   return OpInfo;
135 }
136 
matchPairwiseShuffleMask(ShuffleVectorInst * SI,bool IsLeft,unsigned Level)137 static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
138                                      unsigned Level) {
139   // We don't need a shuffle if we just want to have element 0 in position 0 of
140   // the vector.
141   if (!SI && Level == 0 && IsLeft)
142     return true;
143   else if (!SI)
144     return false;
145 
146   SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1);
147 
148   // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether
149   // we look at the left or right side.
150   for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2)
151     Mask[i] = val;
152 
153   SmallVector<int, 16> ActualMask = SI->getShuffleMask();
154   if (Mask != ActualMask)
155     return false;
156 
157   return true;
158 }
159 
matchPairwiseReductionAtLevel(const BinaryOperator * BinOp,unsigned Level,unsigned NumLevels)160 static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp,
161                                           unsigned Level, unsigned NumLevels) {
162   // Match one level of pairwise operations.
163   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
164   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
165   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
166   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
167   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
168   if (BinOp == nullptr)
169     return false;
170 
171   assert(BinOp->getType()->isVectorTy() && "Expecting a vector type");
172 
173   unsigned Opcode = BinOp->getOpcode();
174   Value *L = BinOp->getOperand(0);
175   Value *R = BinOp->getOperand(1);
176 
177   ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L);
178   if (!LS && Level)
179     return false;
180   ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R);
181   if (!RS && Level)
182     return false;
183 
184   // On level 0 we can omit one shufflevector instruction.
185   if (!Level && !RS && !LS)
186     return false;
187 
188   // Shuffle inputs must match.
189   Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
190   Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr;
191   Value *NextLevelOp = nullptr;
192   if (NextLevelOpR && NextLevelOpL) {
193     // If we have two shuffles their operands must match.
194     if (NextLevelOpL != NextLevelOpR)
195       return false;
196 
197     NextLevelOp = NextLevelOpL;
198   } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
199     // On the first level we can omit the shufflevector <0, undef,...>. So the
200     // input to the other shufflevector <1, undef> must match with one of the
201     // inputs to the current binary operation.
202     // Example:
203     //  %NextLevelOpL = shufflevector %R, <1, undef ...>
204     //  %BinOp        = fadd          %NextLevelOpL, %R
205     if (NextLevelOpL && NextLevelOpL != R)
206       return false;
207     else if (NextLevelOpR && NextLevelOpR != L)
208       return false;
209 
210     NextLevelOp = NextLevelOpL ? R : L;
211   } else
212     return false;
213 
214   // Check that the next levels binary operation exists and matches with the
215   // current one.
216   BinaryOperator *NextLevelBinOp = nullptr;
217   if (Level + 1 != NumLevels) {
218     if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp)))
219       return false;
220     else if (NextLevelBinOp->getOpcode() != Opcode)
221       return false;
222   }
223 
224   // Shuffle mask for pairwise operation must match.
225   if (matchPairwiseShuffleMask(LS, true, Level)) {
226     if (!matchPairwiseShuffleMask(RS, false, Level))
227       return false;
228   } else if (matchPairwiseShuffleMask(RS, true, Level)) {
229     if (!matchPairwiseShuffleMask(LS, false, Level))
230       return false;
231   } else
232     return false;
233 
234   if (++Level == NumLevels)
235     return true;
236 
237   // Match next level.
238   return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels);
239 }
240 
matchPairwiseReduction(const ExtractElementInst * ReduxRoot,unsigned & Opcode,Type * & Ty)241 static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
242                                    unsigned &Opcode, Type *&Ty) {
243   if (!EnableReduxCost)
244     return false;
245 
246   // Need to extract the first element.
247   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
248   unsigned Idx = ~0u;
249   if (CI)
250     Idx = CI->getZExtValue();
251   if (Idx != 0)
252     return false;
253 
254   BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
255   if (!RdxStart)
256     return false;
257 
258   Type *VecTy = ReduxRoot->getOperand(0)->getType();
259   unsigned NumVecElems = VecTy->getVectorNumElements();
260   if (!isPowerOf2_32(NumVecElems))
261     return false;
262 
263   // We look for a sequence of shuffle,shuffle,add triples like the following
264   // that builds a pairwise reduction tree.
265   //
266   //  (X0, X1, X2, X3)
267   //   (X0 + X1, X2 + X3, undef, undef)
268   //    ((X0 + X1) + (X2 + X3), undef, undef, undef)
269   //
270   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
271   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
272   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
273   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
274   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
275   // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
276   //       <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef>
277   // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
278   //       <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
279   // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
280   // %r = extractelement <4 x float> %bin.rdx8, i32 0
281   if (!matchPairwiseReductionAtLevel(RdxStart, 0,  Log2_32(NumVecElems)))
282     return false;
283 
284   Opcode = RdxStart->getOpcode();
285   Ty = VecTy;
286 
287   return true;
288 }
289 
290 static std::pair<Value *, ShuffleVectorInst *>
getShuffleAndOtherOprd(BinaryOperator * B)291 getShuffleAndOtherOprd(BinaryOperator *B) {
292 
293   Value *L = B->getOperand(0);
294   Value *R = B->getOperand(1);
295   ShuffleVectorInst *S = nullptr;
296 
297   if ((S = dyn_cast<ShuffleVectorInst>(L)))
298     return std::make_pair(R, S);
299 
300   S = dyn_cast<ShuffleVectorInst>(R);
301   return std::make_pair(L, S);
302 }
303 
matchVectorSplittingReduction(const ExtractElementInst * ReduxRoot,unsigned & Opcode,Type * & Ty)304 static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
305                                           unsigned &Opcode, Type *&Ty) {
306   if (!EnableReduxCost)
307     return false;
308 
309   // Need to extract the first element.
310   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
311   unsigned Idx = ~0u;
312   if (CI)
313     Idx = CI->getZExtValue();
314   if (Idx != 0)
315     return false;
316 
317   BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
318   if (!RdxStart)
319     return false;
320   unsigned RdxOpcode = RdxStart->getOpcode();
321 
322   Type *VecTy = ReduxRoot->getOperand(0)->getType();
323   unsigned NumVecElems = VecTy->getVectorNumElements();
324   if (!isPowerOf2_32(NumVecElems))
325     return false;
326 
327   // We look for a sequence of shuffles and adds like the following matching one
328   // fadd, shuffle vector pair at a time.
329   //
330   // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef,
331   //                           <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
332   // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf
333   // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef,
334   //                          <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
335   // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7
336   // %r = extractelement <4 x float> %bin.rdx8, i32 0
337 
338   unsigned MaskStart = 1;
339   Value *RdxOp = RdxStart;
340   SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
341   unsigned NumVecElemsRemain = NumVecElems;
342   while (NumVecElemsRemain - 1) {
343     // Check for the right reduction operation.
344     BinaryOperator *BinOp;
345     if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp)))
346       return false;
347     if (BinOp->getOpcode() != RdxOpcode)
348       return false;
349 
350     Value *NextRdxOp;
351     ShuffleVectorInst *Shuffle;
352     std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp);
353 
354     // Check the current reduction operation and the shuffle use the same value.
355     if (Shuffle == nullptr)
356       return false;
357     if (Shuffle->getOperand(0) != NextRdxOp)
358       return false;
359 
360     // Check that shuffle masks matches.
361     for (unsigned j = 0; j != MaskStart; ++j)
362       ShuffleMask[j] = MaskStart + j;
363     // Fill the rest of the mask with -1 for undef.
364     std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1);
365 
366     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
367     if (ShuffleMask != Mask)
368       return false;
369 
370     RdxOp = NextRdxOp;
371     NumVecElemsRemain /= 2;
372     MaskStart *= 2;
373   }
374 
375   Opcode = RdxOpcode;
376   Ty = VecTy;
377   return true;
378 }
379 
getInstructionCost(const Instruction * I) const380 unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
381   if (!TTI)
382     return -1;
383 
384   switch (I->getOpcode()) {
385   case Instruction::GetElementPtr:{
386     Type *ValTy = I->getOperand(0)->getType()->getPointerElementType();
387     return TTI->getAddressComputationCost(ValTy);
388   }
389 
390   case Instruction::Ret:
391   case Instruction::PHI:
392   case Instruction::Br: {
393     return TTI->getCFInstrCost(I->getOpcode());
394   }
395   case Instruction::Add:
396   case Instruction::FAdd:
397   case Instruction::Sub:
398   case Instruction::FSub:
399   case Instruction::Mul:
400   case Instruction::FMul:
401   case Instruction::UDiv:
402   case Instruction::SDiv:
403   case Instruction::FDiv:
404   case Instruction::URem:
405   case Instruction::SRem:
406   case Instruction::FRem:
407   case Instruction::Shl:
408   case Instruction::LShr:
409   case Instruction::AShr:
410   case Instruction::And:
411   case Instruction::Or:
412   case Instruction::Xor: {
413     TargetTransformInfo::OperandValueKind Op1VK =
414       getOperandInfo(I->getOperand(0));
415     TargetTransformInfo::OperandValueKind Op2VK =
416       getOperandInfo(I->getOperand(1));
417     return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK,
418                                        Op2VK);
419   }
420   case Instruction::Select: {
421     const SelectInst *SI = cast<SelectInst>(I);
422     Type *CondTy = SI->getCondition()->getType();
423     return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
424   }
425   case Instruction::ICmp:
426   case Instruction::FCmp: {
427     Type *ValTy = I->getOperand(0)->getType();
428     return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
429   }
430   case Instruction::Store: {
431     const StoreInst *SI = cast<StoreInst>(I);
432     Type *ValTy = SI->getValueOperand()->getType();
433     return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
434                                  SI->getAlignment(),
435                                  SI->getPointerAddressSpace());
436   }
437   case Instruction::Load: {
438     const LoadInst *LI = cast<LoadInst>(I);
439     return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
440                                  LI->getAlignment(),
441                                  LI->getPointerAddressSpace());
442   }
443   case Instruction::ZExt:
444   case Instruction::SExt:
445   case Instruction::FPToUI:
446   case Instruction::FPToSI:
447   case Instruction::FPExt:
448   case Instruction::PtrToInt:
449   case Instruction::IntToPtr:
450   case Instruction::SIToFP:
451   case Instruction::UIToFP:
452   case Instruction::Trunc:
453   case Instruction::FPTrunc:
454   case Instruction::BitCast:
455   case Instruction::AddrSpaceCast: {
456     Type *SrcTy = I->getOperand(0)->getType();
457     return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
458   }
459   case Instruction::ExtractElement: {
460     const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
461     ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
462     unsigned Idx = -1;
463     if (CI)
464       Idx = CI->getZExtValue();
465 
466     // Try to match a reduction sequence (series of shufflevector and vector
467     // adds followed by a extractelement).
468     unsigned ReduxOpCode;
469     Type *ReduxType;
470 
471     if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType))
472       return TTI->getReductionCost(ReduxOpCode, ReduxType, false);
473     else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType))
474       return TTI->getReductionCost(ReduxOpCode, ReduxType, true);
475 
476     return TTI->getVectorInstrCost(I->getOpcode(),
477                                    EEI->getOperand(0)->getType(), Idx);
478   }
479   case Instruction::InsertElement: {
480     const InsertElementInst * IE = cast<InsertElementInst>(I);
481     ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
482     unsigned Idx = -1;
483     if (CI)
484       Idx = CI->getZExtValue();
485     return TTI->getVectorInstrCost(I->getOpcode(),
486                                    IE->getType(), Idx);
487   }
488   case Instruction::ShuffleVector: {
489     const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
490     Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
491     unsigned NumVecElems = VecTypOp0->getVectorNumElements();
492     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
493 
494     if (NumVecElems == Mask.size()) {
495       if (isReverseVectorMask(Mask))
496         return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0,
497                                    0, nullptr);
498       if (isAlternateVectorMask(Mask))
499         return TTI->getShuffleCost(TargetTransformInfo::SK_Alternate,
500                                    VecTypOp0, 0, nullptr);
501     }
502 
503     return -1;
504   }
505   case Instruction::Call:
506     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
507       SmallVector<Type*, 4> Tys;
508       for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J)
509         Tys.push_back(II->getArgOperand(J)->getType());
510 
511       return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
512                                         Tys);
513     }
514     return -1;
515   default:
516     // We don't have any information on this instruction.
517     return -1;
518   }
519 }
520 
print(raw_ostream & OS,const Module *) const521 void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
522   if (!F)
523     return;
524 
525   for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) {
526     for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) {
527       Instruction *Inst = it;
528       unsigned Cost = getInstructionCost(Inst);
529       if (Cost != (unsigned)-1)
530         OS << "Cost Model: Found an estimated cost of " << Cost;
531       else
532         OS << "Cost Model: Unknown cost";
533 
534       OS << " for instruction: "<< *Inst << "\n";
535     }
536   }
537 }
538