1 //
2 // The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14
15 #include "NVPTXISelLowering.h"
16 #include "NVPTX.h"
17 #include "NVPTXTargetMachine.h"
18 #include "NVPTXTargetObjectFile.h"
19 #include "NVPTXUtilities.h"
20 #include "llvm/CodeGen/Analysis.h"
21 #include "llvm/CodeGen/MachineFrameInfo.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineInstrBuilder.h"
24 #include "llvm/CodeGen/MachineRegisterInfo.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/GlobalValue.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Module.h"
32 #include "llvm/MC/MCSectionELF.h"
33 #include "llvm/Support/CallSite.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <sstream>
39
40 #undef DEBUG_TYPE
41 #define DEBUG_TYPE "nvptx-lower"
42
43 using namespace llvm;
44
45 static unsigned int uniqueCallSite = 0;
46
47 static cl::opt<bool>
48 sched4reg("nvptx-sched4reg",
49 cl::desc("NVPTX Specific: schedule for register pressue"),
50 cl::init(false));
51
IsPTXVectorType(MVT VT)52 static bool IsPTXVectorType(MVT VT) {
53 switch (VT.SimpleTy) {
54 default: return false;
55 case MVT::v2i8:
56 case MVT::v4i8:
57 case MVT::v2i16:
58 case MVT::v4i16:
59 case MVT::v2i32:
60 case MVT::v4i32:
61 case MVT::v2i64:
62 case MVT::v2f32:
63 case MVT::v4f32:
64 case MVT::v2f64:
65 return true;
66 }
67 }
68
69 // NVPTXTargetLowering Constructor.
NVPTXTargetLowering(NVPTXTargetMachine & TM)70 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
71 : TargetLowering(TM, new NVPTXTargetObjectFile()),
72 nvTM(&TM),
73 nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
74
75 // always lower memset, memcpy, and memmove intrinsics to load/store
76 // instructions, rather
77 // then generating calls to memset, mempcy or memmove.
78 MaxStoresPerMemset = (unsigned)0xFFFFFFFF;
79 MaxStoresPerMemcpy = (unsigned)0xFFFFFFFF;
80 MaxStoresPerMemmove = (unsigned)0xFFFFFFFF;
81
82 setBooleanContents(ZeroOrNegativeOneBooleanContent);
83
84 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
85 // condition branches.
86 setJumpIsExpensive(true);
87
88 // By default, use the Source scheduling
89 if (sched4reg)
90 setSchedulingPreference(Sched::RegPressure);
91 else
92 setSchedulingPreference(Sched::Source);
93
94 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
95 addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
96 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
97 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
98 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
99 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
100 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
101
102 // Operations not directly supported by NVPTX.
103 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
104 setOperationAction(ISD::BR_CC, MVT::f32, Expand);
105 setOperationAction(ISD::BR_CC, MVT::f64, Expand);
106 setOperationAction(ISD::BR_CC, MVT::i1, Expand);
107 setOperationAction(ISD::BR_CC, MVT::i8, Expand);
108 setOperationAction(ISD::BR_CC, MVT::i16, Expand);
109 setOperationAction(ISD::BR_CC, MVT::i32, Expand);
110 setOperationAction(ISD::BR_CC, MVT::i64, Expand);
111 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
112 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
113 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
114 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Expand);
115 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1 , Expand);
116
117 if (nvptxSubtarget.hasROT64()) {
118 setOperationAction(ISD::ROTL , MVT::i64, Legal);
119 setOperationAction(ISD::ROTR , MVT::i64, Legal);
120 }
121 else {
122 setOperationAction(ISD::ROTL , MVT::i64, Expand);
123 setOperationAction(ISD::ROTR , MVT::i64, Expand);
124 }
125 if (nvptxSubtarget.hasROT32()) {
126 setOperationAction(ISD::ROTL , MVT::i32, Legal);
127 setOperationAction(ISD::ROTR , MVT::i32, Legal);
128 }
129 else {
130 setOperationAction(ISD::ROTL , MVT::i32, Expand);
131 setOperationAction(ISD::ROTR , MVT::i32, Expand);
132 }
133
134 setOperationAction(ISD::ROTL , MVT::i16, Expand);
135 setOperationAction(ISD::ROTR , MVT::i16, Expand);
136 setOperationAction(ISD::ROTL , MVT::i8, Expand);
137 setOperationAction(ISD::ROTR , MVT::i8, Expand);
138 setOperationAction(ISD::BSWAP , MVT::i16, Expand);
139 setOperationAction(ISD::BSWAP , MVT::i32, Expand);
140 setOperationAction(ISD::BSWAP , MVT::i64, Expand);
141
142 // Indirect branch is not supported.
143 // This also disables Jump Table creation.
144 setOperationAction(ISD::BR_JT, MVT::Other, Expand);
145 setOperationAction(ISD::BRIND, MVT::Other, Expand);
146
147 setOperationAction(ISD::GlobalAddress , MVT::i32 , Custom);
148 setOperationAction(ISD::GlobalAddress , MVT::i64 , Custom);
149
150 // We want to legalize constant related memmove and memcopy
151 // intrinsics.
152 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
153
154 // Turn FP extload into load/fextend
155 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
156 // Turn FP truncstore into trunc + store.
157 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
158
159 // PTX does not support load / store predicate registers
160 setOperationAction(ISD::LOAD, MVT::i1, Custom);
161 setOperationAction(ISD::STORE, MVT::i1, Custom);
162
163 setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
164 setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
165 setTruncStoreAction(MVT::i64, MVT::i1, Expand);
166 setTruncStoreAction(MVT::i32, MVT::i1, Expand);
167 setTruncStoreAction(MVT::i16, MVT::i1, Expand);
168 setTruncStoreAction(MVT::i8, MVT::i1, Expand);
169
170 // This is legal in NVPTX
171 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
172 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
173
174 // TRAP can be lowered to PTX trap
175 setOperationAction(ISD::TRAP, MVT::Other, Legal);
176
177 // Register custom handling for vector loads/stores
178 for (int i = MVT::FIRST_VECTOR_VALUETYPE;
179 i <= MVT::LAST_VECTOR_VALUETYPE; ++i) {
180 MVT VT = (MVT::SimpleValueType)i;
181 if (IsPTXVectorType(VT)) {
182 setOperationAction(ISD::LOAD, VT, Custom);
183 setOperationAction(ISD::STORE, VT, Custom);
184 setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
185 }
186 }
187
188 // Now deduce the information based on the above mentioned
189 // actions
190 computeRegisterProperties();
191 }
192
193
getTargetNodeName(unsigned Opcode) const194 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
195 switch (Opcode) {
196 default: return 0;
197 case NVPTXISD::CALL: return "NVPTXISD::CALL";
198 case NVPTXISD::RET_FLAG: return "NVPTXISD::RET_FLAG";
199 case NVPTXISD::Wrapper: return "NVPTXISD::Wrapper";
200 case NVPTXISD::NVBuiltin: return "NVPTXISD::NVBuiltin";
201 case NVPTXISD::DeclareParam: return "NVPTXISD::DeclareParam";
202 case NVPTXISD::DeclareScalarParam:
203 return "NVPTXISD::DeclareScalarParam";
204 case NVPTXISD::DeclareRet: return "NVPTXISD::DeclareRet";
205 case NVPTXISD::DeclareRetParam: return "NVPTXISD::DeclareRetParam";
206 case NVPTXISD::PrintCall: return "NVPTXISD::PrintCall";
207 case NVPTXISD::LoadParam: return "NVPTXISD::LoadParam";
208 case NVPTXISD::StoreParam: return "NVPTXISD::StoreParam";
209 case NVPTXISD::StoreParamS32: return "NVPTXISD::StoreParamS32";
210 case NVPTXISD::StoreParamU32: return "NVPTXISD::StoreParamU32";
211 case NVPTXISD::MoveToParam: return "NVPTXISD::MoveToParam";
212 case NVPTXISD::CallArgBegin: return "NVPTXISD::CallArgBegin";
213 case NVPTXISD::CallArg: return "NVPTXISD::CallArg";
214 case NVPTXISD::LastCallArg: return "NVPTXISD::LastCallArg";
215 case NVPTXISD::CallArgEnd: return "NVPTXISD::CallArgEnd";
216 case NVPTXISD::CallVoid: return "NVPTXISD::CallVoid";
217 case NVPTXISD::CallVal: return "NVPTXISD::CallVal";
218 case NVPTXISD::CallSymbol: return "NVPTXISD::CallSymbol";
219 case NVPTXISD::Prototype: return "NVPTXISD::Prototype";
220 case NVPTXISD::MoveParam: return "NVPTXISD::MoveParam";
221 case NVPTXISD::MoveRetval: return "NVPTXISD::MoveRetval";
222 case NVPTXISD::MoveToRetval: return "NVPTXISD::MoveToRetval";
223 case NVPTXISD::StoreRetval: return "NVPTXISD::StoreRetval";
224 case NVPTXISD::PseudoUseParam: return "NVPTXISD::PseudoUseParam";
225 case NVPTXISD::RETURN: return "NVPTXISD::RETURN";
226 case NVPTXISD::CallSeqBegin: return "NVPTXISD::CallSeqBegin";
227 case NVPTXISD::CallSeqEnd: return "NVPTXISD::CallSeqEnd";
228 case NVPTXISD::LoadV2: return "NVPTXISD::LoadV2";
229 case NVPTXISD::LoadV4: return "NVPTXISD::LoadV4";
230 case NVPTXISD::LDGV2: return "NVPTXISD::LDGV2";
231 case NVPTXISD::LDGV4: return "NVPTXISD::LDGV4";
232 case NVPTXISD::LDUV2: return "NVPTXISD::LDUV2";
233 case NVPTXISD::LDUV4: return "NVPTXISD::LDUV4";
234 case NVPTXISD::StoreV2: return "NVPTXISD::StoreV2";
235 case NVPTXISD::StoreV4: return "NVPTXISD::StoreV4";
236 }
237 }
238
shouldSplitVectorElementType(EVT VT) const239 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
240 return VT == MVT::i1;
241 }
242
243 SDValue
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const244 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
245 DebugLoc dl = Op.getDebugLoc();
246 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
247 Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
248 return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
249 }
250
getPrototype(Type * retTy,const ArgListTy & Args,const SmallVectorImpl<ISD::OutputArg> & Outs,unsigned retAlignment) const251 std::string NVPTXTargetLowering::getPrototype(Type *retTy,
252 const ArgListTy &Args,
253 const SmallVectorImpl<ISD::OutputArg> &Outs,
254 unsigned retAlignment) const {
255
256 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
257
258 std::stringstream O;
259 O << "prototype_" << uniqueCallSite << " : .callprototype ";
260
261 if (retTy->getTypeID() == Type::VoidTyID)
262 O << "()";
263 else {
264 O << "(";
265 if (isABI) {
266 if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
267 unsigned size = 0;
268 if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
269 size = ITy->getBitWidth();
270 if (size < 32) size = 32;
271 }
272 else {
273 assert(retTy->isFloatingPointTy() &&
274 "Floating point type expected here");
275 size = retTy->getPrimitiveSizeInBits();
276 }
277
278 O << ".param .b" << size << " _";
279 }
280 else if (isa<PointerType>(retTy))
281 O << ".param .b" << getPointerTy().getSizeInBits()
282 << " _";
283 else {
284 if ((retTy->getTypeID() == Type::StructTyID) ||
285 isa<VectorType>(retTy)) {
286 SmallVector<EVT, 16> vtparts;
287 ComputeValueVTs(*this, retTy, vtparts);
288 unsigned totalsz = 0;
289 for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
290 unsigned elems = 1;
291 EVT elemtype = vtparts[i];
292 if (vtparts[i].isVector()) {
293 elems = vtparts[i].getVectorNumElements();
294 elemtype = vtparts[i].getVectorElementType();
295 }
296 for (unsigned j=0, je=elems; j!=je; ++j) {
297 unsigned sz = elemtype.getSizeInBits();
298 if (elemtype.isInteger() && (sz < 8)) sz = 8;
299 totalsz += sz/8;
300 }
301 }
302 O << ".param .align "
303 << retAlignment
304 << " .b8 _["
305 << totalsz << "]";
306 }
307 else {
308 assert(false &&
309 "Unknown return type");
310 }
311 }
312 }
313 else {
314 SmallVector<EVT, 16> vtparts;
315 ComputeValueVTs(*this, retTy, vtparts);
316 unsigned idx = 0;
317 for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
318 unsigned elems = 1;
319 EVT elemtype = vtparts[i];
320 if (vtparts[i].isVector()) {
321 elems = vtparts[i].getVectorNumElements();
322 elemtype = vtparts[i].getVectorElementType();
323 }
324
325 for (unsigned j=0, je=elems; j!=je; ++j) {
326 unsigned sz = elemtype.getSizeInBits();
327 if (elemtype.isInteger() && (sz < 32)) sz = 32;
328 O << ".reg .b" << sz << " _";
329 if (j<je-1) O << ", ";
330 ++idx;
331 }
332 if (i < e-1)
333 O << ", ";
334 }
335 }
336 O << ") ";
337 }
338 O << "_ (";
339
340 bool first = true;
341 MVT thePointerTy = getPointerTy();
342
343 for (unsigned i=0,e=Args.size(); i!=e; ++i) {
344 const Type *Ty = Args[i].Ty;
345 if (!first) {
346 O << ", ";
347 }
348 first = false;
349
350 if (Outs[i].Flags.isByVal() == false) {
351 unsigned sz = 0;
352 if (isa<IntegerType>(Ty)) {
353 sz = cast<IntegerType>(Ty)->getBitWidth();
354 if (sz < 32) sz = 32;
355 }
356 else if (isa<PointerType>(Ty))
357 sz = thePointerTy.getSizeInBits();
358 else
359 sz = Ty->getPrimitiveSizeInBits();
360 if (isABI)
361 O << ".param .b" << sz << " ";
362 else
363 O << ".reg .b" << sz << " ";
364 O << "_";
365 continue;
366 }
367 const PointerType *PTy = dyn_cast<PointerType>(Ty);
368 assert(PTy &&
369 "Param with byval attribute should be a pointer type");
370 Type *ETy = PTy->getElementType();
371
372 if (isABI) {
373 unsigned align = Outs[i].Flags.getByValAlign();
374 unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
375 O << ".param .align " << align
376 << " .b8 ";
377 O << "_";
378 O << "[" << sz << "]";
379 continue;
380 }
381 else {
382 SmallVector<EVT, 16> vtparts;
383 ComputeValueVTs(*this, ETy, vtparts);
384 for (unsigned i=0,e=vtparts.size(); i!=e; ++i) {
385 unsigned elems = 1;
386 EVT elemtype = vtparts[i];
387 if (vtparts[i].isVector()) {
388 elems = vtparts[i].getVectorNumElements();
389 elemtype = vtparts[i].getVectorElementType();
390 }
391
392 for (unsigned j=0,je=elems; j!=je; ++j) {
393 unsigned sz = elemtype.getSizeInBits();
394 if (elemtype.isInteger() && (sz < 32)) sz = 32;
395 O << ".reg .b" << sz << " ";
396 O << "_";
397 if (j<je-1) O << ", ";
398 }
399 if (i<e-1)
400 O << ", ";
401 }
402 continue;
403 }
404 }
405 O << ");";
406 return O.str();
407 }
408
409
410 SDValue
LowerCall(TargetLowering::CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const411 NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
412 SmallVectorImpl<SDValue> &InVals) const {
413 SelectionDAG &DAG = CLI.DAG;
414 DebugLoc &dl = CLI.DL;
415 SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
416 SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
417 SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
418 SDValue Chain = CLI.Chain;
419 SDValue Callee = CLI.Callee;
420 bool &isTailCall = CLI.IsTailCall;
421 ArgListTy &Args = CLI.Args;
422 Type *retTy = CLI.RetTy;
423 ImmutableCallSite *CS = CLI.CS;
424
425 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
426
427 SDValue tempChain = Chain;
428 Chain = DAG.getCALLSEQ_START(Chain,
429 DAG.getIntPtrConstant(uniqueCallSite, true));
430 SDValue InFlag = Chain.getValue(1);
431
432 assert((Outs.size() == Args.size()) &&
433 "Unexpected number of arguments to function call");
434 unsigned paramCount = 0;
435 // Declare the .params or .reg need to pass values
436 // to the function
437 for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
438 EVT VT = Outs[i].VT;
439
440 if (Outs[i].Flags.isByVal() == false) {
441 // Plain scalar
442 // for ABI, declare .param .b<size> .param<n>;
443 // for nonABI, declare .reg .b<size> .param<n>;
444 unsigned isReg = 1;
445 if (isABI)
446 isReg = 0;
447 unsigned sz = VT.getSizeInBits();
448 if (VT.isInteger() && (sz < 32)) sz = 32;
449 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
450 SDValue DeclareParamOps[] = { Chain,
451 DAG.getConstant(paramCount, MVT::i32),
452 DAG.getConstant(sz, MVT::i32),
453 DAG.getConstant(isReg, MVT::i32),
454 InFlag };
455 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
456 DeclareParamOps, 5);
457 InFlag = Chain.getValue(1);
458 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
459 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
460 DAG.getConstant(0, MVT::i32), OutVals[i], InFlag };
461
462 unsigned opcode = NVPTXISD::StoreParam;
463 if (isReg)
464 opcode = NVPTXISD::MoveToParam;
465 else {
466 if (Outs[i].Flags.isZExt())
467 opcode = NVPTXISD::StoreParamU32;
468 else if (Outs[i].Flags.isSExt())
469 opcode = NVPTXISD::StoreParamS32;
470 }
471 Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
472
473 InFlag = Chain.getValue(1);
474 ++paramCount;
475 continue;
476 }
477 // struct or vector
478 SmallVector<EVT, 16> vtparts;
479 const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
480 assert(PTy &&
481 "Type of a byval parameter should be pointer");
482 ComputeValueVTs(*this, PTy->getElementType(), vtparts);
483
484 if (isABI) {
485 // declare .param .align 16 .b8 .param<n>[<size>];
486 unsigned sz = Outs[i].Flags.getByValSize();
487 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
488 // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
489 // don't need to
490 // worry about natural alignment or not. See TargetLowering::LowerCallTo()
491 SDValue DeclareParamOps[] = { Chain,
492 DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
493 DAG.getConstant(paramCount, MVT::i32),
494 DAG.getConstant(sz, MVT::i32),
495 InFlag };
496 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
497 DeclareParamOps, 5);
498 InFlag = Chain.getValue(1);
499 unsigned curOffset = 0;
500 for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
501 unsigned elems = 1;
502 EVT elemtype = vtparts[j];
503 if (vtparts[j].isVector()) {
504 elems = vtparts[j].getVectorNumElements();
505 elemtype = vtparts[j].getVectorElementType();
506 }
507 for (unsigned k=0,ke=elems; k!=ke; ++k) {
508 unsigned sz = elemtype.getSizeInBits();
509 if (elemtype.isInteger() && (sz < 8)) sz = 8;
510 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
511 OutVals[i],
512 DAG.getConstant(curOffset,
513 getPointerTy()));
514 SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
515 MachinePointerInfo(), false, false, false, 0);
516 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
517 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount,
518 MVT::i32),
519 DAG.getConstant(curOffset, MVT::i32),
520 theVal, InFlag };
521 Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
522 CopyParamOps, 5);
523 InFlag = Chain.getValue(1);
524 curOffset += sz/8;
525 }
526 }
527 ++paramCount;
528 continue;
529 }
530 // Non-abi, struct or vector
531 // Declare a bunch or .reg .b<size> .param<n>
532 unsigned curOffset = 0;
533 for (unsigned j=0,je=vtparts.size(); j!=je; ++j) {
534 unsigned elems = 1;
535 EVT elemtype = vtparts[j];
536 if (vtparts[j].isVector()) {
537 elems = vtparts[j].getVectorNumElements();
538 elemtype = vtparts[j].getVectorElementType();
539 }
540 for (unsigned k=0,ke=elems; k!=ke; ++k) {
541 unsigned sz = elemtype.getSizeInBits();
542 if (elemtype.isInteger() && (sz < 32)) sz = 32;
543 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
544 SDValue DeclareParamOps[] = { Chain, DAG.getConstant(paramCount,
545 MVT::i32),
546 DAG.getConstant(sz, MVT::i32),
547 DAG.getConstant(1, MVT::i32),
548 InFlag };
549 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
550 DeclareParamOps, 5);
551 InFlag = Chain.getValue(1);
552 SDValue srcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
553 DAG.getConstant(curOffset,
554 getPointerTy()));
555 SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
556 MachinePointerInfo(), false, false, false, 0);
557 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
558 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
559 DAG.getConstant(0, MVT::i32), theVal,
560 InFlag };
561 Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
562 CopyParamOps, 5);
563 InFlag = Chain.getValue(1);
564 ++paramCount;
565 }
566 }
567 }
568
569 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
570 unsigned retAlignment = 0;
571
572 // Handle Result
573 unsigned retCount = 0;
574 if (Ins.size() > 0) {
575 SmallVector<EVT, 16> resvtparts;
576 ComputeValueVTs(*this, retTy, resvtparts);
577
578 // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
579 // individual .reg .b<size> func_retval<0..> for non ABI
580 unsigned resultsz = 0;
581 for (unsigned i=0,e=resvtparts.size(); i!=e; ++i) {
582 unsigned elems = 1;
583 EVT elemtype = resvtparts[i];
584 if (resvtparts[i].isVector()) {
585 elems = resvtparts[i].getVectorNumElements();
586 elemtype = resvtparts[i].getVectorElementType();
587 }
588 for (unsigned j=0,je=elems; j!=je; ++j) {
589 unsigned sz = elemtype.getSizeInBits();
590 if (isABI == false) {
591 if (elemtype.isInteger() && (sz < 32)) sz = 32;
592 }
593 else {
594 if (elemtype.isInteger() && (sz < 8)) sz = 8;
595 }
596 if (isABI == false) {
597 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
598 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
599 DAG.getConstant(sz, MVT::i32),
600 DAG.getConstant(retCount, MVT::i32),
601 InFlag };
602 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
603 DeclareRetOps, 5);
604 InFlag = Chain.getValue(1);
605 ++retCount;
606 }
607 resultsz += sz;
608 }
609 }
610 if (isABI) {
611 if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
612 retTy->isPointerTy() ) {
613 // Scalar needs to be at least 32bit wide
614 if (resultsz < 32)
615 resultsz = 32;
616 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
617 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
618 DAG.getConstant(resultsz, MVT::i32),
619 DAG.getConstant(0, MVT::i32), InFlag };
620 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
621 DeclareRetOps, 5);
622 InFlag = Chain.getValue(1);
623 }
624 else {
625 if (Func) { // direct call
626 if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
627 retAlignment = getDataLayout()->getABITypeAlignment(retTy);
628 } else { // indirect call
629 const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
630 if (!llvm::getAlign(*CallI, 0, retAlignment))
631 retAlignment = getDataLayout()->getABITypeAlignment(retTy);
632 }
633 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
634 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(retAlignment,
635 MVT::i32),
636 DAG.getConstant(resultsz/8, MVT::i32),
637 DAG.getConstant(0, MVT::i32), InFlag };
638 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
639 DeclareRetOps, 5);
640 InFlag = Chain.getValue(1);
641 }
642 }
643 }
644
645 if (!Func) {
646 // This is indirect function call case : PTX requires a prototype of the
647 // form
648 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
649 // to be emitted, and the label has to used as the last arg of call
650 // instruction.
651 // The prototype is embedded in a string and put as the operand for an
652 // INLINEASM SDNode.
653 SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
654 std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
655 const char *asmstr = nvTM->getManagedStrPool()->
656 getManagedString(proto_string.c_str())->c_str();
657 SDValue InlineAsmOps[] = { Chain,
658 DAG.getTargetExternalSymbol(asmstr,
659 getPointerTy()),
660 DAG.getMDNode(0),
661 DAG.getTargetConstant(0, MVT::i32), InFlag };
662 Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
663 InFlag = Chain.getValue(1);
664 }
665 // Op to just print "call"
666 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
667 SDValue PrintCallOps[] = { Chain,
668 DAG.getConstant(isABI ? ((Ins.size()==0) ? 0 : 1)
669 : retCount, MVT::i32),
670 InFlag };
671 Chain = DAG.getNode(Func?(NVPTXISD::PrintCallUni):(NVPTXISD::PrintCall), dl,
672 PrintCallVTs, PrintCallOps, 3);
673 InFlag = Chain.getValue(1);
674
675 // Ops to print out the function name
676 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
677 SDValue CallVoidOps[] = { Chain, Callee, InFlag };
678 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
679 InFlag = Chain.getValue(1);
680
681 // Ops to print out the param list
682 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
683 SDValue CallArgBeginOps[] = { Chain, InFlag };
684 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
685 CallArgBeginOps, 2);
686 InFlag = Chain.getValue(1);
687
688 for (unsigned i=0, e=paramCount; i!=e; ++i) {
689 unsigned opcode;
690 if (i==(e-1))
691 opcode = NVPTXISD::LastCallArg;
692 else
693 opcode = NVPTXISD::CallArg;
694 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
695 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
696 DAG.getConstant(i, MVT::i32),
697 InFlag };
698 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
699 InFlag = Chain.getValue(1);
700 }
701 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
702 SDValue CallArgEndOps[] = { Chain,
703 DAG.getConstant(Func ? 1 : 0, MVT::i32),
704 InFlag };
705 Chain = DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps,
706 3);
707 InFlag = Chain.getValue(1);
708
709 if (!Func) {
710 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
711 SDValue PrototypeOps[] = { Chain,
712 DAG.getConstant(uniqueCallSite, MVT::i32),
713 InFlag };
714 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
715 InFlag = Chain.getValue(1);
716 }
717
718 // Generate loads from param memory/moves from registers for result
719 if (Ins.size() > 0) {
720 if (isABI) {
721 unsigned resoffset = 0;
722 for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
723 unsigned sz = Ins[i].VT.getSizeInBits();
724 if (Ins[i].VT.isInteger() && (sz < 8)) sz = 8;
725 EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue };
726 SDValue LoadRetOps[] = {
727 Chain,
728 DAG.getConstant(1, MVT::i32),
729 DAG.getConstant(resoffset, MVT::i32),
730 InFlag
731 };
732 SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
733 LoadRetOps, array_lengthof(LoadRetOps));
734 Chain = retval.getValue(1);
735 InFlag = retval.getValue(2);
736 InVals.push_back(retval);
737 resoffset += sz/8;
738 }
739 }
740 else {
741 SmallVector<EVT, 16> resvtparts;
742 ComputeValueVTs(*this, retTy, resvtparts);
743
744 assert(Ins.size() == resvtparts.size() &&
745 "Unexpected number of return values in non-ABI case");
746 unsigned paramNum = 0;
747 for (unsigned i=0,e=Ins.size(); i!=e; ++i) {
748 assert(EVT(Ins[i].VT) == resvtparts[i] &&
749 "Unexpected EVT type in non-ABI case");
750 unsigned numelems = 1;
751 EVT elemtype = Ins[i].VT;
752 if (Ins[i].VT.isVector()) {
753 numelems = Ins[i].VT.getVectorNumElements();
754 elemtype = Ins[i].VT.getVectorElementType();
755 }
756 std::vector<SDValue> tempRetVals;
757 for (unsigned j=0; j<numelems; ++j) {
758 EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue };
759 SDValue MoveRetOps[] = {
760 Chain,
761 DAG.getConstant(0, MVT::i32),
762 DAG.getConstant(paramNum, MVT::i32),
763 InFlag
764 };
765 SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
766 MoveRetOps, array_lengthof(MoveRetOps));
767 Chain = retval.getValue(1);
768 InFlag = retval.getValue(2);
769 tempRetVals.push_back(retval);
770 ++paramNum;
771 }
772 if (Ins[i].VT.isVector())
773 InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
774 &tempRetVals[0], tempRetVals.size()));
775 else
776 InVals.push_back(tempRetVals[0]);
777 }
778 }
779 }
780 Chain = DAG.getCALLSEQ_END(Chain,
781 DAG.getIntPtrConstant(uniqueCallSite, true),
782 DAG.getIntPtrConstant(uniqueCallSite+1, true),
783 InFlag);
784 uniqueCallSite++;
785
786 // set isTailCall to false for now, until we figure out how to express
787 // tail call optimization in PTX
788 isTailCall = false;
789 return Chain;
790 }
791
792 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
793 // (see LegalizeDAG.cpp). This is slow and uses local memory.
794 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
795 SDValue NVPTXTargetLowering::
LowerCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG) const796 LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
797 SDNode *Node = Op.getNode();
798 DebugLoc dl = Node->getDebugLoc();
799 SmallVector<SDValue, 8> Ops;
800 unsigned NumOperands = Node->getNumOperands();
801 for (unsigned i=0; i < NumOperands; ++i) {
802 SDValue SubOp = Node->getOperand(i);
803 EVT VVT = SubOp.getNode()->getValueType(0);
804 EVT EltVT = VVT.getVectorElementType();
805 unsigned NumSubElem = VVT.getVectorNumElements();
806 for (unsigned j=0; j < NumSubElem; ++j) {
807 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
808 DAG.getIntPtrConstant(j)));
809 }
810 }
811 return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0),
812 &Ops[0], Ops.size());
813 }
814
815 SDValue NVPTXTargetLowering::
LowerOperation(SDValue Op,SelectionDAG & DAG) const816 LowerOperation(SDValue Op, SelectionDAG &DAG) const {
817 switch (Op.getOpcode()) {
818 case ISD::RETURNADDR: return SDValue();
819 case ISD::FRAMEADDR: return SDValue();
820 case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG);
821 case ISD::INTRINSIC_W_CHAIN: return Op;
822 case ISD::BUILD_VECTOR:
823 case ISD::EXTRACT_SUBVECTOR:
824 return Op;
825 case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG);
826 case ISD::STORE: return LowerSTORE(Op, DAG);
827 case ISD::LOAD: return LowerLOAD(Op, DAG);
828 default:
829 llvm_unreachable("Custom lowering not defined for operation");
830 }
831 }
832
833
LowerLOAD(SDValue Op,SelectionDAG & DAG) const834 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
835 if (Op.getValueType() == MVT::i1)
836 return LowerLOADi1(Op, DAG);
837 else
838 return SDValue();
839 }
840
841 // v = ld i1* addr
842 // =>
843 // v1 = ld i8* addr
844 // v = trunc v1 to i1
845 SDValue NVPTXTargetLowering::
LowerLOADi1(SDValue Op,SelectionDAG & DAG) const846 LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
847 SDNode *Node = Op.getNode();
848 LoadSDNode *LD = cast<LoadSDNode>(Node);
849 DebugLoc dl = Node->getDebugLoc();
850 assert(LD->getExtensionType() == ISD::NON_EXTLOAD) ;
851 assert(Node->getValueType(0) == MVT::i1 &&
852 "Custom lowering for i1 load only");
853 SDValue newLD = DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
854 LD->getPointerInfo(),
855 LD->isVolatile(), LD->isNonTemporal(),
856 LD->isInvariant(),
857 LD->getAlignment());
858 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
859 // The legalizer (the caller) is expecting two values from the legalized
860 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
861 // in LegalizeDAG.cpp which also uses MergeValues.
862 SDValue Ops[] = {result, LD->getChain()};
863 return DAG.getMergeValues(Ops, 2, dl);
864 }
865
LowerSTORE(SDValue Op,SelectionDAG & DAG) const866 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
867 EVT ValVT = Op.getOperand(1).getValueType();
868 if (ValVT == MVT::i1)
869 return LowerSTOREi1(Op, DAG);
870 else if (ValVT.isVector())
871 return LowerSTOREVector(Op, DAG);
872 else
873 return SDValue();
874 }
875
876 SDValue
LowerSTOREVector(SDValue Op,SelectionDAG & DAG) const877 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
878 SDNode *N = Op.getNode();
879 SDValue Val = N->getOperand(1);
880 DebugLoc DL = N->getDebugLoc();
881 EVT ValVT = Val.getValueType();
882
883 if (ValVT.isVector()) {
884 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
885 // legal. We can (and should) split that into 2 stores of <2 x double> here
886 // but I'm leaving that as a TODO for now.
887 if (!ValVT.isSimple())
888 return SDValue();
889 switch (ValVT.getSimpleVT().SimpleTy) {
890 default: return SDValue();
891 case MVT::v2i8:
892 case MVT::v2i16:
893 case MVT::v2i32:
894 case MVT::v2i64:
895 case MVT::v2f32:
896 case MVT::v2f64:
897 case MVT::v4i8:
898 case MVT::v4i16:
899 case MVT::v4i32:
900 case MVT::v4f32:
901 // This is a "native" vector type
902 break;
903 }
904
905 unsigned Opcode = 0;
906 EVT EltVT = ValVT.getVectorElementType();
907 unsigned NumElts = ValVT.getVectorNumElements();
908
909 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
910 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
911 // stored type to i16 and propogate the "real" type as the memory type.
912 bool NeedExt = false;
913 if (EltVT.getSizeInBits() < 16)
914 NeedExt = true;
915
916 switch (NumElts) {
917 default: return SDValue();
918 case 2:
919 Opcode = NVPTXISD::StoreV2;
920 break;
921 case 4: {
922 Opcode = NVPTXISD::StoreV4;
923 break;
924 }
925 }
926
927 SmallVector<SDValue, 8> Ops;
928
929 // First is the chain
930 Ops.push_back(N->getOperand(0));
931
932 // Then the split values
933 for (unsigned i = 0; i < NumElts; ++i) {
934 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
935 DAG.getIntPtrConstant(i));
936 if (NeedExt)
937 // ANY_EXTEND is correct here since the store will only look at the
938 // lower-order bits anyway.
939 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
940 Ops.push_back(ExtVal);
941 }
942
943 // Then any remaining arguments
944 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
945 Ops.push_back(N->getOperand(i));
946 }
947
948 MemSDNode *MemSD = cast<MemSDNode>(N);
949
950 SDValue NewSt = DAG.getMemIntrinsicNode(Opcode, DL,
951 DAG.getVTList(MVT::Other), &Ops[0],
952 Ops.size(), MemSD->getMemoryVT(),
953 MemSD->getMemOperand());
954
955
956 //return DCI.CombineTo(N, NewSt, true);
957 return NewSt;
958 }
959
960 return SDValue();
961 }
962
963 // st i1 v, addr
964 // =>
965 // v1 = zxt v to i8
966 // st i8, addr
967 SDValue NVPTXTargetLowering::
LowerSTOREi1(SDValue Op,SelectionDAG & DAG) const968 LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
969 SDNode *Node = Op.getNode();
970 DebugLoc dl = Node->getDebugLoc();
971 StoreSDNode *ST = cast<StoreSDNode>(Node);
972 SDValue Tmp1 = ST->getChain();
973 SDValue Tmp2 = ST->getBasePtr();
974 SDValue Tmp3 = ST->getValue();
975 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
976 unsigned Alignment = ST->getAlignment();
977 bool isVolatile = ST->isVolatile();
978 bool isNonTemporal = ST->isNonTemporal();
979 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl,
980 MVT::i8, Tmp3);
981 SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2,
982 ST->getPointerInfo(), isVolatile,
983 isNonTemporal, Alignment);
984 return Result;
985 }
986
987
988 SDValue
getExtSymb(SelectionDAG & DAG,const char * inname,int idx,EVT v) const989 NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, int idx,
990 EVT v) const {
991 std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
992 std::stringstream suffix;
993 suffix << idx;
994 *name += suffix.str();
995 return DAG.getTargetExternalSymbol(name->c_str(), v);
996 }
997
998 SDValue
getParamSymbol(SelectionDAG & DAG,int idx,EVT v) const999 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1000 return getExtSymb(DAG, ".PARAM", idx, v);
1001 }
1002
1003 SDValue
getParamHelpSymbol(SelectionDAG & DAG,int idx)1004 NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1005 return getExtSymb(DAG, ".HLPPARAM", idx);
1006 }
1007
1008 // Check to see if the kernel argument is image*_t or sampler_t
1009
isImageOrSamplerVal(const Value * arg,const Module * context)1010 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1011 static const char *const specialTypes[] = {
1012 "struct._image2d_t",
1013 "struct._image3d_t",
1014 "struct._sampler_t"
1015 };
1016
1017 const Type *Ty = arg->getType();
1018 const PointerType *PTy = dyn_cast<PointerType>(Ty);
1019
1020 if (!PTy)
1021 return false;
1022
1023 if (!context)
1024 return false;
1025
1026 const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1027 const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1028
1029 for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1030 if (TypeName == specialTypes[i])
1031 return true;
1032
1033 return false;
1034 }
1035
1036 SDValue
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,DebugLoc dl,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const1037 NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
1038 CallingConv::ID CallConv, bool isVarArg,
1039 const SmallVectorImpl<ISD::InputArg> &Ins,
1040 DebugLoc dl, SelectionDAG &DAG,
1041 SmallVectorImpl<SDValue> &InVals) const {
1042 MachineFunction &MF = DAG.getMachineFunction();
1043 const DataLayout *TD = getDataLayout();
1044
1045 const Function *F = MF.getFunction();
1046 const AttributeSet &PAL = F->getAttributes();
1047
1048 SDValue Root = DAG.getRoot();
1049 std::vector<SDValue> OutChains;
1050
1051 bool isKernel = llvm::isKernelFunction(*F);
1052 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1053
1054 std::vector<Type *> argTypes;
1055 std::vector<const Argument *> theArgs;
1056 for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1057 I != E; ++I) {
1058 theArgs.push_back(I);
1059 argTypes.push_back(I->getType());
1060 }
1061 assert(argTypes.size() == Ins.size() &&
1062 "Ins types and function types did not match");
1063
1064 int idx = 0;
1065 for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
1066 Type *Ty = argTypes[i];
1067 EVT ObjectVT = getValueType(Ty);
1068 assert(ObjectVT == Ins[i].VT &&
1069 "Ins type did not match function type");
1070
1071 // If the kernel argument is image*_t or sampler_t, convert it to
1072 // a i32 constant holding the parameter position. This can later
1073 // matched in the AsmPrinter to output the correct mangled name.
1074 if (isImageOrSamplerVal(theArgs[i],
1075 (theArgs[i]->getParent() ?
1076 theArgs[i]->getParent()->getParent() : 0))) {
1077 assert(isKernel && "Only kernels can have image/sampler params");
1078 InVals.push_back(DAG.getConstant(i+1, MVT::i32));
1079 continue;
1080 }
1081
1082 if (theArgs[i]->use_empty()) {
1083 // argument is dead
1084 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
1085 continue;
1086 }
1087
1088 // In the following cases, assign a node order of "idx+1"
1089 // to newly created nodes. The SDNOdes for params have to
1090 // appear in the same order as their order of appearance
1091 // in the original function. "idx+1" holds that order.
1092 if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
1093 // A plain scalar.
1094 if (isABI || isKernel) {
1095 // If ABI, load from the param symbol
1096 SDValue Arg = getParamSymbol(DAG, idx);
1097 // Conjure up a value that we can get the address space from.
1098 // FIXME: Using a constant here is a hack.
1099 Value *srcValue = Constant::getNullValue(PointerType::get(
1100 ObjectVT.getTypeForEVT(F->getContext()),
1101 llvm::ADDRESS_SPACE_PARAM));
1102 SDValue p = DAG.getLoad(ObjectVT, dl, Root, Arg,
1103 MachinePointerInfo(srcValue), false, false,
1104 false,
1105 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(
1106 F->getContext())));
1107 if (p.getNode())
1108 DAG.AssignOrdering(p.getNode(), idx+1);
1109 InVals.push_back(p);
1110 }
1111 else {
1112 // If no ABI, just move the param symbol
1113 SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
1114 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1115 if (p.getNode())
1116 DAG.AssignOrdering(p.getNode(), idx+1);
1117 InVals.push_back(p);
1118 }
1119 continue;
1120 }
1121
1122 // Param has ByVal attribute
1123 if (isABI || isKernel) {
1124 // Return MoveParam(param symbol).
1125 // Ideally, the param symbol can be returned directly,
1126 // but when SDNode builder decides to use it in a CopyToReg(),
1127 // machine instruction fails because TargetExternalSymbol
1128 // (not lowered) is target dependent, and CopyToReg assumes
1129 // the source is lowered.
1130 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1131 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1132 if (p.getNode())
1133 DAG.AssignOrdering(p.getNode(), idx+1);
1134 if (isKernel)
1135 InVals.push_back(p);
1136 else {
1137 SDValue p2 = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1138 DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32),
1139 p);
1140 InVals.push_back(p2);
1141 }
1142 } else {
1143 // Have to move a set of param symbols to registers and
1144 // store them locally and return the local pointer in InVals
1145 const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
1146 assert(elemPtrType &&
1147 "Byval parameter should be a pointer type");
1148 Type *elemType = elemPtrType->getElementType();
1149 // Compute the constituent parts
1150 SmallVector<EVT, 16> vtparts;
1151 SmallVector<uint64_t, 16> offsets;
1152 ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
1153 unsigned totalsize = 0;
1154 for (unsigned j=0, je=vtparts.size(); j!=je; ++j)
1155 totalsize += vtparts[j].getStoreSizeInBits();
1156 SDValue localcopy = DAG.getFrameIndex(MF.getFrameInfo()->
1157 CreateStackObject(totalsize/8, 16, false),
1158 getPointerTy());
1159 unsigned sizesofar = 0;
1160 std::vector<SDValue> theChains;
1161 for (unsigned j=0, je=vtparts.size(); j!=je; ++j) {
1162 unsigned numElems = 1;
1163 if (vtparts[j].isVector()) numElems = vtparts[j].getVectorNumElements();
1164 for (unsigned k=0, ke=numElems; k!=ke; ++k) {
1165 EVT tmpvt = vtparts[j];
1166 if (tmpvt.isVector()) tmpvt = tmpvt.getVectorElementType();
1167 SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
1168 getParamSymbol(DAG, idx, tmpvt));
1169 SDValue addr = DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
1170 DAG.getConstant(sizesofar, getPointerTy()));
1171 theChains.push_back(DAG.getStore(Chain, dl, arg, addr,
1172 MachinePointerInfo(), false, false, 0));
1173 sizesofar += tmpvt.getStoreSizeInBits()/8;
1174 ++idx;
1175 }
1176 }
1177 --idx;
1178 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
1179 theChains.size());
1180 InVals.push_back(localcopy);
1181 }
1182 }
1183
1184 // Clang will check explicit VarArg and issue error if any. However, Clang
1185 // will let code with
1186 // implicit var arg like f() pass.
1187 // We treat this case as if the arg list is empty.
1188 //if (F.isVarArg()) {
1189 // assert(0 && "VarArg not supported yet!");
1190 //}
1191
1192 if (!OutChains.empty())
1193 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
1194 &OutChains[0], OutChains.size()));
1195
1196 return Chain;
1197 }
1198
1199 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,DebugLoc dl,SelectionDAG & DAG) const1200 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1201 bool isVarArg,
1202 const SmallVectorImpl<ISD::OutputArg> &Outs,
1203 const SmallVectorImpl<SDValue> &OutVals,
1204 DebugLoc dl, SelectionDAG &DAG) const {
1205
1206 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1207
1208 unsigned sizesofar = 0;
1209 unsigned idx = 0;
1210 for (unsigned i=0, e=Outs.size(); i!=e; ++i) {
1211 SDValue theVal = OutVals[i];
1212 EVT theValType = theVal.getValueType();
1213 unsigned numElems = 1;
1214 if (theValType.isVector()) numElems = theValType.getVectorNumElements();
1215 for (unsigned j=0,je=numElems; j!=je; ++j) {
1216 SDValue tmpval = theVal;
1217 if (theValType.isVector())
1218 tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1219 theValType.getVectorElementType(),
1220 tmpval, DAG.getIntPtrConstant(j));
1221 Chain = DAG.getNode(isABI ? NVPTXISD::StoreRetval :NVPTXISD::MoveToRetval,
1222 dl, MVT::Other,
1223 Chain,
1224 DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
1225 tmpval);
1226 if (theValType.isVector())
1227 sizesofar += theValType.getVectorElementType().getStoreSizeInBits()/8;
1228 else
1229 sizesofar += theValType.getStoreSizeInBits()/8;
1230 ++idx;
1231 }
1232 }
1233
1234 return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1235 }
1236
1237 void
LowerAsmOperandForConstraint(SDValue Op,std::string & Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const1238 NVPTXTargetLowering::LowerAsmOperandForConstraint(SDValue Op,
1239 std::string &Constraint,
1240 std::vector<SDValue> &Ops,
1241 SelectionDAG &DAG) const
1242 {
1243 if (Constraint.length() > 1)
1244 return;
1245 else
1246 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1247 }
1248
1249 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1250 // NVPTX specific type legalizer
1251 // will legalize them to the PTX supported length.
1252 bool
isTypeSupportedInIntrinsic(MVT VT) const1253 NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1254 if (isTypeLegal(VT))
1255 return true;
1256 if (VT.isVector()) {
1257 MVT eVT = VT.getVectorElementType();
1258 if (isTypeLegal(eVT))
1259 return true;
1260 }
1261 return false;
1262 }
1263
1264
1265 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1266 // TgtMemIntrinsic
1267 // because we need the information that is only available in the "Value" type
1268 // of destination
1269 // pointer. In particular, the address space information.
1270 bool
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,unsigned Intrinsic) const1271 NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo& Info, const CallInst &I,
1272 unsigned Intrinsic) const {
1273 switch (Intrinsic) {
1274 default:
1275 return false;
1276
1277 case Intrinsic::nvvm_atomic_load_add_f32:
1278 Info.opc = ISD::INTRINSIC_W_CHAIN;
1279 Info.memVT = MVT::f32;
1280 Info.ptrVal = I.getArgOperand(0);
1281 Info.offset = 0;
1282 Info.vol = 0;
1283 Info.readMem = true;
1284 Info.writeMem = true;
1285 Info.align = 0;
1286 return true;
1287
1288 case Intrinsic::nvvm_atomic_load_inc_32:
1289 case Intrinsic::nvvm_atomic_load_dec_32:
1290 Info.opc = ISD::INTRINSIC_W_CHAIN;
1291 Info.memVT = MVT::i32;
1292 Info.ptrVal = I.getArgOperand(0);
1293 Info.offset = 0;
1294 Info.vol = 0;
1295 Info.readMem = true;
1296 Info.writeMem = true;
1297 Info.align = 0;
1298 return true;
1299
1300 case Intrinsic::nvvm_ldu_global_i:
1301 case Intrinsic::nvvm_ldu_global_f:
1302 case Intrinsic::nvvm_ldu_global_p:
1303
1304 Info.opc = ISD::INTRINSIC_W_CHAIN;
1305 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1306 Info.memVT = MVT::i32;
1307 else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1308 Info.memVT = getPointerTy();
1309 else
1310 Info.memVT = MVT::f32;
1311 Info.ptrVal = I.getArgOperand(0);
1312 Info.offset = 0;
1313 Info.vol = 0;
1314 Info.readMem = true;
1315 Info.writeMem = false;
1316 Info.align = 0;
1317 return true;
1318
1319 }
1320 return false;
1321 }
1322
1323 /// isLegalAddressingMode - Return true if the addressing mode represented
1324 /// by AM is legal for this target, for a load/store of the specified type.
1325 /// Used to guide target specific optimizations, like loop strength reduction
1326 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1327 /// (CodeGenPrepare.cpp)
1328 bool
isLegalAddressingMode(const AddrMode & AM,Type * Ty) const1329 NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1330 Type *Ty) const {
1331
1332 // AddrMode - This represents an addressing mode of:
1333 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1334 //
1335 // The legal address modes are
1336 // - [avar]
1337 // - [areg]
1338 // - [areg+immoff]
1339 // - [immAddr]
1340
1341 if (AM.BaseGV) {
1342 if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1343 return false;
1344 return true;
1345 }
1346
1347 switch (AM.Scale) {
1348 case 0: // "r", "r+i" or "i" is allowed
1349 break;
1350 case 1:
1351 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1352 return false;
1353 // Otherwise we have r+i.
1354 break;
1355 default:
1356 // No scale > 1 is allowed
1357 return false;
1358 }
1359 return true;
1360 }
1361
1362 //===----------------------------------------------------------------------===//
1363 // NVPTX Inline Assembly Support
1364 //===----------------------------------------------------------------------===//
1365
1366 /// getConstraintType - Given a constraint letter, return the type of
1367 /// constraint it is for this target.
1368 NVPTXTargetLowering::ConstraintType
getConstraintType(const std::string & Constraint) const1369 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1370 if (Constraint.size() == 1) {
1371 switch (Constraint[0]) {
1372 default:
1373 break;
1374 case 'r':
1375 case 'h':
1376 case 'c':
1377 case 'l':
1378 case 'f':
1379 case 'd':
1380 case '0':
1381 case 'N':
1382 return C_RegisterClass;
1383 }
1384 }
1385 return TargetLowering::getConstraintType(Constraint);
1386 }
1387
1388
1389 std::pair<unsigned, const TargetRegisterClass*>
getRegForInlineAsmConstraint(const std::string & Constraint,EVT VT) const1390 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1391 EVT VT) const {
1392 if (Constraint.size() == 1) {
1393 switch (Constraint[0]) {
1394 case 'c':
1395 return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1396 case 'h':
1397 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1398 case 'r':
1399 return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1400 case 'l':
1401 case 'N':
1402 return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1403 case 'f':
1404 return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1405 case 'd':
1406 return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1407 }
1408 }
1409 return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1410 }
1411
1412
1413
1414 /// getFunctionAlignment - Return the Log2 alignment of this function.
getFunctionAlignment(const Function *) const1415 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1416 return 4;
1417 }
1418
1419 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
ReplaceLoadVector(SDNode * N,SelectionDAG & DAG,SmallVectorImpl<SDValue> & Results)1420 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
1421 SmallVectorImpl<SDValue>& Results) {
1422 EVT ResVT = N->getValueType(0);
1423 DebugLoc DL = N->getDebugLoc();
1424
1425 assert(ResVT.isVector() && "Vector load must have vector type");
1426
1427 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1428 // legal. We can (and should) split that into 2 loads of <2 x double> here
1429 // but I'm leaving that as a TODO for now.
1430 assert(ResVT.isSimple() && "Can only handle simple types");
1431 switch (ResVT.getSimpleVT().SimpleTy) {
1432 default: return;
1433 case MVT::v2i8:
1434 case MVT::v2i16:
1435 case MVT::v2i32:
1436 case MVT::v2i64:
1437 case MVT::v2f32:
1438 case MVT::v2f64:
1439 case MVT::v4i8:
1440 case MVT::v4i16:
1441 case MVT::v4i32:
1442 case MVT::v4f32:
1443 // This is a "native" vector type
1444 break;
1445 }
1446
1447 EVT EltVT = ResVT.getVectorElementType();
1448 unsigned NumElts = ResVT.getVectorNumElements();
1449
1450 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
1451 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
1452 // loaded type to i16 and propogate the "real" type as the memory type.
1453 bool NeedTrunc = false;
1454 if (EltVT.getSizeInBits() < 16) {
1455 EltVT = MVT::i16;
1456 NeedTrunc = true;
1457 }
1458
1459 unsigned Opcode = 0;
1460 SDVTList LdResVTs;
1461
1462 switch (NumElts) {
1463 default: return;
1464 case 2:
1465 Opcode = NVPTXISD::LoadV2;
1466 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1467 break;
1468 case 4: {
1469 Opcode = NVPTXISD::LoadV4;
1470 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1471 LdResVTs = DAG.getVTList(ListVTs, 5);
1472 break;
1473 }
1474 }
1475
1476 SmallVector<SDValue, 8> OtherOps;
1477
1478 // Copy regular operands
1479 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1480 OtherOps.push_back(N->getOperand(i));
1481
1482 LoadSDNode *LD = cast<LoadSDNode>(N);
1483
1484 // The select routine does not have access to the LoadSDNode instance, so
1485 // pass along the extension information
1486 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
1487
1488 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1489 OtherOps.size(), LD->getMemoryVT(),
1490 LD->getMemOperand());
1491
1492 SmallVector<SDValue, 4> ScalarRes;
1493
1494 for (unsigned i = 0; i < NumElts; ++i) {
1495 SDValue Res = NewLD.getValue(i);
1496 if (NeedTrunc)
1497 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1498 ScalarRes.push_back(Res);
1499 }
1500
1501 SDValue LoadChain = NewLD.getValue(NumElts);
1502
1503 SDValue BuildVec = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1504
1505 Results.push_back(BuildVec);
1506 Results.push_back(LoadChain);
1507 }
1508
ReplaceINTRINSIC_W_CHAIN(SDNode * N,SelectionDAG & DAG,SmallVectorImpl<SDValue> & Results)1509 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N,
1510 SelectionDAG &DAG,
1511 SmallVectorImpl<SDValue> &Results) {
1512 SDValue Chain = N->getOperand(0);
1513 SDValue Intrin = N->getOperand(1);
1514 DebugLoc DL = N->getDebugLoc();
1515
1516 // Get the intrinsic ID
1517 unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
1518 switch(IntrinNo) {
1519 default: return;
1520 case Intrinsic::nvvm_ldg_global_i:
1521 case Intrinsic::nvvm_ldg_global_f:
1522 case Intrinsic::nvvm_ldg_global_p:
1523 case Intrinsic::nvvm_ldu_global_i:
1524 case Intrinsic::nvvm_ldu_global_f:
1525 case Intrinsic::nvvm_ldu_global_p: {
1526 EVT ResVT = N->getValueType(0);
1527
1528 if (ResVT.isVector()) {
1529 // Vector LDG/LDU
1530
1531 unsigned NumElts = ResVT.getVectorNumElements();
1532 EVT EltVT = ResVT.getVectorElementType();
1533
1534 // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization.
1535 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
1536 // loaded type to i16 and propogate the "real" type as the memory type.
1537 bool NeedTrunc = false;
1538 if (EltVT.getSizeInBits() < 16) {
1539 EltVT = MVT::i16;
1540 NeedTrunc = true;
1541 }
1542
1543 unsigned Opcode = 0;
1544 SDVTList LdResVTs;
1545
1546 switch (NumElts) {
1547 default: return;
1548 case 2:
1549 switch(IntrinNo) {
1550 default: return;
1551 case Intrinsic::nvvm_ldg_global_i:
1552 case Intrinsic::nvvm_ldg_global_f:
1553 case Intrinsic::nvvm_ldg_global_p:
1554 Opcode = NVPTXISD::LDGV2;
1555 break;
1556 case Intrinsic::nvvm_ldu_global_i:
1557 case Intrinsic::nvvm_ldu_global_f:
1558 case Intrinsic::nvvm_ldu_global_p:
1559 Opcode = NVPTXISD::LDUV2;
1560 break;
1561 }
1562 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1563 break;
1564 case 4: {
1565 switch(IntrinNo) {
1566 default: return;
1567 case Intrinsic::nvvm_ldg_global_i:
1568 case Intrinsic::nvvm_ldg_global_f:
1569 case Intrinsic::nvvm_ldg_global_p:
1570 Opcode = NVPTXISD::LDGV4;
1571 break;
1572 case Intrinsic::nvvm_ldu_global_i:
1573 case Intrinsic::nvvm_ldu_global_f:
1574 case Intrinsic::nvvm_ldu_global_p:
1575 Opcode = NVPTXISD::LDUV4;
1576 break;
1577 }
1578 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1579 LdResVTs = DAG.getVTList(ListVTs, 5);
1580 break;
1581 }
1582 }
1583
1584 SmallVector<SDValue, 8> OtherOps;
1585
1586 // Copy regular operands
1587
1588 OtherOps.push_back(Chain); // Chain
1589 // Skip operand 1 (intrinsic ID)
1590 // Others
1591 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
1592 OtherOps.push_back(N->getOperand(i));
1593
1594 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1595
1596 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1597 OtherOps.size(), MemSD->getMemoryVT(),
1598 MemSD->getMemOperand());
1599
1600 SmallVector<SDValue, 4> ScalarRes;
1601
1602 for (unsigned i = 0; i < NumElts; ++i) {
1603 SDValue Res = NewLD.getValue(i);
1604 if (NeedTrunc)
1605 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1606 ScalarRes.push_back(Res);
1607 }
1608
1609 SDValue LoadChain = NewLD.getValue(NumElts);
1610
1611 SDValue BuildVec = DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1612
1613 Results.push_back(BuildVec);
1614 Results.push_back(LoadChain);
1615 } else {
1616 // i8 LDG/LDU
1617 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
1618 "Custom handling of non-i8 ldu/ldg?");
1619
1620 // Just copy all operands as-is
1621 SmallVector<SDValue, 4> Ops;
1622 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1623 Ops.push_back(N->getOperand(i));
1624
1625 // Force output to i16
1626 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
1627
1628 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1629
1630 // We make sure the memory type is i8, which will be used during isel
1631 // to select the proper instruction.
1632 SDValue NewLD = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL,
1633 LdResVTs, &Ops[0],
1634 Ops.size(), MVT::i8,
1635 MemSD->getMemOperand());
1636
1637 Results.push_back(NewLD.getValue(0));
1638 Results.push_back(NewLD.getValue(1));
1639 }
1640 }
1641 }
1642 }
1643
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const1644 void NVPTXTargetLowering::ReplaceNodeResults(SDNode *N,
1645 SmallVectorImpl<SDValue> &Results,
1646 SelectionDAG &DAG) const {
1647 switch (N->getOpcode()) {
1648 default: report_fatal_error("Unhandled custom legalization");
1649 case ISD::LOAD:
1650 ReplaceLoadVector(N, DAG, Results);
1651 return;
1652 case ISD::INTRINSIC_W_CHAIN:
1653 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
1654 return;
1655 }
1656 }
1657