1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the PTXTargetLowering class.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PTX.h"
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "PTXSubtarget.h"
19 #include "llvm/Function.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/CodeGen/CallingConvLower.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/SelectionDAG.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28
29 using namespace llvm;
30
31 //===----------------------------------------------------------------------===//
32 // TargetLowering Implementation
33 //===----------------------------------------------------------------------===//
34
PTXTargetLowering(TargetMachine & TM)35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
36 : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
37 // Set up the register classes.
38 addRegisterClass(MVT::i1, PTX::RegPredRegisterClass);
39 addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
40 addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
41 addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
42 addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
43 addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
44
45 setBooleanContents(ZeroOrOneBooleanContent);
46 setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct?
47 setMinFunctionAlignment(2);
48
49 ////////////////////////////////////
50 /////////// Expansion //////////////
51 ////////////////////////////////////
52
53 // (any/zero/sign) extload => load + (any/zero/sign) extend
54
55 setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
56 setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
57 setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand);
58
59 // f32 extload => load + fextend
60
61 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
62
63 // f64 truncstore => trunc + store
64
65 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
66
67 // sign_extend_inreg => sign_extend
68
69 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
70
71 // br_cc => brcond
72
73 setOperationAction(ISD::BR_CC, MVT::Other, Expand);
74
75 // select_cc => setcc
76
77 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
78 setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
79 setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
80
81 ////////////////////////////////////
82 //////////// Legal /////////////////
83 ////////////////////////////////////
84
85 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
86 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
87
88 ////////////////////////////////////
89 //////////// Custom ////////////////
90 ////////////////////////////////////
91
92 // customise setcc to use bitwise logic if possible
93
94 setOperationAction(ISD::SETCC, MVT::i1, Custom);
95
96 // customize translation of memory addresses
97
98 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
99 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
100
101 // Compute derived properties from the register classes
102 computeRegisterProperties();
103 }
104
getSetCCResultType(EVT VT) const105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const {
106 return MVT::i1;
107 }
108
LowerOperation(SDValue Op,SelectionDAG & DAG) const109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
110 switch (Op.getOpcode()) {
111 default:
112 llvm_unreachable("Unimplemented operand");
113 case ISD::SETCC:
114 return LowerSETCC(Op, DAG);
115 case ISD::GlobalAddress:
116 return LowerGlobalAddress(Op, DAG);
117 }
118 }
119
getTargetNodeName(unsigned Opcode) const120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
121 switch (Opcode) {
122 default:
123 llvm_unreachable("Unknown opcode");
124 case PTXISD::COPY_ADDRESS:
125 return "PTXISD::COPY_ADDRESS";
126 case PTXISD::LOAD_PARAM:
127 return "PTXISD::LOAD_PARAM";
128 case PTXISD::STORE_PARAM:
129 return "PTXISD::STORE_PARAM";
130 case PTXISD::READ_PARAM:
131 return "PTXISD::READ_PARAM";
132 case PTXISD::WRITE_PARAM:
133 return "PTXISD::WRITE_PARAM";
134 case PTXISD::EXIT:
135 return "PTXISD::EXIT";
136 case PTXISD::RET:
137 return "PTXISD::RET";
138 case PTXISD::CALL:
139 return "PTXISD::CALL";
140 }
141 }
142
143 //===----------------------------------------------------------------------===//
144 // Custom Lower Operation
145 //===----------------------------------------------------------------------===//
146
LowerSETCC(SDValue Op,SelectionDAG & DAG) const147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
148 assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
149 SDValue Op0 = Op.getOperand(0);
150 SDValue Op1 = Op.getOperand(1);
151 SDValue Op2 = Op.getOperand(2);
152 DebugLoc dl = Op.getDebugLoc();
153 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
154
155 // Look for X == 0, X == 1, X != 0, or X != 1
156 // We can simplify these to bitwise logic
157
158 if (Op1.getOpcode() == ISD::Constant &&
159 (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
160 cast<ConstantSDNode>(Op1)->isNullValue()) &&
161 (CC == ISD::SETEQ || CC == ISD::SETNE)) {
162
163 return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
164 }
165
166 return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
167 }
168
169 SDValue PTXTargetLowering::
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
171 EVT PtrVT = getPointerTy();
172 DebugLoc dl = Op.getDebugLoc();
173 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
174
175 assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
176
177 SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
178 SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
179 dl,
180 PtrVT.getSimpleVT(),
181 targetGlobal);
182
183 return movInstr;
184 }
185
186 //===----------------------------------------------------------------------===//
187 // Calling Convention Implementation
188 //===----------------------------------------------------------------------===//
189
190 SDValue PTXTargetLowering::
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,DebugLoc dl,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const191 LowerFormalArguments(SDValue Chain,
192 CallingConv::ID CallConv,
193 bool isVarArg,
194 const SmallVectorImpl<ISD::InputArg> &Ins,
195 DebugLoc dl,
196 SelectionDAG &DAG,
197 SmallVectorImpl<SDValue> &InVals) const {
198 if (isVarArg) llvm_unreachable("PTX does not support varargs");
199
200 MachineFunction &MF = DAG.getMachineFunction();
201 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
202 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
203 PTXParamManager &PM = MFI->getParamManager();
204
205 switch (CallConv) {
206 default:
207 llvm_unreachable("Unsupported calling convention");
208 break;
209 case CallingConv::PTX_Kernel:
210 MFI->setKernel(true);
211 break;
212 case CallingConv::PTX_Device:
213 MFI->setKernel(false);
214 break;
215 }
216
217 // We do one of two things here:
218 // IsKernel || SM >= 2.0 -> Use param space for arguments
219 // SM < 2.0 -> Use registers for arguments
220 if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
221 // We just need to emit the proper LOAD_PARAM ISDs
222 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
223 assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
224 "Kernels cannot take pred operands");
225
226 unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
227 unsigned Param = PM.addArgumentParam(ParamSize);
228 const std::string &ParamName = PM.getParamName(Param);
229 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
230 MVT::Other);
231 SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
232 ParamValue);
233 InVals.push_back(ArgValue);
234 }
235 }
236 else {
237 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
238 EVT RegVT = Ins[i].VT;
239 TargetRegisterClass* TRC = getRegClassFor(RegVT);
240
241 // Use a unique index in the instruction to prevent instruction folding.
242 // Yes, this is a hack.
243 SDValue Index = DAG.getTargetConstant(i, MVT::i32);
244 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
245 SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
246 Index);
247
248 InVals.push_back(ArgValue);
249
250 MFI->addArgReg(Reg);
251 }
252 }
253
254 return Chain;
255 }
256
257 SDValue PTXTargetLowering::
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,DebugLoc dl,SelectionDAG & DAG) const258 LowerReturn(SDValue Chain,
259 CallingConv::ID CallConv,
260 bool isVarArg,
261 const SmallVectorImpl<ISD::OutputArg> &Outs,
262 const SmallVectorImpl<SDValue> &OutVals,
263 DebugLoc dl,
264 SelectionDAG &DAG) const {
265 if (isVarArg) llvm_unreachable("PTX does not support varargs");
266
267 switch (CallConv) {
268 default:
269 llvm_unreachable("Unsupported calling convention.");
270 case CallingConv::PTX_Kernel:
271 assert(Outs.size() == 0 && "Kernel must return void.");
272 return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
273 case CallingConv::PTX_Device:
274 assert(Outs.size() <= 1 && "Can at most return one value.");
275 break;
276 }
277
278 MachineFunction& MF = DAG.getMachineFunction();
279 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
280 PTXParamManager &PM = MFI->getParamManager();
281
282 SDValue Flag;
283 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
284
285 if (ST.useParamSpaceForDeviceArgs()) {
286 assert(Outs.size() < 2 && "Device functions can return at most one value");
287
288 if (Outs.size() == 1) {
289 unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
290 unsigned Param = PM.addReturnParam(ParamSize);
291 const std::string &ParamName = PM.getParamName(Param);
292 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
293 MVT::Other);
294 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
295 ParamValue, OutVals[0]);
296 }
297 } else {
298 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
299 EVT RegVT = Outs[i].VT;
300 TargetRegisterClass* TRC = 0;
301
302 // Determine which register class we need
303 if (RegVT == MVT::i1) {
304 TRC = PTX::RegPredRegisterClass;
305 }
306 else if (RegVT == MVT::i16) {
307 TRC = PTX::RegI16RegisterClass;
308 }
309 else if (RegVT == MVT::i32) {
310 TRC = PTX::RegI32RegisterClass;
311 }
312 else if (RegVT == MVT::i64) {
313 TRC = PTX::RegI64RegisterClass;
314 }
315 else if (RegVT == MVT::f32) {
316 TRC = PTX::RegF32RegisterClass;
317 }
318 else if (RegVT == MVT::f64) {
319 TRC = PTX::RegF64RegisterClass;
320 }
321 else {
322 llvm_unreachable("Unknown parameter type");
323 }
324
325 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
326
327 SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
328 SDValue OutReg = DAG.getRegister(Reg, RegVT);
329
330 Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
331
332 MFI->addRetReg(Reg);
333 }
334 }
335
336 if (Flag.getNode() == 0) {
337 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
338 }
339 else {
340 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
341 }
342 }
343
344 SDValue
LowerCall(SDValue Chain,SDValue Callee,CallingConv::ID CallConv,bool isVarArg,bool & isTailCall,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SmallVectorImpl<ISD::InputArg> & Ins,DebugLoc dl,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const345 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
346 CallingConv::ID CallConv, bool isVarArg,
347 bool &isTailCall,
348 const SmallVectorImpl<ISD::OutputArg> &Outs,
349 const SmallVectorImpl<SDValue> &OutVals,
350 const SmallVectorImpl<ISD::InputArg> &Ins,
351 DebugLoc dl, SelectionDAG &DAG,
352 SmallVectorImpl<SDValue> &InVals) const {
353
354 MachineFunction& MF = DAG.getMachineFunction();
355 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
356 PTXParamManager &PM = MFI->getParamManager();
357
358 assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
359 "Calls are not handled for the target device");
360
361 std::vector<SDValue> Ops;
362 // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
363 Ops.resize(Outs.size() + Ins.size() + 4);
364
365 Ops[0] = Chain;
366
367 // Identify the callee function
368 const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
369 assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
370 "PTX function calls must be to PTX device functions");
371 Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
372 Ops[Ins.size()+2] = Callee;
373
374 // Generate STORE_PARAM nodes for each function argument. In PTX, function
375 // arguments are explicitly stored into .param variables and passed as
376 // arguments. There is no register/stack-based calling convention in PTX.
377 Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
378 for (unsigned i = 0; i != OutVals.size(); ++i) {
379 unsigned Size = OutVals[i].getValueType().getSizeInBits();
380 unsigned Param = PM.addLocalParam(Size);
381 const std::string &ParamName = PM.getParamName(Param);
382 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
383 MVT::Other);
384 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
385 ParamValue, OutVals[i]);
386 Ops[i+Ins.size()+4] = ParamValue;
387 }
388
389 std::vector<SDValue> InParams;
390
391 // Generate list of .param variables to hold the return value(s).
392 Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32);
393 for (unsigned i = 0; i < Ins.size(); ++i) {
394 unsigned Size = Ins[i].VT.getStoreSizeInBits();
395 unsigned Param = PM.addLocalParam(Size);
396 const std::string &ParamName = PM.getParamName(Param);
397 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
398 MVT::Other);
399 Ops[i+2] = ParamValue;
400 InParams.push_back(ParamValue);
401 }
402
403 Ops[0] = Chain;
404
405 // Create the CALL node.
406 Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
407
408 // Create the LOAD_PARAM nodes that retrieve the function return value(s).
409 for (unsigned i = 0; i < Ins.size(); ++i) {
410 SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
411 InParams[i]);
412 InVals.push_back(Load);
413 }
414
415 return Chain;
416 }
417
getNumRegisters(LLVMContext & Context,EVT VT)418 unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) {
419 // All arguments consist of one "register," regardless of the type.
420 return 1;
421 }
422
423