• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 
15 #include "NVPTXISelLowering.h"
16 #include "NVPTX.h"
17 #include "NVPTXTargetMachine.h"
18 #include "NVPTXTargetObjectFile.h"
19 #include "NVPTXUtilities.h"
20 #include "llvm/CodeGen/Analysis.h"
21 #include "llvm/CodeGen/MachineFrameInfo.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/GlobalValue.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/MC/MCSectionELF.h"
33 #include "llvm/Support/CallSite.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <sstream>
39 
40 #undef DEBUG_TYPE
41 #define DEBUG_TYPE "nvptx-lower"
42 
43 using namespace llvm;
44 
45 static unsigned int uniqueCallSite = 0;
46 
47 static cl::opt<bool>
48 sched4reg("nvptx-sched4reg",
49           cl::desc("NVPTX Specific: schedule for register pressue"),
50           cl::init(false));
51 
IsPTXVectorType(MVT VT)52 static bool IsPTXVectorType(MVT VT) {
53   switch (VT.SimpleTy) {
54   default: return false;
55   case MVT::v2i8:
56   case MVT::v4i8:
57   case MVT::v2i16:
58   case MVT::v4i16:
59   case MVT::v2i32:
60   case MVT::v4i32:
61   case MVT::v2i64:
62   case MVT::v2f32:
63   case MVT::v4f32:
64   case MVT::v2f64:
65   return true;
66   }
67 }
68 
69 // NVPTXTargetLowering Constructor.
NVPTXTargetLowering(NVPTXTargetMachine & TM)70 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
71 : TargetLowering(TM, new NVPTXTargetObjectFile()),
72   nvTM(&TM),
73   nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
74 
75   // always lower memset, memcpy, and memmove intrinsics to load/store
76   // instructions, rather
77   // then generating calls to memset, mempcy or memmove.
78   MaxStoresPerMemset = (unsigned)0xFFFFFFFF;
79   MaxStoresPerMemcpy = (unsigned)0xFFFFFFFF;
80   MaxStoresPerMemmove = (unsigned)0xFFFFFFFF;
81 
82   setBooleanContents(ZeroOrNegativeOneBooleanContent);
83 
84   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
85   // condition branches.
86   setJumpIsExpensive(true);
87 
88   // By default, use the Source scheduling
89   if (sched4reg)
90     setSchedulingPreference(Sched::RegPressure);
91   else
92     setSchedulingPreference(Sched::Source);
93 
94   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
95   addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
96   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
97   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
98   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
99   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
100   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
101 
102   // Operations not directly supported by NVPTX.
103   setOperationAction(ISD::SELECT_CC,         MVT::Other, Expand);
104   setOperationAction(ISD::BR_CC,             MVT::f32, Expand);
105   setOperationAction(ISD::BR_CC,             MVT::f64, Expand);
106   setOperationAction(ISD::BR_CC,             MVT::i1,  Expand);
107   setOperationAction(ISD::BR_CC,             MVT::i8,  Expand);
108   setOperationAction(ISD::BR_CC,             MVT::i16, Expand);
109   setOperationAction(ISD::BR_CC,             MVT::i32, Expand);
110   setOperationAction(ISD::BR_CC,             MVT::i64, Expand);
111   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
112   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
113   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
114   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Expand);
115   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1 , Expand);
116 
117   if (nvptxSubtarget.hasROT64()) {
118     setOperationAction(ISD::ROTL , MVT::i64, Legal);
119     setOperationAction(ISD::ROTR , MVT::i64, Legal);
120   }
121   else {
122     setOperationAction(ISD::ROTL , MVT::i64, Expand);
123     setOperationAction(ISD::ROTR , MVT::i64, Expand);
124   }
125   if (nvptxSubtarget.hasROT32()) {
126     setOperationAction(ISD::ROTL , MVT::i32, Legal);
127     setOperationAction(ISD::ROTR , MVT::i32, Legal);
128   }
129   else {
130     setOperationAction(ISD::ROTL , MVT::i32, Expand);
131     setOperationAction(ISD::ROTR , MVT::i32, Expand);
132   }
133 
134   setOperationAction(ISD::ROTL , MVT::i16, Expand);
135   setOperationAction(ISD::ROTR , MVT::i16, Expand);
136   setOperationAction(ISD::ROTL , MVT::i8, Expand);
137   setOperationAction(ISD::ROTR , MVT::i8, Expand);
138   setOperationAction(ISD::BSWAP , MVT::i16, Expand);
139   setOperationAction(ISD::BSWAP , MVT::i32, Expand);
140   setOperationAction(ISD::BSWAP , MVT::i64, Expand);
141 
142   // Indirect branch is not supported.
143   // This also disables Jump Table creation.
144   setOperationAction(ISD::BR_JT,             MVT::Other, Expand);
145   setOperationAction(ISD::BRIND,             MVT::Other, Expand);
146 
147   setOperationAction(ISD::GlobalAddress   , MVT::i32  , Custom);
148   setOperationAction(ISD::GlobalAddress   , MVT::i64  , Custom);
149 
150   // We want to legalize constant related memmove and memcopy
151   // intrinsics.
152   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
153 
154   // Turn FP extload into load/fextend
155   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
156   // Turn FP truncstore into trunc + store.
157   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
158 
159   // PTX does not support load / store predicate registers
160   setOperationAction(ISD::LOAD, MVT::i1, Custom);
161   setOperationAction(ISD::STORE, MVT::i1, Custom);
162 
163   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
164   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
165   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
166   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
167   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
168   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
169 
170   // This is legal in NVPTX
171   setOperationAction(ISD::ConstantFP,         MVT::f64, Legal);
172   setOperationAction(ISD::ConstantFP,         MVT::f32, Legal);
173 
174   // TRAP can be lowered to PTX trap
175   setOperationAction(ISD::TRAP,               MVT::Other, Legal);
176 
177   // Register custom handling for vector loads/stores
178   for (int i = MVT::FIRST_VECTOR_VALUETYPE;
179        i <= MVT::LAST_VECTOR_VALUETYPE; ++i) {
180     MVT VT = (MVT::SimpleValueType)i;
181     if (IsPTXVectorType(VT)) {
182       setOperationAction(ISD::LOAD, VT, Custom);
183       setOperationAction(ISD::STORE, VT, Custom);
184       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
185     }
186   }
187 
188   // Now deduce the information based on the above mentioned
189   // actions
190   computeRegisterProperties();
191 }
192 
193 
getTargetNodeName(unsigned Opcode) const194 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
195   switch (Opcode) {
196   default: return 0;
197   case NVPTXISD::CALL:            return "NVPTXISD::CALL";
198   case NVPTXISD::RET_FLAG:        return "NVPTXISD::RET_FLAG";
199   case NVPTXISD::Wrapper:         return "NVPTXISD::Wrapper";
200   case NVPTXISD::NVBuiltin:       return "NVPTXISD::NVBuiltin";
201   case NVPTXISD::DeclareParam:    return "NVPTXISD::DeclareParam";
202   case NVPTXISD::DeclareScalarParam:
203     return "NVPTXISD::DeclareScalarParam";
204   case NVPTXISD::DeclareRet:      return "NVPTXISD::DeclareRet";
205   case NVPTXISD::DeclareRetParam: return "NVPTXISD::DeclareRetParam";
206   case NVPTXISD::PrintCall:       return "NVPTXISD::PrintCall";
207   case NVPTXISD::LoadParam:       return "NVPTXISD::LoadParam";
208   case NVPTXISD::StoreParam:      return "NVPTXISD::StoreParam";
209   case NVPTXISD::StoreParamS32:   return "NVPTXISD::StoreParamS32";
210   case NVPTXISD::StoreParamU32:   return "NVPTXISD::StoreParamU32";
211   case NVPTXISD::MoveToParam:     return "NVPTXISD::MoveToParam";
212   case NVPTXISD::CallArgBegin:    return "NVPTXISD::CallArgBegin";
213   case NVPTXISD::CallArg:         return "NVPTXISD::CallArg";
214   case NVPTXISD::LastCallArg:     return "NVPTXISD::LastCallArg";
215   case NVPTXISD::CallArgEnd:      return "NVPTXISD::CallArgEnd";
216   case NVPTXISD::CallVoid:        return "NVPTXISD::CallVoid";
217   case NVPTXISD::CallVal:         return "NVPTXISD::CallVal";
218   case NVPTXISD::CallSymbol:      return "NVPTXISD::CallSymbol";
219   case NVPTXISD::Prototype:       return "NVPTXISD::Prototype";
220   case NVPTXISD::MoveParam:       return "NVPTXISD::MoveParam";
221   case NVPTXISD::MoveRetval:      return "NVPTXISD::MoveRetval";
222   case NVPTXISD::MoveToRetval:    return "NVPTXISD::MoveToRetval";
223   case NVPTXISD::StoreRetval:     return "NVPTXISD::StoreRetval";
224   case NVPTXISD::PseudoUseParam:  return "NVPTXISD::PseudoUseParam";
225   case NVPTXISD::RETURN:          return "NVPTXISD::RETURN";
226   case NVPTXISD::CallSeqBegin:    return "NVPTXISD::CallSeqBegin";
227   case NVPTXISD::CallSeqEnd:      return "NVPTXISD::CallSeqEnd";
228   case NVPTXISD::LoadV2:          return "NVPTXISD::LoadV2";
229   case NVPTXISD::LoadV4:          return "NVPTXISD::LoadV4";
230   case NVPTXISD::LDGV2:           return "NVPTXISD::LDGV2";
231   case NVPTXISD::LDGV4:           return "NVPTXISD::LDGV4";
232   case NVPTXISD::LDUV2:           return "NVPTXISD::LDUV2";
233   case NVPTXISD::LDUV4:           return "NVPTXISD::LDUV4";
234   case NVPTXISD::StoreV2:         return "NVPTXISD::StoreV2";
235   case NVPTXISD::StoreV4:         return "NVPTXISD::StoreV4";
236   }
237 }
238 
shouldSplitVectorElementType(EVT VT) const239 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
240   return VT == MVT::i1;
241 }
242 
243 SDValue
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const244 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
245   DebugLoc dl = Op.getDebugLoc();
246   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
247   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
248   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
249 }
250 
getPrototype(Type * retTy,const ArgListTy & Args,const SmallVectorImpl<ISD::OutputArg> & Outs,unsigned retAlignment) const251 std::string NVPTXTargetLowering::getPrototype(Type *retTy,
252                                               const ArgListTy &Args,
253                                     const SmallVectorImpl<ISD::OutputArg> &Outs,
254                                               unsigned retAlignment) const {
255 
256   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
257 
258   std::stringstream O;
259   O << "prototype_" << uniqueCallSite << " : .callprototype ";
260 
261   if (retTy->getTypeID() == Type::VoidTyID)
262     O << "()";
263   else {
264     O << "(";
265     if (isABI) {
266       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
267         unsigned size = 0;
268         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
269           size = ITy->getBitWidth();
270           if (size < 32) size = 32;
271         }
272         else {
273           assert(retTy->isFloatingPointTy() &&
274                  "Floating point type expected here");
275           size = retTy->getPrimitiveSizeInBits();
276         }
277 
278         O << ".param .b" << size << " _";
279       }
280       else if (isa<PointerType>(retTy))
281         O << ".param .b" << getPointerTy().getSizeInBits()
282         << " _";
283       else {
284         if ((retTy->getTypeID() == Type::StructTyID) ||
285             isa<VectorType>(retTy)) {
286           SmallVector<EVT, 16> vtparts;
287           ComputeValueVTs(*this, retTy, vtparts);
288           unsigned totalsz = 0;
289           for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
290             unsigned elems = 1;
291             EVT elemtype = vtparts[i];
292             if (vtparts[i].isVector()) {
293               elems = vtparts[i].getVectorNumElements();
294               elemtype = vtparts[i].getVectorElementType();
295             }
296             for (unsigned j=0, je=elems; j!=je; ++j) {
297               unsigned sz = elemtype.getSizeInBits();
298               if (elemtype.isInteger() && (sz < 8)) sz = 8;
299               totalsz += sz/8;
300             }
301           }
302           O << ".param .align "
303               << retAlignment
304               << " .b8 _["
305               << totalsz << "]";
306         }
307         else {
308           assert(false &&
309                  "Unknown return type");
310         }
311       }
312     }
313     else {
314       SmallVector<EVT, 16> vtparts;
315       ComputeValueVTs(*this, retTy, vtparts);
316       unsigned idx = 0;
317       for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
318         unsigned elems = 1;
319         EVT elemtype = vtparts[i];
320         if (vtparts[i].isVector()) {
321           elems = vtparts[i].getVectorNumElements();
322           elemtype = vtparts[i].getVectorElementType();
323         }
324 
325         for (unsigned j=0, je=elems; j!=je; ++j) {
326           unsigned sz = elemtype.getSizeInBits();
327           if (elemtype.isInteger() && (sz < 32)) sz = 32;
328           O << ".reg .b" << sz << " _";
329           if (j<je-1) O << ", ";
330           ++idx;
331         }
332         if (i < e-1)
333           O << ", ";
334       }
335     }
336     O << ") ";
337   }
338   O << "_ (";
339 
340   bool first = true;
341   MVT thePointerTy = getPointerTy();
342 
343   for (unsigned i=0,e=Args.size(); i!=e; ++i) {
344     const Type *Ty = Args[i].Ty;
345     if (!first) {
346       O << ", ";
347     }
348     first = false;
349 
350     if (Outs[i].Flags.isByVal() == false) {
351       unsigned sz = 0;
352       if (isa<IntegerType>(Ty)) {
353         sz = cast<IntegerType>(Ty)->getBitWidth();
354         if (sz < 32) sz = 32;
355       }
356       else if (isa<PointerType>(Ty))
357         sz = thePointerTy.getSizeInBits();
358       else
359         sz = Ty->getPrimitiveSizeInBits();
360       if (isABI)
361         O << ".param .b" << sz << " ";
362       else
363         O << ".reg .b" << sz << " ";
364       O << "_";
365       continue;
366     }
367     const PointerType *PTy = dyn_cast<PointerType>(Ty);
368     assert(PTy &&
369            "Param with byval attribute should be a pointer type");
370     Type *ETy = PTy->getElementType();
371 
372     if (isABI) {
373       unsigned align = Outs[i].Flags.getByValAlign();
374       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
375       O << ".param .align " << align
376           << " .b8 ";
377       O << "_";
378       O << "[" << sz << "]";
379       continue;
380     }
381     else {
382       SmallVector<EVT, 16> vtparts;
383       ComputeValueVTs(*this, ETy, vtparts);
384       for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
385         unsigned elems = 1;
386         EVT elemtype = vtparts[i];
387         if (vtparts[i].isVector()) {
388           elems = vtparts[i].getVectorNumElements();
389           elemtype = vtparts[i].getVectorElementType();
390         }
391 
392         for (unsigned j=0,je=elems; j!=je; ++j) {
393           unsigned sz = elemtype.getSizeInBits();
394           if (elemtype.isInteger() && (sz < 32)) sz = 32;
395           O << ".reg .b" << sz << " ";
396           O << "_";
397           if (j<je-1) O << ", ";
398         }
399         if (i<e-1)
400           O << ", ";
401       }
402       continue;
403     }
404   }
405   O << ");";
406   return O.str();
407 }
408 
409 
410 SDValue
LowerCall(TargetLowering::CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const411 NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
412                                SmallVectorImpl<SDValue> &InVals) const {
413   SelectionDAG &DAG                     = CLI.DAG;
414   DebugLoc &dl                          = CLI.DL;
415   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
416   SmallVector<SDValue, 32> &OutVals     = CLI.OutVals;
417   SmallVector<ISD::InputArg, 32> &Ins   = CLI.Ins;
418   SDValue Chain                         = CLI.Chain;
419   SDValue Callee                        = CLI.Callee;
420   bool &isTailCall                      = CLI.IsTailCall;
421   ArgListTy &Args                       = CLI.Args;
422   Type *retTy                           = CLI.RetTy;
423   ImmutableCallSite *CS                 = CLI.CS;
424 
425   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
426 
427   SDValue tempChain = Chain;
428   Chain = DAG.getCALLSEQ_START(Chain,
429                                DAG.getIntPtrConstant(uniqueCallSite, true));
430   SDValue InFlag = Chain.getValue(1);
431 
432   assert((Outs.size() == Args.size()) &&
433          "Unexpected number of arguments to function call");
434   unsigned paramCount = 0;
435   // Declare the .params or .reg need to pass values
436   // to the function
437   for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
438     EVT VT = Outs[i].VT;
439 
440     if (Outs[i].Flags.isByVal() == false) {
441       // Plain scalar
442       // for ABI,    declare .param .b<size> .param<n>;
443       // for nonABI, declare .reg .b<size> .param<n>;
444       unsigned isReg = 1;
445       if (isABI)
446         isReg = 0;
447       unsigned sz = VT.getSizeInBits();
448       if (VT.isInteger() && (sz < 32)) sz = 32;
449       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
450       SDValue DeclareParamOps[] = { Chain,
451                                     DAG.getConstant(paramCount, MVT::i32),
452                                     DAG.getConstant(sz, MVT::i32),
453                                     DAG.getConstant(isReg, MVT::i32),
454                                     InFlag };
455       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
456                           DeclareParamOps, 5);
457       InFlag = Chain.getValue(1);
458       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
459       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
460                              DAG.getConstant(0, MVT::i32), OutVals[i], InFlag };
461 
462       unsigned opcode = NVPTXISD::StoreParam;
463       if (isReg)
464         opcode = NVPTXISD::MoveToParam;
465       else {
466         if (Outs[i].Flags.isZExt())
467           opcode = NVPTXISD::StoreParamU32;
468         else if (Outs[i].Flags.isSExt())
469           opcode = NVPTXISD::StoreParamS32;
470       }
471       Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
472 
473       InFlag = Chain.getValue(1);
474       ++paramCount;
475       continue;
476     }
477     // struct or vector
478     SmallVector<EVT, 16> vtparts;
479     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
480     assert(PTy &&
481            "Type of a byval parameter should be pointer");
482     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
483 
484     if (isABI) {
485       // declare .param .align 16 .b8 .param<n>[<size>];
486       unsigned sz = Outs[i].Flags.getByValSize();
487       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
488       // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
489       // don't need to
490       // worry about natural alignment or not. See TargetLowering::LowerCallTo()
491       SDValue DeclareParamOps[] = { Chain,
492                        DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
493                                     DAG.getConstant(paramCount, MVT::i32),
494                                     DAG.getConstant(sz, MVT::i32),
495                                     InFlag };
496       Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
497                           DeclareParamOps, 5);
498       InFlag = Chain.getValue(1);
499       unsigned curOffset = 0;
500       for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
501         unsigned elems = 1;
502         EVT elemtype = vtparts[j];
503         if (vtparts[j].isVector()) {
504           elems = vtparts[j].getVectorNumElements();
505           elemtype = vtparts[j].getVectorElementType();
506         }
507         for (unsigned k=0,ke=elems; k!=ke; ++k) {
508           unsigned sz = elemtype.getSizeInBits();
509           if (elemtype.isInteger() && (sz < 8)) sz = 8;
510           SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
511                                         OutVals[i],
512                                         DAG.getConstant(curOffset,
513                                                         getPointerTy()));
514           SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
515                                 MachinePointerInfo(), false, false, false, 0);
516           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
517           SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount,
518                                                             MVT::i32),
519                                            DAG.getConstant(curOffset, MVT::i32),
520                                                             theVal, InFlag };
521           Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
522                               CopyParamOps, 5);
523           InFlag = Chain.getValue(1);
524           curOffset += sz/8;
525         }
526       }
527       ++paramCount;
528       continue;
529     }
530     // Non-abi, struct or vector
531     // Declare a bunch or .reg .b<size> .param<n>
532     unsigned curOffset = 0;
533     for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
534       unsigned elems = 1;
535       EVT elemtype = vtparts[j];
536       if (vtparts[j].isVector()) {
537         elems = vtparts[j].getVectorNumElements();
538         elemtype = vtparts[j].getVectorElementType();
539       }
540       for (unsigned k=0,ke=elems; k!=ke; ++k) {
541         unsigned sz = elemtype.getSizeInBits();
542         if (elemtype.isInteger() && (sz < 32)) sz = 32;
543         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
544         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount,
545                                                              MVT::i32),
546                                                   DAG.getConstant(sz, MVT::i32),
547                                                    DAG.getConstant(1, MVT::i32),
548                                                              InFlag };
549         Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
550                             DeclareParamOps, 5);
551         InFlag = Chain.getValue(1);
552         SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
553                                       DAG.getConstant(curOffset,
554                                                       getPointerTy()));
555         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
556                                   MachinePointerInfo(), false, false, false, 0);
557         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
558         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
559                                    DAG.getConstant(0, MVT::i32), theVal,
560                                    InFlag };
561         Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
562                             CopyParamOps, 5);
563         InFlag = Chain.getValue(1);
564         ++paramCount;
565       }
566     }
567   }
568 
569   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
570   unsigned retAlignment = 0;
571 
572   // Handle Result
573   unsigned retCount = 0;
574   if (Ins.size() > 0) {
575     SmallVector<EVT, 16> resvtparts;
576     ComputeValueVTs(*this, retTy, resvtparts);
577 
578     // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
579     // individual .reg .b<size> func_retval<0..> for non ABI
580     unsigned resultsz = 0;
581     for (unsigned i=0,e=resvtparts.size(); i!=e; ++i) {
582       unsigned elems = 1;
583       EVT elemtype = resvtparts[i];
584       if (resvtparts[i].isVector()) {
585         elems = resvtparts[i].getVectorNumElements();
586         elemtype = resvtparts[i].getVectorElementType();
587       }
588       for (unsigned j=0,je=elems; j!=je; ++j) {
589         unsigned sz = elemtype.getSizeInBits();
590         if (isABI == false) {
591           if (elemtype.isInteger() && (sz < 32)) sz = 32;
592         }
593         else {
594           if (elemtype.isInteger() && (sz < 8)) sz = 8;
595         }
596         if (isABI == false) {
597           SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
598           SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
599                                       DAG.getConstant(sz, MVT::i32),
600                                       DAG.getConstant(retCount, MVT::i32),
601                                       InFlag };
602           Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
603                               DeclareRetOps, 5);
604           InFlag = Chain.getValue(1);
605           ++retCount;
606         }
607         resultsz += sz;
608       }
609     }
610     if (isABI) {
611       if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
612           retTy->isPointerTy() ) {
613         // Scalar needs to be at least 32bit wide
614         if (resultsz < 32)
615           resultsz = 32;
616         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
617         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
618                                     DAG.getConstant(resultsz, MVT::i32),
619                                     DAG.getConstant(0, MVT::i32), InFlag };
620         Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
621                             DeclareRetOps, 5);
622         InFlag = Chain.getValue(1);
623       }
624       else {
625         if (Func) { // direct call
626           if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
627             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
628         } else { // indirect call
629           const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
630           if (!llvm::getAlign(*CallI, 0, retAlignment))
631             retAlignment = getDataLayout()->getABITypeAlignment(retTy);
632         }
633         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
634         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(retAlignment,
635                                                            MVT::i32),
636                                           DAG.getConstant(resultsz/8, MVT::i32),
637                                          DAG.getConstant(0, MVT::i32), InFlag };
638         Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
639                             DeclareRetOps, 5);
640         InFlag = Chain.getValue(1);
641       }
642     }
643   }
644 
645   if (!Func) {
646     // This is indirect function call case : PTX requires a prototype of the
647     // form
648     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
649     // to be emitted, and the label has to used as the last arg of call
650     // instruction.
651     // The prototype is embedded in a string and put as the operand for an
652     // INLINEASM SDNode.
653     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
654     std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
655     const char *asmstr = nvTM->getManagedStrPool()->
656         getManagedString(proto_string.c_str())->c_str();
657     SDValue InlineAsmOps[] = { Chain,
658                                DAG.getTargetExternalSymbol(asmstr,
659                                                            getPointerTy()),
660                                                            DAG.getMDNode(0),
661                                    DAG.getTargetConstant(0, MVT::i32), InFlag };
662     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
663     InFlag = Chain.getValue(1);
664   }
665   // Op to just print "call"
666   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
667   SDValue PrintCallOps[] = { Chain,
668                              DAG.getConstant(isABI ? ((Ins.size()==0) ? 0 : 1)
669                                  : retCount, MVT::i32),
670                                    InFlag };
671   Chain = DAG.getNode(Func?(NVPTXISD::PrintCallUni):(NVPTXISD::PrintCall), dl,
672       PrintCallVTs, PrintCallOps, 3);
673   InFlag = Chain.getValue(1);
674 
675   // Ops to print out the function name
676   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
677   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
678   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
679   InFlag = Chain.getValue(1);
680 
681   // Ops to print out the param list
682   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
683   SDValue CallArgBeginOps[] = { Chain, InFlag };
684   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
685                       CallArgBeginOps, 2);
686   InFlag = Chain.getValue(1);
687 
688   for (unsigned i=0, e=paramCount; i!=e; ++i) {
689     unsigned opcode;
690     if (i==(e-1))
691       opcode = NVPTXISD::LastCallArg;
692     else
693       opcode = NVPTXISD::CallArg;
694     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
695     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
696                              DAG.getConstant(i, MVT::i32),
697                              InFlag };
698     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
699     InFlag = Chain.getValue(1);
700   }
701   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
702   SDValue CallArgEndOps[] = { Chain,
703                               DAG.getConstant(Func ? 1 : 0, MVT::i32),
704                               InFlag };
705   Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps,
706                       3);
707   InFlag = Chain.getValue(1);
708 
709   if (!Func) {
710     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
711     SDValue PrototypeOps[] = { Chain,
712                                DAG.getConstant(uniqueCallSite, MVT::i32),
713                                InFlag };
714     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
715     InFlag = Chain.getValue(1);
716   }
717 
718   // Generate loads from param memory/moves from registers for result
719   if (Ins.size() > 0) {
720     if (isABI) {
721       unsigned resoffset = 0;
722       for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
723         unsigned sz = Ins[i].VT.getSizeInBits();
724         if (Ins[i].VT.isInteger() && (sz < 8)) sz = 8;
725         EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue };
726         SDValue LoadRetOps[] = {
727           Chain,
728           DAG.getConstant(1, MVT::i32),
729           DAG.getConstant(resoffset, MVT::i32),
730           InFlag
731         };
732         SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
733                                      LoadRetOps, array_lengthof(LoadRetOps));
734         Chain = retval.getValue(1);
735         InFlag = retval.getValue(2);
736         InVals.push_back(retval);
737         resoffset += sz/8;
738       }
739     }
740     else {
741       SmallVector<EVT, 16> resvtparts;
742       ComputeValueVTs(*this, retTy, resvtparts);
743 
744       assert(Ins.size() == resvtparts.size() &&
745              "Unexpected number of return values in non-ABI case");
746       unsigned paramNum = 0;
747       for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
748         assert(EVT(Ins[i].VT) == resvtparts[i] &&
749                "Unexpected EVT type in non-ABI case");
750         unsigned numelems = 1;
751         EVT elemtype = Ins[i].VT;
752         if (Ins[i].VT.isVector()) {
753           numelems = Ins[i].VT.getVectorNumElements();
754           elemtype = Ins[i].VT.getVectorElementType();
755         }
756         std::vector<SDValue> tempRetVals;
757         for (unsigned j=0; j<numelems; ++j) {
758           EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue };
759           SDValue MoveRetOps[] = {
760             Chain,
761             DAG.getConstant(0, MVT::i32),
762             DAG.getConstant(paramNum, MVT::i32),
763             InFlag
764           };
765           SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
766                                        MoveRetOps, array_lengthof(MoveRetOps));
767           Chain = retval.getValue(1);
768           InFlag = retval.getValue(2);
769           tempRetVals.push_back(retval);
770           ++paramNum;
771         }
772         if (Ins[i].VT.isVector())
773           InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
774                                        &tempRetVals[0], tempRetVals.size()));
775         else
776           InVals.push_back(tempRetVals[0]);
777       }
778     }
779   }
780   Chain = DAG.getCALLSEQ_END(Chain,
781                              DAG.getIntPtrConstant(uniqueCallSite, true),
782                              DAG.getIntPtrConstant(uniqueCallSite+1, true),
783                              InFlag);
784   uniqueCallSite++;
785 
786   // set isTailCall to false for now, until we figure out how to express
787   // tail call optimization in PTX
788   isTailCall = false;
789   return Chain;
790 }
791 
792 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
793 // (see LegalizeDAG.cpp). This is slow and uses local memory.
794 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
795 SDValue NVPTXTargetLowering::
LowerCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG) const796 LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
797   SDNode *Node = Op.getNode();
798   DebugLoc dl = Node->getDebugLoc();
799   SmallVector<SDValue, 8> Ops;
800   unsigned NumOperands = Node->getNumOperands();
801   for (unsigned i=0; i < NumOperands; ++i) {
802     SDValue SubOp = Node->getOperand(i);
803     EVT VVT = SubOp.getNode()->getValueType(0);
804     EVT EltVT = VVT.getVectorElementType();
805     unsigned NumSubElem = VVT.getVectorNumElements();
806     for (unsigned j=0; j < NumSubElem; ++j) {
807       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
808                                 DAG.getIntPtrConstant(j)));
809     }
810   }
811   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0),
812                      &Ops[0], Ops.size());
813 }
814 
815 SDValue NVPTXTargetLowering::
LowerOperation(SDValue Op,SelectionDAG & DAG) const816 LowerOperation(SDValue Op, SelectionDAG &DAG) const {
817   switch (Op.getOpcode()) {
818   case ISD::RETURNADDR: return SDValue();
819   case ISD::FRAMEADDR:  return SDValue();
820   case ISD::GlobalAddress:      return LowerGlobalAddress(Op, DAG);
821   case ISD::INTRINSIC_W_CHAIN: return Op;
822   case ISD::BUILD_VECTOR:
823   case ISD::EXTRACT_SUBVECTOR:
824     return Op;
825   case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG);
826   case ISD::STORE: return LowerSTORE(Op, DAG);
827   case ISD::LOAD: return LowerLOAD(Op, DAG);
828   default:
829     llvm_unreachable("Custom lowering not defined for operation");
830   }
831 }
832 
833 
LowerLOAD(SDValue Op,SelectionDAG & DAG) const834 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
835   if (Op.getValueType() == MVT::i1)
836     return LowerLOADi1(Op, DAG);
837   else
838     return SDValue();
839 }
840 
841 // v = ld i1* addr
842 //   =>
843 // v1 = ld i8* addr
844 // v = trunc v1 to i1
845 SDValue NVPTXTargetLowering::
LowerLOADi1(SDValue Op,SelectionDAG & DAG) const846 LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
847   SDNode *Node = Op.getNode();
848   LoadSDNode *LD = cast<LoadSDNode>(Node);
849   DebugLoc dl = Node->getDebugLoc();
850   assert(LD->getExtensionType() == ISD::NON_EXTLOAD) ;
851   assert(Node->getValueType(0) == MVT::i1 &&
852          "Custom lowering for i1 load only");
853   SDValue newLD = DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
854                               LD->getPointerInfo(),
855                               LD->isVolatile(), LD->isNonTemporal(),
856                               LD->isInvariant(),
857                               LD->getAlignment());
858   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
859   // The legalizer (the caller) is expecting two values from the legalized
860   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
861   // in LegalizeDAG.cpp which also uses MergeValues.
862   SDValue Ops[] = {result, LD->getChain()};
863   return DAG.getMergeValues(Ops, 2, dl);
864 }
865 
LowerSTORE(SDValue Op,SelectionDAG & DAG) const866 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
867   EVT ValVT = Op.getOperand(1).getValueType();
868   if (ValVT == MVT::i1)
869     return LowerSTOREi1(Op, DAG);
870   else if (ValVT.isVector())
871     return LowerSTOREVector(Op, DAG);
872   else
873     return SDValue();
874 }
875 
876 SDValue
LowerSTOREVector(SDValue Op,SelectionDAG & DAG) const877 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
878   SDNode *N = Op.getNode();
879   SDValue Val = N->getOperand(1);
880   DebugLoc DL = N->getDebugLoc();
881   EVT ValVT = Val.getValueType();
882 
883   if (ValVT.isVector()) {
884     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
885     // legal.  We can (and should) split that into 2 stores of <2 x double> here
886     // but I'm leaving that as a TODO for now.
887     if (!ValVT.isSimple())
888       return SDValue();
889     switch (ValVT.getSimpleVT().SimpleTy) {
890     default: return SDValue();
891     case MVT::v2i8:
892     case MVT::v2i16:
893     case MVT::v2i32:
894     case MVT::v2i64:
895     case MVT::v2f32:
896     case MVT::v2f64:
897     case MVT::v4i8:
898     case MVT::v4i16:
899     case MVT::v4i32:
900     case MVT::v4f32:
901       // This is a "native" vector type
902       break;
903     }
904 
905     unsigned Opcode = 0;
906     EVT EltVT = ValVT.getVectorElementType();
907     unsigned NumElts = ValVT.getVectorNumElements();
908 
909     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
910     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
911     // stored type to i16 and propogate the "real" type as the memory type.
912     bool NeedExt = false;
913     if (EltVT.getSizeInBits() < 16)
914       NeedExt = true;
915 
916     switch (NumElts) {
917     default:  return SDValue();
918     case 2:
919       Opcode = NVPTXISD::StoreV2;
920       break;
921     case 4: {
922       Opcode = NVPTXISD::StoreV4;
923       break;
924     }
925     }
926 
927     SmallVector<SDValue, 8> Ops;
928 
929     // First is the chain
930     Ops.push_back(N->getOperand(0));
931 
932     // Then the split values
933     for (unsigned i = 0; i < NumElts; ++i) {
934       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
935                                    DAG.getIntPtrConstant(i));
936       if (NeedExt)
937         // ANY_EXTEND is correct here since the store will only look at the
938         // lower-order bits anyway.
939         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
940       Ops.push_back(ExtVal);
941     }
942 
943     // Then any remaining arguments
944     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
945       Ops.push_back(N->getOperand(i));
946     }
947 
948     MemSDNode *MemSD = cast<MemSDNode>(N);
949 
950     SDValue NewSt = DAG.getMemIntrinsicNode(Opcode, DL,
951                                             DAG.getVTList(MVT::Other), &Ops[0],
952                                             Ops.size(), MemSD->getMemoryVT(),
953                                             MemSD->getMemOperand());
954 
955 
956     //return DCI.CombineTo(N, NewSt, true);
957     return NewSt;
958   }
959 
960   return SDValue();
961 }
962 
963 // st i1 v, addr
964 //    =>
965 // v1 = zxt v to i8
966 // st i8, addr
967 SDValue NVPTXTargetLowering::
LowerSTOREi1(SDValue Op,SelectionDAG & DAG) const968 LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
969   SDNode *Node = Op.getNode();
970   DebugLoc dl = Node->getDebugLoc();
971   StoreSDNode *ST = cast<StoreSDNode>(Node);
972   SDValue Tmp1 = ST->getChain();
973   SDValue Tmp2 = ST->getBasePtr();
974   SDValue Tmp3 = ST->getValue();
975   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
976   unsigned Alignment = ST->getAlignment();
977   bool isVolatile = ST->isVolatile();
978   bool isNonTemporal = ST->isNonTemporal();
979   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl,
980                      MVT::i8, Tmp3);
981   SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2,
982                                 ST->getPointerInfo(), isVolatile,
983                                 isNonTemporal, Alignment);
984   return Result;
985 }
986 
987 
988 SDValue
getExtSymb(SelectionDAG & DAG,const char * inname,int idx,EVT v) const989 NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, int idx,
990                                 EVT v) const {
991   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
992   std::stringstream suffix;
993   suffix << idx;
994   *name += suffix.str();
995   return DAG.getTargetExternalSymbol(name->c_str(), v);
996 }
997 
998 SDValue
getParamSymbol(SelectionDAG & DAG,int idx,EVT v) const999 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1000   return getExtSymb(DAG, ".PARAM", idx, v);
1001 }
1002 
1003 SDValue
getParamHelpSymbol(SelectionDAG & DAG,int idx)1004 NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1005   return getExtSymb(DAG, ".HLPPARAM", idx);
1006 }
1007 
1008 // Check to see if the kernel argument is image*_t or sampler_t
1009 
isImageOrSamplerVal(const Value * arg,const Module * context)1010 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1011   static const char *const specialTypes[] = {
1012                                              "struct._image2d_t",
1013                                              "struct._image3d_t",
1014                                              "struct._sampler_t"
1015   };
1016 
1017   const Type *Ty = arg->getType();
1018   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1019 
1020   if (!PTy)
1021     return false;
1022 
1023   if (!context)
1024     return false;
1025 
1026   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1027   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1028 
1029   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1030     if (TypeName == specialTypes[i])
1031       return true;
1032 
1033   return false;
1034 }
1035 
1036 SDValue
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,DebugLoc dl,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const1037 NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
1038                                         CallingConv::ID CallConv, bool isVarArg,
1039                                       const SmallVectorImpl<ISD::InputArg> &Ins,
1040                                           DebugLoc dl, SelectionDAG &DAG,
1041                                        SmallVectorImpl<SDValue> &InVals) const {
1042   MachineFunction &MF = DAG.getMachineFunction();
1043   const DataLayout *TD = getDataLayout();
1044 
1045   const Function *F = MF.getFunction();
1046   const AttributeSet &PAL = F->getAttributes();
1047 
1048   SDValue Root = DAG.getRoot();
1049   std::vector<SDValue> OutChains;
1050 
1051   bool isKernel = llvm::isKernelFunction(*F);
1052   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1053 
1054   std::vector<Type *> argTypes;
1055   std::vector<const Argument *> theArgs;
1056   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1057       I != E; ++I) {
1058     theArgs.push_back(I);
1059     argTypes.push_back(I->getType());
1060   }
1061   assert(argTypes.size() == Ins.size() &&
1062          "Ins types and function types did not match");
1063 
1064   int idx = 0;
1065   for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
1066     Type *Ty = argTypes[i];
1067     EVT ObjectVT = getValueType(Ty);
1068     assert(ObjectVT == Ins[i].VT &&
1069            "Ins type did not match function type");
1070 
1071     // If the kernel argument is image*_t or sampler_t, convert it to
1072     // a i32 constant holding the parameter position. This can later
1073     // matched in the AsmPrinter to output the correct mangled name.
1074     if (isImageOrSamplerVal(theArgs[i],
1075                            (theArgs[i]->getParent() ?
1076                                theArgs[i]->getParent()->getParent() : 0))) {
1077       assert(isKernel && "Only kernels can have image/sampler params");
1078       InVals.push_back(DAG.getConstant(i+1, MVT::i32));
1079       continue;
1080     }
1081 
1082     if (theArgs[i]->use_empty()) {
1083       // argument is dead
1084       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
1085       continue;
1086     }
1087 
1088     // In the following cases, assign a node order of "idx+1"
1089     // to newly created nodes. The SDNOdes for params have to
1090     // appear in the same order as their order of appearance
1091     // in the original function. "idx+1" holds that order.
1092     if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
1093       // A plain scalar.
1094       if (isABI || isKernel) {
1095         // If ABI, load from the param symbol
1096         SDValue Arg = getParamSymbol(DAG, idx);
1097         // Conjure up a value that we can get the address space from.
1098         // FIXME: Using a constant here is a hack.
1099         Value *srcValue = Constant::getNullValue(PointerType::get(
1100                               ObjectVT.getTypeForEVT(F->getContext()),
1101                               llvm::ADDRESS_SPACE_PARAM));
1102         SDValue p = DAG.getLoad(ObjectVT, dl, Root, Arg,
1103                                 MachinePointerInfo(srcValue), false, false,
1104                                 false,
1105                                 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(
1106                                   F->getContext())));
1107         if (p.getNode())
1108           DAG.AssignOrdering(p.getNode(), idx+1);
1109         InVals.push_back(p);
1110       }
1111       else {
1112         // If no ABI, just move the param symbol
1113         SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
1114         SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1115         if (p.getNode())
1116           DAG.AssignOrdering(p.getNode(), idx+1);
1117         InVals.push_back(p);
1118       }
1119       continue;
1120     }
1121 
1122     // Param has ByVal attribute
1123     if (isABI || isKernel) {
1124       // Return MoveParam(param symbol).
1125       // Ideally, the param symbol can be returned directly,
1126       // but when SDNode builder decides to use it in a CopyToReg(),
1127       // machine instruction fails because TargetExternalSymbol
1128       // (not lowered) is target dependent, and CopyToReg assumes
1129       // the source is lowered.
1130       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1131       SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1132       if (p.getNode())
1133         DAG.AssignOrdering(p.getNode(), idx+1);
1134       if (isKernel)
1135         InVals.push_back(p);
1136       else {
1137         SDValue p2 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1138                     DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32),
1139                                  p);
1140         InVals.push_back(p2);
1141       }
1142     } else {
1143       // Have to move a set of param symbols to registers and
1144       // store them locally and return the local pointer in InVals
1145       const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
1146       assert(elemPtrType &&
1147              "Byval parameter should be a pointer type");
1148       Type *elemType = elemPtrType->getElementType();
1149       // Compute the constituent parts
1150       SmallVector<EVT, 16> vtparts;
1151       SmallVector<uint64_t, 16> offsets;
1152       ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
1153       unsigned totalsize = 0;
1154       for (unsigned j=0, je=vtparts.size(); j!=je; ++j)
1155         totalsize += vtparts[j].getStoreSizeInBits();
1156       SDValue localcopy =  DAG.getFrameIndex(MF.getFrameInfo()->
1157                                       CreateStackObject(totalsize/8, 16, false),
1158                                              getPointerTy());
1159       unsigned sizesofar = 0;
1160       std::vector<SDValue> theChains;
1161       for (unsigned j=0, je=vtparts.size(); j!=je; ++j) {
1162         unsigned numElems = 1;
1163         if (vtparts[j].isVector()) numElems = vtparts[j].getVectorNumElements();
1164         for (unsigned k=0, ke=numElems; k!=ke; ++k) {
1165           EVT tmpvt = vtparts[j];
1166           if (tmpvt.isVector()) tmpvt = tmpvt.getVectorElementType();
1167           SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
1168                                     getParamSymbol(DAG, idx, tmpvt));
1169           SDValue addr = DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
1170                                     DAG.getConstant(sizesofar, getPointerTy()));
1171           theChains.push_back(DAG.getStore(Chain, dl, arg, addr,
1172                                         MachinePointerInfo(), false, false, 0));
1173           sizesofar += tmpvt.getStoreSizeInBits()/8;
1174           ++idx;
1175         }
1176       }
1177       --idx;
1178       Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
1179                           theChains.size());
1180       InVals.push_back(localcopy);
1181     }
1182   }
1183 
1184   // Clang will check explicit VarArg and issue error if any. However, Clang
1185   // will let code with
1186   // implicit var arg like f() pass.
1187   // We treat this case as if the arg list is empty.
1188   //if (F.isVarArg()) {
1189   // assert(0 && "VarArg not supported yet!");
1190   //}
1191 
1192   if (!OutChains.empty())
1193     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
1194                             &OutChains[0], OutChains.size()));
1195 
1196   return Chain;
1197 }
1198 
1199 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,DebugLoc dl,SelectionDAG & DAG) const1200 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1201                                  bool isVarArg,
1202                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1203                                  const SmallVectorImpl<SDValue> &OutVals,
1204                                  DebugLoc dl, SelectionDAG &DAG) const {
1205 
1206   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1207 
1208   unsigned sizesofar = 0;
1209   unsigned idx = 0;
1210   for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
1211     SDValue theVal = OutVals[i];
1212     EVT theValType = theVal.getValueType();
1213     unsigned numElems = 1;
1214     if (theValType.isVector()) numElems = theValType.getVectorNumElements();
1215     for (unsigned j=0,je=numElems; j!=je; ++j) {
1216       SDValue tmpval = theVal;
1217       if (theValType.isVector())
1218         tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1219                              theValType.getVectorElementType(),
1220                              tmpval, DAG.getIntPtrConstant(j));
1221       Chain = DAG.getNode(isABI ? NVPTXISD::StoreRetval :NVPTXISD::MoveToRetval,
1222           dl, MVT::Other,
1223           Chain,
1224           DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
1225           tmpval);
1226       if (theValType.isVector())
1227         sizesofar += theValType.getVectorElementType().getStoreSizeInBits()/8;
1228       else
1229         sizesofar += theValType.getStoreSizeInBits()/8;
1230       ++idx;
1231     }
1232   }
1233 
1234   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1235 }
1236 
1237 void
LowerAsmOperandForConstraint(SDValue Op,std::string & Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const1238 NVPTXTargetLowering::LowerAsmOperandForConstraint(SDValue Op,
1239                                                   std::string &Constraint,
1240                                                   std::vector<SDValue> &Ops,
1241                                                   SelectionDAG &DAG) const
1242 {
1243   if (Constraint.length() > 1)
1244     return;
1245   else
1246     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1247 }
1248 
1249 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1250 // NVPTX specific type legalizer
1251 // will legalize them to the PTX supported length.
1252 bool
isTypeSupportedInIntrinsic(MVT VT) const1253 NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1254   if (isTypeLegal(VT))
1255     return true;
1256   if (VT.isVector()) {
1257     MVT eVT = VT.getVectorElementType();
1258     if (isTypeLegal(eVT))
1259       return true;
1260   }
1261   return false;
1262 }
1263 
1264 
1265 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1266 // TgtMemIntrinsic
1267 // because we need the information that is only available in the "Value" type
1268 // of destination
1269 // pointer. In particular, the address space information.
1270 bool
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,unsigned Intrinsic) const1271 NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo& Info, const CallInst &I,
1272                                         unsigned Intrinsic) const {
1273   switch (Intrinsic) {
1274   default:
1275     return false;
1276 
1277   case Intrinsic::nvvm_atomic_load_add_f32:
1278     Info.opc = ISD::INTRINSIC_W_CHAIN;
1279     Info.memVT = MVT::f32;
1280     Info.ptrVal = I.getArgOperand(0);
1281     Info.offset = 0;
1282     Info.vol = 0;
1283     Info.readMem = true;
1284     Info.writeMem = true;
1285     Info.align = 0;
1286     return true;
1287 
1288   case Intrinsic::nvvm_atomic_load_inc_32:
1289   case Intrinsic::nvvm_atomic_load_dec_32:
1290     Info.opc = ISD::INTRINSIC_W_CHAIN;
1291     Info.memVT = MVT::i32;
1292     Info.ptrVal = I.getArgOperand(0);
1293     Info.offset = 0;
1294     Info.vol = 0;
1295     Info.readMem = true;
1296     Info.writeMem = true;
1297     Info.align = 0;
1298     return true;
1299 
1300   case Intrinsic::nvvm_ldu_global_i:
1301   case Intrinsic::nvvm_ldu_global_f:
1302   case Intrinsic::nvvm_ldu_global_p:
1303 
1304     Info.opc = ISD::INTRINSIC_W_CHAIN;
1305     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1306       Info.memVT = MVT::i32;
1307     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1308       Info.memVT = getPointerTy();
1309     else
1310       Info.memVT = MVT::f32;
1311     Info.ptrVal = I.getArgOperand(0);
1312     Info.offset = 0;
1313     Info.vol = 0;
1314     Info.readMem = true;
1315     Info.writeMem = false;
1316     Info.align = 0;
1317     return true;
1318 
1319   }
1320   return false;
1321 }
1322 
1323 /// isLegalAddressingMode - Return true if the addressing mode represented
1324 /// by AM is legal for this target, for a load/store of the specified type.
1325 /// Used to guide target specific optimizations, like loop strength reduction
1326 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1327 /// (CodeGenPrepare.cpp)
1328 bool
isLegalAddressingMode(const AddrMode & AM,Type * Ty) const1329 NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1330                                            Type *Ty) const {
1331 
1332   // AddrMode - This represents an addressing mode of:
1333   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1334   //
1335   // The legal address modes are
1336   // - [avar]
1337   // - [areg]
1338   // - [areg+immoff]
1339   // - [immAddr]
1340 
1341   if (AM.BaseGV) {
1342     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1343       return false;
1344     return true;
1345   }
1346 
1347   switch (AM.Scale) {
1348   case 0:  // "r", "r+i" or "i" is allowed
1349     break;
1350   case 1:
1351     if (AM.HasBaseReg)  // "r+r+i" or "r+r" is not allowed.
1352       return false;
1353     // Otherwise we have r+i.
1354     break;
1355   default:
1356     // No scale > 1 is allowed
1357     return false;
1358   }
1359   return true;
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 //                         NVPTX Inline Assembly Support
1364 //===----------------------------------------------------------------------===//
1365 
1366 /// getConstraintType - Given a constraint letter, return the type of
1367 /// constraint it is for this target.
1368 NVPTXTargetLowering::ConstraintType
getConstraintType(const std::string & Constraint) const1369 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1370   if (Constraint.size() == 1) {
1371     switch (Constraint[0]) {
1372     default:
1373       break;
1374     case 'r':
1375     case 'h':
1376     case 'c':
1377     case 'l':
1378     case 'f':
1379     case 'd':
1380     case '0':
1381     case 'N':
1382       return C_RegisterClass;
1383     }
1384   }
1385   return TargetLowering::getConstraintType(Constraint);
1386 }
1387 
1388 
1389 std::pair<unsigned, const TargetRegisterClass*>
getRegForInlineAsmConstraint(const std::string & Constraint,EVT VT) const1390 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1391                                                   EVT VT) const {
1392   if (Constraint.size() == 1) {
1393     switch (Constraint[0]) {
1394     case 'c':
1395       return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1396     case 'h':
1397       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1398     case 'r':
1399       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1400     case 'l':
1401     case 'N':
1402       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1403     case 'f':
1404       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1405     case 'd':
1406       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1407     }
1408   }
1409   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1410 }
1411 
1412 
1413 
1414 /// getFunctionAlignment - Return the Log2 alignment of this function.
getFunctionAlignment(const Function *) const1415 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1416   return 4;
1417 }
1418 
1419 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
ReplaceLoadVector(SDNode * N,SelectionDAG & DAG,SmallVectorImpl<SDValue> & Results)1420 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
1421                               SmallVectorImpl<SDValue>& Results) {
1422   EVT ResVT = N->getValueType(0);
1423   DebugLoc DL = N->getDebugLoc();
1424 
1425   assert(ResVT.isVector() && "Vector load must have vector type");
1426 
1427   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1428   // legal.  We can (and should) split that into 2 loads of <2 x double> here
1429   // but I'm leaving that as a TODO for now.
1430   assert(ResVT.isSimple() && "Can only handle simple types");
1431   switch (ResVT.getSimpleVT().SimpleTy) {
1432   default: return;
1433   case MVT::v2i8:
1434   case MVT::v2i16:
1435   case MVT::v2i32:
1436   case MVT::v2i64:
1437   case MVT::v2f32:
1438   case MVT::v2f64:
1439   case MVT::v4i8:
1440   case MVT::v4i16:
1441   case MVT::v4i32:
1442   case MVT::v4f32:
1443     // This is a "native" vector type
1444     break;
1445   }
1446 
1447   EVT EltVT = ResVT.getVectorElementType();
1448   unsigned NumElts = ResVT.getVectorNumElements();
1449 
1450   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
1451   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1452   // loaded type to i16 and propogate the "real" type as the memory type.
1453   bool NeedTrunc = false;
1454   if (EltVT.getSizeInBits() < 16) {
1455     EltVT = MVT::i16;
1456     NeedTrunc = true;
1457   }
1458 
1459   unsigned Opcode = 0;
1460   SDVTList LdResVTs;
1461 
1462   switch (NumElts) {
1463   default:  return;
1464   case 2:
1465     Opcode = NVPTXISD::LoadV2;
1466     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1467     break;
1468   case 4: {
1469     Opcode = NVPTXISD::LoadV4;
1470     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1471     LdResVTs = DAG.getVTList(ListVTs, 5);
1472     break;
1473   }
1474   }
1475 
1476   SmallVector<SDValue, 8> OtherOps;
1477 
1478   // Copy regular operands
1479   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1480     OtherOps.push_back(N->getOperand(i));
1481 
1482   LoadSDNode *LD = cast<LoadSDNode>(N);
1483 
1484   // The select routine does not have access to the LoadSDNode instance, so
1485   // pass along the extension information
1486   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
1487 
1488   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1489                                           OtherOps.size(), LD->getMemoryVT(),
1490                                           LD->getMemOperand());
1491 
1492   SmallVector<SDValue, 4> ScalarRes;
1493 
1494   for (unsigned i = 0; i < NumElts; ++i) {
1495     SDValue Res = NewLD.getValue(i);
1496     if (NeedTrunc)
1497       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1498     ScalarRes.push_back(Res);
1499   }
1500 
1501   SDValue LoadChain = NewLD.getValue(NumElts);
1502 
1503   SDValue BuildVec = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1504 
1505   Results.push_back(BuildVec);
1506   Results.push_back(LoadChain);
1507 }
1508 
ReplaceINTRINSIC_W_CHAIN(SDNode * N,SelectionDAG & DAG,SmallVectorImpl<SDValue> & Results)1509 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N,
1510                                      SelectionDAG &DAG,
1511                                      SmallVectorImpl<SDValue> &Results) {
1512   SDValue Chain = N->getOperand(0);
1513   SDValue Intrin = N->getOperand(1);
1514   DebugLoc DL = N->getDebugLoc();
1515 
1516   // Get the intrinsic ID
1517   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
1518   switch(IntrinNo) {
1519   default: return;
1520   case Intrinsic::nvvm_ldg_global_i:
1521   case Intrinsic::nvvm_ldg_global_f:
1522   case Intrinsic::nvvm_ldg_global_p:
1523   case Intrinsic::nvvm_ldu_global_i:
1524   case Intrinsic::nvvm_ldu_global_f:
1525   case Intrinsic::nvvm_ldu_global_p: {
1526     EVT ResVT = N->getValueType(0);
1527 
1528     if (ResVT.isVector()) {
1529       // Vector LDG/LDU
1530 
1531       unsigned NumElts = ResVT.getVectorNumElements();
1532       EVT EltVT = ResVT.getVectorElementType();
1533 
1534       // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization.
1535       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1536       // loaded type to i16 and propogate the "real" type as the memory type.
1537       bool NeedTrunc = false;
1538       if (EltVT.getSizeInBits() < 16) {
1539         EltVT = MVT::i16;
1540         NeedTrunc = true;
1541       }
1542 
1543       unsigned Opcode = 0;
1544       SDVTList LdResVTs;
1545 
1546       switch (NumElts) {
1547       default:  return;
1548       case 2:
1549         switch(IntrinNo) {
1550         default: return;
1551         case Intrinsic::nvvm_ldg_global_i:
1552         case Intrinsic::nvvm_ldg_global_f:
1553         case Intrinsic::nvvm_ldg_global_p:
1554           Opcode = NVPTXISD::LDGV2;
1555           break;
1556         case Intrinsic::nvvm_ldu_global_i:
1557         case Intrinsic::nvvm_ldu_global_f:
1558         case Intrinsic::nvvm_ldu_global_p:
1559           Opcode = NVPTXISD::LDUV2;
1560           break;
1561         }
1562         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1563         break;
1564       case 4: {
1565         switch(IntrinNo) {
1566         default: return;
1567         case Intrinsic::nvvm_ldg_global_i:
1568         case Intrinsic::nvvm_ldg_global_f:
1569         case Intrinsic::nvvm_ldg_global_p:
1570           Opcode = NVPTXISD::LDGV4;
1571           break;
1572         case Intrinsic::nvvm_ldu_global_i:
1573         case Intrinsic::nvvm_ldu_global_f:
1574         case Intrinsic::nvvm_ldu_global_p:
1575           Opcode = NVPTXISD::LDUV4;
1576           break;
1577         }
1578         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1579         LdResVTs = DAG.getVTList(ListVTs, 5);
1580         break;
1581       }
1582       }
1583 
1584       SmallVector<SDValue, 8> OtherOps;
1585 
1586       // Copy regular operands
1587 
1588       OtherOps.push_back(Chain); // Chain
1589       // Skip operand 1 (intrinsic ID)
1590       // Others
1591       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
1592         OtherOps.push_back(N->getOperand(i));
1593 
1594       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1595 
1596       SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1597                                               OtherOps.size(), MemSD->getMemoryVT(),
1598                                               MemSD->getMemOperand());
1599 
1600       SmallVector<SDValue, 4> ScalarRes;
1601 
1602       for (unsigned i = 0; i < NumElts; ++i) {
1603         SDValue Res = NewLD.getValue(i);
1604         if (NeedTrunc)
1605           Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1606         ScalarRes.push_back(Res);
1607       }
1608 
1609       SDValue LoadChain = NewLD.getValue(NumElts);
1610 
1611       SDValue BuildVec = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1612 
1613       Results.push_back(BuildVec);
1614       Results.push_back(LoadChain);
1615     } else {
1616       // i8 LDG/LDU
1617       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
1618              "Custom handling of non-i8 ldu/ldg?");
1619 
1620       // Just copy all operands as-is
1621       SmallVector<SDValue, 4> Ops;
1622       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1623         Ops.push_back(N->getOperand(i));
1624 
1625       // Force output to i16
1626       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
1627 
1628       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1629 
1630       // We make sure the memory type is i8, which will be used during isel
1631       // to select the proper instruction.
1632       SDValue NewLD = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL,
1633                                               LdResVTs, &Ops[0],
1634                                               Ops.size(), MVT::i8,
1635                                               MemSD->getMemOperand());
1636 
1637       Results.push_back(NewLD.getValue(0));
1638       Results.push_back(NewLD.getValue(1));
1639     }
1640   }
1641   }
1642 }
1643 
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const1644 void NVPTXTargetLowering::ReplaceNodeResults(SDNode *N,
1645                                              SmallVectorImpl<SDValue> &Results,
1646                                              SelectionDAG &DAG) const {
1647   switch (N->getOpcode()) {
1648   default: report_fatal_error("Unhandled custom legalization");
1649   case ISD::LOAD:
1650     ReplaceLoadVector(N, DAG, Results);
1651     return;
1652   case ISD::INTRINSIC_W_CHAIN:
1653     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
1654     return;
1655   }
1656 }
1657